Least Square Solver

class torchidl.sim.solvers.least_square.LeastSquareSolver(regen_states=False, tol=1e-06)[source]

Solve using numpy.linalg.lstsq. Note: This solver is fast but it cannot handle the wellposeness condition (norm(A) <= kappa).

Parameters:
  • regen_states (bool, optional) – Whether to regenerate state data to solve exact C,D after solving A,B. Defaults to False.

  • tol (float, optional) – Zero out weights that are less than tol. Defaults to 1e-6.

solve(X, U, Z, Y, model_config)[source]

Solve with numpy.linalg.lstsq to get an implicit model.

Parameters:
  • X (np.ndarray) – Post-activation array.

  • U (np.ndarray) – Input array.

  • Z (np.ndarray) – Pre-activation array.

  • Y (np.ndarray) – Output array.

  • model_config (Dict[str, Any]) – Model configuration. - activation_fn (Callable): Activation function used by the implicit model. - device (str): Implicit model’s device. - atol (float): Equilibrium function’s tolerance. - kappa (float): Wellposedness condition parameter.

Returns:

Implicit model’s parameters.

Return type:

A, B, C, D (np.ndarray)

Example usage:

import torch
from torchidl import SIM
from torchidl import LeastSquareSolver

# Load dataset
dataloader = ...

# Load a pretrained explicit model
explicit_model = ...
explicit_model.load_state_dict(torch.load("checkpoint.pt"))

# Define the SIM model
sim = SIM(activation_fn=torch.nn.functional.relu, device="cuda", dtype=torch.float32)

# Define the solver and solve the state-driven training problem
solver = LeastSquareSolver()
sim.train(solver=solver, model=explicit_model, dataloader=dataloader)