# Copyright (c) 2023 lightning-uq-box. All rights reserved.
# Licensed under the Apache License 2.0.
"""Laplace Approximation model."""
import copy
import os
from typing import Any
import torch
from laplace import Laplace
from torch import Tensor
from tqdm import trange
from lightning_uq_box.uq_methods import BaseModule
from .utils import (
_get_num_inputs,
_get_num_outputs,
default_classification_metrics,
default_regression_metrics,
save_classification_predictions,
save_regression_predictions,
)
# TODO check whether Laplace fitting procedure can be implemented as working
# over training_step in lightning
def tune_prior_precision_and_sigma(
model: Laplace,
tune_precision_lr: float,
n_epochs_tune_precision: int,
tune_prior_precision: bool,
tune_sigma_noise: bool,
):
"""Tune the prior precision and sigma noise via Empirical Bayes.
Args:
model: laplace model
tune_precision_lr: learning rate for tuning prior precision
n_epochs_tune_precision: number of epochs to tune prior precision
tune_prior_precision: whether to tune prior precision
tune_sigma_noise: whether to tune sigma noise
"""
with torch.inference_mode(False):
optim_params = []
if tune_prior_precision:
log_prior = torch.ones(1, requires_grad=True)
optim_params.append(log_prior)
else:
log_prior = torch.log(model.prior_precision)
if tune_sigma_noise:
log_sigma = torch.ones(1, requires_grad=True)
optim_params.append(log_sigma)
else:
log_sigma = torch.log(model.sigma_noise)
# import pdb
# pdb.set_trace()
hyper_optimizer = torch.optim.Adam(optim_params, lr=tune_precision_lr)
bar = trange(n_epochs_tune_precision)
for _ in bar:
hyper_optimizer.zero_grad()
neg_marglik = -model.log_marginal_likelihood(
log_prior.exp(), log_sigma.exp()
)
neg_marglik.backward()
hyper_optimizer.step()
bar.set_postfix(neg_marglik=f"{neg_marglik.detach().cpu().item()}")
[docs]
class LaplaceBase(BaseModule):
"""Laplace Approximation Method.
This is a lightning module wrapper for the `Laplace library <https://aleximmer.github.io/Laplace/>`_. # noqa: E501
If you use this model in your research, please cite the following papers:
* https://arxiv.org/abs/2106.14806
"""
pred_file_name = "preds.csv"
[docs]
def __init__(
self,
laplace_model: Laplace,
pred_type: str = "glm",
link_approx: str = "probit",
num_samples: int | None = None,
) -> None:
"""Initialize a new instance of Laplace Model Wrapper.
Args:
laplace_model: initialized Laplace model
pred_type: prediction type, one of ['glm', 'nn']
link_approx: link function approximation, one of ['mc', 'probit', 'bridge']
for `pred_type='nn'` only 'mc' is supported
num_samples: number of samples for prediction, if specified
will call `predictive_samples` instead of `predictive` method in
Laplace library
.. versionchanged:: 0.2
Add 'pred_type' and 'link_approx' arguments.
"""
super().__init__()
if pred_type == "nn":
assert link_approx == "mc", "For nn prediction only mc link is supported"
self.pred_type = pred_type
self.link_approx = link_approx
self.num_samples = num_samples
self.save_hyperparameters(ignore=["laplace_model"])
# reinitialize the model with the correct device because cannot set device
# to laplace model afterwards
LaplaceClass = type(laplace_model)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
init_args = laplace_model.__init__.__code__.co_varnames
model_copy = copy.deepcopy(laplace_model.model)
model = model_copy.to(device)
args_dict = {
arg: getattr(laplace_model, arg)
for arg in init_args
if hasattr(laplace_model, arg) and arg != "model"
}
args_dict["model"] = model
# Create a new instance of the LaplaceClass with the same arguments,
# but with the model on the CUDA device
self.laplace_model = LaplaceClass(**args_dict)
self.laplace_fitted = False
self.setup_task()
[docs]
def setup_task(self) -> None:
"""Set up task."""
pass
@property
def num_input_features(self) -> int:
"""Retrieve input dimension to the model.
Returns:
number of input dimension to the model
"""
return _get_num_inputs(self.model.model)
@property
def num_outputs(self) -> int:
"""Retrieve output dimension to the model.
Returns:
number of output dimension to model
"""
return _get_num_outputs(self.model.model)
[docs]
def on_test_start(self) -> None:
"""Fit the Laplace approximation before testing."""
self.train_loader = self.trainer.datamodule.train_dataloader()
def collate_fn_laplace_torch(batch):
"""Collate function to for laplace torch tuple convention.
Args:
batch: input batch
Returns:
renamed batch
"""
# Extract images and labels from the batch dictionary
if isinstance(batch[0], dict):
images = [item[self.input_key] for item in batch]
labels = [item[self.target_key] for item in batch]
else:
images = [item[0] for item in batch]
labels = [item[1] for item in batch]
# Stack images and labels into tensors
inputs = torch.stack(images)
targets = torch.stack(labels)
# apply datamodule augmentation
aug_batch = self.trainer.datamodule.on_after_batch_transfer(
{self.input_key: inputs, self.target_key: targets}, dataloader_idx=0
)
return (aug_batch[self.input_key], aug_batch[self.target_key])
self.train_loader.collate_fn = collate_fn_laplace_torch
if not self.laplace_fitted:
# take the deterministic model we trained and fit laplace
# laplace needs a nn.Module ant not a lightning module
# also lightning automatically disables gradient computation during test
# but need it for laplace so set inference mode to false with cntx manager
with torch.inference_mode(False):
# fit the laplace approximation
self.laplace_model.fit(self.train_loader)
self.laplace_fitted = True
# save this laplace fitted model as a checkpoint?!
[docs]
def test_step(
self, batch: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0
) -> None:
"""Test step."""
out_dict = self.predict_step(batch[self.input_key])
out_dict[self.target_key] = batch[self.target_key].detach().squeeze(-1)
self.log(
"test_loss",
self.loss_fn(out_dict["pred"], batch[self.target_key].squeeze(-1)),
batch_size=batch[self.input_key].shape[0],
) # logging to Logger
if batch[self.input_key].shape[0] > 1:
self.test_metrics(out_dict["pred"], batch[self.target_key].squeeze(-1))
out_dict["pred"] = out_dict["pred"].detach().squeeze(-1)
# save metadata
out_dict = self.add_aux_data_to_dict(out_dict, batch)
return out_dict
[docs]
def on_test_epoch_end(self):
"""Log epoch-level test metrics."""
self.log_dict(self.test_metrics.compute())
self.test_metrics.reset()
[docs]
class LaplaceRegression(LaplaceBase):
"""Laplace Approximation Wrapper for regression.
This is a lightning module wrapper for the
`Laplace library <https://aleximmer.github.io/Laplace/>`_.
If you use this model in your research, please cite the following paper:
* https://arxiv.org/abs/2106.14806
"""
[docs]
def __init__(
self,
laplace_model: Laplace,
pred_type: str = "glm",
link_approx: str = "probit",
num_samples: int | None = None,
tune_prior_precision: bool = True,
tune_sigma_noise: bool = False,
tuning_lr: float = 1e-3,
n_epochs_tuning: int = 100,
) -> None:
"""Initialize a new instance of Laplace Model Wrapper for regression.
Args:
laplace_model: initialized Laplace model
pred_type: prediction type, one of ['glm', 'nn']
link_approx: link function approximation, one of ['mc', 'probit', 'bridge']
for `pred_type='nn'` only 'mc' is supported
num_samples: number of samples for prediction, if specified
will call `predictive_samples` instead of `predictive` method in
Laplace library
tune_prior_precision: whether to tune prior precision
tune_sigma_noise: whether to tune sigma noise
tuning_lr: learning rate for tuning prior precision and sigma
n_epochs_tuning: number of epochs to tune prior precision and sigma
"""
super().__init__(laplace_model, pred_type, link_approx, num_samples)
assert self.laplace_model.likelihood == "regression"
self.loss_fn = torch.nn.MSELoss()
self.tuning_lr = tuning_lr
self.n_epochs_tuning = n_epochs_tuning
self.tune_prior_precision = tune_prior_precision
self.tune_sigma_noise = tune_sigma_noise
[docs]
def on_test_start(self) -> None:
"""Fit the Laplace approximation before testing."""
super().on_test_start()
if self.tune_prior_precision or self.tune_sigma_noise:
tune_prior_precision_and_sigma(
self.laplace_model,
self.tuning_lr,
self.n_epochs_tuning,
self.tune_prior_precision,
self.tune_sigma_noise,
)
[docs]
def setup_task(self) -> None:
"""Set up task specific attributes."""
self.test_metrics = default_regression_metrics("test")
[docs]
def forward(self, X: Tensor) -> dict[str, Tensor]:
"""Fitted Laplace Model Forward Pass.
Args:
X: tensor of data to run through the model [batch_size, input_dim]
Returns:
output from the laplace model
"""
if not self.laplace_fitted:
self.on_test_start()
pred_dict: dict[str, Tensor] = {}
if self.num_samples:
fsamples = self.laplace_model.predictive_samples(
X, pred_type=self.pred_type, n_samples=self.num_samples
)
mean = fsamples.mean(0).squeeze()
# return samples as shape [batch_size, out_dim, num_samples]
pred_dict["samples"] = fsamples.permute(1, 2, 0)
laplace_epistemic = fsamples.std(0).squeeze()
laplace_aleatoric = (
torch.ones_like(laplace_epistemic)
* self.laplace_model.sigma_noise.detach().item()
)
pred_std = torch.sqrt(laplace_epistemic + laplace_aleatoric**2)
else:
mean, var = self.laplace_model(
X, pred_type=self.pred_type, link_approx=self.link_approx
)
mean = mean.squeeze().detach()
laplace_epistemic = var.squeeze().sqrt()
laplace_aleatoric = (
torch.ones_like(laplace_epistemic)
* self.laplace_model.sigma_noise.detach().item()
)
pred_std = torch.sqrt(laplace_epistemic**2 + laplace_aleatoric**2)
pred_dict["epistemic_uct"] = laplace_epistemic
pred_dict["aleatoric_uct"] = laplace_aleatoric
pred_dict["pred"] = mean
pred_dict["pred_uct"] = pred_std
return pred_dict
[docs]
def predict_step(
self, X: Tensor, batch_idx: int = 0, dataloader_idx: int = 0
) -> dict[str, Tensor]:
"""Predict step with Laplace Approximation.
Args:
X: prediction batch of shape [batch_size x input_dims]
batch_idx: the index of this batch
dataloader_idx: the index of the dataloader
Returns:
prediction dictionary
"""
if not self.laplace_fitted:
self.on_test_start()
# also lightning automatically disables gradient computation during test
# but need it for laplace so set inference mode to false with context manager
with torch.inference_mode(False):
# inference tensors are not saved for backward so need to create
# a clone with autograd enables
input = X.clone().requires_grad_().to(self.device)
return self.forward(input)
[docs]
def on_test_batch_end(
self, outputs: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0
) -> None:
"""Test batch end save predictions.
Args:
outputs: dictionary of model outputs and aux variables
batch_idx: batch index
dataloader_idx: dataloader index
"""
save_regression_predictions(
outputs, os.path.join(self.trainer.default_root_dir, self.pred_file_name)
)
[docs]
class LaplaceClassification(LaplaceBase):
"""Laplace Approximation Wrapper for classification.
This is a lightning module wrapper for the
`Laplace library <https://aleximmer.github.io/Laplace/>`_.
If you use this model in your research, please cite the following paper:
* https://arxiv.org/abs/2106.14806
"""
valid_tasks = ["binary", "multiclass"]
[docs]
def __init__(
self,
laplace_model: Laplace,
task: str = "multiclass",
pred_type: str = "glm",
link_approx: str = "probit",
num_samples: int | None = None,
) -> None:
"""Initialize a new instance of Laplace Model Wrapper for Classification.
Args:
laplace_model: initialized Laplace model
task: classification task, one of ['binary', 'multiclass']
pred_type: prediction type, one of ['glm', 'nn']
link_approx: link function approximation, one of ['mc', 'probit', 'bridge']
for `pred_type='nn'` only 'mc' is supported
num_samples: number of samples for prediction, if specified
will call `predictive_samples` instead of `predictive` method in
Laplace library
"""
assert task in self.valid_tasks
self.task = task
super().__init__(laplace_model, pred_type, link_approx, num_samples)
self.loss_fn = torch.nn.CrossEntropyLoss()
assert self.laplace_model.likelihood == "classification"
[docs]
def forward(self, X: Tensor, **kwargs: Any) -> dict[str, Tensor]:
"""Fitted Laplace Model Forward Pass.
Args:
X: tensor of data to run through the model [batch_size, input_dim]
kwargs: additional arguments for laplace forward pass
Returns:
output from the laplace model
"""
if not self.laplace_fitted:
self.on_test_start()
pred_dict: dict[str, Tensor] = {}
if self.num_samples:
fsamples = self.laplace_model.predictive_samples(
X, pred_type=self.pred_type, n_samples=self.num_samples
)
mean = fsamples.mean(0)
pred_dict["samples"] = fsamples
else:
mean = self.laplace_model(
X, pred_type=self.pred_type, link_approx=self.link_approx
)
pred_dict["pred"] = mean
return pred_dict
[docs]
def setup_task(self) -> None:
"""Set up task specific attributes."""
self.test_metrics = default_classification_metrics(
"test", self.task, _get_num_outputs(self.laplace_model.model)
)
[docs]
def predict_step(
self, X: Tensor, batch_idx: int = 0, dataloader_idx: int = 0
) -> dict[str, Tensor]:
"""Predict step with Laplace Approximation.
Args:
X: prediction batch of shape [batch_size x input_dims]
batch_idx: the index of this batch
dataloader_idx: the index of the dataloader
Returns:
prediction dictionary
"""
if not self.laplace_fitted:
self.on_test_start()
# also lightning automatically disables gradient computation during test
# but need it for laplace so set inference mode to false with context manager
with torch.inference_mode(False):
# inference tensors are not saved for backward so need to create
# a clone with autograd enables
input = X.clone().requires_grad_().to(self.device)
pred_dict = self.forward(input)
pred_dict["pred_uct"] = -torch.sum(
pred_dict["pred"] * torch.log(pred_dict["pred"]), dim=1
)
return pred_dict
[docs]
def on_test_batch_end(
self, outputs: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0
) -> None:
"""Test batch end save predictions.
Args:
outputs: dictionary of model outputs and aux variables
batch_idx: batch index
dataloader_idx: dataloader index
"""
save_classification_predictions(
outputs, os.path.join(self.trainer.default_root_dir, self.pred_file_name)
)