Functional
tt.functional contains small helper functions used by traceTorch layers and useful in experiments.
Decay Helpers
- tracetorch.functional.halflife_to_decay(num_steps: float)[source]
Convert a halflife (how many timesteps are needed to lose half the signal) into a decay (per-timestep multiplicative value).
Mathematically, solves \(d^{\text{num steps}}=0.5\) for \(d\).
- Parameters:
num_steps (float) – the number of steps needed for the signal to halve in magnitude.
Examples:
>>> print(halflife_to_decay(1)) 0.5 >>> print(halflife_to_decay(10)) 0.933032991537 >>> print(halflife_to_decay(100)) 0.993092495437
- tracetorch.functional.decay_to_halflife(decay: float)[source]
Convert a decay (per-timestep multiplicative value) into a halflife (how many timesteps are needed to lose half the signal).
Mathematically, solves \(d^{\text{num steps}}=0.5\) for \(\text{num steps}\).
- Parameters:
decay (float) – the per-timestep multiplicative value.
Examples:
>>> print(decay_to_halflife(0.5)) 1 >>> print(decay_to_halflife(0.9)) 6.57881347896 >>> print(decay_to_halflife(0.99)) 68.9675639365
- tracetorch.functional.timesteps_to_decay(time: float)[source]
Converts a time horizon (how many timesteps a signal persists) into a decay. This is not equivalent to halflife.
- Parameters:
time (float) – how many timesteps a signal persists for.
Examples:
>>> print(timesteps_to_decay(2)) 0.5 >>> print(timesteps_to_decay(10)) 0.9 >>> print(timesteps_to_decay(100)) 0.99
- tracetorch.functional.decay_to_timesteps(decay: float)[source]
Convert a decay into a time horizon: over how many timesteps the signal persists. This is not equivalent to halflife.
- Parameters:
decay (float) – the per-timestep multiplicative value.
Examples:
>>> print(decay_to_timesteps(0.5)) 2 >>> print(decay_to_timesteps(0.9)) 10 >>> print(decay_to_timesteps(0.99)) 100
Parameter Transforms
- tracetorch.functional.sigmoid_inverse(x: torch.Tensor) torch.Tensor[source]
Return the logit transform of a tensor in
(0, 1).traceTorch uses this when registering constrained decay parameters. The raw stored value can be optimized freely, while
torch.sigmoid(raw)recovers the user-facing decay.- Parameters:
x (torch.Tensor) – tensor whose values must lie strictly between zero and one.
- Returns:
unconstrained tensor such that
sigmoid(output) == xup to numerical precision.- Return type:
torch.Tensor
- tracetorch.functional.softplus_inverse(x: torch.Tensor) torch.Tensor[source]
Return the inverse softplus transform of a positive tensor.
traceTorch uses this when registering positive constrained parameters such as SNN thresholds. The raw stored value can be optimized freely, while
softplus(raw)recovers the positive user-facing value.- Parameters:
x (torch.Tensor) – tensor whose values must be strictly positive.
- Returns:
unconstrained tensor such that
softplus(output) == xup to numerical precision.- Return type:
torch.Tensor
- tracetorch.functional.mamba_scale(x: torch.Tensor) torch.Tensor[source]
Return the inverse transform for Mamba-style exponential decay scales.
The forward transform used by the corresponding SSM helper is
exp(log(2) * -exp(raw)). This maps unconstrained raw values into(0, 1)with a half-life-like interpretation.- Parameters:
x (torch.Tensor) – tensor whose values must lie strictly between zero and one.
- Returns:
unconstrained raw scale for the Mamba-style transform.
- Return type:
torch.Tensor
Spike Functions
- tracetorch.functional.sigmoid4x(x: torch.Tensor) torch.Tensor[source]
Apply the default traceTorch smooth spike function.
sigmoid4xis a steeper sigmoid used to turn membrane distance from threshold into a differentiable firing probability. SNN firing layers pass values such asmem - threshold + biasthrough this function before applying theirquant_fn.- Parameters:
x (torch.Tensor) – membrane distance from threshold.
- Returns:
smooth firing probability in
(0, 1).- Return type:
torch.Tensor
Quantizers
- tracetorch.functional.round_ste(step_size=1.0)[source]
Create a deterministic rounding straight-through quantizer.
- Parameters:
step_size (float, default=1.0) – quantization interval. A value of
1maps probabilities to integer-like events.- Returns:
function that rounds in the forward pass and passes gradients through unchanged.
- Return type:
Callable[[torch.Tensor], torch.Tensor]
- tracetorch.functional.stochastic_round_ste(step_size=1.0)[source]
Create a stochastic rounding straight-through quantizer.
- Parameters:
step_size (float, default=1.0) – quantization interval.
- Returns:
function that stochastically rounds in the forward pass and passes gradients through unchanged.
- Return type:
Callable[[torch.Tensor], torch.Tensor]
- tracetorch.functional.probabilistic_ste()[source]
Create the probabilistic straight-through quantizer.
- Returns:
function that samples probability-weighted events in the forward pass and applies the custom probabilistic surrogate gradient in the backward pass.
- Return type:
Callable[[torch.Tensor], torch.Tensor]