Implicit Models¶
Implicit Deep Learning introduces a novel class of models based on fixed-point prediction rules, as formalized in “Implicit Deep Learning”.
Theoretical Foundation¶
- Given the following dimensions:
\(p\): input features dimension
\(q\): output features dimension
\(n\): hidden state dimension
\(m\): batch size
An implicit model processes an input matrix \(U \in \mathbb{R}^{p \times m}\) to produce an output matrix \(\hat{Y} \in \mathbb{R}^{q \times m}\) by solving:
- where:
\(A \in \mathbb{R}^{n \times n}, B \in \mathbb{R}^{n \times p}, C \in \mathbb{R}^{q \times n}, D \in \mathbb{R}^{q \times p}\) are learnable parameters,
\(X \in \mathbb{R}^{n \times m}\) is the matrix containing the hidden states of the model,
\(\phi: \mathbb{R}^{n \times m} \to \mathbb{R}^{n \times m}\) is a non-linear activation function (default: ReLU).
Low-Rank Parameterization¶
For memory efficiency and improved generalization, the weight matrix \(A\) can be parameterized as a low-rank product:
where \(L, R \in \mathbb{R}^{n \times r}\) with rank \(r \ll n\).
Well-Posedness Condition¶
To guarantee the existence and uniqueness of solutions to the equilibrium equation, the model must satisfy:
This constraint ensures the fixed-point iteration converges to a unique solution.
API Reference¶
- class torchidl.implicit_base_model.ImplicitModel(input_dim, output_dim, hidden_dim, f=<class 'torchidl.implicit_function.ImplicitFunctionInf'>, no_D=False, bias=False, mitr=300, grad_mitr=300, tol=3e-06, grad_tol=3e-06, kappa=0.99, is_low_rank=False, rank=None)[source]¶
The most basic form of an Implicit Model.
Note: In conventional deep learning, the batch size typically comes first for inputs \(U\), hidden states \(X\), and outputs \(Y\). We follow this convention, but the model internally transposes these matrices to solve the fixed-point equation. Users can input their data in the usual format, and the output will be returned in the standard format.
- Parameters:
input_dim (int) – Number of input features (\(p\)).
output_dim (int) – Number of output features (\(q\)).
hidden_dim (int) – Number of hidden features (\(n\)).
is_low_rank (bool, optional) – Whether to use low-rank approximation (default: False).
rank (int, optional) – Rank for low-rank approximation (\(r\)), required if is_low_rank is True.
f (Type[ImplicitFunction], optional) – The implicit function to use (default: ImplicitFunctionInf for well-posedness).
kappa (float, optional) – Radius of the L-infinity norm ball (\(\kappa\)) for well-posedness. (default: 0.99).
no_D (bool, optional) – Whether to exclude matrix D (default: False).
bias (bool, optional) – Whether to include a bias term (default: False).
mitr (int, optional) – Max iterations for the forward pass. (default: 300).
grad_mitr (int, optional) – Max iterations for gradient computation. (default: 300).
tol (float, optional) – Convergence tolerance for the forward pass. (default: 3e-6).
grad_tol (float, optional) – Convergence tolerance for gradients. (default: 3e-6).
- forward(U, X0=None)[source]¶
Forward pass of ImplicitModel.
- Parameters:
U (torch.Tensor) – Input tensor of shape (batch_size, input_dim).
X0 (torch.Tensor, optional) – Initial hidden state tensor of shape (batch_size, hidden_dim).
- Returns:
The output tensor of shape (batch_size, output_dim).
- Return type:
Example usage:
import torch
from torchidl import ImplicitModel
x = torch.randn(5, 64) # (batch_size=5, input_dim=64)
model = ImplicitModel(input_dim=64,
output_dim=10,
hidden_dim=128)
output = model(x) # (batch_size=5, output_dim=10)