Source code for etna.auto.optuna.config_sampler

from functools import partial
from typing import List
from typing import Optional
from typing import Set

import numpy as np
from optuna.samplers import BaseSampler
from optuna.study import Study
from optuna.trial import FrozenTrial
from optuna.trial import TrialState

from etna.auto.utils import config_hash
from etna.auto.utils import retry


[docs]class ConfigSampler(BaseSampler): """Optuna based sampler for greedy search over different configurations.""" def __init__(self, configs: List[dict], random_generator: Optional[np.random.Generator] = None, retries: int = 10): """Init Config sampler. Parameters ---------- configs: pool of configs to sample from random_generator: numpy generator to get reproducible samples retries: number of retries to get new sample from storage. It could be useful if storage is not reliable. """ self.configs = configs self.configs_hash = {config_hash(config=config): config for config in self.configs} self._rng = random_generator self.retries = retries
[docs] def sample_independent(self, *args, **kwargs): # noqa: D102 """Sample independent. Not used.""" return {}
[docs] def infer_relative_search_space(self, *args, **kwargs): # noqa: D102 """Infer relative search space. Not used.""" return {}
[docs] def sample_relative(self, study: Study, trial: FrozenTrial, *args, **kwargs) -> dict: """Sample configuration to test. Parameters ---------- study: current optuna study trial: optuna trial to use Return ------ : sampled configuration to run objective on """ trials_to_sample = self._get_unfinished_hashes(study=study, current_trial=trial) if len(trials_to_sample) == 0: # TODO: this could cause job duplication # For some reason `_get_unfinished_hashes` does not return zero length list in `after_trial` _to_sample = list(self.configs_hash) idx = self.rng.choice(len(_to_sample)) hash_to_sample = _to_sample[idx] else: _trials_to_sample = list(trials_to_sample) idx = self.rng.choice(len(_trials_to_sample)) hash_to_sample = _trials_to_sample[idx] map_to_objective = self.configs_hash[hash_to_sample] study._storage.set_trial_user_attr(trial._trial_id, "hash", hash_to_sample) study._storage.set_trial_user_attr(trial._trial_id, "pipeline", map_to_objective) return map_to_objective
[docs] def after_trial(self, study: Study, trial: FrozenTrial, *args, **kwargs) -> None: # noqa: D102 """Stop study if all configs have been tested. Parameters ---------- study: current optuna study """ unfinished_hashes = self._get_unfinished_hashes(study=study, current_trial=trial) if len(unfinished_hashes) == 0: study.stop() if len(unfinished_hashes) == 1 and list(unfinished_hashes)[0] == trial.user_attrs["hash"]: study.stop()
def _get_unfinished_hashes(self, study: Study, current_trial: Optional[FrozenTrial] = None) -> Set[str]: """Get unfinished config hashes. Parameters ---------- study: current optuna study Returns ------- : hashes to run """ trials = study._storage.get_all_trials(study._study_id, deepcopy=False) if current_trial is not None: trials = [trial for trial in trials if trial._trial_id != current_trial._trial_id] finished_trials_hash = [] running_trials_hash = [] for t in trials: if t.state.is_finished(): finished_trials_hash.append(t.user_attrs["hash"]) elif t.state == TrialState.RUNNING: def _closure(trial): return study._storage.get_trial(trial._trial_id).user_attrs["hash"] hash_to_add = retry(partial(_closure, trial=t), max_retries=self.retries) running_trials_hash.append(hash_to_add) else: pass return set(self.configs_hash) - set(finished_trials_hash) - set(running_trials_hash) @property def rng(self): # noqa: D102 if self._rng is None: self._rng = np.random.default_rng() return self._rng
[docs] def get_config_by_hash(self, hash: str): """Get config by hash. Parameters ---------- hash: hash to get config for """ return self.configs_hash[hash]