ADMM Solvers¶
Alternating Direction Method of Multipliers (ADMM) is a powerful optimization algorithm for solving constrained convex problems.
This section presents the derivation of ADMM update rules for SIM training. The algorithm is implemented in torchidl.sim.solvers.admm_consensus.ADMMSolver() and idl.sim.solvers.admm_consensus_multigpu.ADMMMultiGPUSolver().
Introduction to ADMM¶
Dual Problem
Consider the constrained convex optimization problem:
The dual approach involves:
Lagrangian: \(L(x, y) = f(x) + y^T (Ax - b)\)
Dual function: \(g(y) = \inf_x L(x, y)\)
Dual problem: \(\text{maximize} \ g(y)\)
Primal recovery: \(x^\star = \arg\min_x L(x, y^\star)\)
Dual Ascent
To solve the dual problem, we apply the Gradient Ascent method:
The gradient of the dual function is:
Therefore, the update rules to solve the problem are:
We continue updating until \(Ax_{k+1} - b \rightarrow 0\).
Method of Multipliers
Powell (1969) introduced the augmented Lagrangian, with hyperparameter \(\rho > 0\):
With this augmented Lagrangian, the update rules become:
We continue updating until \(Ax_{k+1} - b \rightarrow 0\).
Alternating Direction Method of Multipliers (ADMM)
ADMM addresses problems of the form:
The augmented Lagrangian, with \(\rho > 0\), is:
Instead of solving for \(x\) and \(z\) jointly, ADMM applies the Gauss-Seidel method to solve them separately:
We continue updating until \(Ax_k + Bz_k - c \to 0\).
State-driven Implicit Modeling (SIM)¶
Implicit Deep Learning Model
An implicit model is defined as:
SIM Training Problem Formulation
Given an explicit neural network, we can transform it into an implicit model by solving the following optimization problem:
Here, \(M\) is the stacked weight matrix, \(M = \begin{pmatrix} A & B \\ C & D \end{pmatrix}\). The matrices \(A, B, C, D\) are the modified weights of the explicit model. \(Z\) represents the pre-activation matrix, \(U\) is the input matrix, and \(X\) is the post-activation matrix (\(X = \sigma(Z)\)).
To achieve sparsity, the objective function can be set to \(f(M) = \|M\|_\infty = \left\| \begin{pmatrix} A & B \\ C & D \end{pmatrix} \right\|_\infty\).
The problem can be solved row by row, as each row’s computation is independent of others, enabling parallel implementation. Moreover, since the problem minimizes the max function in the \(l_\infty\) norm, it is equivalent to minimizing each element.
We solve for \(A, B\) first, and then for \(C, D\) later (the update rules are identical):
Solving SIM with ADMM¶
The row-form problem is almost identical to this problem, where we relax the inequality constraint. Note that \(\|a\|_1 \leq \left \| \begin{pmatrix} a \\ b \end{pmatrix} \right\|_1 = \|a\|_1 + \|b\|_1 \leq \kappa.\)
This is equivalent to:
Global consensus form:
With \(\beta = \begin{pmatrix} a \\ b \end{pmatrix}\) and \(\Phi = \begin{pmatrix} X^T U^T \end{pmatrix}\), we express the problem in global consensus form:
ADMM update rules:
The closed-form update rules are derived as:
The projection operation can be performed efficiently by projecting the first n rows of \((\bar{\beta}^{k} - u_3^k)\) onto the norm ball with radius \(\kappa\) and copying the remaining elements of \((\bar{\beta}^{k} - u_3^k)\) to \(\beta_3^{k+1}\).
API Reference¶
- class torchidl.sim.solvers.admm_consensus.ADMMSolver(num_epoch_ab=1000, num_epoch_cd=100, lambda_y=1e-06, lambda_z=1e-06, rho_ab=10.0, rho_cd=10.0, batch_feature_size=100, regen_states=False, tol=1e-06)[source]¶
ADMM Consensus Solver on a single GPU.
- Parameters:
num_epoch_ab (int, optional) – Number of epochs for solving A and B. Defaults to 1000.
num_epoch_cd (int, optional) – Number of epochs for solving C and D. Defaults to 100.
lambda_y (float, optional) – Lasso regularization parameter for Y. Defaults to 1.0.
lambda_z (float, optional) – Lasso regularization parameter for Z. Defaults to 1.0.
rho_ab (float, optional) – ADMM’s rho parameter for A and B. Defaults to 10.0.
rho_cd (float, optional) – ADMM’s rho parameter for C and D. Defaults to 10.0.
batch_feature_size (int, optional) – Number of columns to solve in each solving iteration. This is used to control the memory usage. The solver is performed (total_rows // batch_feature_size + 1) times. Defaults to 100.
regen_states (bool, optional) – Whether to regenerate states. Defaults to False.
tol (float, optional) – Tolerance for zeroing out weights. Defaults to 1e-6.
- solve(X, U, Z, Y, model_config, plot_loss=False)[source]¶
Solve State-driven Implicit Model using the ADMM consensus algorithm.
- Parameters:
X (np.ndarray) – Post-activation array.
U (np.ndarray) – Input array.
Z (np.ndarray) – Pre-activation array.
Y (np.ndarray) – Output array.
model_config (Dict[str, Any]) – Model configuration. - activation_fn (Callable): Activation function used by the implicit model. - device (str): Implicit model’s device. - atol (float): Equilibrium function’s tolerance. - kappa (float): Wellposedness condition parameter.
plot_loss (Optional[bool], optional) – Whether to plot loss curve. Useful for tuning hyperparameters and debugging. Defaults to False.
- Returns:
Implicit model’s parameters.
- Return type:
A, B, C, D (np.ndarray)
- class torchidl.sim.solvers.admm_consensus_multigpu.ADMMMultiGPUSolver(gpu_ids, num_epoch_ab=1000, num_epoch_cd=100, lambda_y=1e-06, lambda_z=1e-06, rho_ab=10.0, rho_cd=10.0, batch_feature_size=100, regen_states=False, tol=1e-06)[source]¶
ADMM Consensus Solver with distributed computation over multiple GPUs. This is highly recommended for large-scale problems and when multiple GPUs are available.
- Parameters:
gpu_ids (List[Union[int, torch.device]]) – List of GPU IDs or devices.
num_epoch_ab (int, optional) – Number of epochs for solving A and B. Defaults to 1000.
num_epoch_cd (int, optional) – Number of epochs for solving C and D. Defaults to 100.
lambda_y (float, optional) – Lasso regularization parameter for Y. Defaults to 1.0.
lambda_z (float, optional) – Lasso regularization parameter for Z. Defaults to 1.0.
rho_ab (float, optional) – ADMM’s rho parameter for A and B. Defaults to 10.0.
rho_cd (float, optional) – ADMM’s rho parameter for C and D. Defaults to 10.0.
batch_feature_size (int, optional) – Number of columns to solve in each solving iteration. This is used to control the memory usage. The solver is performed (total_rows // batch_feature_size + 1) times. Defaults to 100.
regen_states (bool, optional) – Whether to regenerate states. Defaults to False.
tol (float, optional) – Tolerance for zeroing out weights. Defaults to 1e-6.
- solve(X, U, Z, Y, model_config, plot_loss=False)[source]¶
Solve State-driven Implicit Model using the ADMM consensus algorithm.
- Parameters:
X (np.ndarray) – Post-activation array.
U (np.ndarray) – Input array.
Z (np.ndarray) – Pre-activation array.
Y (np.ndarray) – Output array.
model_config (Dict[str, Any]) – Model configuration. - activation_fn (Callable): Activation function used by the implicit model. - device (str): Implicit model’s device. - atol (float): Equilibrium function’s tolerance. - kappa (float): Wellposedness condition parameter.
plot_loss (Optional[bool], optional) – Whether to plot loss curve. Useful for tuning hyperparameters and debugging. Defaults to False.
- Returns:
Implicit model’s parameters.
- Return type:
A, B, C, D (np.ndarray)
Example usage:
import torch
from torchidl import SIM
from torchidl import ADMMSolver
# Load dataset
dataloader = ...
# Load a pretrained explicit model
explicit_model = ...
explicit_model.load_state_dict(torch.load("checkpoint.pt"))
# Define the SIM model
sim = SIM(activation_fn=torch.nn.functional.relu, device="cuda", dtype=torch.float32)
# Define the solver and solve the state-driven training problem
solver = ADMMSolver()
sim.train(solver=solver, model=explicit_model, dataloader=dataloader)