Projected GD Solver with Low-rank Regularization¶
- class torchidl.sim.solvers.projected_gd_lowrank.ProjectedGDLowRankSolver(rank=None, num_epoch=10000, lambda_z=1e-06, lr=0.001, verbose_epoch=100, regen_states=False, tol=1e-06)[source]¶
Train State-driven Implicit Model using projected gradient descent to force A low-rank and well-posed. A, B are solved using projected gradient descent. C, D are solved using numpy least square solver.
- Parameters:
rank (int) – Rank of the A.
num_epoch (int) – Number of epochs to train A, B.
lambda_z (float) – Lasso regularization parameter for Z.
lr (float) – Learning rate.
verbose_epoch (int) – Number of epochs to print the loss.
regen_states (bool, optional) – Whether to regenerate states. Defaults to False.
tol (float, optional) – Tolerance for zeroing out weights. Defaults to 1e-6.
- solve(X, U, Z, Y, model_config, plot_loss=False)[source]¶
Solve the implicit model and force A to be low-rank.
- Parameters:
X (np.ndarray) – Post-activation array.
U (np.ndarray) – Input array.
Z (np.ndarray) – Pre-activation array.
Y (np.ndarray) – Output array.
model_config (Dict[str, Any]) – Implicit model configuration. - activation_fn (Callable): Activation function used by the implicit model. - device (str): Implicit model’s device. - atol (float): Equilibrium function’s tolerance. - kappa (float): Wellposedness condition parameter.
plot_loss (bool, optional) – Whether to plot the loss. Defaults to False.
- Returns:
Implicit model’s parameters.
- Return type:
A, B, C, D (np.ndarray)
Example usage:
import torch
from torchidl import SIM
from torchidl import ProjectedGDLowRankSolver
# Load dataset
dataloader = ...
# Load a pretrained explicit model
explicit_model = ...
explicit_model.load_state_dict(torch.load("checkpoint.pt"))
# Define the SIM model
sim = SIM(activation_fn=torch.nn.functional.relu, device="cuda", dtype=torch.float32)
# Define the solver and solve the state-driven training problem
solver = ProjectedGDLowRankSolver()
sim.train(solver=solver, model=explicit_model, dataloader=dataloader)