Compiling and Decompiling
Some traceTorch parameters are stored in an unconstrained raw form and transformed when used. Decays, for example, are
stored as raw values and passed through a sigmoid so the effective decay stays in (0, 1). Thresholds are passed
through a softplus so they stay positive.
This is useful for training, but a trained inference model does not need to recompute those transforms every forward
pass. TTcompile() bakes the transformed values into buffers. TTdecompile() restores the raw trainable form.
Basic use
model = Net().to(device)
# train normally
...
model.eval()
model.TTcompile()
with torch.no_grad():
output = model(x)
After compilation, parameters registered through traceTorch’s parameter helpers are no longer trainable raw parameters. Use compiled models for inference, not training.
Returning to training
model.TTdecompile()
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
TTdecompile() uses each parameter’s inverse transform to recreate the raw parameter or buffer.
What gets compiled
Only parameters registered through traceTorch’s registration helpers are compiled. Ordinary PyTorch parameters, such as
nn.Linear.weight, are unaffected.
For an SNN layer, this usually includes values such as:
beta/alpha/gammadecays;thresholds;
biases;
some scale or recurrent-weight parameters, depending on the layer.
State and compilation are separate
Compilation changes parameter storage, not hidden state values. You can compile a model with or without existing hidden states.
For clean inference, a common pattern is:
model.eval()
model.zero_states()
model.TTcompile()
with torch.no_grad():
for t in range(sequence.size(0)):
output = model(sequence[t])
Checking equivalence
Compiled and decompiled models should produce the same outputs for the same weights and same initial states. A minimal sanity check is:
model.zero_states()
baseline = model(x)
model.TTcompile()
model.zero_states()
compiled = model(x)
assert torch.allclose(baseline, compiled)
model.TTdecompile()
model.zero_states()
decompiled = model(x)
assert torch.allclose(baseline, decompiled)