Source code for lightning_uq_box.uq_methods.raps

# Apache License 2.0
# Copyright (c) 2020 Anastasios Angelopoulos

# Implementation adapted from above
# Copyright (c) 2023 lightning-uq-box. All rights reserved.
# Licensed under the Apache License 2.0.

"""Regularized Adaptive Prediction Sets (RAPS).

Adapted from https://github.com/aangelopoulos/conformal_classification
to be integrated with Lightning and port several functions to pytorch.
"""

import os
from functools import partial

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from lightning import LightningModule
from torch import Tensor
from torch.utils.data import DataLoader, Subset, TensorDataset, random_split
from torchmetrics import Accuracy, CalibrationError, MetricCollection

from .base import PosthocBase
from .metrics import EmpiricalCoverage, SetSize
from .temp_scaling import run_temperature_optimization, temp_scale_logits
from .utils import process_classification_prediction, save_classification_predictions


[docs] class RAPS(PosthocBase): """Regularized Adaptive Prediction Sets (RAPS). Conformal prediction method for classification tasks, as introduced by `Angelopoulos et al. (2020) <https://arxiv.org/abs/2009.14193>`_. If you use this method, please cite the following paper: * https://arxiv.org/abs/2009.14193 """ valid_tasks = ["binary", "multiclass", "multilable"] valid_lambda_criterion = ["size", "adaptiveness"] pred_file_name = "preds.csv"
[docs] def __init__( self, model: LightningModule | nn.Module, optim_lr: float = 0.01, max_iter: int = 50, alpha: float = 0.1, kreg: int | None = None, lamda_param: float | None = None, randomized: bool = False, allow_zero_sets: bool = False, pct_param_tune: float = 0.3, lamda_criterion: str = "size", task: str = "multiclass", ) -> None: """Initialize RAPS. Args: model: model to be calibrated with RAPS optim_lr: learning rate for optimizer max_iter: maximum number of iterations to run optimizer alpha: 1 - alpha is the desired coverage kreg: regularization param (smaller kreg leads to smaller sets) lamda_param: regularization param (larger lamda leads to smaller sets) (any value of kreg and lambda will lead to coverage, but will yield different set sizes) randomized: whether to use randomized version of conformal prediction allow_zero_sets: whether to allow sets of size zero pct_param_tune: fraction of calibration data to use for parameter tuning lamda_criterion: optimize for 'size' or 'adaptiveness' task: task type, one of 'binary', 'multiclass', or 'multilabel' """ super().__init__(model) self.num_classes = self.num_outputs assert task in self.valid_tasks self.task = task self.temperature = nn.Parameter(torch.Tensor([1.3])) self.optim_lr = optim_lr self.max_iter = max_iter self.loss_fn = nn.CrossEntropyLoss() self.randomized = randomized self.allow_zero_sets = allow_zero_sets self.pct_param_tune = pct_param_tune assert lamda_criterion in self.valid_lambda_criterion, ( f"lamda_criterion must be one of {self.valid_lambda_criterion}, " f"got {lamda_criterion}" ) self.lamda_criterion = lamda_criterion self.alpha = alpha self.kreg = kreg self.lamda_param = lamda_param self.setup_task()
[docs] def setup_task(self) -> None: """Setup task specific attributes.""" self.test_metrics = MetricCollection( { "Acc": Accuracy(task=self.task, num_classes=self.num_classes), "Calibration": CalibrationError( self.task, num_classes=self.num_classes ), "Empirical Coverage": EmpiricalCoverage(alpha=self.alpha), "Set Size": SetSize(alpha=self.alpha), }, prefix="test", )
[docs] def test_step( self, batch: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0 ) -> dict[str, Tensor]: """Test step after running posthoc fitting methodology.""" # need to set manually because of inference mode self.eval() pred_dict = self.predict_step(batch[self.input_key]) pred_dict[self.target_key] = batch[self.target_key] # logging metrics self.log( "test_loss", self.loss_fn(pred_dict["pred"], batch[self.target_key]), batch_size=batch[self.input_key].shape[0], ) self.test_metrics(pred_dict["pred"], batch[self.target_key]) pred_dict = self.add_aux_data_to_dict(pred_dict, batch) return pred_dict
[docs] def predict_step( self, X: Tensor, batch_idx: int = 0, dataloader_idx: int = 0 ) -> dict[str, Tensor]: """Predict step with RAPS applied. Args: X: prediction batch of shape [batch_size x input_dims] batch_idx: the index of this batch dataloader_idx: the index of the dataloader Returns: logits and conformalized prediction sets """ # need to set manually because of inference mode self.eval() with torch.no_grad(): if hasattr(self.model, "predict_step"): logits = self.model.predict_step(X)["logits"] else: logits = self.model(X) # temp scale logits just for prediction processing scores = temp_scale_logits(logits, self.temperature) # adjust model logits for set prediction S = self.adjust_model_logits(logits) if logits.ndim == 3: agg_func = torch.mean else: def identity(x, dim=None): return x agg_func = identity pred_dict = process_classification_prediction(scores, aggregate_fn=agg_func) pred_dict["pred_set"] = S pred_dict["size"] = torch.tensor([len(x) for x in S]) return pred_dict
[docs] def adjust_model_logits(self, model_logits: Tensor) -> Tensor: """Adjust model output according to RAPS with fitted Qhat. Args: model_logits: model output tensor of shape [batch_size x num_outputs] Returns: adjusted model logits tensor of shape [batch_size x num_outputs] """ scores = temp_scale_logits(model_logits, self.temperature) softmax_scores = F.softmax(scores, dim=1) if softmax_scores.ndim == 3: # predictions from MC-type model softmax_scores = softmax_scores.mean(-1) sorted_score_indices, ordered, cumsum = sort_sum(softmax_scores) return gen_cond_quantile_function( softmax_scores, self.Qhat, sorted_score_indices, ordered, cumsum, self.penalties, self.randomized, self.allow_zero_sets, )
[docs] @torch.no_grad() def on_train_end(self) -> None: """Apply RASP conformal method.""" self.eval() calib_logits = torch.cat(self.model_logits, dim=0).detach() calib_labels = torch.cat(self.labels, dim=0).detach() calib_ds = TensorDataset(calib_logits, calib_labels) # check kreg and lamda param if self.kreg is None or self.lamda_param is None: num_paramtune = int(np.ceil(self.pct_param_tune * len(calib_logits))) paramtune_ds, calib_ds = random_split( calib_ds, [num_paramtune, len(calib_logits) - num_paramtune] ) paramtune_loader = DataLoader(paramtune_ds, batch_size=64, shuffle=False) if self.kreg is None: self.kreg = find_kreg_param(paramtune_loader, self.alpha) if self.lamda_param is None: if self.lamda_criterion == "size": self.lamda_param = find_lamda_param_size( self.model, self.loss_fn, paramtune_loader, self.alpha, self.kreg, self.randomized, self.allow_zero_sets, ) elif self.lamda_criterion == "adaptiveness": self.lamda_param = find_lamda_param_adaptiveness( self.model, self.loss_fn, paramtune_loader, self.alpha, self.kreg, self.randomized, self.allow_zero_sets, ) self.penalties = torch.zeros((1, self.num_classes)) self.penalties[:, int(self.kreg) :] += self.lamda_param # noqa: E203 optimizer = partial(torch.optim.SGD, lr=self.optim_lr) self.temperature = run_temperature_optimization( calib_logits, calib_labels, self.loss_fn, self.temperature, optimizer ) self.Qhat = compute_q_hat( calib_logits, calib_labels, self.temperature, self.penalties, self.alpha ) self.post_hoc_fitted = True
[docs] def on_test_batch_end( self, outputs: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0 ) -> None: """Test batch end save predictions. Args: outputs: dictionary of model outputs and aux variables batch_idx: batch index dataloader_idx: dataloader index """ save_classification_predictions( outputs, os.path.join(self.trainer.default_root_dir, self.pred_file_name) )
def compute_q_hat( logits: Tensor, targets: Tensor, temperature: nn.Parameter, penalties: Tensor, alpha: float, ) -> Tensor: """Compute q_hat. Args: logits: model output logits of shape [batch_size x num_outputs] targets: labels of shape [batch_size] temperature: temperature parameter penalties: regularization penalties alpha: 1 - alpha is the desired coverage Returns: q_hat """ scores = temp_scale_logits(logits, temperature) softmax_scores = F.softmax(scores, dim=1) sorted_score_indices, ordered, cumsum = sort_sum(softmax_scores) # TODO why is this always randomized and allow zero sets to true? E = gen_inverse_quantile_function( targets, sorted_score_indices, ordered, cumsum, penalties, True, True ) Qhat = torch.quantile(E, 1 - alpha, interpolation="higher") return Qhat def find_kreg_param(paramtune_logits: DataLoader, alpha: float) -> int: """Find kreg parameter. Args: paramtune_logits: logits of paramtune data alpha: 1 - alpha is the desired coverage Returns: kreg parameter """ all_ranks = [] for sample in paramtune_logits: # Get the sorted indices of logits in descending order sorted_indices = torch.argsort(sample[0], dim=1, descending=True) target_expanded = sample[1].unsqueeze(1) ranks = (sorted_indices == target_expanded).nonzero(as_tuple=True)[1] all_ranks.append(ranks) all_ranks = torch.cat(all_ranks) kstar = torch.quantile(all_ranks.float(), 1 - alpha, interpolation="higher") + 1 return int(kstar.cpu().item()) def find_lamda_param_size( model: nn.Module, loss_fn: nn.Module, paramtune_loader: DataLoader, alpha: float, kreg: float, randomized: bool, allow_zero_sets: bool, ) -> Tensor: """Find lamda parameter for size criterion. Args: model: model to be used as conformal model loss_fn: loss function paramtune_loader: paramtune data loader alpha: 1 - alpha is the desired coverage kreg: regularization param (smaller kreg leads to smaller sets) randomized: whether to use randomized version of conformal prediction allow_zero_sets: whether to allow sets of size zero """ best_size = iter(paramtune_loader).__next__()[0][1].shape[0] # number of classes lamda_star = 0 for temp_lam in [0.001, 0.01, 0.1, 0.2, 0.5]: conformal_model = ConformalModelLogits( model, paramtune_loader, loss_fn=loss_fn, alpha=alpha, kreg=kreg, lamda=temp_lam, randomized=randomized, allow_zero_sets=allow_zero_sets, ) size_metric = SetSize(topk=None) for i, (logit, target) in enumerate(paramtune_loader): _, S = conformal_model(logit) size_metric.update(S, target) size = size_metric.compute() if size < best_size: best_size = size lamda_star = temp_lam return lamda_star def find_lamda_param_adaptiveness( model: nn.Module, loss_fn: nn.Module, paramtune_loader: DataLoader, alpha: float, kreg: float, randomized: bool, allow_zero_sets: bool, strata: list[list[int]] = [[0, 1], [2, 3], [4, 6], [7, 10], [11, 100], [101, 1000]], ) -> Tensor: """Find lamda parameter for adaptiveness criterion. Args: model: model to be used as conformal model loss_fn: loss function paramtune_loader: paramtune data loader alpha: 1 - alpha is the desired coverage kreg: regularization param (smaller kreg leads to smaller sets) randomized: whether to use randomized version of conformal prediction allow_zero_sets: whether to allow sets of size zero strata: strata to use for adaptiveness criterion """ lamda_star = 0 best_violation = 1 for temp_lam in [0, 1e-5, 1e-4, 8e-4, 9e-4, 1e-3, 1.5e-3, 2e-3]: conformal_model = ConformalModelLogits( model, paramtune_loader, loss_fn=loss_fn, alpha=alpha, kreg=kreg, lamda=temp_lam, randomized=randomized, allow_zero_sets=allow_zero_sets, ) curr_violation = get_violation(conformal_model, paramtune_loader, strata, alpha) if curr_violation < best_violation: best_violation = curr_violation lamda_star = temp_lam return lamda_star def get_violation( cmodel: nn.Module, loader_paramtune: DataLoader, strata: list[list[int]], alpha: float, ): """Get violation for adaptiveness criterion. Args: cmodel: conformal model loader_paramtune: paramtune data loader strata: strata to use for adaptiveness criterion alpha: 1 - alpha is the desired coverage Returns: violation """ df = pd.DataFrame(columns=["size", "correct"]) for logit, target in loader_paramtune: _, S = cmodel(logit) size = np.array([x.size()[0] for x in S]) sorted_score_indices, _, _ = sort_sum(logit) sorted_score_indices = sorted_score_indices.cpu().numpy() correct = np.zeros_like(size) for j in range(correct.shape[0]): correct[j] = int(target[j] in list(S[j])) batch_df = pd.DataFrame({"size": size, "correct": correct}) df = pd.concat([df, batch_df], ignore_index=True) wc_violation = 0 for stratum in strata: temp_df = df[(df["size"] >= stratum[0]) & (df["size"] <= stratum[1])] if len(temp_df) == 0: continue stratum_violation = abs(temp_df.correct.mean() - (1 - alpha)) wc_violation = max(wc_violation, stratum_violation) return wc_violation class ConformalModelLogits(nn.Module): """Conformal Model for logits.""" def __init__( self, model: nn.Module, calib_loader: DataLoader, loss_fn: nn.Module, alpha: float, kreg: int | None = None, lamda: float | None = None, randomized: bool = False, allow_zero_sets: bool = False, ): """Initialize ConformalModelLogits. Args: model: model to be used as conformal model calib_loader: calibration data loader loss_fn: loss function alpha: 1 - alpha is the desired coverage kreg: regularization param (smaller kreg leads to smaller sets) lamda: regularization param (larger lamda leads to smaller sets) (any value of kreg and lambda will lead to coverage, but will yield different set sizes) randomized: whether to use randomized version of conformal prediction allow_zero_sets: whether to allow sets of size zero pct_paramtune: fraction of calibration data to use for parameter tuning batch_size: batch size for parameter tuning lamda_criterion: optimize for 'size' or 'adaptiveness' """ super().__init__() self.model = model self.alpha = alpha self.randomized = randomized self.allow_zero_sets = allow_zero_sets # temperature optimization if isinstance(calib_loader.dataset, TensorDataset): logits = calib_loader.dataset.tensors[0] labels = calib_loader.dataset.tesnors[1] elif isinstance(calib_loader.dataset, Subset): logits = calib_loader.dataset.dataset.tensors[0] labels = calib_loader.dataset.dataset.tensors[1] optimizer = partial(torch.optim.SGD, lr=0.01) self.temperature = run_temperature_optimization( logits, labels, nn.CrossEntropyLoss(), nn.Parameter(torch.Tensor([1.3]).to(logits.device)), optimizer, ) self.penalties = torch.zeros((1, calib_loader.dataset[0][0].shape[0])) self.penalties[:, kreg:] += lamda self.Qhat = compute_q_hat( logits, labels, self.temperature, self.penalties, self.alpha ) def forward( self, logits: Tensor, randomized: bool | None = None, allow_zero_sets: bool | None = None, ): """Forward pass. Args: logits: model output logits of shape [batch_size x num_outputs] randomized: whether to use randomized version of conformal prediction allow_zero_sets: whether to allow sets of size zero Returns: logits and prediction sets """ if randomized is None: randomized = self.randomized if allow_zero_sets is None: allow_zero_sets = self.allow_zero_sets with torch.no_grad(): scores = F.softmax(logits / self.temperature.item(), dim=1) sorted_score_indices, ordered, cumsum = sort_sum(scores) S = gen_cond_quantile_function( scores, self.Qhat, sorted_score_indices=sorted_score_indices, ordered=ordered, cumsum=cumsum, penalties=self.penalties, randomized=randomized, allow_zero_sets=allow_zero_sets, ) return logits, S def gen_inverse_quantile_function( targets: Tensor, sorted_score_indices: Tensor, ordered: Tensor, cumsum: Tensor, penalties: Tensor, randomized: bool, allow_zero_sets: bool, ) -> Tensor: """Generalized inverse quantile conformity score function. E from equation (7) in Romano, Sesia, Candes. Find the minimum tau in [0, 1] such that the correct label enters. Args:a targets: shape [batch_size] sorted_score_indices: shape [batch_size x num_classes] ordered: shape [batch_size x num_classes] cumsum: shape [batch_size x num_classes] penalties: shape [1 x num_classes] randomized: whether to use randomized version of conformal prediction allow_zero_sets: whether to allow sets of size zero Returns: E: generalized inverse quantile conformity score """ E = -torch.ones((targets.shape[0],)) for i in range(targets.shape[0]): E[i] = get_single_tau( targets[i].item(), sorted_score_indices[i : i + 1, :], # noqa: E203 ordered[i : i + 1, :], # noqa: E203 cumsum[i : i + 1, :], # noqa: E203 penalties[0, :], # noqa: E203 randomized=randomized, allow_zero_sets=allow_zero_sets, ) return E def gen_cond_quantile_function( scores: Tensor, tau: Tensor, sorted_score_indices: Tensor, ordered: Tensor, cumsum: Tensor, penalties: Tensor, randomized: bool, allow_zero_sets: bool, ) -> list[Tensor]: """Generalized conditional quantile function. Args: scores: shape [batch_size x num_classes] tau: no shape, uses Q_hat sorted_score_indices: shape [batch_size x num_classes] ordered: shape [batch_size x num_classes] cumsum: shape [batch_size x num_classes] penalties: shape [1 x num_classes] randomized: whether to use randomized version of conformal prediction allow_zero_sets: whether to allow sets of size zero Returns: prediction sets """ penalties = penalties.to(scores.device) tau = tau.to(scores.device) penalties_cumsum = torch.cumsum(penalties, dim=1) sizes_base = ((cumsum + penalties_cumsum) <= tau).sum(dim=1).to(scores.device) + 1 sizes_base = torch.minimum( sizes_base, torch.ones_like(sizes_base) * scores.shape[1] ) if randomized: V = torch.zeros_like(sizes_base) for i in range(sizes_base.shape[0]): V[i] = ( 1 / ordered[i, sizes_base[i] - 1] * ( tau - (cumsum[i, sizes_base[i] - 1] - ordered[i, sizes_base[i] - 1]) - penalties_cumsum[0, sizes_base[i] - 1] ) ) # -1 since sizes_base \in {1,...,1000}. sizes = sizes_base - (torch.rand(V.shape).to(scores.device) >= V).int() else: sizes = sizes_base # always predict max size if alpha==0. (Avoids numerical error.) if tau == 1.0: sizes[:] = cumsum.shape[1] # allow the user the option to never have empty sets (will lead to incorrect # coverage if 1-alpha < model's top-1 accuracy if not allow_zero_sets: sizes[sizes == 0] = 1 # Construct S from equation (5) pred_sets: list[Tensor] = [] for i in range(sorted_score_indices.shape[0]): pred_sets.append(sorted_score_indices[i, 0 : sizes[i]]) # noqa: E203 return pred_sets def get_single_tau( target: Tensor, sorted_score_indices: Tensor, ordered: Tensor, cumsum: Tensor, penalty: Tensor, randomized: bool, allow_zero_sets: bool, ) -> Tensor: """Get tau for one example. Args: target: target label sorted_score_indices: shape [batch_size x num_classes] ordered: shape [batch_size x num_classes] cumsum: shape [batch_size x num_classes] penalty: shape [1 x num_classes] randomized: whether to use randomized version of conformal prediction allow_zero_sets: whether to allow sets of size zero Returns: single tau """ idx = torch.where(sorted_score_indices == target) tau_nonrandom = cumsum[idx] if not randomized: return tau_nonrandom + penalty[0] U = np.random.random() if idx == (0, 0): if not allow_zero_sets: return tau_nonrandom + penalty[0] else: return U * tau_nonrandom + penalty[0] else: return ( U * ordered[idx] + cumsum[(idx[0], idx[1] - 1)] + (penalty[0 : (idx[1][0] + 1)]).sum() # noqa: E203 ) def sort_sum(scores: Tensor) -> tuple[Tensor, Tensor, Tensor]: """Sort and sum scores. Args: scores: model scores [batch_size x num_classes] Returns: sorted_score_indices, ordered, cumsum """ # https://stackoverflow.com/questions/66767610/difference-between-numpy-argsort-and-torch-argsort # scores = scores.double() if scores.dtype != torch.float64 else scores # sorted_score_indices = torch.argsort(scores, dim=1, descending=True) device = scores.device sorted_score_indices = scores.cpu().numpy().argsort(axis=1)[:, ::-1] sorted_score_indices = torch.from_numpy(sorted_score_indices.copy()).to(device) ordered = torch.sort(scores, dim=1, descending=True).values cumsum = torch.cumsum(ordered, dim=1) return sorted_score_indices, ordered, cumsum