Implicit RNN¶
Implicit Recurrent Neural Network extend the implicit modeling framework to sequential data processing. Unlike traditional RNNs that update hidden states using explicit linear transformations, ImplicitRNN uses an implicit layer to define recurrence in a standard RNN framework.
Theoretical Foundation¶
- Given the following dimensions:
\(p\): input features dimension
\(q\): output features dimension
\(n\): hidden state dimension
\(T\): sequence length
An Implicit RNN processes a sequence of inputs \(\{u_t\}_{t=1}^T\) where \(u_t \in \mathbb{R}^p\) to produce a sequence of outputs \(\{\hat{y}_t\}_{t=1}^T\) where \(\hat{y}_t \in \mathbb{R}^q\) by solving:
The final hidden state \(H_T\) is projected to the output:
- where:
\(A \in \mathbb{R}^{n \times n}, B \in \mathbb{R}^{n \times (p+n)}, C \in \mathbb{R}^{n \times n}, D \in \mathbb{R}^{n \times p}\) are learnable parameters,
\(U_t \in \mathbb{R}^{m \times p}\) is the input at timestep \(t\),
\(X_t \in \mathbb{R}^{m \times n}\) is the implicit hidden state solved via a fixed-point equation,
\(H_t \in \mathbb{R}^{m \times n}\) is the RNN hidden state at timestep \(t\),
\(\phi: \mathbb{R}^{n \times m} \to \mathbb{R}^{n \times m}\) is an activation function (default is ReLU),
:math:` text{Linear}` is a linear transformation that turns \(H_T\) into \(\hat{Y}\).
API Reference¶
- class torchidl.implicit_rnn_model.ImplicitRNN(input_dim, output_dim, implicit_hidden_dim, hidden_dim, **kwargs)[source]¶
Implicit Recurrent Neural Networks.
- Parameters:
input_dim (int) – Number of input features (\(p\)).
output_dim (int) – Number of output features (\(q\)).
implicit_hidden_dim (int) – Hidden dimension in the implicit layer (\(n\)).
hidden_dim (int) – Size of the recurrent hidden state (\(n\)).
**kwargs (Any) – Additional keyword arguments for ImplicitModel.
Example usage:
import torch
from torchidl import ImplicitRNN
x = torch.randn(100, 60, 1) # (batch_size=100, seq_len=60, input_dim=1)
model = ImplicitRNN(input_dim=1,
output_dim=1,
hidden_dim=128,
implicit_hidden_dim=64)
output = model(x) # (batch_size=100, output_dim=1)