Source code for lightning_uq_box.uq_methods.quantile_regression

# Copyright (c) 2023 lightning-uq-box. All rights reserved.
# Licensed under the Apache License 2.0.

"""Implement Quantile Regression Model."""

import os
from typing import Any

import torch
import torch.nn as nn
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
from torch import Tensor

from lightning_uq_box.eval_utils import compute_sample_mean_std_from_quantile

from .base import DeterministicModel
from .loss_functions import PinballLoss
from .utils import (
    default_px_regression_metrics,
    default_regression_metrics,
    freeze_segmentation_model,
    save_image_predictions,
    save_regression_predictions,
)


[docs] class QuantileRegressionBase(DeterministicModel): """Quantile Regression Base Module. If you use this model in your research, please cite the following paper: * https://www.jstor.org/stable/1913643 """
[docs] def __init__( self, model: nn.Module, loss_fn: nn.Module | None = None, quantiles: list[float] = [0.1, 0.5, 0.9], freeze_backbone: bool = False, optimizer: OptimizerCallable = torch.optim.Adam, lr_scheduler: LRSchedulerCallable = None, ) -> None: """Initialize a new instance of Quantile Regression Model. Args: model: pytorch model loss_fn: loss function quantiles: quantiles to compute freeze_backbone: whether to freeze the backbone optimizer: optimizer used for training lr_scheduler: learning rate scheduler """ assert all(i < 1 for i in quantiles), "Quantiles should be less than 1." assert all(i > 0 for i in quantiles), "Quantiles should be greater than 0." if loss_fn is None: loss_fn = PinballLoss(quantiles=quantiles) super().__init__(model, loss_fn, freeze_backbone, optimizer, lr_scheduler) self.quantiles = quantiles self.median_index = self.quantiles.index(0.5)
[docs] def setup_task(self) -> None: """Set up task specific attributes.""" self.train_metrics = default_regression_metrics("train") self.val_metrics = default_regression_metrics("val") self.test_metrics = default_regression_metrics("test")
[docs] class QuantileRegression(QuantileRegressionBase): """Quantile Regression Module for Regression. If you use this model in your research, please cite the following paper: * https://www.jstor.org/stable/1913643 """ pred_file_name = "preds.csv"
[docs] def __init__( self, model: nn.Module, loss_fn: nn.Module | None = None, quantiles: list[float] = [0.1, 0.5, 0.9], freeze_backbone: bool = False, optimizer: OptimizerCallable = torch.optim.Adam, lr_scheduler: LRSchedulerCallable = None, ) -> None: """Initialize a new instance of Quantile Regression Model. Args: model: pytorch model optimizer: optimizer used for training loss_fn: loss function quantiles: quantiles to compute freeze_backbone: whether to freeze the backbone optimizer: optimizer used for training lr_scheduler: learning rate scheduler """ super().__init__( model, loss_fn, quantiles, freeze_backbone, optimizer, lr_scheduler ) self.save_hyperparameters( ignore=["model", "loss_fn", "optimizer", "lr_scheduler"] )
[docs] def adapt_output_for_metrics(self, out: Tensor) -> Tensor: """Adapt model output to be compatible for metric computation. Args: out: output from :meth:`self.forward` [batch_size x num_outputs] Returns: extracted mean used for metric computation [batch_size x 1] """ return out[:, self.median_index : self.median_index + 1] # noqa: E203
[docs] def test_step( self, batch: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0 ) -> dict[str, Tensor]: """Test step.""" out_dict = self.predict_step(batch[self.input_key]) out_dict[self.target_key] = batch[self.target_key].detach().squeeze(-1) if batch[self.input_key].shape[0] > 1: self.test_metrics(out_dict["pred"], batch[self.target_key]) out_dict["pred"] = out_dict["pred"].detach().cpu().squeeze(-1) out_dict = self.add_aux_data_to_dict(out_dict, batch) return out_dict
[docs] def predict_step( self, X: Tensor, batch_idx: int = 0, dataloader_idx: int = 0 ) -> dict[str, Tensor]: """Predict step with Quantile Regression. Args: X: prediction batch of shape [batch_size x input_dims] batch_idx: the index of this batch dataloader_idx: the index of the dataloader Returns: predicted uncertainties """ with torch.no_grad(): out = self.model(X) # [batch_size, len(self.quantiles)] median = self.adapt_output_for_metrics(out) _, std = compute_sample_mean_std_from_quantile(out, self.hparams.quantiles) return { "pred": median, "pred_uct": std, "lower_quant": out[:, 0], "upper_quant": out[:, -1], "aleatoric_uct": std, }
[docs] def on_test_batch_end( self, outputs: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0 ) -> None: """Test batch end save predictions. Args: outputs: dictionary of model outputs and aux variables batch_idx: batch index dataloader_idx: dataloader index """ save_regression_predictions( outputs, os.path.join(self.trainer.default_root_dir, self.pred_file_name) )
[docs] class QuantilePxRegression(QuantileRegressionBase): """Quantile Regression for Pixelwise Regression. .. versionadded:: 0.2.0 """ pred_dir_name = "preds"
[docs] def __init__( self, model: nn.Module, loss_fn: nn.Module | None = None, quantiles: list[float] = [0.1, 0.5, 0.9], freeze_backbone: bool = False, freeze_decoder: bool = False, optimizer: OptimizerCallable = torch.optim.Adam, lr_scheduler: LRSchedulerCallable = None, save_preds: bool = False, ) -> None: """Initialize a new instance of Quantile Regression Model. Args: model: pytorch model optimizer: optimizer used for training loss_fn: loss function quantiles: quantiles to compute freeze_backbone: whether to freeze the backbone freeze_decoder: whether to freeze the decoder optimizer: optimizer used for training lr_scheduler: learning rate scheduler save_preds: whether to save predictions """ self.freeze_decoder = freeze_decoder super().__init__( model, loss_fn, quantiles, freeze_backbone, optimizer, lr_scheduler ) self.save_hyperparameters( ignore=["model", "loss_fn", "optimizer", "lr_scheduler"] ) self.save_preds = save_preds
[docs] def freeze_model(self) -> None: """Freeze model backbone. By default, assumes a timm model with a backbone and head. Alternatively, selected the last layer with parameters to freeze. """ freeze_segmentation_model(self.model, self.freeze_backbone, self.freeze_decoder)
[docs] def setup_task(self) -> None: """Set up task specific attributes.""" self.train_metrics = default_px_regression_metrics("train") self.val_metrics = default_px_regression_metrics("val") self.test_metrics = default_px_regression_metrics("test")
[docs] def adapt_output_for_metrics(self, out: Tensor) -> Tensor: """Adapt model output to be compatible for metric computation. Args: out: output from :meth:`self.forward` [batch_size x num_outputs x height x width] Returns: extracted mean used for metric computation [batch_size x 1 x height x width] """ return out[ :, self.median_index : self.median_index + 1, ... # noqa: E203 ].contiguous()
[docs] def on_test_start(self) -> None: """Create logging directory and initialize metrics.""" self.pred_dir = os.path.join(self.trainer.default_root_dir, self.pred_dir_name) if not os.path.exists(self.pred_dir) and self.save_preds: os.makedirs(self.pred_dir)
[docs] def test_step( self, batch: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0 ) -> dict[str, Tensor]: """Test step. Args: batch: batch of testing data batch_idx: batch index dataloader_idx: dataloader index """ pred_dict = self.predict_step(batch[self.input_key]) pred_dict[self.target_key] = batch[self.target_key].detach().squeeze(-1).cpu() pred_dict = self.add_aux_data_to_dict(pred_dict, batch) self.test_metrics( pred_dict["pred"].contiguous(), batch[self.target_key].squeeze() ) return pred_dict
[docs] def predict_step( self, X: Tensor, batch_idx: int = 0, dataloader_idx: int = 0 ) -> dict[str, Tensor]: """Predict step with Quantile Regression. Args: X: prediction batch of shape [batch_size x input_dims] batch_idx: the index of this batch dataloader_idx: the index of the dataloader Returns: predicted uncertainties """ with torch.no_grad(): out = self.model(X) # [batch_size, len(self.quantiles)] return { "pred": self.adapt_output_for_metrics(out).squeeze(1), "lower": out[:, 0], "upper": out[:, -1], }
[docs] def on_test_batch_end( self, outputs: dict[str, Tensor], batch: Any, batch_idx: int, dataloader_idx: int = 0, ) -> None: """Test batch end save predictions. Args: outputs: dictionary of model outputs and aux variables batch: batch from dataloader batch_idx: batch index dataloader_idx: dataloader index """ if self.save_preds: save_image_predictions(outputs, batch_idx, self.pred_dir)