Deep Evidential Regression#
[1]:
%%capture
%pip install git+https://github.com/lightning-uq-box/lightning-uq-box.git
Theoretic Foundation#
Deep Evidential Regression (DER) Amini, 2020 is a single forward pass UQ method that aims to disentangle aleatoric and epistemic uncertainty. DER entails a four headed network output
that is used to compute the predictive t-distribution with \(2\alpha(x^{\star})\) degrees of freedom:
In Amini, 2020 the network weights are obtained by minimizing the loss objective that is the negative log-likelihood of the predictive distribution and a regularization term. However, due to several drawbacks of DER, \cite{meinert2023unreasonable} propose the following adapted loss objective that we also utilise,
where \(\sigma_{\theta}^2(x^{\star})=\beta_{\theta}(x^{\star})/\nu_{\theta}(x^{\star})\). The mean prediction is given by
Further following Meinert, 2022, we use their reformulation of the uncertainty decomposition. The aleatoric uncertainty is given by
and the epistemic uncertainy by,
The predictive uncertainty is then, given by
Imports#
[2]:
import os
import tempfile
from functools import partial
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from lightning import Trainer
from lightning.pytorch import seed_everything
from lightning.pytorch.loggers import CSVLogger
from lightning_uq_box.datamodules import ToyHeteroscedasticDatamodule
from lightning_uq_box.models import MLP
from lightning_uq_box.uq_methods import DER
from lightning_uq_box.viz_utils import (
plot_calibration_uq_toolbox,
plot_predictions_regression,
plot_toy_regression_data,
plot_training_metrics,
)
plt.rcParams["figure.figsize"] = [14, 5]
[3]:
seed_everything(0) # seed everything for reproducibility
Seed set to 0
[3]:
0
We define a temporary directory to look at some training metrics and results.
[4]:
my_temp_dir = tempfile.mkdtemp()
Datamodule#
To demonstrate the method, we will make use of a Toy Regression Example that is defined as a Lightning Datamodule. While this might seem like overkill for a small toy problem, we think it is more helpful how the individual pieces of the library fit together so you can train models on more complex tasks.
[5]:
dm = ToyHeteroscedasticDatamodule()
X_train, Y_train, train_loader, X_test, Y_test, test_loader, X_gtext, Y_gtext = (
dm.X_train,
dm.Y_train,
dm.train_dataloader(),
dm.X_test,
dm.Y_test,
dm.test_dataloader(),
dm.X_gtext,
dm.Y_gtext,
)
[6]:
fig = plot_toy_regression_data(X_train, Y_train, X_test, Y_test)
Model#
For our Toy Regression problem, we will use a simple Multi-layer Perceptron (MLP) that you can configure to your needs. For the documentation of the MLP see here. For the DER Model our underlying network will require 4 outputs that feed into a DERLayer based on which the loss function will be computed.
[7]:
network = MLP(n_inputs=1, n_hidden=[50, 50, 50], n_outputs=4, activation_fn=nn.Tanh())
network
[7]:
MLP(
(model): Sequential(
(0): Linear(in_features=1, out_features=50, bias=True)
(1): Tanh()
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=50, out_features=50, bias=True)
(4): Tanh()
(5): Dropout(p=0.0, inplace=False)
(6): Linear(in_features=50, out_features=50, bias=True)
(7): Tanh()
(8): Dropout(p=0.0, inplace=False)
(9): Linear(in_features=50, out_features=4, bias=True)
)
)
With an underlying neural network, we can now use our desired UQ-Method as a sort of wrapper. All UQ-Methods are implemented as LightningModule that allow us to concisely organize the code and remove as much boilerplate code as possible.
[8]:
der_model = DER(network, optimizer=partial(torch.optim.Adam, lr=1e-3))
Trainer#
Now that we have a LightningDataModule and a UQ-Method as a LightningModule, we can conduct training with a Lightning Trainer. It has tons of options to make your life easier, so we encourage you to check the documentation.
[9]:
logger = CSVLogger(my_temp_dir)
trainer = Trainer(
accelerator="cpu",
max_epochs=250, # number of epochs we want to train
logger=logger, # log training metrics for later evaluation
log_every_n_steps=1,
enable_checkpointing=False,
enable_progress_bar=False,
default_root_dir=my_temp_dir,
)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
đź’ˇ Tip: For seamless cloud logging and experiment tracking, try installing [litlogger](https://pypi.org/project/litlogger/) to enable LitLogger, which logs metrics and artifacts automatically to the Lightning Experiments platform.
Training our model is now easy:
[10]:
trainer.fit(der_model, dm)
| Name | Type | Params | Mode | FLOPs
-------------------------------------------------------------------
0 | model | MLP | 5.4 K | train | 0
1 | train_metrics | MetricCollection | 0 | train | 0
2 | val_metrics | MetricCollection | 0 | train | 0
3 | test_metrics | MetricCollection | 0 | train | 0
4 | loss_fn | DERLoss | 0 | train | 0
5 | der_layer | DERLayer | 0 | train | 0
-------------------------------------------------------------------
5.4 K Trainable params
0 Non-trainable params
5.4 K Total params
0.022 Total estimated model params size (MB)
24 Modules in train mode
0 Modules in eval mode
0 Total Flops
/home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/latest/lib/python3.12/site-packages/lightning/pytorch/utilities/_pytree.py:21: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
`Trainer.fit` stopped: `max_epochs=250` reached.
Training Metrics#
To get some insights into how the training went, we can use the utility function to plot the training loss and RMSE metric.
[11]:
fig = plot_training_metrics(
os.path.join(my_temp_dir, "lightning_logs"), ["train_loss", "trainRMSE"]
)
Evaluate Predictions#
The constructed Data Module contains two possible test variable. X_test are IID samples from the same noise distribution as the training data, while X_gtext (“X ground truth extended”) are dense inputs from the underlying “ground truth” function without any noise that also extends the input range to either side, so we can visualize the method’s UQ tendencies when extrapolating beyond the training data range. Thus, we will use X_gtext for visualization purposes, but use X_test to
compute uncertainty and calibration metrics because we want to analyse how well the method has learned the noisy data distribution.
[12]:
preds = der_model.predict_step(X_gtext.to(der_model.device))
fig = plot_predictions_regression(
X_train,
Y_train,
X_gtext,
Y_gtext,
preds["pred"],
preds["pred_uct"].squeeze(-1),
epistemic=preds["epistemic_uct"],
aleatoric=preds["aleatoric_uct"],
title="Deep Evidential Regression",
show_bands=False,
)
[13]:
preds = der_model.predict_step(X_test.to(der_model.device))
fig = plot_calibration_uq_toolbox(
preds["pred"].squeeze(-1).cpu().numpy(),
preds["pred_uct"].numpy(),
Y_test.cpu().numpy(),
X_test.cpu().numpy(),
)