S4BlockInterface

class ssm.model.block.s4_block_interface.S4BlockInterface(method, **kwargs)

Bases: Module, ABC

Implementation of the S4 block interface. Every S4 block should inherit from this interface and implement the required methods.

This block supports two forward pass methods: recurrent, and convolutional.

  • Recurrent: It applies discretized dynamics for sequential processing.

  • Convolutional: It uses the Fourier transform to compute convolutions.

The block is defined by the following equations:

\[\dot{h}(t) = Ah(t) + Bx(t), y(t) = Ch(t),\]

where \(h(t)\) is the hidden state, \(x(t)\) is the input, \(y(t)\) is the output, \(A\) is the hidden-to-hidden matrix, \(B\) is the input-to-hidden matrix, and \(C\) is the hidden-to-output matrix.

See also

Original Reference: Gu, A., Goel, K., and Re, G. (2021). “Efficiently Modeling Long Sequences with Structured State Spaces”. arXiv:2111.00396. DOI: <https://doi.org/10.48550/arXiv.2111.00396>_.

abstract _compute_K(L)

Computation of the kernel K used in the convolutional method.

Parameters:

L (int) – The length of the sequence.

Returns:

The convolution kernel \(K\).

Return type:

torch.Tensor

static _preprocess(A_bar, B_bar, C)

Preprocessing of the discretized matrices A_bar and B_bar.

Returns:

The preprocessed matrices A_bar, B_bar, and C.

Return type:

tuple

change_forward(method)

Change the forward method.

Parameters:

method (str) – The forward computation method. Available options are: recurrent, convolutional.

Raises:

ValueError – If an invalid method is provided.

forward_convolutional(x)

Forward pass using the convolutional method.

Parameters:

x (torch.Tensor) – The input tensor.

Returns:

The output tensor.

Return type:

torch.Tensor

forward_recurrent(x)

Forward pass using the recurrent method.

Parameters:

x (torch.Tensor) – The input tensor .

Returns:

The output tensor.

Return type:

torch.Tensor