S4BlockInterface
- class ssm.model.block.s4_block_interface.S4BlockInterface(method, **kwargs)
Bases:
Module
,ABC
Implementation of the S4 block interface. Every S4 block should inherit from this interface and implement the required methods.
This block supports two forward pass methods: recurrent, and convolutional.
Recurrent: It applies discretized dynamics for sequential processing.
Convolutional: It uses the Fourier transform to compute convolutions.
The block is defined by the following equations:
\[\dot{h}(t) = Ah(t) + Bx(t), y(t) = Ch(t),\]where \(h(t)\) is the hidden state, \(x(t)\) is the input, \(y(t)\) is the output, \(A\) is the hidden-to-hidden matrix, \(B\) is the input-to-hidden matrix, and \(C\) is the hidden-to-output matrix.
See also
Original Reference: Gu, A., Goel, K., and Re, G. (2021). “Efficiently Modeling Long Sequences with Structured State Spaces”. arXiv:2111.00396. DOI: <https://doi.org/10.48550/arXiv.2111.00396>_.
- abstract _compute_K(L)
Computation of the kernel K used in the convolutional method.
- Parameters:
L (int) – The length of the sequence.
- Returns:
The convolution kernel \(K\).
- Return type:
torch.Tensor
- static _preprocess(A_bar, B_bar, C)
Preprocessing of the discretized matrices A_bar and B_bar.
- Returns:
The preprocessed matrices A_bar, B_bar, and C.
- Return type:
tuple
- change_forward(method)
Change the forward method.
- Parameters:
method (str) – The forward computation method. Available options are: recurrent, convolutional.
- Raises:
ValueError – If an invalid method is provided.
- forward_convolutional(x)
Forward pass using the convolutional method.
- Parameters:
x (torch.Tensor) – The input tensor.
- Returns:
The output tensor.
- Return type:
torch.Tensor
- forward_recurrent(x)
Forward pass using the recurrent method.
- Parameters:
x (torch.Tensor) – The input tensor .
- Returns:
The output tensor.
- Return type:
torch.Tensor