Bayes By Backprop - Mean Field Variational Inference

Bayes By Backprop - Mean Field Variational Inference#

[1]:
%%capture
%pip install git+https://github.com/lightning-uq-box/lightning-uq-box.git

Theoretic Foundation#

Imports#

[2]:
import os
import tempfile
from functools import partial

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from lightning import Trainer
from lightning.pytorch import seed_everything
from lightning.pytorch.loggers import CSVLogger

from lightning_uq_box.datamodules import TwoMoonsDataModule
from lightning_uq_box.models import MLP
from lightning_uq_box.uq_methods import BNN_VI_ELBO_Classification
from lightning_uq_box.viz_utils import (
    plot_predictions_classification,
    plot_training_metrics,
    plot_two_moons_data,
)

plt.rcParams["figure.figsize"] = [14, 5]

%load_ext autoreload
%autoreload 2
[3]:
seed_everything(0)  # seed everything for reproducibility
Seed set to 0
[3]:
0

We define a temporary directory to look at some training metrics and results.

[4]:
my_temp_dir = tempfile.mkdtemp()

Datamodule#

To demonstrate the method, we will make use of a Toy Regression Example that is defined as a Lightning Datamodule. While this might seem like overkill for a small toy problem, we think it is more helpful how the individual pieces of the library fit together so you can train models on more complex tasks.

[5]:
dm = TwoMoonsDataModule(batch_size=128)

X_train, Y_train, X_test, Y_test, test_grid_points = (
    dm.X_train,
    dm.Y_train,
    dm.X_test,
    dm.Y_test,
    dm.test_grid_points,
)

X_train.min(), X_train.max()
[5]:
(tensor(-1.1298), tensor(2.1606))
[6]:
fig = plot_two_moons_data(X_train, Y_train, X_test, Y_test)
../../_images/tutorials_classification_bnn_vi_elbo_10_0.png

Model#

For our Toy Regression problem, we will use a simple Multi-layer Perceptron (MLP) that you can configure to your needs. For the documentation of the MLP see here.

[7]:
network = MLP(n_inputs=2, n_hidden=[50, 50], n_outputs=2, activation_fn=nn.ReLU())
network
[7]:
MLP(
  (model): Sequential(
    (0): Linear(in_features=2, out_features=50, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.0, inplace=False)
    (3): Linear(in_features=50, out_features=50, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.0, inplace=False)
    (6): Linear(in_features=50, out_features=2, bias=True)
  )
)

With an underlying neural network, we can now use our desired UQ-Method as a sort of wrapper. All UQ-Methods are implemented as LightningModule that allow us to concisely organize the code and remove as much boilerplate code as possible.

[8]:
bbp_model = BNN_VI_ELBO_Classification(
    network,
    optimizer=partial(torch.optim.Adam, lr=1e-2),
    criterion=nn.CrossEntropyLoss(),
    num_mc_samples_train=10,
    num_mc_samples_test=25,
)

Trainer#

Now that we have a LightningDataModule and a UQ-Method as a LightningModule, we can conduct training with a Lightning Trainer. It has tons of options to make your life easier, so we encourage you to check the documentation.

[9]:
logger = CSVLogger(my_temp_dir)
trainer = Trainer(
    accelerator="cpu",
    max_epochs=250,  # number of epochs we want to train
    logger=logger,  # log training metrics for later evaluation
    log_every_n_steps=20,
    enable_checkpointing=False,
    enable_progress_bar=False,
    default_root_dir=my_temp_dir,
)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
💡 Tip: For seamless cloud logging and experiment tracking, try installing [litlogger](https://pypi.org/project/litlogger/) to enable LitLogger, which logs metrics and artifacts automatically to the Lightning Experiments platform.

Training our model is now easy:

[10]:
trainer.fit(bbp_model, dm)

  | Name          | Type             | Params | Mode  | FLOPs
-------------------------------------------------------------------
0 | model         | MLP              | 5.6 K  | train | 0
1 | loss_fn       | CrossEntropyLoss | 0      | train | 0
2 | train_metrics | MetricCollection | 0      | train | 0
3 | val_metrics   | MetricCollection | 0      | train | 0
4 | test_metrics  | MetricCollection | 0      | train | 0
-------------------------------------------------------------------
5.6 K     Trainable params
0         Non-trainable params
5.6 K     Total params
0.022     Total estimated model params size (MB)
21        Modules in train mode
0         Modules in eval mode
0         Total Flops
/home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/latest/lib/python3.12/site-packages/lightning/pytorch/utilities/_pytree.py:21: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
/home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/latest/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py:317: The number of training batches (5) is smaller than the logging interval Trainer(log_every_n_steps=20). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
`Trainer.fit` stopped: `max_epochs=250` reached.

Training Metrics#

To get some insights into how the training went, we can use the utility function to plot the training loss and RMSE metric.

[11]:
fig = plot_training_metrics(
    os.path.join(my_temp_dir, "lightning_logs"), ["train_loss", "trainAcc"]
)
../../_images/tutorials_classification_bnn_vi_elbo_20_0.png

Prediction#

[12]:
preds = bbp_model.predict_step(test_grid_points.to(bbp_model.device))

Evaluate Predictions#

[13]:
fig = plot_predictions_classification(
    X_test, Y_test, preds["pred"].argmax(-1), test_grid_points, preds["pred_uct"]
)
../../_images/tutorials_classification_bnn_vi_elbo_24_0.png