Source code for torchidl.sim.solvers.projected_gd_lowrank

import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
from typing import Any, Dict, Tuple, Optional

from .solver import BaseSolver
from ..utils import fixpoint_iteration

logger = logging.getLogger(__name__)

class LowRankModel(torch.nn.Module):
    def __init__(self, n, p, rank, kappa, device):
        """
        Low-rank model for solving A,B.
        """
        super(LowRankModel, self).__init__()
        self.device = device
        self.kappa = kappa

        if rank is None:
            rank = n
        elif rank > n:
            logger.warning("Rank is set to be larger than max rank, setting it back to full rank")
            rank = n

        self.L = nn.Parameter(torch.randn(n, rank, device=device))
        self.R = nn.Parameter(torch.randn(n, rank, device=device))
        self.B = nn.Parameter(torch.randn(n, p, device=device))

    def forward(self, X, U):
        """
        Forward pass with low-rank decomposition components.
        """
        X = X.to(self.device)
        U = U.to(self.device)
        output = self.L @ (self.R.T @ X) + self.B @ U
        return output

    def project_LR(self):
        """
        Project the low-rank components to satisfy the well-posedness condition.
        """
        self.L.data = self.project_w(self.L, self.kappa)
        self.R.data = self.project_w(self.R.T, self.kappa).T

    def project_w(self, matrix, v=0.99):
        """
        Project the matrix to the l1 norm ball.
        """
        A_np = matrix.detach().clone().cpu().numpy()
        x = np.abs(A_np).sum(axis=-1)

        for idx in np.where(x > v)[0]:
            a_orig = A_np[idx, :]
            a_sign = np.sign(a_orig)
            a_abs = np.abs(a_orig)
            a = np.sort(a_abs)

            s = np.sum(a) - v
            l = float(len(a))
            for i in range(len(a)):
                if s / l > a[i]:
                    s -= a[i]
                    l -= 1
                else:
                    break
            alpha = s / l if l > 0 else np.max(a_abs)
            a = a_sign * np.maximum(a_abs - alpha, 0)
            # assert np.isclose(np.abs(a).sum(), v)
            A_np[idx, :] = a

        proj = torch.tensor(A_np, dtype=matrix.dtype, device=matrix.device)

        return proj

class Trainer:
    def __init__(self, num_epoch, lambda_z, lr, verbose_epoch):
        """
        Initialize the trainer.
        """
        self.num_epoch = num_epoch
        self.lamb = lambda_z
        self.lr = lr
        self.verbose_epoch = verbose_epoch

    def train(self, model, X, U, Z):
        """
        Train the model.
        """
        optimizer = torch.optim.Adam(model.parameters(), lr=self.lr)
        losses = []
        for epoch in tqdm(range(self.num_epoch)):
            optimizer.zero_grad()
            output = model(X, U)
            loss = F.mse_loss(output, Z.to(output.device)) + self.lamb * torch.norm(model.L, p=2) + self.lamb * torch.norm(model.R, p=2)
            loss.backward()
            optimizer.step()
            model.project_LR()
            losses.append(loss.item())
            if epoch % self.verbose_epoch == 0:
                logger.info(f"Loss at epoch {epoch}: {loss.item()}")
        return model, losses

[docs] class ProjectedGDLowRankSolver(BaseSolver): r""" Train State-driven Implicit Model using projected gradient descent to force A low-rank and well-posed. A, B are solved using projected gradient descent. C, D are solved using numpy least square solver. Args: rank (int): Rank of the A. num_epoch (int): Number of epochs to train A, B. lambda_z (float): Lasso regularization parameter for Z. lr (float): Learning rate. verbose_epoch (int): Number of epochs to print the loss. regen_states (bool, optional): Whether to regenerate states. Defaults to False. tol (float, optional): Tolerance for zeroing out weights. Defaults to 1e-6. """ def __init__( self, rank : Optional[int] = None, num_epoch : int = 10000, lambda_z : float = 1e-6, lr : float = 1e-3, verbose_epoch : int = 100, regen_states : bool = False, tol : float = 1e-6, ): self.rank = rank self.num_epoch = num_epoch self.lambda_z = lambda_z self.lr = lr self.verbose_epoch = verbose_epoch self.regen_states = regen_states self.tol = tol
[docs] def solve( self, X : np.ndarray, U : np.ndarray, Z : np.ndarray, Y : np.ndarray, model_config : Dict[str, Any], plot_loss : bool = False, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """ Solve the implicit model and force A to be low-rank. Args: X (np.ndarray): Post-activation array. U (np.ndarray): Input array. Z (np.ndarray): Pre-activation array. Y (np.ndarray): Output array. model_config (Dict[str, Any]): Implicit model configuration. - activation_fn (Callable): Activation function used by the implicit model. - device (str): Implicit model's device. - atol (float): Equilibrium function's tolerance. - kappa (float): Wellposedness condition parameter. plot_loss (bool, optional): Whether to plot the loss. Defaults to False. Returns: A, B, C, D (np.ndarray): Implicit model's parameters. """ n, m, p, q = X.shape[0], X.shape[1], U.shape[0], Y.shape[0] X, U, Z = torch.tensor(X), torch.tensor(U), torch.tensor(Z) logger.info("===== Start parallel solve for A and B =====") L, R, B = self.lowrank_solve_matrix(X, U, Z, n=n, p=p, kappa=model_config["kappa"], device=model_config["device"], plot_loss=plot_loss) A = L @ R.T logger.info(f"Rank A: {np.linalg.matrix_rank(A)}") logger.info(f"Rank B: {np.linalg.matrix_rank(B)}") if self.regen_states: X = fixpoint_iteration(A, B, U, model_config["activation_fn"], model_config["device"], atol=model_config["atol"]).cpu() logger.info("===== Start parallel solve for C and D =====") CD = self.state_matching(np.hstack([X.cpu().numpy().T, U.cpu().numpy().T]), Y.T) C = CD[:, :n] D = CD[:, n:] return A, B, C, D
def state_matching(self, X, Y): """ Solve the least square problem with numpy.linalg.lstsq. """ W, c, r, _ = np.linalg.lstsq(X, Y, rcond=None) loss = np.mean(np.square(X @ W - Y)) logger.info(f"Total Lasso loss: {loss}") logger.info(f"Data rank: {r}") W[np.abs(W) <= self.tol] = 0 return W.T def lowrank_solve_matrix(self, X, U, Z, n, p, kappa, device, plot_loss): """ Solve the state matching problem with projected gradient descent to ensure A is low-rank and well-posed. """ model = LowRankModel(n=n, p=p, rank=self.rank, kappa=kappa, device=device) trainer = Trainer(num_epoch=self.num_epoch, lambda_z=self.lambda_z, lr=self.lr, verbose_epoch=self.verbose_epoch) model, losses = trainer.train(model, X, U, Z) # Plot losses if plot_loss: plt.figure() plt.plot(losses) plt.xlabel("Epoch") plt.ylabel("MSE loss") plt.yscale("log") plt.title(f"Training Loss") plt.savefig(f"loss_AB.png") # Save the loss trace np.save(f"loss_AB_trace.npy", losses) L = model.L.clone().detach().cpu().numpy() R = model.R.clone().detach().cpu().numpy() B = model.B.clone().detach().cpu().numpy() logger.info(f"Total loss: {losses[-1]}") return L, R, B