MNIST Examples

The MNIST examples live in examples/mnist. They all train three comparable models: an SNN, a GRU-based RNN, and an S6-based SSM. This makes the examples useful for seeing how traceTorch layers can be swapped while the training loop stays nearly identical.

Setup:

cd examples/mnist
pip install -r requirements.txt

Rate-coded MNIST

Run:

python rate_coded.py

The rate-coded script presents the same MNIST image for several timesteps. At each timestep, the input image is converted to a binary sample:

for t in range(num_timesteps):
    spk_image = torch.bernoulli(image)
    output = model(spk_image)
    loss = loss_fn(output, label)
    running_loss += loss

The SNN model is:

class RateSNN(tt.Model):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(784, 128),
            tt.snn.LIB(128, beta=torch.rand(128), threshold=torch.rand(128)),
            nn.Linear(128, 128),
            tt.snn.LIB(128, beta=torch.rand(128), threshold=torch.rand(128)),
            nn.Linear(128, 10),
        )

    def forward(self, x):
        return self.net(x)

The important traceTorch mechanics are model.zero_states() before each image batch and repeated forward calls over the timestep loop.

Sequential MNIST

Run:

python sequential.py

The sequential script turns each image into a sequence of patches. A 4x4 patch becomes one timestep with 16 input features. The model sees a scrambled sequence of local image observations and must produce a final classification.

The SNN uses recurrent binary layers:

self.net = nn.Sequential(
    nn.Linear(kernel_size ** 2, 128),
    tt.snn.RLIB(128, beta=torch.rand(128), gamma=torch.rand(128), threshold=torch.rand(128)),
    nn.Linear(128, 128),
    tt.snn.RLIB(128, beta=torch.rand(128), gamma=torch.rand(128), threshold=torch.rand(128)),
    nn.Linear(128, 10),
)

This is a good example of when R layers make sense: the output at one patch can influence the membrane update at the next patch.

Noisy MNIST

Run:

python noisy.py

The noisy script repeatedly corrupts the same input image and trains the model on several noisy observations:

noise_level = torch.rand_like(image) ** 0.5
noisy_image = image * noise_level + (torch.randn_like(image) + 0.5) * (1 - noise_level)
output = model(noisy_image)

This example is useful for understanding traceTorch as a temporal evidence accumulator. The model receives multiple imperfect observations and can use its hidden states to build a more stable representation over time.

Reading the plots

Each script plots training and evaluation loss/accuracy for the SNN, RNN, and SSM variants. The examples are intended for comparison and code clarity, not for state-of-the-art MNIST results.