Source code for etna.transforms.feature_selection.base

import warnings
from abc import ABC
from typing import List
from typing import Optional
from typing import Union

import pandas as pd
from typing_extensions import Literal

from etna.transforms import ReversibleTransform

[docs]class BaseFeatureSelectionTransform(ReversibleTransform, ABC): """Base class for feature selection transforms.""" def __init__(self, features_to_use: Union[List[str], Literal["all"]] = "all", return_features: bool = False): super().__init__(required_features="all") self.features_to_use = features_to_use self.selected_features: List[str] = [] self.return_features = return_features self._df_removed: Optional[pd.DataFrame] = None
[docs] def get_regressors_info(self) -> List[str]: """Return the list with regressors created by the transform.""" return []
def _get_features_to_use(self, df: pd.DataFrame) -> List[str]: """Get list of features from the dataframe to perform the selection on.""" features = set(df.columns.get_level_values("feature")) - {"target"} if self.features_to_use != "all": features = features.intersection(self.features_to_use) if sorted(features) != sorted(self.features_to_use): warnings.warn("Columns from feature_to_use which are out of dataframe columns will be dropped!") return sorted(features) def _transform(self, df: pd.DataFrame) -> pd.DataFrame: """Select top_k features. Parameters ---------- df: dataframe with all segments data Returns ------- result: pd.DataFrame Dataframe with with only selected features """ rest_columns = set(df.columns.get_level_values("feature")) - set(self._get_features_to_use(df)) selected_columns = sorted(self.selected_features + list(rest_columns)) result = df.loc[:, pd.IndexSlice[:, selected_columns]] if self.return_features: self._df_removed = df.drop(result.columns, axis=1) return result def _inverse_transform(self, df: pd.DataFrame) -> pd.DataFrame: """Apply inverse transform to the data. Parameters ---------- df: dataframe to apply inverse transformation Returns ------- result: pd.DataFrame dataframe before transformation """ return pd.concat([df, self._df_removed], axis=1)