S4ShiftBlock

class ssm.model.block.S4ShiftBlock(method, **kwargs)

Bases: S4BaseBlock

Implementation of the S4 block with shift dynamics.

This block is a variant of the S4 block that uses a shift matrix for the hidden-to-hidden dynamics to create a memory of the previous state. In particular, matrix A is not trainable.

This block supports two forward pass methods: recurrent, and convolutional.

  • Recurrent: It applies discretized dynamics for sequential processing.

  • Convolutional: It uses the Fourier transform to compute convolutions.

The block is defined by the following equations:

\[\dot{h}(t) = Ah(t) + Bx(t), y(t) = Ch(t),\]

where \(h(t)\) is the hidden state, \(x(t)\) is the input, \(y(t)\) is the output, \(A\) is the hidden-to-hidden matrix, \(B\) is the input-to-hidden matrix, and \(C\) is the hidden-to-output matrix.

See also

Original Reference: Fu, D., Dao, T., et al. (2023). “Hungry Hungry Hippos: Towards Language Modeling with State Space Models”. arXiv:2212.14052. DOI: <https://doi.org/10.48550/arXiv.2212.14052>_.

initialize_A(hid_dim)

Initialize the shift matrix A.

Parameters:

hid_dim (int) – The hidden state dimension.

Returns:

The initialized shift matrix A.

Return type:

torch.Tensor