Core
Core classes provide the state-management and parameter-management machinery used by every traceTorch layer.
Model
- class tracetorch.core.Model(*args: Any, **kwargs: Any)[source]
Bases:
ModuleThe superclass used for all traceTorch models. Handles zeroing and detaching, compiling and decompiling, and saving and loading of states across the entire model tree: in PyTorch and python modules.
- save_states() Dict[str, torch.Tensor][source]
Save all hidden states from all Layers in the model.
- Returns:
Dictionary mapping layer_state_name -> tensor, compatible with torch.save()
Examples:
>>> states = model.save_states() >>> torch.save(states, "model_states.pt") # Keys look like: "net.layer1.H", "net.layer2.C", et cetera.
- load_states(states: Dict[str, torch.Tensor], strict: bool = True, device=None) None[source]
Load hidden states into the layers in the model.
- Parameters:
states (Dict) – dictionary from
save_states()ortorch.load().strict (bool, default=True) – if True, raises an error for missing / extra states.
device (str, default=None) – target device for the loaded states. Automatically detected if set to None.
Examples:
>>> states = torch.load("model_states.pt") >>> model.load_states(states)
- zero_states() None[source]
Set all hidden states to None across the entire model tree.
Recursively traverses the model hierarchy to find all traceTorch layers and sets their hidden states to None. This forces lazy re-initialization on the next forward pass with proper tensor shapes.
Notes
Traverses traceTorch models, PyTorch modules and Python containers.
Only affects traceTorch layers that implement state management.
Used for resetting model states between batches or episodes.
- detach_states() None[source]
Detach all hidden states from the computation graph across the entire model tree.
Recursively traverses the model hierarchy to find all traceTorch layers and detaches their hidden states from the computation graph. This enables online learning by preventing gradients from flowing through time.
Notes
Traverses traceTorch models, PyTorch modules and Python containers.
Only affects traceTorch layers that implement state management.
Used for online learning or truncated backpropagation, when you want to break temporal gradients.
Layer
- class tracetorch.core.Layer(*args: Any, **kwargs: Any)[source]
Bases:
ModuleThe superclass used for all traceTorch layers. Handles state management, parameter initialization, compilation and decompilation, moving around tensors to the target dimension.
- Parameters:
num_neurons (int) – the number of neurons the layer is considered to have. When initializing any hidden states or registering parameters via the tracetorch methods, this is the value used.
dim (int, default=-1) – the dimension along which the layer operates.
- detach_states() None[source]
Detach all initialized state tensors from the computation graph if they are not None.