The SNNs of traceTorch

traceTorch was designed around spiking neural networks, but it deliberately treats “spiking” as a configurable part of the layer rather than as an unavoidable hard-coded event. The SNN layers are recurrent dynamical systems first: they keep one or more traces, update those traces one timestep at a time, and optionally convert the membrane value into binary, ternary, or scaled ternary output.

This is the main idea to keep in mind when reading the API:

  • LI layers return the membrane trace directly. They are leaky integrators, not firing neurons.

  • LIB layers use one positive threshold and produce non-negative output.

  • LIT layers use positive and negative thresholds and can produce positive or negative output.

  • LITS layers are ternary layers with learnable or fixed positive and negative output scales.

  • D splits traces into positive and negative branches.

  • S adds a synaptic trace before the membrane trace.

  • R adds a recurrent trace from the previous output.

The membrane trace

The simplest traceTorch SNN layer is LI. It stores a membrane state called mem and updates it as:

mem = beta * mem + x
return mem

The decay beta is constrained to (0, 1) by the layer registration machinery, even when it is learnable. A value near zero makes the layer mostly react to the current input. A value near one gives the layer longer memory.

The LIEMA variants use an exponential-moving-average form instead:

mem = beta * mem + (1 - beta) * x
return mem

This keeps the magnitude bounded in a way that is useful when the layer is meant to smooth a signal instead of accumulate evidence.

Spikes, probabilities, and quantization

The firing layers first update a membrane trace and then compare it with one or more thresholds. For a binary LIB layer, the default calculation is:

mem = beta * mem + x
spike_prob = sigmoid4x(mem - threshold + bias)
spikes = quant_fn(spike_prob)
mem = mem - spikes * threshold
return spikes

The default spike_fn is sigmoid4x:

sigmoid4x(x) = sigmoid(4 * x)

This gives a smooth transition around threshold. It is steep enough to behave like a soft firing decision, but it still has useful gradients for optimization.

The default quant_fn is nn.Identity(). That means the default layer returns the smooth spike probability itself, not a hard 0 or 1 event. In this default mode, the SNN behaves more like a differentiable recurrent ODE-style system than a strictly discrete spiking mechanism. This is intentional: it gives stable gradients and makes the default layer easy to train.

To make the forward pass more spike-like, pass a quantization function explicitly:

layer = tt.snn.LIB(
    num_neurons=128,
    quant_fn=tt.functional.round_ste(),
)

traceTorch provides straight-through quantizers in tt.functional. These functions make a discrete or stochastic forward decision while defining a surrogate gradient for the backward pass.

Surrogate gradients

Hard spikes are difficult to train with ordinary gradient descent because a step function has zero derivative almost everywhere and is undefined at the threshold. Surrogate-gradient training keeps the forward computation spike-like while using a smoother backward approximation.

In traceTorch, the separation is explicit:

  • spike_fn turns membrane distance from threshold into a smooth firing probability.

  • quant_fn decides whether that probability stays continuous or becomes discrete.

This design makes the training behavior visible. If you keep quant_fn=nn.Identity(), gradients flow through the smooth probability. If you use round_ste(), stochastic_round_ste(), or probabilistic_ste(), the forward pass becomes more discrete while the backward pass remains trainable.

Synaptic traces

Synaptic variants add a trace called syn before the membrane:

syn = alpha * syn + (1 - alpha) * x
mem = beta * mem + syn

This makes the input current smoother before it reaches the membrane. Dual synaptic layers use pos_syn and neg_syn so positive and negative signals can decay independently.

Recurrent traces

Recurrent variants add a trace of the previous output:

rec = gamma * rec + (1 - gamma) * prev_output
mem = beta * mem + x + rec_weight * rec

The recurrent trace is internal to the layer. The caller still passes one tensor in and receives one tensor out, which keeps traceTorch layers composable with ordinary PyTorch modules.

Choosing a layer

Use LI or LIEMA when you want a continuous trace. Use LIB when you want one-sided firing. Use LIT when positive and negative events should be represented separately. Use LITS when the positive and negative events should also have their own output magnitudes.

Then add prefixes only when they match the dynamics you need:

  • Add D when positive and negative history should be stored separately.

  • Add S when inputs should be smoothed before membrane integration.

  • Add R when the neuron’s previous output should influence its next membrane update.

The names look dense at first, but they are meant to be mechanical. DSRLITS is a dual, synaptic, recurrent, leaky integrator with ternary scaled output.

The layer families

The 32 SNN layers are easier to understand as four output families, each with the same optional dynamics layered on top.

Family

Output

Base layer

Variants

LI

Continuous membrane value

LI

LI, DLI, SLI, DSLI, LIEMA, DLIEMA, SLIEMA, DSLIEMA

LIB

One-sided non-negative firing value

LIB

LIB, DLIB, SLIB, RLIB, DSLIB, DRLIB, SRLIB, DSRLIB

LIT

Positive, zero, or negative firing value

LIT

LIT, DLIT, SLIT, RLIT, DSLIT, DRLIT, SRLIT, DSRLIT

LITS

Positive, zero, or negative firing value with separate output scales

LITS

LITS, DLITS, SLITS, RLITS, DSLITS, DRLITS, SRLITS, DSRLITS

The output family goes at the end of the name. Prefixes describe extra internal traces:

  • No prefix: one membrane trace.

  • D: positive and negative traces are stored separately.

  • S: a synaptic trace is added before the membrane.

  • R: a recurrent trace of the previous output is added before the membrane.

  • DS, DR, SR, DSR: combinations of the above.

This means that SRLIT is a synaptic recurrent ternary layer, while DLITS is a dual ternary-scaled layer with synaptic and recurrent traces.

Common parameters

Most SNN parameters follow the same rules across all layers.

num_neurons is the size of the dimension the layer operates on. The layer expects the input to already have this size at dim. For example, tt.snn.LIB(64, dim=-3) works on the channel dimension of a [B, C, H, W] tensor where C == 64.

Decay parameters are constrained to (0, 1):

  • alpha controls synaptic memory.

  • beta controls membrane memory.

  • gamma controls recurrent-output memory.

Threshold parameters are constrained to positive values:

  • threshold is used by binary layers.

  • pos_threshold and neg_threshold are used by ternary layers.

The *_rank arguments choose whether a parameter is shared or per-neuron. A rank of 0 creates a scalar. A rank of 1 creates a vector of length num_neurons. You can also pass a tensor directly; scalar tensors are accepted, and 1D tensors must have num_neurons elements.

The learn_* arguments choose whether a parameter is trainable. If learn_beta=False, for example, the raw decay is registered as a buffer rather than as an nn.Parameter.

Working on non-last dimensions

traceTorch layers do not require features to live in the last dimension. The layer temporarily moves dim to the end, runs the same vectorized update, then moves it back.

layer = tt.snn.LIB(num_neurons=32, dim=-3)
x = torch.rand(16, 32, 28, 28)
y = layer(x)
print(y.shape)
# torch.Size([16, 32, 28, 28])

This is why the same SNN layer can be placed after a convolution, inside an MLP, or anywhere else that has a clear feature dimension.

Practical choices

Start simple unless the task gives you a reason not to.

Use LI as a readout or continuous accumulator when you want to keep magnitude information. Use LIEMA when the trace should remain bounded like a smoothed signal.

Use LIB for ordinary one-sided SNN experiments. It is the traceTorch name for the common leaky integrate-and-fire shape, with the important caveat that the default output is continuous unless you pass a quantizer.

Use LIT when negative events are meaningful. This is often more natural for signed activations, residual streams, or signals where “below baseline” should be represented explicitly rather than merely suppressing positive firing.

Use LITS when the positive and negative events should have independent magnitudes. This lets the layer learn or fix the downstream strength of positive and negative events separately from the thresholds that produced them.

Add S when input should arrive as a current over time instead of as a direct membrane increment. Add R when a neuron’s previous output should influence its next update. Add D when positive and negative history should have separate time constants.

For a first model, something like this is usually easier to reason about than starting with the largest layer:

self.net = nn.Sequential(
    nn.Linear(784, 128),
    tt.snn.LIB(128),
    nn.Linear(128, 10),
    tt.snn.LI(10),
)

Once that works, the layer name gives a mechanical upgrade path: LIB to SLIB for input smoothing, LIB to RLIB for recurrent output memory, LIB to LIT for signed events, or LIT to LITS for signed events with separate output magnitudes.