RNNModel

class RNNModel(input_size: int, decoder_length: int, encoder_length: int, num_layers: int = 2, hidden_size: int = 16, lr: float = 0.001, loss: Optional[torch.nn.modules.module.Module] = None, train_batch_size: int = 16, test_batch_size: int = 16, optimizer_params: Optional[dict] = None, trainer_params: Optional[dict] = None, train_dataloader_params: Optional[dict] = None, test_dataloader_params: Optional[dict] = None, val_dataloader_params: Optional[dict] = None, split_params: Optional[dict] = None)[source]

Bases: etna.models.base.DeepBaseModel

RNN based model on LSTM cell.

Init DeepBaseModel.

Parameters
  • net – network to train

  • encoder_length (int) – encoder length

  • decoder_length (int) – decoder length

  • train_batch_size (int) – batch size for training

  • test_batch_size (int) – batch size for testing

  • trainer_params (Optional[dict]) – Pytorch ligthning trainer parameters (api reference pytorch_lightning.trainer.trainer.Trainer)

  • train_dataloader_params (Optional[dict]) – parameters for train dataloader like sampler for example (api reference torch.utils.data.DataLoader)

  • test_dataloader_params (Optional[dict]) – parameters for test dataloader

  • val_dataloader_params (Optional[dict]) – parameters for validation dataloader

  • split_params (Optional[dict]) –

    dictionary with parameters for torch.utils.data.random_split() for train-test splitting
    • train_size: (float) value from 0 to 1 - fraction of samples to use for training

    • generator: (Optional[torch.Generator]) - generator for reproducibile train-test splitting

    • torch_dataset_size: (Optional[int]) - number of samples in dataset, in case of dataset not implementing __len__

  • input_size (int) –

  • num_layers (int) –

  • hidden_size (int) –

  • lr (float) –

  • loss (Optional[torch.nn.Module]) –

  • optimizer_params (Optional[dict]) –

Inherited-members

Methods

fit(ts)

Fit model.

forecast(ts, horizon)

Make predictions.

get_model()

Get model.

raw_fit(torch_dataset)

Fit model on torch like Dataset.

raw_predict(torch_dataset)

Make inference on torch like Dataset.