Variational Bayesian Last Layer (VBLL) Regression

Variational Bayesian Last Layer (VBLL) Regression#

This notebook illustrates the recently developed method called “Variational Bayesian Last Layer” from Harrison et. al 2024 that introduces a sampling free Variational Inference method that can be attached as a last layer to a neural network. For more details we refer the reader to their paper or their implementation for which we provide a Lightning wrapper.

[1]:
%%capture
%pip install git+https://github.com/lightning-uq-box/lightning-uq-box.git
%pip install vbll
[2]:
import os
import tempfile
from functools import partial

import torch
import torch.nn as nn
from lightning import Trainer
from lightning.pytorch import seed_everything
from lightning.pytorch.loggers import CSVLogger

from lightning_uq_box.datamodules import ToyHeteroscedasticDatamodule
from lightning_uq_box.models import MLP
from lightning_uq_box.uq_methods import VBLLRegression
from lightning_uq_box.viz_utils import (
    plot_predictions_regression,
    plot_toy_regression_data,
    plot_training_metrics,
)

%load_ext autoreload
%autoreload 2
[3]:
# temporary directory for saving
my_temp_dir = tempfile.mkdtemp()

seed_everything(42)
Seed set to 42
[3]:
42

Datamodule#

[4]:
dm = ToyHeteroscedasticDatamodule(batch_size=64)

X_train, Y_train, train_loader, X_test, Y_test, test_loader, X_gtext, Y_gtext = (
    dm.X_train,
    dm.Y_train,
    dm.train_dataloader(),
    dm.X_test,
    dm.Y_test,
    dm.test_dataloader(),
    dm.X_gtext,
    dm.Y_gtext,
)
[5]:
fig = plot_toy_regression_data(X_train, Y_train, X_test, Y_test)
../../_images/tutorials_regression_vbll_7_0.png

Model#

We define a backbone model, to which the VBLL layer will be attached. This can be any neural network architecture, also pretrained ones. For this regression task, we will use a simple MLP.

[6]:
network = MLP(n_inputs=1, n_hidden=[64, 64], n_outputs=64, activation_fn=nn.Tanh())
network
[6]:
MLP(
  (model): Sequential(
    (0): Linear(in_features=1, out_features=64, bias=True)
    (1): Tanh()
    (2): Dropout(p=0.0, inplace=False)
    (3): Linear(in_features=64, out_features=64, bias=True)
    (4): Tanh()
    (5): Dropout(p=0.0, inplace=False)
    (6): Linear(in_features=64, out_features=64, bias=True)
  )
)

Under the hood, the VBLL layer will be attached to the network to form a model that can be trained. If you have a pretrained model, there is a freeze_backbone argument, and if true, only the last variational layer will be trained.

One should pay attention to the following hyperparameters:

  • parameterization: how the last layer covariance matrix is parameterized, possible options are diagonal or dense

  • regularization_weight: specifies how much the KL term should be regularized and will impact the epistemic uncertainty in the last layer. It should be 1 / (number of training examples) by default, however, it can also be used as a hyperparamter to tune the epistemic uncertainty, where larger regularization weight will lead to a larger epistemic uncertainty estimate

  • prior_scale: specifies the scale of the prior in the last layer

  • wishart_scale: specifies a regularizing weight of the noise covariance in the last layer and influences the aleatoric uncertainty

Additionally, VBLL can also be applied to pretrained networks. For this purpose set replace_ll=True and freeze_backbone=True to replace the last layer in the architecture with a VBLL layer and only train this layer.

[7]:
vbll_model = VBLLRegression(
    model=network,
    num_targets=1,
    regularization_weight=(1 / X_train.shape[0]) * 50,
    optimizer=partial(torch.optim.Adam, lr=3e-3),
    parameterization="dense",
    prior_scale=1.0,
    wishart_scale=0.1,
)

Trainer#

[8]:
logger = CSVLogger(my_temp_dir)
trainer = Trainer(
    accelerator="cpu",
    max_epochs=500,  # number of epochs we want to train
    logger=logger,
    log_every_n_steps=3,
    enable_checkpointing=False,
    enable_progress_bar=False,
    limit_val_batches=0.0,  # no validation runs
    default_root_dir=my_temp_dir,
    gradient_clip_val=1.0,
)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
💡 Tip: For seamless cloud logging and experiment tracking, try installing [litlogger](https://pypi.org/project/litlogger/) to enable LitLogger, which logs metrics and artifacts automatically to the Lightning Experiments platform.
[9]:
trainer.fit(vbll_model, dm)

  | Name          | Type             | Params | Mode  | FLOPs
-------------------------------------------------------------------
0 | model         | MLP              | 8.5 K  | train | 0
1 | train_metrics | MetricCollection | 0      | train | 0
2 | val_metrics   | MetricCollection | 0      | train | 0
3 | test_metrics  | MetricCollection | 0      | train | 0
-------------------------------------------------------------------
8.5 K     Trainable params
1         Non-trainable params
8.5 K     Total params
0.034     Total estimated model params size (MB)
20        Modules in train mode
0         Modules in eval mode
0         Total Flops
/home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/latest/lib/python3.12/site-packages/lightning/pytorch/utilities/_pytree.py:21: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
`Trainer.fit` stopped: `max_epochs=500` reached.
[10]:
fig = plot_training_metrics(
    os.path.join(my_temp_dir, "lightning_logs"), ["train_loss", "trainRMSE"]
)
../../_images/tutorials_regression_vbll_16_0.png

Evaluate Predictions#

The constructed Data Module contains two possible test variable. X_test are IID samples from the same noise distribution as the training data, while X_gtext (“X ground truth extended”) are dense inputs from the underlying “ground truth” function without any noise that also extends the input range to either side, so we can visualize the method’s UQ tendencies when extrapolating beyond the training data range. Thus, we will use X_gtext for visualization purposes, but use X_test to compute uncertainty and calibration metrics because we want to analyse how well the method has learned the noisy data distribution.

[11]:
preds = vbll_model.predict_step(X_gtext.to(vbll_model.device))

fig = plot_predictions_regression(
    X_train,
    Y_train,
    X_gtext,
    Y_gtext,
    preds["pred"],
    preds["pred_uct"].squeeze(-1),
    epistemic=preds["pred_uct"].squeeze(-1),
    show_bands=False,
    title="VBLL",
)
../../_images/tutorials_regression_vbll_18_0.png
[ ]: