Core

Core classes provide the state-management and parameter-management machinery used by every traceTorch layer.

Model

class tracetorch.core.Model(*args: Any, **kwargs: Any)[source]

Bases: Module

The superclass used for all traceTorch models. Handles zeroing and detaching, compiling and decompiling, and saving and loading of states across the entire model tree: in PyTorch and python modules.

save_states() Dict[str, torch.Tensor][source]

Save all hidden states from all Layers in the model.

Returns:

Dictionary mapping layer_state_name -> tensor, compatible with torch.save()

Examples:

>>> states = model.save_states()
>>> torch.save(states, "model_states.pt")
# Keys look like: "net.layer1.H", "net.layer2.C", et cetera.
load_states(states: Dict[str, torch.Tensor], strict: bool = True, device=None) None[source]

Load hidden states into the layers in the model.

Parameters:
  • states (Dict) – dictionary from save_states() or torch.load().

  • strict (bool, default=True) – if True, raises an error for missing / extra states.

  • device (str, default=None) – target device for the loaded states. Automatically detected if set to None.

Examples:

>>> states = torch.load("model_states.pt")
>>> model.load_states(states)
zero_states() None[source]

Set all hidden states to None across the entire model tree.

Recursively traverses the model hierarchy to find all traceTorch layers and sets their hidden states to None. This forces lazy re-initialization on the next forward pass with proper tensor shapes.

Notes

  • Traverses traceTorch models, PyTorch modules and Python containers.

  • Only affects traceTorch layers that implement state management.

  • Used for resetting model states between batches or episodes.

detach_states() None[source]

Detach all hidden states from the computation graph across the entire model tree.

Recursively traverses the model hierarchy to find all traceTorch layers and detaches their hidden states from the computation graph. This enables online learning by preventing gradients from flowing through time.

Notes

  • Traverses traceTorch models, PyTorch modules and Python containers.

  • Only affects traceTorch layers that implement state management.

  • Used for online learning or truncated backpropagation, when you want to break temporal gradients.

TTcompile()[source]

Compiles all layers for inference by pre-computing parameters.

Recursively traverses the model hierarchy to find all traceTorch layers and compiles their parameters. This allows a trained model to skip needless computation for each forward pass.

TTdecompile()[source]

Decompiles all layers to restore training capabilities.

Recursively traverses the model hierarchy to find all traceTorch layers and decompiles their parameters. This allows a compiled model to be trained once again.

Layer

class tracetorch.core.Layer(*args: Any, **kwargs: Any)[source]

Bases: Module

The superclass used for all traceTorch layers. Handles state management, parameter initialization, compilation and decompilation, moving around tensors to the target dimension.

Parameters:
  • num_neurons (int) – the number of neurons the layer is considered to have. When initializing any hidden states or registering parameters via the tracetorch methods, this is the value used.

  • dim (int, default=-1) – the dimension along which the layer operates.

detach_states() None[source]

Detach all initialized state tensors from the computation graph if they are not None.

zero_states() None[source]

Set all initialized states to None.

TTcompile() None[source]

Compile the layer for inference by pre-computing parameters.

Notes

All parameters registered via _register_parameter will be optimized, as the activation_fn will be baked in to the activated parameter. Proper training is not possible on a compiled layer.

TTdecompile() None[source]

Decompile the layer to restore training capabilities.

Notes

All parameters registered via _register_parameter will be decompiled, as the inverse_fn will be used to re-create the raw version of the parameter.