SNN

tt.snn contains traceTorch’s leaky-integrator-based spiking layers.

Base Layer

class tracetorch.snn.Layer(*args: Any, **kwargs: Any)[source]

Bases: Layer

Base class for traceTorch SNN layers.

This class extends tt.Layer with parameter registration helpers commonly used by spiking layers:

  • decays are constrained to (0, 1) through a sigmoid transform;

  • thresholds are constrained to positive values through a softplus transform;

  • biases use a smooth unconstrained transform.

Parameters:
  • num_neurons (int) – number of neurons in the target dimension.

  • dim (int, default=-1) – dimension along which the layer operates.

Notes

Users normally instantiate concrete layers such as LIB or LIT. Subclass this base when creating a custom SNN layer that should integrate with tt.Model state management and traceTorch parameter compilation.

Continuous Leaky Integrators

class tracetorch.snn.LI(*args: Any, **kwargs: Any)[source]

Bases: Layer

A leaky integrator layer with continuous membrane output.

LI stores a membrane trace and returns that trace directly. It does not spike, threshold, or reset; it is useful as a readout-style accumulator or as a smooth recurrent feature transform.

Parameters:
  • num_neurons (int) – number of neurons in the target dimension.

  • beta (float or torch.Tensor, default=0.9) – membrane decay. The activated value is constrained to (0, 1).

  • dim (int, default=-1) – the dimension along which the layer operates.

  • beta_rank (Literal[0, 1], default=1) – 0 for a scalar decay shared by all neurons, 1 for one decay per neuron.

  • learn_beta (bool, default=True) – whether beta is trainable.

Variables:
  • mem – membrane state. Lazily initialized to zeros with the input shape, except the target dimension is set to num_neurons.

  • beta – activated membrane decay.

Notes

  • Input: tensor of shape [*,num_neurons,*] where num_neurons is at index dim.

  • Output: tensor with the same shape as the input.

Updates the membrane by exponentially decaying the previous value and adding the current input. Pseudocode looks as follows:

mem = beta * mem + x
return mem

Examples:

>>> layer = tt.snn.LI(num_neurons=32)
>>> input = torch.rand(16, 32)
>>> output = layer(input)
>>> print(output.shape)
torch.Size([16, 32])
forward(x)[source]

Computes the forward pass.

class tracetorch.snn.DLI(*args: Any, **kwargs: Any)[source]

Bases: Layer

A dual leaky integrator layer with continuous membrane output.

DLI splits the membrane trace into separate positive and negative branches. Positive input updates pos_mem and negative input updates neg_mem; the returned membrane is their sum.

Parameters:
  • num_neurons (int) – number of neurons in the target dimension.

  • pos_beta (float or torch.Tensor, default=0.9) – decay for the positive membrane branch.

  • neg_beta (float or torch.Tensor, default=0.9) – decay for the negative membrane branch.

  • dim (int, default=-1) – the dimension along which the layer operates.

  • pos_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron positive decay.

  • neg_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron negative decay.

  • learn_pos_beta (bool, default=True) – whether pos_beta is trainable.

  • learn_neg_beta (bool, default=True) – whether neg_beta is trainable.

Variables:
  • pos_mem – positive membrane state.

  • neg_mem – negative membrane state.

  • pos_beta – activated positive membrane decay, constrained to (0, 1).

  • neg_beta – activated negative membrane decay, constrained to (0, 1).

Notes

  • Input: tensor of shape [*,num_neurons,*] where num_neurons is at index dim.

  • Output: tensor with the same shape as the input.

Dual traces let positive and negative evidence retain different time constants. Pseudocode looks as follows:

pos_mem = pos_beta * pos_mem + where(x >= 0, x, 0)
neg_mem = neg_beta * neg_mem + where(x <= 0, x, 0)
return pos_mem + neg_mem

Examples:

>>> layer = tt.snn.DLI(num_neurons=32)
>>> input = torch.randn(16, 32)
>>> output = layer(input)
>>> print(output.shape)
torch.Size([16, 32])
forward(x)[source]

Computes the forward pass.

class tracetorch.snn.SLI(*args: Any, **kwargs: Any)[source]

Bases: Layer

A synaptic leaky integrator layer with continuous membrane output.

SLI adds a synaptic trace before the membrane trace. The synaptic trace smooths the input with decay alpha before the membrane integrates it with decay beta.

Parameters:
  • num_neurons (int) – number of neurons in the target dimension.

  • alpha (float or torch.Tensor, default=0.5) – synaptic decay, constrained to (0, 1).

  • beta (float or torch.Tensor, default=0.9) – membrane decay, constrained to (0, 1).

  • dim (int, default=-1) – the dimension along which the layer operates.

  • alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron synaptic decay.

  • beta_rank (Literal[0, 1], default=1) – scalar or per-neuron membrane decay.

  • learn_alpha (bool, default=True) – whether alpha is trainable.

  • learn_beta (bool, default=True) – whether beta is trainable.

Variables:
  • syn – synaptic state.

  • mem – membrane state.

  • alpha – activated synaptic decay.

  • beta – activated membrane decay.

Notes

  • Input: tensor of shape [*,num_neurons,*] where num_neurons is at index dim.

  • Output: tensor with the same shape as the input.

The synaptic trace is an exponential moving average of the input. The membrane then accumulates that smoothed current. Pseudocode looks as follows:

syn = alpha * syn + (1 - alpha) * x
mem = beta * mem + syn
return mem

Examples:

>>> layer = tt.snn.SLI(num_neurons=32)
>>> input = torch.rand(16, 32)
>>> output = layer(input)
>>> print(output.shape)
torch.Size([16, 32])
forward(x)[source]

Computes the forward pass.

class tracetorch.snn.DSLI(*args: Any, **kwargs: Any)[source]

Bases: Layer

A dual synaptic leaky integrator layer with continuous membrane output.

DSLI combines dual positive/negative traces with a synaptic stage. It keeps separate positive and negative synaptic traces, then integrates their sum into separate positive and negative membrane traces.

Parameters:
  • num_neurons (int) – number of neurons in the target dimension.

  • pos_alpha (float or torch.Tensor, default=0.5) – positive synaptic decay.

  • neg_alpha (float or torch.Tensor, default=0.5) – negative synaptic decay.

  • pos_beta (float or torch.Tensor, default=0.9) – positive membrane decay.

  • neg_beta (float or torch.Tensor, default=0.9) – negative membrane decay.

  • dim (int, default=-1) – the dimension along which the layer operates.

  • pos_alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron positive synaptic decay.

  • neg_alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron negative synaptic decay.

  • pos_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron positive membrane decay.

  • neg_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron negative membrane decay.

  • learn_pos_alpha (bool, default=True) – whether pos_alpha is trainable.

  • learn_neg_alpha (bool, default=True) – whether neg_alpha is trainable.

  • learn_pos_beta (bool, default=True) – whether pos_beta is trainable.

  • learn_neg_beta (bool, default=True) – whether neg_beta is trainable.

Variables:
  • pos_syn – positive synaptic state.

  • neg_syn – negative synaptic state.

  • pos_mem – positive membrane state.

  • neg_mem – negative membrane state.

Notes

  • Input: tensor of shape [*,num_neurons,*] where num_neurons is at index dim.

  • Output: tensor with the same shape as the input.

Pseudocode looks as follows:

pos_syn = pos_alpha * pos_syn + (1 - pos_alpha) * where(x >= 0, x, 0)
neg_syn = neg_alpha * neg_syn + (1 - neg_alpha) * where(x <= 0, x, 0)
syn = pos_syn + neg_syn
pos_mem = pos_beta * pos_mem + where(syn >= 0, syn, 0)
neg_mem = neg_beta * neg_mem + where(syn <= 0, syn, 0)
return pos_mem + neg_mem

Examples:

>>> layer = tt.snn.DSLI(num_neurons=32)
>>> input = torch.randn(16, 32)
>>> output = layer(input)
>>> print(output.shape)
torch.Size([16, 32])
forward(x)[source]

Computes the forward pass.

class tracetorch.snn.LIEMA(*args: Any, **kwargs: Any)[source]

Bases: Layer

A leaky integrator layer with exponential-moving-average output.

LIEMA is the bounded counterpart to LI. It stores one membrane trace and updates it as an exponential moving average of the input instead of as an unnormalized accumulator.

Parameters:
  • num_neurons (int) – number of neurons in the target dimension.

  • beta (float or torch.Tensor, default=0.9) – membrane EMA decay, constrained to (0, 1).

  • dim (int, default=-1) – the dimension along which the layer operates.

  • beta_rank (Literal[0, 1], default=1) – scalar or per-neuron decay.

  • learn_beta (bool, default=True) – whether beta is trainable.

Variables:
  • mem – membrane EMA state.

  • beta – activated membrane decay.

Notes

  • Input: tensor of shape [*,num_neurons,*] where num_neurons is at index dim.

  • Output: tensor with the same shape as the input.

Pseudocode looks as follows:

mem = beta * mem + (1 - beta) * x
return mem

Examples:

>>> layer = tt.snn.LIEMA(num_neurons=32)
>>> input = torch.rand(16, 32)
>>> output = layer(input)
>>> print(output.shape)
torch.Size([16, 32])
forward(x)[source]

Computes the forward pass.

class tracetorch.snn.DLIEMA(*args: Any, **kwargs: Any)[source]

Bases: Layer

A dual leaky integrator layer with exponential-moving-average output.

DLIEMA keeps positive and negative membrane EMA traces separately and returns their sum. This is useful when positive and negative evidence should decay independently without allowing unbounded accumulation.

Parameters:
  • num_neurons (int) – number of neurons in the target dimension.

  • pos_beta (float or torch.Tensor, default=0.9) – positive membrane EMA decay.

  • neg_beta (float or torch.Tensor, default=0.9) – negative membrane EMA decay.

  • dim (int, default=-1) – the dimension along which the layer operates.

  • pos_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron positive decay.

  • neg_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron negative decay.

  • learn_pos_beta (bool, default=True) – whether pos_beta is trainable.

  • learn_neg_beta (bool, default=True) – whether neg_beta is trainable.

Variables:
  • pos_mem – positive membrane EMA state.

  • neg_mem – negative membrane EMA state.

Notes

  • Input: tensor of shape [*,num_neurons,*] where num_neurons is at index dim.

  • Output: tensor with the same shape as the input.

Pseudocode looks as follows:

pos_mem = pos_beta * pos_mem + (1 - pos_beta) * where(x >= 0, x, 0)
neg_mem = neg_beta * neg_mem + (1 - neg_beta) * where(x <= 0, x, 0)
return pos_mem + neg_mem

Examples:

>>> layer = tt.snn.DLIEMA(num_neurons=32)
>>> input = torch.randn(16, 32)
>>> output = layer(input)
>>> print(output.shape)
torch.Size([16, 32])
forward(x)[source]

Computes the forward pass.

class tracetorch.snn.SLIEMA(*args: Any, **kwargs: Any)[source]

Bases: Layer

A synaptic leaky integrator layer with exponential-moving-average output.

SLIEMA smooths the input through a synaptic EMA and then updates the membrane as an EMA of that synaptic current.

Parameters:
  • num_neurons (int) – number of neurons in the target dimension.

  • alpha (float or torch.Tensor, default=0.5) – synaptic decay, constrained to (0, 1).

  • beta (float or torch.Tensor, default=0.9) – membrane decay, constrained to (0, 1).

  • dim (int, default=-1) – the dimension along which the layer operates.

  • alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron synaptic decay.

  • beta_rank (Literal[0, 1], default=1) – scalar or per-neuron membrane decay.

  • learn_alpha (bool, default=True) – whether alpha is trainable.

  • learn_beta (bool, default=True) – whether beta is trainable.

Variables:
  • syn – synaptic EMA state.

  • mem – membrane EMA state.

Notes

  • Input: tensor of shape [*,num_neurons,*] where num_neurons is at index dim.

  • Output: tensor with the same shape as the input.

Pseudocode looks as follows:

syn = alpha * syn + (1 - alpha) * x
mem = beta * mem + (1 - beta) * syn
return mem

Examples:

>>> layer = tt.snn.SLIEMA(num_neurons=32)
>>> input = torch.rand(16, 32)
>>> output = layer(input)
>>> print(output.shape)
torch.Size([16, 32])
forward(x)[source]

Computes the forward pass.

class tracetorch.snn.DSLIEMA(*args: Any, **kwargs: Any)[source]

Bases: Layer

A dual synaptic leaky integrator layer with EMA membrane output.

DSLIEMA is the dual, synaptic, bounded-output variant of the leaky integrator family. It keeps positive and negative synaptic traces, then updates positive and negative membrane EMA traces from their combined synaptic current.

Parameters:
  • num_neurons (int) – number of neurons in the target dimension.

  • pos_alpha (float or torch.Tensor, default=0.5) – positive synaptic decay.

  • neg_alpha (float or torch.Tensor, default=0.5) – negative synaptic decay.

  • pos_beta (float or torch.Tensor, default=0.9) – positive membrane decay.

  • neg_beta (float or torch.Tensor, default=0.9) – negative membrane decay.

  • dim (int, default=-1) – the dimension along which the layer operates.

  • pos_alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron positive synaptic decay.

  • neg_alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron negative synaptic decay.

  • pos_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron positive membrane decay.

  • neg_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron negative membrane decay.

  • learn_pos_alpha (bool, default=True) – whether pos_alpha is trainable.

  • learn_neg_alpha (bool, default=True) – whether neg_alpha is trainable.

  • learn_pos_beta (bool, default=True) – whether pos_beta is trainable.

  • learn_neg_beta (bool, default=True) – whether neg_beta is trainable.

Variables:
  • pos_syn – positive synaptic EMA state.

  • neg_syn – negative synaptic EMA state.

  • pos_mem – positive membrane EMA state.

  • neg_mem – negative membrane EMA state.

Notes

  • Input: tensor of shape [*,num_neurons,*] where num_neurons is at index dim.

  • Output: tensor with the same shape as the input.

Pseudocode looks as follows:

pos_syn = pos_alpha * pos_syn + (1 - pos_alpha) * where(x >= 0, x, 0)
neg_syn = neg_alpha * neg_syn + (1 - neg_alpha) * where(x <= 0, x, 0)
syn = pos_syn + neg_syn
pos_mem = pos_beta * pos_mem + (1 - pos_beta) * where(syn >= 0, syn, 0)
neg_mem = neg_beta * neg_mem + (1 - neg_beta) * where(syn <= 0, syn, 0)
return pos_mem + neg_mem

Examples:

>>> layer = tt.snn.DSLIEMA(num_neurons=32)
>>> input = torch.randn(16, 32)
>>> output = layer(input)
>>> print(output.shape)
torch.Size([16, 32])
forward(x)[source]

Computes the forward pass.

Binary Firing Layers

class tracetorch.snn.LIB(*args: Any, **kwargs: Any)[source]

Bases: Layer

A leaky integrate-and-binary-fire layer.

LIB is traceTorch’s one-sided firing layer. It stores one membrane trace, converts the distance from threshold into a firing probability, optionally quantizes that probability, subtracts the threshold-scaled output from the membrane, and returns the output.

Parameters:
  • num_neurons (int) – number of neurons in the target dimension.

  • beta (float or torch.Tensor, default=0.9) – membrane decay, constrained to (0, 1).

  • threshold (float or torch.Tensor, default=1.0) – positive firing threshold, constrained to positive values.

  • bias (float or torch.Tensor, default=0.0) – additive bias applied before the spike function.

  • dim (int, default=-1) – the dimension along which the layer operates.

  • beta_rank (Literal[0, 1], default=1) – scalar or per-neuron membrane decay.

  • threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron threshold.

  • bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.

  • learn_beta (bool, default=True) – whether beta is trainable.

  • learn_threshold (bool, default=True) – whether threshold is trainable.

  • learn_bias (bool, default=True) – whether bias is trainable.

  • spike_fn (Callable, default=tt.functional.sigmoid4x) – function that maps membrane distance from threshold to a firing probability.

  • quant_fn (Callable, default=nn.Identity()) – function that maps firing probability to the returned spike value.

Variables:
  • mem – membrane state.

  • beta – activated membrane decay.

  • threshold – activated positive threshold.

  • bias – activated bias.

  • spike_fn – spike probability function.

  • quant_fn – output quantization function.

Notes

  • Input: tensor of shape [*,num_neurons,*] where num_neurons is at index dim.

  • Output: tensor with the same shape as the input.

With the default quant_fn=nn.Identity(), the layer returns smooth firing probabilities. Pass a straight-through quantizer such as tt.functional.round_ste() for harder binary events. Pseudocode looks as follows:

mem = beta * mem + x
spike_prob = spike_fn(mem - threshold + bias)
spikes = quant_fn(spike_prob)
mem = mem - spikes * threshold
return spikes

Examples:

>>> layer = tt.snn.LIB(num_neurons=32)
>>> input = torch.rand(16, 32)
>>> output = layer(input)
>>> print(output.shape)
torch.Size([16, 32])
forward(x)[source]

Computes the forward pass.

class tracetorch.snn.DLIB(*args: Any, **kwargs: Any)[source]

Bases: Layer

A dual leaky integrate-and-binary-fire layer.

DLIB splits membrane integration into positive and negative branches, but still emits a one-sided binary-style output. The two membrane branches are summed for thresholding, then the reset is split evenly across both branches.

Parameters:
  • num_neurons (int) – number of neurons in the target dimension.

  • pos_beta (float or torch.Tensor, default=0.9) – positive membrane decay.

  • neg_beta (float or torch.Tensor, default=0.9) – negative membrane decay.

  • threshold (float or torch.Tensor, default=1.0) – positive firing threshold.

  • bias (float or torch.Tensor, default=0.0) – additive bias before firing.

  • dim (int, default=-1) – the dimension along which the layer operates.

  • pos_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron positive decay.

  • neg_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron negative decay.

  • threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron threshold.

  • bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.

  • learn_pos_beta (bool, default=True) – whether pos_beta is trainable.

  • learn_neg_beta (bool, default=True) – whether neg_beta is trainable.

  • learn_threshold (bool, default=True) – whether threshold is trainable.

  • learn_bias (bool, default=True) – whether bias is trainable.

  • spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.

  • quant_fn (Callable, default=nn.Identity()) – output quantization function.

Variables:
  • pos_mem – positive membrane state.

  • neg_mem – negative membrane state.

  • threshold – activated positive threshold.

  • bias – activated bias.

Notes

  • Input: tensor of shape [*,num_neurons,*] where num_neurons is at index dim.

  • Output: tensor with the same shape as the input.

Pseudocode looks as follows:

pos_mem = pos_beta * pos_mem + where(x >= 0, x, 0)
neg_mem = neg_beta * neg_mem + where(x <= 0, x, 0)
mem = pos_mem + neg_mem
spikes = quant_fn(spike_fn(mem - threshold + bias))
pos_mem = pos_mem - 0.5 * spikes * threshold
neg_mem = neg_mem - 0.5 * spikes * threshold
return spikes

Examples:

>>> layer = tt.snn.DLIB(num_neurons=32)
>>> input = torch.randn(16, 32)
>>> output = layer(input)
>>> print(output.shape)
torch.Size([16, 32])
forward(x)[source]

Computes the forward pass.

class tracetorch.snn.SLIB(*args: Any, **kwargs: Any)[source]

Bases: Layer

A synaptic leaky integrate-and-binary-fire layer.

SLIB smooths the input through a synaptic trace before membrane integration and one-sided firing. This is useful when the input should behave like a current with its own time constant rather than an instantaneous membrane increment.

Parameters:
  • num_neurons (int) – number of neurons in the target dimension.

  • alpha (float or torch.Tensor, default=0.5) – synaptic decay.

  • beta (float or torch.Tensor, default=0.9) – membrane decay.

  • threshold (float or torch.Tensor, default=1.0) – positive firing threshold.

  • bias (float or torch.Tensor, default=0.0) – additive bias before firing.

  • dim (int, default=-1) – the dimension along which the layer operates.

  • alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron synaptic decay.

  • beta_rank (Literal[0, 1], default=1) – scalar or per-neuron membrane decay.

  • threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron threshold.

  • bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.

  • learn_alpha (bool, default=True) – whether alpha is trainable.

  • learn_beta (bool, default=True) – whether beta is trainable.

  • learn_threshold (bool, default=True) – whether threshold is trainable.

  • learn_bias (bool, default=True) – whether bias is trainable.

  • spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.

  • quant_fn (Callable, default=nn.Identity()) – output quantization function.

Variables:
  • syn – synaptic state.

  • mem – membrane state.

Notes

  • Input: tensor of shape [*,num_neurons,*] where num_neurons is at index dim.

  • Output: tensor with the same shape as the input.

Pseudocode looks as follows:

syn = alpha * syn + (1 - alpha) * x
mem = beta * mem + syn
spikes = quant_fn(spike_fn(mem - threshold + bias))
mem = mem - spikes * threshold
return spikes

Examples:

>>> layer = tt.snn.SLIB(num_neurons=32)
>>> input = torch.rand(16, 32)
>>> output = layer(input)
>>> print(output.shape)
torch.Size([16, 32])
forward(x)[source]

Computes the forward pass.

class tracetorch.snn.RLIB(*args: Any, **kwargs: Any)[source]

Bases: Layer

A recurrent leaky integrate-and-binary-fire layer.

RLIB adds a recurrent trace of the previous output. The recurrent trace is decayed with gamma, scaled by rec_weight, and added to the current input before membrane integration.

Parameters:
  • num_neurons (int) – number of neurons in the target dimension.

  • beta (float or torch.Tensor, default=0.9) – membrane decay.

  • gamma (float or torch.Tensor, default=0.9) – recurrent trace decay.

  • threshold (float or torch.Tensor, default=1.0) – positive firing threshold.

  • bias (float or torch.Tensor, default=0.0) – additive bias before firing.

  • rec_weight (float or torch.Tensor, default=0.0) – recurrent input scale.

  • dim (int, default=-1) – the dimension along which the layer operates.

  • beta_rank (Literal[0, 1], default=1) – scalar or per-neuron membrane decay.

  • gamma_rank (Literal[0, 1], default=1) – scalar or per-neuron recurrent decay.

  • threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron threshold.

  • bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.

  • rec_weight_rank (Literal[0, 1], default=1) – scalar or per-neuron recurrent scale.

  • learn_beta (bool, default=True) – whether beta is trainable.

  • learn_gamma (bool, default=True) – whether gamma is trainable.

  • learn_threshold (bool, default=True) – whether threshold is trainable.

  • learn_bias (bool, default=True) – whether bias is trainable.

  • learn_rec_weight (bool, default=True) – whether rec_weight is trainable.

  • spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.

  • quant_fn (Callable, default=nn.Identity()) – output quantization function.

Variables:
  • mem – membrane state.

  • rec – recurrent trace state.

  • prev_output – previous returned output.

Notes

  • Input: tensor of shape [*,num_neurons,*] where num_neurons is at index dim.

  • Output: tensor with the same shape as the input.

Pseudocode looks as follows:

rec = gamma * rec + (1 - gamma) * prev_output
mem = beta * mem + x + rec_weight * rec
spikes = quant_fn(spike_fn(mem - threshold + bias))
mem = mem - spikes * threshold
prev_output = spikes
return spikes

Examples:

>>> layer = tt.snn.RLIB(num_neurons=32)
>>> input = torch.rand(16, 32)
>>> output = layer(input)
>>> print(output.shape)
torch.Size([16, 32])
forward(x)[source]

Computes the forward pass.

class tracetorch.snn.DSLIB(*args: Any, **kwargs: Any)[source]

Bases: Layer

A dual synaptic leaky integrate-and-binary-fire layer.

DSLIB combines dual positive/negative traces with a synaptic stage and a one-sided firing output. Positive and negative inputs are smoothed separately, summed, integrated into dual membrane traces, and thresholded as one combined membrane.

Parameters:
  • num_neurons (int) – number of neurons in the target dimension.

  • pos_alpha (float or torch.Tensor, default=0.5) – positive synaptic decay.

  • neg_alpha (float or torch.Tensor, default=0.5) – negative synaptic decay.

  • pos_beta (float or torch.Tensor, default=0.9) – positive membrane decay.

  • neg_beta (float or torch.Tensor, default=0.9) – negative membrane decay.

  • threshold (float or torch.Tensor, default=1.0) – positive firing threshold.

  • bias (float or torch.Tensor, default=0.0) – additive bias before firing.

  • dim (int, default=-1) – the dimension along which the layer operates.

  • pos_alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron positive synaptic decay.

  • neg_alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron negative synaptic decay.

  • pos_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron positive membrane decay.

  • neg_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron negative membrane decay.

  • threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron threshold.

  • bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.

  • learn_pos_alpha (bool, default=True) – whether pos_alpha is trainable.

  • learn_neg_alpha (bool, default=True) – whether neg_alpha is trainable.

  • learn_pos_beta (bool, default=True) – whether pos_beta is trainable.

  • learn_neg_beta (bool, default=True) – whether neg_beta is trainable.

  • learn_threshold (bool, default=True) – whether threshold is trainable.

  • learn_bias (bool, default=True) – whether bias is trainable.

  • spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.

  • quant_fn (Callable, default=nn.Identity()) – output quantization function.

Variables:
  • pos_syn – positive synaptic state.

  • neg_syn – negative synaptic state.

  • pos_mem – positive membrane state.

  • neg_mem – negative membrane state.

Notes

  • Input: tensor of shape [*,num_neurons,*] where num_neurons is at index dim.

  • Output: tensor with the same shape as the input.

The reset is split evenly across the dual membrane branches. Pseudocode looks as follows:

pos_syn = pos_alpha * pos_syn + (1 - pos_alpha) * where(x >= 0, x, 0)
neg_syn = neg_alpha * neg_syn + (1 - neg_alpha) * where(x <= 0, x, 0)
syn = pos_syn + neg_syn
pos_mem = pos_beta * pos_mem + where(syn >= 0, syn, 0)
neg_mem = neg_beta * neg_mem + where(syn <= 0, syn, 0)
spikes = quant_fn(spike_fn(pos_mem + neg_mem - threshold + bias))
pos_mem = pos_mem - 0.5 * spikes * threshold
neg_mem = neg_mem - 0.5 * spikes * threshold
return spikes

Examples:

>>> layer = tt.snn.DSLIB(num_neurons=32)
>>> input = torch.randn(16, 32)
>>> output = layer(input)
>>> print(output.shape)
torch.Size([16, 32])
forward(x)[source]

Computes the forward pass.

class tracetorch.snn.DRLIB(*args: Any, **kwargs: Any)[source]

Bases: Layer

A dual recurrent leaky integrate-and-binary-fire layer.

DRLIB keeps positive and negative membrane traces and positive and negative recurrent traces. The previous output is split by sign, smoothed into recurrent traces, scaled, and integrated with the current input before a one-sided binary firing decision.

Parameters:
  • num_neurons (int) – number of neurons in the target dimension.

  • pos_beta (float or torch.Tensor, default=0.9) – positive membrane decay.

  • neg_beta (float or torch.Tensor, default=0.9) – negative membrane decay.

  • pos_gamma (float or torch.Tensor, default=0.9) – positive recurrent decay.

  • neg_gamma (float or torch.Tensor, default=0.9) – negative recurrent decay.

  • threshold (float or torch.Tensor, default=1.0) – positive firing threshold.

  • bias (float or torch.Tensor, default=0.0) – additive bias before firing.

  • pos_rec_weight (float or torch.Tensor, default=0.0) – positive recurrent input scale.

  • neg_rec_weight (float or torch.Tensor, default=0.0) – negative recurrent input scale.

  • dim (int, default=-1) – the dimension along which the layer operates.

  • pos_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron positive membrane decay.

  • neg_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron negative membrane decay.

  • pos_gamma_rank (Literal[0, 1], default=1) – scalar or per-neuron positive recurrent decay.

  • neg_gamma_rank (Literal[0, 1], default=1) – scalar or per-neuron negative recurrent decay.

  • threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron threshold.

  • bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.

  • pos_rec_weight_rank (Literal[0, 1], default=1) – scalar or per-neuron positive recurrent scale.

  • neg_rec_weight_rank (Literal[0, 1], default=1) – scalar or per-neuron negative recurrent scale.

  • learn_pos_beta (bool, default=True) – whether pos_beta is trainable.

  • learn_neg_beta (bool, default=True) – whether neg_beta is trainable.

  • learn_pos_gamma (bool, default=True) – whether pos_gamma is trainable.

  • learn_neg_gamma (bool, default=True) – whether neg_gamma is trainable.

  • learn_threshold (bool, default=True) – whether threshold is trainable.

  • learn_bias (bool, default=True) – whether bias is trainable.

  • learn_pos_rec_weight (bool, default=True) – whether pos_rec_weight is trainable.

  • learn_neg_rec_weight (bool, default=True) – whether neg_rec_weight is trainable.

  • spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.

  • quant_fn (Callable, default=nn.Identity()) – output quantization function.

Variables:
  • pos_mem – positive membrane state.

  • neg_mem – negative membrane state.

  • pos_rec – positive recurrent trace state.

  • neg_rec – negative recurrent trace state.

  • prev_output – previous returned output.

Notes

  • Input: tensor of shape [*,num_neurons,*] where num_neurons is at index dim.

  • Output: tensor with the same shape as the input.

Pseudocode looks as follows:

pos_rec = pos_gamma * pos_rec + (1 - pos_gamma) * where(prev_output >= 0, prev_output, 0)
neg_rec = neg_gamma * neg_rec + (1 - neg_gamma) * where(prev_output <= 0, prev_output, 0)
pos_mem = pos_beta * pos_mem + pos_rec_weight * where(pos_rec + neg_rec >= 0, pos_rec + neg_rec, 0) + where(x >= 0, x, 0)
neg_mem = neg_beta * neg_mem + neg_rec_weight * where(pos_rec + neg_rec <= 0, pos_rec + neg_rec, 0) + where(x <= 0, x, 0)
spikes = quant_fn(spike_fn(pos_mem + neg_mem - threshold + bias))
prev_output = spikes
return spikes

Examples:

>>> layer = tt.snn.DRLIB(num_neurons=32)
>>> input = torch.randn(16, 32)
>>> output = layer(input)
>>> print(output.shape)
torch.Size([16, 32])
forward(x)[source]

Computes the forward pass.

class tracetorch.snn.SRLIB(*args: Any, **kwargs: Any)[source]

Bases: Layer

A synaptic recurrent leaky integrate-and-binary-fire layer.

SRLIB combines an input synaptic trace with a recurrent trace of the previous output. The membrane receives both the smoothed input and the scaled recurrent current before one-sided firing.

Parameters:
  • num_neurons (int) – number of neurons in the target dimension.

  • alpha (float or torch.Tensor, default=0.5) – synaptic decay.

  • beta (float or torch.Tensor, default=0.9) – membrane decay.

  • gamma (float or torch.Tensor, default=0.9) – recurrent trace decay.

  • threshold (float or torch.Tensor, default=1.0) – positive firing threshold.

  • bias (float or torch.Tensor, default=0.0) – additive bias before firing.

  • rec_weight (float or torch.Tensor, default=0.0) – recurrent input scale.

  • dim (int, default=-1) – the dimension along which the layer operates.

  • alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron synaptic decay.

  • beta_rank (Literal[0, 1], default=1) – scalar or per-neuron membrane decay.

  • gamma_rank (Literal[0, 1], default=1) – scalar or per-neuron recurrent decay.

  • threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron threshold.

  • bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.

  • rec_weight_rank (Literal[0, 1], default=1) – scalar or per-neuron recurrent scale.

  • learn_alpha (bool, default=True) – whether alpha is trainable.

  • learn_beta (bool, default=True) – whether beta is trainable.

  • learn_gamma (bool, default=True) – whether gamma is trainable.

  • learn_threshold (bool, default=True) – whether threshold is trainable.

  • learn_bias (bool, default=True) – whether bias is trainable.

  • learn_rec_weight (bool, default=True) – whether rec_weight is trainable.

  • spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.

  • quant_fn (Callable, default=nn.Identity()) – output quantization function.

Variables:
  • syn – synaptic state.

  • mem – membrane state.

  • rec – recurrent trace state.

  • prev_output – previous returned output.

Notes

  • Input: tensor of shape [*,num_neurons,*] where num_neurons is at index dim.

  • Output: tensor with the same shape as the input.

Pseudocode looks as follows:

syn = alpha * syn + (1 - alpha) * x
rec = gamma * rec + (1 - gamma) * prev_output
mem = beta * mem + syn + rec_weight * rec
spikes = quant_fn(spike_fn(mem - threshold + bias))
mem = mem - spikes * threshold
prev_output = spikes
return spikes

Examples:

>>> layer = tt.snn.SRLIB(num_neurons=32)
>>> input = torch.rand(16, 32)
>>> output = layer(input)
>>> print(output.shape)
torch.Size([16, 32])
forward(x)[source]

Computes the forward pass.

class tracetorch.snn.DSRLIB(*args: Any, **kwargs: Any)[source]

Bases: Layer

A dual synaptic recurrent leaky integrate-and-binary-fire layer.

DSRLIB is the most expressive binary SNN layer in traceTorch. It combines dual positive/negative synaptic traces, dual positive/negative recurrent traces, dual positive/negative membrane traces, and a one-sided binary firing output.

Parameters:
  • num_neurons (int) – number of neurons in the target dimension.

  • pos_alpha (float or torch.Tensor, default=0.5) – positive synaptic decay.

  • neg_alpha (float or torch.Tensor, default=0.5) – negative synaptic decay.

  • pos_beta (float or torch.Tensor, default=0.9) – positive membrane decay.

  • neg_beta (float or torch.Tensor, default=0.9) – negative membrane decay.

  • pos_gamma (float or torch.Tensor, default=0.9) – positive recurrent decay.

  • neg_gamma (float or torch.Tensor, default=0.9) – negative recurrent decay.

  • threshold (float or torch.Tensor, default=1.0) – positive firing threshold.

  • bias (float or torch.Tensor, default=0.0) – additive bias before firing.

  • pos_rec_weight (float or torch.Tensor, default=0.0) – positive recurrent input scale.

  • neg_rec_weight (float or torch.Tensor, default=0.0) – negative recurrent input scale.

  • dim (int, default=-1) – the dimension along which the layer operates.

  • pos_alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron positive synaptic decay.

  • neg_alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron negative synaptic decay.

  • pos_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron positive membrane decay.

  • neg_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron negative membrane decay.

  • pos_gamma_rank (Literal[0, 1], default=1) – scalar or per-neuron positive recurrent decay.

  • neg_gamma_rank (Literal[0, 1], default=1) – scalar or per-neuron negative recurrent decay.

  • threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron threshold.

  • bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.

  • pos_rec_weight_rank (Literal[0, 1], default=1) – scalar or per-neuron positive recurrent scale.

  • neg_rec_weight_rank (Literal[0, 1], default=1) – scalar or per-neuron negative recurrent scale.

  • learn_pos_alpha (bool, default=True) – whether pos_alpha is trainable.

  • learn_neg_alpha (bool, default=True) – whether neg_alpha is trainable.

  • learn_pos_beta (bool, default=True) – whether pos_beta is trainable.

  • learn_neg_beta (bool, default=True) – whether neg_beta is trainable.

  • learn_pos_gamma (bool, default=True) – whether pos_gamma is trainable.

  • learn_neg_gamma (bool, default=True) – whether neg_gamma is trainable.

  • learn_threshold (bool, default=True) – whether threshold is trainable.

  • learn_bias (bool, default=True) – whether bias is trainable.

  • learn_pos_rec_weight (bool, default=True) – whether pos_rec_weight is trainable.

  • learn_neg_rec_weight (bool, default=True) – whether neg_rec_weight is trainable.

  • spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.

  • quant_fn (Callable, default=nn.Identity()) – output quantization function.

Variables:
  • pos_syn – positive synaptic state.

  • neg_syn – negative synaptic state.

  • pos_mem – positive membrane state.

  • neg_mem – negative membrane state.

  • pos_rec – positive recurrent trace state.

  • neg_rec – negative recurrent trace state.

  • prev_output – previous returned output.

Notes

  • Input: tensor of shape [*,num_neurons,*] where num_neurons is at index dim.

  • Output: tensor with the same shape as the input.

DSRLIB is useful when the sign of the input and the sign of the recurrent history should both have independent memory. The firing output is still one-sided: negative internal evidence can suppress firing, but the returned output is non-negative unless a custom quant_fn changes that convention.

Examples:

>>> layer = tt.snn.DSRLIB(num_neurons=32)
>>> input = torch.randn(16, 32)
>>> output = layer(input)
>>> print(output.shape)
torch.Size([16, 32])
forward(x)[source]

Computes the forward pass.

Ternary Firing Layers

class tracetorch.snn.LIT(*args: Any, **kwargs: Any)[source]

Bases: Layer

A leaky integrate-and-ternary-fire layer.

LIT stores one membrane trace and can emit positive, zero, or negative output. Positive firing is controlled by pos_threshold; negative firing is controlled by neg_threshold. The negative branch returns negative values by convention.

Parameters:
  • num_neurons (int) – number of neurons in the target dimension.

  • beta (float or torch.Tensor, default=0.9) – membrane decay, constrained to (0, 1).

  • pos_threshold (float or torch.Tensor, default=1.0) – positive firing threshold, constrained to positive values.

  • neg_threshold (float or torch.Tensor, default=1.0) – magnitude of the negative firing threshold, constrained to positive values.

  • bias (float or torch.Tensor, default=0.0) – bias that shifts both firing boundaries.

  • dim (int, default=-1) – the dimension along which the layer operates.

  • beta_rank (Literal[0, 1], default=1) – scalar or per-neuron membrane decay.

  • pos_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron positive threshold.

  • neg_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron negative threshold magnitude.

  • bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.

  • learn_beta (bool, default=True) – whether beta is trainable.

  • learn_pos_threshold (bool, default=True) – whether pos_threshold is trainable.

  • learn_neg_threshold (bool, default=True) – whether neg_threshold is trainable.

  • learn_bias (bool, default=True) – whether bias is trainable.

  • spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.

  • quant_fn (Callable, default=nn.Identity()) – output quantization function.

Variables:
  • mem – membrane state.

  • beta – activated membrane decay.

  • pos_threshold – activated positive threshold.

  • neg_threshold – activated negative threshold magnitude.

  • bias – activated bias.

Notes

  • Input: tensor of shape [*,num_neurons,*] where num_neurons is at index dim.

  • Output: tensor with the same shape as the input.

With the default identity quantizer, positive and negative outputs are smooth probabilities with opposite signs. Pseudocode looks as follows:

mem = beta * mem + x
pos = quant_fn(spike_fn(mem - pos_threshold + bias))
neg = -quant_fn(spike_fn(-neg_threshold - mem - bias))
mem = mem - pos * pos_threshold
mem = mem - neg * neg_threshold
return pos + neg

Examples:

>>> layer = tt.snn.LIT(num_neurons=32)
>>> input = torch.randn(16, 32)
>>> output = layer(input)
>>> print(output.shape)
torch.Size([16, 32])
forward(x)[source]

Computes the forward pass.

class tracetorch.snn.DLIT(*args: Any, **kwargs: Any)[source]

Bases: Layer

A dual leaky integrate-and-ternary-fire layer.

DLIT splits membrane integration into positive and negative branches and emits ternary-style output. The summed membrane is thresholded, and each reset is split evenly across the positive and negative membrane branches.

Parameters:
  • num_neurons (int) – number of neurons in the target dimension.

  • pos_beta (float or torch.Tensor, default=0.9) – positive membrane decay.

  • neg_beta (float or torch.Tensor, default=0.9) – negative membrane decay.

  • pos_threshold (float or torch.Tensor, default=1.0) – positive threshold.

  • neg_threshold (float or torch.Tensor, default=1.0) – negative threshold magnitude.

  • bias (float or torch.Tensor, default=0.0) – bias that shifts both firing boundaries.

  • dim (int, default=-1) – the dimension along which the layer operates.

  • pos_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron positive decay.

  • neg_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron negative decay.

  • pos_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron positive threshold.

  • neg_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron negative threshold magnitude.

  • bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.

  • learn_pos_beta (bool, default=True) – whether pos_beta is trainable.

  • learn_neg_beta (bool, default=True) – whether neg_beta is trainable.

  • learn_pos_threshold (bool, default=True) – whether pos_threshold is trainable.

  • learn_neg_threshold (bool, default=True) – whether neg_threshold is trainable.

  • learn_bias (bool, default=True) – whether bias is trainable.

  • spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.

  • quant_fn (Callable, default=nn.Identity()) – output quantization function.

Variables:
  • pos_mem – positive membrane state.

  • neg_mem – negative membrane state.

Notes

  • Input: tensor of shape [*,num_neurons,*] where num_neurons is at index dim.

  • Output: tensor with the same shape as the input.

Pseudocode looks as follows:

pos_mem = pos_beta * pos_mem + where(x >= 0, x, 0)
neg_mem = neg_beta * neg_mem + where(x < 0, x, 0)
mem = pos_mem + neg_mem
pos = quant_fn(spike_fn(mem - pos_threshold + bias))
neg = -quant_fn(spike_fn(-neg_threshold - mem - bias))
pos_mem = pos_mem - 0.5 * pos * pos_threshold - 0.5 * neg * neg_threshold
neg_mem = neg_mem - 0.5 * pos * pos_threshold - 0.5 * neg * neg_threshold
return pos + neg

Examples:

>>> layer = tt.snn.DLIT(num_neurons=32)
>>> input = torch.randn(16, 32)
>>> output = layer(input)
>>> print(output.shape)
torch.Size([16, 32])
forward(x)[source]

Computes the forward pass.

class tracetorch.snn.SLIT(*args: Any, **kwargs: Any)[source]

Bases: Layer

A synaptic leaky integrate-and-ternary-fire layer.

SLIT smooths the input through a synaptic trace before membrane integration and ternary firing. It is the ternary counterpart of SLIB.

Parameters:
  • num_neurons (int) – number of neurons in the target dimension.

  • alpha (float or torch.Tensor, default=0.5) – synaptic decay.

  • beta (float or torch.Tensor, default=0.9) – membrane decay.

  • pos_threshold (float or torch.Tensor, default=1.0) – positive threshold.

  • neg_threshold (float or torch.Tensor, default=1.0) – negative threshold magnitude.

  • bias (float or torch.Tensor, default=0.0) – bias that shifts both firing boundaries.

  • dim (int, default=-1) – the dimension along which the layer operates.

  • alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron synaptic decay.

  • beta_rank (Literal[0, 1], default=1) – scalar or per-neuron membrane decay.

  • pos_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron positive threshold.

  • neg_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron negative threshold magnitude.

  • bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.

  • learn_alpha (bool, default=True) – whether alpha is trainable.

  • learn_beta (bool, default=True) – whether beta is trainable.

  • learn_pos_threshold (bool, default=True) – whether pos_threshold is trainable.

  • learn_neg_threshold (bool, default=True) – whether neg_threshold is trainable.

  • learn_bias (bool, default=True) – whether bias is trainable.

  • spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.

  • quant_fn (Callable, default=nn.Identity()) – output quantization function.

Variables:
  • syn – synaptic state.

  • mem – membrane state.

Notes

  • Input: tensor of shape [*,num_neurons,*] where num_neurons is at index dim.

  • Output: tensor with the same shape as the input.

Pseudocode looks as follows:

syn = alpha * syn + (1 - alpha) * x
mem = beta * mem + syn
pos = quant_fn(spike_fn(mem - pos_threshold + bias))
neg = -quant_fn(spike_fn(-neg_threshold - mem - bias))
mem = mem - pos * pos_threshold - neg * neg_threshold
return pos + neg

Examples:

>>> layer = tt.snn.SLIT(num_neurons=32)
>>> input = torch.randn(16, 32)
>>> output = layer(input)
>>> print(output.shape)
torch.Size([16, 32])
forward(x)[source]

Computes the forward pass.

class tracetorch.snn.RLIT(*args: Any, **kwargs: Any)[source]

Bases: Layer

A recurrent leaky integrate-and-ternary-fire layer.

RLIT adds a recurrent trace of the previous ternary output. The recurrent trace is scaled by rec_weight and added to the input before membrane integration and ternary firing.

Parameters:
  • num_neurons (int) – number of neurons in the target dimension.

  • beta (float or torch.Tensor, default=0.9) – membrane decay.

  • gamma (float or torch.Tensor, default=0.9) – recurrent trace decay.

  • pos_threshold (float or torch.Tensor, default=1.0) – positive threshold.

  • neg_threshold (float or torch.Tensor, default=1.0) – negative threshold magnitude.

  • bias (float or torch.Tensor, default=0.0) – bias that shifts both firing boundaries.

  • rec_weight (float or torch.Tensor, default=0.0) – recurrent input scale.

  • dim (int, default=-1) – the dimension along which the layer operates.

  • beta_rank (Literal[0, 1], default=1) – scalar or per-neuron membrane decay.

  • gamma_rank (Literal[0, 1], default=1) – scalar or per-neuron recurrent decay.

  • pos_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron positive threshold.

  • neg_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron negative threshold magnitude.

  • bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.

  • rec_weight_rank (Literal[0, 1], default=1) – scalar or per-neuron recurrent scale.

  • learn_beta (bool, default=True) – whether beta is trainable.

  • learn_gamma (bool, default=True) – whether gamma is trainable.

  • learn_pos_threshold (bool, default=True) – whether pos_threshold is trainable.

  • learn_neg_threshold (bool, default=True) – whether neg_threshold is trainable.

  • learn_bias (bool, default=True) – whether bias is trainable.

  • learn_rec_weight (bool, default=True) – whether rec_weight is trainable.

  • spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.

  • quant_fn (Callable, default=nn.Identity()) – output quantization function.

Variables:
  • mem – membrane state.

  • rec – recurrent trace state.

  • prev_output – previous returned output.

Notes

  • Input: tensor of shape [*,num_neurons,*] where num_neurons is at index dim.

  • Output: tensor with the same shape as the input.

Pseudocode looks as follows:

rec = gamma * rec + (1 - gamma) * prev_output
mem = beta * mem + x + rec_weight * rec
pos = quant_fn(spike_fn(mem - pos_threshold + bias))
neg = -quant_fn(spike_fn(-neg_threshold - mem - bias))
mem = mem - pos * pos_threshold - neg * neg_threshold
prev_output = pos + neg
return prev_output

Examples:

>>> layer = tt.snn.RLIT(num_neurons=32)
>>> input = torch.randn(16, 32)
>>> output = layer(input)
>>> print(output.shape)
torch.Size([16, 32])
forward(x)[source]

Computes the forward pass.

class tracetorch.snn.DSLIT(*args: Any, **kwargs: Any)[source]

Bases: Layer

A dual synaptic leaky integrate-and-ternary-fire layer.

DSLIT combines dual positive/negative traces with a synaptic stage and ternary output. It smooths positive and negative input separately before integrating the combined synaptic current into dual membrane traces.

Parameters:
  • num_neurons (int) – number of neurons in the target dimension.

  • pos_alpha (float or torch.Tensor, default=0.5) – positive synaptic decay.

  • neg_alpha (float or torch.Tensor, default=0.5) – negative synaptic decay.

  • pos_beta (float or torch.Tensor, default=0.9) – positive membrane decay.

  • neg_beta (float or torch.Tensor, default=0.9) – negative membrane decay.

  • pos_threshold (float or torch.Tensor, default=1.0) – positive threshold.

  • neg_threshold (float or torch.Tensor, default=1.0) – negative threshold magnitude.

  • bias (float or torch.Tensor, default=0.0) – bias that shifts both firing boundaries.

  • dim (int, default=-1) – the dimension along which the layer operates.

  • pos_alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron positive synaptic decay.

  • neg_alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron negative synaptic decay.

  • pos_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron positive membrane decay.

  • neg_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron negative membrane decay.

  • pos_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron positive threshold.

  • neg_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron negative threshold magnitude.

  • bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.

  • learn_pos_alpha (bool, default=True) – whether pos_alpha is trainable.

  • learn_neg_alpha (bool, default=True) – whether neg_alpha is trainable.

  • learn_pos_beta (bool, default=True) – whether pos_beta is trainable.

  • learn_neg_beta (bool, default=True) – whether neg_beta is trainable.

  • learn_pos_threshold (bool, default=True) – whether pos_threshold is trainable.

  • learn_neg_threshold (bool, default=True) – whether neg_threshold is trainable.

  • learn_bias (bool, default=True) – whether bias is trainable.

  • spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.

  • quant_fn (Callable, default=nn.Identity()) – output quantization function.

Variables:
  • pos_syn – positive synaptic state.

  • neg_syn – negative synaptic state.

  • pos_mem – positive membrane state.

  • neg_mem – negative membrane state.

Notes

  • Input: tensor of shape [*,num_neurons,*] where num_neurons is at index dim.

  • Output: tensor with the same shape as the input.

The membrane reset is split evenly across the two membrane branches.

Examples:

>>> layer = tt.snn.DSLIT(num_neurons=32)
>>> input = torch.randn(16, 32)
>>> output = layer(input)
>>> print(output.shape)
torch.Size([16, 32])
forward(x)[source]

Computes the forward pass.

class tracetorch.snn.DRLIT(*args: Any, **kwargs: Any)[source]

Bases: Layer

A dual recurrent leaky integrate-and-ternary-fire layer.

DRLIT keeps dual membrane traces and dual recurrent traces for ternary output. The previous output is split by sign into recurrent branches, then reintegrated with the current input before the ternary firing decision.

Parameters:
  • num_neurons (int) – number of neurons in the target dimension.

  • pos_beta (float or torch.Tensor, default=0.9) – positive membrane decay.

  • neg_beta (float or torch.Tensor, default=0.9) – negative membrane decay.

  • pos_gamma (float or torch.Tensor, default=0.9) – positive recurrent decay.

  • neg_gamma (float or torch.Tensor, default=0.9) – negative recurrent decay.

  • pos_threshold (float or torch.Tensor, default=1.0) – positive threshold.

  • neg_threshold (float or torch.Tensor, default=1.0) – negative threshold magnitude.

  • bias (float or torch.Tensor, default=0.0) – bias that shifts both firing boundaries.

  • pos_rec_weight (float or torch.Tensor, default=0.0) – positive recurrent input scale.

  • neg_rec_weight (float or torch.Tensor, default=0.0) – negative recurrent input scale.

  • dim (int, default=-1) – the dimension along which the layer operates.

  • pos_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron positive membrane decay.

  • neg_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron negative membrane decay.

  • pos_gamma_rank (Literal[0, 1], default=1) – scalar or per-neuron positive recurrent decay.

  • neg_gamma_rank (Literal[0, 1], default=1) – scalar or per-neuron negative recurrent decay.

  • pos_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron positive threshold.

  • neg_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron negative threshold magnitude.

  • bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.

  • pos_rec_weight_rank (Literal[0, 1], default=1) – scalar or per-neuron positive recurrent scale.

  • neg_rec_weight_rank (Literal[0, 1], default=1) – scalar or per-neuron negative recurrent scale.

  • learn_pos_beta (bool, default=True) – whether pos_beta is trainable.

  • learn_neg_beta (bool, default=True) – whether neg_beta is trainable.

  • learn_pos_gamma (bool, default=True) – whether pos_gamma is trainable.

  • learn_neg_gamma (bool, default=True) – whether neg_gamma is trainable.

  • learn_pos_threshold (bool, default=True) – whether pos_threshold is trainable.

  • learn_neg_threshold (bool, default=True) – whether neg_threshold is trainable.

  • learn_bias (bool, default=True) – whether bias is trainable.

  • learn_pos_rec_weight (bool, default=True) – whether pos_rec_weight is trainable.

  • learn_neg_rec_weight (bool, default=True) – whether neg_rec_weight is trainable.

  • spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.

  • quant_fn (Callable, default=nn.Identity()) – output quantization function.

Variables:
  • pos_mem – positive membrane state.

  • neg_mem – negative membrane state.

  • pos_rec – positive recurrent trace state.

  • neg_rec – negative recurrent trace state.

  • prev_output – previous returned output.

Notes

  • Input: tensor of shape [*,num_neurons,*] where num_neurons is at index dim.

  • Output: tensor with the same shape as the input.

DRLIT is the ternary recurrent layer to reach for when positive and negative recurrent history should use different decays and gains.

Examples:

>>> layer = tt.snn.DRLIT(num_neurons=32)
>>> input = torch.randn(16, 32)
>>> output = layer(input)
>>> print(output.shape)
torch.Size([16, 32])
forward(x)[source]

Computes the forward pass.

class tracetorch.snn.SRLIT(*args: Any, **kwargs: Any)[source]

Bases: Layer

A synaptic recurrent leaky integrate-and-ternary-fire layer.

SRLIT combines a synaptic input trace with a recurrent trace of the previous ternary output. The membrane receives both the smoothed input and the recurrent current before the ternary firing decision.

Parameters:
  • num_neurons (int) – number of neurons in the target dimension.

  • alpha (float or torch.Tensor, default=0.5) – synaptic decay.

  • beta (float or torch.Tensor, default=0.9) – membrane decay.

  • gamma (float or torch.Tensor, default=0.9) – recurrent trace decay.

  • pos_threshold (float or torch.Tensor, default=1.0) – positive threshold.

  • neg_threshold (float or torch.Tensor, default=1.0) – negative threshold magnitude.

  • bias (float or torch.Tensor, default=0.0) – bias that shifts both firing boundaries.

  • rec_weight (float or torch.Tensor, default=0.0) – recurrent input scale.

  • dim (int, default=-1) – the dimension along which the layer operates.

  • alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron synaptic decay.

  • beta_rank (Literal[0, 1], default=1) – scalar or per-neuron membrane decay.

  • gamma_rank (Literal[0, 1], default=1) – scalar or per-neuron recurrent decay.

  • pos_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron positive threshold.

  • neg_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron negative threshold magnitude.

  • bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.

  • rec_weight_rank (Literal[0, 1], default=1) – scalar or per-neuron recurrent scale.

  • learn_alpha (bool, default=True) – whether alpha is trainable.

  • learn_beta (bool, default=True) – whether beta is trainable.

  • learn_gamma (bool, default=True) – whether gamma is trainable.

  • learn_pos_threshold (bool, default=True) – whether pos_threshold is trainable.

  • learn_neg_threshold (bool, default=True) – whether neg_threshold is trainable.

  • learn_bias (bool, default=True) – whether bias is trainable.

  • learn_rec_weight (bool, default=True) – whether rec_weight is trainable.

  • spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.

  • quant_fn (Callable, default=nn.Identity()) – output quantization function.

Variables:
  • syn – synaptic state.

  • mem – membrane state.

  • rec – recurrent trace state.

  • prev_output – previous returned output.

Notes

  • Input: tensor of shape [*,num_neurons,*] where num_neurons is at index dim.

  • Output: tensor with the same shape as the input.

Pseudocode looks as follows:

syn = alpha * syn + (1 - alpha) * x
rec = gamma * rec + (1 - gamma) * prev_output
mem = beta * mem + syn + rec_weight * rec
pos = quant_fn(spike_fn(mem - pos_threshold + bias))
neg = -quant_fn(spike_fn(-neg_threshold - mem - bias))
prev_output = pos + neg
return prev_output

Examples:

>>> layer = tt.snn.SRLIT(num_neurons=32)
>>> input = torch.randn(16, 32)
>>> output = layer(input)
>>> print(output.shape)
torch.Size([16, 32])
forward(x)[source]

Computes the forward pass.

class tracetorch.snn.DSRLIT(*args: Any, **kwargs: Any)[source]

Bases: Layer

A dual synaptic recurrent leaky integrate-and-ternary-fire layer.

DSRLIT combines every ternary trace mechanism: dual synaptic traces, dual recurrent traces, dual membrane traces, and positive/negative firing thresholds. It is the most expressive unscaled ternary SNN layer.

Parameters:
  • num_neurons (int) – number of neurons in the target dimension.

  • pos_alpha (float or torch.Tensor, default=0.5) – positive synaptic decay.

  • neg_alpha (float or torch.Tensor, default=0.5) – negative synaptic decay.

  • pos_beta (float or torch.Tensor, default=0.9) – positive membrane decay.

  • neg_beta (float or torch.Tensor, default=0.9) – negative membrane decay.

  • pos_gamma (float or torch.Tensor, default=0.9) – positive recurrent decay.

  • neg_gamma (float or torch.Tensor, default=0.9) – negative recurrent decay.

  • pos_threshold (float or torch.Tensor, default=1.0) – positive threshold.

  • neg_threshold (float or torch.Tensor, default=1.0) – negative threshold magnitude.

  • bias (float or torch.Tensor, default=0.0) – bias that shifts both firing boundaries.

  • pos_rec_weight (float or torch.Tensor, default=0.0) – positive recurrent input scale.

  • neg_rec_weight (float or torch.Tensor, default=0.0) – negative recurrent input scale.

  • dim (int, default=-1) – the dimension along which the layer operates.

  • pos_alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron positive synaptic decay.

  • neg_alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron negative synaptic decay.

  • pos_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron positive membrane decay.

  • neg_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron negative membrane decay.

  • pos_gamma_rank (Literal[0, 1], default=1) – scalar or per-neuron positive recurrent decay.

  • neg_gamma_rank (Literal[0, 1], default=1) – scalar or per-neuron negative recurrent decay.

  • pos_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron positive threshold.

  • neg_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron negative threshold magnitude.

  • bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.

  • pos_rec_weight_rank (Literal[0, 1], default=1) – scalar or per-neuron positive recurrent scale.

  • neg_rec_weight_rank (Literal[0, 1], default=1) – scalar or per-neuron negative recurrent scale.

  • learn_pos_alpha (bool, default=True) – whether pos_alpha is trainable.

  • learn_neg_alpha (bool, default=True) – whether neg_alpha is trainable.

  • learn_pos_beta (bool, default=True) – whether pos_beta is trainable.

  • learn_neg_beta (bool, default=True) – whether neg_beta is trainable.

  • learn_pos_gamma (bool, default=True) – whether pos_gamma is trainable.

  • learn_neg_gamma (bool, default=True) – whether neg_gamma is trainable.

  • learn_pos_threshold (bool, default=True) – whether pos_threshold is trainable.

  • learn_neg_threshold (bool, default=True) – whether neg_threshold is trainable.

  • learn_bias (bool, default=True) – whether bias is trainable.

  • learn_pos_rec_weight (bool, default=True) – whether pos_rec_weight is trainable.

  • learn_neg_rec_weight (bool, default=True) – whether neg_rec_weight is trainable.

  • spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.

  • quant_fn (Callable, default=nn.Identity()) – output quantization function.

Variables:
  • pos_syn – positive synaptic state.

  • neg_syn – negative synaptic state.

  • pos_mem – positive membrane state.

  • neg_mem – negative membrane state.

  • pos_rec – positive recurrent trace state.

  • neg_rec – negative recurrent trace state.

  • prev_output – previous returned output.

Notes

  • Input: tensor of shape [*,num_neurons,*] where num_neurons is at index dim.

  • Output: tensor with the same shape as the input.

Use DSRLIT when both input history and output history should retain sign-specific dynamics before ternary firing.

Examples:

>>> layer = tt.snn.DSRLIT(num_neurons=32)
>>> input = torch.randn(16, 32)
>>> output = layer(input)
>>> print(output.shape)
torch.Size([16, 32])
forward(x)[source]

Computes the forward pass.

Ternary Scaled Firing Layers

class tracetorch.snn.LITS(*args: Any, **kwargs: Any)[source]

Bases: Layer

A leaky integrate-and-ternary-scaled-fire layer.

LITS is the scaled-output variant of LIT. It computes positive and negative ternary firing decisions, resets the membrane with the unscaled spike signs, then multiplies positive and negative outputs by separate learnable or fixed scales.

Parameters:
  • num_neurons (int) – number of neurons in the target dimension.

  • beta (float or torch.Tensor, default=0.9) – membrane decay, constrained to (0, 1).

  • pos_threshold (float or torch.Tensor, default=1.0) – positive threshold.

  • neg_threshold (float or torch.Tensor, default=1.0) – negative threshold magnitude.

  • bias (float or torch.Tensor, default=0.0) – bias that shifts both firing boundaries.

  • pos_scale (float or torch.Tensor, default=1.0) – multiplier applied to positive output events.

  • neg_scale (float or torch.Tensor, default=1.0) – multiplier applied to negative output events.

  • dim (int, default=-1) – the dimension along which the layer operates.

  • beta_rank (Literal[0, 1], default=1) – scalar or per-neuron membrane decay.

  • pos_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron positive threshold.

  • neg_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron negative threshold magnitude.

  • pos_scale_rank (Literal[0, 1], default=1) – scalar or per-neuron positive scale.

  • neg_scale_rank (Literal[0, 1], default=1) – scalar or per-neuron negative scale.

  • bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.

  • learn_beta (bool, default=True) – whether beta is trainable.

  • learn_pos_threshold (bool, default=True) – whether pos_threshold is trainable.

  • learn_neg_threshold (bool, default=True) – whether neg_threshold is trainable.

  • learn_pos_scale (bool, default=True) – whether pos_scale is trainable.

  • learn_neg_scale (bool, default=True) – whether neg_scale is trainable.

  • learn_bias (bool, default=True) – whether bias is trainable.

  • spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.

  • quant_fn (Callable, default=nn.Identity()) – output quantization function.

Variables:
  • mem – membrane state.

  • pos_scale – positive output scale.

  • neg_scale – negative output scale.

Notes

  • Input: tensor of shape [*,num_neurons,*] where num_neurons is at index dim.

  • Output: tensor with the same shape as the input.

Pseudocode looks as follows:

mem = beta * mem + x
pos = quant_fn(spike_fn(mem - pos_threshold + bias))
neg = -quant_fn(spike_fn(-neg_threshold - mem - bias))
mem = mem - pos * pos_threshold - neg * neg_threshold
return pos * pos_scale + neg * neg_scale

Examples:

>>> layer = tt.snn.LITS(num_neurons=32)
>>> input = torch.randn(16, 32)
>>> output = layer(input)
>>> print(output.shape)
torch.Size([16, 32])
forward(x)[source]

Computes the forward pass.

class tracetorch.snn.DLITS(*args: Any, **kwargs: Any)[source]

Bases: Layer

A dual leaky integrate-and-ternary-scaled-fire layer.

DLITS combines dual membrane traces with scaled ternary output. The summed membrane is thresholded, the reset is split across the dual traces, and the returned positive and negative outputs are scaled independently.

Parameters:
  • num_neurons (int) – number of neurons in the target dimension.

  • pos_beta (float or torch.Tensor, default=0.9) – positive membrane decay.

  • neg_beta (float or torch.Tensor, default=0.9) – negative membrane decay.

  • pos_threshold (float or torch.Tensor, default=1.0) – positive threshold.

  • neg_threshold (float or torch.Tensor, default=1.0) – negative threshold magnitude.

  • bias (float or torch.Tensor, default=0.0) – bias that shifts both firing boundaries.

  • pos_scale (float or torch.Tensor, default=1.0) – positive output scale.

  • neg_scale (float or torch.Tensor, default=1.0) – negative output scale.

  • dim (int, default=-1) – the dimension along which the layer operates.

  • pos_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron positive decay.

  • neg_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron negative decay.

  • pos_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron positive threshold.

  • neg_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron negative threshold magnitude.

  • pos_scale_rank (Literal[0, 1], default=1) – scalar or per-neuron positive scale.

  • neg_scale_rank (Literal[0, 1], default=1) – scalar or per-neuron negative scale.

  • bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.

  • learn_pos_beta (bool, default=True) – whether pos_beta is trainable.

  • learn_neg_beta (bool, default=True) – whether neg_beta is trainable.

  • learn_pos_threshold (bool, default=True) – whether pos_threshold is trainable.

  • learn_neg_threshold (bool, default=True) – whether neg_threshold is trainable.

  • learn_pos_scale (bool, default=True) – whether pos_scale is trainable.

  • learn_neg_scale (bool, default=True) – whether neg_scale is trainable.

  • learn_bias (bool, default=True) – whether bias is trainable.

  • spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.

  • quant_fn (Callable, default=nn.Identity()) – output quantization function.

Variables:
  • pos_mem – positive membrane state.

  • neg_mem – negative membrane state.

  • pos_scale – positive output scale.

  • neg_scale – negative output scale.

Notes

  • Input: tensor of shape [*,num_neurons,*] where num_neurons is at index dim.

  • Output: tensor with the same shape as the input.

The output scales affect what downstream layers receive; membrane reset still uses the unscaled threshold crossing events.

Examples:

>>> layer = tt.snn.DLITS(num_neurons=32)
>>> input = torch.randn(16, 32)
>>> output = layer(input)
>>> print(output.shape)
torch.Size([16, 32])
forward(x)[source]

Computes the forward pass.

class tracetorch.snn.SLITS(*args: Any, **kwargs: Any)[source]

Bases: Layer

A synaptic leaky integrate-and-ternary-scaled-fire layer.

SLITS smooths the input with a synaptic trace before membrane integration and scaled ternary output.

Parameters:
  • num_neurons (int) – number of neurons in the target dimension.

  • alpha (float or torch.Tensor, default=0.5) – synaptic decay.

  • beta (float or torch.Tensor, default=0.9) – membrane decay.

  • pos_threshold (float or torch.Tensor, default=1.0) – positive threshold.

  • neg_threshold (float or torch.Tensor, default=1.0) – negative threshold magnitude.

  • bias (float or torch.Tensor, default=0.0) – bias that shifts both firing boundaries.

  • pos_scale (float or torch.Tensor, default=1.0) – positive output scale.

  • neg_scale (float or torch.Tensor, default=1.0) – negative output scale.

  • dim (int, default=-1) – the dimension along which the layer operates.

  • alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron synaptic decay.

  • beta_rank (Literal[0, 1], default=1) – scalar or per-neuron membrane decay.

  • pos_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron positive threshold.

  • neg_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron negative threshold magnitude.

  • pos_scale_rank (Literal[0, 1], default=1) – scalar or per-neuron positive scale.

  • neg_scale_rank (Literal[0, 1], default=1) – scalar or per-neuron negative scale.

  • bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.

  • learn_alpha (bool, default=True) – whether alpha is trainable.

  • learn_beta (bool, default=True) – whether beta is trainable.

  • learn_pos_threshold (bool, default=True) – whether pos_threshold is trainable.

  • learn_neg_threshold (bool, default=True) – whether neg_threshold is trainable.

  • learn_pos_scale (bool, default=True) – whether pos_scale is trainable.

  • learn_neg_scale (bool, default=True) – whether neg_scale is trainable.

  • learn_bias (bool, default=True) – whether bias is trainable.

  • spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.

  • quant_fn (Callable, default=nn.Identity()) – output quantization function.

Variables:
  • syn – synaptic state.

  • mem – membrane state.

  • pos_scale – positive output scale.

  • neg_scale – negative output scale.

Notes

  • Input: tensor of shape [*,num_neurons,*] where num_neurons is at index dim.

  • Output: tensor with the same shape as the input.

Pseudocode looks as follows:

syn = alpha * syn + (1 - alpha) * x
mem = beta * mem + syn
pos = quant_fn(spike_fn(mem - pos_threshold + bias))
neg = -quant_fn(spike_fn(-neg_threshold - mem - bias))
mem = mem - pos * pos_threshold - neg * neg_threshold
return pos * pos_scale + neg * neg_scale

Examples:

>>> layer = tt.snn.SLITS(num_neurons=32)
>>> input = torch.randn(16, 32)
>>> output = layer(input)
>>> print(output.shape)
torch.Size([16, 32])
forward(x)[source]

Computes the forward pass.

class tracetorch.snn.RLITS(*args: Any, **kwargs: Any)[source]

Bases: Layer

A recurrent leaky integrate-and-ternary-scaled-fire layer.

RLITS adds a recurrent trace of the previous scaled ternary output. The recurrent trace is scaled by rec_weight and added before membrane integration; the current output is then scaled separately for positive and negative events.

Parameters:
  • num_neurons (int) – number of neurons in the target dimension.

  • beta (float or torch.Tensor, default=0.9) – membrane decay.

  • gamma (float or torch.Tensor, default=0.9) – recurrent trace decay.

  • pos_threshold (float or torch.Tensor, default=1.0) – positive threshold.

  • neg_threshold (float or torch.Tensor, default=1.0) – negative threshold magnitude.

  • bias (float or torch.Tensor, default=0.0) – bias that shifts both firing boundaries.

  • pos_scale (float or torch.Tensor, default=1.0) – positive output scale.

  • neg_scale (float or torch.Tensor, default=1.0) – negative output scale.

  • rec_weight (float or torch.Tensor, default=0.0) – recurrent input scale.

  • dim (int, default=-1) – the dimension along which the layer operates.

  • beta_rank (Literal[0, 1], default=1) – scalar or per-neuron membrane decay.

  • gamma_rank (Literal[0, 1], default=1) – scalar or per-neuron recurrent decay.

  • pos_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron positive threshold.

  • neg_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron negative threshold magnitude.

  • pos_scale_rank (Literal[0, 1], default=1) – scalar or per-neuron positive scale.

  • neg_scale_rank (Literal[0, 1], default=1) – scalar or per-neuron negative scale.

  • bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.

  • rec_weight_rank (Literal[0, 1], default=1) – scalar or per-neuron recurrent scale.

  • learn_beta (bool, default=True) – whether beta is trainable.

  • learn_gamma (bool, default=True) – whether gamma is trainable.

  • learn_pos_threshold (bool, default=True) – whether pos_threshold is trainable.

  • learn_neg_threshold (bool, default=True) – whether neg_threshold is trainable.

  • learn_pos_scale (bool, default=True) – whether pos_scale is trainable.

  • learn_neg_scale (bool, default=True) – whether neg_scale is trainable.

  • learn_bias (bool, default=True) – whether bias is trainable.

  • learn_rec_weight (bool, default=True) – whether rec_weight is trainable.

  • spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.

  • quant_fn (Callable, default=nn.Identity()) – output quantization function.

Variables:
  • mem – membrane state.

  • rec – recurrent trace state.

  • prev_output – previous returned output.

  • pos_scale – positive output scale.

  • neg_scale – negative output scale.

Notes

  • Input: tensor of shape [*,num_neurons,*] where num_neurons is at index dim.

  • Output: tensor with the same shape as the input.

prev_output stores the scaled output, so recurrence sees the same value that downstream layers saw at the previous timestep.

Examples:

>>> layer = tt.snn.RLITS(num_neurons=32)
>>> input = torch.randn(16, 32)
>>> output = layer(input)
>>> print(output.shape)
torch.Size([16, 32])
forward(x)[source]

Computes the forward pass.

class tracetorch.snn.DSLITS(*args: Any, **kwargs: Any)[source]

Bases: Layer

A dual synaptic leaky integrate-and-ternary-scaled-fire layer.

DSLITS combines dual synaptic traces, dual membrane traces, and scaled ternary output. It is the scaled counterpart of DSLIT.

Parameters:
  • num_neurons (int) – number of neurons in the target dimension.

  • pos_alpha (float or torch.Tensor, default=0.5) – positive synaptic decay.

  • neg_alpha (float or torch.Tensor, default=0.5) – negative synaptic decay.

  • pos_beta (float or torch.Tensor, default=0.9) – positive membrane decay.

  • neg_beta (float or torch.Tensor, default=0.9) – negative membrane decay.

  • pos_threshold (float or torch.Tensor, default=1.0) – positive threshold.

  • neg_threshold (float or torch.Tensor, default=1.0) – negative threshold magnitude.

  • bias (float or torch.Tensor, default=0.0) – bias that shifts both firing boundaries.

  • pos_scale (float or torch.Tensor, default=1.0) – positive output scale.

  • neg_scale (float or torch.Tensor, default=1.0) – negative output scale.

  • dim (int, default=-1) – the dimension along which the layer operates.

  • pos_alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron positive synaptic decay.

  • neg_alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron negative synaptic decay.

  • pos_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron positive membrane decay.

  • neg_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron negative membrane decay.

  • pos_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron positive threshold.

  • neg_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron negative threshold magnitude.

  • pos_scale_rank (Literal[0, 1], default=1) – scalar or per-neuron positive scale.

  • neg_scale_rank (Literal[0, 1], default=1) – scalar or per-neuron negative scale.

  • bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.

  • learn_pos_alpha (bool, default=True) – whether pos_alpha is trainable.

  • learn_neg_alpha (bool, default=True) – whether neg_alpha is trainable.

  • learn_pos_beta (bool, default=True) – whether pos_beta is trainable.

  • learn_neg_beta (bool, default=True) – whether neg_beta is trainable.

  • learn_pos_threshold (bool, default=True) – whether pos_threshold is trainable.

  • learn_neg_threshold (bool, default=True) – whether neg_threshold is trainable.

  • learn_pos_scale (bool, default=True) – whether pos_scale is trainable.

  • learn_neg_scale (bool, default=True) – whether neg_scale is trainable.

  • learn_bias (bool, default=True) – whether bias is trainable.

  • spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.

  • quant_fn (Callable, default=nn.Identity()) – output quantization function.

Variables:
  • pos_syn – positive synaptic state.

  • neg_syn – negative synaptic state.

  • pos_mem – positive membrane state.

  • neg_mem – negative membrane state.

  • pos_scale – positive output scale.

  • neg_scale – negative output scale.

Notes

  • Input: tensor of shape [*,num_neurons,*] where num_neurons is at index dim.

  • Output: tensor with the same shape as the input.

Use this layer when sign-specific input smoothing and sign-specific output magnitudes are both useful, but no recurrent output trace is needed.

Examples:

>>> layer = tt.snn.DSLITS(num_neurons=32)
>>> input = torch.randn(16, 32)
>>> output = layer(input)
>>> print(output.shape)
torch.Size([16, 32])
forward(x)[source]

Computes the forward pass.

class tracetorch.snn.DRLITS(*args: Any, **kwargs: Any)[source]

Bases: Layer

A dual recurrent leaky integrate-and-ternary-scaled-fire layer.

DRLITS combines dual membrane traces, dual recurrent traces, and scaled ternary output. It is useful when positive and negative recurrent history should have different decays, gains, and downstream output magnitudes.

Parameters:
  • num_neurons (int) – number of neurons in the target dimension.

  • pos_beta (float or torch.Tensor, default=0.9) – positive membrane decay.

  • neg_beta (float or torch.Tensor, default=0.9) – negative membrane decay.

  • pos_gamma (float or torch.Tensor, default=0.9) – positive recurrent decay.

  • neg_gamma (float or torch.Tensor, default=0.9) – negative recurrent decay.

  • pos_threshold (float or torch.Tensor, default=1.0) – positive threshold.

  • neg_threshold (float or torch.Tensor, default=1.0) – negative threshold magnitude.

  • bias (float or torch.Tensor, default=0.0) – bias that shifts both firing boundaries.

  • pos_scale (float or torch.Tensor, default=1.0) – positive output scale.

  • neg_scale (float or torch.Tensor, default=1.0) – negative output scale.

  • pos_rec_weight (float or torch.Tensor, default=0.0) – positive recurrent input scale.

  • neg_rec_weight (float or torch.Tensor, default=0.0) – negative recurrent input scale.

  • dim (int, default=-1) – the dimension along which the layer operates.

  • pos_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron positive membrane decay.

  • neg_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron negative membrane decay.

  • pos_gamma_rank (Literal[0, 1], default=1) – scalar or per-neuron positive recurrent decay.

  • neg_gamma_rank (Literal[0, 1], default=1) – scalar or per-neuron negative recurrent decay.

  • pos_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron positive threshold.

  • neg_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron negative threshold magnitude.

  • pos_scale_rank (Literal[0, 1], default=1) – scalar or per-neuron positive scale.

  • neg_scale_rank (Literal[0, 1], default=1) – scalar or per-neuron negative scale.

  • bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.

  • pos_rec_weight_rank (Literal[0, 1], default=1) – scalar or per-neuron positive recurrent scale.

  • neg_rec_weight_rank (Literal[0, 1], default=1) – scalar or per-neuron negative recurrent scale.

  • learn_pos_beta (bool, default=True) – whether pos_beta is trainable.

  • learn_neg_beta (bool, default=True) – whether neg_beta is trainable.

  • learn_pos_gamma (bool, default=True) – whether pos_gamma is trainable.

  • learn_neg_gamma (bool, default=True) – whether neg_gamma is trainable.

  • learn_pos_threshold (bool, default=True) – whether pos_threshold is trainable.

  • learn_neg_threshold (bool, default=True) – whether neg_threshold is trainable.

  • learn_pos_scale (bool, default=True) – whether pos_scale is trainable.

  • learn_neg_scale (bool, default=True) – whether neg_scale is trainable.

  • learn_bias (bool, default=True) – whether bias is trainable.

  • learn_pos_rec_weight (bool, default=True) – whether pos_rec_weight is trainable.

  • learn_neg_rec_weight (bool, default=True) – whether neg_rec_weight is trainable.

  • spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.

  • quant_fn (Callable, default=nn.Identity()) – output quantization function.

Variables:
  • pos_mem – positive membrane state.

  • neg_mem – negative membrane state.

  • pos_rec – positive recurrent trace state.

  • neg_rec – negative recurrent trace state.

  • prev_output – previous returned output.

  • pos_scale – positive output scale.

  • neg_scale – negative output scale.

Notes

  • Input: tensor of shape [*,num_neurons,*] where num_neurons is at index dim.

  • Output: tensor with the same shape as the input.

The previous output stored in prev_output is already scaled, matching the value that downstream layers received.

Examples:

>>> layer = tt.snn.DRLITS(num_neurons=32)
>>> input = torch.randn(16, 32)
>>> output = layer(input)
>>> print(output.shape)
torch.Size([16, 32])
forward(x)[source]

Computes the forward pass.

class tracetorch.snn.SRLITS(*args: Any, **kwargs: Any)[source]

Bases: Layer

A synaptic recurrent leaky integrate-and-ternary-scaled-fire layer.

SRLITS combines a synaptic input trace, a recurrent output trace, and scaled ternary firing. It is the scaled counterpart of SRLIT.

Parameters:
  • num_neurons (int) – number of neurons in the target dimension.

  • alpha (float or torch.Tensor, default=0.5) – synaptic decay.

  • beta (float or torch.Tensor, default=0.9) – membrane decay.

  • gamma (float or torch.Tensor, default=0.9) – recurrent trace decay.

  • pos_threshold (float or torch.Tensor, default=1.0) – positive threshold.

  • neg_threshold (float or torch.Tensor, default=1.0) – negative threshold magnitude.

  • bias (float or torch.Tensor, default=0.0) – bias that shifts both firing boundaries.

  • pos_scale (float or torch.Tensor, default=1.0) – positive output scale.

  • neg_scale (float or torch.Tensor, default=1.0) – negative output scale.

  • rec_weight (float or torch.Tensor, default=0.0) – recurrent input scale.

  • dim (int, default=-1) – the dimension along which the layer operates.

  • alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron synaptic decay.

  • beta_rank (Literal[0, 1], default=1) – scalar or per-neuron membrane decay.

  • gamma_rank (Literal[0, 1], default=1) – scalar or per-neuron recurrent decay.

  • pos_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron positive threshold.

  • neg_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron negative threshold magnitude.

  • pos_scale_rank (Literal[0, 1], default=1) – scalar or per-neuron positive scale.

  • neg_scale_rank (Literal[0, 1], default=1) – scalar or per-neuron negative scale.

  • bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.

  • rec_weight_rank (Literal[0, 1], default=1) – scalar or per-neuron recurrent scale.

  • learn_alpha (bool, default=True) – whether alpha is trainable.

  • learn_beta (bool, default=True) – whether beta is trainable.

  • learn_gamma (bool, default=True) – whether gamma is trainable.

  • learn_pos_threshold (bool, default=True) – whether pos_threshold is trainable.

  • learn_neg_threshold (bool, default=True) – whether neg_threshold is trainable.

  • learn_pos_scale (bool, default=True) – whether pos_scale is trainable.

  • learn_neg_scale (bool, default=True) – whether neg_scale is trainable.

  • learn_bias (bool, default=True) – whether bias is trainable.

  • learn_rec_weight (bool, default=True) – whether rec_weight is trainable.

  • spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.

  • quant_fn (Callable, default=nn.Identity()) – output quantization function.

Variables:
  • syn – synaptic state.

  • mem – membrane state.

  • rec – recurrent trace state.

  • prev_output – previous returned output.

  • pos_scale – positive output scale.

  • neg_scale – negative output scale.

Notes

  • Input: tensor of shape [*,num_neurons,*] where num_neurons is at index dim.

  • Output: tensor with the same shape as the input.

Pseudocode looks as follows:

syn = alpha * syn + (1 - alpha) * x
rec = gamma * rec + (1 - gamma) * prev_output
mem = beta * mem + syn + rec_weight * rec
pos = quant_fn(spike_fn(mem - pos_threshold + bias))
neg = -quant_fn(spike_fn(-neg_threshold - mem - bias))
mem = mem - pos * pos_threshold - neg * neg_threshold
prev_output = pos * pos_scale + neg * neg_scale
return prev_output

Examples:

>>> layer = tt.snn.SRLITS(num_neurons=32)
>>> input = torch.randn(16, 32)
>>> output = layer(input)
>>> print(output.shape)
torch.Size([16, 32])
forward(x)[source]

Computes the forward pass.

class tracetorch.snn.DSRLITS(*args: Any, **kwargs: Any)[source]

Bases: Layer

A dual synaptic recurrent leaky integrate-and-ternary-scaled-fire layer.

DSRLITS combines every SNN mechanism provided by traceTorch: dual sign-specific synaptic traces, dual sign-specific recurrent traces, dual sign-specific membrane traces, ternary firing, and independent positive and negative output scales.

Parameters:
  • num_neurons (int) – number of neurons in the target dimension.

  • pos_alpha (float or torch.Tensor, default=0.5) – positive synaptic decay.

  • neg_alpha (float or torch.Tensor, default=0.5) – negative synaptic decay.

  • pos_beta (float or torch.Tensor, default=0.9) – positive membrane decay.

  • neg_beta (float or torch.Tensor, default=0.9) – negative membrane decay.

  • pos_gamma (float or torch.Tensor, default=0.9) – positive recurrent decay.

  • neg_gamma (float or torch.Tensor, default=0.9) – negative recurrent decay.

  • pos_threshold (float or torch.Tensor, default=1.0) – positive threshold.

  • neg_threshold (float or torch.Tensor, default=1.0) – negative threshold magnitude.

  • bias (float or torch.Tensor, default=0.0) – bias that shifts both firing boundaries.

  • pos_scale (float or torch.Tensor, default=1.0) – positive output scale.

  • neg_scale (float or torch.Tensor, default=1.0) – negative output scale.

  • pos_rec_weight (float or torch.Tensor, default=0.0) – positive recurrent input scale.

  • neg_rec_weight (float or torch.Tensor, default=0.0) – negative recurrent input scale.

  • dim (int, default=-1) – the dimension along which the layer operates.

  • pos_alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron positive synaptic decay.

  • neg_alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron negative synaptic decay.

  • pos_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron positive membrane decay.

  • neg_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron negative membrane decay.

  • pos_gamma_rank (Literal[0, 1], default=1) – scalar or per-neuron positive recurrent decay.

  • neg_gamma_rank (Literal[0, 1], default=1) – scalar or per-neuron negative recurrent decay.

  • pos_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron positive threshold.

  • neg_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron negative threshold magnitude.

  • pos_scale_rank (Literal[0, 1], default=1) – scalar or per-neuron positive scale.

  • neg_scale_rank (Literal[0, 1], default=1) – scalar or per-neuron negative scale.

  • bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.

  • pos_rec_weight_rank (Literal[0, 1], default=1) – scalar or per-neuron positive recurrent scale.

  • neg_rec_weight_rank (Literal[0, 1], default=1) – scalar or per-neuron negative recurrent scale.

  • learn_pos_alpha (bool, default=True) – whether pos_alpha is trainable.

  • learn_neg_alpha (bool, default=True) – whether neg_alpha is trainable.

  • learn_pos_beta (bool, default=True) – whether pos_beta is trainable.

  • learn_neg_beta (bool, default=True) – whether neg_beta is trainable.

  • learn_pos_gamma (bool, default=True) – whether pos_gamma is trainable.

  • learn_neg_gamma (bool, default=True) – whether neg_gamma is trainable.

  • learn_pos_threshold (bool, default=True) – whether pos_threshold is trainable.

  • learn_neg_threshold (bool, default=True) – whether neg_threshold is trainable.

  • learn_pos_scale (bool, default=True) – whether pos_scale is trainable.

  • learn_neg_scale (bool, default=True) – whether neg_scale is trainable.

  • learn_bias (bool, default=True) – whether bias is trainable.

  • learn_pos_rec_weight (bool, default=True) – whether pos_rec_weight is trainable.

  • learn_neg_rec_weight (bool, default=True) – whether neg_rec_weight is trainable.

  • spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.

  • quant_fn (Callable, default=nn.Identity()) – output quantization function.

Variables:
  • pos_syn – positive synaptic state.

  • neg_syn – negative synaptic state.

  • pos_mem – positive membrane state.

  • neg_mem – negative membrane state.

  • pos_rec – positive recurrent trace state.

  • neg_rec – negative recurrent trace state.

  • prev_output – previous returned output.

  • pos_scale – positive output scale.

  • neg_scale – negative output scale.

Notes

  • Input: tensor of shape [*,num_neurons,*] where num_neurons is at index dim.

  • Output: tensor with the same shape as the input.

This layer is intentionally dense: it is meant for cases where sign, input history, output history, thresholding, and output magnitude all need separate degrees of freedom.

Examples:

>>> layer = tt.snn.DSRLITS(num_neurons=32)
>>> input = torch.randn(16, 32)
>>> output = layer(input)
>>> print(output.shape)
torch.Size([16, 32])
forward(x)[source]

Computes the forward pass.