SNN
tt.snn contains traceTorch’s leaky-integrator-based spiking layers.
Base Layer
- class tracetorch.snn.Layer(*args: Any, **kwargs: Any)[source]
Bases:
LayerBase class for traceTorch SNN layers.
This class extends
tt.Layerwith parameter registration helpers commonly used by spiking layers:decays are constrained to
(0, 1)through a sigmoid transform;thresholds are constrained to positive values through a softplus transform;
biases use a smooth unconstrained transform.
- Parameters:
num_neurons (int) – number of neurons in the target dimension.
dim (int, default=-1) – dimension along which the layer operates.
Notes
Users normally instantiate concrete layers such as
LIBorLIT. Subclass this base when creating a custom SNN layer that should integrate withtt.Modelstate management and traceTorch parameter compilation.
Continuous Leaky Integrators
- class tracetorch.snn.LI(*args: Any, **kwargs: Any)[source]
Bases:
LayerA leaky integrator layer with continuous membrane output.
LIstores a membrane trace and returns that trace directly. It does not spike, threshold, or reset; it is useful as a readout-style accumulator or as a smooth recurrent feature transform.- Parameters:
num_neurons (int) – number of neurons in the target dimension.
beta (float or torch.Tensor, default=0.9) – membrane decay. The activated value is constrained to
(0, 1).dim (int, default=-1) – the dimension along which the layer operates.
beta_rank (Literal[0, 1], default=1) –
0for a scalar decay shared by all neurons,1for one decay per neuron.learn_beta (bool, default=True) – whether
betais trainable.
- Variables:
mem – membrane state. Lazily initialized to zeros with the input shape, except the target dimension is set to
num_neurons.beta – activated membrane decay.
Notes
Input: tensor of shape
[*,num_neurons,*]wherenum_neuronsis at indexdim.Output: tensor with the same shape as the input.
Updates the membrane by exponentially decaying the previous value and adding the current input. Pseudocode looks as follows:
mem = beta * mem + x return mem
Examples:
>>> layer = tt.snn.LI(num_neurons=32) >>> input = torch.rand(16, 32) >>> output = layer(input) >>> print(output.shape) torch.Size([16, 32])
- class tracetorch.snn.DLI(*args: Any, **kwargs: Any)[source]
Bases:
LayerA dual leaky integrator layer with continuous membrane output.
DLIsplits the membrane trace into separate positive and negative branches. Positive input updatespos_memand negative input updatesneg_mem; the returned membrane is their sum.- Parameters:
num_neurons (int) – number of neurons in the target dimension.
pos_beta (float or torch.Tensor, default=0.9) – decay for the positive membrane branch.
neg_beta (float or torch.Tensor, default=0.9) – decay for the negative membrane branch.
dim (int, default=-1) – the dimension along which the layer operates.
pos_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron positive decay.
neg_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron negative decay.
learn_pos_beta (bool, default=True) – whether
pos_betais trainable.learn_neg_beta (bool, default=True) – whether
neg_betais trainable.
- Variables:
pos_mem – positive membrane state.
neg_mem – negative membrane state.
pos_beta – activated positive membrane decay, constrained to
(0, 1).neg_beta – activated negative membrane decay, constrained to
(0, 1).
Notes
Input: tensor of shape
[*,num_neurons,*]wherenum_neuronsis at indexdim.Output: tensor with the same shape as the input.
Dual traces let positive and negative evidence retain different time constants. Pseudocode looks as follows:
pos_mem = pos_beta * pos_mem + where(x >= 0, x, 0) neg_mem = neg_beta * neg_mem + where(x <= 0, x, 0) return pos_mem + neg_mem
Examples:
>>> layer = tt.snn.DLI(num_neurons=32) >>> input = torch.randn(16, 32) >>> output = layer(input) >>> print(output.shape) torch.Size([16, 32])
- class tracetorch.snn.SLI(*args: Any, **kwargs: Any)[source]
Bases:
LayerA synaptic leaky integrator layer with continuous membrane output.
SLIadds a synaptic trace before the membrane trace. The synaptic trace smooths the input with decayalphabefore the membrane integrates it with decaybeta.- Parameters:
num_neurons (int) – number of neurons in the target dimension.
alpha (float or torch.Tensor, default=0.5) – synaptic decay, constrained to
(0, 1).beta (float or torch.Tensor, default=0.9) – membrane decay, constrained to
(0, 1).dim (int, default=-1) – the dimension along which the layer operates.
alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron synaptic decay.
beta_rank (Literal[0, 1], default=1) – scalar or per-neuron membrane decay.
learn_alpha (bool, default=True) – whether
alphais trainable.learn_beta (bool, default=True) – whether
betais trainable.
- Variables:
syn – synaptic state.
mem – membrane state.
alpha – activated synaptic decay.
beta – activated membrane decay.
Notes
Input: tensor of shape
[*,num_neurons,*]wherenum_neuronsis at indexdim.Output: tensor with the same shape as the input.
The synaptic trace is an exponential moving average of the input. The membrane then accumulates that smoothed current. Pseudocode looks as follows:
syn = alpha * syn + (1 - alpha) * x mem = beta * mem + syn return mem
Examples:
>>> layer = tt.snn.SLI(num_neurons=32) >>> input = torch.rand(16, 32) >>> output = layer(input) >>> print(output.shape) torch.Size([16, 32])
- class tracetorch.snn.DSLI(*args: Any, **kwargs: Any)[source]
Bases:
LayerA dual synaptic leaky integrator layer with continuous membrane output.
DSLIcombines dual positive/negative traces with a synaptic stage. It keeps separate positive and negative synaptic traces, then integrates their sum into separate positive and negative membrane traces.- Parameters:
num_neurons (int) – number of neurons in the target dimension.
pos_alpha (float or torch.Tensor, default=0.5) – positive synaptic decay.
neg_alpha (float or torch.Tensor, default=0.5) – negative synaptic decay.
pos_beta (float or torch.Tensor, default=0.9) – positive membrane decay.
neg_beta (float or torch.Tensor, default=0.9) – negative membrane decay.
dim (int, default=-1) – the dimension along which the layer operates.
pos_alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron positive synaptic decay.
neg_alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron negative synaptic decay.
pos_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron positive membrane decay.
neg_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron negative membrane decay.
learn_pos_alpha (bool, default=True) – whether
pos_alphais trainable.learn_neg_alpha (bool, default=True) – whether
neg_alphais trainable.learn_pos_beta (bool, default=True) – whether
pos_betais trainable.learn_neg_beta (bool, default=True) – whether
neg_betais trainable.
- Variables:
pos_syn – positive synaptic state.
neg_syn – negative synaptic state.
pos_mem – positive membrane state.
neg_mem – negative membrane state.
Notes
Input: tensor of shape
[*,num_neurons,*]wherenum_neuronsis at indexdim.Output: tensor with the same shape as the input.
Pseudocode looks as follows:
pos_syn = pos_alpha * pos_syn + (1 - pos_alpha) * where(x >= 0, x, 0) neg_syn = neg_alpha * neg_syn + (1 - neg_alpha) * where(x <= 0, x, 0) syn = pos_syn + neg_syn pos_mem = pos_beta * pos_mem + where(syn >= 0, syn, 0) neg_mem = neg_beta * neg_mem + where(syn <= 0, syn, 0) return pos_mem + neg_mem
Examples:
>>> layer = tt.snn.DSLI(num_neurons=32) >>> input = torch.randn(16, 32) >>> output = layer(input) >>> print(output.shape) torch.Size([16, 32])
- class tracetorch.snn.LIEMA(*args: Any, **kwargs: Any)[source]
Bases:
LayerA leaky integrator layer with exponential-moving-average output.
LIEMAis the bounded counterpart toLI. It stores one membrane trace and updates it as an exponential moving average of the input instead of as an unnormalized accumulator.- Parameters:
num_neurons (int) – number of neurons in the target dimension.
beta (float or torch.Tensor, default=0.9) – membrane EMA decay, constrained to
(0, 1).dim (int, default=-1) – the dimension along which the layer operates.
beta_rank (Literal[0, 1], default=1) – scalar or per-neuron decay.
learn_beta (bool, default=True) – whether
betais trainable.
- Variables:
mem – membrane EMA state.
beta – activated membrane decay.
Notes
Input: tensor of shape
[*,num_neurons,*]wherenum_neuronsis at indexdim.Output: tensor with the same shape as the input.
Pseudocode looks as follows:
mem = beta * mem + (1 - beta) * x return mem
Examples:
>>> layer = tt.snn.LIEMA(num_neurons=32) >>> input = torch.rand(16, 32) >>> output = layer(input) >>> print(output.shape) torch.Size([16, 32])
- class tracetorch.snn.DLIEMA(*args: Any, **kwargs: Any)[source]
Bases:
LayerA dual leaky integrator layer with exponential-moving-average output.
DLIEMAkeeps positive and negative membrane EMA traces separately and returns their sum. This is useful when positive and negative evidence should decay independently without allowing unbounded accumulation.- Parameters:
num_neurons (int) – number of neurons in the target dimension.
pos_beta (float or torch.Tensor, default=0.9) – positive membrane EMA decay.
neg_beta (float or torch.Tensor, default=0.9) – negative membrane EMA decay.
dim (int, default=-1) – the dimension along which the layer operates.
pos_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron positive decay.
neg_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron negative decay.
learn_pos_beta (bool, default=True) – whether
pos_betais trainable.learn_neg_beta (bool, default=True) – whether
neg_betais trainable.
- Variables:
pos_mem – positive membrane EMA state.
neg_mem – negative membrane EMA state.
Notes
Input: tensor of shape
[*,num_neurons,*]wherenum_neuronsis at indexdim.Output: tensor with the same shape as the input.
Pseudocode looks as follows:
pos_mem = pos_beta * pos_mem + (1 - pos_beta) * where(x >= 0, x, 0) neg_mem = neg_beta * neg_mem + (1 - neg_beta) * where(x <= 0, x, 0) return pos_mem + neg_mem
Examples:
>>> layer = tt.snn.DLIEMA(num_neurons=32) >>> input = torch.randn(16, 32) >>> output = layer(input) >>> print(output.shape) torch.Size([16, 32])
- class tracetorch.snn.SLIEMA(*args: Any, **kwargs: Any)[source]
Bases:
LayerA synaptic leaky integrator layer with exponential-moving-average output.
SLIEMAsmooths the input through a synaptic EMA and then updates the membrane as an EMA of that synaptic current.- Parameters:
num_neurons (int) – number of neurons in the target dimension.
alpha (float or torch.Tensor, default=0.5) – synaptic decay, constrained to
(0, 1).beta (float or torch.Tensor, default=0.9) – membrane decay, constrained to
(0, 1).dim (int, default=-1) – the dimension along which the layer operates.
alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron synaptic decay.
beta_rank (Literal[0, 1], default=1) – scalar or per-neuron membrane decay.
learn_alpha (bool, default=True) – whether
alphais trainable.learn_beta (bool, default=True) – whether
betais trainable.
- Variables:
syn – synaptic EMA state.
mem – membrane EMA state.
Notes
Input: tensor of shape
[*,num_neurons,*]wherenum_neuronsis at indexdim.Output: tensor with the same shape as the input.
Pseudocode looks as follows:
syn = alpha * syn + (1 - alpha) * x mem = beta * mem + (1 - beta) * syn return mem
Examples:
>>> layer = tt.snn.SLIEMA(num_neurons=32) >>> input = torch.rand(16, 32) >>> output = layer(input) >>> print(output.shape) torch.Size([16, 32])
- class tracetorch.snn.DSLIEMA(*args: Any, **kwargs: Any)[source]
Bases:
LayerA dual synaptic leaky integrator layer with EMA membrane output.
DSLIEMAis the dual, synaptic, bounded-output variant of the leaky integrator family. It keeps positive and negative synaptic traces, then updates positive and negative membrane EMA traces from their combined synaptic current.- Parameters:
num_neurons (int) – number of neurons in the target dimension.
pos_alpha (float or torch.Tensor, default=0.5) – positive synaptic decay.
neg_alpha (float or torch.Tensor, default=0.5) – negative synaptic decay.
pos_beta (float or torch.Tensor, default=0.9) – positive membrane decay.
neg_beta (float or torch.Tensor, default=0.9) – negative membrane decay.
dim (int, default=-1) – the dimension along which the layer operates.
pos_alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron positive synaptic decay.
neg_alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron negative synaptic decay.
pos_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron positive membrane decay.
neg_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron negative membrane decay.
learn_pos_alpha (bool, default=True) – whether
pos_alphais trainable.learn_neg_alpha (bool, default=True) – whether
neg_alphais trainable.learn_pos_beta (bool, default=True) – whether
pos_betais trainable.learn_neg_beta (bool, default=True) – whether
neg_betais trainable.
- Variables:
pos_syn – positive synaptic EMA state.
neg_syn – negative synaptic EMA state.
pos_mem – positive membrane EMA state.
neg_mem – negative membrane EMA state.
Notes
Input: tensor of shape
[*,num_neurons,*]wherenum_neuronsis at indexdim.Output: tensor with the same shape as the input.
Pseudocode looks as follows:
pos_syn = pos_alpha * pos_syn + (1 - pos_alpha) * where(x >= 0, x, 0) neg_syn = neg_alpha * neg_syn + (1 - neg_alpha) * where(x <= 0, x, 0) syn = pos_syn + neg_syn pos_mem = pos_beta * pos_mem + (1 - pos_beta) * where(syn >= 0, syn, 0) neg_mem = neg_beta * neg_mem + (1 - neg_beta) * where(syn <= 0, syn, 0) return pos_mem + neg_mem
Examples:
>>> layer = tt.snn.DSLIEMA(num_neurons=32) >>> input = torch.randn(16, 32) >>> output = layer(input) >>> print(output.shape) torch.Size([16, 32])
Binary Firing Layers
- class tracetorch.snn.LIB(*args: Any, **kwargs: Any)[source]
Bases:
LayerA leaky integrate-and-binary-fire layer.
LIBis traceTorch’s one-sided firing layer. It stores one membrane trace, converts the distance from threshold into a firing probability, optionally quantizes that probability, subtracts the threshold-scaled output from the membrane, and returns the output.- Parameters:
num_neurons (int) – number of neurons in the target dimension.
beta (float or torch.Tensor, default=0.9) – membrane decay, constrained to
(0, 1).threshold (float or torch.Tensor, default=1.0) – positive firing threshold, constrained to positive values.
bias (float or torch.Tensor, default=0.0) – additive bias applied before the spike function.
dim (int, default=-1) – the dimension along which the layer operates.
beta_rank (Literal[0, 1], default=1) – scalar or per-neuron membrane decay.
threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron threshold.
bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.
learn_beta (bool, default=True) – whether
betais trainable.learn_threshold (bool, default=True) – whether
thresholdis trainable.learn_bias (bool, default=True) – whether
biasis trainable.spike_fn (Callable, default=tt.functional.sigmoid4x) – function that maps membrane distance from threshold to a firing probability.
quant_fn (Callable, default=nn.Identity()) – function that maps firing probability to the returned spike value.
- Variables:
mem – membrane state.
beta – activated membrane decay.
threshold – activated positive threshold.
bias – activated bias.
spike_fn – spike probability function.
quant_fn – output quantization function.
Notes
Input: tensor of shape
[*,num_neurons,*]wherenum_neuronsis at indexdim.Output: tensor with the same shape as the input.
With the default
quant_fn=nn.Identity(), the layer returns smooth firing probabilities. Pass a straight-through quantizer such astt.functional.round_ste()for harder binary events. Pseudocode looks as follows:mem = beta * mem + x spike_prob = spike_fn(mem - threshold + bias) spikes = quant_fn(spike_prob) mem = mem - spikes * threshold return spikes
Examples:
>>> layer = tt.snn.LIB(num_neurons=32) >>> input = torch.rand(16, 32) >>> output = layer(input) >>> print(output.shape) torch.Size([16, 32])
- class tracetorch.snn.DLIB(*args: Any, **kwargs: Any)[source]
Bases:
LayerA dual leaky integrate-and-binary-fire layer.
DLIBsplits membrane integration into positive and negative branches, but still emits a one-sided binary-style output. The two membrane branches are summed for thresholding, then the reset is split evenly across both branches.- Parameters:
num_neurons (int) – number of neurons in the target dimension.
pos_beta (float or torch.Tensor, default=0.9) – positive membrane decay.
neg_beta (float or torch.Tensor, default=0.9) – negative membrane decay.
threshold (float or torch.Tensor, default=1.0) – positive firing threshold.
bias (float or torch.Tensor, default=0.0) – additive bias before firing.
dim (int, default=-1) – the dimension along which the layer operates.
pos_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron positive decay.
neg_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron negative decay.
threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron threshold.
bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.
learn_pos_beta (bool, default=True) – whether
pos_betais trainable.learn_neg_beta (bool, default=True) – whether
neg_betais trainable.learn_threshold (bool, default=True) – whether
thresholdis trainable.learn_bias (bool, default=True) – whether
biasis trainable.spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.
quant_fn (Callable, default=nn.Identity()) – output quantization function.
- Variables:
pos_mem – positive membrane state.
neg_mem – negative membrane state.
threshold – activated positive threshold.
bias – activated bias.
Notes
Input: tensor of shape
[*,num_neurons,*]wherenum_neuronsis at indexdim.Output: tensor with the same shape as the input.
Pseudocode looks as follows:
pos_mem = pos_beta * pos_mem + where(x >= 0, x, 0) neg_mem = neg_beta * neg_mem + where(x <= 0, x, 0) mem = pos_mem + neg_mem spikes = quant_fn(spike_fn(mem - threshold + bias)) pos_mem = pos_mem - 0.5 * spikes * threshold neg_mem = neg_mem - 0.5 * spikes * threshold return spikes
Examples:
>>> layer = tt.snn.DLIB(num_neurons=32) >>> input = torch.randn(16, 32) >>> output = layer(input) >>> print(output.shape) torch.Size([16, 32])
- class tracetorch.snn.SLIB(*args: Any, **kwargs: Any)[source]
Bases:
LayerA synaptic leaky integrate-and-binary-fire layer.
SLIBsmooths the input through a synaptic trace before membrane integration and one-sided firing. This is useful when the input should behave like a current with its own time constant rather than an instantaneous membrane increment.- Parameters:
num_neurons (int) – number of neurons in the target dimension.
alpha (float or torch.Tensor, default=0.5) – synaptic decay.
beta (float or torch.Tensor, default=0.9) – membrane decay.
threshold (float or torch.Tensor, default=1.0) – positive firing threshold.
bias (float or torch.Tensor, default=0.0) – additive bias before firing.
dim (int, default=-1) – the dimension along which the layer operates.
alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron synaptic decay.
beta_rank (Literal[0, 1], default=1) – scalar or per-neuron membrane decay.
threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron threshold.
bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.
learn_alpha (bool, default=True) – whether
alphais trainable.learn_beta (bool, default=True) – whether
betais trainable.learn_threshold (bool, default=True) – whether
thresholdis trainable.learn_bias (bool, default=True) – whether
biasis trainable.spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.
quant_fn (Callable, default=nn.Identity()) – output quantization function.
- Variables:
syn – synaptic state.
mem – membrane state.
Notes
Input: tensor of shape
[*,num_neurons,*]wherenum_neuronsis at indexdim.Output: tensor with the same shape as the input.
Pseudocode looks as follows:
syn = alpha * syn + (1 - alpha) * x mem = beta * mem + syn spikes = quant_fn(spike_fn(mem - threshold + bias)) mem = mem - spikes * threshold return spikes
Examples:
>>> layer = tt.snn.SLIB(num_neurons=32) >>> input = torch.rand(16, 32) >>> output = layer(input) >>> print(output.shape) torch.Size([16, 32])
- class tracetorch.snn.RLIB(*args: Any, **kwargs: Any)[source]
Bases:
LayerA recurrent leaky integrate-and-binary-fire layer.
RLIBadds a recurrent trace of the previous output. The recurrent trace is decayed withgamma, scaled byrec_weight, and added to the current input before membrane integration.- Parameters:
num_neurons (int) – number of neurons in the target dimension.
beta (float or torch.Tensor, default=0.9) – membrane decay.
gamma (float or torch.Tensor, default=0.9) – recurrent trace decay.
threshold (float or torch.Tensor, default=1.0) – positive firing threshold.
bias (float or torch.Tensor, default=0.0) – additive bias before firing.
rec_weight (float or torch.Tensor, default=0.0) – recurrent input scale.
dim (int, default=-1) – the dimension along which the layer operates.
beta_rank (Literal[0, 1], default=1) – scalar or per-neuron membrane decay.
gamma_rank (Literal[0, 1], default=1) – scalar or per-neuron recurrent decay.
threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron threshold.
bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.
rec_weight_rank (Literal[0, 1], default=1) – scalar or per-neuron recurrent scale.
learn_beta (bool, default=True) – whether
betais trainable.learn_gamma (bool, default=True) – whether
gammais trainable.learn_threshold (bool, default=True) – whether
thresholdis trainable.learn_bias (bool, default=True) – whether
biasis trainable.learn_rec_weight (bool, default=True) – whether
rec_weightis trainable.spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.
quant_fn (Callable, default=nn.Identity()) – output quantization function.
- Variables:
mem – membrane state.
rec – recurrent trace state.
prev_output – previous returned output.
Notes
Input: tensor of shape
[*,num_neurons,*]wherenum_neuronsis at indexdim.Output: tensor with the same shape as the input.
Pseudocode looks as follows:
rec = gamma * rec + (1 - gamma) * prev_output mem = beta * mem + x + rec_weight * rec spikes = quant_fn(spike_fn(mem - threshold + bias)) mem = mem - spikes * threshold prev_output = spikes return spikes
Examples:
>>> layer = tt.snn.RLIB(num_neurons=32) >>> input = torch.rand(16, 32) >>> output = layer(input) >>> print(output.shape) torch.Size([16, 32])
- class tracetorch.snn.DSLIB(*args: Any, **kwargs: Any)[source]
Bases:
LayerA dual synaptic leaky integrate-and-binary-fire layer.
DSLIBcombines dual positive/negative traces with a synaptic stage and a one-sided firing output. Positive and negative inputs are smoothed separately, summed, integrated into dual membrane traces, and thresholded as one combined membrane.- Parameters:
num_neurons (int) – number of neurons in the target dimension.
pos_alpha (float or torch.Tensor, default=0.5) – positive synaptic decay.
neg_alpha (float or torch.Tensor, default=0.5) – negative synaptic decay.
pos_beta (float or torch.Tensor, default=0.9) – positive membrane decay.
neg_beta (float or torch.Tensor, default=0.9) – negative membrane decay.
threshold (float or torch.Tensor, default=1.0) – positive firing threshold.
bias (float or torch.Tensor, default=0.0) – additive bias before firing.
dim (int, default=-1) – the dimension along which the layer operates.
pos_alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron positive synaptic decay.
neg_alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron negative synaptic decay.
pos_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron positive membrane decay.
neg_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron negative membrane decay.
threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron threshold.
bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.
learn_pos_alpha (bool, default=True) – whether
pos_alphais trainable.learn_neg_alpha (bool, default=True) – whether
neg_alphais trainable.learn_pos_beta (bool, default=True) – whether
pos_betais trainable.learn_neg_beta (bool, default=True) – whether
neg_betais trainable.learn_threshold (bool, default=True) – whether
thresholdis trainable.learn_bias (bool, default=True) – whether
biasis trainable.spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.
quant_fn (Callable, default=nn.Identity()) – output quantization function.
- Variables:
pos_syn – positive synaptic state.
neg_syn – negative synaptic state.
pos_mem – positive membrane state.
neg_mem – negative membrane state.
Notes
Input: tensor of shape
[*,num_neurons,*]wherenum_neuronsis at indexdim.Output: tensor with the same shape as the input.
The reset is split evenly across the dual membrane branches. Pseudocode looks as follows:
pos_syn = pos_alpha * pos_syn + (1 - pos_alpha) * where(x >= 0, x, 0) neg_syn = neg_alpha * neg_syn + (1 - neg_alpha) * where(x <= 0, x, 0) syn = pos_syn + neg_syn pos_mem = pos_beta * pos_mem + where(syn >= 0, syn, 0) neg_mem = neg_beta * neg_mem + where(syn <= 0, syn, 0) spikes = quant_fn(spike_fn(pos_mem + neg_mem - threshold + bias)) pos_mem = pos_mem - 0.5 * spikes * threshold neg_mem = neg_mem - 0.5 * spikes * threshold return spikes
Examples:
>>> layer = tt.snn.DSLIB(num_neurons=32) >>> input = torch.randn(16, 32) >>> output = layer(input) >>> print(output.shape) torch.Size([16, 32])
- class tracetorch.snn.DRLIB(*args: Any, **kwargs: Any)[source]
Bases:
LayerA dual recurrent leaky integrate-and-binary-fire layer.
DRLIBkeeps positive and negative membrane traces and positive and negative recurrent traces. The previous output is split by sign, smoothed into recurrent traces, scaled, and integrated with the current input before a one-sided binary firing decision.- Parameters:
num_neurons (int) – number of neurons in the target dimension.
pos_beta (float or torch.Tensor, default=0.9) – positive membrane decay.
neg_beta (float or torch.Tensor, default=0.9) – negative membrane decay.
pos_gamma (float or torch.Tensor, default=0.9) – positive recurrent decay.
neg_gamma (float or torch.Tensor, default=0.9) – negative recurrent decay.
threshold (float or torch.Tensor, default=1.0) – positive firing threshold.
bias (float or torch.Tensor, default=0.0) – additive bias before firing.
pos_rec_weight (float or torch.Tensor, default=0.0) – positive recurrent input scale.
neg_rec_weight (float or torch.Tensor, default=0.0) – negative recurrent input scale.
dim (int, default=-1) – the dimension along which the layer operates.
pos_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron positive membrane decay.
neg_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron negative membrane decay.
pos_gamma_rank (Literal[0, 1], default=1) – scalar or per-neuron positive recurrent decay.
neg_gamma_rank (Literal[0, 1], default=1) – scalar or per-neuron negative recurrent decay.
threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron threshold.
bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.
pos_rec_weight_rank (Literal[0, 1], default=1) – scalar or per-neuron positive recurrent scale.
neg_rec_weight_rank (Literal[0, 1], default=1) – scalar or per-neuron negative recurrent scale.
learn_pos_beta (bool, default=True) – whether
pos_betais trainable.learn_neg_beta (bool, default=True) – whether
neg_betais trainable.learn_pos_gamma (bool, default=True) – whether
pos_gammais trainable.learn_neg_gamma (bool, default=True) – whether
neg_gammais trainable.learn_threshold (bool, default=True) – whether
thresholdis trainable.learn_bias (bool, default=True) – whether
biasis trainable.learn_pos_rec_weight (bool, default=True) – whether
pos_rec_weightis trainable.learn_neg_rec_weight (bool, default=True) – whether
neg_rec_weightis trainable.spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.
quant_fn (Callable, default=nn.Identity()) – output quantization function.
- Variables:
pos_mem – positive membrane state.
neg_mem – negative membrane state.
pos_rec – positive recurrent trace state.
neg_rec – negative recurrent trace state.
prev_output – previous returned output.
Notes
Input: tensor of shape
[*,num_neurons,*]wherenum_neuronsis at indexdim.Output: tensor with the same shape as the input.
Pseudocode looks as follows:
pos_rec = pos_gamma * pos_rec + (1 - pos_gamma) * where(prev_output >= 0, prev_output, 0) neg_rec = neg_gamma * neg_rec + (1 - neg_gamma) * where(prev_output <= 0, prev_output, 0) pos_mem = pos_beta * pos_mem + pos_rec_weight * where(pos_rec + neg_rec >= 0, pos_rec + neg_rec, 0) + where(x >= 0, x, 0) neg_mem = neg_beta * neg_mem + neg_rec_weight * where(pos_rec + neg_rec <= 0, pos_rec + neg_rec, 0) + where(x <= 0, x, 0) spikes = quant_fn(spike_fn(pos_mem + neg_mem - threshold + bias)) prev_output = spikes return spikes
Examples:
>>> layer = tt.snn.DRLIB(num_neurons=32) >>> input = torch.randn(16, 32) >>> output = layer(input) >>> print(output.shape) torch.Size([16, 32])
- class tracetorch.snn.SRLIB(*args: Any, **kwargs: Any)[source]
Bases:
LayerA synaptic recurrent leaky integrate-and-binary-fire layer.
SRLIBcombines an input synaptic trace with a recurrent trace of the previous output. The membrane receives both the smoothed input and the scaled recurrent current before one-sided firing.- Parameters:
num_neurons (int) – number of neurons in the target dimension.
alpha (float or torch.Tensor, default=0.5) – synaptic decay.
beta (float or torch.Tensor, default=0.9) – membrane decay.
gamma (float or torch.Tensor, default=0.9) – recurrent trace decay.
threshold (float or torch.Tensor, default=1.0) – positive firing threshold.
bias (float or torch.Tensor, default=0.0) – additive bias before firing.
rec_weight (float or torch.Tensor, default=0.0) – recurrent input scale.
dim (int, default=-1) – the dimension along which the layer operates.
alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron synaptic decay.
beta_rank (Literal[0, 1], default=1) – scalar or per-neuron membrane decay.
gamma_rank (Literal[0, 1], default=1) – scalar or per-neuron recurrent decay.
threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron threshold.
bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.
rec_weight_rank (Literal[0, 1], default=1) – scalar or per-neuron recurrent scale.
learn_alpha (bool, default=True) – whether
alphais trainable.learn_beta (bool, default=True) – whether
betais trainable.learn_gamma (bool, default=True) – whether
gammais trainable.learn_threshold (bool, default=True) – whether
thresholdis trainable.learn_bias (bool, default=True) – whether
biasis trainable.learn_rec_weight (bool, default=True) – whether
rec_weightis trainable.spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.
quant_fn (Callable, default=nn.Identity()) – output quantization function.
- Variables:
syn – synaptic state.
mem – membrane state.
rec – recurrent trace state.
prev_output – previous returned output.
Notes
Input: tensor of shape
[*,num_neurons,*]wherenum_neuronsis at indexdim.Output: tensor with the same shape as the input.
Pseudocode looks as follows:
syn = alpha * syn + (1 - alpha) * x rec = gamma * rec + (1 - gamma) * prev_output mem = beta * mem + syn + rec_weight * rec spikes = quant_fn(spike_fn(mem - threshold + bias)) mem = mem - spikes * threshold prev_output = spikes return spikes
Examples:
>>> layer = tt.snn.SRLIB(num_neurons=32) >>> input = torch.rand(16, 32) >>> output = layer(input) >>> print(output.shape) torch.Size([16, 32])
- class tracetorch.snn.DSRLIB(*args: Any, **kwargs: Any)[source]
Bases:
LayerA dual synaptic recurrent leaky integrate-and-binary-fire layer.
DSRLIBis the most expressive binary SNN layer in traceTorch. It combines dual positive/negative synaptic traces, dual positive/negative recurrent traces, dual positive/negative membrane traces, and a one-sided binary firing output.- Parameters:
num_neurons (int) – number of neurons in the target dimension.
pos_alpha (float or torch.Tensor, default=0.5) – positive synaptic decay.
neg_alpha (float or torch.Tensor, default=0.5) – negative synaptic decay.
pos_beta (float or torch.Tensor, default=0.9) – positive membrane decay.
neg_beta (float or torch.Tensor, default=0.9) – negative membrane decay.
pos_gamma (float or torch.Tensor, default=0.9) – positive recurrent decay.
neg_gamma (float or torch.Tensor, default=0.9) – negative recurrent decay.
threshold (float or torch.Tensor, default=1.0) – positive firing threshold.
bias (float or torch.Tensor, default=0.0) – additive bias before firing.
pos_rec_weight (float or torch.Tensor, default=0.0) – positive recurrent input scale.
neg_rec_weight (float or torch.Tensor, default=0.0) – negative recurrent input scale.
dim (int, default=-1) – the dimension along which the layer operates.
pos_alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron positive synaptic decay.
neg_alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron negative synaptic decay.
pos_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron positive membrane decay.
neg_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron negative membrane decay.
pos_gamma_rank (Literal[0, 1], default=1) – scalar or per-neuron positive recurrent decay.
neg_gamma_rank (Literal[0, 1], default=1) – scalar or per-neuron negative recurrent decay.
threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron threshold.
bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.
pos_rec_weight_rank (Literal[0, 1], default=1) – scalar or per-neuron positive recurrent scale.
neg_rec_weight_rank (Literal[0, 1], default=1) – scalar or per-neuron negative recurrent scale.
learn_pos_alpha (bool, default=True) – whether
pos_alphais trainable.learn_neg_alpha (bool, default=True) – whether
neg_alphais trainable.learn_pos_beta (bool, default=True) – whether
pos_betais trainable.learn_neg_beta (bool, default=True) – whether
neg_betais trainable.learn_pos_gamma (bool, default=True) – whether
pos_gammais trainable.learn_neg_gamma (bool, default=True) – whether
neg_gammais trainable.learn_threshold (bool, default=True) – whether
thresholdis trainable.learn_bias (bool, default=True) – whether
biasis trainable.learn_pos_rec_weight (bool, default=True) – whether
pos_rec_weightis trainable.learn_neg_rec_weight (bool, default=True) – whether
neg_rec_weightis trainable.spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.
quant_fn (Callable, default=nn.Identity()) – output quantization function.
- Variables:
pos_syn – positive synaptic state.
neg_syn – negative synaptic state.
pos_mem – positive membrane state.
neg_mem – negative membrane state.
pos_rec – positive recurrent trace state.
neg_rec – negative recurrent trace state.
prev_output – previous returned output.
Notes
Input: tensor of shape
[*,num_neurons,*]wherenum_neuronsis at indexdim.Output: tensor with the same shape as the input.
DSRLIBis useful when the sign of the input and the sign of the recurrent history should both have independent memory. The firing output is still one-sided: negative internal evidence can suppress firing, but the returned output is non-negative unless a customquant_fnchanges that convention.Examples:
>>> layer = tt.snn.DSRLIB(num_neurons=32) >>> input = torch.randn(16, 32) >>> output = layer(input) >>> print(output.shape) torch.Size([16, 32])
Ternary Firing Layers
- class tracetorch.snn.LIT(*args: Any, **kwargs: Any)[source]
Bases:
LayerA leaky integrate-and-ternary-fire layer.
LITstores one membrane trace and can emit positive, zero, or negative output. Positive firing is controlled bypos_threshold; negative firing is controlled byneg_threshold. The negative branch returns negative values by convention.- Parameters:
num_neurons (int) – number of neurons in the target dimension.
beta (float or torch.Tensor, default=0.9) – membrane decay, constrained to
(0, 1).pos_threshold (float or torch.Tensor, default=1.0) – positive firing threshold, constrained to positive values.
neg_threshold (float or torch.Tensor, default=1.0) – magnitude of the negative firing threshold, constrained to positive values.
bias (float or torch.Tensor, default=0.0) – bias that shifts both firing boundaries.
dim (int, default=-1) – the dimension along which the layer operates.
beta_rank (Literal[0, 1], default=1) – scalar or per-neuron membrane decay.
pos_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron positive threshold.
neg_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron negative threshold magnitude.
bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.
learn_beta (bool, default=True) – whether
betais trainable.learn_pos_threshold (bool, default=True) – whether
pos_thresholdis trainable.learn_neg_threshold (bool, default=True) – whether
neg_thresholdis trainable.learn_bias (bool, default=True) – whether
biasis trainable.spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.
quant_fn (Callable, default=nn.Identity()) – output quantization function.
- Variables:
mem – membrane state.
beta – activated membrane decay.
pos_threshold – activated positive threshold.
neg_threshold – activated negative threshold magnitude.
bias – activated bias.
Notes
Input: tensor of shape
[*,num_neurons,*]wherenum_neuronsis at indexdim.Output: tensor with the same shape as the input.
With the default identity quantizer, positive and negative outputs are smooth probabilities with opposite signs. Pseudocode looks as follows:
mem = beta * mem + x pos = quant_fn(spike_fn(mem - pos_threshold + bias)) neg = -quant_fn(spike_fn(-neg_threshold - mem - bias)) mem = mem - pos * pos_threshold mem = mem - neg * neg_threshold return pos + neg
Examples:
>>> layer = tt.snn.LIT(num_neurons=32) >>> input = torch.randn(16, 32) >>> output = layer(input) >>> print(output.shape) torch.Size([16, 32])
- class tracetorch.snn.DLIT(*args: Any, **kwargs: Any)[source]
Bases:
LayerA dual leaky integrate-and-ternary-fire layer.
DLITsplits membrane integration into positive and negative branches and emits ternary-style output. The summed membrane is thresholded, and each reset is split evenly across the positive and negative membrane branches.- Parameters:
num_neurons (int) – number of neurons in the target dimension.
pos_beta (float or torch.Tensor, default=0.9) – positive membrane decay.
neg_beta (float or torch.Tensor, default=0.9) – negative membrane decay.
pos_threshold (float or torch.Tensor, default=1.0) – positive threshold.
neg_threshold (float or torch.Tensor, default=1.0) – negative threshold magnitude.
bias (float or torch.Tensor, default=0.0) – bias that shifts both firing boundaries.
dim (int, default=-1) – the dimension along which the layer operates.
pos_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron positive decay.
neg_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron negative decay.
pos_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron positive threshold.
neg_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron negative threshold magnitude.
bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.
learn_pos_beta (bool, default=True) – whether
pos_betais trainable.learn_neg_beta (bool, default=True) – whether
neg_betais trainable.learn_pos_threshold (bool, default=True) – whether
pos_thresholdis trainable.learn_neg_threshold (bool, default=True) – whether
neg_thresholdis trainable.learn_bias (bool, default=True) – whether
biasis trainable.spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.
quant_fn (Callable, default=nn.Identity()) – output quantization function.
- Variables:
pos_mem – positive membrane state.
neg_mem – negative membrane state.
Notes
Input: tensor of shape
[*,num_neurons,*]wherenum_neuronsis at indexdim.Output: tensor with the same shape as the input.
Pseudocode looks as follows:
pos_mem = pos_beta * pos_mem + where(x >= 0, x, 0) neg_mem = neg_beta * neg_mem + where(x < 0, x, 0) mem = pos_mem + neg_mem pos = quant_fn(spike_fn(mem - pos_threshold + bias)) neg = -quant_fn(spike_fn(-neg_threshold - mem - bias)) pos_mem = pos_mem - 0.5 * pos * pos_threshold - 0.5 * neg * neg_threshold neg_mem = neg_mem - 0.5 * pos * pos_threshold - 0.5 * neg * neg_threshold return pos + neg
Examples:
>>> layer = tt.snn.DLIT(num_neurons=32) >>> input = torch.randn(16, 32) >>> output = layer(input) >>> print(output.shape) torch.Size([16, 32])
- class tracetorch.snn.SLIT(*args: Any, **kwargs: Any)[source]
Bases:
LayerA synaptic leaky integrate-and-ternary-fire layer.
SLITsmooths the input through a synaptic trace before membrane integration and ternary firing. It is the ternary counterpart ofSLIB.- Parameters:
num_neurons (int) – number of neurons in the target dimension.
alpha (float or torch.Tensor, default=0.5) – synaptic decay.
beta (float or torch.Tensor, default=0.9) – membrane decay.
pos_threshold (float or torch.Tensor, default=1.0) – positive threshold.
neg_threshold (float or torch.Tensor, default=1.0) – negative threshold magnitude.
bias (float or torch.Tensor, default=0.0) – bias that shifts both firing boundaries.
dim (int, default=-1) – the dimension along which the layer operates.
alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron synaptic decay.
beta_rank (Literal[0, 1], default=1) – scalar or per-neuron membrane decay.
pos_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron positive threshold.
neg_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron negative threshold magnitude.
bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.
learn_alpha (bool, default=True) – whether
alphais trainable.learn_beta (bool, default=True) – whether
betais trainable.learn_pos_threshold (bool, default=True) – whether
pos_thresholdis trainable.learn_neg_threshold (bool, default=True) – whether
neg_thresholdis trainable.learn_bias (bool, default=True) – whether
biasis trainable.spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.
quant_fn (Callable, default=nn.Identity()) – output quantization function.
- Variables:
syn – synaptic state.
mem – membrane state.
Notes
Input: tensor of shape
[*,num_neurons,*]wherenum_neuronsis at indexdim.Output: tensor with the same shape as the input.
Pseudocode looks as follows:
syn = alpha * syn + (1 - alpha) * x mem = beta * mem + syn pos = quant_fn(spike_fn(mem - pos_threshold + bias)) neg = -quant_fn(spike_fn(-neg_threshold - mem - bias)) mem = mem - pos * pos_threshold - neg * neg_threshold return pos + neg
Examples:
>>> layer = tt.snn.SLIT(num_neurons=32) >>> input = torch.randn(16, 32) >>> output = layer(input) >>> print(output.shape) torch.Size([16, 32])
- class tracetorch.snn.RLIT(*args: Any, **kwargs: Any)[source]
Bases:
LayerA recurrent leaky integrate-and-ternary-fire layer.
RLITadds a recurrent trace of the previous ternary output. The recurrent trace is scaled byrec_weightand added to the input before membrane integration and ternary firing.- Parameters:
num_neurons (int) – number of neurons in the target dimension.
beta (float or torch.Tensor, default=0.9) – membrane decay.
gamma (float or torch.Tensor, default=0.9) – recurrent trace decay.
pos_threshold (float or torch.Tensor, default=1.0) – positive threshold.
neg_threshold (float or torch.Tensor, default=1.0) – negative threshold magnitude.
bias (float or torch.Tensor, default=0.0) – bias that shifts both firing boundaries.
rec_weight (float or torch.Tensor, default=0.0) – recurrent input scale.
dim (int, default=-1) – the dimension along which the layer operates.
beta_rank (Literal[0, 1], default=1) – scalar or per-neuron membrane decay.
gamma_rank (Literal[0, 1], default=1) – scalar or per-neuron recurrent decay.
pos_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron positive threshold.
neg_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron negative threshold magnitude.
bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.
rec_weight_rank (Literal[0, 1], default=1) – scalar or per-neuron recurrent scale.
learn_beta (bool, default=True) – whether
betais trainable.learn_gamma (bool, default=True) – whether
gammais trainable.learn_pos_threshold (bool, default=True) – whether
pos_thresholdis trainable.learn_neg_threshold (bool, default=True) – whether
neg_thresholdis trainable.learn_bias (bool, default=True) – whether
biasis trainable.learn_rec_weight (bool, default=True) – whether
rec_weightis trainable.spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.
quant_fn (Callable, default=nn.Identity()) – output quantization function.
- Variables:
mem – membrane state.
rec – recurrent trace state.
prev_output – previous returned output.
Notes
Input: tensor of shape
[*,num_neurons,*]wherenum_neuronsis at indexdim.Output: tensor with the same shape as the input.
Pseudocode looks as follows:
rec = gamma * rec + (1 - gamma) * prev_output mem = beta * mem + x + rec_weight * rec pos = quant_fn(spike_fn(mem - pos_threshold + bias)) neg = -quant_fn(spike_fn(-neg_threshold - mem - bias)) mem = mem - pos * pos_threshold - neg * neg_threshold prev_output = pos + neg return prev_output
Examples:
>>> layer = tt.snn.RLIT(num_neurons=32) >>> input = torch.randn(16, 32) >>> output = layer(input) >>> print(output.shape) torch.Size([16, 32])
- class tracetorch.snn.DSLIT(*args: Any, **kwargs: Any)[source]
Bases:
LayerA dual synaptic leaky integrate-and-ternary-fire layer.
DSLITcombines dual positive/negative traces with a synaptic stage and ternary output. It smooths positive and negative input separately before integrating the combined synaptic current into dual membrane traces.- Parameters:
num_neurons (int) – number of neurons in the target dimension.
pos_alpha (float or torch.Tensor, default=0.5) – positive synaptic decay.
neg_alpha (float or torch.Tensor, default=0.5) – negative synaptic decay.
pos_beta (float or torch.Tensor, default=0.9) – positive membrane decay.
neg_beta (float or torch.Tensor, default=0.9) – negative membrane decay.
pos_threshold (float or torch.Tensor, default=1.0) – positive threshold.
neg_threshold (float or torch.Tensor, default=1.0) – negative threshold magnitude.
bias (float or torch.Tensor, default=0.0) – bias that shifts both firing boundaries.
dim (int, default=-1) – the dimension along which the layer operates.
pos_alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron positive synaptic decay.
neg_alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron negative synaptic decay.
pos_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron positive membrane decay.
neg_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron negative membrane decay.
pos_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron positive threshold.
neg_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron negative threshold magnitude.
bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.
learn_pos_alpha (bool, default=True) – whether
pos_alphais trainable.learn_neg_alpha (bool, default=True) – whether
neg_alphais trainable.learn_pos_beta (bool, default=True) – whether
pos_betais trainable.learn_neg_beta (bool, default=True) – whether
neg_betais trainable.learn_pos_threshold (bool, default=True) – whether
pos_thresholdis trainable.learn_neg_threshold (bool, default=True) – whether
neg_thresholdis trainable.learn_bias (bool, default=True) – whether
biasis trainable.spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.
quant_fn (Callable, default=nn.Identity()) – output quantization function.
- Variables:
pos_syn – positive synaptic state.
neg_syn – negative synaptic state.
pos_mem – positive membrane state.
neg_mem – negative membrane state.
Notes
Input: tensor of shape
[*,num_neurons,*]wherenum_neuronsis at indexdim.Output: tensor with the same shape as the input.
The membrane reset is split evenly across the two membrane branches.
Examples:
>>> layer = tt.snn.DSLIT(num_neurons=32) >>> input = torch.randn(16, 32) >>> output = layer(input) >>> print(output.shape) torch.Size([16, 32])
- class tracetorch.snn.DRLIT(*args: Any, **kwargs: Any)[source]
Bases:
LayerA dual recurrent leaky integrate-and-ternary-fire layer.
DRLITkeeps dual membrane traces and dual recurrent traces for ternary output. The previous output is split by sign into recurrent branches, then reintegrated with the current input before the ternary firing decision.- Parameters:
num_neurons (int) – number of neurons in the target dimension.
pos_beta (float or torch.Tensor, default=0.9) – positive membrane decay.
neg_beta (float or torch.Tensor, default=0.9) – negative membrane decay.
pos_gamma (float or torch.Tensor, default=0.9) – positive recurrent decay.
neg_gamma (float or torch.Tensor, default=0.9) – negative recurrent decay.
pos_threshold (float or torch.Tensor, default=1.0) – positive threshold.
neg_threshold (float or torch.Tensor, default=1.0) – negative threshold magnitude.
bias (float or torch.Tensor, default=0.0) – bias that shifts both firing boundaries.
pos_rec_weight (float or torch.Tensor, default=0.0) – positive recurrent input scale.
neg_rec_weight (float or torch.Tensor, default=0.0) – negative recurrent input scale.
dim (int, default=-1) – the dimension along which the layer operates.
pos_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron positive membrane decay.
neg_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron negative membrane decay.
pos_gamma_rank (Literal[0, 1], default=1) – scalar or per-neuron positive recurrent decay.
neg_gamma_rank (Literal[0, 1], default=1) – scalar or per-neuron negative recurrent decay.
pos_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron positive threshold.
neg_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron negative threshold magnitude.
bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.
pos_rec_weight_rank (Literal[0, 1], default=1) – scalar or per-neuron positive recurrent scale.
neg_rec_weight_rank (Literal[0, 1], default=1) – scalar or per-neuron negative recurrent scale.
learn_pos_beta (bool, default=True) – whether
pos_betais trainable.learn_neg_beta (bool, default=True) – whether
neg_betais trainable.learn_pos_gamma (bool, default=True) – whether
pos_gammais trainable.learn_neg_gamma (bool, default=True) – whether
neg_gammais trainable.learn_pos_threshold (bool, default=True) – whether
pos_thresholdis trainable.learn_neg_threshold (bool, default=True) – whether
neg_thresholdis trainable.learn_bias (bool, default=True) – whether
biasis trainable.learn_pos_rec_weight (bool, default=True) – whether
pos_rec_weightis trainable.learn_neg_rec_weight (bool, default=True) – whether
neg_rec_weightis trainable.spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.
quant_fn (Callable, default=nn.Identity()) – output quantization function.
- Variables:
pos_mem – positive membrane state.
neg_mem – negative membrane state.
pos_rec – positive recurrent trace state.
neg_rec – negative recurrent trace state.
prev_output – previous returned output.
Notes
Input: tensor of shape
[*,num_neurons,*]wherenum_neuronsis at indexdim.Output: tensor with the same shape as the input.
DRLITis the ternary recurrent layer to reach for when positive and negative recurrent history should use different decays and gains.Examples:
>>> layer = tt.snn.DRLIT(num_neurons=32) >>> input = torch.randn(16, 32) >>> output = layer(input) >>> print(output.shape) torch.Size([16, 32])
- class tracetorch.snn.SRLIT(*args: Any, **kwargs: Any)[source]
Bases:
LayerA synaptic recurrent leaky integrate-and-ternary-fire layer.
SRLITcombines a synaptic input trace with a recurrent trace of the previous ternary output. The membrane receives both the smoothed input and the recurrent current before the ternary firing decision.- Parameters:
num_neurons (int) – number of neurons in the target dimension.
alpha (float or torch.Tensor, default=0.5) – synaptic decay.
beta (float or torch.Tensor, default=0.9) – membrane decay.
gamma (float or torch.Tensor, default=0.9) – recurrent trace decay.
pos_threshold (float or torch.Tensor, default=1.0) – positive threshold.
neg_threshold (float or torch.Tensor, default=1.0) – negative threshold magnitude.
bias (float or torch.Tensor, default=0.0) – bias that shifts both firing boundaries.
rec_weight (float or torch.Tensor, default=0.0) – recurrent input scale.
dim (int, default=-1) – the dimension along which the layer operates.
alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron synaptic decay.
beta_rank (Literal[0, 1], default=1) – scalar or per-neuron membrane decay.
gamma_rank (Literal[0, 1], default=1) – scalar or per-neuron recurrent decay.
pos_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron positive threshold.
neg_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron negative threshold magnitude.
bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.
rec_weight_rank (Literal[0, 1], default=1) – scalar or per-neuron recurrent scale.
learn_alpha (bool, default=True) – whether
alphais trainable.learn_beta (bool, default=True) – whether
betais trainable.learn_gamma (bool, default=True) – whether
gammais trainable.learn_pos_threshold (bool, default=True) – whether
pos_thresholdis trainable.learn_neg_threshold (bool, default=True) – whether
neg_thresholdis trainable.learn_bias (bool, default=True) – whether
biasis trainable.learn_rec_weight (bool, default=True) – whether
rec_weightis trainable.spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.
quant_fn (Callable, default=nn.Identity()) – output quantization function.
- Variables:
syn – synaptic state.
mem – membrane state.
rec – recurrent trace state.
prev_output – previous returned output.
Notes
Input: tensor of shape
[*,num_neurons,*]wherenum_neuronsis at indexdim.Output: tensor with the same shape as the input.
Pseudocode looks as follows:
syn = alpha * syn + (1 - alpha) * x rec = gamma * rec + (1 - gamma) * prev_output mem = beta * mem + syn + rec_weight * rec pos = quant_fn(spike_fn(mem - pos_threshold + bias)) neg = -quant_fn(spike_fn(-neg_threshold - mem - bias)) prev_output = pos + neg return prev_output
Examples:
>>> layer = tt.snn.SRLIT(num_neurons=32) >>> input = torch.randn(16, 32) >>> output = layer(input) >>> print(output.shape) torch.Size([16, 32])
- class tracetorch.snn.DSRLIT(*args: Any, **kwargs: Any)[source]
Bases:
LayerA dual synaptic recurrent leaky integrate-and-ternary-fire layer.
DSRLITcombines every ternary trace mechanism: dual synaptic traces, dual recurrent traces, dual membrane traces, and positive/negative firing thresholds. It is the most expressive unscaled ternary SNN layer.- Parameters:
num_neurons (int) – number of neurons in the target dimension.
pos_alpha (float or torch.Tensor, default=0.5) – positive synaptic decay.
neg_alpha (float or torch.Tensor, default=0.5) – negative synaptic decay.
pos_beta (float or torch.Tensor, default=0.9) – positive membrane decay.
neg_beta (float or torch.Tensor, default=0.9) – negative membrane decay.
pos_gamma (float or torch.Tensor, default=0.9) – positive recurrent decay.
neg_gamma (float or torch.Tensor, default=0.9) – negative recurrent decay.
pos_threshold (float or torch.Tensor, default=1.0) – positive threshold.
neg_threshold (float or torch.Tensor, default=1.0) – negative threshold magnitude.
bias (float or torch.Tensor, default=0.0) – bias that shifts both firing boundaries.
pos_rec_weight (float or torch.Tensor, default=0.0) – positive recurrent input scale.
neg_rec_weight (float or torch.Tensor, default=0.0) – negative recurrent input scale.
dim (int, default=-1) – the dimension along which the layer operates.
pos_alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron positive synaptic decay.
neg_alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron negative synaptic decay.
pos_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron positive membrane decay.
neg_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron negative membrane decay.
pos_gamma_rank (Literal[0, 1], default=1) – scalar or per-neuron positive recurrent decay.
neg_gamma_rank (Literal[0, 1], default=1) – scalar or per-neuron negative recurrent decay.
pos_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron positive threshold.
neg_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron negative threshold magnitude.
bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.
pos_rec_weight_rank (Literal[0, 1], default=1) – scalar or per-neuron positive recurrent scale.
neg_rec_weight_rank (Literal[0, 1], default=1) – scalar or per-neuron negative recurrent scale.
learn_pos_alpha (bool, default=True) – whether
pos_alphais trainable.learn_neg_alpha (bool, default=True) – whether
neg_alphais trainable.learn_pos_beta (bool, default=True) – whether
pos_betais trainable.learn_neg_beta (bool, default=True) – whether
neg_betais trainable.learn_pos_gamma (bool, default=True) – whether
pos_gammais trainable.learn_neg_gamma (bool, default=True) – whether
neg_gammais trainable.learn_pos_threshold (bool, default=True) – whether
pos_thresholdis trainable.learn_neg_threshold (bool, default=True) – whether
neg_thresholdis trainable.learn_bias (bool, default=True) – whether
biasis trainable.learn_pos_rec_weight (bool, default=True) – whether
pos_rec_weightis trainable.learn_neg_rec_weight (bool, default=True) – whether
neg_rec_weightis trainable.spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.
quant_fn (Callable, default=nn.Identity()) – output quantization function.
- Variables:
pos_syn – positive synaptic state.
neg_syn – negative synaptic state.
pos_mem – positive membrane state.
neg_mem – negative membrane state.
pos_rec – positive recurrent trace state.
neg_rec – negative recurrent trace state.
prev_output – previous returned output.
Notes
Input: tensor of shape
[*,num_neurons,*]wherenum_neuronsis at indexdim.Output: tensor with the same shape as the input.
Use
DSRLITwhen both input history and output history should retain sign-specific dynamics before ternary firing.Examples:
>>> layer = tt.snn.DSRLIT(num_neurons=32) >>> input = torch.randn(16, 32) >>> output = layer(input) >>> print(output.shape) torch.Size([16, 32])
Ternary Scaled Firing Layers
- class tracetorch.snn.LITS(*args: Any, **kwargs: Any)[source]
Bases:
LayerA leaky integrate-and-ternary-scaled-fire layer.
LITSis the scaled-output variant ofLIT. It computes positive and negative ternary firing decisions, resets the membrane with the unscaled spike signs, then multiplies positive and negative outputs by separate learnable or fixed scales.- Parameters:
num_neurons (int) – number of neurons in the target dimension.
beta (float or torch.Tensor, default=0.9) – membrane decay, constrained to
(0, 1).pos_threshold (float or torch.Tensor, default=1.0) – positive threshold.
neg_threshold (float or torch.Tensor, default=1.0) – negative threshold magnitude.
bias (float or torch.Tensor, default=0.0) – bias that shifts both firing boundaries.
pos_scale (float or torch.Tensor, default=1.0) – multiplier applied to positive output events.
neg_scale (float or torch.Tensor, default=1.0) – multiplier applied to negative output events.
dim (int, default=-1) – the dimension along which the layer operates.
beta_rank (Literal[0, 1], default=1) – scalar or per-neuron membrane decay.
pos_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron positive threshold.
neg_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron negative threshold magnitude.
pos_scale_rank (Literal[0, 1], default=1) – scalar or per-neuron positive scale.
neg_scale_rank (Literal[0, 1], default=1) – scalar or per-neuron negative scale.
bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.
learn_beta (bool, default=True) – whether
betais trainable.learn_pos_threshold (bool, default=True) – whether
pos_thresholdis trainable.learn_neg_threshold (bool, default=True) – whether
neg_thresholdis trainable.learn_pos_scale (bool, default=True) – whether
pos_scaleis trainable.learn_neg_scale (bool, default=True) – whether
neg_scaleis trainable.learn_bias (bool, default=True) – whether
biasis trainable.spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.
quant_fn (Callable, default=nn.Identity()) – output quantization function.
- Variables:
mem – membrane state.
pos_scale – positive output scale.
neg_scale – negative output scale.
Notes
Input: tensor of shape
[*,num_neurons,*]wherenum_neuronsis at indexdim.Output: tensor with the same shape as the input.
Pseudocode looks as follows:
mem = beta * mem + x pos = quant_fn(spike_fn(mem - pos_threshold + bias)) neg = -quant_fn(spike_fn(-neg_threshold - mem - bias)) mem = mem - pos * pos_threshold - neg * neg_threshold return pos * pos_scale + neg * neg_scale
Examples:
>>> layer = tt.snn.LITS(num_neurons=32) >>> input = torch.randn(16, 32) >>> output = layer(input) >>> print(output.shape) torch.Size([16, 32])
- class tracetorch.snn.DLITS(*args: Any, **kwargs: Any)[source]
Bases:
LayerA dual leaky integrate-and-ternary-scaled-fire layer.
DLITScombines dual membrane traces with scaled ternary output. The summed membrane is thresholded, the reset is split across the dual traces, and the returned positive and negative outputs are scaled independently.- Parameters:
num_neurons (int) – number of neurons in the target dimension.
pos_beta (float or torch.Tensor, default=0.9) – positive membrane decay.
neg_beta (float or torch.Tensor, default=0.9) – negative membrane decay.
pos_threshold (float or torch.Tensor, default=1.0) – positive threshold.
neg_threshold (float or torch.Tensor, default=1.0) – negative threshold magnitude.
bias (float or torch.Tensor, default=0.0) – bias that shifts both firing boundaries.
pos_scale (float or torch.Tensor, default=1.0) – positive output scale.
neg_scale (float or torch.Tensor, default=1.0) – negative output scale.
dim (int, default=-1) – the dimension along which the layer operates.
pos_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron positive decay.
neg_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron negative decay.
pos_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron positive threshold.
neg_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron negative threshold magnitude.
pos_scale_rank (Literal[0, 1], default=1) – scalar or per-neuron positive scale.
neg_scale_rank (Literal[0, 1], default=1) – scalar or per-neuron negative scale.
bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.
learn_pos_beta (bool, default=True) – whether
pos_betais trainable.learn_neg_beta (bool, default=True) – whether
neg_betais trainable.learn_pos_threshold (bool, default=True) – whether
pos_thresholdis trainable.learn_neg_threshold (bool, default=True) – whether
neg_thresholdis trainable.learn_pos_scale (bool, default=True) – whether
pos_scaleis trainable.learn_neg_scale (bool, default=True) – whether
neg_scaleis trainable.learn_bias (bool, default=True) – whether
biasis trainable.spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.
quant_fn (Callable, default=nn.Identity()) – output quantization function.
- Variables:
pos_mem – positive membrane state.
neg_mem – negative membrane state.
pos_scale – positive output scale.
neg_scale – negative output scale.
Notes
Input: tensor of shape
[*,num_neurons,*]wherenum_neuronsis at indexdim.Output: tensor with the same shape as the input.
The output scales affect what downstream layers receive; membrane reset still uses the unscaled threshold crossing events.
Examples:
>>> layer = tt.snn.DLITS(num_neurons=32) >>> input = torch.randn(16, 32) >>> output = layer(input) >>> print(output.shape) torch.Size([16, 32])
- class tracetorch.snn.SLITS(*args: Any, **kwargs: Any)[source]
Bases:
LayerA synaptic leaky integrate-and-ternary-scaled-fire layer.
SLITSsmooths the input with a synaptic trace before membrane integration and scaled ternary output.- Parameters:
num_neurons (int) – number of neurons in the target dimension.
alpha (float or torch.Tensor, default=0.5) – synaptic decay.
beta (float or torch.Tensor, default=0.9) – membrane decay.
pos_threshold (float or torch.Tensor, default=1.0) – positive threshold.
neg_threshold (float or torch.Tensor, default=1.0) – negative threshold magnitude.
bias (float or torch.Tensor, default=0.0) – bias that shifts both firing boundaries.
pos_scale (float or torch.Tensor, default=1.0) – positive output scale.
neg_scale (float or torch.Tensor, default=1.0) – negative output scale.
dim (int, default=-1) – the dimension along which the layer operates.
alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron synaptic decay.
beta_rank (Literal[0, 1], default=1) – scalar or per-neuron membrane decay.
pos_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron positive threshold.
neg_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron negative threshold magnitude.
pos_scale_rank (Literal[0, 1], default=1) – scalar or per-neuron positive scale.
neg_scale_rank (Literal[0, 1], default=1) – scalar or per-neuron negative scale.
bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.
learn_alpha (bool, default=True) – whether
alphais trainable.learn_beta (bool, default=True) – whether
betais trainable.learn_pos_threshold (bool, default=True) – whether
pos_thresholdis trainable.learn_neg_threshold (bool, default=True) – whether
neg_thresholdis trainable.learn_pos_scale (bool, default=True) – whether
pos_scaleis trainable.learn_neg_scale (bool, default=True) – whether
neg_scaleis trainable.learn_bias (bool, default=True) – whether
biasis trainable.spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.
quant_fn (Callable, default=nn.Identity()) – output quantization function.
- Variables:
syn – synaptic state.
mem – membrane state.
pos_scale – positive output scale.
neg_scale – negative output scale.
Notes
Input: tensor of shape
[*,num_neurons,*]wherenum_neuronsis at indexdim.Output: tensor with the same shape as the input.
Pseudocode looks as follows:
syn = alpha * syn + (1 - alpha) * x mem = beta * mem + syn pos = quant_fn(spike_fn(mem - pos_threshold + bias)) neg = -quant_fn(spike_fn(-neg_threshold - mem - bias)) mem = mem - pos * pos_threshold - neg * neg_threshold return pos * pos_scale + neg * neg_scale
Examples:
>>> layer = tt.snn.SLITS(num_neurons=32) >>> input = torch.randn(16, 32) >>> output = layer(input) >>> print(output.shape) torch.Size([16, 32])
- class tracetorch.snn.RLITS(*args: Any, **kwargs: Any)[source]
Bases:
LayerA recurrent leaky integrate-and-ternary-scaled-fire layer.
RLITSadds a recurrent trace of the previous scaled ternary output. The recurrent trace is scaled byrec_weightand added before membrane integration; the current output is then scaled separately for positive and negative events.- Parameters:
num_neurons (int) – number of neurons in the target dimension.
beta (float or torch.Tensor, default=0.9) – membrane decay.
gamma (float or torch.Tensor, default=0.9) – recurrent trace decay.
pos_threshold (float or torch.Tensor, default=1.0) – positive threshold.
neg_threshold (float or torch.Tensor, default=1.0) – negative threshold magnitude.
bias (float or torch.Tensor, default=0.0) – bias that shifts both firing boundaries.
pos_scale (float or torch.Tensor, default=1.0) – positive output scale.
neg_scale (float or torch.Tensor, default=1.0) – negative output scale.
rec_weight (float or torch.Tensor, default=0.0) – recurrent input scale.
dim (int, default=-1) – the dimension along which the layer operates.
beta_rank (Literal[0, 1], default=1) – scalar or per-neuron membrane decay.
gamma_rank (Literal[0, 1], default=1) – scalar or per-neuron recurrent decay.
pos_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron positive threshold.
neg_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron negative threshold magnitude.
pos_scale_rank (Literal[0, 1], default=1) – scalar or per-neuron positive scale.
neg_scale_rank (Literal[0, 1], default=1) – scalar or per-neuron negative scale.
bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.
rec_weight_rank (Literal[0, 1], default=1) – scalar or per-neuron recurrent scale.
learn_beta (bool, default=True) – whether
betais trainable.learn_gamma (bool, default=True) – whether
gammais trainable.learn_pos_threshold (bool, default=True) – whether
pos_thresholdis trainable.learn_neg_threshold (bool, default=True) – whether
neg_thresholdis trainable.learn_pos_scale (bool, default=True) – whether
pos_scaleis trainable.learn_neg_scale (bool, default=True) – whether
neg_scaleis trainable.learn_bias (bool, default=True) – whether
biasis trainable.learn_rec_weight (bool, default=True) – whether
rec_weightis trainable.spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.
quant_fn (Callable, default=nn.Identity()) – output quantization function.
- Variables:
mem – membrane state.
rec – recurrent trace state.
prev_output – previous returned output.
pos_scale – positive output scale.
neg_scale – negative output scale.
Notes
Input: tensor of shape
[*,num_neurons,*]wherenum_neuronsis at indexdim.Output: tensor with the same shape as the input.
prev_outputstores the scaled output, so recurrence sees the same value that downstream layers saw at the previous timestep.Examples:
>>> layer = tt.snn.RLITS(num_neurons=32) >>> input = torch.randn(16, 32) >>> output = layer(input) >>> print(output.shape) torch.Size([16, 32])
- class tracetorch.snn.DSLITS(*args: Any, **kwargs: Any)[source]
Bases:
LayerA dual synaptic leaky integrate-and-ternary-scaled-fire layer.
DSLITScombines dual synaptic traces, dual membrane traces, and scaled ternary output. It is the scaled counterpart ofDSLIT.- Parameters:
num_neurons (int) – number of neurons in the target dimension.
pos_alpha (float or torch.Tensor, default=0.5) – positive synaptic decay.
neg_alpha (float or torch.Tensor, default=0.5) – negative synaptic decay.
pos_beta (float or torch.Tensor, default=0.9) – positive membrane decay.
neg_beta (float or torch.Tensor, default=0.9) – negative membrane decay.
pos_threshold (float or torch.Tensor, default=1.0) – positive threshold.
neg_threshold (float or torch.Tensor, default=1.0) – negative threshold magnitude.
bias (float or torch.Tensor, default=0.0) – bias that shifts both firing boundaries.
pos_scale (float or torch.Tensor, default=1.0) – positive output scale.
neg_scale (float or torch.Tensor, default=1.0) – negative output scale.
dim (int, default=-1) – the dimension along which the layer operates.
pos_alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron positive synaptic decay.
neg_alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron negative synaptic decay.
pos_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron positive membrane decay.
neg_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron negative membrane decay.
pos_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron positive threshold.
neg_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron negative threshold magnitude.
pos_scale_rank (Literal[0, 1], default=1) – scalar or per-neuron positive scale.
neg_scale_rank (Literal[0, 1], default=1) – scalar or per-neuron negative scale.
bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.
learn_pos_alpha (bool, default=True) – whether
pos_alphais trainable.learn_neg_alpha (bool, default=True) – whether
neg_alphais trainable.learn_pos_beta (bool, default=True) – whether
pos_betais trainable.learn_neg_beta (bool, default=True) – whether
neg_betais trainable.learn_pos_threshold (bool, default=True) – whether
pos_thresholdis trainable.learn_neg_threshold (bool, default=True) – whether
neg_thresholdis trainable.learn_pos_scale (bool, default=True) – whether
pos_scaleis trainable.learn_neg_scale (bool, default=True) – whether
neg_scaleis trainable.learn_bias (bool, default=True) – whether
biasis trainable.spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.
quant_fn (Callable, default=nn.Identity()) – output quantization function.
- Variables:
pos_syn – positive synaptic state.
neg_syn – negative synaptic state.
pos_mem – positive membrane state.
neg_mem – negative membrane state.
pos_scale – positive output scale.
neg_scale – negative output scale.
Notes
Input: tensor of shape
[*,num_neurons,*]wherenum_neuronsis at indexdim.Output: tensor with the same shape as the input.
Use this layer when sign-specific input smoothing and sign-specific output magnitudes are both useful, but no recurrent output trace is needed.
Examples:
>>> layer = tt.snn.DSLITS(num_neurons=32) >>> input = torch.randn(16, 32) >>> output = layer(input) >>> print(output.shape) torch.Size([16, 32])
- class tracetorch.snn.DRLITS(*args: Any, **kwargs: Any)[source]
Bases:
LayerA dual recurrent leaky integrate-and-ternary-scaled-fire layer.
DRLITScombines dual membrane traces, dual recurrent traces, and scaled ternary output. It is useful when positive and negative recurrent history should have different decays, gains, and downstream output magnitudes.- Parameters:
num_neurons (int) – number of neurons in the target dimension.
pos_beta (float or torch.Tensor, default=0.9) – positive membrane decay.
neg_beta (float or torch.Tensor, default=0.9) – negative membrane decay.
pos_gamma (float or torch.Tensor, default=0.9) – positive recurrent decay.
neg_gamma (float or torch.Tensor, default=0.9) – negative recurrent decay.
pos_threshold (float or torch.Tensor, default=1.0) – positive threshold.
neg_threshold (float or torch.Tensor, default=1.0) – negative threshold magnitude.
bias (float or torch.Tensor, default=0.0) – bias that shifts both firing boundaries.
pos_scale (float or torch.Tensor, default=1.0) – positive output scale.
neg_scale (float or torch.Tensor, default=1.0) – negative output scale.
pos_rec_weight (float or torch.Tensor, default=0.0) – positive recurrent input scale.
neg_rec_weight (float or torch.Tensor, default=0.0) – negative recurrent input scale.
dim (int, default=-1) – the dimension along which the layer operates.
pos_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron positive membrane decay.
neg_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron negative membrane decay.
pos_gamma_rank (Literal[0, 1], default=1) – scalar or per-neuron positive recurrent decay.
neg_gamma_rank (Literal[0, 1], default=1) – scalar or per-neuron negative recurrent decay.
pos_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron positive threshold.
neg_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron negative threshold magnitude.
pos_scale_rank (Literal[0, 1], default=1) – scalar or per-neuron positive scale.
neg_scale_rank (Literal[0, 1], default=1) – scalar or per-neuron negative scale.
bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.
pos_rec_weight_rank (Literal[0, 1], default=1) – scalar or per-neuron positive recurrent scale.
neg_rec_weight_rank (Literal[0, 1], default=1) – scalar or per-neuron negative recurrent scale.
learn_pos_beta (bool, default=True) – whether
pos_betais trainable.learn_neg_beta (bool, default=True) – whether
neg_betais trainable.learn_pos_gamma (bool, default=True) – whether
pos_gammais trainable.learn_neg_gamma (bool, default=True) – whether
neg_gammais trainable.learn_pos_threshold (bool, default=True) – whether
pos_thresholdis trainable.learn_neg_threshold (bool, default=True) – whether
neg_thresholdis trainable.learn_pos_scale (bool, default=True) – whether
pos_scaleis trainable.learn_neg_scale (bool, default=True) – whether
neg_scaleis trainable.learn_bias (bool, default=True) – whether
biasis trainable.learn_pos_rec_weight (bool, default=True) – whether
pos_rec_weightis trainable.learn_neg_rec_weight (bool, default=True) – whether
neg_rec_weightis trainable.spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.
quant_fn (Callable, default=nn.Identity()) – output quantization function.
- Variables:
pos_mem – positive membrane state.
neg_mem – negative membrane state.
pos_rec – positive recurrent trace state.
neg_rec – negative recurrent trace state.
prev_output – previous returned output.
pos_scale – positive output scale.
neg_scale – negative output scale.
Notes
Input: tensor of shape
[*,num_neurons,*]wherenum_neuronsis at indexdim.Output: tensor with the same shape as the input.
The previous output stored in
prev_outputis already scaled, matching the value that downstream layers received.Examples:
>>> layer = tt.snn.DRLITS(num_neurons=32) >>> input = torch.randn(16, 32) >>> output = layer(input) >>> print(output.shape) torch.Size([16, 32])
- class tracetorch.snn.SRLITS(*args: Any, **kwargs: Any)[source]
Bases:
LayerA synaptic recurrent leaky integrate-and-ternary-scaled-fire layer.
SRLITScombines a synaptic input trace, a recurrent output trace, and scaled ternary firing. It is the scaled counterpart ofSRLIT.- Parameters:
num_neurons (int) – number of neurons in the target dimension.
alpha (float or torch.Tensor, default=0.5) – synaptic decay.
beta (float or torch.Tensor, default=0.9) – membrane decay.
gamma (float or torch.Tensor, default=0.9) – recurrent trace decay.
pos_threshold (float or torch.Tensor, default=1.0) – positive threshold.
neg_threshold (float or torch.Tensor, default=1.0) – negative threshold magnitude.
bias (float or torch.Tensor, default=0.0) – bias that shifts both firing boundaries.
pos_scale (float or torch.Tensor, default=1.0) – positive output scale.
neg_scale (float or torch.Tensor, default=1.0) – negative output scale.
rec_weight (float or torch.Tensor, default=0.0) – recurrent input scale.
dim (int, default=-1) – the dimension along which the layer operates.
alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron synaptic decay.
beta_rank (Literal[0, 1], default=1) – scalar or per-neuron membrane decay.
gamma_rank (Literal[0, 1], default=1) – scalar or per-neuron recurrent decay.
pos_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron positive threshold.
neg_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron negative threshold magnitude.
pos_scale_rank (Literal[0, 1], default=1) – scalar or per-neuron positive scale.
neg_scale_rank (Literal[0, 1], default=1) – scalar or per-neuron negative scale.
bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.
rec_weight_rank (Literal[0, 1], default=1) – scalar or per-neuron recurrent scale.
learn_alpha (bool, default=True) – whether
alphais trainable.learn_beta (bool, default=True) – whether
betais trainable.learn_gamma (bool, default=True) – whether
gammais trainable.learn_pos_threshold (bool, default=True) – whether
pos_thresholdis trainable.learn_neg_threshold (bool, default=True) – whether
neg_thresholdis trainable.learn_pos_scale (bool, default=True) – whether
pos_scaleis trainable.learn_neg_scale (bool, default=True) – whether
neg_scaleis trainable.learn_bias (bool, default=True) – whether
biasis trainable.learn_rec_weight (bool, default=True) – whether
rec_weightis trainable.spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.
quant_fn (Callable, default=nn.Identity()) – output quantization function.
- Variables:
syn – synaptic state.
mem – membrane state.
rec – recurrent trace state.
prev_output – previous returned output.
pos_scale – positive output scale.
neg_scale – negative output scale.
Notes
Input: tensor of shape
[*,num_neurons,*]wherenum_neuronsis at indexdim.Output: tensor with the same shape as the input.
Pseudocode looks as follows:
syn = alpha * syn + (1 - alpha) * x rec = gamma * rec + (1 - gamma) * prev_output mem = beta * mem + syn + rec_weight * rec pos = quant_fn(spike_fn(mem - pos_threshold + bias)) neg = -quant_fn(spike_fn(-neg_threshold - mem - bias)) mem = mem - pos * pos_threshold - neg * neg_threshold prev_output = pos * pos_scale + neg * neg_scale return prev_output
Examples:
>>> layer = tt.snn.SRLITS(num_neurons=32) >>> input = torch.randn(16, 32) >>> output = layer(input) >>> print(output.shape) torch.Size([16, 32])
- class tracetorch.snn.DSRLITS(*args: Any, **kwargs: Any)[source]
Bases:
LayerA dual synaptic recurrent leaky integrate-and-ternary-scaled-fire layer.
DSRLITScombines every SNN mechanism provided by traceTorch: dual sign-specific synaptic traces, dual sign-specific recurrent traces, dual sign-specific membrane traces, ternary firing, and independent positive and negative output scales.- Parameters:
num_neurons (int) – number of neurons in the target dimension.
pos_alpha (float or torch.Tensor, default=0.5) – positive synaptic decay.
neg_alpha (float or torch.Tensor, default=0.5) – negative synaptic decay.
pos_beta (float or torch.Tensor, default=0.9) – positive membrane decay.
neg_beta (float or torch.Tensor, default=0.9) – negative membrane decay.
pos_gamma (float or torch.Tensor, default=0.9) – positive recurrent decay.
neg_gamma (float or torch.Tensor, default=0.9) – negative recurrent decay.
pos_threshold (float or torch.Tensor, default=1.0) – positive threshold.
neg_threshold (float or torch.Tensor, default=1.0) – negative threshold magnitude.
bias (float or torch.Tensor, default=0.0) – bias that shifts both firing boundaries.
pos_scale (float or torch.Tensor, default=1.0) – positive output scale.
neg_scale (float or torch.Tensor, default=1.0) – negative output scale.
pos_rec_weight (float or torch.Tensor, default=0.0) – positive recurrent input scale.
neg_rec_weight (float or torch.Tensor, default=0.0) – negative recurrent input scale.
dim (int, default=-1) – the dimension along which the layer operates.
pos_alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron positive synaptic decay.
neg_alpha_rank (Literal[0, 1], default=1) – scalar or per-neuron negative synaptic decay.
pos_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron positive membrane decay.
neg_beta_rank (Literal[0, 1], default=1) – scalar or per-neuron negative membrane decay.
pos_gamma_rank (Literal[0, 1], default=1) – scalar or per-neuron positive recurrent decay.
neg_gamma_rank (Literal[0, 1], default=1) – scalar or per-neuron negative recurrent decay.
pos_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron positive threshold.
neg_threshold_rank (Literal[0, 1], default=1) – scalar or per-neuron negative threshold magnitude.
pos_scale_rank (Literal[0, 1], default=1) – scalar or per-neuron positive scale.
neg_scale_rank (Literal[0, 1], default=1) – scalar or per-neuron negative scale.
bias_rank (Literal[0, 1], default=1) – scalar or per-neuron bias.
pos_rec_weight_rank (Literal[0, 1], default=1) – scalar or per-neuron positive recurrent scale.
neg_rec_weight_rank (Literal[0, 1], default=1) – scalar or per-neuron negative recurrent scale.
learn_pos_alpha (bool, default=True) – whether
pos_alphais trainable.learn_neg_alpha (bool, default=True) – whether
neg_alphais trainable.learn_pos_beta (bool, default=True) – whether
pos_betais trainable.learn_neg_beta (bool, default=True) – whether
neg_betais trainable.learn_pos_gamma (bool, default=True) – whether
pos_gammais trainable.learn_neg_gamma (bool, default=True) – whether
neg_gammais trainable.learn_pos_threshold (bool, default=True) – whether
pos_thresholdis trainable.learn_neg_threshold (bool, default=True) – whether
neg_thresholdis trainable.learn_pos_scale (bool, default=True) – whether
pos_scaleis trainable.learn_neg_scale (bool, default=True) – whether
neg_scaleis trainable.learn_bias (bool, default=True) – whether
biasis trainable.learn_pos_rec_weight (bool, default=True) – whether
pos_rec_weightis trainable.learn_neg_rec_weight (bool, default=True) – whether
neg_rec_weightis trainable.spike_fn (Callable, default=tt.functional.sigmoid4x) – spike probability function.
quant_fn (Callable, default=nn.Identity()) – output quantization function.
- Variables:
pos_syn – positive synaptic state.
neg_syn – negative synaptic state.
pos_mem – positive membrane state.
neg_mem – negative membrane state.
pos_rec – positive recurrent trace state.
neg_rec – negative recurrent trace state.
prev_output – previous returned output.
pos_scale – positive output scale.
neg_scale – negative output scale.
Notes
Input: tensor of shape
[*,num_neurons,*]wherenum_neuronsis at indexdim.Output: tensor with the same shape as the input.
This layer is intentionally dense: it is meant for cases where sign, input history, output history, thresholding, and output magnitude all need separate degrees of freedom.
Examples:
>>> layer = tt.snn.DSRLITS(num_neurons=32) >>> input = torch.randn(16, 32) >>> output = layer(input) >>> print(output.shape) torch.Size([16, 32])