Stochastic Weight Averaging - Gaussian (SWAG)#

[1]:
%%capture
%pip install git+https://github.com/lightning-uq-box/lightning-uq-box.git

Theoretic Foundation#

SWAG was introduced in A simple Baseline for Bayesian Uncertainty in Deep Learning (Maddox et al. 2019).

Imports#

[2]:
import os
import tempfile
from functools import partial

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from lightning import Trainer
from lightning.pytorch import seed_everything
from lightning.pytorch.loggers import CSVLogger

from lightning_uq_box.datamodules import TwoMoonsDataModule
from lightning_uq_box.models import MLP
from lightning_uq_box.uq_methods import DeterministicClassification, SWAGClassification
from lightning_uq_box.viz_utils import (
    plot_predictions_classification,
    plot_training_metrics,
    plot_two_moons_data,
)

plt.rcParams["figure.figsize"] = [14, 5]
INFO:root:Asdfghjkl backend not available since the old asdfghjkl dependency is not installed. If you want to use it, run: pip install git+https://git@github.com/wiseodd/asdl@asdfghjkl
[3]:
seed_everything(0)  # seed everything for reproducibility
Seed set to 0
[3]:
0

We define a temporary directory to look at some training metrics and results.

[4]:
my_temp_dir = tempfile.mkdtemp()

Datamodule#

To demonstrate the method, we will make use of a Toy Regression Example that is defined as a Lightning Datamodule. While this might seem like overkill for a small toy problem, we think it is more helpful how the individual pieces of the library fit together so you can train models on more complex tasks.

[5]:
dm = TwoMoonsDataModule()

X_train, Y_train, X_test, Y_test, test_grid_points = (
    dm.X_train,
    dm.Y_train,
    dm.X_test,
    dm.Y_test,
    dm.test_grid_points,
)
[6]:
fig = plot_two_moons_data(X_train, Y_train, X_test, Y_test)
../../_images/tutorials_classification_swag_10_0.png

Model#

For our Toy Classification problem, we will use a simple Multi-layer Perceptron (MLP) that you can configure to your needs. For the documentation of the MLP see here. We define a model with two outputs corresponding to the two classes.

[7]:
network = MLP(n_inputs=2, n_hidden=[50, 50, 50], n_outputs=2, activation_fn=nn.ReLU())
network
[7]:
MLP(
  (model): Sequential(
    (0): Linear(in_features=2, out_features=50, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.0, inplace=False)
    (3): Linear(in_features=50, out_features=50, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.0, inplace=False)
    (6): Linear(in_features=50, out_features=50, bias=True)
    (7): ReLU()
    (8): Dropout(p=0.0, inplace=False)
    (9): Linear(in_features=50, out_features=2, bias=True)
  )
)
[8]:
deterministic_model = DeterministicClassification(
    network, loss_fn=nn.CrossEntropyLoss(), optimizer=partial(torch.optim.Adam, lr=1e-3)
)

Trainer#

Now that we have a LightningDataModule and a UQ-Method as a LightningModule, we can conduct training with a Lightning Trainer. It has tons of options to make your life easier, so we encourage you to check the documentation.

[9]:
logger = CSVLogger(my_temp_dir)
trainer = Trainer(
    max_epochs=50,  # number of epochs we want to train
    logger=logger,  # log training metrics for later evaluation
    log_every_n_steps=1,
    enable_checkpointing=False,
    enable_progress_bar=False,
    default_root_dir=my_temp_dir,
)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

Training our model is now easy:

[10]:
trainer.fit(deterministic_model, dm)

  | Name          | Type             | Params | Mode
-----------------------------------------------------------
0 | model         | MLP              | 5.4 K  | train
1 | loss_fn       | CrossEntropyLoss | 0      | train
2 | train_metrics | MetricCollection | 0      | train
3 | val_metrics   | MetricCollection | 0      | train
4 | test_metrics  | MetricCollection | 0      | train
-----------------------------------------------------------
5.4 K     Trainable params
0         Non-trainable params
5.4 K     Total params
0.021     Total estimated model params size (MB)
23        Modules in train mode
0         Modules in eval mode
`Trainer.fit` stopped: `max_epochs=50` reached.

Training Metrics#

To get some insights into how the training went of our underlying deterministic model, we can use the utility function to plot the training loss and RMSE metric.

[11]:
fig = plot_training_metrics(
    os.path.join(my_temp_dir, "lightning_logs"), ["train_loss", "trainAcc"]
)
../../_images/tutorials_classification_swag_19_0.png

We can also make predictions with the underlying model.

[12]:
deterministic_preds = deterministic_model.predict_step(test_grid_points)
fig = plot_predictions_classification(
    X_test, Y_test, deterministic_preds["pred"].argmax(-1), test_grid_points
)
../../_images/tutorials_classification_swag_21_0.png

Apply SWAG#

We now have a deterministic model that can make predictions, however, we do not have any uncertainty around the network weights. SWAG is a Bayesian Approximation method to capture this uncertainty, and we will now apply it to obtain epistemic uncertainty.

[13]:
swag_model = SWAGClassification(
    deterministic_model.model,
    max_swag_snapshots=30,
    snapshot_freq=1,
    num_mc_samples=50,
    swag_lr=1e-3,
    loss_fn=nn.CrossEntropyLoss(),
)

Prediction#

For prediction we can either rely on the trainer.test() method or manually conduct a predict_step(). Using the trainer will save the predictions and some metrics to a CSV file, while the manual predict_step() with a single input tensor will generate a dictionary that holds the mean prediction as well as some other quantities of interest, for example the predicted standard deviation or quantile. The SWAG wrapper module will conduct the SWAG fitting procedure automatically before making the first prediction and will use it for any subsequent call to sample network weights for the desired number of Monte Carlo samples.

[14]:
trainer = Trainer(enable_progress_bar=False, max_epochs=30)
trainer.fit(swag_model, datamodule=dm)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default

  | Name         | Type             | Params | Mode
----------------------------------------------------------
0 | model        | MLP              | 5.4 K  | train
1 | loss_fn      | CrossEntropyLoss | 0      | train
2 | test_metrics | MetricCollection | 0      | train
----------------------------------------------------------
5.4 K     Trainable params
0         Non-trainable params
5.4 K     Total params
0.021     Total estimated model params size (MB)
15        Modules in train mode
0         Modules in eval mode
/home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (19) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
`Trainer.fit` stopped: `max_epochs=30` reached.
[15]:
preds = swag_model.predict_step(test_grid_points)

Evaluate Predictions#

[16]:
fig = plot_predictions_classification(
    X_test, Y_test, preds["pred"].argmax(-1), test_grid_points, preds["pred_uct"]
)
../../_images/tutorials_classification_swag_28_0.png