Deep Ensemble#
[1]:
%%capture
%pip install git+https://github.com/lightning-uq-box/lightning-uq-box.git
Theoretic Foundation#
Introduced in Lakshminarayanan, 2017, Deep Ensembles approximate a posterior distribution over the model weights with a Gaussian mixture model over the output of separately initialized and trained networks. In Wilson, 2020 the authors showed that Deep Ensembles can be interpreted as a Bayesian method.
For the Deep Ensembles model the predictive mean is given by the mean taken over \(N \in \mathbb{N}\) models \(f_{\theta_i}(x^{\star}) = \mu_{\theta_i}(x^{\star})\) that output a mean with different weights \(\{\theta_i\}_{i=1}^N\),
The predictive uncertainty is given by the standard deviation of the predictions of the \(N\) different networks, Gaussian ensemble members,
Deep Ensembles GMM#
For the Deep Ensembles GMM model, the predictive mean is given by the mean taken over \(N \in \mathbb{N}\) models \(f_{\theta_i}(x^{\star}) = (\mu_{\theta_i}(x^{\star}), \sigma_{\theta_i}(x^{\star}))\) with different weights \(\{\theta_i\}_{i=1}^N\),
The predictive uncertainty is given by the standard deviation of the Gaussian mixture model consisting of the \(N\) different networks, Gaussian ensemble members,
Note that the difference between “Deep Ensembles” and “Deep Ensembles GMM” is that in the latter we also consider the predictive uncertainty output of each individual ensemble member, whereas in the former we only consider the means and the variance of the mean predictions of the ensemble members.
Imports#
[2]:
import os
import tempfile
from functools import partial
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from lightning import Trainer
from lightning.pytorch import seed_everything
from lightning_uq_box.datamodules import ToyHeteroscedasticDatamodule
from lightning_uq_box.models import MLP
from lightning_uq_box.uq_methods import DeepEnsembleRegression, MVERegression
from lightning_uq_box.viz_utils import (
plot_calibration_uq_toolbox,
plot_predictions_regression,
plot_toy_regression_data,
)
plt.rcParams["figure.figsize"] = [14, 5]
INFO:root:Asdfghjkl backend not available since the old asdfghjkl dependency is not installed. If you want to use it, run: pip install git+https://git@github.com/wiseodd/asdl@asdfghjkl
[3]:
seed_everything(0) # seed everything for reproducibility
Seed set to 0
[3]:
0
We define a temporary directory to look at some training metrics and results.
Datamodule#
To demonstrate the method, we will make use of a Toy Regression Example that is defined as a Lightning Datamodule. While this might seem like overkill for a small toy problem, we think it is more helpful how the individual pieces of the library fit together so you can train models on more complex tasks.
[4]:
my_temp_dir = tempfile.mkdtemp()
[5]:
dm = ToyHeteroscedasticDatamodule()
X_train, Y_train, train_loader, X_test, Y_test, test_loader, X_gtext, Y_gtext = (
dm.X_train,
dm.Y_train,
dm.train_dataloader(),
dm.X_test,
dm.Y_test,
dm.test_dataloader(),
dm.X_gtext,
dm.Y_gtext,
)
[6]:
fig = plot_toy_regression_data(X_train, Y_train, X_test, Y_test)
Model#
For our Toy Regression problem, we will use a simple Multi-layer Perceptron (MLP) that you can configure to your needs. For the documentation of the MLP see here. In the case of the Deep Ensemble, we will train 5 differently initialzed base deterministic networks and later combine them into an ensemble for predictions. We will keep track of the model checkpoints of these models and save them manually in our temporary directory. Later we can use these to initialized the different members that make up our Ensemble during prediction, where only the corresponding ensemble member that is needed will be loaded to reduce memory requirements.
[7]:
n_ensembles = 5
trained_models_nll = []
for i in range(n_ensembles):
mlp_model = MLP(n_hidden=[50, 50], n_outputs=2, activation_fn=nn.Tanh())
ensemble_member = MVERegression(
mlp_model, optimizer=partial(torch.optim.Adam, lr=1e-2), burnin_epochs=20
)
trainer = Trainer(
max_epochs=150,
limit_val_batches=0,
num_sanity_val_steps=0,
logger=False,
enable_checkpointing=False,
default_root_dir=my_temp_dir,
)
trainer.fit(ensemble_member, dm)
save_path = os.path.join(my_temp_dir, f"model_nll_{i}.ckpt")
trainer.save_checkpoint(save_path)
trained_models_nll.append({"base_model": ensemble_member, "ckpt_path": save_path})
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
| Name | Type | Params | Mode
-----------------------------------------------------------
0 | model | MLP | 2.8 K | train
1 | loss_fn | NLL | 0 | train
2 | train_metrics | MetricCollection | 0 | train
3 | val_metrics | MetricCollection | 0 | train
4 | test_metrics | MetricCollection | 0 | train
-----------------------------------------------------------
2.8 K Trainable params
0 Non-trainable params
2.8 K Total params
0.011 Total estimated model params size (MB)
21 Modules in train mode
0 Modules in eval mode
`Trainer.fit` stopped: `max_epochs=150` reached.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
| Name | Type | Params | Mode
-----------------------------------------------------------
0 | model | MLP | 2.8 K | train
1 | loss_fn | NLL | 0 | train
2 | train_metrics | MetricCollection | 0 | train
3 | val_metrics | MetricCollection | 0 | train
4 | test_metrics | MetricCollection | 0 | train
-----------------------------------------------------------
2.8 K Trainable params
0 Non-trainable params
2.8 K Total params
0.011 Total estimated model params size (MB)
21 Modules in train mode
0 Modules in eval mode
`Trainer.fit` stopped: `max_epochs=150` reached.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
| Name | Type | Params | Mode
-----------------------------------------------------------
0 | model | MLP | 2.8 K | train
1 | loss_fn | NLL | 0 | train
2 | train_metrics | MetricCollection | 0 | train
3 | val_metrics | MetricCollection | 0 | train
4 | test_metrics | MetricCollection | 0 | train
-----------------------------------------------------------
2.8 K Trainable params
0 Non-trainable params
2.8 K Total params
0.011 Total estimated model params size (MB)
21 Modules in train mode
0 Modules in eval mode
`Trainer.fit` stopped: `max_epochs=150` reached.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
| Name | Type | Params | Mode
-----------------------------------------------------------
0 | model | MLP | 2.8 K | train
1 | loss_fn | NLL | 0 | train
2 | train_metrics | MetricCollection | 0 | train
3 | val_metrics | MetricCollection | 0 | train
4 | test_metrics | MetricCollection | 0 | train
-----------------------------------------------------------
2.8 K Trainable params
0 Non-trainable params
2.8 K Total params
0.011 Total estimated model params size (MB)
21 Modules in train mode
0 Modules in eval mode
`Trainer.fit` stopped: `max_epochs=150` reached.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
| Name | Type | Params | Mode
-----------------------------------------------------------
0 | model | MLP | 2.8 K | train
1 | loss_fn | NLL | 0 | train
2 | train_metrics | MetricCollection | 0 | train
3 | val_metrics | MetricCollection | 0 | train
4 | test_metrics | MetricCollection | 0 | train
-----------------------------------------------------------
2.8 K Trainable params
0 Non-trainable params
2.8 K Total params
0.011 Total estimated model params size (MB)
21 Modules in train mode
0 Modules in eval mode
`Trainer.fit` stopped: `max_epochs=150` reached.
Construct Deep Ensemble#
[8]:
deep_ens_nll = DeepEnsembleRegression(trained_models_nll)
Evaluate Predictions#
The constructed Data Module contains two possible test variable. X_test are IID samples from the same noise distribution as the training data, while X_gtext (“X ground truth extended”) are dense inputs from the underlying “ground truth” function without any noise that also extends the input range to either side, so we can visualize the method’s UQ tendencies when extrapolating beyond the training data range. Thus, we will use X_gtext for visualization purposes, but use X_test to
compute uncertainty and calibration metrics because we want to analyse how well the method has learned the noisy data distribution.
[9]:
preds = deep_ens_nll.predict_step(X_gtext)
fig = plot_predictions_regression(
X_train,
Y_train,
X_gtext,
Y_gtext,
preds["pred"],
preds["pred_uct"],
epistemic=preds["epistemic_uct"],
aleatoric=preds["aleatoric_uct"],
title="Deep Ensemble NLL",
show_bands=False,
)
INFO:matplotlib.mathtext:Substituting symbol V from STIXNonUnicode
INFO:matplotlib.mathtext:Substituting symbol V from STIXNonUnicode
For some additional metrics relevant to UQ, we can use the great uncertainty-toolbox that gives us some insight into the calibration of our prediction. For a discussion of why this is important, see …
[10]:
preds = deep_ens_nll.predict_step(X_test)
fig = plot_calibration_uq_toolbox(
preds["pred"].numpy(),
preds["pred_uct"].numpy(),
Y_test.cpu().numpy(),
X_test.cpu().numpy(),
)
Additional Resources#
Links to othere related literature that might be interesting.