# MIT License
# Copyright (c) 2024 CVLAB @ EPFL
# Copyright (c) 2023 lightning-uq-box. All rights reserved.
# Licensed under the Apache License 2.0.
# Adapted from https://github.com/cvlab-epfl/zigzag as a LightningModule
"""ZigZag Universal Sampling-free Uncertainty Estimation."""
import os
import torch
import torch.nn as nn
from einops import repeat
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
from torch import Tensor
from torch.optim.adam import Adam as Adam
from .base import DeterministicModel
from .utils import (
_get_input_layer_name_and_module,
_get_num_outputs,
default_classification_metrics,
default_regression_metrics,
save_classification_predictions,
save_regression_predictions,
)
[docs]
class ZigZagBase(DeterministicModel):
"""ZigZag Uncertainty Quantification Base.
If you use this method in your work, please cite:
* https://openreview.net/forum?id=QSvb6jBXML
.. versionadded:: 0.2
"""
[docs]
def __init__(
self,
model: nn.Module,
loss_fn: nn.Module,
blank_const: int = -100,
freeze_backbone: bool = False,
optimizer: OptimizerCallable = torch.optim.Adam,
lr_scheduler: LRSchedulerCallable | None = None,
) -> None:
"""Initialize a new instance of ZigZag.
Args:
model: PyTorch model.
loss_fn: Loss function.
blank_const: constant for the blank zig zag input, should be a
value far from training targets
freeze_backbone: Whether or not to freeze the backbone.
optimizer: Optimizer.
lr_scheduler: Learning rate scheduler.
"""
super().__init__(model, loss_fn, freeze_backbone, optimizer, lr_scheduler)
self.blank_const = blank_const
self.check_input_layer()
[docs]
def forward(
self, x: Tensor, y: Tensor | None = None, training: bool = False
) -> Tensor:
"""Forward pass of Zig Zag method.
Args:
x: Input tensor.
y: Target tensor.
training: Whether or not the model is in training mode,
which affects the Zig Zag operation for conv input layers
Returns:
Output of model with Zig Zag operation.
"""
# create additional feature dimension either of blanks or targets
if y is None:
if self.input_linear:
x_in = torch.concat(
[
x,
self.blank_const
* torch.ones([x.shape[0], 1], device=x.device, dtype=x.dtype),
],
dim=1,
)
else:
batch_size, _, height, width = x.shape
ones_tensor = torch.ones(
[batch_size, 1, height, width], device=x.device, dtype=x.dtype
)
x_in = torch.cat([x, self.blank_const * ones_tensor], dim=1)
else:
if y.dim() == 1:
y = y.unsqueeze(-1)
if self.input_linear:
# classification labels are just 1D
x_in = torch.concat([x, torch.atleast_2d(y)], dim=1)
else:
batch_size, _, height, width = x.shape
channel_y = torch.atleast_2d(y).shape[-1]
ones_tensor = torch.ones(
[batch_size, channel_y, height, width],
device=x.device,
dtype=x.dtype,
)
if training:
inputs_1 = torch.cat([x, self.blank_const * ones_tensor], dim=1)
# The second input with actual targets, the second term in Eq. 1
t_inputs = y.reshape(-1, 1, 1, 1) * ones_tensor
inputs_2 = torch.cat([x, t_inputs], dim=1)
p = 0.5
mask = (
(torch.empty(inputs_1.shape[0], 1, 1, 1).uniform_(0, 1) > p)
.float()
.to(x.device)
)
x_in = inputs_1 * mask + inputs_2 * (1 - mask)
else:
x_in = torch.cat([x, y.reshape(-1, 1, 1, 1) * ones_tensor], dim=1)
return self.model(x_in)
[docs]
def training_step(
self, batch: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0
) -> Tensor:
"""Compute and return the training loss.
Args:
batch: the output of your DataLoader
batch_idx: the index of this batch
dataloader_idx: the index of the dataloader
Returns:
training loss
"""
X, y = batch[self.input_key], batch[self.target_key]
if self.input_linear:
x_in = repeat(X, "b ... -> (repeat b) ...", repeat=2)
y_in = torch.cat([self.blank_const * torch.ones_like(y), y])
y_target = repeat(y, "b ... -> (repeat b) ...", repeat=2)
else:
x_in = X
y_in = y
y_target = y
out = self.forward(x_in, y_in, training=True)
loss = self.loss_fn(out, y_target)
# compute metrics only on the real input not the zigzag condition
if X.shape[0] > 1:
if self.input_linear:
self.train_metrics(out[: X.shape[0]], y)
else:
self.train_metrics(out, y)
self.log("train_loss", loss, batch_size=X.shape[0])
return loss
[docs]
def validation_step(
self, batch: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0
) -> Tensor:
"""Compute and return validation loss.
Args:
batch: the output of your DataLoader
batch_idx: the index of this batch
dataloader_idx: the index of the dataloader
Returns:
validation loss
"""
X, y = batch[self.input_key], batch[self.target_key]
if self.input_linear:
x_in = repeat(X, "b ... -> (repeat b) ...", repeat=2)
y_in = torch.cat([self.blank_const * torch.ones_like(y), y])
y_target = repeat(y, "b ... -> (repeat b) ...", repeat=2)
else:
x_in = X
y_in = y
y_target = y
out = self.forward(x_in, y_in, training=False)
loss = self.loss_fn(out, y_target)
# compute metrics only on the real input not the zigzag condition
if X.shape[0] > 1:
if self.input_linear:
self.val_metrics(out[: X.shape[0]], y)
else:
self.val_metrics(out, y)
self.log("val_loss", loss, batch_size=X.shape[0])
return loss
[docs]
def test_step(
self, batch: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0
) -> dict[str, Tensor]:
"""Test step."""
pred_dict = self.predict_step(batch[self.input_key])
pred_dict[self.target_key] = batch[self.target_key].detach().squeeze(-1)
if batch[self.input_key].shape[0] > 1:
self.test_metrics(pred_dict["pred"], batch[self.target_key])
pred_dict["pred"] = pred_dict["pred"].detach().squeeze(-1)
pred_dict = self.add_aux_data_to_dict(pred_dict, batch)
return pred_dict
[docs]
class ZigZagRegression(ZigZagBase):
"""Zig Zag Uncertainty Estimation for Regression.
If you use this method in your work, please cite:
* https://openreview.net/forum?id=QSvb6jBXML
.. versionadded:: 0.2
"""
pred_file_name = "preds.csv"
[docs]
def setup_task(self) -> None:
"""Set up task specific attributes."""
self.train_metrics = default_regression_metrics("train")
self.val_metrics = default_regression_metrics("val")
self.test_metrics = default_regression_metrics("test")
[docs]
def predict_step(
self, X: Tensor, batch_idx: int = 0, dataloader_idx: int = 0
) -> dict[str, Tensor]:
"""Predict Step.
Conducts two forward passes. One with the input, and
a second one with the input and the output of the first
forward pass.
Args:
X: prediction input tensor
batch_idx: batch index
dataloader_idx: dataloader index
Returns:
prediction dictionary
"""
with torch.no_grad():
Y_1 = self.forward(X, training=False)
Y_2 = self.forward(X, Y_1, training=False)
return {"pred": Y_1, "pred_uct": torch.linalg.norm(Y_1 - Y_2, dim=1)}
[docs]
def on_test_batch_end(
self, outputs: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0
) -> None:
"""Test batch end save predictions.
Args:
outputs: dictionary of model outputs and aux variables
batch_idx: batch index
dataloader_idx: dataloader index
"""
save_regression_predictions(
outputs, os.path.join(self.trainer.default_root_dir, self.pred_file_name)
)
[docs]
class ZigZagClassification(ZigZagBase):
"""Zig Zag Uncertainty Estimation for Classification.
If you use this method in your work, please cite:
* https://openreview.net/forum?id=QSvb6jBXML
.. versionadded:: 0.2
"""
pred_file_name = "preds.csv"
valid_tasks = ["binary", "multiclass", "multilable"]
[docs]
def __init__(
self,
model: nn.Module,
loss_fn: nn.Module,
blank_const: int = -100,
task: str = "multiclass",
freeze_backbone: bool = False,
optimizer: OptimizerCallable = torch.optim.Adam,
lr_scheduler: LRSchedulerCallable | None = None,
) -> None:
"""Initialize a new instance of ZigZag for classification.
Args:
model: PyTorch model.
loss_fn: Loss function.
blank_const: constant for the blank zig zag input, should be a
value far from training targets
task: Task type. One of "binary", "multiclass", "multilabel".
freeze_backbone: Whether or not to freeze the backbone.
optimizer: Optimizer.
lr_scheduler: Learning rate scheduler.
"""
self.num_classes = _get_num_outputs(model)
assert task in self.valid_tasks, f"Task must be one of {self.valid_tasks}"
self.task = task
super().__init__(
model, loss_fn, blank_const, freeze_backbone, optimizer, lr_scheduler
)
[docs]
def setup_task(self) -> None:
"""Set up task specific attributes."""
self.train_metrics = default_classification_metrics(
"train", self.task, self.num_classes
)
self.val_metrics = default_classification_metrics(
"val", self.task, self.num_classes
)
self.test_metrics = default_classification_metrics(
"test", self.task, self.num_classes
)
[docs]
def predict_step(
self, X: Tensor, batch_idx: int = 0, dataloader_idx: int = 0
) -> dict[str, Tensor]:
"""Prediction step.
Args:
X: prediction batch of shape [batch_size x input_dims]
batch_idx: batch index
dataloader_idx: dataloader index
Returns:
prediction dictionary
"""
self.model.eval()
with torch.no_grad():
Y_1 = self.forward(X, training=False)
Y_1_softmax = torch.softmax(Y_1, dim=1)
Y_1_labels = torch.argmax(Y_1_softmax, dim=1)
Y_2 = self.forward(X, Y_1_labels, training=False)
Y_2_softmax = torch.softmax(Y_2, dim=1)
return {
"pred": Y_1_softmax,
"pred_uct": torch.abs(Y_1_softmax - Y_2_softmax).mean(dim=1),
"logits": Y_1,
}
[docs]
def on_test_batch_end(
self, outputs: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0
) -> None:
"""Test batch end save predictions.
Args:
outputs: dictionary of model outputs and aux variables
batch_idx: batch index
dataloader_idx: dataloader index
"""
save_classification_predictions(
outputs, os.path.join(self.trainer.default_root_dir, self.pred_file_name)
)