Creating a Custom Layer
A traceTorch layer is a PyTorch module with a small amount of state-management structure. This tutorial recreates a minimal GRU-like layer to show the moving parts without drowning in SNN details.
The goal
We want a layer that:
stores a hidden state
H;creates
Hlazily from the input shape;works on any feature dimension through
dim;integrates with
tt.Model.zero_states()anddetach_states().
Subclass a traceTorch layer
For RNN-style layers, subclass tt.rnn.Layer. It already inherits from tt.Layer.
import torch
from torch import nn
import tracetorch as tt
class MiniGRU(tt.rnn.Layer):
def __init__(self, in_features: int, out_features: int, dim: int = -1):
super().__init__(num_neurons=out_features, dim=dim)
self._initialize_state("H")
self.gates = nn.Linear(in_features + out_features, 2 * out_features)
self.candidate = nn.Linear(in_features + out_features, out_features)
Registering a state
_initialize_state("H") records the state name and sets self.H = None. That is enough for tt.Model to find and
manage the state later.
Forward pass
The forward pass has the same shape as the built-in layers:
def forward(self, x):
self._ensure_states(x)
x = self._to_working_dim(x)
H = self._to_working_dim(self.H)
H_x = torch.cat([H, x], dim=-1)
reset, update = torch.sigmoid(self.gates(H_x)).chunk(2, dim=-1)
candidate = torch.tanh(self.candidate(torch.cat([H * reset, x], dim=-1)))
H = H * (1 - update) + update * candidate
self.H = self._from_working_dim(H)
return self.H
_ensure_states(x)Creates
Hif it isNone. The state shape matchesxexcept the target dimension becomesout_features._to_working_dim(x)Moves the configured
dimto the last dimension so linear layers can operate normally._from_working_dim(H)Moves the last dimension back to the configured
dim.
Using the layer
class Net(tt.Model):
def __init__(self):
super().__init__()
self.layer = MiniGRU(64, 32)
def forward(self, x):
return self.layer(x)
model = Net()
model.zero_states()
y = model(torch.rand(16, 64))
print(y.shape)
# torch.Size([16, 32])
Adding constrained parameters
SNN and SSM layers often need constrained parameters. For example, a decay should stay between zero and one. SNN layers
use helper methods such as _register_decay:
class DecayLayer(tt.snn.Layer):
def __init__(self, num_neurons: int, beta: float = 0.9):
super().__init__(num_neurons)
self._initialize_state("mem")
self._register_decay("beta", beta, rank=1, learnable=True)
def forward(self, x):
self._ensure_states(x)
self.mem = self.mem * self.beta + x
return self.mem
The public self.beta value is activated through a sigmoid, so it stays in (0, 1). The raw stored parameter is
self.raw_beta. This is what allows TTcompile() and TTdecompile() to optimize constrained parameters later.
Checklist
When creating a custom traceTorch layer:
call the superclass with
num_neuronsanddim;call
_initialize_statefor every hidden state;call
_ensure_states(x)before using states inforward;use
_to_working_dimbefore operations that expect features last;write updated states back through
_from_working_dim;return one tensor, keeping hidden states internal.