# Copyright (c) 2023 lightning-uq-box. All rights reserved.
# Licensed under the Apache License 2.0.
"""Test Time Augmentation (TTA)."""
import os
from collections.abc import Callable
from typing import Any, Literal
import kornia.augmentation as K
import torch
import torch.nn as nn
from lightning import LightningModule
from torch import Tensor
from .base import PosthocBase
from .utils import (
default_classification_metrics,
default_regression_metrics,
process_classification_prediction,
process_regression_prediction,
save_classification_predictions,
save_regression_predictions,
)
def torch_median_val_only(tensor: torch.Tensor, dim: int) -> torch.Tensor:
"""Torch median but only return values."""
values, _ = torch.median(tensor, dim=dim)
return values
def torch_max_val_only(tensor: torch.Tensor, dim: int) -> torch.Tensor:
"""Torch max but only return values."""
values, _ = torch.max(tensor, dim=dim)
return values
def torch_min_val_only(tensor: torch.Tensor, dim: int) -> torch.Tensor:
"""Torch min but only return values."""
values, _ = torch.min(tensor, dim=dim)
return values
merge_strategy_dict = {
"mean": torch.mean,
"median": torch_median_val_only,
"sum": torch.sum,
"max": torch_max_val_only,
"min": torch_min_val_only,
}
[docs]
class TTABase(PosthocBase):
"""Test Time Augmentation Module.
In addition to a prediction with no test time augmentation,
an additional prediction will be made for each element in `tt_augmentation`.
"""
valid_merge_strategies: list[str] = ["mean", "median", "sum", "max", "min"]
[docs]
def __init__(
self,
model: LightningModule | nn.Module,
tt_augmentation: list[Callable] | None = None,
merge_strategy: Literal["mean", "median", "sum", "max", "min"] = "mean",
) -> None:
"""Initialize a new instance of TTA module.
Args:
model: LightningModule or nn.Module used for prediction
tt_augmentation: list of test time augmentation function, assumed to
accept an input that is a dictionary with key `self.input_key`
and `self.target_key` which is `input` and `target`, if None, a set
of default augmentations will be used
merge_strategy: strategy to merge the predictions from the different
augmentations
"""
super().__init__(model)
self.tt_augmentation = tt_augmentation
assert merge_strategy in self.valid_merge_strategies, (
f"Merge strategy must be one of {self.valid_merge_strategies}"
)
self.merge_strategy = merge_strategy
[docs]
def compute_predictive_uncertainty(self, aug_tensor: Tensor) -> dict[str, Tensor]:
"""Merge predictions via different strategies to compute predictive uncertainty.
Args:
aug_tensor: The tensor containing predictions from different
augmentations
Returns:
dict with predictive mean and uncertainty
"""
raise NotImplementedError
[docs]
def test_step(
self, batch: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0
) -> dict[str, Tensor]:
"""Test step for TTA procedure."""
# make prediction with TTA
merged_aug_preds = self.predict_step(batch[self.input_key])
# augment the targets as well in same order as predictionsn
# TODO is it not enough to keep the same target for each augmentation
# since they are being undone and should therefore stay the same?
# compute metrics
self.test_metrics(merged_aug_preds["pred"], batch[self.target_key])
merged_aug_preds[self.target_key] = batch[self.target_key]
merged_aug_preds = self.add_aux_data_to_dict(merged_aug_preds, batch)
return merged_aug_preds
[docs]
def predict_step(
self,
X: Tensor,
aug: list[Callable] = None,
batch_idx: int = 0,
dataloader_idx: int = 0,
) -> dict[str, Tensor]:
"""Predict step with TTA applied.
Args:
X: prediction batch of shape [batch_size x input_dims]
aug: augmentation function to apply to X
batch_idx: batch index
dataloader_idx: dataloader index
Returns:
logits and conformalized prediction sets
"""
self.eval()
def yield_prediction(X: Tensor) -> Tensor | dict[str, Tensor]:
"""Yield prediction depending on underlying model."""
with torch.no_grad():
if hasattr(self.model, "predict_step"):
pred = self.model.predict_step(X)
else:
pred = self.model(X)
return pred
aug_predictions: list[Tensor] | list[dict[str, Tensor]] = []
if aug is None:
aug = self.tt_augmentation
# first prediction with no augmentation
aug_predictions.append(yield_prediction(X))
# iterate over augmentation functions
for aug_fn in aug:
# augment the input
aug_X = aug_fn(X)
# reverse augmentation on the label to keep track on label
# TODO
# save prediction
aug_predictions.append(yield_prediction(aug_X))
# combine predictions to common tensor
# TODO: how should predictions be merged from underlying sampling based models?
# for example is the pred and pred_uct just averages of the augmentations
# or do you consider the underlying samples to create a larger set of samples
# with the augmentations and compute uncertainty on those?
aug_preds: dict[str, Tensor] = {}
if isinstance(aug_predictions[0], dict):
for key in aug_predictions[0].keys():
aug_preds[key] = torch.stack(
[pred[key] for pred in aug_predictions], dim=-1
)
if key in ["pred", "logits"]:
aug_preds[key] = merge_strategy_dict[self.merge_strategy](
aug_preds[key], dim=-1
)
else:
aug_preds[key] = aug_preds[key].mean(dim=-1)
else:
aug_preds = self.compute_predictive_uncertainty(
torch.stack(aug_predictions, dim=-1)
)
return aug_preds
[docs]
def setup_task(self) -> None:
"""Setup task."""
raise NotImplementedError
[docs]
def validation_step(
self, batch: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0
) -> None:
"""No validation step in TTA."""
pass
[docs]
def on_validation_start(self) -> None:
"""No validation step in TTA."""
pass
[docs]
class TTARegression(TTABase):
"""Regression Test Time Augmentation Module."""
pred_file_name = "preds.csv"
[docs]
def __init__(
self,
model: LightningModule | nn.Module,
tt_augmentation: list[Callable[..., Any]] | None = None,
merge_strategy: Literal["mean", "median", "sum", "max", "min"] = "mean",
) -> None:
"""Initialize a new instance of TTA Regression module.
Args:
model: LightningModule or nn.Module used for prediction
tt_augmentation: list of test time augmentation function, assumed to
accept an input that is a dictionary with key `self.input_key`
and `self.target_key` which is `input` and `target`, if None, a set
of default augmentations will be used
merge_strategy: strategy to merge the predictions from the different
augmentations
"""
super().__init__(model, tt_augmentation, merge_strategy)
self.setup_task()
[docs]
def setup_task(self) -> None:
"""Set up task specific attributes."""
self.test_metrics = default_regression_metrics("test")
if self.tt_augmentation is None:
self.tt_augmentation: list[nn.Module] = [
K.RandomHorizontalFlip(p=1.0),
K.RandomVerticalFlip(p=1.0),
K.ColorJiggle(
brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1, p=1.0
),
]
[docs]
def compute_predictive_uncertainty(self, aug_tensor: Tensor) -> dict[str, Tensor]:
"""Merge predictions according to `merge_strategy`.
Args:
aug_tensor: The tensor containing predictions from different
augmentations
Returns:
The tensor after applying the merge strategy
"""
if self.merge_strategy == "mean":
pred_dict = process_regression_prediction(
aug_tensor, aggregate_fn=torch.mean
)
elif self.merge_strategy == "median":
pred_dict = process_regression_prediction(
aug_tensor, aggregate_fn=torch_median_val_only
)
elif self.merge_strategy == "sum":
pred_dict = process_regression_prediction(
aug_tensor, aggregate_fn=torch.sum
)
elif self.merge_strategy == "max":
pred_dict = process_regression_prediction(
aug_tensor, aggregate_fn=torch_max_val_only
)
elif self.merge_strategy == "min":
pred_dict = process_regression_prediction(
aug_tensor, aggregate_fn=torch_min_val_only
)
return pred_dict
[docs]
def on_test_batch_end(
self, outputs: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0
) -> None:
"""Test batch end save predictions.
Args:
outputs: dictionary of model outputs and aux variables
batch_idx: batch index
dataloader_idx: dataloader index
"""
save_regression_predictions(
outputs, os.path.join(self.trainer.default_root_dir, self.pred_file_name)
)
[docs]
class TTAClassification(TTABase):
"""Classification Test Time Augmentation Module."""
pred_file_name = "preds.csv"
valid_tasks: list[str] = ["binary", "multiclass", "multilable"]
[docs]
def __init__(
self,
model: LightningModule | nn.Module,
tt_augmentation: list[Callable[..., Any]] | None = None,
merge_strategy: Literal["mean", "median", "sum", "max", "min"] = "mean",
task: Literal["binary", "multiclass", "multilable"] = "multiclass",
) -> None:
"""Initialize a new instance of TTA Classification module.
Args:
model: LightningModule or nn.Module used for prediction
tt_augmentation: list of test time augmentation function, assumed to
accept an input that is a dictionary with key `self.input_key`
and `self.target_key` which is `input` and `target`, if None, a set
of default augmentations will be used
merge_strategy: strategy to merge the predictions from the different
augmentations
task: task type, one of "binary", "multiclass", "multilabel"
"""
assert task in self.valid_tasks, f"Task must be one of {self.valid_tasks}"
self.task = task
super().__init__(model, tt_augmentation, merge_strategy)
self.num_classes = self.num_outputs
self.setup_task()
[docs]
def setup_task(self) -> None:
"""Set up task specific attributes."""
self.test_metrics = default_classification_metrics(
"test", self.task, self.num_classes
)
if self.tt_augmentation is None:
self.tt_augmentation: list[nn.Module] = [
K.RandomHorizontalFlip(p=1.0),
K.RandomVerticalFlip(p=1.0),
K.ColorJiggle(
brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1, p=1.0
),
]
[docs]
def compute_predictive_uncertainty(self, aug_tensor: Tensor) -> dict[str, Tensor]:
"""Merge predictions according to `merge_strategy`.
Args:
aug_tensor: The tensor containing predictions from different
augmentations
Returns:
The tensor after applying the merge strategy
"""
if self.merge_strategy == "mean":
pred_dict = process_classification_prediction(
aug_tensor, aggregate_fn=torch.mean
)
elif self.merge_strategy == "median":
pred_dict = process_classification_prediction(
aug_tensor, aggregate_fn=torch_median_val_only
)
elif self.merge_strategy == "sum":
pred_dict = process_classification_prediction(
aug_tensor, aggregate_fn=torch.sum
)
elif self.merge_strategy == "max":
pred_dict = process_classification_prediction(
aug_tensor, aggregate_fn=torch_max_val_only
)
elif self.merge_strategy == "min":
pred_dict = process_classification_prediction(
aug_tensor, aggregate_fn=torch_min_val_only
)
return pred_dict
[docs]
def on_test_batch_end(
self, outputs: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0
) -> None:
"""Test batch end save predictions.
Args:
outputs: dictionary of model outputs and aux variables
batch_idx: batch index
dataloader_idx: dataloader index
"""
save_classification_predictions(
outputs, os.path.join(self.trainer.default_root_dir, self.pred_file_name)
)