Laplace Approximation#
[1]:
%%capture
%pip install git+https://github.com/lightning-uq-box/lightning-uq-box.git
Theoretic Foundation#
The Laplace Approximation was originally introduced by MacKay, 1992. Then, the Laplace Approximation has been adapted to modern neural networks by Ritter, 2018 and Daxberger, 2021 and is an approximate Bayesian method. The goal of the Laplace Approximation is to use a second-order Taylor expansion around the fitted MAP estimate and yield a posterior approximation over the model parameters via a full-rank, diagonal or Kronecker-factorized approach. In order for the Laplace Approximation to be computationally feasible for larger network architectures, we use the Laplace library to include approaches, such as subnetwork selection that have been for example proposed by Daxberger, 2021.
The general idea of the Laplace Approximation to obtain a distribution over the network parameters is to approximate the posterior with a Gaussian distribution centered at the MAP estimate of the parameters Daxberger, 2021. In this setting, we define a prior distribution \(p(\theta)\) over our network parameters. Because modern neural networks consists of millions of parameters, obtaining a posterior distribution over the weights \(\theta\) is intractable. The LA takes MAP estimate of the parameters \(\theta_{MAP}\) from a trained network \(f_{\theta_{MAP}}(x) = \mu_{\theta_{MAP}}(x)\) and constructs a Gaussian distribution around it. The parameters \(\theta_{MAP}\) are obtained by
where \(\mathcal{L}\) is the mean squared error or also referred to as the \(\ell^2\) loss, \(\mathcal{L}(\theta; \mathcal{D}) := -\sum_{i=1}^n log(p(y_i|f_{\theta}(x_i)))\) and we chose the posterior \(p(y_i|f_{\theta}(x_i))\) to be a Gaussian with constant variance \(\sigma^2\), such that the loss is the mean squared error and a homoskedastic noise model is assumed. Then with Bayes Theorem, as in Daxberger, 2021, one can relate the posterior to the loss,
with \(Z = \int p(D\vert\theta)p(\theta) d\theta\). Now a second-order expansion of \(\mathcal{L}\) around \(\theta_{MAP}\) is used to construct a Gaussian approximation to the posterior \(p(\theta|D)\):
The term with the first order derivative is zero as the loss is evaluated at a minimum \(\theta_{MAP}\) Murphy, 2022, and, further, one assumes that the first term is neglible as the loss is evaluated at \(\theta = \theta_{MAP}\). Then taking the expontential of both sides allows to identify, after normalization, the Laplace approximation,
\begin{align*} p(\theta|D) \approx \mathcal{N}(\theta_{MAP}, \Sigma) && \text{with} \qquad \Sigma = (\nabla_{\theta}^2 \mathcal{L}(\theta; D)\vert \theta_{MAP})^{-1}. \end{align*}
As the covariance is just the inverse Hessian of the loss, with \(\theta_{MAP}\in \mathcal{R}^W\) and \(H^{-1}\in \mathcal{R}^{W\times W}\), with \(W\) being the number of weights, we get the posterior distribution
The computation of the Hessian term is still expensive. Therefore, further approximations are introduced in practice, most commonly the Generalized Gauss-Newton matrix \cite{martens2020new}. This takes the following form:
where \(J_n\in \mathcal{R}^{O\times W}\) is the Jacobian of the model outputs with respect to the parameters \(\theta\) and \(H_n\in\mathcal{R}^{O\times O}\) is the Hessian of the negative log-likelihood with respect to the model outputs, where \(O\) denotes the model output size and \(W\) the number of parameters.
During prediction we cannot compute the full posterior predictive distribution but instead resort to approximations. One strategy is to do sampling \(\theta_s \sim p(\theta|D)\) for \(s \in \{1, ...,S\}\) to approximate the predictions, however, Immer et al. 2021 suggested that a linearization of the form \(f_{\theta}(x)=f_{\theta_{MAP}}(x)+ J_{\theta_{MAP}}(\theta-\theta_{MAP})\) works better in practice and this is also the default in the Laplace library.
and obtain the predictive uncertainty by
The implementation is a wrapper around a model from the fantastic Laplace library so all the available options for subnet strategies can be found in their docs.
Imports#
[2]:
import os
import tempfile
from functools import partial
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from laplace import Laplace
from lightning import Trainer
from lightning.pytorch import seed_everything
from lightning.pytorch.loggers import CSVLogger
from lightning_uq_box.datamodules import ToyHeteroscedasticDatamodule
from lightning_uq_box.models import MLP
from lightning_uq_box.uq_methods import DeterministicRegression, LaplaceRegression
from lightning_uq_box.viz_utils import (
plot_calibration_uq_toolbox,
plot_predictions_regression,
plot_toy_regression_data,
plot_training_metrics,
)
plt.rcParams["figure.figsize"] = [14, 5]
/opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
[3]:
seed_everything(0) # seed everything for reproducibility
INFO: Seed set to 0
INFO:lightning.fabric.utilities.seed:Seed set to 0
[3]:
0
We define a temporary directory to look at some training metrics and results.
[4]:
my_temp_dir = tempfile.mkdtemp()
Datamodule#
To demonstrate the method, we will make use of a Toy Regression Example that is defined as a Lightning Datamodule. While this might seem like overkill for a small toy problem, we think it is more helpful how the individual pieces of the library fit together so you can train models on more complex tasks.
[5]:
dm = ToyHeteroscedasticDatamodule()
X_train, Y_train, train_loader, X_test, Y_test, test_loader, X_gtext, Y_gtext = (
dm.X_train,
dm.Y_train,
dm.train_dataloader(),
dm.X_test,
dm.Y_test,
dm.test_dataloader(),
dm.X_gtext,
dm.Y_gtext,
)
[6]:
fig = plot_toy_regression_data(X_train, Y_train, X_test, Y_test)
Model#
For our Toy Regression problem, we will use a simple Multi-layer Perceptron (MLP) that you can configure to your needs. For the documentation of the MLP see here.
[7]:
network = MLP(n_inputs=1, n_hidden=[50, 50], n_outputs=1, activation_fn=nn.Tanh())
network
[7]:
MLP(
(model): Sequential(
(0): Linear(in_features=1, out_features=50, bias=True)
(1): Tanh()
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=50, out_features=50, bias=True)
(4): Tanh()
(5): Dropout(p=0.0, inplace=False)
(6): Linear(in_features=50, out_features=1, bias=True)
)
)
For the Laplace model, we first train a plain deterministic model to obtain a MAP estimate of the weights via the standard MSE loss. Subsequently, we fit the Laplace Approximation to obtain an estimate of the epistemic uncertainty for predictions.
[8]:
deterministic_model = DeterministicRegression(
model=network,
optimizer=partial(torch.optim.Adam, lr=1e-2),
loss_fn=torch.nn.MSELoss(),
)
Trainer#
Now that we have a LightningDataModule and base model, we can conduct training with a Lightning Trainer. It has tons of options to make your life easier, so we encourage you to check the documentation.
[ ]:
logger = CSVLogger(my_temp_dir)
trainer = Trainer(
accelerator="cpu",
devices=[0],
max_epochs=100, # number of epochs we want to train
logger=logger, # log training metrics for later evaluation
log_every_n_steps=1,
enable_checkpointing=False,
enable_progress_bar=False,
)
/opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packag ...
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
Training our model is now easy:
[10]:
trainer.fit(deterministic_model, dm)
/opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/torch/__init__.py:1551: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:80.)
return _C._get_float32_matmul_precision()
INFO: You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO:lightning.pytorch.utilities.rank_zero:You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
INFO:
| Name | Type | Params | Mode
-----------------------------------------------------------
0 | model | MLP | 2.7 K | train
1 | loss_fn | MSELoss | 0 | train
2 | train_metrics | MetricCollection | 0 | train
3 | val_metrics | MetricCollection | 0 | train
4 | test_metrics | MetricCollection | 0 | train
-----------------------------------------------------------
2.7 K Trainable params
0 Non-trainable params
2.7 K Total params
0.011 Total estimated model params size (MB)
21 Modules in train mode
0 Modules in eval mode
INFO:lightning.pytorch.callbacks.model_summary:
| Name | Type | Params | Mode
-----------------------------------------------------------
0 | model | MLP | 2.7 K | train
1 | loss_fn | MSELoss | 0 | train
2 | train_metrics | MetricCollection | 0 | train
3 | val_metrics | MetricCollection | 0 | train
4 | test_metrics | MetricCollection | 0 | train
-----------------------------------------------------------
2.7 K Trainable params
0 Non-trainable params
2.7 K Total params
0.011 Total estimated model params size (MB)
21 Modules in train mode
0 Modules in eval mode
/opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.
/opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.
INFO: `Trainer.fit` stopped: `max_epochs=100` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=100` reached.
Training Metrics#
To get some insights into how the training went, we can use the utility function to plot the training loss and RMSE metric.
[11]:
fig = plot_training_metrics(
os.path.join(my_temp_dir, "lightning_logs"), ["train_loss", "trainRMSE"]
)
Fit Laplace#
We will utilize the great Laplace Library that allows you to define different flavors of Laplace approximations. For small networks like in this example, one can fit the Laplace approximation over all weights, but this is not feasible for large million-parameter networks. In those cases, on can resort to a βlast-layerβ approximation, where only the last layer weights are stochastic, while all other weights are deterministic. This behavior is controlled
with the subset_of_weights parameter. This is chosen in combination with the structure of the Hessian that is fitted, see the hessian_structure parameter. Check their documentation for details. The Lightning-UQ-Box provides a wrapper so that the workflow is the same as with any other implemented UQ-Method.
One can also tune the prior precision and or sigma noise values after the Laplace fitting procedure with the argument tune_prior_precision=True and tune_sigma_noise=True otherwise, the prediction will rely on the default sigma_noise values passed to the Laplace class for an estimate of aleatoric uncertainty under a homoscedastic noise assumption. For a related discussion one can look at this GitHub issue.
[ ]:
la = Laplace(
deterministic_model.model,
"regression",
subset_of_weights="last_layer",
hessian_structure="full",
sigma_noise=0.4,
)
laplace_model = LaplaceRegression(laplace_model=la, tune_prior_precision=True)
trainer = Trainer(default_root_dir=my_temp_dir, enable_progress_bar=False)
INFO: Trainer will use only 1 of 8 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=8)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
INFO:lightning.pytorch.utilities.rank_zero:Trainer will use only 1 of 8 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=8)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
/opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packag ...
INFO: π‘ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
INFO:lightning.pytorch.utilities.rank_zero:π‘ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
/opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
Prediction#
For prediction we can either rely on the trainer.test() method or manually conduct a predict_step(). Using the trainer will save the predictions and some metrics to a CSV file, while the manual predict_step() with a single input tensor will generate a dictionary that holds the mean prediction as well as some other quantities of interest, for example the predicted standard deviation or quantile. The Laplace wrapper module will conduct the Laplace fitting procedure automatically before
making the first prediction and will use it for any subsequent call. Originally, this was done through sampling and multiple forward passes, however, Immer et al. 2021 showed that a linearization of the model achieves better performance in practice: \(f_{\theta}(x)=f_{\theta_{MAP}}(x)+ J_{\theta_{MAP}}(\theta-\theta_{MAP})\) and that is the current default implementation. However, arguments can be passed to the class or the individual predict step to
choose between the following procedures of sampling as pred_type="glm" (linearization) or pred_type="nn" (sampling).
[13]:
trainer.test(laplace_model, dm)
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
/opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.
Testing: | | 0/? [00:00<?, ?it/s]
100%|ββββββββββ| 100/100 [00:01<00:00, 83.09it/s, neg_marglik=75.84266662597656]
Testing: | | 0/? [00:00<?, ?it/s]
Testing DataLoader 0: 100%|ββββββββββ| 1/1 [00:00<00:00, 18.35it/s]
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Test metric DataLoader 0
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
testMAE 0.3344504237174988
testR2 0.7423638701438904
testRMSE 0.5086640119552612
test_loss 0.25873905420303345
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
[13]:
[{'test_loss': 0.25873905420303345,
'testMAE': 0.3344504237174988,
'testR2': 0.7423638701438904,
'testRMSE': 0.5086640119552612}]
Evaluate Predictions#
The constructed Data Module contains two possible test variable. X_test are IID samples from the same noise distribution as the training data, while X_gtext (βX ground truth extendedβ) are dense inputs from the underlying βground truthβ function without any noise that also extends the input range to either side, so we can visualize the methodβs UQ tendencies when extrapolating beyond the training data range. Thus, we will use X_gtext for visualization purposes, but use X_test to
compute uncertainty and calibration metrics because we want to analyse how well the method has learned the noisy data distribution.
[ ]:
preds = laplace_model.predict_step(X_gtext.to(laplace_model.device))
fig = plot_predictions_regression(
X_train,
Y_train,
X_gtext,
Y_gtext,
preds["pred"],
preds["pred_uct"],
epistemic=preds["epistemic_uct"],
aleatoric=preds["aleatoric_uct"],
title="Laplace Approximation",
)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[14], line 1
----> 1 preds = laplace_model.predict_step(X_gtext)
3 fig = plot_predictions_regression(
4 X_train,
5 Y_train,
(...) 12 title="Laplace Approximation",
13 )
File /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/lightning_uq_box/uq_methods/laplace_model.py:371, in LaplaceRegression.predict_step(self, X, batch_idx, dataloader_idx)
366 with torch.inference_mode(False):
367 # inference tensors are not saved for backward so need to create
368 # a clone with autograd enables
369 input = X.clone().requires_grad_()
--> 371 return self.forward(input)
File /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/lightning_uq_box/uq_methods/laplace_model.py:329, in LaplaceRegression.forward(self, X)
327 pred_std = torch.sqrt(laplace_epistemic + laplace_aleatoric**2)
328 else:
--> 329 mean, var = self.laplace_model(
330 X, pred_type=self.pred_type, link_approx=self.link_approx
331 )
332 mean = mean.squeeze().detach()
333 laplace_epistemic = var.squeeze().sqrt()
File /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/laplace/baselaplace.py:1144, in ParametricLaplace.__call__(self, x, pred_type, joint, link_approx, n_samples, diagonal_output, generator, fitting, **model_kwargs)
1141 likelihood = Likelihood.CLASSIFICATION if fitting else Likelihood.REGRESSION
1143 if pred_type == PredType.GLM:
-> 1144 return self._glm_forward_call(
1145 x, likelihood, joint, link_approx, n_samples, diagonal_output
1146 )
1147 else:
1148 if likelihood == Likelihood.REGRESSION:
File /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/laplace/baselaplace.py:646, in BaseLaplace._glm_forward_call(self, x, likelihood, joint, link_approx, n_samples, diagonal_output)
598 def _glm_forward_call(
599 self,
600 x: torch.Tensor | MutableMapping,
(...) 605 diagonal_output: bool = False,
606 ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
607 """Compute the posterior predictive on input data `x` for "glm" pred type.
608
609 Parameters
(...) 644 is returned with the mean and the predictive covariance.
645 """
--> 646 f_mu, f_var = self._glm_predictive_distribution(
647 x, joint=joint and likelihood == Likelihood.REGRESSION
648 )
650 if likelihood == Likelihood.REGRESSION:
651 if diagonal_output and not joint:
File /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/laplace/lllaplace.py:230, in LLLaplace._glm_predictive_distribution(self, X, joint, diagonal_output)
228 f_var = self.functional_variance(Js).diagonal(dim1=-2, dim2=-1)
229 else:
--> 230 Js, f_mu = self.backend.last_layer_jacobians(X, self.enable_backprop)
231 f_var = self.functional_variance(Js)
233 return (
234 (f_mu.detach(), f_var.detach())
235 if not self.enable_backprop
236 else (f_mu, f_var)
237 )
File /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/laplace/curvature/curvature.py:151, in CurvatureInterface.last_layer_jacobians(self, x, enable_backprop)
131 def last_layer_jacobians(
132 self,
133 x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
134 enable_backprop: bool = False,
135 ) -> tuple[torch.Tensor, torch.Tensor]:
136 """Compute Jacobians \\(\\nabla_{\\theta_\\textrm{last}} f(x;\\theta_\\textrm{last})\\)
137 only at current last-layer parameter \\(\\theta_{\\textrm{last}}\\).
138
(...) 149 output function `(batch, outputs)`
150 """
--> 151 f, phi = self.model.forward_with_features(x)
152 bsize = phi.shape[0]
153 output_size = int(f.numel() / bsize)
File /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/laplace/utils/feature_extractor.py:109, in FeatureExtractor.forward_with_features(self, x)
97 def forward_with_features(
98 self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any]
99 ) -> tuple[torch.Tensor, torch.Tensor]:
100 """Forward pass which returns the output of the penultimate layer along
101 with the output of the last layer. If the last layer is not known yet,
102 it will be determined when this function is called for the first time.
(...) 107 one batch of data to use as input for the forward pass
108 """
--> 109 out = self.forward(x)
110 features = self._features[self._last_layer_name]
112 if features.dim() > 2 and self.feature_reduction is not None:
File /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/laplace/utils/feature_extractor.py:94, in FeatureExtractor.forward(self, x)
91 out = self.find_last_layer(x)
92 else:
93 # if last and penultimate layers are already known
---> 94 out = self.model(x)
95 return out
File /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/torch/nn/modules/module.py:1775, in Module._wrapped_call_impl(self, *args, **kwargs)
1773 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1774 else:
-> 1775 return self._call_impl(*args, **kwargs)
File /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/torch/nn/modules/module.py:1786, in Module._call_impl(self, *args, **kwargs)
1781 # If we don't have any hooks, we want to skip the rest of the logic in
1782 # this function, and just call forward.
1783 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1784 or _global_backward_pre_hooks or _global_backward_hooks
1785 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1786 return forward_call(*args, **kwargs)
1788 result = None
1789 called_always_called_hooks = set()
File /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/laplace/utils/feature_extractor.py:94, in FeatureExtractor.forward(self, x)
91 out = self.find_last_layer(x)
92 else:
93 # if last and penultimate layers are already known
---> 94 out = self.model(x)
95 return out
File /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/torch/nn/modules/module.py:1775, in Module._wrapped_call_impl(self, *args, **kwargs)
1773 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1774 else:
-> 1775 return self._call_impl(*args, **kwargs)
File /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/torch/nn/modules/module.py:1786, in Module._call_impl(self, *args, **kwargs)
1781 # If we don't have any hooks, we want to skip the rest of the logic in
1782 # this function, and just call forward.
1783 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1784 or _global_backward_pre_hooks or _global_backward_hooks
1785 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1786 return forward_call(*args, **kwargs)
1788 result = None
1789 called_always_called_hooks = set()
File /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/lightning_uq_box/models/mlp.py:57, in MLP.forward(self, x)
48 def forward(self, x) -> Tensor:
49 """Forward pass through the neural network.
50
51 Args:
(...) 55 output from neural net of dimension [batch_size, n_outputs]
56 """
---> 57 return self.model(x)
File /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/torch/nn/modules/module.py:1775, in Module._wrapped_call_impl(self, *args, **kwargs)
1773 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1774 else:
-> 1775 return self._call_impl(*args, **kwargs)
File /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/torch/nn/modules/module.py:1786, in Module._call_impl(self, *args, **kwargs)
1781 # If we don't have any hooks, we want to skip the rest of the logic in
1782 # this function, and just call forward.
1783 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1784 or _global_backward_pre_hooks or _global_backward_hooks
1785 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1786 return forward_call(*args, **kwargs)
1788 result = None
1789 called_always_called_hooks = set()
File /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/torch/nn/modules/container.py:250, in Sequential.forward(self, input)
246 """
247 Runs the forward pass.
248 """
249 for module in self:
--> 250 input = module(input)
251 return input
File /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/torch/nn/modules/module.py:1775, in Module._wrapped_call_impl(self, *args, **kwargs)
1773 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1774 else:
-> 1775 return self._call_impl(*args, **kwargs)
File /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/torch/nn/modules/module.py:1786, in Module._call_impl(self, *args, **kwargs)
1781 # If we don't have any hooks, we want to skip the rest of the logic in
1782 # this function, and just call forward.
1783 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1784 or _global_backward_pre_hooks or _global_backward_hooks
1785 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1786 return forward_call(*args, **kwargs)
1788 result = None
1789 called_always_called_hooks = set()
File /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/torch/nn/modules/linear.py:134, in Linear.forward(self, input)
130 def forward(self, input: Tensor) -> Tensor:
131 """
132 Runs the forward pass.
133 """
--> 134 return F.linear(input, self.weight, self.bias)
RuntimeError: Expected all tensors to be on the same device, but got mat1 is on cpu, different from other tensors on cuda:0 (when checking argument in method wrapper_CUDA_addmm)
For some additional metrics relevant to UQ, we can use the great uncertainty-toolbox that gives us some insight into the calibration of our prediction. For a discussion of why this is important, see β¦
[ ]:
preds = laplace_model.predict_step(X_test.to(laplace_model.device))
fig = plot_calibration_uq_toolbox(
preds["pred"].cpu().numpy(),
preds["pred_uct"].numpy(),
Y_test.cpu().numpy(),
X_test.cpu().numpy(),
)
Additional Resources#
Daxberger et al. 2020 introduced a subnetwork selection strategy that turns selected weights βBayesianβ while keeping the rest of the network deterministic and show performance on par with deep ensembles.