Utils
- ssm.utils.compute_hippo(N)
Constructs the HIPPO hidden-to-hidden matrix A.
- Parameters:
N (int) – The size of the HIPPO matrix.
- Returns:
A (N, N) matrix initialized using the HIPPO method.
- Return type:
torch.Tensor
- ssm.utils.compute_S4DInv(N, real_random=False, imag_random=False)
Construct the S4D-Inv matrix A.
- Parameters:
N (int) – The size of the matrix.
real_random (bool) – If True, the real part of the A matrix is initialized at random between 0 and 1. Default is False.
imag_random (bool) – If True, the imaginary part of the A matrix is initialized at random between 0 and 1. Default is False.
- Returns:
The computed matrix A.
- Return type:
torch.Tensor
- ssm.utils.compute_S4DLin(N, real_random=False, imag_random=False)
Construct the S4D-Lin matrix A.
- Parameters:
N (int) – The size of the matrix.
real_random (bool) – If True, the real part of the A matrix is initialized at random between 0 and 1. Default is False.
imag_random (bool) – If True, the imaginary part of the A matrix is initialized at random between 0 and 1. Default is False.
- Returns:
The computed matrix A.
- Return type:
torch.Tensor
- ssm.utils.compute_S4DQuad(N, real_random=False)
Construct the S4D-Quad matrix A.
- Parameters:
N (int) – The size of the matrix.
real_random (bool) – If True, the real part of the A matrix is initialized at random between 0 and 1. Default is False.
- Returns:
The computed matrix A.
- Return type:
torch.Tensor
- ssm.utils.compute_S4DReal(N, real_random=False)
Construct the S4D-Real matrix A.
- Parameters:
N (int) – The size of the matrix.
real_random (bool) – If True, the real part of the A matrix is initialized at random between 0 and 1. Default is False.
- Returns:
The computed matrix A.
- Return type:
torch.Tensor
- ssm.utils.compute_dplr(A)
Construct the diagonal plus low-rank (DPLR) form of matrix A. The matrix A is decomposed into a diagonal matrix Lambda and in a low-rank matrix given by the outer product of two vectors p and q.
- Parameters:
A (torch.Tensor) – The input matrix.
- Returns:
The diagonal plus low-rank form of A.
- Return type:
tuple
- ssm.utils.initialize_dt(dim, dt_min, dt_max, inverse_softplus=False)
Initialize the time step dt for the S4 and S6 blocks.
- Parameters:
dt_min (float) – The minimum time step for discretization.
dt_max (float) – The maximum time step for discretization.
- Returns:
Initialized time step dt tensor of shape (input_dim,).
- Return type:
torch.Tensor