Source code for lightning_uq_box.uq_methods.conformal_qr
# Copyright (c) 2023 lightning-uq-box. All rights reserved.
# Licensed under the Apache License 2.0.
"""conformalized Quantile Regression Model."""
import copy
import math
import os
import torch
import torch.nn as nn
from lightning import LightningModule
from lightning.pytorch.utilities.types import OptimizerLRScheduler
from torch import Tensor
from lightning_uq_box.eval_utils import compute_sample_mean_std_from_quantile
from .base import PosthocBase
from .utils import default_regression_metrics, save_regression_predictions
def compute_q_hat_with_cqr(
lower_quant: Tensor, upper_quant: Tensor, cal_labels: Tensor, alpha: float
) -> float:
"""Compute q_hat which is the adjustment factor for quantiles.
Check trusted computation here.
Args:
lower_quant: lower quantile predictions
upper_quant: upper quantile predictions
cal_labels: calibration set targets
alpha: 1 - alpha is desired error rate for quantile
Returns:
q_hat the computed quantile by which prediction intervals
can be adjusted according to cqr
"""
cal_labels = cal_labels.squeeze()
n = cal_labels.shape[0]
# Get scores. cal_upper.shape[0] == cal_lower.shape[0] == n
cal_scores = torch.maximum(cal_labels - upper_quant, lower_quant - cal_labels)
# Get the score quantile
q_hat = torch.quantile(
cal_scores, math.ceil((n + 1) * (1 - alpha)) / n, interpolation="higher"
)
return q_hat
[docs]
class ConformalQR(PosthocBase):
"""Conformalized Quantile Regression.
If you use this model, please cite the following paper:
* https://papers.nips.cc/paper_files/paper/2019/hash/5103c3584b063c431bd1268e9b5e76fb-Abstract.html
""" # noqa: E501
pred_file_name = "preds.csv"
[docs]
def __init__(
self,
model: nn.Module | LightningModule,
quantiles: list[float] = [0.1, 0.5, 0.9],
alpha: float = 0.1,
) -> None:
"""Initialize a new CQR model.
Args:
model: underlying model to be wrapped
quantiles: quantiles to be used for CQR
alpha: 1 - alpha is desired error rate for quantile
"""
super().__init__(model)
self.save_hyperparameters(ignore=["model"])
self.quantiles = quantiles
assert alpha > 0 and alpha < 1, "alpha must be in (0, 1)"
self.alpha = alpha
self.desired_coverage = 1 - self.alpha # 1-alpha is the desired coverage
self.setup_task()
[docs]
def setup_task(self) -> None:
"""Set up task."""
self.test_metrics = default_regression_metrics("test")
[docs]
def forward(self, X: Tensor) -> dict[str, Tensor]:
"""Forward pass of model.
Args:
X: input tensor of shape [batch_size x input_dims]
Returns:
model output tensor of shape [batch_size x num_outputs]
"""
if not self.post_hoc_fitted:
raise RuntimeError(
"Model has not been post hoc fitted, "
"please call "
"trainer.fit(model, train_dataloaders=dm.calib_dataloader()) first."
)
with torch.no_grad():
if hasattr(self.model, "predict_step"):
pred = self.model.predict_step(X)
else:
pred = self.model(X)
pred = self.adjust_model_logits(pred)
return pred
[docs]
def adjust_model_logits(
self, model_output: dict[str, Tensor] | Tensor
) -> dict[str, Tensor]:
"""Conformalize underlying model output.
Args:
model_output: model output tensor of shape [batch_size x num_outputs]
Returns:
conformalized model predictions
"""
if isinstance(model_output, dict):
output_dict = copy.deepcopy(model_output)
output_dict["pred"] = model_output["pred"].squeeze(-1)
output_dict["lower_quant"] = model_output["lower_quant"] - self.q_hat
output_dict["upper_quant"] = model_output["upper_quant"] + self.q_hat
else:
output_dict: dict[str, Tensor] = {}
# conformalize predictions assum ordering of quantiles
output_dict["lower_quant"] = model_output[:, 0] - self.q_hat
output_dict["pred"] = model_output[:, 1]
output_dict["upper_quant"] = model_output[:, -1] + self.q_hat
# compute gaussian assumption std with updated quantiles
_, std = compute_sample_mean_std_from_quantile(
torch.stack(
[
output_dict["lower_quant"],
output_dict["pred"],
output_dict["upper_quant"],
],
dim=-1,
),
self.quantiles,
)
output_dict["pred_uct"] = std
output_dict["aleatoric_uct"] = std
return output_dict
[docs]
def on_train_end(self) -> None:
"""Perform CQR computation to obtain q_hat for predictions.
Args:
outputs: list of dictionaries containing model outputs and labels
"""
all_outputs = torch.cat(self.model_logits, dim=0)
all_labels = torch.cat(self.labels, dim=0)
# calibration quantiles assume order of outputs corresponds
# to order of quantiles
self.q_hat = compute_q_hat_with_cqr(
all_outputs[:, 0], all_outputs[:, -1], all_labels, self.alpha
)
self.post_hoc_fitted = True
[docs]
def test_step(
self, batch: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0
) -> dict[str, Tensor]:
"""Test step.
Args:
batch: batch of testing data
batch_idx: batch index
dataloader_idx: dataloader index
"""
out_dict = self.predict_step(batch[self.input_key])
out_dict[self.target_key] = batch[self.target_key].detach().squeeze(-1).cpu()
out_dict["pred"] = out_dict["pred"].detach().cpu().squeeze(-1)
self.test_metrics(out_dict["pred"], out_dict[self.target_key])
# save metadata
out_dict = self.add_aux_data_to_dict(out_dict, batch)
if "out" in out_dict:
del out_dict["out"]
return out_dict
[docs]
def predict_step(self, X: Tensor) -> dict[str, Tensor]:
"""Prediction step that produces conformalized prediction sets.
Args:
X: input tensor of shape [batch_size x input_dims]
batch_idx: batch index
dataloader_idx: dataloader index
"""
if not self.post_hoc_fitted:
raise RuntimeError(
"Model has not been post hoc fitted, "
"please call "
"trainer.fit(model, train_dataloaders=dm.calib_dataloader()) first."
)
cqr_sets = self.forward(X)
return cqr_sets
[docs]
def configure_optimizers(self) -> OptimizerLRScheduler:
"""No optimizer needed for Conformal QR."""
pass
[docs]
def on_test_batch_end(
self, outputs: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0
) -> None:
"""Test batch end save predictions.
Args:
outputs: dictionary of model outputs and aux variables
batch_idx: batch index
dataloader_idx: dataloader index
"""
save_regression_predictions(
outputs, os.path.join(self.trainer.default_root_dir, self.pred_file_name)
)