S4LowRankBlock

class ssm.model.block.S4LowRankBlock(method, **kwargs)

Bases: S4BlockInterface

Implementation of the low-rank S4 block.

This block supports only the convolutional method for the forward pass. It allows an efficient computation of the convolutional kernel using the Cauchy product.

See also

Original Reference: Gu, A., Goel, K., and Re, G. (2021). “Efficiently Modeling Long Sequences with Structured State Spaces”. arXiv:2111.00396. DOI: <https://doi.org/10.48550/arXiv.2111.00396>_.

_cauchy_dot(a0, a1, b0, b1, denominator)

Compute the Cauchy product of two sequences.

Parameters:
  • a0 (torch.Tensor) – Matrix A0.

  • a1 (torch.Tensor) – Matrix A1.

  • b0 (torch.Tensor) – Matrix B0.

  • b1 (torch.Tensor) – Matrix B1.

  • denominator (torch.Tensor) – Denominator tensor.

Returns:

The Cauchy product matrices.

Return type:

tuple

_compute_K(L)

Computation of the kernel K used in the convolutional method.

Parameters:

L (int) – The length of the sequence.

Returns:

The convolution kernel K.

Return type:

torch.Tensor

static _compute_omega(L)

Compute the roots of unity for the FFT.

Parameters:

L (int) – Length of the sequence.

Returns:

The roots of unity.

Return type:

torch.Tensor

forward_convolutional(x)

Forward pass.

Parameters:

x (torch.Tensor) – The input tensor.

Returns:

The output tensor.

Return type:

torch.Tensor