Functional

tt.functional contains small helper functions used by traceTorch layers and useful in experiments.

Decay Helpers

tracetorch.functional.halflife_to_decay(num_steps: float)[source]

Convert a halflife (how many timesteps are needed to lose half the signal) into a decay (per-timestep multiplicative value).

Mathematically, solves \(d^{\text{num steps}}=0.5\) for \(d\).

Parameters:

num_steps (float) – the number of steps needed for the signal to halve in magnitude.

Examples:

>>> print(halflife_to_decay(1))
0.5

>>> print(halflife_to_decay(10))
0.933032991537

>>> print(halflife_to_decay(100))
0.993092495437
tracetorch.functional.decay_to_halflife(decay: float)[source]

Convert a decay (per-timestep multiplicative value) into a halflife (how many timesteps are needed to lose half the signal).

Mathematically, solves \(d^{\text{num steps}}=0.5\) for \(\text{num steps}\).

Parameters:

decay (float) – the per-timestep multiplicative value.

Examples:

>>> print(decay_to_halflife(0.5))
1

>>> print(decay_to_halflife(0.9))
6.57881347896

>>> print(decay_to_halflife(0.99))
68.9675639365
tracetorch.functional.timesteps_to_decay(time: float)[source]

Converts a time horizon (how many timesteps a signal persists) into a decay. This is not equivalent to halflife.

Parameters:

time (float) – how many timesteps a signal persists for.

Examples:

>>> print(timesteps_to_decay(2))
0.5

>>> print(timesteps_to_decay(10))
0.9

>>> print(timesteps_to_decay(100))
0.99
tracetorch.functional.decay_to_timesteps(decay: float)[source]

Convert a decay into a time horizon: over how many timesteps the signal persists. This is not equivalent to halflife.

Parameters:

decay (float) – the per-timestep multiplicative value.

Examples:

>>> print(decay_to_timesteps(0.5))
2

>>> print(decay_to_timesteps(0.9))
10

>>> print(decay_to_timesteps(0.99))
100

Parameter Transforms

tracetorch.functional.sigmoid_inverse(x: torch.Tensor) torch.Tensor[source]

Return the logit transform of a tensor in (0, 1).

traceTorch uses this when registering constrained decay parameters. The raw stored value can be optimized freely, while torch.sigmoid(raw) recovers the user-facing decay.

Parameters:

x (torch.Tensor) – tensor whose values must lie strictly between zero and one.

Returns:

unconstrained tensor such that sigmoid(output) == x up to numerical precision.

Return type:

torch.Tensor

tracetorch.functional.softplus_inverse(x: torch.Tensor) torch.Tensor[source]

Return the inverse softplus transform of a positive tensor.

traceTorch uses this when registering positive constrained parameters such as SNN thresholds. The raw stored value can be optimized freely, while softplus(raw) recovers the positive user-facing value.

Parameters:

x (torch.Tensor) – tensor whose values must be strictly positive.

Returns:

unconstrained tensor such that softplus(output) == x up to numerical precision.

Return type:

torch.Tensor

tracetorch.functional.mamba_scale(x: torch.Tensor) torch.Tensor[source]

Return the inverse transform for Mamba-style exponential decay scales.

The forward transform used by the corresponding SSM helper is exp(log(2) * -exp(raw)). This maps unconstrained raw values into (0, 1) with a half-life-like interpretation.

Parameters:

x (torch.Tensor) – tensor whose values must lie strictly between zero and one.

Returns:

unconstrained raw scale for the Mamba-style transform.

Return type:

torch.Tensor

Spike Functions

tracetorch.functional.sigmoid4x(x: torch.Tensor) torch.Tensor[source]

Apply the default traceTorch smooth spike function.

sigmoid4x is a steeper sigmoid used to turn membrane distance from threshold into a differentiable firing probability. SNN firing layers pass values such as mem - threshold + bias through this function before applying their quant_fn.

Parameters:

x (torch.Tensor) – membrane distance from threshold.

Returns:

smooth firing probability in (0, 1).

Return type:

torch.Tensor

Quantizers

tracetorch.functional.round_ste(step_size=1.0)[source]

Create a deterministic rounding straight-through quantizer.

Parameters:

step_size (float, default=1.0) – quantization interval. A value of 1 maps probabilities to integer-like events.

Returns:

function that rounds in the forward pass and passes gradients through unchanged.

Return type:

Callable[[torch.Tensor], torch.Tensor]

tracetorch.functional.stochastic_round_ste(step_size=1.0)[source]

Create a stochastic rounding straight-through quantizer.

Parameters:

step_size (float, default=1.0) – quantization interval.

Returns:

function that stochastically rounds in the forward pass and passes gradients through unchanged.

Return type:

Callable[[torch.Tensor], torch.Tensor]

tracetorch.functional.probabilistic_ste()[source]

Create the probabilistic straight-through quantizer.

Returns:

function that samples probability-weighted events in the forward pass and applies the custom probabilistic surrogate gradient in the backward pass.

Return type:

Callable[[torch.Tensor], torch.Tensor]