The SNNs of traceTorch
traceTorch was designed around spiking neural networks, but it deliberately treats “spiking” as a configurable part of the layer rather than as an unavoidable hard-coded event. The SNN layers are recurrent dynamical systems first: they keep one or more traces, update those traces one timestep at a time, and optionally convert the membrane value into binary, ternary, or scaled ternary output.
This is the main idea to keep in mind when reading the API:
LIlayers return the membrane trace directly. They are leaky integrators, not firing neurons.LIBlayers use one positive threshold and produce non-negative output.LITlayers use positive and negative thresholds and can produce positive or negative output.LITSlayers are ternary layers with learnable or fixed positive and negative output scales.Dsplits traces into positive and negative branches.Sadds a synaptic trace before the membrane trace.Radds a recurrent trace from the previous output.
The membrane trace
The simplest traceTorch SNN layer is LI. It stores a membrane state called mem and updates it as:
mem = beta * mem + x
return mem
The decay beta is constrained to (0, 1) by the layer registration machinery, even when it is learnable. A value
near zero makes the layer mostly react to the current input. A value near one gives the layer longer memory.
The LIEMA variants use an exponential-moving-average form instead:
mem = beta * mem + (1 - beta) * x
return mem
This keeps the magnitude bounded in a way that is useful when the layer is meant to smooth a signal instead of accumulate evidence.
Spikes, probabilities, and quantization
The firing layers first update a membrane trace and then compare it with one or more thresholds. For a binary LIB
layer, the default calculation is:
mem = beta * mem + x
spike_prob = sigmoid4x(mem - threshold + bias)
spikes = quant_fn(spike_prob)
mem = mem - spikes * threshold
return spikes
The default spike_fn is sigmoid4x:
sigmoid4x(x) = sigmoid(4 * x)
This gives a smooth transition around threshold. It is steep enough to behave like a soft firing decision, but it still has useful gradients for optimization.
The default quant_fn is nn.Identity(). That means the default layer returns the smooth spike probability itself,
not a hard 0 or 1 event. In this default mode, the SNN behaves more like a differentiable recurrent ODE-style
system than a strictly discrete spiking mechanism. This is intentional: it gives stable gradients and makes the default
layer easy to train.
To make the forward pass more spike-like, pass a quantization function explicitly:
layer = tt.snn.LIB(
num_neurons=128,
quant_fn=tt.functional.round_ste(),
)
traceTorch provides straight-through quantizers in tt.functional. These functions make a discrete or stochastic
forward decision while defining a surrogate gradient for the backward pass.
Surrogate gradients
Hard spikes are difficult to train with ordinary gradient descent because a step function has zero derivative almost everywhere and is undefined at the threshold. Surrogate-gradient training keeps the forward computation spike-like while using a smoother backward approximation.
In traceTorch, the separation is explicit:
spike_fnturns membrane distance from threshold into a smooth firing probability.quant_fndecides whether that probability stays continuous or becomes discrete.
This design makes the training behavior visible. If you keep quant_fn=nn.Identity(), gradients flow through the
smooth probability. If you use round_ste(), stochastic_round_ste(), or probabilistic_ste(), the forward pass
becomes more discrete while the backward pass remains trainable.
Synaptic traces
Synaptic variants add a trace called syn before the membrane:
syn = alpha * syn + (1 - alpha) * x
mem = beta * mem + syn
This makes the input current smoother before it reaches the membrane. Dual synaptic layers use pos_syn and
neg_syn so positive and negative signals can decay independently.
Recurrent traces
Recurrent variants add a trace of the previous output:
rec = gamma * rec + (1 - gamma) * prev_output
mem = beta * mem + x + rec_weight * rec
The recurrent trace is internal to the layer. The caller still passes one tensor in and receives one tensor out, which keeps traceTorch layers composable with ordinary PyTorch modules.
Choosing a layer
Use LI or LIEMA when you want a continuous trace. Use LIB when you want one-sided firing. Use LIT when
positive and negative events should be represented separately. Use LITS when the positive and negative events should
also have their own output magnitudes.
Then add prefixes only when they match the dynamics you need:
Add
Dwhen positive and negative history should be stored separately.Add
Swhen inputs should be smoothed before membrane integration.Add
Rwhen the neuron’s previous output should influence its next membrane update.
The names look dense at first, but they are meant to be mechanical. DSRLITS is a dual, synaptic, recurrent, leaky
integrator with ternary scaled output.
The layer families
The 32 SNN layers are easier to understand as four output families, each with the same optional dynamics layered on top.
Family |
Output |
Base layer |
Variants |
|---|---|---|---|
|
Continuous membrane value |
|
|
|
One-sided non-negative firing value |
|
|
|
Positive, zero, or negative firing value |
|
|
|
Positive, zero, or negative firing value with separate output scales |
|
|
The output family goes at the end of the name. Prefixes describe extra internal traces:
No prefix: one membrane trace.
D: positive and negative traces are stored separately.S: a synaptic trace is added before the membrane.R: a recurrent trace of the previous output is added before the membrane.DS,DR,SR,DSR: combinations of the above.
This means that SRLIT is a synaptic recurrent ternary layer, while DLITS is a dual ternary-scaled layer with
synaptic and recurrent traces.
Common parameters
Most SNN parameters follow the same rules across all layers.
num_neurons is the size of the dimension the layer operates on. The layer expects the input to already have this
size at dim. For example, tt.snn.LIB(64, dim=-3) works on the channel dimension of a [B, C, H, W] tensor
where C == 64.
Decay parameters are constrained to (0, 1):
alphacontrols synaptic memory.betacontrols membrane memory.gammacontrols recurrent-output memory.
Threshold parameters are constrained to positive values:
thresholdis used by binary layers.pos_thresholdandneg_thresholdare used by ternary layers.
The *_rank arguments choose whether a parameter is shared or per-neuron. A rank of 0 creates a scalar. A rank of
1 creates a vector of length num_neurons. You can also pass a tensor directly; scalar tensors are accepted, and
1D tensors must have num_neurons elements.
The learn_* arguments choose whether a parameter is trainable. If learn_beta=False, for example, the raw decay is
registered as a buffer rather than as an nn.Parameter.
Working on non-last dimensions
traceTorch layers do not require features to live in the last dimension. The layer temporarily moves dim to the end,
runs the same vectorized update, then moves it back.
layer = tt.snn.LIB(num_neurons=32, dim=-3)
x = torch.rand(16, 32, 28, 28)
y = layer(x)
print(y.shape)
# torch.Size([16, 32, 28, 28])
This is why the same SNN layer can be placed after a convolution, inside an MLP, or anywhere else that has a clear feature dimension.
Practical choices
Start simple unless the task gives you a reason not to.
Use LI as a readout or continuous accumulator when you want to keep magnitude information. Use LIEMA when the
trace should remain bounded like a smoothed signal.
Use LIB for ordinary one-sided SNN experiments. It is the traceTorch name for the common leaky integrate-and-fire
shape, with the important caveat that the default output is continuous unless you pass a quantizer.
Use LIT when negative events are meaningful. This is often more natural for signed activations, residual streams, or
signals where “below baseline” should be represented explicitly rather than merely suppressing positive firing.
Use LITS when the positive and negative events should have independent magnitudes. This lets the layer learn or fix
the downstream strength of positive and negative events separately from the thresholds that produced them.
Add S when input should arrive as a current over time instead of as a direct membrane increment. Add R when a
neuron’s previous output should influence its next update. Add D when positive and negative history should have
separate time constants.
For a first model, something like this is usually easier to reason about than starting with the largest layer:
self.net = nn.Sequential(
nn.Linear(784, 128),
tt.snn.LIB(128),
nn.Linear(128, 10),
tt.snn.LI(10),
)
Once that works, the layer name gives a mechanical upgrade path: LIB to SLIB for input smoothing, LIB to
RLIB for recurrent output memory, LIB to LIT for signed events, or LIT to LITS for signed events with
separate output magnitudes.