Source code for torchidl.sim.solvers.solver
from abc import ABC, abstractmethod
from typing import Dict, Any, Tuple
import numpy as np
[docs]
class BaseSolver(ABC):
r"""
Base class for all solvers. All solver implementations must inherit from this class
and implement the solve method.
"""
def __init__(self, *args, **kwargs):
pass
[docs]
@abstractmethod
def solve(
self,
X: np.ndarray,
U: np.ndarray,
Z: np.ndarray,
Y: np.ndarray,
model_config: Dict[str, Any]
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Solve for the implicit model parameters.
Args:
X (np.ndarray): Post-activation array of shape (n_samples, hidden_dim).
U (np.ndarray): Input array of shape (n_samples, input_dim).
Z (np.ndarray): Pre-activation array of shape (n_samples, hidden_dim).
Y (np.ndarray): Output array of shape (n_samples, output_dim).
model_config (Dict[str, Any]): Model configuration containing:
- activation_fn (Callable): Activation function used by the implicit model
- device (str): Device to run computations on ('cpu' or 'cuda')
- atol (float): Absolute tolerance for convergence
- kappa (float): Wellposedness condition parameter
Returns:
Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: A tuple containing:
- A (np.ndarray): Hidden-to-hidden weight matrix of shape (hidden_dim, hidden_dim)
- B (np.ndarray): Input-to-hidden weight matrix of shape (hidden_dim, input_dim)
- C (np.ndarray): Hidden-to-output weight matrix of shape (output_dim, hidden_dim)
- D (np.ndarray): Input-to-output weight matrix of shape (output_dim, input_dim)
"""
raise NotImplementedError