Source code for lightning_uq_box.uq_methods.temp_scaling

"""Temperature Scaling.

Adapted from https://github.com/gpleiss/temperature_scaling/blob/master/temperature_scaling.py. # noqa: E501
"""

import os
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
from lightning import LightningModule
from torch import Tensor
from torch.optim import LBFGS

from .base import PosthocBase
from .utils import default_classification_metrics, save_classification_predictions


[docs] class TempScaling(PosthocBase): """Temperature Scaling. If you use this method, please cite the following paper: * https://arxiv.org/abs/1706.04599 """ pred_file_name = "preds.csv" valid_tasks = ["binary", "multiclass"]
[docs] def __init__( self, model: LightningModule | nn.Module, optim_lr: float = 0.01, max_iter: int = 50, task: str = "multiclass", ) -> None: """Initialize Temperature Scaling method. Args: model: model to be calibrated with Temperature S optim_lr: learning rate for optimizer max_iter: maximum number of iterations to run optimizer task: classification task, one of "multiclass" or "binary" """ super().__init__(model) self.temperature = nn.Parameter(torch.ones(1) * 1.5) self.optim_lr = optim_lr self.max_iter = max_iter self.criterion = nn.CrossEntropyLoss() assert ( task in self.valid_tasks ), f"Task {task} not supported, please choose from {self.valid_tasks}" # noqa: E501" self.task = task self.setup_task()
[docs] def setup_task(self) -> None: """Set up task.""" self.test_metrics = default_classification_metrics( prefix="test", task=self.task, num_classes=self.num_outputs )
[docs] def adjust_model_logits(self, model_logits: Tensor) -> Tensor: """Adjust model logits by applying temperature scaling. Args: model_logits: model output logits of shape [batch_size x num_outputs] Returns: adjusted model logits of shape [batch_size x num_outputs] """ return temp_scale_logits(model_logits, self.temperature)
[docs] def on_train_epoch_end(self) -> None: """Perform CQR computation to obtain q_hat for predictions. Args: outputs: list of dictionaries containing model outputs and labels """ all_logits = torch.cat(self.model_logits, dim=0).detach() all_labels = torch.cat(self.labels, dim=0).detach() # optimizer temperature w.r.t. NLL optimizer = partial(LBFGS, lr=self.optim_lr, max_iter=self.max_iter) self.temperature = run_temperature_optimization( all_logits, all_labels, self.criterion, self.temperature, optimizer ) self.post_hoc_fitted = True
[docs] def predict_step(self, X: Tensor) -> dict[str, Tensor]: """Prediction step with applied temperature scaling. Args: X: input tensor of shape [batch_size x num_features] """ if not self.post_hoc_fitted: raise RuntimeError( "Model has not been post hoc fitted, " "please call " "trainer.fit(model, train_dataloaders=dm.calib_dataloader()) first." ) with torch.no_grad(): temp_scaled_outputs = self.forward(X) entropy = -torch.sum( F.softmax(temp_scaled_outputs, dim=1) * F.log_softmax(temp_scaled_outputs, dim=1), dim=1, ) return { "pred": temp_scaled_outputs, "pred_uct": entropy, "logits": temp_scaled_outputs, }
[docs] def test_step( self, batch: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0 ) -> dict[str, Tensor]: """Test step after running posthoc fitting methodology. Args: batch: batch of testing data batch_idx: batch index dataloader_idx: dataloader index """ out_dict = self.predict_step(batch[self.input_key]) out_dict[self.target_key] = batch[self.target_key] self.test_metrics(out_dict["pred"], batch[self.target_key]) out_dict = self.add_aux_data_to_dict(out_dict, batch) return out_dict
[docs] def on_test_batch_end( self, outputs: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0 ) -> None: """Test batch end save predictions. Args: outputs: dictionary of model outputs and aux variables batch_idx: batch index dataloader_idx: dataloader index """ save_classification_predictions( outputs, os.path.join(self.trainer.default_root_dir, self.pred_file_name) )
def temp_scale_logits(logits: torch.Tensor, temperature: torch.Tensor) -> torch.Tensor: """Apply temperature scaling to logits. Args: logits: model output logits of shape [batch_size x num_outputs] temperature: temperature tensor of shape [batch_size x 1] Returns: temperature scaled logits of shape [batch_size x num_outputs] """ return logits / temperature def run_temperature_optimization( logits: torch.Tensor, labels: torch.Tensor, criterion: nn.Module, temperature: nn.Parameter, optimizer: type[torch.optim.Optimizer] = partial(LBFGS, lr=0.01, max_iter=50), max_iter: int | None = 50, ) -> Tensor: """Run temperature optimization. Args: logits: model output logits of shape [batch_size x num_outputs] labels: labels of shape [batch_size] criterion: loss function temperature: temperature parameter optimizer: optimizer class max_iter: maximum number of iterations to run optimizer Returns: optimized temperature parameter """ optimizer = optimizer([temperature]) with torch.inference_mode(False): logits = logits.clone().requires_grad_(True) if isinstance(optimizer, torch.optim.LBFGS): def closure(): optimizer.zero_grad() loss = criterion(temp_scale_logits(logits, temperature), labels) loss.backward() return loss optimizer.step(closure) else: for _ in range( max_iter ): # You might need to adjust the number of iterations optimizer.zero_grad() loss = criterion(temp_scale_logits(logits, temperature), labels) loss.backward() optimizer.step() return temperature