S4BaseBlock
- class ssm.model.block.S4BaseBlock(method, **kwargs)
Bases:
S4BlockInterface
Implementation of the basic S4 block.
This block supports two forward pass methods: recurrent, and convolutional.
Recurrent: It applies discretized dynamics for sequential processing.
Convolutional: It uses the Fourier transform to compute convolutions.
The block is defined by the following equations:
\[\dot{h}(t) = Ah(t) + Bx(t), y(t) = Ch(t),\]where \(h(t)\) is the hidden state, \(x(t)\) is the input, \(y(t)\) is the output, \(A\) is the hidden-to-hidden matrix, \(B\) is the input-to-hidden matrix, and \(C\) is the hidden-to-output matrix.
See also
Original Reference: Gu, A., Goel, K., and Re, G. (2021). “Efficiently Modeling Long Sequences with Structured State Spaces”. arXiv:2111.00396. DOI: <https://doi.org/10.48550/arXiv.2111.00396>_.
- _compute_K(L)
Computation of the kernel K used in the convolutional method. K is defined as \(K = [C A^0 B, C A^1 B, ..., C A^{L-1} B]\).
- Parameters:
L (int) – The length of the sequence.
- Returns:
The convolution kernel \(K\).
- Return type:
torch.Tensor
- _discretize()
Discretization of the continuous-time dynamics to obtain the matrices \(A_{bar}\) and \(B_{bar}\).
- static _preprocess(A_bar, B_bar, C)
Preprocessing of the discretized matrices A_bar and B_bar.
- Returns:
The preprocessed matrices A_bar, B_bar, and C.
- Return type:
tuple
- static _recurrent_step(A_bar, B_bar, C, x, y, h, t)
Recurrent step computation.
- Parameters:
A_bar (torch.Tensor) – The discretized hidden-to-hidden matrix.
B_bar (torch.Tensor) – The discretized input-to-hidden matrix.
C (torch.Tensor) – The hidden-to-output matrix.
x (torch.Tensor) – The input tensor.
y (torch.Tensor) – The output tensor.
h (torch.Tensor) – The hidden state tensor.
t (int) – The current time step.
- Returns:
The updated hidden state.
- Return type:
torch.Tensor