S4ShiftBlock
- class ssm.model.block.S4ShiftBlock(method, **kwargs)
Bases:
S4BaseBlock
Implementation of the S4 block with shift dynamics.
This block is a variant of the S4 block that uses a shift matrix for the hidden-to-hidden dynamics to create a memory of the previous state. In particular, matrix A is not trainable.
This block supports two forward pass methods: recurrent, and convolutional.
Recurrent: It applies discretized dynamics for sequential processing.
Convolutional: It uses the Fourier transform to compute convolutions.
The block is defined by the following equations:
\[\dot{h}(t) = Ah(t) + Bx(t), y(t) = Ch(t),\]where \(h(t)\) is the hidden state, \(x(t)\) is the input, \(y(t)\) is the output, \(A\) is the hidden-to-hidden matrix, \(B\) is the input-to-hidden matrix, and \(C\) is the hidden-to-output matrix.
See also
Original Reference: Fu, D., Dao, T., et al. (2023). “Hungry Hungry Hippos: Towards Language Modeling with State Space Models”. arXiv:2212.14052. DOI: <https://doi.org/10.48550/arXiv.2212.14052>_.
- initialize_A(hid_dim)
Initialize the shift matrix A.
- Parameters:
hid_dim (int) – The hidden state dimension.
- Returns:
The initialized shift matrix A.
- Return type:
torch.Tensor