Source code for divik.core.io._model_io

import json
import logging
import os
import pickle
from functools import partial

import numpy as np
import pandas as pd
from sklearn.pipeline import Pipeline

from divik.core import configurable, visualize

_SAVERS = set()


[docs]def saver(fn): """Register the function as handler for saving model and related summaries The saver function should be reusable for different models exhibiting the required variables. Rather prefer checking the required attributes than the model class. Examples -------- >>> from divik.core.io import saver >>> @saver ... def my_saver(model, destination, **kwargs): ... if not hasattr(model, 'my_custom_field_'): ... return ... if not 'my_param' in kwargs: ... return ... # custom saving logic comes here You can also make this function configurable: >>> import gin >>> from divik.core.io import saver >>> @saver ... @gin.configurable(allowlist=['my_param']) ... def configurable_saver(model, destination, my_param=None, **kwargs): ... if not hasattr(model, 'my_custom_field_'): ... return ... if my_param is None: ... return ... # custom saving logic comes here """ _SAVERS.add(fn)
[docs]def save(model, destination, **kwargs): """Save model and related summaries into specified destination directory""" if isinstance(destination, partial): fname_fn = destination else: fname_fn = partial(os.path.join, destination) for save_fn in _SAVERS: save_fn(model, fname_fn, **kwargs)
@saver @configurable(allowlist=["enabled"]) def save_pickle(model, fname_fn, enabled=True, **kwargs): if not enabled: return logging.info("Saving model pickle.") with open(fname_fn("model.pkl"), "wb") as pkl: pickle.dump(model, pkl) @saver def save_summary(model, fname_fn, **kwargs): if not hasattr(model, "labels_"): return logging.info("Saving JSON summary.") n_clusters = getattr(model, "n_clusters_", np.unique(model.labels_).size) with open(fname_fn("summary.json"), "w") as smr: json.dump( { "depth": getattr(model, "depth_", 1), "number_of_clusters": n_clusters, "mean_cluster_size": model.labels_.size / float(n_clusters), }, smr, ) @saver def save_labels(model, fname_fn, **kwargs): if not hasattr(model, "labels_"): return logging.info("Saving final partition.") np.save(fname_fn("final_partition.npy"), model.labels_) np.savetxt(fname_fn("final_partition.csv"), model.labels_, delimiter=", ", fmt="%i") if "xy" in kwargs: import skimage.io visualization = visualize(model.labels_, xy=kwargs["xy"]) skimage.io.imsave(fname_fn("final_partition.png"), visualization) @saver def save_multiple_labels(model, fname_fn, **kwargs): if not hasattr(model, "estimators_") or not hasattr(model.estimators[0], "labels_"): return logging.info("Saving all considered partitions.") part = np.hstack([e.labels_.reshape(-1, 1) for e in model.estimators_]) np.save(fname_fn("partitions.npy"), part) np.savetxt(fname_fn("partitions.csv"), part, delimiter=", ", fmt="%i") import skimage.io for i in range(part.shape[1]): np.savetxt( fname_fn("partitions.{0}.csv").format(i), part[:, i].reshape(-1, 1), delimiter=", ", fmt="%i", ) if "xy" in kwargs: visualization = visualize(part, xy=kwargs["xy"]) skimage.io.imsave(fname_fn("partitions.{0}.png").format(i), visualization) @saver def save_centroids(model, fname_fn, **kwargs): if not hasattr(model, "centroids_"): return logging.info("Saving centroids.") np.save(fname_fn("centroids.npy"), model.centroids_) np.savetxt(fname_fn("centroids.csv"), model.centroids_, delimiter=", ") @saver def save_filters(model, fname_fn, **kwargs): if not hasattr(model, "filters_"): return logging.info("Saving filters.") np.save(fname_fn("filters.npy"), model.filters_) np.savetxt(fname_fn("filters.csv"), model.filters_, delimiter=", ", fmt="%i") @saver def save_cluster_paths(model, fname_fn, **kwargs): if not hasattr(model, "reverse_paths_"): return rev = ["_".join(map(str, p)) for p in model.reverse_paths_] pd.DataFrame( {"path": rev, "cluster_number": list(model.reverse_paths_.values())} ).to_csv(fname_fn("paths.csv")) @saver def save_pipeline(model, fname_fn, **kwargs): if not isinstance(model, Pipeline): return feature_selector = model[:-1] clustering = model[-1] if isinstance(clustering, Pipeline): logging.info("Saving pre-extractor pickle.") with open(fname_fn("feature_pre_extractor.pkl"), "wb") as pkl: pickle.dump(feature_selector, pkl) return save(clustering, fname_fn, **kwargs) logging.info("Saving model pickle.") with open(fname_fn("feature_selector.pkl"), "wb") as pkl: pickle.dump(feature_selector, pkl) save(clustering, fname_fn, **kwargs)