Source code for divik.sampler._core

from abc import ABCMeta, abstractmethod
from contextlib import contextmanager
from itertools import count

from sklearn.base import BaseEstimator, clone


[docs]class BaseSampler(BaseEstimator, metaclass=ABCMeta): """Base class for all the samplers Sampler is Pool-safe, i.e. can simply store a dataset. It will not be serialized by pickle when going to another process, if handled properly. Before you spawn a pool, a data must be moved to a module-level variable. To simplify that process a contract has been prepared. You open a context and operate within a context: >>> with sampler.parallel() as sampler_, ... Pool(initializer=sampler_.initializer, ... initargs=sampler_.initargs) as pool: ... pool.map(sampler_.get_sample, range(10)) Keep in mind, that __iter__ and fit are not accessible in parallel context. __iter__ would yield the same values independently in all the workers. Now it needs to be done consciously and in well-though manner. fit could lead to a non-predictable behaviour. If you need the original sampler, you can get a clone (not fit to the data). """ def __iter__(self): """Iter through `n_samples` samples or infinitely if unspecified""" if hasattr(self, "n_samples") and self.n_samples is not None: samples = range(self.n_samples) else: samples = count() for i in samples: yield self.get_sample(i)
[docs] @abstractmethod def get_sample(self, seed): """Return specific sample Following assumptions should be met: a) sampler.get_sample(x) == sampler.get_sample(x) b) x != y should yield sampler.get_sample(x) != sampler.get_sample(y) Parameters ---------- seed : int The seed to use to draw the sample Returns ------- sample : array_like, (*self.shape_) Returns the drawn sample """ raise NotImplementedError("get_sample is not implemented")
[docs] def fit(self, X, y=None): """Fit sampler to data It's a base for both supervised and unsupervised samplers. """ return self
[docs] @contextmanager def parallel(self): """Create parallel context for the sampler to operate""" yield ParallelSampler(self)
[docs]class ParallelSampler: """Helper class for sharing the sampler functionality""" def __init__(self, sampler: BaseSampler): self.sampler = sampler
[docs] def get_sample(self, seed): """Return specific sample""" return self.sampler.get_sample(seed)
[docs] def initializer(self, *args): pass
@property def initargs(self): return ()
[docs] def clone(self): """Clones the original sampler""" return clone(self.sampler)