Saving and Loading States

PyTorch already gives you state_dict() for parameters. traceTorch adds a separate state dictionary for hidden states: membrane traces, recurrent traces, RNN hidden states, SSM states, and so on.

Why states are separate

Parameters describe the model. Hidden states describe where the model currently is in a sequence.

For many training jobs, you do not need to save hidden states at all. You call zero_states() for every independent sequence and train from a fresh state. State saving becomes useful when:

  • doing streaming inference;

  • checkpointing the middle of a long sequence;

  • resuming online learning;

  • comparing two models from the exact same hidden state.

Saving

Run the model first so states exist. States are lazily initialized, so a freshly constructed model has no tensors to save.

model.eval()
model.zero_states()

with torch.no_grad():
    for t in range(sequence.size(0)):
        output = model(sequence[t])

weights = model.state_dict()
states = model.save_states()

torch.save(weights, "weights.pt")
torch.save(states, "states.pt")

The keys in states are path-like names that point to the layer and state, such as "net.2.mem".

Loading

Load parameters and states separately:

model = Net().to(device)
model.load_state_dict(torch.load("weights.pt", map_location=device))
model.load_states(torch.load("states.pt", map_location=device), strict=False, device=device)

strict=False is useful when not every state exists yet or when you intentionally load only part of a model’s state. Use strict=True when you expect the saved state dictionary to exactly match the current model.

Shape checking

If a layer already has a state tensor and the loaded state has a different shape, traceTorch raises a ValueError. This prevents silent state corruption when batch size, target dimension, or model structure changed.

Continuing a sequence

After loading states, continue calling the model normally:

model.eval()
with torch.no_grad():
    next_output = model(next_timestep)

Do not call zero_states() after loading unless you intentionally want to discard the loaded states.