Compiling and Decompiling

Some traceTorch parameters are stored in an unconstrained raw form and transformed when used. Decays, for example, are stored as raw values and passed through a sigmoid so the effective decay stays in (0, 1). Thresholds are passed through a softplus so they stay positive.

This is useful for training, but a trained inference model does not need to recompute those transforms every forward pass. TTcompile() bakes the transformed values into buffers. TTdecompile() restores the raw trainable form.

Basic use

model = Net().to(device)

# train normally
...

model.eval()
model.TTcompile()

with torch.no_grad():
    output = model(x)

After compilation, parameters registered through traceTorch’s parameter helpers are no longer trainable raw parameters. Use compiled models for inference, not training.

Returning to training

model.TTdecompile()
model.train()

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

TTdecompile() uses each parameter’s inverse transform to recreate the raw parameter or buffer.

What gets compiled

Only parameters registered through traceTorch’s registration helpers are compiled. Ordinary PyTorch parameters, such as nn.Linear.weight, are unaffected.

For an SNN layer, this usually includes values such as:

  • beta / alpha / gamma decays;

  • thresholds;

  • biases;

  • some scale or recurrent-weight parameters, depending on the layer.

State and compilation are separate

Compilation changes parameter storage, not hidden state values. You can compile a model with or without existing hidden states.

For clean inference, a common pattern is:

model.eval()
model.zero_states()
model.TTcompile()

with torch.no_grad():
    for t in range(sequence.size(0)):
        output = model(sequence[t])

Checking equivalence

Compiled and decompiled models should produce the same outputs for the same weights and same initial states. A minimal sanity check is:

model.zero_states()
baseline = model(x)

model.TTcompile()
model.zero_states()
compiled = model(x)

assert torch.allclose(baseline, compiled)

model.TTdecompile()
model.zero_states()
decompiled = model(x)

assert torch.allclose(baseline, decompiled)