Models

Models are used to make predictions. Let’s look at the basic example of usage:

>>> import pandas as pd
>>> from etna.datasets import TSDataset, generate_ar_df
>>> from etna.transforms import LagTransform
>>> from etna.models import LinearPerSegmentModel
>>>
>>> df = generate_ar_df(periods=100, start_time="2021-01-01", ar_coef=[1/2], n_segments=2)
>>> ts = TSDataset(TSDataset.to_dataset(df), "D")
>>> lag_transform = LagTransform(in_column="target", lags=[3, 4, 5])
>>> ts.fit_transform(transforms=[lag_transform])
>>> future_ts = ts.make_future(3)
>>> model = LinearPerSegmentModel()
>>> model.fit(ts)
LinearPerSegmentModel(fit_intercept = True, normalize = False, )
>>> forecast_ts = model.forecast(future_ts)
segment                 segment_0  ... segment_1
feature    regressor_target_lag_3  ...    target
timestamp                          ...
2021-04-11              -0.090673  ...  0.286764
2021-04-12              -0.665337  ...  0.295589
2021-04-13               0.365363  ...  0.374554
[3 rows x 8 columns]

There is a key note to mention: future_ts and forecast_ts are the same objects. Method forecast only fills ‘target’ columns in future_ts and return reference to it.

>>> forecast_ts is future_ts
True

Details and available models

See the API documentation for further details on available models:

etna.models.autoarima.AutoARIMAModel(**kwargs)

Class for holding auto arima model.

etna.models.tbats.BATSModel([use_box_cox, ...])

Class for holding segment interval BATS model.

etna.models.base.BaseAdapter()

Base class for models adapter.

etna.models.catboost.CatBoostMultiSegmentModel([...])

Class for holding Catboost model for all segments.

etna.models.catboost.CatBoostPerSegmentModel([...])

Class for holding per segment Catboost model.

etna.models.deadline_ma.DeadlineMovingAverageModel([...])

Moving average model that uses exact previous dates to predict.

etna.models.nn.deepar.DeepARModel([...])

Wrapper for pytorch_forecasting.models.deepar.DeepAR.

etna.models.nn.deepstate.deepstate.DeepStateModel(...)

DeepState model.

etna.models.linear.ElasticMultiSegmentModel([...])

Class holding sklearn.linear_model.ElasticNet for all segments.

etna.models.linear.ElasticPerSegmentModel([...])

Class holding per segment sklearn.linear_model.ElasticNet.

etna.models.holt_winters.HoltModel([...])

Holt etna model.

etna.models.holt_winters.HoltWintersModel([...])

Holt-Winters' etna model.

etna.models.linear.LinearMultiSegmentModel([...])

Class holding sklearn.linear_model.LinearRegression for all segments.

etna.models.linear.LinearPerSegmentModel([...])

Class holding per segment sklearn.linear_model.LinearRegression.

etna.models.nn.mlp.MLPModel(input_size, ...)

MLPModel.

etna.models.moving_average.MovingAverageModel([...])

MovingAverageModel averages previous series values to forecast future one.

etna.models.nn.nbeats.nbeats.NBeatsGenericModel(...)

Generic N-BEATS model.

etna.models.nn.nbeats.nbeats.NBeatsInterpretableModel(...)

Interpretable N-BEATS model.

etna.models.naive.NaiveModel([lag])

Naive model predicts t-th value of series with its (t - lag) value.

etna.models.base.NonPredictionIntervalContextIgnorantAbstractModel()

Interface for models that don't support prediction intervals and don't need context for prediction.

etna.models.base.NonPredictionIntervalContextRequiredAbstractModel()

Interface for models that don't support prediction intervals and need context for prediction.

etna.models.nn.patchts.PatchTSModel(...[, ...])

PatchTS model using PyTorch layers.

etna.models.base.PredictionIntervalContextIgnorantAbstractModel()

Interface for models that support prediction intervals and don't need context for prediction.

etna.models.base.PredictionIntervalContextRequiredAbstractModel()

Interface for models that support prediction intervals and need context for prediction.

etna.models.prophet.ProphetModel([growth, ...])

Class for holding Prophet model.

etna.models.nn.utils.PytorchForecastingDatasetBuilder([...])

Builder for PytorchForecasting dataset.

etna.models.nn.rnn.RNNModel(input_size, ...)

RNN based model on LSTM cell.

etna.models.sarimax.SARIMAXModel([order, ...])

Class for holding SARIMAX model.

etna.settings.SETTINGS

etna settings.

etna.settings.SETTINGS

etna settings.

etna.models.seasonal_ma.SeasonalMovingAverageModel([...])

Seasonal moving average.

etna.models.holt_winters.SimpleExpSmoothingModel([...])

Exponential smoothing etna model.

etna.models.sklearn.SklearnMultiSegmentModel(...)

Class for holding Sklearn model for all segments.

etna.models.sklearn.SklearnPerSegmentModel(...)

Class for holding per segment Sklearn model.

etna.models.statsforecast.StatsForecastARIMAModel([...])

Class for holding statsforecast.models.ARIMA.

etna.models.statsforecast.StatsForecastAutoARIMAModel([...])

Class for holding statsforecast.models.AutoARIMA.

etna.models.statsforecast.StatsForecastAutoCESModel([...])

Class for holding statsforecast.models.AutoCES.

etna.models.statsforecast.StatsForecastAutoETSModel([...])

Class for holding statsforecast.models.AutoETS.

etna.models.statsforecast.StatsForecastAutoThetaModel([...])

Class for holding statsforecast.models.AutoTheta.

etna.models.tbats.TBATSModel([use_box_cox, ...])

Class for holding segment interval TBATS model.

etna.models.nn.tft.TFTModel([...])

Wrapper for pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer.