Projected GD Solver with Low-rank Regularization

class torchidl.sim.solvers.projected_gd_lowrank.ProjectedGDLowRankSolver(rank=None, num_epoch=10000, lambda_z=1e-06, lr=0.001, verbose_epoch=100, regen_states=False, tol=1e-06)[source]

Train State-driven Implicit Model using projected gradient descent to force A low-rank and well-posed. A, B are solved using projected gradient descent. C, D are solved using numpy least square solver.

Parameters:
  • rank (int) – Rank of the A.

  • num_epoch (int) – Number of epochs to train A, B.

  • lambda_z (float) – Lasso regularization parameter for Z.

  • lr (float) – Learning rate.

  • verbose_epoch (int) – Number of epochs to print the loss.

  • regen_states (bool, optional) – Whether to regenerate states. Defaults to False.

  • tol (float, optional) – Tolerance for zeroing out weights. Defaults to 1e-6.

solve(X, U, Z, Y, model_config, plot_loss=False)[source]

Solve the implicit model and force A to be low-rank.

Parameters:
  • X (np.ndarray) – Post-activation array.

  • U (np.ndarray) – Input array.

  • Z (np.ndarray) – Pre-activation array.

  • Y (np.ndarray) – Output array.

  • model_config (Dict[str, Any]) – Implicit model configuration. - activation_fn (Callable): Activation function used by the implicit model. - device (str): Implicit model’s device. - atol (float): Equilibrium function’s tolerance. - kappa (float): Wellposedness condition parameter.

  • plot_loss (bool, optional) – Whether to plot the loss. Defaults to False.

Returns:

Implicit model’s parameters.

Return type:

A, B, C, D (np.ndarray)

Example usage:

import torch
from torchidl import SIM
from torchidl import ProjectedGDLowRankSolver

# Load dataset
dataloader = ...

# Load a pretrained explicit model
explicit_model = ...
explicit_model.load_state_dict(torch.load("checkpoint.pt"))

# Define the SIM model
sim = SIM(activation_fn=torch.nn.functional.relu, device="cuda", dtype=torch.float32)

# Define the solver and solve the state-driven training problem
solver = ProjectedGDLowRankSolver()
sim.train(solver=solver, model=explicit_model, dataloader=dataloader)