from typing import TypedDict, Optional, Literal, Union, Dict, Any
import torch
from torch import nn
from ._snnlayer import Layer as SNNLayer
from .. import functional
[docs]
class LITS(SNNLayer):
r"""A leaky integrate-and-ternary-scaled-fire layer.
``LITS`` is the scaled-output variant of ``LIT``. It computes positive and
negative ternary firing decisions, resets the membrane with the unscaled
spike signs, then multiplies positive and negative outputs by separate
learnable or fixed scales.
Args:
num_neurons (int): number of neurons in the target dimension.
beta (float or torch.Tensor, default=0.9): membrane decay, constrained
to ``(0, 1)``.
pos_threshold (float or torch.Tensor, default=1.0): positive threshold.
neg_threshold (float or torch.Tensor, default=1.0): negative threshold
magnitude.
bias (float or torch.Tensor, default=0.0): bias that shifts both firing
boundaries.
pos_scale (float or torch.Tensor, default=1.0): multiplier applied to
positive output events.
neg_scale (float or torch.Tensor, default=1.0): multiplier applied to
negative output events.
dim (int, default=-1): the dimension along which the layer operates.
beta_rank (Literal[0, 1], default=1): scalar or per-neuron membrane
decay.
pos_threshold_rank (Literal[0, 1], default=1): scalar or per-neuron
positive threshold.
neg_threshold_rank (Literal[0, 1], default=1): scalar or per-neuron
negative threshold magnitude.
pos_scale_rank (Literal[0, 1], default=1): scalar or per-neuron positive
scale.
neg_scale_rank (Literal[0, 1], default=1): scalar or per-neuron negative
scale.
bias_rank (Literal[0, 1], default=1): scalar or per-neuron bias.
learn_beta (bool, default=True): whether ``beta`` is trainable.
learn_pos_threshold (bool, default=True): whether ``pos_threshold`` is
trainable.
learn_neg_threshold (bool, default=True): whether ``neg_threshold`` is
trainable.
learn_pos_scale (bool, default=True): whether ``pos_scale`` is trainable.
learn_neg_scale (bool, default=True): whether ``neg_scale`` is trainable.
learn_bias (bool, default=True): whether ``bias`` is trainable.
spike_fn (Callable, default=tt.functional.sigmoid4x): spike probability
function.
quant_fn (Callable, default=nn.Identity()): output quantization function.
Attributes:
mem: membrane state.
pos_scale: positive output scale.
neg_scale: negative output scale.
Notes:
- **Input**: tensor of shape ``[*,num_neurons,*]`` where ``num_neurons``
is at index ``dim``.
- **Output**: tensor with the same shape as the input.
Pseudocode looks as follows:
::
mem = beta * mem + x
pos = quant_fn(spike_fn(mem - pos_threshold + bias))
neg = -quant_fn(spike_fn(-neg_threshold - mem - bias))
mem = mem - pos * pos_threshold - neg * neg_threshold
return pos * pos_scale + neg * neg_scale
Examples::
>>> layer = tt.snn.LITS(num_neurons=32)
>>> input = torch.randn(16, 32)
>>> output = layer(input)
>>> print(output.shape)
torch.Size([16, 32])
"""
def __init__(
self,
num_neurons: int,
beta: Union[float, torch.Tensor] = 0.9,
pos_threshold: Union[float, torch.Tensor] = 1.0,
neg_threshold: Union[float, torch.Tensor] = 1.0,
bias: Union[float, torch.Tensor] = 0.0,
pos_scale: Union[float, torch.Tensor] = 1.0,
neg_scale: Union[float, torch.Tensor] = 1.0,
dim: int = -1,
beta_rank: Literal[0, 1] = 1,
pos_threshold_rank: Literal[0, 1] = 1,
neg_threshold_rank: Literal[0, 1] = 1,
pos_scale_rank: Literal[0, 1] = 1,
neg_scale_rank: Literal[0, 1] = 1,
bias_rank: Literal[0, 1] = 1,
learn_beta: bool = True,
learn_pos_threshold: bool = True,
learn_neg_threshold: bool = True,
learn_pos_scale: bool = True,
learn_neg_scale: bool = True,
learn_bias: bool = True,
spike_fn=functional.sigmoid4x,
quant_fn=nn.Identity(),
):
super().__init__(num_neurons, dim)
self._initialize_state("mem")
self._register_decay("beta", beta, beta_rank, learn_beta)
self.spike_fn = spike_fn
self.quant_fn = quant_fn
self._register_threshold("pos_threshold", pos_threshold, pos_threshold_rank, learn_pos_threshold)
self._register_threshold("neg_threshold", neg_threshold, neg_threshold_rank, learn_neg_threshold)
self._register_bias("bias", bias, bias_rank, learn_bias)
self._register_parameter("pos_scale", pos_scale, pos_scale_rank, learn_pos_scale)
self._register_parameter("neg_scale", neg_scale, neg_scale_rank, learn_neg_scale)
[docs]
def forward(self, x):
"""Computes the forward pass."""
self._ensure_states(x)
x = self._to_working_dim(x)
mem = self._to_working_dim(self.mem)
mem = mem * self.beta + x
pos_spike_prob = self.spike_fn(mem - self.pos_threshold + self.bias)
neg_spike_prob = self.spike_fn(-self.neg_threshold - mem - self.bias)
pos_spikes = self.quant_fn(pos_spike_prob)
neg_spikes = -self.quant_fn(neg_spike_prob)
mem = mem - pos_spikes * self.pos_threshold
mem = mem - neg_spikes * self.neg_threshold
pos_spikes = pos_spikes * self.pos_scale
neg_spikes = neg_spikes * self.neg_scale
spikes = pos_spikes + neg_spikes
spikes = self._from_working_dim(spikes)
self.mem = self._from_working_dim(mem)
return spikes
[docs]
class DLITS(SNNLayer):
r"""A dual leaky integrate-and-ternary-scaled-fire layer.
``DLITS`` combines dual membrane traces with scaled ternary output. The
summed membrane is thresholded, the reset is split across the dual traces,
and the returned positive and negative outputs are scaled independently.
Args:
num_neurons (int): number of neurons in the target dimension.
pos_beta (float or torch.Tensor, default=0.9): positive membrane decay.
neg_beta (float or torch.Tensor, default=0.9): negative membrane decay.
pos_threshold (float or torch.Tensor, default=1.0): positive threshold.
neg_threshold (float or torch.Tensor, default=1.0): negative threshold
magnitude.
bias (float or torch.Tensor, default=0.0): bias that shifts both firing
boundaries.
pos_scale (float or torch.Tensor, default=1.0): positive output scale.
neg_scale (float or torch.Tensor, default=1.0): negative output scale.
dim (int, default=-1): the dimension along which the layer operates.
pos_beta_rank (Literal[0, 1], default=1): scalar or per-neuron positive
decay.
neg_beta_rank (Literal[0, 1], default=1): scalar or per-neuron negative
decay.
pos_threshold_rank (Literal[0, 1], default=1): scalar or per-neuron
positive threshold.
neg_threshold_rank (Literal[0, 1], default=1): scalar or per-neuron
negative threshold magnitude.
pos_scale_rank (Literal[0, 1], default=1): scalar or per-neuron positive
scale.
neg_scale_rank (Literal[0, 1], default=1): scalar or per-neuron negative
scale.
bias_rank (Literal[0, 1], default=1): scalar or per-neuron bias.
learn_pos_beta (bool, default=True): whether ``pos_beta`` is trainable.
learn_neg_beta (bool, default=True): whether ``neg_beta`` is trainable.
learn_pos_threshold (bool, default=True): whether ``pos_threshold`` is
trainable.
learn_neg_threshold (bool, default=True): whether ``neg_threshold`` is
trainable.
learn_pos_scale (bool, default=True): whether ``pos_scale`` is trainable.
learn_neg_scale (bool, default=True): whether ``neg_scale`` is trainable.
learn_bias (bool, default=True): whether ``bias`` is trainable.
spike_fn (Callable, default=tt.functional.sigmoid4x): spike probability
function.
quant_fn (Callable, default=nn.Identity()): output quantization function.
Attributes:
pos_mem: positive membrane state.
neg_mem: negative membrane state.
pos_scale: positive output scale.
neg_scale: negative output scale.
Notes:
- **Input**: tensor of shape ``[*,num_neurons,*]`` where ``num_neurons``
is at index ``dim``.
- **Output**: tensor with the same shape as the input.
The output scales affect what downstream layers receive; membrane reset
still uses the unscaled threshold crossing events.
Examples::
>>> layer = tt.snn.DLITS(num_neurons=32)
>>> input = torch.randn(16, 32)
>>> output = layer(input)
>>> print(output.shape)
torch.Size([16, 32])
"""
def __init__(
self,
num_neurons: int,
pos_beta: Union[float, torch.Tensor] = 0.9,
neg_beta: Union[float, torch.Tensor] = 0.9,
pos_threshold: Union[float, torch.Tensor] = 1.0,
neg_threshold: Union[float, torch.Tensor] = 1.0,
bias: Union[float, torch.Tensor] = 0.0,
pos_scale: Union[float, torch.Tensor] = 1.0,
neg_scale: Union[float, torch.Tensor] = 1.0,
dim: int = -1,
pos_beta_rank: Literal[0, 1] = 1,
neg_beta_rank: Literal[0, 1] = 1,
pos_threshold_rank: Literal[0, 1] = 1,
neg_threshold_rank: Literal[0, 1] = 1,
pos_scale_rank: Literal[0, 1] = 1,
neg_scale_rank: Literal[0, 1] = 1,
bias_rank: Literal[0, 1] = 1,
learn_pos_beta: bool = True,
learn_neg_beta: bool = True,
learn_pos_threshold: bool = True,
learn_neg_threshold: bool = True,
learn_pos_scale: bool = True,
learn_neg_scale: bool = True,
learn_bias: bool = True,
spike_fn=functional.sigmoid4x,
quant_fn=nn.Identity(),
):
super().__init__(num_neurons, dim)
self._initialize_state("pos_mem")
self._initialize_state("neg_mem")
self._register_decay("pos_beta", pos_beta, pos_beta_rank, learn_pos_beta)
self._register_decay("neg_beta", neg_beta, neg_beta_rank, learn_neg_beta)
self.spike_fn = spike_fn
self.quant_fn = quant_fn
self._register_threshold("pos_threshold", pos_threshold, pos_threshold_rank, learn_pos_threshold)
self._register_threshold("neg_threshold", neg_threshold, neg_threshold_rank, learn_neg_threshold)
self._register_bias("bias", bias, bias_rank, learn_bias)
self._register_parameter("pos_scale", pos_scale, pos_scale_rank, learn_pos_scale)
self._register_parameter("neg_scale", neg_scale, neg_scale_rank, learn_neg_scale)
[docs]
def forward(self, x):
"""Computes the forward pass."""
self._ensure_states(x)
x = self._to_working_dim(x)
pos_mem = self._to_working_dim(self.pos_mem)
neg_mem = self._to_working_dim(self.neg_mem)
pos_mem = pos_mem * self.pos_beta + torch.where(x >= 0, x, 0.0)
neg_mem = neg_mem * self.neg_beta + torch.where(x <= 0, x, 0.0)
mem = pos_mem + neg_mem
pos_spike_prob = self.spike_fn(mem - self.pos_threshold + self.bias)
neg_spike_prob = self.spike_fn(-self.neg_threshold - mem - self.bias)
pos_spikes = self.quant_fn(pos_spike_prob)
neg_spikes = -self.quant_fn(neg_spike_prob)
pos_mem = pos_mem - pos_spikes * self.pos_threshold * 0.5
neg_mem = neg_mem - pos_spikes * self.pos_threshold * 0.5
pos_mem = pos_mem - neg_spikes * self.neg_threshold * 0.5
neg_mem = neg_mem - neg_spikes * self.neg_threshold * 0.5
pos_spikes = pos_spikes * self.pos_scale
neg_spikes = neg_spikes * self.neg_scale
spikes = pos_spikes + neg_spikes
spikes = self._from_working_dim(spikes)
self.pos_mem = self._from_working_dim(pos_mem)
self.neg_mem = self._from_working_dim(neg_mem)
return spikes
[docs]
class SLITS(SNNLayer):
r"""A synaptic leaky integrate-and-ternary-scaled-fire layer.
``SLITS`` smooths the input with a synaptic trace before membrane integration
and scaled ternary output.
Args:
num_neurons (int): number of neurons in the target dimension.
alpha (float or torch.Tensor, default=0.5): synaptic decay.
beta (float or torch.Tensor, default=0.9): membrane decay.
pos_threshold (float or torch.Tensor, default=1.0): positive threshold.
neg_threshold (float or torch.Tensor, default=1.0): negative threshold
magnitude.
bias (float or torch.Tensor, default=0.0): bias that shifts both firing
boundaries.
pos_scale (float or torch.Tensor, default=1.0): positive output scale.
neg_scale (float or torch.Tensor, default=1.0): negative output scale.
dim (int, default=-1): the dimension along which the layer operates.
alpha_rank (Literal[0, 1], default=1): scalar or per-neuron synaptic
decay.
beta_rank (Literal[0, 1], default=1): scalar or per-neuron membrane
decay.
pos_threshold_rank (Literal[0, 1], default=1): scalar or per-neuron
positive threshold.
neg_threshold_rank (Literal[0, 1], default=1): scalar or per-neuron
negative threshold magnitude.
pos_scale_rank (Literal[0, 1], default=1): scalar or per-neuron positive
scale.
neg_scale_rank (Literal[0, 1], default=1): scalar or per-neuron negative
scale.
bias_rank (Literal[0, 1], default=1): scalar or per-neuron bias.
learn_alpha (bool, default=True): whether ``alpha`` is trainable.
learn_beta (bool, default=True): whether ``beta`` is trainable.
learn_pos_threshold (bool, default=True): whether ``pos_threshold`` is
trainable.
learn_neg_threshold (bool, default=True): whether ``neg_threshold`` is
trainable.
learn_pos_scale (bool, default=True): whether ``pos_scale`` is trainable.
learn_neg_scale (bool, default=True): whether ``neg_scale`` is trainable.
learn_bias (bool, default=True): whether ``bias`` is trainable.
spike_fn (Callable, default=tt.functional.sigmoid4x): spike probability
function.
quant_fn (Callable, default=nn.Identity()): output quantization function.
Attributes:
syn: synaptic state.
mem: membrane state.
pos_scale: positive output scale.
neg_scale: negative output scale.
Notes:
- **Input**: tensor of shape ``[*,num_neurons,*]`` where ``num_neurons``
is at index ``dim``.
- **Output**: tensor with the same shape as the input.
Pseudocode looks as follows:
::
syn = alpha * syn + (1 - alpha) * x
mem = beta * mem + syn
pos = quant_fn(spike_fn(mem - pos_threshold + bias))
neg = -quant_fn(spike_fn(-neg_threshold - mem - bias))
mem = mem - pos * pos_threshold - neg * neg_threshold
return pos * pos_scale + neg * neg_scale
Examples::
>>> layer = tt.snn.SLITS(num_neurons=32)
>>> input = torch.randn(16, 32)
>>> output = layer(input)
>>> print(output.shape)
torch.Size([16, 32])
"""
def __init__(
self,
num_neurons: int,
alpha: Union[float, torch.Tensor] = 0.5,
beta: Union[float, torch.Tensor] = 0.9,
pos_threshold: Union[float, torch.Tensor] = 1.0,
neg_threshold: Union[float, torch.Tensor] = 1.0,
bias: Union[float, torch.Tensor] = 0.0,
pos_scale: Union[float, torch.Tensor] = 1.0,
neg_scale: Union[float, torch.Tensor] = 1.0,
dim: int = -1,
alpha_rank: Literal[0, 1] = 1,
beta_rank: Literal[0, 1] = 1,
pos_threshold_rank: Literal[0, 1] = 1,
neg_threshold_rank: Literal[0, 1] = 1,
pos_scale_rank: Literal[0, 1] = 1,
neg_scale_rank: Literal[0, 1] = 1,
bias_rank: Literal[0, 1] = 1,
learn_alpha: bool = True,
learn_beta: bool = True,
learn_pos_threshold: bool = True,
learn_neg_threshold: bool = True,
learn_pos_scale: bool = True,
learn_neg_scale: bool = True,
learn_bias: bool = True,
spike_fn=functional.sigmoid4x,
quant_fn=nn.Identity(),
):
super().__init__(num_neurons, dim)
self._initialize_state("syn")
self._register_decay("alpha", alpha, alpha_rank, learn_alpha)
self._initialize_state("mem")
self._register_decay("beta", beta, beta_rank, learn_beta)
self.spike_fn = spike_fn
self.quant_fn = quant_fn
self._register_threshold("pos_threshold", pos_threshold, pos_threshold_rank, learn_pos_threshold)
self._register_threshold("neg_threshold", neg_threshold, neg_threshold_rank, learn_neg_threshold)
self._register_bias("bias", bias, bias_rank, learn_bias)
self._register_parameter("pos_scale", pos_scale, pos_scale_rank, learn_pos_scale)
self._register_parameter("neg_scale", neg_scale, neg_scale_rank, learn_neg_scale)
[docs]
def forward(self, x):
"""Computes the forward pass."""
self._ensure_states(x)
x = self._to_working_dim(x)
syn = self._to_working_dim(self.syn)
syn = syn * self.alpha + x * (1 - self.alpha)
mem = self._to_working_dim(self.mem)
mem = mem * self.beta + syn
pos_spike_prob = self.spike_fn(mem - self.pos_threshold + self.bias)
neg_spike_prob = self.spike_fn(-self.neg_threshold - mem - self.bias)
pos_spikes = self.quant_fn(pos_spike_prob)
neg_spikes = -self.quant_fn(neg_spike_prob)
mem = mem - pos_spikes * self.pos_threshold
mem = mem - neg_spikes * self.neg_threshold
pos_spikes = pos_spikes * self.pos_scale
neg_spikes = neg_spikes * self.neg_scale
spikes = pos_spikes + neg_spikes
spikes = self._from_working_dim(spikes)
self.syn = self._from_working_dim(syn)
self.mem = self._from_working_dim(mem)
return spikes
[docs]
class RLITS(SNNLayer):
r"""A recurrent leaky integrate-and-ternary-scaled-fire layer.
``RLITS`` adds a recurrent trace of the previous scaled ternary output. The
recurrent trace is scaled by ``rec_weight`` and added before membrane
integration; the current output is then scaled separately for positive and
negative events.
Args:
num_neurons (int): number of neurons in the target dimension.
beta (float or torch.Tensor, default=0.9): membrane decay.
gamma (float or torch.Tensor, default=0.9): recurrent trace decay.
pos_threshold (float or torch.Tensor, default=1.0): positive threshold.
neg_threshold (float or torch.Tensor, default=1.0): negative threshold
magnitude.
bias (float or torch.Tensor, default=0.0): bias that shifts both firing
boundaries.
pos_scale (float or torch.Tensor, default=1.0): positive output scale.
neg_scale (float or torch.Tensor, default=1.0): negative output scale.
rec_weight (float or torch.Tensor, default=0.0): recurrent input scale.
dim (int, default=-1): the dimension along which the layer operates.
beta_rank (Literal[0, 1], default=1): scalar or per-neuron membrane
decay.
gamma_rank (Literal[0, 1], default=1): scalar or per-neuron recurrent
decay.
pos_threshold_rank (Literal[0, 1], default=1): scalar or per-neuron
positive threshold.
neg_threshold_rank (Literal[0, 1], default=1): scalar or per-neuron
negative threshold magnitude.
pos_scale_rank (Literal[0, 1], default=1): scalar or per-neuron positive
scale.
neg_scale_rank (Literal[0, 1], default=1): scalar or per-neuron negative
scale.
bias_rank (Literal[0, 1], default=1): scalar or per-neuron bias.
rec_weight_rank (Literal[0, 1], default=1): scalar or per-neuron
recurrent scale.
learn_beta (bool, default=True): whether ``beta`` is trainable.
learn_gamma (bool, default=True): whether ``gamma`` is trainable.
learn_pos_threshold (bool, default=True): whether ``pos_threshold`` is
trainable.
learn_neg_threshold (bool, default=True): whether ``neg_threshold`` is
trainable.
learn_pos_scale (bool, default=True): whether ``pos_scale`` is trainable.
learn_neg_scale (bool, default=True): whether ``neg_scale`` is trainable.
learn_bias (bool, default=True): whether ``bias`` is trainable.
learn_rec_weight (bool, default=True): whether ``rec_weight`` is trainable.
spike_fn (Callable, default=tt.functional.sigmoid4x): spike probability
function.
quant_fn (Callable, default=nn.Identity()): output quantization function.
Attributes:
mem: membrane state.
rec: recurrent trace state.
prev_output: previous returned output.
pos_scale: positive output scale.
neg_scale: negative output scale.
Notes:
- **Input**: tensor of shape ``[*,num_neurons,*]`` where ``num_neurons``
is at index ``dim``.
- **Output**: tensor with the same shape as the input.
``prev_output`` stores the scaled output, so recurrence sees the same
value that downstream layers saw at the previous timestep.
Examples::
>>> layer = tt.snn.RLITS(num_neurons=32)
>>> input = torch.randn(16, 32)
>>> output = layer(input)
>>> print(output.shape)
torch.Size([16, 32])
"""
def __init__(
self,
num_neurons: int,
beta: Union[float, torch.Tensor] = 0.9,
gamma: Union[float, torch.Tensor] = 0.9,
pos_threshold: Union[float, torch.Tensor] = 1.0,
neg_threshold: Union[float, torch.Tensor] = 1.0,
bias: Union[float, torch.Tensor] = 0.0,
pos_scale: Union[float, torch.Tensor] = 1.0,
neg_scale: Union[float, torch.Tensor] = 1.0,
rec_weight: Union[float, torch.Tensor] = 0.0,
dim: int = -1,
beta_rank: Literal[0, 1] = 1,
gamma_rank: Literal[0, 1] = 1,
pos_threshold_rank: Literal[0, 1] = 1,
neg_threshold_rank: Literal[0, 1] = 1,
pos_scale_rank: Literal[0, 1] = 1,
neg_scale_rank: Literal[0, 1] = 1,
bias_rank: Literal[0, 1] = 1,
rec_weight_rank: Literal[0, 1] = 1,
learn_beta: bool = True,
learn_gamma: bool = True,
learn_pos_threshold: bool = True,
learn_neg_threshold: bool = True,
learn_pos_scale: bool = True,
learn_neg_scale: bool = True,
learn_bias: bool = True,
learn_rec_weight: bool = True,
spike_fn=functional.sigmoid4x,
quant_fn=nn.Identity(),
):
super().__init__(num_neurons, dim)
self._initialize_state("mem")
self._register_decay("beta", beta, beta_rank, learn_beta)
self._initialize_state("rec")
self._initialize_state("prev_output")
self._register_decay("gamma", gamma, gamma_rank, learn_gamma)
self.spike_fn = spike_fn
self.quant_fn = quant_fn
self._register_threshold("pos_threshold", pos_threshold, pos_threshold_rank, learn_pos_threshold)
self._register_threshold("neg_threshold", neg_threshold, neg_threshold_rank, learn_neg_threshold)
self._register_bias("bias", bias, bias_rank, learn_bias)
self._register_parameter("pos_scale", pos_scale, pos_scale_rank, learn_pos_scale)
self._register_parameter("neg_scale", neg_scale, neg_scale_rank, learn_neg_scale)
self._register_parameter("rec_weight", rec_weight, rec_weight_rank, learn_rec_weight)
[docs]
def forward(self, x):
"""Computes the forward pass."""
self._ensure_states(x)
x = self._to_working_dim(x)
rec = self._to_working_dim(self.rec)
prev_output = self._to_working_dim(self.prev_output)
rec = rec * self.gamma + prev_output * (1 - self.gamma)
mem_delta = rec * self.rec_weight + x
mem = self._to_working_dim(self.mem)
mem = mem * self.beta + mem_delta
pos_spike_prob = self.spike_fn(mem - self.pos_threshold + self.bias)
neg_spike_prob = self.spike_fn(-self.neg_threshold - mem - self.bias)
pos_spikes = self.quant_fn(pos_spike_prob)
neg_spikes = -self.quant_fn(neg_spike_prob)
mem = mem - pos_spikes * self.pos_threshold
mem = mem - neg_spikes * self.neg_threshold
pos_spikes = pos_spikes * self.pos_scale
neg_spikes = neg_spikes * self.neg_scale
spikes = pos_spikes + neg_spikes
spikes = self._from_working_dim(spikes)
self.rec = self._from_working_dim(rec)
self.mem = self._from_working_dim(mem)
self.prev_output = spikes
return spikes
[docs]
class DSLITS(SNNLayer):
r"""A dual synaptic leaky integrate-and-ternary-scaled-fire layer.
``DSLITS`` combines dual synaptic traces, dual membrane traces, and scaled
ternary output. It is the scaled counterpart of ``DSLIT``.
Args:
num_neurons (int): number of neurons in the target dimension.
pos_alpha (float or torch.Tensor, default=0.5): positive synaptic decay.
neg_alpha (float or torch.Tensor, default=0.5): negative synaptic decay.
pos_beta (float or torch.Tensor, default=0.9): positive membrane decay.
neg_beta (float or torch.Tensor, default=0.9): negative membrane decay.
pos_threshold (float or torch.Tensor, default=1.0): positive threshold.
neg_threshold (float or torch.Tensor, default=1.0): negative threshold
magnitude.
bias (float or torch.Tensor, default=0.0): bias that shifts both firing
boundaries.
pos_scale (float or torch.Tensor, default=1.0): positive output scale.
neg_scale (float or torch.Tensor, default=1.0): negative output scale.
dim (int, default=-1): the dimension along which the layer operates.
pos_alpha_rank (Literal[0, 1], default=1): scalar or per-neuron positive
synaptic decay.
neg_alpha_rank (Literal[0, 1], default=1): scalar or per-neuron negative
synaptic decay.
pos_beta_rank (Literal[0, 1], default=1): scalar or per-neuron positive
membrane decay.
neg_beta_rank (Literal[0, 1], default=1): scalar or per-neuron negative
membrane decay.
pos_threshold_rank (Literal[0, 1], default=1): scalar or per-neuron
positive threshold.
neg_threshold_rank (Literal[0, 1], default=1): scalar or per-neuron
negative threshold magnitude.
pos_scale_rank (Literal[0, 1], default=1): scalar or per-neuron positive
scale.
neg_scale_rank (Literal[0, 1], default=1): scalar or per-neuron negative
scale.
bias_rank (Literal[0, 1], default=1): scalar or per-neuron bias.
learn_pos_alpha (bool, default=True): whether ``pos_alpha`` is trainable.
learn_neg_alpha (bool, default=True): whether ``neg_alpha`` is trainable.
learn_pos_beta (bool, default=True): whether ``pos_beta`` is trainable.
learn_neg_beta (bool, default=True): whether ``neg_beta`` is trainable.
learn_pos_threshold (bool, default=True): whether ``pos_threshold`` is
trainable.
learn_neg_threshold (bool, default=True): whether ``neg_threshold`` is
trainable.
learn_pos_scale (bool, default=True): whether ``pos_scale`` is trainable.
learn_neg_scale (bool, default=True): whether ``neg_scale`` is trainable.
learn_bias (bool, default=True): whether ``bias`` is trainable.
spike_fn (Callable, default=tt.functional.sigmoid4x): spike probability
function.
quant_fn (Callable, default=nn.Identity()): output quantization function.
Attributes:
pos_syn: positive synaptic state.
neg_syn: negative synaptic state.
pos_mem: positive membrane state.
neg_mem: negative membrane state.
pos_scale: positive output scale.
neg_scale: negative output scale.
Notes:
- **Input**: tensor of shape ``[*,num_neurons,*]`` where ``num_neurons``
is at index ``dim``.
- **Output**: tensor with the same shape as the input.
Use this layer when sign-specific input smoothing and sign-specific
output magnitudes are both useful, but no recurrent output trace is
needed.
Examples::
>>> layer = tt.snn.DSLITS(num_neurons=32)
>>> input = torch.randn(16, 32)
>>> output = layer(input)
>>> print(output.shape)
torch.Size([16, 32])
"""
def __init__(
self,
num_neurons: int,
pos_alpha: Union[float, torch.Tensor] = 0.5,
neg_alpha: Union[float, torch.Tensor] = 0.5,
pos_beta: Union[float, torch.Tensor] = 0.9,
neg_beta: Union[float, torch.Tensor] = 0.9,
pos_threshold: Union[float, torch.Tensor] = 1.0,
neg_threshold: Union[float, torch.Tensor] = 1.0,
bias: Union[float, torch.Tensor] = 0.0,
pos_scale: Union[float, torch.Tensor] = 1.0,
neg_scale: Union[float, torch.Tensor] = 1.0,
dim: int = -1,
pos_alpha_rank: Literal[0, 1] = 1,
neg_alpha_rank: Literal[0, 1] = 1,
pos_beta_rank: Literal[0, 1] = 1,
neg_beta_rank: Literal[0, 1] = 1,
pos_threshold_rank: Literal[0, 1] = 1,
neg_threshold_rank: Literal[0, 1] = 1,
pos_scale_rank: Literal[0, 1] = 1,
neg_scale_rank: Literal[0, 1] = 1,
bias_rank: Literal[0, 1] = 1,
learn_pos_alpha: bool = True,
learn_neg_alpha: bool = True,
learn_pos_beta: bool = True,
learn_neg_beta: bool = True,
learn_pos_threshold: bool = True,
learn_neg_threshold: bool = True,
learn_pos_scale: bool = True,
learn_neg_scale: bool = True,
learn_bias: bool = True,
spike_fn=functional.sigmoid4x,
quant_fn=nn.Identity(),
):
super().__init__(num_neurons, dim)
self._initialize_state("pos_syn")
self._initialize_state("neg_syn")
self._register_decay("pos_alpha", pos_alpha, pos_alpha_rank, learn_pos_alpha)
self._register_decay("neg_alpha", neg_alpha, neg_alpha_rank, learn_neg_alpha)
self._initialize_state("pos_mem")
self._initialize_state("neg_mem")
self._register_decay("pos_beta", pos_beta, pos_beta_rank, learn_pos_beta)
self._register_decay("neg_beta", neg_beta, neg_beta_rank, learn_neg_beta)
self.spike_fn = spike_fn
self.quant_fn = quant_fn
self._register_threshold("pos_threshold", pos_threshold, pos_threshold_rank, learn_pos_threshold)
self._register_threshold("neg_threshold", neg_threshold, neg_threshold_rank, learn_neg_threshold)
self._register_bias("bias", bias, bias_rank, learn_bias)
self._register_parameter("pos_scale", pos_scale, pos_scale_rank, learn_pos_scale)
self._register_parameter("neg_scale", neg_scale, neg_scale_rank, learn_neg_scale)
[docs]
def forward(self, x):
"""Computes the forward pass."""
self._ensure_states(x)
x = self._to_working_dim(x)
pos_syn = self._to_working_dim(self.pos_syn)
neg_syn = self._to_working_dim(self.neg_syn)
pos_syn = pos_syn * self.pos_alpha + torch.where(x >= 0, x, 0.0) * (1 - self.pos_alpha)
neg_syn = neg_syn * self.neg_alpha + torch.where(x <= 0, x, 0.0) * (1 - self.neg_alpha)
self.pos_syn = self._from_working_dim(pos_syn)
self.neg_syn = self._from_working_dim(neg_syn)
syn = pos_syn + neg_syn
pos_mem = self._to_working_dim(self.pos_mem)
neg_mem = self._to_working_dim(self.neg_mem)
pos_mem = pos_mem * self.pos_beta + torch.where(syn >= 0, syn, 0.0)
neg_mem = neg_mem * self.neg_beta + torch.where(syn <= 0, syn, 0.0)
mem = pos_mem + neg_mem
pos_spike_prob = self.spike_fn(mem - self.pos_threshold + self.bias)
neg_spike_prob = self.spike_fn(-self.neg_threshold - mem - self.bias)
pos_spikes = self.quant_fn(pos_spike_prob)
neg_spikes = -self.quant_fn(neg_spike_prob)
pos_mem = pos_mem - pos_spikes * self.pos_threshold * 0.5
neg_mem = neg_mem - pos_spikes * self.pos_threshold * 0.5
pos_mem = pos_mem - neg_spikes * self.neg_threshold * 0.5
neg_mem = neg_mem - neg_spikes * self.neg_threshold * 0.5
pos_spikes = pos_spikes * self.pos_scale
neg_spikes = neg_spikes * self.neg_scale
spikes = pos_spikes + neg_spikes
spikes = self._from_working_dim(spikes)
self.pos_mem = self._from_working_dim(pos_mem)
self.neg_mem = self._from_working_dim(neg_mem)
return spikes
[docs]
class DRLITS(SNNLayer):
r"""A dual recurrent leaky integrate-and-ternary-scaled-fire layer.
``DRLITS`` combines dual membrane traces, dual recurrent traces, and scaled
ternary output. It is useful when positive and negative recurrent history
should have different decays, gains, and downstream output magnitudes.
Args:
num_neurons (int): number of neurons in the target dimension.
pos_beta (float or torch.Tensor, default=0.9): positive membrane decay.
neg_beta (float or torch.Tensor, default=0.9): negative membrane decay.
pos_gamma (float or torch.Tensor, default=0.9): positive recurrent decay.
neg_gamma (float or torch.Tensor, default=0.9): negative recurrent decay.
pos_threshold (float or torch.Tensor, default=1.0): positive threshold.
neg_threshold (float or torch.Tensor, default=1.0): negative threshold
magnitude.
bias (float or torch.Tensor, default=0.0): bias that shifts both firing
boundaries.
pos_scale (float or torch.Tensor, default=1.0): positive output scale.
neg_scale (float or torch.Tensor, default=1.0): negative output scale.
pos_rec_weight (float or torch.Tensor, default=0.0): positive recurrent
input scale.
neg_rec_weight (float or torch.Tensor, default=0.0): negative recurrent
input scale.
dim (int, default=-1): the dimension along which the layer operates.
pos_beta_rank (Literal[0, 1], default=1): scalar or per-neuron positive
membrane decay.
neg_beta_rank (Literal[0, 1], default=1): scalar or per-neuron negative
membrane decay.
pos_gamma_rank (Literal[0, 1], default=1): scalar or per-neuron positive
recurrent decay.
neg_gamma_rank (Literal[0, 1], default=1): scalar or per-neuron negative
recurrent decay.
pos_threshold_rank (Literal[0, 1], default=1): scalar or per-neuron
positive threshold.
neg_threshold_rank (Literal[0, 1], default=1): scalar or per-neuron
negative threshold magnitude.
pos_scale_rank (Literal[0, 1], default=1): scalar or per-neuron positive
scale.
neg_scale_rank (Literal[0, 1], default=1): scalar or per-neuron negative
scale.
bias_rank (Literal[0, 1], default=1): scalar or per-neuron bias.
pos_rec_weight_rank (Literal[0, 1], default=1): scalar or per-neuron
positive recurrent scale.
neg_rec_weight_rank (Literal[0, 1], default=1): scalar or per-neuron
negative recurrent scale.
learn_pos_beta (bool, default=True): whether ``pos_beta`` is trainable.
learn_neg_beta (bool, default=True): whether ``neg_beta`` is trainable.
learn_pos_gamma (bool, default=True): whether ``pos_gamma`` is trainable.
learn_neg_gamma (bool, default=True): whether ``neg_gamma`` is trainable.
learn_pos_threshold (bool, default=True): whether ``pos_threshold`` is
trainable.
learn_neg_threshold (bool, default=True): whether ``neg_threshold`` is
trainable.
learn_pos_scale (bool, default=True): whether ``pos_scale`` is trainable.
learn_neg_scale (bool, default=True): whether ``neg_scale`` is trainable.
learn_bias (bool, default=True): whether ``bias`` is trainable.
learn_pos_rec_weight (bool, default=True): whether ``pos_rec_weight`` is
trainable.
learn_neg_rec_weight (bool, default=True): whether ``neg_rec_weight`` is
trainable.
spike_fn (Callable, default=tt.functional.sigmoid4x): spike probability
function.
quant_fn (Callable, default=nn.Identity()): output quantization function.
Attributes:
pos_mem: positive membrane state.
neg_mem: negative membrane state.
pos_rec: positive recurrent trace state.
neg_rec: negative recurrent trace state.
prev_output: previous returned output.
pos_scale: positive output scale.
neg_scale: negative output scale.
Notes:
- **Input**: tensor of shape ``[*,num_neurons,*]`` where ``num_neurons``
is at index ``dim``.
- **Output**: tensor with the same shape as the input.
The previous output stored in ``prev_output`` is already scaled, matching
the value that downstream layers received.
Examples::
>>> layer = tt.snn.DRLITS(num_neurons=32)
>>> input = torch.randn(16, 32)
>>> output = layer(input)
>>> print(output.shape)
torch.Size([16, 32])
"""
def __init__(
self,
num_neurons: int,
pos_beta: Union[float, torch.Tensor] = 0.9,
neg_beta: Union[float, torch.Tensor] = 0.9,
pos_gamma: Union[float, torch.Tensor] = 0.9,
neg_gamma: Union[float, torch.Tensor] = 0.9,
pos_threshold: Union[float, torch.Tensor] = 1.0,
neg_threshold: Union[float, torch.Tensor] = 1.0,
bias: Union[float, torch.Tensor] = 0.0,
pos_scale: Union[float, torch.Tensor] = 1.0,
neg_scale: Union[float, torch.Tensor] = 1.0,
pos_rec_weight: Union[float, torch.Tensor] = 0.0,
neg_rec_weight: Union[float, torch.Tensor] = 0.0,
dim: int = -1,
pos_beta_rank: Literal[0, 1] = 1,
neg_beta_rank: Literal[0, 1] = 1,
pos_gamma_rank: Literal[0, 1] = 1,
neg_gamma_rank: Literal[0, 1] = 1,
pos_threshold_rank: Literal[0, 1] = 1,
neg_threshold_rank: Literal[0, 1] = 1,
pos_scale_rank: Literal[0, 1] = 1,
neg_scale_rank: Literal[0, 1] = 1,
bias_rank: Literal[0, 1] = 1,
pos_rec_weight_rank: Literal[0, 1] = 1,
neg_rec_weight_rank: Literal[0, 1] = 1,
learn_pos_beta: bool = True,
learn_neg_beta: bool = True,
learn_pos_gamma: bool = True,
learn_neg_gamma: bool = True,
learn_pos_threshold: bool = True,
learn_neg_threshold: bool = True,
learn_pos_scale: bool = True,
learn_neg_scale: bool = True,
learn_bias: bool = True,
learn_pos_rec_weight: bool = True,
learn_neg_rec_weight: bool = True,
spike_fn=functional.sigmoid4x,
quant_fn=nn.Identity(),
):
super().__init__(num_neurons, dim)
self._initialize_state("pos_mem")
self._initialize_state("neg_mem")
self._register_decay("pos_beta", pos_beta, pos_beta_rank, learn_pos_beta)
self._register_decay("neg_beta", neg_beta, neg_beta_rank, learn_neg_beta)
self._initialize_state("pos_rec")
self._initialize_state("neg_rec")
self._initialize_state("prev_output")
self._register_decay("pos_gamma", pos_gamma, pos_gamma_rank, learn_pos_gamma)
self._register_decay("neg_gamma", neg_gamma, neg_gamma_rank, learn_neg_gamma)
self.spike_fn = spike_fn
self.quant_fn = quant_fn
self._register_threshold("pos_threshold", pos_threshold, pos_threshold_rank, learn_pos_threshold)
self._register_threshold("neg_threshold", neg_threshold, neg_threshold_rank, learn_neg_threshold)
self._register_bias("bias", bias, bias_rank, learn_bias)
self._register_parameter("pos_scale", pos_scale, pos_scale_rank, learn_pos_scale)
self._register_parameter("neg_scale", neg_scale, neg_scale_rank, learn_neg_scale)
self._register_parameter("pos_rec_weight", pos_rec_weight, pos_rec_weight_rank, learn_pos_rec_weight)
self._register_parameter("neg_rec_weight", neg_rec_weight, neg_rec_weight_rank, learn_neg_rec_weight)
[docs]
def forward(self, x):
"""Computes the forward pass."""
self._ensure_states(x)
x = self._to_working_dim(x)
pos_rec = self._to_working_dim(self.pos_rec)
neg_rec = self._to_working_dim(self.neg_rec)
prev_output = self._to_working_dim(self.prev_output)
pos_rec = pos_rec * self.pos_gamma + torch.where(prev_output >= 0, prev_output, 0.0) * (1 - self.pos_gamma)
neg_rec = neg_rec * self.neg_gamma + torch.where(prev_output <= 0, prev_output, 0.0) * (1 - self.neg_gamma)
self.pos_rec = self._from_working_dim(pos_rec)
self.neg_rec = self._from_working_dim(neg_rec)
rec = pos_rec + neg_rec
pos_mem_delta = torch.where(rec >= 0, rec, 0.0) * self.pos_rec_weight + torch.where(x >= 0, x, 0.0)
neg_mem_delta = torch.where(rec <= 0, rec, 0.0) * self.neg_rec_weight + torch.where(x <= 0, x, 0.0)
pos_mem = self._to_working_dim(self.pos_mem)
neg_mem = self._to_working_dim(self.neg_mem)
pos_mem = pos_mem * self.pos_beta + pos_mem_delta
neg_mem = neg_mem * self.neg_beta + neg_mem_delta
mem = pos_mem + neg_mem
pos_spike_prob = self.spike_fn(mem - self.pos_threshold + self.bias)
neg_spike_prob = self.spike_fn(-self.neg_threshold - mem - self.bias)
pos_spikes = self.quant_fn(pos_spike_prob)
neg_spikes = -self.quant_fn(neg_spike_prob)
pos_mem = pos_mem - pos_spikes * self.pos_threshold * 0.5
neg_mem = neg_mem - pos_spikes * self.pos_threshold * 0.5
pos_mem = pos_mem - neg_spikes * self.neg_threshold * 0.5
neg_mem = neg_mem - neg_spikes * self.neg_threshold * 0.5
pos_spikes = pos_spikes * self.pos_scale
neg_spikes = neg_spikes * self.neg_scale
spikes = pos_spikes + neg_spikes
spikes = self._from_working_dim(spikes)
self.pos_mem = self._from_working_dim(pos_mem)
self.neg_mem = self._from_working_dim(neg_mem)
self.prev_output = spikes
return spikes
[docs]
class SRLITS(SNNLayer):
r"""A synaptic recurrent leaky integrate-and-ternary-scaled-fire layer.
``SRLITS`` combines a synaptic input trace, a recurrent output trace, and
scaled ternary firing. It is the scaled counterpart of ``SRLIT``.
Args:
num_neurons (int): number of neurons in the target dimension.
alpha (float or torch.Tensor, default=0.5): synaptic decay.
beta (float or torch.Tensor, default=0.9): membrane decay.
gamma (float or torch.Tensor, default=0.9): recurrent trace decay.
pos_threshold (float or torch.Tensor, default=1.0): positive threshold.
neg_threshold (float or torch.Tensor, default=1.0): negative threshold
magnitude.
bias (float or torch.Tensor, default=0.0): bias that shifts both firing
boundaries.
pos_scale (float or torch.Tensor, default=1.0): positive output scale.
neg_scale (float or torch.Tensor, default=1.0): negative output scale.
rec_weight (float or torch.Tensor, default=0.0): recurrent input scale.
dim (int, default=-1): the dimension along which the layer operates.
alpha_rank (Literal[0, 1], default=1): scalar or per-neuron synaptic
decay.
beta_rank (Literal[0, 1], default=1): scalar or per-neuron membrane
decay.
gamma_rank (Literal[0, 1], default=1): scalar or per-neuron recurrent
decay.
pos_threshold_rank (Literal[0, 1], default=1): scalar or per-neuron
positive threshold.
neg_threshold_rank (Literal[0, 1], default=1): scalar or per-neuron
negative threshold magnitude.
pos_scale_rank (Literal[0, 1], default=1): scalar or per-neuron positive
scale.
neg_scale_rank (Literal[0, 1], default=1): scalar or per-neuron negative
scale.
bias_rank (Literal[0, 1], default=1): scalar or per-neuron bias.
rec_weight_rank (Literal[0, 1], default=1): scalar or per-neuron
recurrent scale.
learn_alpha (bool, default=True): whether ``alpha`` is trainable.
learn_beta (bool, default=True): whether ``beta`` is trainable.
learn_gamma (bool, default=True): whether ``gamma`` is trainable.
learn_pos_threshold (bool, default=True): whether ``pos_threshold`` is
trainable.
learn_neg_threshold (bool, default=True): whether ``neg_threshold`` is
trainable.
learn_pos_scale (bool, default=True): whether ``pos_scale`` is trainable.
learn_neg_scale (bool, default=True): whether ``neg_scale`` is trainable.
learn_bias (bool, default=True): whether ``bias`` is trainable.
learn_rec_weight (bool, default=True): whether ``rec_weight`` is trainable.
spike_fn (Callable, default=tt.functional.sigmoid4x): spike probability
function.
quant_fn (Callable, default=nn.Identity()): output quantization function.
Attributes:
syn: synaptic state.
mem: membrane state.
rec: recurrent trace state.
prev_output: previous returned output.
pos_scale: positive output scale.
neg_scale: negative output scale.
Notes:
- **Input**: tensor of shape ``[*,num_neurons,*]`` where ``num_neurons``
is at index ``dim``.
- **Output**: tensor with the same shape as the input.
Pseudocode looks as follows:
::
syn = alpha * syn + (1 - alpha) * x
rec = gamma * rec + (1 - gamma) * prev_output
mem = beta * mem + syn + rec_weight * rec
pos = quant_fn(spike_fn(mem - pos_threshold + bias))
neg = -quant_fn(spike_fn(-neg_threshold - mem - bias))
mem = mem - pos * pos_threshold - neg * neg_threshold
prev_output = pos * pos_scale + neg * neg_scale
return prev_output
Examples::
>>> layer = tt.snn.SRLITS(num_neurons=32)
>>> input = torch.randn(16, 32)
>>> output = layer(input)
>>> print(output.shape)
torch.Size([16, 32])
"""
def __init__(
self,
num_neurons: int,
alpha: Union[float, torch.Tensor] = 0.5,
beta: Union[float, torch.Tensor] = 0.9,
gamma: Union[float, torch.Tensor] = 0.9,
pos_threshold: Union[float, torch.Tensor] = 1.0,
neg_threshold: Union[float, torch.Tensor] = 1.0,
bias: Union[float, torch.Tensor] = 0.0,
pos_scale: Union[float, torch.Tensor] = 1.0,
neg_scale: Union[float, torch.Tensor] = 1.0,
rec_weight: Union[float, torch.Tensor] = 0.0,
dim: int = -1,
alpha_rank: Literal[0, 1] = 1,
beta_rank: Literal[0, 1] = 1,
gamma_rank: Literal[0, 1] = 1,
pos_threshold_rank: Literal[0, 1] = 1,
neg_threshold_rank: Literal[0, 1] = 1,
pos_scale_rank: Literal[0, 1] = 1,
neg_scale_rank: Literal[0, 1] = 1,
bias_rank: Literal[0, 1] = 1,
rec_weight_rank: Literal[0, 1] = 1,
learn_alpha: bool = True,
learn_beta: bool = True,
learn_gamma: bool = True,
learn_pos_threshold: bool = True,
learn_neg_threshold: bool = True,
learn_pos_scale: bool = True,
learn_neg_scale: bool = True,
learn_bias: bool = True,
learn_rec_weight: bool = True,
spike_fn=functional.sigmoid4x,
quant_fn=nn.Identity(),
):
super().__init__(num_neurons, dim)
self._initialize_state("syn")
self._register_decay("alpha", alpha, alpha_rank, learn_alpha)
self._initialize_state("mem")
self._register_decay("beta", beta, beta_rank, learn_beta)
self._initialize_state("rec")
self._initialize_state("prev_output")
self._register_decay("gamma", gamma, gamma_rank, learn_gamma)
self.spike_fn = spike_fn
self.quant_fn = quant_fn
self._register_threshold("pos_threshold", pos_threshold, pos_threshold_rank, learn_pos_threshold)
self._register_threshold("neg_threshold", neg_threshold, neg_threshold_rank, learn_neg_threshold)
self._register_bias("bias", bias, bias_rank, learn_bias)
self._register_parameter("pos_scale", pos_scale, pos_scale_rank, learn_pos_scale)
self._register_parameter("neg_scale", neg_scale, neg_scale_rank, learn_neg_scale)
self._register_parameter("rec_weight", rec_weight, rec_weight_rank, learn_rec_weight)
[docs]
def forward(self, x):
"""Computes the forward pass."""
self._ensure_states(x)
x = self._to_working_dim(x)
syn = self._to_working_dim(self.syn)
syn = syn * self.alpha + x * (1 - self.alpha)
rec = self._to_working_dim(self.rec)
prev_output = self._to_working_dim(self.prev_output)
rec = rec * self.gamma + prev_output * (1 - self.gamma)
self.rec = self._from_working_dim(rec)
mem_delta = rec * self.rec_weight + syn
mem = self._to_working_dim(self.mem)
mem = mem * self.beta + mem_delta
pos_spike_prob = self.spike_fn(mem - self.pos_threshold + self.bias)
neg_spike_prob = self.spike_fn(-self.neg_threshold - mem - self.bias)
pos_spikes = self.quant_fn(pos_spike_prob)
neg_spikes = -self.quant_fn(neg_spike_prob)
mem = mem - pos_spikes * self.pos_threshold
mem = mem - neg_spikes * self.neg_threshold
pos_spikes = pos_spikes * self.pos_scale
neg_spikes = neg_spikes * self.neg_scale
spikes = pos_spikes + neg_spikes
spikes = self._from_working_dim(spikes)
self.syn = self._from_working_dim(syn)
self.mem = self._from_working_dim(mem)
self.prev_output = spikes
return spikes
[docs]
class DSRLITS(SNNLayer):
r"""A dual synaptic recurrent leaky integrate-and-ternary-scaled-fire layer.
``DSRLITS`` combines every SNN mechanism provided by traceTorch: dual
sign-specific synaptic traces, dual sign-specific recurrent traces, dual
sign-specific membrane traces, ternary firing, and independent positive and
negative output scales.
Args:
num_neurons (int): number of neurons in the target dimension.
pos_alpha (float or torch.Tensor, default=0.5): positive synaptic decay.
neg_alpha (float or torch.Tensor, default=0.5): negative synaptic decay.
pos_beta (float or torch.Tensor, default=0.9): positive membrane decay.
neg_beta (float or torch.Tensor, default=0.9): negative membrane decay.
pos_gamma (float or torch.Tensor, default=0.9): positive recurrent decay.
neg_gamma (float or torch.Tensor, default=0.9): negative recurrent decay.
pos_threshold (float or torch.Tensor, default=1.0): positive threshold.
neg_threshold (float or torch.Tensor, default=1.0): negative threshold
magnitude.
bias (float or torch.Tensor, default=0.0): bias that shifts both firing
boundaries.
pos_scale (float or torch.Tensor, default=1.0): positive output scale.
neg_scale (float or torch.Tensor, default=1.0): negative output scale.
pos_rec_weight (float or torch.Tensor, default=0.0): positive recurrent
input scale.
neg_rec_weight (float or torch.Tensor, default=0.0): negative recurrent
input scale.
dim (int, default=-1): the dimension along which the layer operates.
pos_alpha_rank (Literal[0, 1], default=1): scalar or per-neuron positive
synaptic decay.
neg_alpha_rank (Literal[0, 1], default=1): scalar or per-neuron negative
synaptic decay.
pos_beta_rank (Literal[0, 1], default=1): scalar or per-neuron positive
membrane decay.
neg_beta_rank (Literal[0, 1], default=1): scalar or per-neuron negative
membrane decay.
pos_gamma_rank (Literal[0, 1], default=1): scalar or per-neuron positive
recurrent decay.
neg_gamma_rank (Literal[0, 1], default=1): scalar or per-neuron negative
recurrent decay.
pos_threshold_rank (Literal[0, 1], default=1): scalar or per-neuron
positive threshold.
neg_threshold_rank (Literal[0, 1], default=1): scalar or per-neuron
negative threshold magnitude.
pos_scale_rank (Literal[0, 1], default=1): scalar or per-neuron positive
scale.
neg_scale_rank (Literal[0, 1], default=1): scalar or per-neuron negative
scale.
bias_rank (Literal[0, 1], default=1): scalar or per-neuron bias.
pos_rec_weight_rank (Literal[0, 1], default=1): scalar or per-neuron
positive recurrent scale.
neg_rec_weight_rank (Literal[0, 1], default=1): scalar or per-neuron
negative recurrent scale.
learn_pos_alpha (bool, default=True): whether ``pos_alpha`` is trainable.
learn_neg_alpha (bool, default=True): whether ``neg_alpha`` is trainable.
learn_pos_beta (bool, default=True): whether ``pos_beta`` is trainable.
learn_neg_beta (bool, default=True): whether ``neg_beta`` is trainable.
learn_pos_gamma (bool, default=True): whether ``pos_gamma`` is trainable.
learn_neg_gamma (bool, default=True): whether ``neg_gamma`` is trainable.
learn_pos_threshold (bool, default=True): whether ``pos_threshold`` is
trainable.
learn_neg_threshold (bool, default=True): whether ``neg_threshold`` is
trainable.
learn_pos_scale (bool, default=True): whether ``pos_scale`` is trainable.
learn_neg_scale (bool, default=True): whether ``neg_scale`` is trainable.
learn_bias (bool, default=True): whether ``bias`` is trainable.
learn_pos_rec_weight (bool, default=True): whether ``pos_rec_weight`` is
trainable.
learn_neg_rec_weight (bool, default=True): whether ``neg_rec_weight`` is
trainable.
spike_fn (Callable, default=tt.functional.sigmoid4x): spike probability
function.
quant_fn (Callable, default=nn.Identity()): output quantization function.
Attributes:
pos_syn: positive synaptic state.
neg_syn: negative synaptic state.
pos_mem: positive membrane state.
neg_mem: negative membrane state.
pos_rec: positive recurrent trace state.
neg_rec: negative recurrent trace state.
prev_output: previous returned output.
pos_scale: positive output scale.
neg_scale: negative output scale.
Notes:
- **Input**: tensor of shape ``[*,num_neurons,*]`` where ``num_neurons``
is at index ``dim``.
- **Output**: tensor with the same shape as the input.
This layer is intentionally dense: it is meant for cases where sign,
input history, output history, thresholding, and output magnitude all
need separate degrees of freedom.
Examples::
>>> layer = tt.snn.DSRLITS(num_neurons=32)
>>> input = torch.randn(16, 32)
>>> output = layer(input)
>>> print(output.shape)
torch.Size([16, 32])
"""
def __init__(
self,
num_neurons: int,
pos_alpha: Union[float, torch.Tensor] = 0.5,
neg_alpha: Union[float, torch.Tensor] = 0.5,
pos_beta: Union[float, torch.Tensor] = 0.9,
neg_beta: Union[float, torch.Tensor] = 0.9,
pos_gamma: Union[float, torch.Tensor] = 0.9,
neg_gamma: Union[float, torch.Tensor] = 0.9,
pos_threshold: Union[float, torch.Tensor] = 1.0,
neg_threshold: Union[float, torch.Tensor] = 1.0,
bias: Union[float, torch.Tensor] = 0.0,
pos_scale: Union[float, torch.Tensor] = 1.0,
neg_scale: Union[float, torch.Tensor] = 1.0,
pos_rec_weight: Union[float, torch.Tensor] = 0.0,
neg_rec_weight: Union[float, torch.Tensor] = 0.0,
dim: int = -1,
pos_alpha_rank: Literal[0, 1] = 1,
neg_alpha_rank: Literal[0, 1] = 1,
pos_beta_rank: Literal[0, 1] = 1,
neg_beta_rank: Literal[0, 1] = 1,
pos_gamma_rank: Literal[0, 1] = 1,
neg_gamma_rank: Literal[0, 1] = 1,
pos_threshold_rank: Literal[0, 1] = 1,
neg_threshold_rank: Literal[0, 1] = 1,
pos_scale_rank: Literal[0, 1] = 1,
neg_scale_rank: Literal[0, 1] = 1,
bias_rank: Literal[0, 1] = 1,
pos_rec_weight_rank: Literal[0, 1] = 1,
neg_rec_weight_rank: Literal[0, 1] = 1,
learn_pos_alpha: bool = True,
learn_neg_alpha: bool = True,
learn_pos_beta: bool = True,
learn_neg_beta: bool = True,
learn_pos_gamma: bool = True,
learn_neg_gamma: bool = True,
learn_pos_threshold: bool = True,
learn_neg_threshold: bool = True,
learn_pos_scale: bool = True,
learn_neg_scale: bool = True,
learn_bias: bool = True,
learn_pos_rec_weight: bool = True,
learn_neg_rec_weight: bool = True,
spike_fn=functional.sigmoid4x,
quant_fn=nn.Identity(),
):
super().__init__(num_neurons, dim)
self._initialize_state("pos_syn")
self._initialize_state("neg_syn")
self._register_decay("pos_alpha", pos_alpha, pos_alpha_rank, learn_pos_alpha)
self._register_decay("neg_alpha", neg_alpha, neg_alpha_rank, learn_neg_alpha)
self._initialize_state("pos_mem")
self._initialize_state("neg_mem")
self._register_decay("pos_beta", pos_beta, pos_beta_rank, learn_pos_beta)
self._register_decay("neg_beta", neg_beta, neg_beta_rank, learn_neg_beta)
self._initialize_state("pos_rec")
self._initialize_state("neg_rec")
self._initialize_state("prev_output")
self._register_decay("pos_gamma", pos_gamma, pos_gamma_rank, learn_pos_gamma)
self._register_decay("neg_gamma", neg_gamma, neg_gamma_rank, learn_neg_gamma)
self.spike_fn = spike_fn
self.quant_fn = quant_fn
self._register_threshold("pos_threshold", pos_threshold, pos_threshold_rank, learn_pos_threshold)
self._register_threshold("neg_threshold", neg_threshold, neg_threshold_rank, learn_neg_threshold)
self._register_bias("bias", bias, bias_rank, learn_bias)
self._register_parameter("pos_scale", pos_scale, pos_scale_rank, learn_pos_scale)
self._register_parameter("neg_scale", neg_scale, neg_scale_rank, learn_neg_scale)
self._register_parameter("pos_rec_weight", pos_rec_weight, pos_rec_weight_rank, learn_pos_rec_weight)
self._register_parameter("neg_rec_weight", neg_rec_weight, neg_rec_weight_rank, learn_neg_rec_weight)
[docs]
def forward(self, x):
"""Computes the forward pass."""
self._ensure_states(x)
x = self._to_working_dim(x)
pos_syn = self._to_working_dim(self.pos_syn)
neg_syn = self._to_working_dim(self.neg_syn)
pos_syn = pos_syn * self.pos_alpha + torch.where(x >= 0, x, 0.0) * (1 - self.pos_alpha)
neg_syn = neg_syn * self.neg_alpha + torch.where(x <= 0, x, 0.0) * (1 - self.neg_alpha)
self.pos_syn = self._from_working_dim(pos_syn)
self.neg_syn = self._from_working_dim(neg_syn)
syn = pos_syn + neg_syn
pos_rec = self._to_working_dim(self.pos_rec)
neg_rec = self._to_working_dim(self.neg_rec)
prev_output = self._to_working_dim(self.prev_output)
pos_rec = pos_rec * self.pos_gamma + torch.where(prev_output >= 0, prev_output, 0.0) * (1 - self.pos_gamma)
neg_rec = neg_rec * self.neg_gamma + torch.where(prev_output <= 0, prev_output, 0.0) * (1 - self.neg_gamma)
self.pos_rec = self._from_working_dim(pos_rec)
self.neg_rec = self._from_working_dim(neg_rec)
rec = pos_rec + neg_rec
pos_mem_delta = torch.where(rec >= 0, rec, 0.0) * self.pos_rec_weight + torch.where(syn >= 0, syn, 0.0)
neg_mem_delta = torch.where(rec <= 0, rec, 0.0) * self.neg_rec_weight + torch.where(syn <= 0, syn, 0.0)
pos_mem = self._to_working_dim(self.pos_mem)
neg_mem = self._to_working_dim(self.neg_mem)
pos_mem = pos_mem * self.pos_beta + pos_mem_delta
neg_mem = neg_mem * self.neg_beta + neg_mem_delta
mem = pos_mem + neg_mem
pos_spike_prob = self.spike_fn(mem - self.pos_threshold + self.bias)
neg_spike_prob = self.spike_fn(-self.neg_threshold - mem - self.bias)
pos_spikes = self.quant_fn(pos_spike_prob)
neg_spikes = -self.quant_fn(neg_spike_prob)
pos_mem = pos_mem - pos_spikes * self.pos_threshold * 0.5
neg_mem = neg_mem - pos_spikes * self.pos_threshold * 0.5
pos_mem = pos_mem - neg_spikes * self.neg_threshold * 0.5
neg_mem = neg_mem - neg_spikes * self.neg_threshold * 0.5
pos_spikes = pos_spikes * self.pos_scale
neg_spikes = neg_spikes * self.neg_scale
spikes = pos_spikes + neg_spikes
spikes = self._from_working_dim(spikes)
self.pos_mem = self._from_working_dim(pos_mem)
self.neg_mem = self._from_working_dim(neg_mem)
self.prev_output = spikes
return spikes