ADMM Solvers

Alternating Direction Method of Multipliers (ADMM) is a powerful optimization algorithm for solving constrained convex problems. This section presents the derivation of ADMM update rules for SIM training. The algorithm is implemented in torchidl.sim.solvers.admm_consensus.ADMMSolver() and idl.sim.solvers.admm_consensus_multigpu.ADMMMultiGPUSolver().

Introduction to ADMM

Dual Problem

Consider the constrained convex optimization problem:

\[ \begin{align}\begin{aligned}\text{minimize} \quad f(x)\\\text{subject to} \quad Ax = b\end{aligned}\end{align} \]

The dual approach involves:

  • Lagrangian: \(L(x, y) = f(x) + y^T (Ax - b)\)

  • Dual function: \(g(y) = \inf_x L(x, y)\)

  • Dual problem: \(\text{maximize} \ g(y)\)

  • Primal recovery: \(x^\star = \arg\min_x L(x, y^\star)\)

Dual Ascent

To solve the dual problem, we apply the Gradient Ascent method:

\[y_{k+1} = y_k + \alpha_k \nabla g(y_k)\]

The gradient of the dual function is:

\[ \begin{align}\begin{aligned}\nabla g(y_k) = Ax^\ast - b\\x^\ast = \arg\min_x L(x, y_k)\end{aligned}\end{align} \]

Therefore, the update rules to solve the problem are:

\[\begin{split}x_{k+1} &:= \arg\min_x L(x, y_k) \\ y_{k+1} &:= y_k + \alpha_k (Ax_{k+1} - b)\end{split}\]

We continue updating until \(Ax_{k+1} - b \rightarrow 0\).

Method of Multipliers

Powell (1969) introduced the augmented Lagrangian, with hyperparameter \(\rho > 0\):

\[L_{\rho}(x, y) = f(x) + y^T (Ax - b) + (\rho / 2) \|Ax - b\|_2^2\]

With this augmented Lagrangian, the update rules become:

\[\begin{split}x_{k+1} &:= \arg\min_x L_{\rho}(x, y_k) \\ y_{k+1} &:= y_k + \rho (Ax_{k+1} - b)\end{split}\]

We continue updating until \(Ax_{k+1} - b \rightarrow 0\).

Alternating Direction Method of Multipliers (ADMM)

ADMM addresses problems of the form:

\[ \begin{align}\begin{aligned}\text{minimize} \quad f(x) + g(z)\\\text{subject to} \quad Ax + Bz = c\end{aligned}\end{align} \]

The augmented Lagrangian, with \(\rho > 0\), is:

\[L_{\rho}(x, z, y) = f(x) + g(z) + y^T (Ax + Bz - c) + (\rho / 2) \|Ax + Bz - c\|_2^2\]

Instead of solving for \(x\) and \(z\) jointly, ADMM applies the Gauss-Seidel method to solve them separately:

\[\begin{split}x_{k+1} &:= \arg\min_x L_{\rho}(x, z_k, y_k)\\ z_{k+1} &:= \arg\min_z L_{\rho}(x_{k+1}, z, y_k) \\ y_{k+1} &:= y_k + \rho (Ax_{k+1} + Bz_{k+1} - c)\end{split}\]

We continue updating until \(Ax_k + Bz_k - c \to 0\).

State-driven Implicit Modeling (SIM)

Implicit Deep Learning Model

An implicit model is defined as:

\[ \begin{align}\begin{aligned}x = \phi(Ax + Bu) \quad \text{[equilibrium equation]}\\\hat{y}(u) = Cx + Du \quad \text{[prediction equation]}\end{aligned}\end{align} \]

SIM Training Problem Formulation

Given an explicit neural network, we can transform it into an implicit model by solving the following optimization problem:

\[ \begin{align}\begin{aligned}\min_{M} \quad f(M)\\\text{s.t.} \quad Z = AX + BU,\\\quad \quad \hat{Y} = CX + DU,\\\quad \quad \|A\|_\infty \leq \kappa.\end{aligned}\end{align} \]

Here, \(M\) is the stacked weight matrix, \(M = \begin{pmatrix} A & B \\ C & D \end{pmatrix}\). The matrices \(A, B, C, D\) are the modified weights of the explicit model. \(Z\) represents the pre-activation matrix, \(U\) is the input matrix, and \(X\) is the post-activation matrix (\(X = \sigma(Z)\)).

To achieve sparsity, the objective function can be set to \(f(M) = \|M\|_\infty = \left\| \begin{pmatrix} A & B \\ C & D \end{pmatrix} \right\|_\infty\).

The problem can be solved row by row, as each row’s computation is independent of others, enabling parallel implementation. Moreover, since the problem minimizes the max function in the \(l_\infty\) norm, it is equivalent to minimizing each element.

We solve for \(A, B\) first, and then for \(C, D\) later (the update rules are identical):

\[ \begin{align}\begin{aligned}\begin{split}\min_{a,b} \quad & f\left(\begin{pmatrix} a \\ b \end{pmatrix}\right) = \left\|\begin{pmatrix} a \\ b \end{pmatrix}\right\|_1 \\\end{split}\\\begin{split}\text{s.t.} \quad & z = \begin{pmatrix} X^T & U^T \end{pmatrix} \begin{pmatrix} a \\ b \end{pmatrix}, \\\end{split}\\& \|a\|_1 \leq \kappa.\end{aligned}\end{align} \]

Solving SIM with ADMM

The row-form problem is almost identical to this problem, where we relax the inequality constraint. Note that \(\|a\|_1 \leq \left \| \begin{pmatrix} a \\ b \end{pmatrix} \right\|_1 = \|a\|_1 + \|b\|_1 \leq \kappa.\)

\[ \begin{align}\begin{aligned}\begin{split}\min_{a,b} & \quad \frac{1}{2\lambda_1} \left \| \begin{pmatrix} a \\ b \end{pmatrix} \right\|_1 + \frac{1}{2} \left\| z - \begin{pmatrix} X^T U^T \end{pmatrix} \begin{pmatrix} a \\ b \end{pmatrix} \right\|_2^2\end{split}\\\begin{split}\text{s.t.}&\quad \left\| \begin{pmatrix} a \\ b \end{pmatrix} \right\|_1 \leq k.\end{split}\end{aligned}\end{align} \]

This is equivalent to:

\[ \begin{align}\begin{aligned}\begin{split}\min_{a,b} & \quad \frac{1}{2} \left\| \begin{pmatrix} X^T U^T \end{pmatrix} \begin{pmatrix} a \\ b \end{pmatrix} - z \right\|_2^2 + \lambda \left \| \begin{pmatrix} a \\ b \end{pmatrix} \right\|_1\end{split}\\\begin{split}\text{s.t.}&\quad \left\| \begin{pmatrix} a \\ b \end{pmatrix} \right\|_1 \leq k.\end{split}\end{aligned}\end{align} \]

Global consensus form:

With \(\beta = \begin{pmatrix} a \\ b \end{pmatrix}\) and \(\Phi = \begin{pmatrix} X^T U^T \end{pmatrix}\), we express the problem in global consensus form:

\[\begin{split}\min & \quad \frac{1}{2}\|\Phi \beta_1 - z\|_2^2 + \lambda \|\beta_2\|_1 + I_C(\beta_3) \\ \text{s.t.} & \quad \beta_1 = \beta_2 = \beta_3 \\\end{split}\]

ADMM update rules:

\[ \begin{align}\begin{aligned}\beta_1^{k+1} &= \arg\min_{\beta_1} \left( \frac{1}{2}\|\Phi \beta_1 - z\|_2^2 + \frac{\rho}{2} \|\beta_1 - \bar{\beta}^k + u_1^k\|_2^2 \right)\\\beta_2^{k+1} &= \arg\min_{\beta_2} \left( \lambda \|\beta_2\|_1 + \frac{\rho}{2} \|\beta_2 - \bar{\beta}^k + u_2^k\|_2^2 \right)\\\beta_3^{k+1} &= \arg\min_{\beta_3} \left( I_C(\beta_3) + \frac{\rho}{2} \|\beta_3 - \bar{\beta}^k + u_3^k\|_2^2 \right)\\\bar{\beta}^{k+1} &= \frac{1}{3} \sum_{i=1}^{3} \beta_i^k\\u_i^{k+1} &= u_i^k + \beta_i^{k+1} - \bar{\beta}^{k+1} \quad (i = 1,2,3)\end{aligned}\end{align} \]

The closed-form update rules are derived as:

\[ \begin{align}\begin{aligned}\beta_1^{k+1} &= \left( \Phi^T \Phi + \rho I \right)^{-1} \left( \Phi^T z + \rho \left( \bar{\beta}^k - u_1^k \right) \right)\\\beta_2^{k+1} &= \mathcal{S} \left( \bar{\beta}^{k} - u_2^k, \frac{\lambda}{\rho} \right), \quad \text{where } \mathcal{S}(z, a) = z - \max \left( \min(z, a), -a \right)\\\beta_3^{k+1} &= \text{Proj}_C (\bar{\beta}^{k} - u_3^k)\end{aligned}\end{align} \]

The projection operation can be performed efficiently by projecting the first n rows of \((\bar{\beta}^{k} - u_3^k)\) onto the norm ball with radius \(\kappa\) and copying the remaining elements of \((\bar{\beta}^{k} - u_3^k)\) to \(\beta_3^{k+1}\).

API Reference

class torchidl.sim.solvers.admm_consensus.ADMMSolver(num_epoch_ab=1000, num_epoch_cd=100, lambda_y=1e-06, lambda_z=1e-06, rho_ab=10.0, rho_cd=10.0, batch_feature_size=100, regen_states=False, tol=1e-06)[source]

ADMM Consensus Solver on a single GPU.

Parameters:
  • num_epoch_ab (int, optional) – Number of epochs for solving A and B. Defaults to 1000.

  • num_epoch_cd (int, optional) – Number of epochs for solving C and D. Defaults to 100.

  • lambda_y (float, optional) – Lasso regularization parameter for Y. Defaults to 1.0.

  • lambda_z (float, optional) – Lasso regularization parameter for Z. Defaults to 1.0.

  • rho_ab (float, optional) – ADMM’s rho parameter for A and B. Defaults to 10.0.

  • rho_cd (float, optional) – ADMM’s rho parameter for C and D. Defaults to 10.0.

  • batch_feature_size (int, optional) – Number of columns to solve in each solving iteration. This is used to control the memory usage. The solver is performed (total_rows // batch_feature_size + 1) times. Defaults to 100.

  • regen_states (bool, optional) – Whether to regenerate states. Defaults to False.

  • tol (float, optional) – Tolerance for zeroing out weights. Defaults to 1e-6.

solve(X, U, Z, Y, model_config, plot_loss=False)[source]

Solve State-driven Implicit Model using the ADMM consensus algorithm.

Parameters:
  • X (np.ndarray) – Post-activation array.

  • U (np.ndarray) – Input array.

  • Z (np.ndarray) – Pre-activation array.

  • Y (np.ndarray) – Output array.

  • model_config (Dict[str, Any]) – Model configuration. - activation_fn (Callable): Activation function used by the implicit model. - device (str): Implicit model’s device. - atol (float): Equilibrium function’s tolerance. - kappa (float): Wellposedness condition parameter.

  • plot_loss (Optional[bool], optional) – Whether to plot loss curve. Useful for tuning hyperparameters and debugging. Defaults to False.

Returns:

Implicit model’s parameters.

Return type:

A, B, C, D (np.ndarray)

class torchidl.sim.solvers.admm_consensus_multigpu.ADMMMultiGPUSolver(gpu_ids, num_epoch_ab=1000, num_epoch_cd=100, lambda_y=1e-06, lambda_z=1e-06, rho_ab=10.0, rho_cd=10.0, batch_feature_size=100, regen_states=False, tol=1e-06)[source]

ADMM Consensus Solver with distributed computation over multiple GPUs. This is highly recommended for large-scale problems and when multiple GPUs are available.

Parameters:
  • gpu_ids (List[Union[int, torch.device]]) – List of GPU IDs or devices.

  • num_epoch_ab (int, optional) – Number of epochs for solving A and B. Defaults to 1000.

  • num_epoch_cd (int, optional) – Number of epochs for solving C and D. Defaults to 100.

  • lambda_y (float, optional) – Lasso regularization parameter for Y. Defaults to 1.0.

  • lambda_z (float, optional) – Lasso regularization parameter for Z. Defaults to 1.0.

  • rho_ab (float, optional) – ADMM’s rho parameter for A and B. Defaults to 10.0.

  • rho_cd (float, optional) – ADMM’s rho parameter for C and D. Defaults to 10.0.

  • batch_feature_size (int, optional) – Number of columns to solve in each solving iteration. This is used to control the memory usage. The solver is performed (total_rows // batch_feature_size + 1) times. Defaults to 100.

  • regen_states (bool, optional) – Whether to regenerate states. Defaults to False.

  • tol (float, optional) – Tolerance for zeroing out weights. Defaults to 1e-6.

solve(X, U, Z, Y, model_config, plot_loss=False)[source]

Solve State-driven Implicit Model using the ADMM consensus algorithm.

Parameters:
  • X (np.ndarray) – Post-activation array.

  • U (np.ndarray) – Input array.

  • Z (np.ndarray) – Pre-activation array.

  • Y (np.ndarray) – Output array.

  • model_config (Dict[str, Any]) – Model configuration. - activation_fn (Callable): Activation function used by the implicit model. - device (str): Implicit model’s device. - atol (float): Equilibrium function’s tolerance. - kappa (float): Wellposedness condition parameter.

  • plot_loss (Optional[bool], optional) – Whether to plot loss curve. Useful for tuning hyperparameters and debugging. Defaults to False.

Returns:

Implicit model’s parameters.

Return type:

A, B, C, D (np.ndarray)

Example usage:

import torch
from torchidl import SIM
from torchidl import ADMMSolver

# Load dataset
dataloader = ...

# Load a pretrained explicit model
explicit_model = ...
explicit_model.load_state_dict(torch.load("checkpoint.pt"))

# Define the SIM model
sim = SIM(activation_fn=torch.nn.functional.relu, device="cuda", dtype=torch.float32)

# Define the solver and solve the state-driven training problem
solver = ADMMSolver()
sim.train(solver=solver, model=explicit_model, dataloader=dataloader)