Least Square Solver¶
- class torchidl.sim.solvers.least_square.LeastSquareSolver(regen_states=False, tol=1e-06)[source]¶
Solve using numpy.linalg.lstsq. Note: This solver is fast but it cannot handle the wellposeness condition (norm(A) <= kappa).
- Parameters:
- solve(X, U, Z, Y, model_config)[source]¶
Solve with numpy.linalg.lstsq to get an implicit model.
- Parameters:
X (np.ndarray) – Post-activation array.
U (np.ndarray) – Input array.
Z (np.ndarray) – Pre-activation array.
Y (np.ndarray) – Output array.
model_config (Dict[str, Any]) – Model configuration. - activation_fn (Callable): Activation function used by the implicit model. - device (str): Implicit model’s device. - atol (float): Equilibrium function’s tolerance. - kappa (float): Wellposedness condition parameter.
- Returns:
Implicit model’s parameters.
- Return type:
A, B, C, D (np.ndarray)
Example usage:
import torch
from torchidl import SIM
from torchidl import LeastSquareSolver
# Load dataset
dataloader = ...
# Load a pretrained explicit model
explicit_model = ...
explicit_model.load_state_dict(torch.load("checkpoint.pt"))
# Define the SIM model
sim = SIM(activation_fn=torch.nn.functional.relu, device="cuda", dtype=torch.float32)
# Define the solver and solve the state-driven training problem
solver = LeastSquareSolver()
sim.train(solver=solver, model=explicit_model, dataloader=dataloader)