Base Solver¶
- class torchidl.sim.solvers.solver.BaseSolver(*args, **kwargs)[source]¶
Base class for all solvers. All solver implementations must inherit from this class and implement the solve method.
- abstract solve(X, U, Z, Y, model_config)[source]¶
Solve for the implicit model parameters.
- Parameters:
X (np.ndarray) – Post-activation array of shape (n_samples, hidden_dim).
U (np.ndarray) – Input array of shape (n_samples, input_dim).
Z (np.ndarray) – Pre-activation array of shape (n_samples, hidden_dim).
Y (np.ndarray) – Output array of shape (n_samples, output_dim).
model_config (Dict[str, Any]) – Model configuration containing: - activation_fn (Callable): Activation function used by the implicit model - device (str): Device to run computations on (‘cpu’ or ‘cuda’) - atol (float): Absolute tolerance for convergence - kappa (float): Wellposedness condition parameter
- Returns:
- A tuple containing:
A (np.ndarray): Hidden-to-hidden weight matrix of shape (hidden_dim, hidden_dim)
B (np.ndarray): Input-to-hidden weight matrix of shape (hidden_dim, input_dim)
C (np.ndarray): Hidden-to-output weight matrix of shape (output_dim, hidden_dim)
D (np.ndarray): Input-to-output weight matrix of shape (output_dim, input_dim)
- Return type:
Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]