S4

class ssm.model.S4(model_dim, hid_dim, method, n_layers=2, block_type='S4', activation=<class 'torch.nn.modules.activation.GELU'>, normalization=True, **kwargs)

Bases: Module

Implementation of the Structured State Space Sequence (S4) model.

The S4 model is designed for efficiently modeling long-range dependencies in sequential data using structured state space representations. It enables improved scalability and performance compared to traditional recurrent architectures.

The model is composed of several S4 blocks, each followed by an activation function and a linear layer, as explained in the referenced paper. The S4 blocks can be of different types, including the basic S4, and the low-rank and diagonal variants.

Each block supports two forward pass methods:

  • Recurrent: It applies discretized dynamics for sequential processing.

  • Convolutional: It uses the Fourier transform to compute convolutions.

Warning

The low-rank S4 block supports only the convolutional forward pass.

See also

Original Reference: Gu, A., Goel, K., and Re, G. (2021). “Efficiently Modeling Long Sequences with Structured State Spaces”. arXiv:2111.00396. DOI: <https://doi.org/10.48550/arXiv.2111.00396>_.

change_forward(method)

Change the forward method of each block, depending on chosen method.

Parameters:

method (str) – The forward computation method. Available options are: “recurrent”, “convolutional”.

forward(x)

Forward pass of the S4 model.

Parameters:

x (torch.Tensor) – The input tensor.

Returns:

The output tensor.

Return type:

torch.Tensor