Creating a Custom Layer

A traceTorch layer is a PyTorch module with a small amount of state-management structure. This tutorial recreates a minimal GRU-like layer to show the moving parts without drowning in SNN details.

The goal

We want a layer that:

  • stores a hidden state H;

  • creates H lazily from the input shape;

  • works on any feature dimension through dim;

  • integrates with tt.Model.zero_states() and detach_states().

Subclass a traceTorch layer

For RNN-style layers, subclass tt.rnn.Layer. It already inherits from tt.Layer.

import torch
from torch import nn
import tracetorch as tt


class MiniGRU(tt.rnn.Layer):
    def __init__(self, in_features: int, out_features: int, dim: int = -1):
        super().__init__(num_neurons=out_features, dim=dim)

        self._initialize_state("H")
        self.gates = nn.Linear(in_features + out_features, 2 * out_features)
        self.candidate = nn.Linear(in_features + out_features, out_features)

Registering a state

_initialize_state("H") records the state name and sets self.H = None. That is enough for tt.Model to find and manage the state later.

Forward pass

The forward pass has the same shape as the built-in layers:

def forward(self, x):
    self._ensure_states(x)

    x = self._to_working_dim(x)
    H = self._to_working_dim(self.H)

    H_x = torch.cat([H, x], dim=-1)
    reset, update = torch.sigmoid(self.gates(H_x)).chunk(2, dim=-1)

    candidate = torch.tanh(self.candidate(torch.cat([H * reset, x], dim=-1)))
    H = H * (1 - update) + update * candidate

    self.H = self._from_working_dim(H)
    return self.H
_ensure_states(x)

Creates H if it is None. The state shape matches x except the target dimension becomes out_features.

_to_working_dim(x)

Moves the configured dim to the last dimension so linear layers can operate normally.

_from_working_dim(H)

Moves the last dimension back to the configured dim.

Using the layer

class Net(tt.Model):
    def __init__(self):
        super().__init__()
        self.layer = MiniGRU(64, 32)

    def forward(self, x):
        return self.layer(x)


model = Net()
model.zero_states()
y = model(torch.rand(16, 64))
print(y.shape)
# torch.Size([16, 32])

Adding constrained parameters

SNN and SSM layers often need constrained parameters. For example, a decay should stay between zero and one. SNN layers use helper methods such as _register_decay:

class DecayLayer(tt.snn.Layer):
    def __init__(self, num_neurons: int, beta: float = 0.9):
        super().__init__(num_neurons)
        self._initialize_state("mem")
        self._register_decay("beta", beta, rank=1, learnable=True)

    def forward(self, x):
        self._ensure_states(x)
        self.mem = self.mem * self.beta + x
        return self.mem

The public self.beta value is activated through a sigmoid, so it stays in (0, 1). The raw stored parameter is self.raw_beta. This is what allows TTcompile() and TTdecompile() to optimize constrained parameters later.

Checklist

When creating a custom traceTorch layer:

  • call the superclass with num_neurons and dim;

  • call _initialize_state for every hidden state;

  • call _ensure_states(x) before using states in forward;

  • use _to_working_dim before operations that expect features last;

  • write updated states back through _from_working_dim;

  • return one tensor, keeping hidden states internal.