S4DiagonalBlock

class ssm.model.block.S4DBlock(method, **kwargs)

Bases: S4BlockInterface

Implementation of the diagonal S4 block.

This block is a variant of the S4 block that uses a diagonal matrix for the hidden-to-hidden dynamics. It is designed to simplify both the logic and implementation of the S4 block while maintaining the same functionality.

This block supports two forward pass methods: recurrent, and convolutional.

  • Recurrent: It applies discretized dynamics for sequential processing.

  • Convolutional: It uses the Fourier transform to compute convolutions.

The block is defined by the following equations:

\[\dot{h}(t) = Ah(t) + Bx(t), y(t) = Ch(t),\]

where \(h(t)\) is the hidden state, \(x(t)\) is the input, \(y(t)\) is the output, \(A\) is the hidden-to-hidden matrix, \(B\) is the input-to-hidden matrix, and \(C\) is the hidden-to-output matrix.

See also

Original Reference: Gu, A., Gupta, A., Goel, K., and Re, G. (2022). “On the Parameterization and Initialization of Diagonal State Space Models”. arXiv:2206.11893. DOI: <https://arxiv.org/pdf/2206.11893>_.

_compute_K(L)

Computation of the kernel K used in the convolutional method.

_discretize_bilinear()

Discretization of the continuous-time dynamics to obtain the matrices \(A_{bar}\) and \(B_{bar}\).

_discretize_zoh()

Discretization of the continuous-time dynamics to obtain the matrices \(A_{bar}\) and \(B_{bar}\).

static _recurrent_step(A_bar, B_bar, C, x, y, h, t)

Recurrent step computation.

Parameters:
  • A_bar (torch.Tensor) – The discretized hidden-to-hidden matrix.

  • B_bar (torch.Tensor) – The discretized input-to-hidden matrix.

  • C (torch.Tensor) – The hidden-to-output matrix.

  • x (torch.Tensor) – The input tensor.

  • y (torch.Tensor) – The output tensor.

  • h (torch.Tensor) – The hidden state tensor.

  • t (int) – The current time step.

Returns:

The updated hidden state.

Return type:

torch.Tensor

static initialize_A(hid_dim, init_method, real_random=False, imag_random=False)

Initialization of the A matrix.

Parameters:
  • hid_dim (int) – The hidden state dimension.

  • init_method (str) – The method for initializing the A matrix. Options are: S4D-Inv, S4D-Lin, S4D-Quad, S4D-Real.

  • real_random (bool) – If True, the real part of the A matrix is initialized at random between 0 and 1. Default is False.

  • imag_random (bool) – If True, the imaginary part of the A matrix is initialized at random between 0 and 1. Default is False.

Returns:

The initialized A matrix.

Return type:

torch.Tensor

Raises:

ValueError – If an unknown initialization method is provided.

static vandermonde_matrix(L, A_bar)

Compute the Vandermonde matrix for the diagonal S4 block.

Parameters:

L (int) – The length of the sequence.

Returns:

The Vandermonde matrix.

Return type:

torch.Tensor