S6Block

class ssm.model.block.S6Block(model_dim, hid_dim, dt_min=0.001, dt_max=0.1, real_random=False, dt_rank=None, scan_type='parallel', **kwargs)

Bases: Module

Implementation of the S6 block.

This block is designed to efficiently model long sequences using selective state space models. Its selection mechanism allows it to focus on relevant parts of the input sequence, making it suitable for tasks such as selective copy.

The output is computed in an efficient manner by leveraging the parallel scan algorithm.

See also

Original Reference: Gu, A., Dao, T. (2024). “Mamba: Linear-Time Sequence Modeling with Selective State Spaces”. arXiv:2312.00752. DOI: <https://arxiv.org/abs/2312.00752>_.

Original Reference: Heinsen, F., A. (2023) “Efficient Parallelization of a Ubiquitous Sequential Computation”. arXiv:2311.06281. DOI: <https://arxiv.org/abs/2311.06281>_.

_discretize(A, B, dt)

Discretization of the continuous-time dynamics to obtain the matrices \(A_{bar}\) and \(B_{bar}\).

Parameters:
  • A (torch.Tensor) – The hidden-to-hidden matrix.

  • B (torch.Tensor) – The input-to-hidden matrix.

  • dt (torch.Tensor) – The time step for discretization.

Returns:

The discretized matrices \(A_{bar}\) and \(B_{bar}\).

Return type:

tuple

forward(x)

Forward pass of the S6 block.

Parameters:

x (torch.Tensor) – The input tensor.

Returns:

The output tensor.

Return type:

torch.Tensor

static sequential_scan(A, B)

Sequential scan of the input tensor using the given matrices A and B.

Parameters:
  • A (torch.Tensor) – A tensor of shape (B, L, D, N).

  • B (torch.Tensor) – Another tensor of shape (B, L, D, N).

Returns:

The output tensor after sequential scan, of shape (B, L, D, N).

Return type:

torch.Tensor