MNIST Classification with SIMΒΆ

The following example shows how to use the idl.sim.sim.SIM class to train a simple classification model on the MNIST dataset.

First, we need to train a simple Feedforward Neural Network on the MNIST dataset.

import torch
import torchvision

# Load MNIST dataset
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('./data', train=True, download=True),
    batch_size=32
)

# Define model
model = torch.nn.Sequential(
    torch.nn.Flatten(),
    torch.nn.Linear(784, 64),
    torch.nn.ReLU(),
    torch.nn.Linear(64, 32),
    torch.nn.ReLU(),
    torch.nn.Linear(32, 10)
)

# Define optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Train model
for epoch in range(10):
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = torch.nn.functional.nll_loss(output, target)
        loss.backward()
        optimizer.step()

# Save model
torch.save(model.state_dict(), 'model.pt')

Now, we can use the idl.sim.sim.SIM class to train the Implicit Model on the MNIST dataset.

import torch
import torchvision
from idl.sim import SIM
from idl.sim.solvers import CVXSolver

# Load MNIST dataset. The dataset is quite large, but we only need a small subset to train our implicit model with the state-driven method.
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('./data', train=True, download=True),
    batch_size=32
)
selected_indices = random.sample(
    range(len(train_loader.dataset)), 2000
)
subset = Subset(train_loader.dataset, selected_indices)
subset_loader = DataLoader(subset, batch_size=1000, shuffle=True)

# Load pretrained explicit model
model = torch.nn.Sequential(
    torch.nn.Flatten(),
    torch.nn.Linear(784, 64),
    torch.nn.ReLU(),
    torch.nn.Linear(64, 32),
    torch.nn.ReLU(),
    torch.nn.Linear(32, 10)
)
model.load_state_dict(torch.load('model.pt'))

# Define SIM model
sim = SIM(activation_fn=torch.nn.functional.relu, device="cuda", dtype=torch.float32)

# Define solver
solver = CVXSolver()

# Train SIM
sim.train(solver=solver, model=model, dataloader=subset_loader)