Source code for lightning_uq_box.uq_methods.sgld

# Copyright (c) 2023 lightning-uq-box. All rights reserved.
# Licensed under the Apache License 2.0.

"""Stochastic Gradient Langevin Dynamics (SGLD) model."""
# TO DO:
# SGLD with ensembles

import os
from collections.abc import Iterator
from typing import Any

import torch
import torch.nn as nn
from torch import Tensor
from torch.optim.optimizer import Optimizer

from lightning_uq_box.uq_methods import DeterministicModel

from .utils import (
    _get_num_outputs,
    default_classification_metrics,
    default_regression_metrics,
    process_classification_prediction,
    process_regression_prediction,
    save_classification_predictions,
    save_regression_predictions,
)


# SGLD Optimizer from Izmailov, currently in __init__.py
[docs] class SGLD(Optimizer): """Stochastic Gradient Langevian Dynamics Optimzer. If you use this optimizer in your research, please cite the following paper: * https://www.stats.ox.ac.uk/~teh/research/compstats/WelTeh2011a.pdf """
[docs] def __init__( self, params: Iterator[nn.parameter.Parameter], lr: float, noise_factor: float, weight_decay: float = 0.0, ) -> None: """Initialize new instance of SGLD Optimier. Args: params: model parameters lr: initial learning rate noise_factor: parameter denoting how much noise to inject in the SGD update weight_decay: weight decay parameter for SGLD optimizer """ defaults = dict(lr=lr, noise_factor=noise_factor, weight_decay=weight_decay) super().__init__(params, defaults) self.lr = lr
[docs] def step(self, closure: callable): """Perform a single optimization step. Args: closure: A closure that reevaluates the model Returns: updated loss """ loss = None # if closure is not None: loss = closure() for group in self.param_groups: weight_decay = group["weight_decay"] noise_factor = group["noise_factor"] for p in group["params"]: if p.grad is None: continue d_p = p.grad.data if weight_decay != 0: d_p.add_(p.data, alpha=weight_decay) p.data.add_(d_p, alpha=-group["lr"]) p.data.add_( torch.randn_like(d_p), alpha=noise_factor * (2.0 * group["lr"]) ** 0.5, ) return loss
class SGLDBase(DeterministicModel): """Storchastic Gradient Langevian Dynamics method for regression. If you use this model in your research, please cite the following paper: * https://www.stats.ox.ac.uk/~teh/research/compstats/WelTeh2011a.pdf """ def __init__( self, model: nn.Module, loss_fn: nn.Module, lr: float, weight_decay: float, noise_factor: float, n_sgld_samples: int, ) -> None: """Initialize a new instance of SGLD model. Args: model: pytorch model loss_fn: choice of loss function lr: initial learning rate for SGLD optimizer weight_decay: weight decay parameter for SGLD optimizer noise_factor: parameter denoting how much noise to inject in the SGD update burnin_epochs: number of epochs to fit mse loss n_sgld_samples: number of sgld samples to collect """ super().__init__(model, loss_fn, None, None) self.save_hyperparameters(ignore=["model", "loss_fn"]) self.models: list[nn.Module] = [] self.dir_list = [] # manual optimization with SGLD optimizer self.automatic_optimization = False def configure_optimizers(self) -> dict[str, Any]: """Initialize the optimizer and learning rate scheduler. Returns: SGLD optimizer and scheduler """ optimizer = SGLD( params=self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay, noise_factor=self.hparams.noise_factor, ) return {"optimizer": optimizer} def on_train_start(self) -> None: """On training start.""" self.snapshot_dir = os.path.join( self.trainer.default_root_dir, "model_snapshots" ) os.makedirs(self.snapshot_dir) def on_train_epoch_end(self) -> None: """Save model ckpts after epoch and log training metrics.""" # save ckpts for n_sgld_sample epochs before end (max_epochs) if self.current_epoch >= ( self.trainer.max_epochs - self.hparams.n_sgld_samples ): torch.save( self.model.state_dict(), os.path.join(self.snapshot_dir, f"{self.current_epoch}_model.ckpt"), ) self.dir_list.append( os.path.join(self.snapshot_dir, f"{self.current_epoch}_model.ckpt") ) # log train metrics self.log_dict(self.train_metrics.compute()) self.train_metrics.reset() class SGLDRegression(SGLDBase): """Stochastic Gradient Langevin Dynamics method for regression.""" pred_file_name = "preds.csv" def __init__( self, model: nn.Module, loss_fn: nn.Module, lr: float, weight_decay: float, noise_factor: float, burnin_epochs: int, n_sgld_samples: int, ) -> None: """Initialize a new instance of SGLD model. Args: model: pytorch model loss_fn: choice of loss function lr: initial learning rate for SGLD optimizer weight_decay: weight decay parameter for SGLD optimizer noise_factor: parameter denoting how much noise to inject in the SGD update burnin_epochs: number of epochs to fit mse loss n_sgld_samples: number of sgld samples to collect """ super().__init__(model, loss_fn, lr, weight_decay, noise_factor, n_sgld_samples) self.burnin_epochs = burnin_epochs def setup_task(self) -> None: """Set up task specific metrics.""" self.train_metrics = default_regression_metrics("train") self.val_metrics = default_regression_metrics("val") self.test_metrics = default_regression_metrics("test") def adapt_output_for_metrics(self, out: Tensor) -> Tensor: """Adapt model output to be compatible for metric computation. Args: out: output from :meth:`self.forward` [batch_size x (mu, sigma)] Returns: extracted mean used for metric computation [batch_size x 1] """ return out[:, 0:1] def training_step( self, batch: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0 ) -> Tensor: """Compute and return the training loss. Args: batch: the output of your DataLoader batch_idx: the index of this batch dataloader_idx: the index of the dataloader Returns: training loss """ sgld_opt = self.optimizers() sgld_opt.zero_grad() X, y = batch[self.input_key], batch[self.target_key] out = self.forward(X) def closure(): """Closure function for optimizer.""" sgld_opt.zero_grad() if self.current_epoch < self.hparams.burnin_epochs: loss = nn.functional.mse_loss(self.adapt_output_for_metrics(out), y) # after train with nll else: loss = self.loss_fn(out, y) sgld_opt.zero_grad() self.manual_backward(loss) return loss loss = sgld_opt.step(closure=closure) self.log( "train_loss", loss, batch_size=batch[self.input_key].shape[0] ) # logging to Logger self.train_metrics(self.adapt_output_for_metrics(out), y) # return loss def predict_step( self, X: Tensor, batch_idx: int = 0, dataloader_idx: int = 0 ) -> dict[str, Tensor]: """Predict step with SGLD, take n_sgld_sampled models, get mean and variance. Args: X: input tensor batch_idx: the index of this batch dataloader_idx: the index of the dataloader Returns: output dictionary with uncertainty estimates """ # create predictions from models loaded from checkpoints preds: list[torch.Tensor] = [] for ckpt_path in self.dir_list: self.model.load_state_dict(torch.load(ckpt_path, weights_only=True)) preds.append(self.model(X)) preds = torch.stack(preds, dim=-1).detach() # shape [batch_size, num_outputs, n_sgld_samples] return process_regression_prediction(preds) def on_test_batch_end( self, outputs: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0 ) -> None: """Test batch end save predictions. Args: outputs: dictionary of model outputs and aux variables batch_idx: batch index dataloader_idx: dataloader index """ save_regression_predictions( outputs, os.path.join(self.trainer.default_root_dir, self.pred_file_name) ) class SGLDClassification(SGLDBase): """Stochastic Gradient Langevin Dynamics method for classification.""" valid_tasks = ["multiclass", "binary", "multilabel"] pred_file_name = "preds.csv" def __init__( self, model: nn.Module, loss_fn: nn.Module, lr: float, weight_decay: float, noise_factor: float, task: str = "multiclass", n_sgld_samples: int = 20, ) -> None: """Initialize a new instance of SGLD model. Args: model: pytorch model to train with SGLD loss_fn: choice of loss function lr: initial learning rate weight_decay: weight decay parameter for SGLD optimizer noise_factor: parameter denoting how much noise to inject in the SGD update task: classification task, one of ["multiclass", "binary", "multilabel"] n_sgld_samples: number of sgld samples to collect """ assert task in self.valid_tasks self.task = task self.num_classes = _get_num_outputs(model) super().__init__(model, loss_fn, lr, weight_decay, noise_factor, n_sgld_samples) def setup_task(self) -> None: """Set up task specific metrics.""" self.train_metrics = default_classification_metrics( "train", self.task, self.num_classes ) self.val_metrics = default_classification_metrics( "val", self.task, self.num_classes ) self.test_metrics = default_classification_metrics( "test", self.task, self.num_classes ) def adapt_output_for_metrics(self, out: Tensor) -> Tensor: """Adapt model output to be compatible for metric computation. Args: out: output from :meth:`self.forward` [batch_size x (mu, sigma)] Returns: extracted mean used for metric computation [batch_size x 1] """ return out def training_step( self, batch: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0 ) -> Tensor: """Compute and return the training loss. Args: batch: the output of your DataLoader batch_idx: the index of this batch dataloader_idx: the index of the dataloader Returns: training loss """ sgld_opt = self.optimizers() sgld_opt.zero_grad() X, y = batch[self.input_key], batch[self.target_key] out = self.forward(X) def closure(): """Closure function for optimizer.""" sgld_opt.zero_grad() loss = self.loss_fn(out, y) sgld_opt.zero_grad() self.manual_backward(loss) return loss loss = sgld_opt.step(closure=closure) self.log( "train_loss", loss, batch_size=batch[self.input_key].shape[0] ) # logging to Logger self.train_metrics(self.adapt_output_for_metrics(out), y) return loss def predict_step( self, X: Tensor, batch_idx: int = 0, dataloader_idx: int = 0 ) -> dict[str, Tensor]: """Predict step with SGLD, take n_sgld_sampled models, get mean and variance. Args: X: prediction batch of shape [batch_size x input_dims] batch_idx: the index of this batch dataloader_idx: the index of the dataloader Returns: output dictionary with uncertainty estimates """ # create predictions from models loaded from checkpoints preds: list[torch.Tensor] = [] for ckpt_path in self.dir_list: self.model.load_state_dict(torch.load(ckpt_path, weights_only=True)) preds.append(self.model(X)) preds = torch.stack(preds, dim=-1).detach() # shape [batch_size, num_outputs, n_sgld_samples] return process_classification_prediction(preds) def on_test_batch_end( self, outputs: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0 ) -> None: """Test batch end save predictions. Args: outputs: dictionary of model outputs and aux variables batch_idx: batch index dataloader_idx: dataloader index """ save_classification_predictions( outputs, os.path.join(self.trainer.default_root_dir, self.pred_file_name) )