Laplace Approximation#

[1]:
%%capture
%pip install git+https://github.com/lightning-uq-box/lightning-uq-box.git

Theoretic Foundation#

The Laplace Approximation was originally introduced by MacKay, 1992. Then, the Laplace Approximation has been adapted to modern neural networks by Ritter, 2018 and Daxberger, 2021 and is an approximate Bayesian method. The goal of the Laplace Approximation is to use a second-order Taylor expansion around the fitted MAP estimate and yield a posterior approximation over the model parameters via a full-rank, diagonal or Kronecker-factorized approach. In order for the Laplace Approximation to be computationally feasible for larger network architectures, we use the Laplace library to include approaches, such as subnetwork selection that have been for example proposed by Daxberger, 2021.

The general idea of the Laplace Approximation to obtain a distribution over the network parameters is to approximate the posterior with a Gaussian distribution centered at the MAP estimate of the parameters Daxberger, 2021. In this setting, we define a prior distribution \(p(\theta)\) over our network parameters. Because modern neural networks consists of millions of parameters, obtaining a posterior distribution over the weights \(\theta\) is intractable. The LA takes MAP estimate of the parameters \(\theta_{MAP}\) from a trained network \(f_{\theta_{MAP}}(x) = \mu_{\theta_{MAP}}(x)\) and constructs a Gaussian distribution around it. The parameters \(\theta_{MAP}\) are obtained by

\[\theta_{MAP} = \text{argmin} \mathcal{L}(\theta; D),\]

where \(\mathcal{L}\) is the mean squared error or also referred to as the \(\ell^2\) loss, \(\mathcal{L}(\theta; \mathcal{D}) := -\sum_{i=1}^n log(p(y_i|f_{\theta}(x_i)))\) and we chose the posterior \(p(y_i|f_{\theta}(x_i))\) to be a Gaussian with constant variance \(\sigma^2\), such that the loss is the mean squared error and a homoskedastic noise model is assumed. Then with Bayes Theorem, as in Daxberger, 2021, one can relate the posterior to the loss,

\[p(\theta|D) = p(D\vert\theta)p(\theta)/p(D)= \frac{1}{Z} exp(- \mathcal{L}(\theta; D)),\]

with \(Z = \int p(D\vert\theta)p(\theta) d\theta\). Now a second-order expansion of \(\mathcal{L}\) around \(\theta_{MAP}\) is used to construct a Gaussian approximation to the posterior \(p(\theta|D)\):

\[-\mathcal{L}(\theta; D) \approx -\mathcal{L}(\theta_{MAP}; D)- \frac{1}{2}(\theta-\theta_{MAP}) (\nabla_{\theta}^2 \mathcal{L}(\theta; D)\vert \theta_{MAP}) (\theta-\theta_{MAP}).\]

The term with the first order derivative is zero as the loss is evaluated at a minimum \(\theta_{MAP}\) Murphy, 2022, and, further, one assumes that the first term is neglible as the loss is evaluated at \(\theta = \theta_{MAP}\). Then taking the expontential of both sides allows to identify, after normalization, the Laplace approximation,

\begin{align*} p(\theta|D) \approx \mathcal{N}(\theta_{MAP}, \Sigma) && \text{with} \qquad \Sigma = (\nabla_{\theta}^2 \mathcal{L}(\theta; D)\vert \theta_{MAP})^{-1}. \end{align*}

As the covariance is just the inverse Hessian of the loss, with \(\theta_{MAP}\in \mathcal{R}^W\) and \(H^{-1}\in \mathcal{R}^{W\times W}\), with \(W\) being the number of weights, we get the posterior distribution

\[p(\theta|D)\approx \mathcal{N}(\theta_{MAP}, H^{-1}).\]

The computation of the Hessian term is still expensive. Therefore, further approximations are introduced in practice, most commonly the Generalized Gauss-Newton matrix \cite{martens2020new}. This takes the following form:

\[H \approx \widetilde{H}=\sum_{n=1}^NJ_n^TH_nJ_n,\]

where \(J_n\in \mathcal{R}^{O\times W}\) is the Jacobian of the model outputs with respect to the parameters \(\theta\) and \(H_n\in\mathcal{R}^{O\times O}\) is the Hessian of the negative log-likelihood with respect to the model outputs, where \(O\) denotes the model output size and \(W\) the number of parameters.

During prediction we cannot compute the full posterior predictive distribution but instead resort to approximations. One strategy is to do sampling \(\theta_s \sim p(\theta|D)\) for \(s \in \{1, ...,S\}\) to approximate the predictions, however, Immer et al. 2021 suggested that a linearization of the form \(f_{\theta}(x)=f_{\theta_{MAP}}(x)+ J_{\theta_{MAP}}(\theta-\theta_{MAP})\) works better in practice and this is also the default in the Laplace library.

\[\hat{y}(x^{\star}) = \frac{1}{S} \sum_{s=1}^S f_{\theta_s}(x^{\star}),\]

and obtain the predictive uncertainty by

\[\sigma^2(x^{\star}) = \sqrt{\frac{1}{S} \sum_{s=1}^S f_{\theta_s}(x^{\star})^2 - \hat{y}(x^{\star})^2+\sigma^2}.\]

The implementation is a wrapper around a model from the fantastic Laplace library so all the available options for subnet strategies can be found in their docs.

Imports#

[2]:
import os
import tempfile
from functools import partial

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from laplace import Laplace
from lightning import Trainer
from lightning.pytorch import seed_everything
from lightning.pytorch.loggers import CSVLogger

from lightning_uq_box.datamodules import ToyHeteroscedasticDatamodule
from lightning_uq_box.models import MLP
from lightning_uq_box.uq_methods import DeterministicRegression, LaplaceRegression
from lightning_uq_box.viz_utils import (
    plot_calibration_uq_toolbox,
    plot_predictions_regression,
    plot_toy_regression_data,
    plot_training_metrics,
)

plt.rcParams["figure.figsize"] = [14, 5]
/opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
[3]:
seed_everything(0)  # seed everything for reproducibility
INFO: Seed set to 0
INFO:lightning.fabric.utilities.seed:Seed set to 0
[3]:
0

We define a temporary directory to look at some training metrics and results.

[4]:
my_temp_dir = tempfile.mkdtemp()

Datamodule#

To demonstrate the method, we will make use of a Toy Regression Example that is defined as a Lightning Datamodule. While this might seem like overkill for a small toy problem, we think it is more helpful how the individual pieces of the library fit together so you can train models on more complex tasks.

[5]:
dm = ToyHeteroscedasticDatamodule()

X_train, Y_train, train_loader, X_test, Y_test, test_loader, X_gtext, Y_gtext = (
    dm.X_train,
    dm.Y_train,
    dm.train_dataloader(),
    dm.X_test,
    dm.Y_test,
    dm.test_dataloader(),
    dm.X_gtext,
    dm.Y_gtext,
)
[6]:
fig = plot_toy_regression_data(X_train, Y_train, X_test, Y_test)
../../_images/tutorials_regression_laplace_10_0.png

Model#

For our Toy Regression problem, we will use a simple Multi-layer Perceptron (MLP) that you can configure to your needs. For the documentation of the MLP see here.

[7]:
network = MLP(n_inputs=1, n_hidden=[50, 50], n_outputs=1, activation_fn=nn.Tanh())
network
[7]:
MLP(
  (model): Sequential(
    (0): Linear(in_features=1, out_features=50, bias=True)
    (1): Tanh()
    (2): Dropout(p=0.0, inplace=False)
    (3): Linear(in_features=50, out_features=50, bias=True)
    (4): Tanh()
    (5): Dropout(p=0.0, inplace=False)
    (6): Linear(in_features=50, out_features=1, bias=True)
  )
)

For the Laplace model, we first train a plain deterministic model to obtain a MAP estimate of the weights via the standard MSE loss. Subsequently, we fit the Laplace Approximation to obtain an estimate of the epistemic uncertainty for predictions.

[8]:
deterministic_model = DeterministicRegression(
    model=network,
    optimizer=partial(torch.optim.Adam, lr=1e-2),
    loss_fn=torch.nn.MSELoss(),
)

Trainer#

Now that we have a LightningDataModule and base model, we can conduct training with a Lightning Trainer. It has tons of options to make your life easier, so we encourage you to check the documentation.

[ ]:
logger = CSVLogger(my_temp_dir)
trainer = Trainer(
    accelerator="cpu",
    devices=[0],
    max_epochs=100,  # number of epochs we want to train
    logger=logger,  # log training metrics for later evaluation
    log_every_n_steps=1,
    enable_checkpointing=False,
    enable_progress_bar=False,
)
/opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packag ...
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores

Training our model is now easy:

[10]:
trainer.fit(deterministic_model, dm)
/opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/torch/__init__.py:1551: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:80.)
  return _C._get_float32_matmul_precision()
INFO: You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO:lightning.pytorch.utilities.rank_zero:You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
INFO:
  | Name          | Type             | Params | Mode
-----------------------------------------------------------
0 | model         | MLP              | 2.7 K  | train
1 | loss_fn       | MSELoss          | 0      | train
2 | train_metrics | MetricCollection | 0      | train
3 | val_metrics   | MetricCollection | 0      | train
4 | test_metrics  | MetricCollection | 0      | train
-----------------------------------------------------------
2.7 K     Trainable params
0         Non-trainable params
2.7 K     Total params
0.011     Total estimated model params size (MB)
21        Modules in train mode
0         Modules in eval mode
INFO:lightning.pytorch.callbacks.model_summary:
  | Name          | Type             | Params | Mode
-----------------------------------------------------------
0 | model         | MLP              | 2.7 K  | train
1 | loss_fn       | MSELoss          | 0      | train
2 | train_metrics | MetricCollection | 0      | train
3 | val_metrics   | MetricCollection | 0      | train
4 | test_metrics  | MetricCollection | 0      | train
-----------------------------------------------------------
2.7 K     Trainable params
0         Non-trainable params
2.7 K     Total params
0.011     Total estimated model params size (MB)
21        Modules in train mode
0         Modules in eval mode
/opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.
/opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.
INFO: `Trainer.fit` stopped: `max_epochs=100` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=100` reached.

Training Metrics#

To get some insights into how the training went, we can use the utility function to plot the training loss and RMSE metric.

[11]:
fig = plot_training_metrics(
    os.path.join(my_temp_dir, "lightning_logs"), ["train_loss", "trainRMSE"]
)
../../_images/tutorials_regression_laplace_20_0.png

Fit Laplace#

We will utilize the great Laplace Library that allows you to define different flavors of Laplace approximations. For small networks like in this example, one can fit the Laplace approximation over all weights, but this is not feasible for large million-parameter networks. In those cases, on can resort to a β€œlast-layer” approximation, where only the last layer weights are stochastic, while all other weights are deterministic. This behavior is controlled with the subset_of_weights parameter. This is chosen in combination with the structure of the Hessian that is fitted, see the hessian_structure parameter. Check their documentation for details. The Lightning-UQ-Box provides a wrapper so that the workflow is the same as with any other implemented UQ-Method.

One can also tune the prior precision and or sigma noise values after the Laplace fitting procedure with the argument tune_prior_precision=True and tune_sigma_noise=True otherwise, the prediction will rely on the default sigma_noise values passed to the Laplace class for an estimate of aleatoric uncertainty under a homoscedastic noise assumption. For a related discussion one can look at this GitHub issue.

[ ]:
la = Laplace(
    deterministic_model.model,
    "regression",
    subset_of_weights="last_layer",
    hessian_structure="full",
    sigma_noise=0.4,
)


laplace_model = LaplaceRegression(laplace_model=la, tune_prior_precision=True)

trainer = Trainer(default_root_dir=my_temp_dir, enable_progress_bar=False)
INFO: Trainer will use only 1 of 8 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=8)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
INFO:lightning.pytorch.utilities.rank_zero:Trainer will use only 1 of 8 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=8)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
/opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packag ...
INFO: πŸ’‘ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
INFO:lightning.pytorch.utilities.rank_zero:πŸ’‘ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
/opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default

Prediction#

For prediction we can either rely on the trainer.test() method or manually conduct a predict_step(). Using the trainer will save the predictions and some metrics to a CSV file, while the manual predict_step() with a single input tensor will generate a dictionary that holds the mean prediction as well as some other quantities of interest, for example the predicted standard deviation or quantile. The Laplace wrapper module will conduct the Laplace fitting procedure automatically before making the first prediction and will use it for any subsequent call. Originally, this was done through sampling and multiple forward passes, however, Immer et al. 2021 showed that a linearization of the model achieves better performance in practice: \(f_{\theta}(x)=f_{\theta_{MAP}}(x)+ J_{\theta_{MAP}}(\theta-\theta_{MAP})\) and that is the current default implementation. However, arguments can be passed to the class or the individual predict step to choose between the following procedures of sampling as pred_type="glm" (linearization) or pred_type="nn" (sampling).

[13]:
trainer.test(laplace_model, dm)
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
/opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.
Testing: |          | 0/? [00:00<?, ?it/s]
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 100/100 [00:01<00:00, 83.09it/s, neg_marglik=75.84266662597656]
Testing: |          | 0/? [00:00<?, ?it/s]

Testing DataLoader 0: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:00<00:00, 18.35it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
         testMAE            0.3344504237174988
         testR2             0.7423638701438904
        testRMSE            0.5086640119552612
        test_loss           0.25873905420303345
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
[13]:
[{'test_loss': 0.25873905420303345,
  'testMAE': 0.3344504237174988,
  'testR2': 0.7423638701438904,
  'testRMSE': 0.5086640119552612}]

Evaluate Predictions#

The constructed Data Module contains two possible test variable. X_test are IID samples from the same noise distribution as the training data, while X_gtext (β€œX ground truth extended”) are dense inputs from the underlying β€œground truth” function without any noise that also extends the input range to either side, so we can visualize the method’s UQ tendencies when extrapolating beyond the training data range. Thus, we will use X_gtext for visualization purposes, but use X_test to compute uncertainty and calibration metrics because we want to analyse how well the method has learned the noisy data distribution.

[ ]:
preds = laplace_model.predict_step(X_gtext.to(laplace_model.device))

fig = plot_predictions_regression(
    X_train,
    Y_train,
    X_gtext,
    Y_gtext,
    preds["pred"],
    preds["pred_uct"],
    epistemic=preds["epistemic_uct"],
    aleatoric=preds["aleatoric_uct"],
    title="Laplace Approximation",
)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[14], line 1
----> 1 preds = laplace_model.predict_step(X_gtext)
      3 fig = plot_predictions_regression(
      4     X_train,
      5     Y_train,
   (...)     12     title="Laplace Approximation",
     13 )

File /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/lightning_uq_box/uq_methods/laplace_model.py:371, in LaplaceRegression.predict_step(self, X, batch_idx, dataloader_idx)
    366 with torch.inference_mode(False):
    367     # inference tensors are not saved for backward so need to create
    368     # a clone with autograd enables
    369     input = X.clone().requires_grad_()
--> 371 return self.forward(input)

File /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/lightning_uq_box/uq_methods/laplace_model.py:329, in LaplaceRegression.forward(self, X)
    327     pred_std = torch.sqrt(laplace_epistemic + laplace_aleatoric**2)
    328 else:
--> 329     mean, var = self.laplace_model(
    330         X, pred_type=self.pred_type, link_approx=self.link_approx
    331     )
    332     mean = mean.squeeze().detach()
    333     laplace_epistemic = var.squeeze().sqrt()

File /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/laplace/baselaplace.py:1144, in ParametricLaplace.__call__(self, x, pred_type, joint, link_approx, n_samples, diagonal_output, generator, fitting, **model_kwargs)
   1141     likelihood = Likelihood.CLASSIFICATION if fitting else Likelihood.REGRESSION
   1143 if pred_type == PredType.GLM:
-> 1144     return self._glm_forward_call(
   1145         x, likelihood, joint, link_approx, n_samples, diagonal_output
   1146     )
   1147 else:
   1148     if likelihood == Likelihood.REGRESSION:

File /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/laplace/baselaplace.py:646, in BaseLaplace._glm_forward_call(self, x, likelihood, joint, link_approx, n_samples, diagonal_output)
    598 def _glm_forward_call(
    599     self,
    600     x: torch.Tensor | MutableMapping,
   (...)    605     diagonal_output: bool = False,
    606 ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    607     """Compute the posterior predictive on input data `x` for "glm" pred type.
    608
    609     Parameters
   (...)    644         is returned with the mean and the predictive covariance.
    645     """
--> 646     f_mu, f_var = self._glm_predictive_distribution(
    647         x, joint=joint and likelihood == Likelihood.REGRESSION
    648     )
    650     if likelihood == Likelihood.REGRESSION:
    651         if diagonal_output and not joint:

File /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/laplace/lllaplace.py:230, in LLLaplace._glm_predictive_distribution(self, X, joint, diagonal_output)
    228         f_var = self.functional_variance(Js).diagonal(dim1=-2, dim2=-1)
    229 else:
--> 230     Js, f_mu = self.backend.last_layer_jacobians(X, self.enable_backprop)
    231     f_var = self.functional_variance(Js)
    233 return (
    234     (f_mu.detach(), f_var.detach())
    235     if not self.enable_backprop
    236     else (f_mu, f_var)
    237 )

File /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/laplace/curvature/curvature.py:151, in CurvatureInterface.last_layer_jacobians(self, x, enable_backprop)
    131 def last_layer_jacobians(
    132     self,
    133     x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
    134     enable_backprop: bool = False,
    135 ) -> tuple[torch.Tensor, torch.Tensor]:
    136     """Compute Jacobians \\(\\nabla_{\\theta_\\textrm{last}} f(x;\\theta_\\textrm{last})\\)
    137     only at current last-layer parameter \\(\\theta_{\\textrm{last}}\\).
    138
   (...)    149         output function `(batch, outputs)`
    150     """
--> 151     f, phi = self.model.forward_with_features(x)
    152     bsize = phi.shape[0]
    153     output_size = int(f.numel() / bsize)

File /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/laplace/utils/feature_extractor.py:109, in FeatureExtractor.forward_with_features(self, x)
     97 def forward_with_features(
     98     self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any]
     99 ) -> tuple[torch.Tensor, torch.Tensor]:
    100     """Forward pass which returns the output of the penultimate layer along
    101     with the output of the last layer. If the last layer is not known yet,
    102     it will be determined when this function is called for the first time.
   (...)    107         one batch of data to use as input for the forward pass
    108     """
--> 109     out = self.forward(x)
    110     features = self._features[self._last_layer_name]
    112     if features.dim() > 2 and self.feature_reduction is not None:

File /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/laplace/utils/feature_extractor.py:94, in FeatureExtractor.forward(self, x)
     91     out = self.find_last_layer(x)
     92 else:
     93     # if last and penultimate layers are already known
---> 94     out = self.model(x)
     95 return out

File /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/torch/nn/modules/module.py:1775, in Module._wrapped_call_impl(self, *args, **kwargs)
   1773     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1774 else:
-> 1775     return self._call_impl(*args, **kwargs)

File /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/torch/nn/modules/module.py:1786, in Module._call_impl(self, *args, **kwargs)
   1781 # If we don't have any hooks, we want to skip the rest of the logic in
   1782 # this function, and just call forward.
   1783 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1784         or _global_backward_pre_hooks or _global_backward_hooks
   1785         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1786     return forward_call(*args, **kwargs)
   1788 result = None
   1789 called_always_called_hooks = set()

File /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/laplace/utils/feature_extractor.py:94, in FeatureExtractor.forward(self, x)
     91     out = self.find_last_layer(x)
     92 else:
     93     # if last and penultimate layers are already known
---> 94     out = self.model(x)
     95 return out

File /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/torch/nn/modules/module.py:1775, in Module._wrapped_call_impl(self, *args, **kwargs)
   1773     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1774 else:
-> 1775     return self._call_impl(*args, **kwargs)

File /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/torch/nn/modules/module.py:1786, in Module._call_impl(self, *args, **kwargs)
   1781 # If we don't have any hooks, we want to skip the rest of the logic in
   1782 # this function, and just call forward.
   1783 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1784         or _global_backward_pre_hooks or _global_backward_hooks
   1785         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1786     return forward_call(*args, **kwargs)
   1788 result = None
   1789 called_always_called_hooks = set()

File /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/lightning_uq_box/models/mlp.py:57, in MLP.forward(self, x)
     48 def forward(self, x) -> Tensor:
     49     """Forward pass through the neural network.
     50
     51     Args:
   (...)     55       output from neural net of dimension [batch_size, n_outputs]
     56     """
---> 57     return self.model(x)

File /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/torch/nn/modules/module.py:1775, in Module._wrapped_call_impl(self, *args, **kwargs)
   1773     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1774 else:
-> 1775     return self._call_impl(*args, **kwargs)

File /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/torch/nn/modules/module.py:1786, in Module._call_impl(self, *args, **kwargs)
   1781 # If we don't have any hooks, we want to skip the rest of the logic in
   1782 # this function, and just call forward.
   1783 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1784         or _global_backward_pre_hooks or _global_backward_hooks
   1785         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1786     return forward_call(*args, **kwargs)
   1788 result = None
   1789 called_always_called_hooks = set()

File /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/torch/nn/modules/container.py:250, in Sequential.forward(self, input)
    246 """
    247 Runs the forward pass.
    248 """
    249 for module in self:
--> 250     input = module(input)
    251 return input

File /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/torch/nn/modules/module.py:1775, in Module._wrapped_call_impl(self, *args, **kwargs)
   1773     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1774 else:
-> 1775     return self._call_impl(*args, **kwargs)

File /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/torch/nn/modules/module.py:1786, in Module._call_impl(self, *args, **kwargs)
   1781 # If we don't have any hooks, we want to skip the rest of the logic in
   1782 # this function, and just call forward.
   1783 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1784         or _global_backward_pre_hooks or _global_backward_hooks
   1785         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1786     return forward_call(*args, **kwargs)
   1788 result = None
   1789 called_always_called_hooks = set()

File /opt/anaconda3/envs/uqEnv/lib/python3.12/site-packages/torch/nn/modules/linear.py:134, in Linear.forward(self, input)
    130 def forward(self, input: Tensor) -> Tensor:
    131     """
    132     Runs the forward pass.
    133     """
--> 134     return F.linear(input, self.weight, self.bias)

RuntimeError: Expected all tensors to be on the same device, but got mat1 is on cpu, different from other tensors on cuda:0 (when checking argument in method wrapper_CUDA_addmm)

For some additional metrics relevant to UQ, we can use the great uncertainty-toolbox that gives us some insight into the calibration of our prediction. For a discussion of why this is important, see …

[ ]:
preds = laplace_model.predict_step(X_test.to(laplace_model.device))
fig = plot_calibration_uq_toolbox(
    preds["pred"].cpu().numpy(),
    preds["pred_uct"].numpy(),
    Y_test.cpu().numpy(),
    X_test.cpu().numpy(),
)

Additional Resources#

Daxberger et al. 2020 introduced a subnetwork selection strategy that turns selected weights β€œBayesian” while keeping the rest of the network deterministic and show performance on par with deep ensembles.