Source code for tracetorch.ssm._s_series

import torch
from torch import nn
from ._ssmlayer import Layer as SSMLayer


[docs] class S4(SSMLayer): r"""A diagonal S4-style state-space layer adapted to traceTorch. ``S4`` stores a per-feature latent state of size ``d_state`` and updates it one timestep at a time. It is designed for traceTorch-style composition, not as an optimized replacement for sequence-parallel S4 implementations. Args: num_neurons (int): number of features in the target dimension. d_state (int, default=64): latent state size per feature. dim (int, default=-1): dimension along which the layer operates. Attributes: state: per-feature latent SSM state. A_log: log-parameterized diagonal dynamics. B: input projection into the state. C: output projection from the state. D: skip connection scale. log_dt: log timestep scale. Notes: - **Input**: tensor of shape ``[*,num_neurons,*]`` where ``num_neurons`` is at index ``dim``. - **Output**: tensor with the same shape as the input. """ def __init__(self, num_neurons: int, d_state: int = 64, dim: int = -1): super().__init__(num_neurons, dim, d_state=d_state) A = torch.arange(1, d_state + 1).float().repeat(num_neurons, 1) self.A_log = nn.Parameter(torch.log(A)) self.B = nn.Parameter(torch.randn(num_neurons, d_state)) self.C = nn.Parameter(torch.randn(num_neurons, d_state)) self.log_dt = nn.Parameter(torch.randn(num_neurons)) self.D = nn.Parameter(torch.randn(num_neurons)) self._initialize_state("state") def forward(self, x): self._ensure_states(x) x_w = self._to_working_dim(x) dt = torch.exp(self.log_dt).unsqueeze(-1) A = -torch.exp(self.A_log) bar_A = torch.exp(dt * A) bar_B = (bar_A - 1) / A * self.B state = self._state_to_working_dim(self.state) state = state * bar_A + bar_B * x_w.unsqueeze(-1) self.state = self._state_from_working_dim(state) y = torch.sum(state * self.C, dim=-1) + x_w * self.D return self._from_working_dim(y)
[docs] class S5(SSMLayer): r"""An S5-style state-space layer with a global latent state. ``S5`` projects the input features into a shared latent state of size ``d_state`` and projects that state back to ``num_neurons`` outputs. It processes one timestep per forward call and keeps the global state internal. Args: num_neurons (int): number of features in the target dimension. d_state (int, default=64): size of the shared latent state. dim (int, default=-1): dimension along which the layer operates. Attributes: global_state: shared latent state. A_log: log-parameterized diagonal dynamics. B: input projection into the global state. C: output projection from the global state. D: skip connection scale. log_dt: log timestep scale. Notes: - **Input**: tensor of shape ``[*,num_neurons,*]`` where ``num_neurons`` is at index ``dim``. - **Output**: tensor with the same shape as the input. """ def __init__(self, num_neurons: int, d_state: int = 64, dim: int = -1): super().__init__(num_neurons, dim, d_state=1) self.d_state = d_state A = torch.arange(1, d_state + 1).float() self.A_log = nn.Parameter(torch.log(A)) self.B = nn.Parameter(torch.randn(num_neurons, d_state)) self.C = nn.Parameter(torch.randn(d_state, num_neurons)) self.D = nn.Parameter(torch.randn(num_neurons)) self.log_dt = nn.Parameter(torch.randn(1)) self._initialize_state("global_state") def forward(self, x): self._ensure_states(x) x_w = self._to_working_dim(x) dt = torch.exp(self.log_dt) A = -torch.exp(self.A_log) bar_A = torch.exp(dt * A) bar_B = (bar_A - 1) / A g_state = self._to_working_dim(self.global_state) x_b = torch.matmul(x_w, self.B) g_state = g_state * bar_A + bar_B * x_b self.global_state = self._from_working_dim(g_state) y = torch.matmul(g_state, self.C) + x_w * self.D return self._from_working_dim(y) def _ensure_state(self, state_name: str, reference_tensor: torch.Tensor): state = getattr(self, state_name) if state is None: shape = list(reference_tensor.shape) shape[self.dim] = self.d_state state = torch.zeros(shape, dtype=reference_tensor.dtype, device=reference_tensor.device) setattr(self, state_name, state)
[docs] class S6(SSMLayer): r"""A data-dependent S6 state-space layer adapted to traceTorch. ``S6`` is the selective SSM core associated with Mamba-style models, without the causal convolution and multiplicative block gate. The timestep, input, and output projections are computed from the current input, then applied to an internal per-feature state. Args: num_neurons (int): number of features in the target dimension. d_state (int, default=16): latent state size per feature. dt_rank (int, default=-1): rank of the timestep projection. ``-1`` uses ``max(1, num_neurons // 16)``. dim (int, default=-1): dimension along which the layer operates. Attributes: state: per-feature latent SSM state. x_proj: input-dependent projection producing timestep, ``B``, and ``C``. dt_proj: projection from low-rank timestep features to per-feature timesteps. A_log: log-parameterized diagonal dynamics. D: skip connection scale. Notes: - **Input**: tensor of shape ``[*,num_neurons,*]`` where ``num_neurons`` is at index ``dim``. - **Output**: tensor with the same shape as the input. """ def __init__(self, num_neurons: int, d_state: int = 16, dt_rank: int = -1, dim: int = -1): super().__init__(num_neurons, dim, d_state=d_state) if dt_rank == -1: dt_rank = max(1, num_neurons // 16) self.x_proj = nn.Linear(num_neurons, dt_rank + d_state * 2, bias=False) self.dt_proj = nn.Linear(dt_rank, num_neurons, bias=True) nn.init.constant_(self.dt_proj.bias, -2.0) A = torch.arange(1, d_state + 1).float().repeat(num_neurons, 1) self.A_log = nn.Parameter(torch.log(A)) self.D = nn.Parameter(torch.ones(num_neurons)) self._initialize_state("state") def forward(self, x): self._ensure_states(x) x_w = self._to_working_dim(x) x_proj = self.x_proj(x_w) dt_raw, B, C = torch.split(x_proj, [self.dt_proj.in_features, self.d_state, self.d_state], dim=-1) dt = nn.functional.softplus(self.dt_proj(dt_raw)).unsqueeze(-1) A = -torch.exp(self.A_log) B = B.unsqueeze(-2) bar_A = torch.exp(dt * A) bar_B = (bar_A - 1) / (A + 1e-12) * B state = self._state_to_working_dim(self.state) state = state * bar_A + bar_B * x_w.unsqueeze(-1) self.state = self._state_from_working_dim(state) y = torch.sum(state * C.unsqueeze(-2), dim=-1) + x_w * self.D return self._from_working_dim(y)
[docs] class Mamba(SSMLayer): r"""A compact Mamba-style block adapted to traceTorch. ``Mamba`` combines an input projection, optional causal convolution buffer, SiLU gating, an S6-style selective SSM core, output projection, and residual connection. It keeps the convolution buffer and SSM state internal and processes one timestep per forward call. Args: num_neurons (int): number of features in the target dimension. d_state (int, default=16): latent SSM state size per feature. dim (int, default=-1): dimension along which the layer operates. dt_rank (int, default=-1): rank of the timestep projection. ``-1`` uses ``max(1, num_neurons // 16)``. conv_kernel (int, default=4): causal convolution buffer length. Values ``<= 1`` disable the convolution buffer. Attributes: ssm_state: per-feature selective SSM state. conv_buffer: causal convolution buffer, present when ``conv_kernel > 1``. Notes: This is a traceTorch-compatible experimental implementation. It is not an optimized replacement for production Mamba kernels. """ def __init__(self, num_neurons: int, d_state: int = 16, dim: int = -1, dt_rank: int = -1, conv_kernel: int = 4): # We pass d_state to super() so the base _ensure_state handles our ssm_state natively super().__init__(num_neurons, dim, d_state=d_state) self.conv_kernel = conv_kernel # 1. Outer Block Projections self.in_proj = nn.Linear(num_neurons, num_neurons * 2, bias=False) self.out_proj = nn.Linear(num_neurons, num_neurons, bias=False) # 2. Causal Conv1D Buffer if conv_kernel > 1: self.conv_weights = nn.Parameter(torch.ones(num_neurons, conv_kernel) / conv_kernel) self.conv_bias = nn.Parameter(torch.zeros(num_neurons)) self._initialize_state("conv_buffer") # 3. Inner S6 (SSM) Core Parameters if dt_rank == -1: dt_rank = max(1, num_neurons // 16) self.ssm_proj = nn.Linear(num_neurons, dt_rank + d_state * 2, bias=False) self.dt_proj = nn.Linear(dt_rank, num_neurons, bias=True) nn.init.constant_(self.dt_proj.bias, -2.0) # Prevents exploding dt on step 1 A = torch.arange(1, d_state + 1).float().repeat(num_neurons, 1) self.A_log = nn.Parameter(torch.log(A)) self.D = nn.Parameter(torch.ones(num_neurons)) # SSM Skip Connection self._initialize_state("ssm_state") def forward(self, x): self._ensure_states(x) x_w = self._to_working_dim(x) # 1. Project and split into Main branch and Gate branch x_proj = self.in_proj(x_w) x_main, x_gate = x_proj.chunk(2, dim=-1) # 2. Causal 1D Convolution over the time window if self.conv_kernel > 1: # We use _state_to_working_dim because conv_buffer has the kernel appended at the end! conv_buf = self._state_to_working_dim(self.conv_buffer) conv_buf = torch.cat([conv_buf[..., 1:], x_main.unsqueeze(-1)], dim=-1) self.conv_buffer = self._state_from_working_dim(conv_buf) x_conv = torch.sum(conv_buf * self.conv_weights, dim=-1) + self.conv_bias else: x_conv = x_main x_conv = nn.functional.silu(x_conv) # 3. Native Data-Dependent SSM (S6 Math) ssm_proj_out = self.ssm_proj(x_conv) dt_raw, B, C = torch.split(ssm_proj_out, [self.dt_proj.in_features, self.d_state, self.d_state], dim=-1) dt = nn.functional.softplus(self.dt_proj(dt_raw)).unsqueeze(-1) A = -torch.exp(self.A_log) B = B.unsqueeze(-2) bar_A = torch.exp(dt * A) bar_B = (bar_A - 1) / (A + 1e-12) * B state = self._state_to_working_dim(self.ssm_state) state = state * bar_A + bar_B * x_conv.unsqueeze(-1) self.ssm_state = self._state_from_working_dim(state) y_ssm = torch.sum(state * C.unsqueeze(-2), dim=-1) + x_conv * self.D # 4. Multiplicative Gating branch y = y_ssm * nn.functional.silu(x_gate) y = self.out_proj(y) # 5. Block-Level Residual Connection return self._from_working_dim(y + x_w) def _ensure_state(self, state_name: str, reference_tensor: torch.Tensor): if state_name == "conv_buffer": state = getattr(self, state_name) if state is None: shape = list(reference_tensor.shape) shape[self.dim] = self.num_neurons shape.append(self.conv_kernel) state = torch.zeros(shape, dtype=reference_tensor.dtype, device=reference_tensor.device) setattr(self, state_name, state) else: # Passes ssm_state to the base layer, which handles self.d_state cleanly super()._ensure_state(state_name, reference_tensor)