MambaBlock

class ssm.model.block.MambaBlock(model_dim, expansion_factor=2, kernel_size=4, ssm_type='S4', **kwargs)

Bases: Module

Implementation of the Mamba block. It combines a linear layer, a convolutional layer, and an SSM block as explained in the original Mamba paper and in the official repository.

See also

Original Reference: Gu, A. and Dao, T. (2024). “Mamba: Linear-Time Sequence Modeling with Selective State Spaces”. arXiv:2312.00752. DOI: <https://doi.org/10.48550/arXiv.2312.00752>_. Official GitHub Repository: https://github.com/state-spaces/mamba.

_initialize_ssm_block(ssm_type, **kwargs)

Initialize the SSM block based on the specified type.

Parameters:
  • ssm_type (str) – The type of SSM block to use. Available options are “S4”, “S4D”, “S4LowRank”, “S6”.

  • kwargs (dict) – Additional arguments for the SSM block constructor.

Raises:

ValueError – If an invalid ssm_type is provided.

forward(x)

Forward pass of the Mamba block. :param torch.Tensor x: The input tensor with shape (B, L, H). :return: The output tensor. :rtype: torch.Tensor