# Copyright (c) 2023 lightning-uq-box. All rights reserved.
# Licensed under the Apache License 2.0.
"""Bayesian Neural Networks with Variational Inference."""
import os
from typing import Any
import torch
import torch.nn as nn
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
from torch import Tensor
from lightning_uq_box.models.bnn_layers.bnn_utils import (
convert_deterministic_to_bnn,
get_kl_loss,
)
from .base import DeterministicModel
from .utils import (
_get_num_outputs,
default_classification_metrics,
default_regression_metrics,
default_segmentation_metrics,
freeze_segmentation_model,
map_stochastic_modules,
process_classification_prediction,
process_regression_prediction,
process_segmentation_prediction,
save_classification_predictions,
save_image_predictions,
save_regression_predictions,
)
[docs]
class BNN_VI_ELBO_Base(DeterministicModel):
"""Bayes By Backprop Base with Variational Inference (VI).
If you use this model in your work, please cite:
* https://arxiv.org/abs/1505.05424
"""
[docs]
def __init__(
self,
model: nn.Module,
criterion: nn.Module,
beta: float = 100,
num_mc_samples_train: int = 10,
num_mc_samples_test: int = 50,
output_noise_scale: float = 1.3,
prior_mu: float = 0.0,
prior_sigma: float = 1.0,
posterior_mu_init: float = 0.0,
posterior_rho_init: float = -5.0,
bayesian_layer_type: str = "reparameterization",
stochastic_module_names: list[int | str] | None = None,
freeze_backbone: bool = False,
optimizer: OptimizerCallable = torch.optim.Adam,
lr_scheduler: LRSchedulerCallable = None,
) -> None:
"""Initialize a new Model instance.
Args:
model: pytorch model that will be converted into a BNN
criterion: loss function used for optimization
beta: beta factor for negative elbo loss computation,
should be number of weights and biases
num_mc_samples_train: number of MC samples during training when computing
the negative ELBO loss. When setting num_mc_samples_train=1, this
is just Bayes by Backprop.
num_mc_samples_test: number of MC samples during test and prediction
output_noise_scale: scale of predicted sigmas
prior_mu: prior mean value for bayesian layer
prior_sigma: prior variance value for bayesian layer
posterior_mu_init: mean initialization value for approximate posterior
posterior_rho_init: variance initialization value for approximate posterior
through softplus σ = log(1 + exp(ρ))
bayesian_layer_type: `flipout` or `reparameterization`
stochastic_module_names: list of module names or indices that should
be converted to variational layers
freeze_backbone: whether to freeze the backbone
optimizer: optimizer used for training
lr_scheduler: learning rate scheduler
Raises:
AssertionError: if ``num_mc_samples_train`` is not positive.
AssertionError: if ``num_mc_samples_test`` is not positive.
"""
self.bnn_args = {
"prior_mu": prior_mu,
"prior_sigma": prior_sigma,
"posterior_mu_init": posterior_mu_init,
"posterior_rho_init": posterior_rho_init,
"layer_type": bayesian_layer_type,
}
self.stochastic_module_names = map_stochastic_modules(
model, stochastic_module_names
)
self._setup_bnn_with_vi(model)
super().__init__(model, criterion, freeze_backbone, optimizer, lr_scheduler)
assert num_mc_samples_train > 0, "Need to sample at least once during training."
assert num_mc_samples_test > 0, "Need to sample at least once during testing."
self.save_hyperparameters(
ignore=["model", "criterion", "optimizer", "lr_scheduler"]
)
# update hyperparameters
self.hparams["weight_decay"] = 1e-5
# hyperparameter depending on network size
self.beta = beta
self.criterion = criterion
self.lr_scheduler = lr_scheduler
self.freeze_backbone = freeze_backbone
[docs]
def setup_task(self) -> None:
"""Set up task."""
pass
def _setup_bnn_with_vi(self, model: nn.Module) -> None:
"""Configure setup of the BNN Model."""
# convert deterministic model to BNN
convert_deterministic_to_bnn(
model, self.bnn_args, stochastic_module_names=self.stochastic_module_names
)
[docs]
def forward(self, X: Tensor) -> Tensor:
"""Forward pass BNN+VI.
Args:
X: input data
Returns:
bnn output
"""
return self.model(X)
[docs]
def on_fit_start(self) -> None:
"""Before fitting compute number of training points."""
self.num_training_points = len(
self.trainer.datamodule.train_dataloader().dataset
)
[docs]
def training_step(
self, batch: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0
) -> Tensor:
"""Compute and return the training loss.
Args:
batch: the output of your DataLoader
batch_idx: the index of this batch
dataloader_idx: the index of the dataloader
Returns:
training loss
"""
X, y = batch[self.input_key], batch[self.target_key]
elbo_loss, mean_output = self.compute_elbo_loss(X, y)
self.log("train_loss", elbo_loss, batch_size=X.shape[0]) # logging to Logger
if batch[self.input_key].shape[0] > 1:
self.train_metrics(mean_output, y)
return elbo_loss
[docs]
def validation_step(
self, batch: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0
) -> Tensor:
"""Compute validation loss and log example predictions.
Args:
batch: the output of your DataLoader
batch_idx: the index of this batch
dataloader_idx: the index of the dataloader
Returns:
validation loss
"""
X, y = batch[self.input_key], batch[self.target_key]
elbo_loss, mean_output = self.compute_elbo_loss(X, y)
self.log("val_loss", elbo_loss, batch_size=X.shape[0]) # logging to Logger
if batch[self.input_key].shape[0] > 1:
self.val_metrics(mean_output, y)
return elbo_loss
[docs]
def compute_elbo_loss(self, X: Tensor, y: Tensor) -> tuple[Tensor]:
"""Compute the ELBO loss with mse/nll.
Args:
X: input data
y: target
Returns:
negative elbo loss and mean model output [batch_size]
for logging
"""
model_preds: list[Tensor] = []
pred_losses = torch.zeros(self.hparams.num_mc_samples_train)
for i in range(self.hparams.num_mc_samples_train):
# mean prediction
pred = self.forward(X)
pred_losses[i] = self.compute_task_loss(pred, y)
model_preds.append(self.adapt_output_for_metrics(pred).detach())
mean_pred = torch.stack(model_preds, dim=-1).mean(-1)
# dimension [batch_size]
mean_pred_nll_loss = torch.mean(pred_losses)
# shape 0, mean over batch_size, this is "the S factor":)
# need to potentially multiply by full training set size
mean_kl = get_kl_loss(self.model)
negative_beta_elbo = (
self.num_training_points * mean_pred_nll_loss + self.beta * mean_kl
)
return negative_beta_elbo, mean_pred
[docs]
def compute_task_loss(self, X: Tensor, y: Tensor) -> Tensor:
"""Compute the loss for the respective task for a single sampling iteration.
Args:
X: input data
y: target
Returns:
nll loss for the task
"""
raise NotImplementedError
[docs]
def exclude_from_wt_decay(
self, named_params, weight_decay: float, skip_list: list[str] = ("mu", "rho")
):
"""Exclude non VI parameters from weight_decay optimization.
Args:
named_params: named parameters of the model
weight_decay: weight decay factor
skip_list: list of strings that if found in parameter name
excludes the parameter from weight decay
Returns:
split parameter groups for optimization with and without
weight_decay
"""
params = []
excluded_params = []
for name, param in named_params:
if not param.requires_grad:
continue
elif any(layer_name in name for layer_name in skip_list):
excluded_params.append(param)
else:
params.append(param)
return [
{"params": params, "weight_decay": weight_decay},
{"params": excluded_params, "weight_decay": 0.0},
]
[docs]
class BNN_VI_ELBO_Regression(BNN_VI_ELBO_Base):
"""Bayes By Backprop Model with Variational Inference (VI) for Regression.
If you use this model in your work, please cite:
* https://arxiv.org/abs/1505.05424
"""
pred_file_name = "preds.csv"
[docs]
def __init__(
self,
model: nn.Module,
criterion: nn.Module,
burnin_epochs: int,
beta: float = 100,
num_mc_samples_train: int = 10,
num_mc_samples_test: int = 50,
output_noise_scale: float = 1.3,
prior_mu: float = 0,
prior_sigma: float = 1,
posterior_mu_init: float = 0,
posterior_rho_init: float = -5,
bayesian_layer_type: str = "reparameterization",
stochastic_module_names: list[int] | list[str] | None = None,
freeze_backbone: bool = False,
optimizer: OptimizerCallable = torch.optim.Adam,
lr_scheduler: LRSchedulerCallable = None,
) -> None:
"""Initialize a new Model instance.
Args:
model: pytorch model that will be converted into a BNN
criterion: loss function used for optimization
burnin_epochs: number of epochs to train before switching to nll loss
beta: beta factor for negative elbo loss computation,
should be number of weights and biases
num_mc_samples_train: number of MC samples during training when computing
the negative ELBO loss. When setting num_mc_samples_train=1, this
is just Bayes by Backprop.
num_mc_samples_test: number of MC samples during test and prediction
output_noise_scale: scale of predicted sigmas
prior_mu: prior mean value for bayesian layer
prior_sigma: prior variance value for bayesian layer
posterior_mu_init: mean initialization value for approximate posterior
posterior_rho_init: variance initialization value for approximate posterior
through softplus σ = log(1 + exp(ρ))
bayesian_layer_type: `flipout` or `reparameterization`
stochastic_module_names: list of module names or indices that should
be converted to variational layers
freeze_backbone: whether to freeze the backbone
optimizer: optimizer used for training
lr_scheduler: learning rate scheduler
Raises:
AssertionError: if ``num_mc_samples_train`` is not positive.
AssertionError: if ``num_mc_samples_test`` is not positive.
"""
super().__init__(
model,
criterion,
beta,
num_mc_samples_train,
num_mc_samples_test,
output_noise_scale,
prior_mu,
prior_sigma,
posterior_mu_init,
posterior_rho_init,
bayesian_layer_type,
stochastic_module_names,
freeze_backbone,
optimizer,
lr_scheduler,
)
self.save_hyperparameters(
ignore=["model", "criterion", "optimizer", "lr_scheduler"]
)
[docs]
def setup_task(self) -> None:
"""Set up task specific attributes."""
self.train_metrics = default_regression_metrics("train")
self.val_metrics = default_regression_metrics("val")
self.test_metrics = default_regression_metrics("test")
[docs]
def compute_task_loss(self, pred: Tensor, y: Tensor) -> Tensor:
"""Compute the loss for the respective task for a single sampling iteration.
Args:
pred: model_prediction
y: target
Returns:
nll loss for the task
"""
if self.current_epoch < self.hparams.burnin_epochs or isinstance(
self.criterion, nn.MSELoss
):
# compute mse loss with output noise scale, is like mse
loss = torch.nn.functional.mse_loss(self.adapt_output_for_metrics(pred), y)
else:
# after burnin compute nll with log_sigma
loss = self.criterion(pred, y)
return loss
[docs]
def predict_step(
self, X: Tensor, batch_idx: int = 0, dataloader_idx: int = 0
) -> dict[str, Tensor]:
"""Prediction step.
Args:
X: prediction batch of shape [batch_size x input_dims]
batch_idx: batch index
dataloader_idx: dataloader index
"""
with torch.no_grad():
preds = torch.stack(
[self.model(X) for _ in range(self.hparams.num_mc_samples_test)], dim=-1
) # shape [batch_size, num_outputs, num_samples]
return process_regression_prediction(preds)
[docs]
def on_test_batch_end(
self, outputs: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0
) -> None:
"""Test batch end save predictions.
Args:
outputs: dictionary of model outputs and aux variables
batch_idx: batch index
dataloader_idx: dataloader index
"""
save_regression_predictions(
outputs, os.path.join(self.trainer.default_root_dir, self.pred_file_name)
)
[docs]
class BNN_VI_ELBO_Classification(BNN_VI_ELBO_Base):
"""Bayes By Backprop Model with Variational Inference (VI) for Classification.
If you use this model in your work, please cite:
* https://arxiv.org/abs/1505.05424
"""
pred_file_name = "preds.csv"
valid_tasks = ["binary", "multiclass", "multilable"]
[docs]
def __init__(
self,
model: nn.Module,
criterion: nn.Module,
task: str = "multiclass",
beta: float = 100,
num_mc_samples_train: int = 10,
num_mc_samples_test: int = 50,
output_noise_scale: float = 1.3,
prior_mu: float = 0,
prior_sigma: float = 1,
posterior_mu_init: float = 0,
posterior_rho_init: float = -5,
bayesian_layer_type: str = "reparameterization",
stochastic_module_names: list[int] | list[str] | None = None,
freeze_backbone: bool = False,
optimizer: OptimizerCallable = torch.optim.Adam,
lr_scheduler: LRSchedulerCallable = None,
) -> None:
"""Initialize a new Model instance.
Args:
model: pytorch model that will be converted into a BNN
criterion: loss function used for optimization
task: classification task, one of `binary`, `multiclass`, `multilabel`
beta: beta factor for negative elbo loss computation,
should be number of weights and biases
num_mc_samples_train: number of MC samples during training when computing
the negative ELBO loss. When setting num_mc_samples_train=1, this
is just Bayes by Backprop.
num_mc_samples_test: number of MC samples during test and prediction
output_noise_scale: scale of predicted sigmas
prior_mu: prior mean value for bayesian layer
prior_sigma: prior variance value for bayesian layer
posterior_mu_init: mean initialization value for approximate posterior
posterior_rho_init: variance initialization value for approximate posterior
through softplus σ = log(1 + exp(ρ))
bayesian_layer_type: `flipout` or `reparameterization`
stochastic_module_names: list of module names or indices that should
be converted to variational layers
freeze_backbone: whether to freeze the backbone
lr_scheduler: learning rate scheduler
optimizer: optimizer used for training
Raises:
AssertionError: if ``num_mc_samples_train`` is not positive.
AssertionError: if ``num_mc_samples_test`` is not positive.
"""
assert task in self.valid_tasks
self.task = task
self.num_classes = _get_num_outputs(model)
super().__init__(
model,
criterion,
beta,
num_mc_samples_train,
num_mc_samples_test,
output_noise_scale,
prior_mu,
prior_sigma,
posterior_mu_init,
posterior_rho_init,
bayesian_layer_type,
stochastic_module_names,
freeze_backbone,
optimizer,
lr_scheduler,
)
self.save_hyperparameters(
ignore=["model", "criterion", "optimizer", "lr_scheduler"]
)
[docs]
def setup_task(self) -> None:
"""Set up task specific attributes."""
self.train_metrics = default_classification_metrics(
"train", self.task, self.num_classes
)
self.val_metrics = default_classification_metrics(
"val", self.task, self.num_classes
)
self.test_metrics = default_classification_metrics(
"test", self.task, self.num_classes
)
[docs]
def compute_task_loss(self, pred: Tensor, y: Tensor) -> Tensor:
"""Compute the loss for the respective task for a single sampling iteration.
Args:
pred: model_prediction
y: target
Returns:
nll loss for the task
"""
return self.criterion(pred, y)
[docs]
def adapt_output_for_metrics(self, out: Tensor) -> Tensor:
"""Adapt model output to be compatible for metric computation."""
return out
[docs]
def predict_step(
self, X: Tensor, batch_idx: int = 0, dataloader_idx: int = 0
) -> dict[str, Tensor]:
"""Prediction step.
Args:
X: prediction batch of shape [batch_size x input_dims]
batch_idx: batch index
dataloader_idx: dataloader index
"""
with torch.no_grad():
preds = torch.stack(
[self.model(X) for _ in range(self.hparams.num_mc_samples_test)], dim=-1
) # shape [batch_size, num_classes, num_samples]
return process_classification_prediction(preds)
[docs]
def on_test_batch_end(
self, outputs: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0
) -> None:
"""Test batch end save predictions.
Args:
outputs: dictionary of model outputs and aux variables
batch_idx: batch index
dataloader_idx: dataloader index
"""
save_classification_predictions(
outputs, os.path.join(self.trainer.default_root_dir, self.pred_file_name)
)
[docs]
class BNN_VI_ELBO_Segmentation(BNN_VI_ELBO_Classification):
"""Bayes By Backprop Model with Variational Inference (VI) for Segmentation.
If you use this model in your work, please cite:
* https://arxiv.org/abs/1505.05424
"""
pred_dir_name = "preds"
[docs]
def __init__(
self,
model: nn.Module,
criterion: nn.Module,
task: str = "multiclass",
beta: float = 100,
num_mc_samples_train: int = 10,
num_mc_samples_test: int = 50,
output_noise_scale: float = 1.3,
prior_mu: float = 0,
prior_sigma: float = 1,
posterior_mu_init: float = 0,
posterior_rho_init: float = -5,
bayesian_layer_type: str = "reparameterization",
stochastic_module_names: list[int] | list[str] | None = None,
freeze_backbone: bool = False,
freeze_decoder: bool = False,
optimizer: OptimizerCallable = torch.optim.Adam,
lr_scheduler: LRSchedulerCallable = None,
save_preds: bool = False,
) -> None:
"""Initialize a new BNN VI ELBO Segmentation instance.
Args:
model: pytorch model that will be converted into a BNN
criterion: loss function used for optimization
task: classification task, one of `binary`, `multiclass`, `multilabel`
beta: beta factor for negative elbo loss computation,
should be number of weights and biases
num_mc_samples_train: number of MC samples during training when computing
the negative ELBO loss. When setting num_mc_samples_train=1, this
is just Bayes by Backprop.
num_mc_samples_test: number of MC samples during test and prediction
output_noise_scale: scale of predicted sigmas
prior_mu: prior mean value for bayesian layer
prior_sigma: prior variance value for bayesian layer
posterior_mu_init: mean initialization value for approximate posterior
posterior_rho_init: variance initialization value for approximate posterior
through softplus σ = log(1 + exp(ρ))
bayesian_layer_type: `flipout` or `reparameterization`
stochastic_module_names: list of module names or indices that should
be converted to variational layers
freeze_backbone: whether to freeze the model backbone, by default this is
supported for torchseg Unet models
freeze_decoder: whether to freeze the model decoder, by default this is
supported for torchseg Unet models
lr_scheduler: learning rate scheduler
optimizer: optimizer used for training
save_preds: whether to save predictions
"""
self.freeze_backbone = freeze_backbone
self.freeze_decoder = freeze_decoder
super().__init__(
model,
criterion,
task,
beta,
num_mc_samples_train,
num_mc_samples_test,
output_noise_scale,
prior_mu,
prior_sigma,
posterior_mu_init,
posterior_rho_init,
bayesian_layer_type,
stochastic_module_names,
freeze_backbone,
optimizer,
lr_scheduler,
)
self.save_preds = save_preds
[docs]
def freeze_model(self) -> None:
"""Freeze model backbone.
By default, assumes a timm model with a backbone and head.
Alternatively, selected the last layer with parameters to freeze.
"""
freeze_segmentation_model(self.model, self.freeze_backbone, self.freeze_decoder)
[docs]
def setup_task(self) -> None:
"""Set up task specific attributes for segmentation."""
self.train_metrics = default_segmentation_metrics(
"train", self.task, self.num_classes
)
self.val_metrics = default_segmentation_metrics(
"val", self.task, self.num_classes
)
self.test_metrics = default_segmentation_metrics(
"test", self.task, self.num_classes
)
[docs]
def predict_step(
self, X: Tensor, batch_idx: int = 0, dataloader_idx: int = 0
) -> dict[str, Tensor]:
"""Prediction step for segmentation.
Args:
X: prediction batch of shape [batch_size x num_channels x height x width]
batch_idx: batch index
dataloader_idx: dataloader index
"""
with torch.no_grad():
preds = torch.stack(
[self.model(X) for _ in range(self.hparams.num_mc_samples_test)], dim=-1
) # shape [batch_size, num_classes, height, width, num_samples]
return process_segmentation_prediction(preds)
[docs]
def on_test_start(self) -> None:
"""Create logging directory and initialize metrics."""
self.pred_dir = os.path.join(self.trainer.default_root_dir, self.pred_dir_name)
if not os.path.exists(self.pred_dir) and self.save_preds:
os.makedirs(self.pred_dir)
[docs]
def on_test_batch_end(
self,
outputs: dict[str, Tensor],
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
"""Test batch end save predictions.
Args:
outputs: dictionary of model outputs and aux variables
batch: batch from dataloader
batch_idx: batch index
dataloader_idx: dataloader index
"""
if self.save_preds:
save_image_predictions(outputs, batch_idx, self.pred_dir)