import logging
import torch
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
from typing import Any, Dict, Optional, Tuple
from .solver import BaseSolver
from ..utils import fixpoint_iteration
logger = logging.getLogger(__name__)
[docs]
class ADMMSolver(BaseSolver):
r"""
ADMM Consensus Solver on a single GPU.
Args:
num_epoch_ab (int, optional): Number of epochs for solving A and B. Defaults to 1000.
num_epoch_cd (int, optional): Number of epochs for solving C and D. Defaults to 100.
lambda_y (float, optional): Lasso regularization parameter for Y. Defaults to 1.0.
lambda_z (float, optional): Lasso regularization parameter for Z. Defaults to 1.0.
rho_ab (float, optional): ADMM's rho parameter for A and B. Defaults to 10.0.
rho_cd (float, optional): ADMM's rho parameter for C and D. Defaults to 10.0.
batch_feature_size (int, optional): Number of columns to solve in each solving iteration. This is used to control the memory usage.
The solver is performed (total_rows // batch_feature_size + 1) times. Defaults to 100.
regen_states (bool, optional): Whether to regenerate states. Defaults to False.
tol (float, optional): Tolerance for zeroing out weights. Defaults to 1e-6.
"""
def __init__(
self,
num_epoch_ab : int = 1000,
num_epoch_cd : int = 100,
lambda_y : float = 1e-6,
lambda_z : float = 1e-6,
rho_ab : float = 10.0,
rho_cd : float = 10.0,
batch_feature_size : int = 100,
regen_states : bool = False,
tol : float = 1e-6,
):
self.num_epoch_ab = num_epoch_ab
self.num_epoch_cd = num_epoch_cd
self.lambda_y = lambda_y
self.lambda_z = lambda_z
self.rho_ab = rho_ab
self.rho_cd = rho_cd
self.batch_feature_size = batch_feature_size
self.regen_states = regen_states
self.tol = tol
[docs]
def solve(
self,
X : np.ndarray,
U : np.ndarray,
Z : np.ndarray,
Y : np.ndarray,
model_config : Dict[str, Any],
plot_loss : Optional[bool] = False,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
Solve State-driven Implicit Model using the ADMM consensus algorithm.
Args:
X (np.ndarray): Post-activation array.
U (np.ndarray): Input array.
Z (np.ndarray): Pre-activation array.
Y (np.ndarray): Output array.
model_config (Dict[str, Any]): Model configuration.
- activation_fn (Callable): Activation function used by the implicit model.
- device (str): Implicit model's device.
- atol (float): Equilibrium function's tolerance.
- kappa (float): Wellposedness condition parameter.
plot_loss (Optional[bool], optional): Whether to plot loss curve. Useful for tuning hyperparameters and debugging. Defaults to False.
Returns:
A, B, C, D (np.ndarray): Implicit model's parameters.
"""
n, m, p, q = X.shape[0], X.shape[1], U.shape[0], Y.shape[0]
X, U, Z, Y = torch.tensor(X), torch.tensor(U), torch.tensor(Z), torch.tensor(Y)
logger.info("===== Start solving A and B =====")
AB = self.parallel_solve_matrix(torch.hstack([X.T, U.T]), Z.T, is_y=False, n=n, kappa=model_config['kappa'], plot_loss=plot_loss, device=model_config['device'])
A = AB[:, :n]
B = AB[:, n:]
if self.regen_states:
X = fixpoint_iteration(A, B, U, model_config['activation_fn'], model_config['device'], atol=model_config['atol']).cpu()
logger.info("===== Start solving C and D =====")
CD = self.parallel_solve_matrix(torch.hstack([X.T, U.T]), Y.T, is_y=True, n=n, kappa=model_config['kappa'], plot_loss=plot_loss, device=model_config['device'])
C = CD[:, :n]
D = CD[:, n:]
return A, B, C, D
def parallel_solve_matrix(self, X, Y, is_y, n, kappa, plot_loss, device):
"""
Dividing the data matrix for memory controlled solve.
"""
total_rows = Y.shape[1]
batch_rows_length = self.batch_feature_size
num_batches = total_rows // (batch_rows_length) + 1
W = None
loss = 0.0
for k in range(num_batches):
logger.info(f"Solving batch feature {k+1}/{num_batches}")
start_idx = k * batch_rows_length
end_idx = min((k + 1) * batch_rows_length, total_rows)
Y_batch = Y[:, start_idx:end_idx]
W_k, loss_k = self.run_solve_opt_problem(X, Y_batch, is_y, n, k, kappa, plot_loss=plot_loss, device=device)
W = np.vstack([W, W_k]) if W is not None else W_k
loss += loss_k
logger.info(f"Total Lasso loss: {loss}")
return W
def run_solve_opt_problem(self, X, Y, is_y, n, k, kappa, plot_loss, device):
"""
ADMM Solve Wrapper.
"""
if is_y:
num_epoch = self.num_epoch_cd
rho = self.rho_cd
lambda_yz = self.lambda_y
else:
num_epoch = self.num_epoch_ab
rho = self.rho_ab
lambda_yz = self.lambda_z
if is_y:
admm = ADMM_CD(X.shape[1], Y.shape[1], rho, lambda_yz, device=device)
else:
admm = ADMM_AB(X.shape[1], Y.shape[1], n, rho, lambda_yz, kappa, device=device)
losses = []
with torch.no_grad():
for i in tqdm(range(num_epoch), desc="Training Epochs"):
admm.step(X, Y)
loss = admm.LassoObjective(X, Y)
tqdm.write(f"Loss: {loss}")
losses.append(loss)
# Plot losses
if plot_loss:
plt.figure()
plt.plot(losses)
plt.xlabel("Epoch")
plt.ylabel("Lasso Objective")
plt.yscale("log")
plt.title(f"Training Loss")
plt.savefig(f"loss_k_{k}_isy_{is_y}.png")
# Save the loss trace
np.save(f"loss_trace_k_{k}_isy_{is_y}.npy", losses)
if is_y:
result = admm.CD.T.clone().detach().cpu().numpy()
else:
result = admm.avg.T.clone().detach().cpu().numpy()
result[np.abs(result) <= self.tol] = 0
return result, losses[-1]
class ADMM_AB:
def __init__(
self,
D : int,
Q : int,
n : int,
rho : float,
lambda_yz : float,
kappa : float,
device : str
):
"""
Solver class to solve A and B (concatenated) using consensus ADMM.
Args:
D (int): Number of features.
Q (int): Number of samples.
n (int): Number of rows of A.
rho (float): ADMM's rho parameter.
lambda_yz (float): Lasso regularization parameter.
kappa (float): Wellposedness condition parameter.
device (str): Device.
"""
self.D = D
self.Q = Q
self.n = n
self.device = device
self.nu_AB = torch.zeros(self.D, self.Q, device=device, requires_grad=False) # ADMM's dual variable for AB
self.nu_Z = torch.zeros(self.D, self.Q, device=device, requires_grad=False) # ADMM's dual variable for Z
self.nu_M = torch.zeros(self.D, self.Q, device=device, requires_grad=False) # ADMM's dual variable for M
self.rho = rho
self.AB = torch.randn(self.D, self.Q, device=device, requires_grad=False) # running estimate of AB
self.Z = torch.zeros(self.D, self.Q, device=device, requires_grad=False) # Z
self.M = torch.zeros(self.D, self.Q, device=device, requires_grad=False) # M
self.avg = torch.zeros(self.D, self.Q, device=device, requires_grad=False) # Average of AB, Z, M => Final solution of AB
self.lambda_yz = lambda_yz
self.kappa = kappa
@torch.no_grad()
def step(self, X, y):
"""
ADMM's update step.
"""
X = X.to(self.device)
y = y.to(self.device)
t1 = X.T.matmul(X) + self.rho * torch.eye(self.D, device=self.device)
t2 = X.T.matmul(y) + self.rho * (self.avg - self.nu_AB)
self.AB = torch.linalg.solve(t1, t2)
self.Z = torch.sign(self.avg - self.nu_Z) * torch.clamp(torch.abs(self.avg - self.nu_Z) - self.lambda_yz / self.rho, min=0)
self.M[:self.n,:] = self.project_w((self.avg - self.nu_M)[:self.n,:].T).T
self.M[self.n:,:] = (self.avg - self.nu_M)[self.n:,:]
self.avg = (self.AB + self.Z + self.M) / 3
self.nu_AB = self.nu_AB + (self.AB - self.avg)
self.nu_Z = self.nu_Z + (self.Z - self.avg)
self.nu_M = self.nu_M + (self.M - self.avg)
def project_w(self, matrix):
"""
Project the matrix to the L1 norm ball.
"""
A_np = matrix.clone().cpu().numpy()
v = self.kappa
x = np.abs(A_np).sum(axis=-1)
for idx in np.where(x > v)[0]:
a_orig = A_np[idx, :]
a_sign = np.sign(a_orig)
a_abs = np.abs(a_orig)
a = np.sort(a_abs)
s = np.sum(a) - v
l = float(len(a))
for i in range(len(a)):
if s / l > a[i]:
s -= a[i]
l -= 1
else:
break
alpha = s / l if l > 0 else np.max(a_abs)
a = a_sign * np.maximum(a_abs - alpha, 0)
# assert np.isclose(np.abs(a).sum(), v)
A_np[idx, :] = a
proj = torch.tensor(A_np, dtype=self.avg.dtype, device=self.device)
return proj
@torch.no_grad()
def LassoObjective(self, X, y):
"""
Evaluate the Lasso objective.
"""
X = X.to(self.device)
y = y.to(self.device)
return (0.5 * torch.norm(X.matmul(self.avg) - y)**2 + self.lambda_yz * torch.sum(torch.abs(self.avg))).item()
class ADMM_CD:
def __init__(
self,
D : int,
Q : int,
rho : float,
lambda_yz : float,
device : str
):
"""
Solver class to solve C and D (concatenated) using ADMM.
Args:
D (int): Number of features.
Q (int): Number of samples.
rho (float): ADMM's rho parameter.
lambda_yz (float): Lasso regularization parameter.
device (str): Device.
"""
self.D = D
self.Q = Q
self.device = device
self.nu = torch.zeros(self.D, self.Q, device=device)
self.rho = rho
self.CD = torch.randn(self.D, self.Q, device=device)
self.Z = torch.zeros(self.D, self.Q, device=device)
self.lambda_yz = lambda_yz
@torch.no_grad()
def step(self, X, y):
"""
ADMM's update step.
"""
X = X.to(self.device)
y = y.to(self.device)
t1 = X.T.matmul(X) + self.rho * torch.eye(self.D, device=self.device)
t2 = X.T.matmul(y) + self.rho * self.Z - self.nu
self.CD = torch.linalg.solve(t1, t2)
self.Z = self.CD + self.nu / self.rho - (self.lambda_yz / self.rho) * torch.sign(self.Z).to(self.device)
self.nu = self.nu + self.rho * (self.CD - self.Z)
@torch.no_grad()
def LassoObjective(self, X, y):
"""
Evaluate the Lasso objective.
"""
X = X.to(self.device)
y = y.to(self.device)
return (0.5 * torch.norm(X.matmul(self.CD) - y)**2 + self.lambda_yz * torch.sum(torch.abs(self.CD))).item()