Source code for torchidl.implicit_base_model

import torch
from torch import nn
import torch.nn.functional as F
from typing import Optional, Type
from .implicit_function import ImplicitFunctionInf, ImplicitFunction, transpose, project_onto_Linf_ball


[docs] class ImplicitModel(nn.Module): r""" The most basic form of an Implicit Model. **Note**: In conventional deep learning, the batch size typically comes first for inputs :math:`U`, hidden states :math:`X`, and outputs :math:`Y`. We follow this convention, but the model internally transposes these matrices to solve the fixed-point equation. Users can input their data in the usual format, and the output will be returned in the standard format. Args: input_dim (int): Number of input features (:math:`p`). output_dim (int): Number of output features (:math:`q`). hidden_dim (int): Number of hidden features (:math:`n`). is_low_rank (bool, optional): Whether to use low-rank approximation (default: False). rank (int, optional): Rank for low-rank approximation (:math:`r`), required if `is_low_rank` is True. f (Type[ImplicitFunction], optional): The implicit function to use (default: ImplicitFunctionInf for well-posedness). kappa (float, optional): Radius of the L-infinity norm ball (:math:`\kappa`) for well-posedness. (default: 0.99). no_D (bool, optional): Whether to exclude matrix D (default: False). bias (bool, optional): Whether to include a bias term (default: False). mitr (int, optional): Max iterations for the forward pass. (default: 300). grad_mitr (int, optional): Max iterations for gradient computation. (default: 300). tol (float, optional): Convergence tolerance for the forward pass. (default: 3e-6). grad_tol (float, optional): Convergence tolerance for gradients. (default: 3e-6). """ def __init__( self, input_dim: int, output_dim: int, hidden_dim: int, f: Type[ImplicitFunction] = ImplicitFunctionInf, no_D: bool = False, bias: bool = False, mitr: int = 300, grad_mitr: int = 300, tol: float = 3e-6, grad_tol: float = 3e-6, kappa: float = 0.99, is_low_rank: bool = False, rank: Optional[int] = None ) -> None: super().__init__() if is_low_rank and rank is None: raise ValueError("Parameter 'k' is required when 'is_low_rank' is True.") if bias: input_dim += 1 self.hidden_dim: int = hidden_dim self.input_dim: int = input_dim self.output_dim: int = output_dim self.is_low_rank: bool = is_low_rank self.no_D: bool = no_D self.bias: bool = bias if self.is_low_rank: if rank is None: raise ValueError("Rank must be specified when using low-rank approximation.") self.L: nn.Parameter = nn.Parameter(torch.randn(hidden_dim, rank) / hidden_dim) self.R: nn.Parameter = nn.Parameter(torch.randn(hidden_dim, rank) / hidden_dim) else: self.A: nn.Parameter = nn.Parameter(torch.randn(hidden_dim, hidden_dim) / hidden_dim) self.B: nn.Parameter = nn.Parameter(torch.randn(hidden_dim, input_dim) / hidden_dim) self.C: nn.Parameter = nn.Parameter(torch.randn(output_dim, hidden_dim) / hidden_dim) if not self.no_D: self.D: nn.Parameter = nn.Parameter(torch.randn(output_dim, input_dim) / hidden_dim) else: self.D: torch.Tensor = torch.zeros((output_dim, input_dim), requires_grad=False) self.f: ImplicitFunction = f() self.f.set_parameters(mitr=mitr, grad_mitr=grad_mitr, tol=tol, grad_tol=grad_tol, kappa=kappa)
[docs] def forward( self, U: torch.Tensor, X0: Optional[torch.Tensor] = None ) -> torch.Tensor: """ Forward pass of `ImplicitModel`. Args: U (torch.Tensor): Input tensor of shape (batch_size, input_dim). X0 (torch.Tensor, optional): Initial hidden state tensor of shape (batch_size, hidden_dim). Returns: torch.Tensor: The output tensor of shape (batch_size, output_dim). """ if (len(U.size()) == 3): U = U.flatten(1, -1) U = transpose(U) if self.bias: U = F.pad(U, (0, 0, 0, 1), value=1) assert U.shape[0] == self.input_dim, f'Given input size {U.shape[0]} does not match expected input size {self.p}.' m = U.shape[1] X_shape = torch.Size([self.hidden_dim, m]) if X0 is not None: X0 = transpose(X0) assert X0.shape == X_shape else: X0 = torch.zeros(X_shape, dtype=U.dtype, device=U.device) if self.is_low_rank: L_projected = project_onto_Linf_ball(self.L, self.f.kappa) RT_projected = project_onto_Linf_ball(transpose(self.R), self.f.kappa) X = self.f.apply(L_projected @ RT_projected, self.B, X0, U) else: X = self.f.apply(self.A, self.B, X0, U) return transpose(self.C @ X + self.D @ U)