Source code for lightning_uq_box.uq_methods.mc_dropout

# Copyright (c) 2023 lightning-uq-box. All rights reserved.
# Licensed under the Apache License 2.0.

"""Mc-Dropout module."""

import os
from typing import Any

import torch
import torch.nn as nn
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
from torch import Tensor

from .base import DeterministicModel
from .utils import (
    _get_num_outputs,
    default_classification_metrics,
    default_px_regression_metrics,
    default_regression_metrics,
    default_segmentation_metrics,
    freeze_model_backbone,
    freeze_segmentation_model,
    process_classification_prediction,
    process_regression_prediction,
    process_segmentation_prediction,
    save_classification_predictions,
    save_image_predictions,
    save_regression_predictions,
)


def find_dropout_layers(model: nn.Module) -> list[str]:
    """Find dropout layers in model."""
    dropout_layers = []
    for name, module in model.named_modules():
        if isinstance(module, nn.Dropout):
            dropout_layers.append(name)

    # if not dropout_layers:
    #     raise UserWarning(
    #         (
    #           "No dropout layers found in model, maybe dropout "
    #           "is implemented through nn.fucntional?"
    #         )
    #     )
    return dropout_layers


[docs] class MCDropoutBase(DeterministicModel): """MC-Dropout Base class. If you use this model in your research, please cite the following paper: * https://proceedings.mlr.press/v48/gal16.html """
[docs] def __init__( self, model: nn.Module, num_mc_samples: int, loss_fn: nn.Module, dropout_layer_names: list[str] = [], freeze_backbone: bool = False, optimizer: OptimizerCallable = torch.optim.Adam, lr_scheduler: LRSchedulerCallable = None, ) -> None: """Initialize a new instance of MCDropoutModel. Args: model: pytorch model with dropout layers num_mc_samples: number of MC samples during prediction loss_fn: loss function dropout_layer_names: names of dropout layers to activate during prediction freeze_backbone: freeze backbone during training optimizer: optimizer used for training lr_scheduler: learning rate scheduler """ super().__init__(model, loss_fn, freeze_backbone, optimizer, lr_scheduler) if not dropout_layer_names: dropout_layer_names = find_dropout_layers(model) self.dropout_layer_names = dropout_layer_names
[docs] def setup_task(self) -> None: """Set up task specific attributes.""" pass
[docs] def training_step( self, batch: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0 ) -> Tensor: """Compute and return the training loss. Args: batch: the output of your DataLoader batch_idx: the index of this batch dataloader_idx: the index of the dataloader Returns: training loss """ out = self.forward(batch[self.input_key]) loss = self.loss_fn(out, batch[self.target_key]) self.log( "train_loss", loss, batch_size=batch[self.input_key].shape[0] ) # logging to Logger self.train_metrics(self.adapt_output_for_metrics(out), batch[self.target_key]) return loss
[docs] def activate_dropout(self) -> None: """Activate dropout layers.""" dropout_layers_found = [] self.model.train() def activate_dropout_recursive(model, prefix=""): for name, module in model.named_children(): full_name = f"{prefix}.{name}" if prefix else name if full_name in self.dropout_layer_names and isinstance( module, nn.Dropout ): module.train() dropout_layers_found.append(full_name) elif isinstance(module, nn.Module): activate_dropout_recursive(module, full_name) # set batch norm layers to eval mode elif isinstance( module, nn.BatchNorm1d | nn.BatchNorm2d | nn.BatchNorm3d ): module.eval() activate_dropout_recursive(self.model) if not dropout_layers_found: raise UserWarning( "No dropout layers found in model, maybe dropout " "is implemented via specialized layers?" )
[docs] class MCDropoutRegression(MCDropoutBase): """MC-Dropout Model for Regression. If you use this model in your research, please cite the following paper: * https://proceedings.mlr.press/v48/gal16.html """ pred_file_name = "preds.csv"
[docs] def __init__( self, model: nn.Module, num_mc_samples: int, loss_fn: nn.Module, burnin_epochs: int = 0, dropout_layer_names: list[str] = [], freeze_backbone: bool = False, optimizer: OptimizerCallable = torch.optim.Adam, lr_scheduler: LRSchedulerCallable = None, ) -> None: """Initialize a new instance of MC-Dropout Model for Regression. Args: model: pytorch model with dropout layers num_mc_samples: number of MC samples during prediction loss_fn: loss function burnin_epochs: number of burnin epochs before using the loss_fn dropout_layer_names: names of dropout layers to activate during prediction freeze_backbone: freeze backbone during training optimizer: optimizer used for training lr_scheduler: learning rate scheduler from the predictive distribution """ super().__init__( model, num_mc_samples, loss_fn, dropout_layer_names, freeze_backbone, optimizer, lr_scheduler, ) self.save_hyperparameters( ignore=["model", "loss_fn", "optimizer", "lr_scheduler"] )
[docs] def setup_task(self) -> None: """Set up task specific attributes.""" self.train_metrics = default_regression_metrics("train") self.val_metrics = default_regression_metrics("val") self.test_metrics = default_regression_metrics("test")
[docs] def freeze_model(self) -> None: """Freeze model backbone. By default, assumes a timm model with a backbone and head. Alternatively, selected the last layer with parameters to freeze. """ if self.freeze_backbone: freeze_model_backbone(self.model)
[docs] def adapt_output_for_metrics(self, out: Tensor) -> Tensor: """Adapt model output to be compatible for metric computation..""" assert out.shape[-1] <= 2, "Ony support single mean or Gaussian output." return out[:, 0:1]
[docs] def training_step( self, batch: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0 ) -> Tensor: """Compute and return the training loss. Args: batch: the output of your DataLoader batch_idx: the index of this batch dataloader_idx: the index of the dataloader Returns: training loss """ out = self.forward(batch[self.input_key]) if self.current_epoch < self.hparams.burnin_epochs: loss = nn.functional.mse_loss( self.adapt_output_for_metrics(out), batch[self.target_key] ) else: loss = self.loss_fn(out, batch[self.target_key]) self.log( "train_loss", loss, batch_size=batch[self.input_key].shape[0] ) # logging to Logger self.train_metrics(self.adapt_output_for_metrics(out), batch[self.target_key]) return loss
[docs] def predict_step( self, X: Tensor, batch_idx: int = 0, dataloader_idx: int = 0 ) -> dict[str, Tensor]: """Predict steps via Monte Carlo Sampling. Args: X: prediction batch of shape [batch_size x input_dims] batch_idx: the index of this batch dataloader_idx: the index of the dataloader Returns: mean and standard deviation of MC predictions """ self.activate_dropout() with torch.no_grad(): preds = torch.stack( [self.model(X) for _ in range(self.hparams.num_mc_samples)], dim=-1 ) # shape [batch_size, num_outputs, num_samples] return process_regression_prediction(preds)
[docs] def on_test_batch_end( self, outputs: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0 ) -> None: """Test batch end save predictions. Args: outputs: dictionary of model outputs and aux variables batch_idx: batch index dataloader_idx: dataloader index """ save_regression_predictions( outputs, os.path.join(self.trainer.default_root_dir, self.pred_file_name) )
[docs] class MCDropoutClassification(MCDropoutBase): """MC-Dropout Model for Classification. If you use this model in your research, please cite the following paper: * https://proceedings.mlr.press/v48/gal16.html """ pred_file_name = "preds.csv" valid_tasks = ["binary", "multiclass", "multilable"]
[docs] def __init__( self, model: nn.Module, num_mc_samples: int, loss_fn: nn.Module, task: str = "multiclass", dropout_layer_names: list[str] = [], freeze_backbone: bool = False, optimizer: OptimizerCallable = torch.optim.Adam, lr_scheduler: LRSchedulerCallable = None, ) -> None: """Initialize a new instance of MC-Dropout Model for Classification. Args: model: pytorch model with dropout layers num_mc_samples: number of MC samples during prediction loss_fn: loss function task: classification task, one of ['binary', 'multiclass', 'multilabel'] dropout_layer_names: names of dropout layers to activate during prediction freeze_backbone: freeze backbone during training optimizer: optimizer used for training lr_scheduler: learning rate scheduler """ assert task in self.valid_tasks self.task = task self.num_classes = _get_num_outputs(model) super().__init__( model, num_mc_samples, loss_fn, dropout_layer_names, freeze_backbone, optimizer, lr_scheduler, ) self.save_hyperparameters( ignore=["model", "loss_fn", "optimizer", "lr_scheduler"] ) # FIXME: why isn't save_hyperparameters working? self.num_mc_samples = num_mc_samples
[docs] def setup_task(self) -> None: """Set up task specific attributes.""" self.train_metrics = default_classification_metrics( "train", self.task, self.num_classes ) self.val_metrics = default_classification_metrics( "val", self.task, self.num_classes ) self.test_metrics = default_classification_metrics( "test", self.task, self.num_classes )
[docs] def adapt_output_for_metrics(self, out: Tensor) -> Tensor: """Extract mean output from model.""" return out
[docs] def predict_step( self, X: Tensor, batch_idx: int = 0, dataloader_idx: int = 0 ) -> dict[str, Tensor]: """Predict steps via Monte Carlo Sampling. Args: X: prediction batch of shape [batch_size x input_dims] batch_idx: batch index dataloader_idx: dataloader index Returns: mean and standard deviation of MC predictions """ self.activate_dropout() # activate dropout during prediction with torch.no_grad(): preds = torch.stack( [self.model(X) for _ in range(self.num_mc_samples)], dim=-1 ) # shape [batch_size, num_outputs, num_samples] return process_classification_prediction(preds)
[docs] def on_test_batch_end( self, outputs: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0 ) -> None: """Test batch end save predictions. Args: outputs: dictionary of model outputs and aux variables batch_idx: batch index dataloader_idx: dataloader index """ save_classification_predictions( outputs, os.path.join(self.trainer.default_root_dir, self.pred_file_name) )
[docs] class MCDropoutSegmentation(MCDropoutClassification): """MC-Dropout Model for Segmentation.""" pred_dir_name = "preds"
[docs] def __init__( self, model: nn.Module, num_mc_samples: int, loss_fn: nn.Module, task: str = "multiclass", dropout_layer_names: list[str] = [], freeze_backbone: bool = False, freeze_decoder: bool = False, optimizer: OptimizerCallable = torch.optim.Adam, lr_scheduler: LRSchedulerCallable = None, save_preds: bool = False, ) -> None: """Initialize a new instance of MC-Dropout Model for Segmentation. Args: model: pytorch model with dropout layers num_mc_samples: number of MC samples during prediction loss_fn: loss function task: classification task, one of ['binary', 'multiclass', 'multilabel'] dropout_layer_names: names of dropout layers to activate during prediction freeze_backbone: whether to freeze the model backbone, by default this is supported for torchseg Unet models freeze_decoder: whether to freeze the model decoder, by default this is supported for torchseg Unet models optimizer: optimizer used for training lr_scheduler: learning rate scheduler save_preds: whether to save predictions """ self.freeze_backbone = freeze_backbone self.freeze_decoder = freeze_decoder super().__init__( model, num_mc_samples, loss_fn, task, dropout_layer_names, freeze_backbone, optimizer, lr_scheduler, ) self.save_preds = save_preds
[docs] def setup_task(self) -> None: """Set up task specific attributes for segmentation.""" self.train_metrics = default_segmentation_metrics( "train", self.task, self.num_classes ) self.val_metrics = default_segmentation_metrics( "val", self.task, self.num_classes ) self.test_metrics = default_segmentation_metrics( "test", self.task, self.num_classes )
[docs] def freeze_model(self) -> None: """Freeze model backbone. By default, assumes a timm model with a backbone and head. Alternatively, selected the last layer with parameters to freeze. """ freeze_segmentation_model(self.model, self.freeze_backbone, self.freeze_decoder)
[docs] def predict_step( self, X: Tensor, batch_idx: int = 0, dataloader_idx: int = 0 ) -> dict[str, Tensor]: """Predict steps via Monte Carlo Sampling. Args: X: prediction batch of shape [batch_size x num_channels x height x width] batch_idx: batch index dataloader_idx: dataloader index Returns: mean and standard deviation of MC predictions """ self.activate_dropout() # activate dropout during prediction with torch.no_grad(): preds = torch.stack( [self.model(X) for _ in range(self.hparams.num_mc_samples)], dim=-1 ) # shape [batch_size, num_outputs, num_samples] return process_segmentation_prediction(preds)
[docs] def on_test_start(self) -> None: """Create logging directory and initialize metrics.""" self.pred_dir = os.path.join(self.trainer.default_root_dir, self.pred_dir_name) if not os.path.exists(self.pred_dir) and self.save_preds: os.makedirs(self.pred_dir)
[docs] def on_test_batch_end( self, outputs: dict[str, Tensor], batch: Any, batch_idx: int, dataloader_idx: int = 0, ) -> None: """Test batch end save predictions. Args: outputs: dictionary of model outputs and aux variables batch: batch from dataloader batch_idx: batch index dataloader_idx: dataloader index """ if self.save_preds: save_image_predictions(outputs, batch_idx, self.pred_dir)
[docs] class MCDropoutPxRegression(MCDropoutRegression): """MC-Dropout Model for Pixel-wise Regression. .. versionadded:: 0.2.0 """ pred_dir_name = "preds"
[docs] def __init__( self, model: nn.Module, num_mc_samples: int, loss_fn: nn.Module, burnin_epochs: int = 0, dropout_layer_names: list[str] = [], freeze_backbone: bool = False, freeze_decoder: bool = False, optimizer: OptimizerCallable = torch.optim.Adam, lr_scheduler: LRSchedulerCallable = None, save_preds: bool = False, ) -> None: """Initialize a new instance of MC-Dropout Model for Pixel-wise Regression. Args: model: pytorch model with dropout layers num_mc_samples: number of MC samples during prediction loss_fn: loss function burnin_epochs: number of burnin epochs before using the loss_fn dropout_layer_names: names of dropout layers to activate during prediction freeze_backbone: freeze backbone during training freeze_decoder: freeze decoder during training optimizer: optimizer used for training lr_scheduler: learning rate scheduler save_preds: whether to save predictions """ self.freeze_decoder = freeze_decoder super().__init__( model, num_mc_samples, loss_fn, burnin_epochs, dropout_layer_names, freeze_backbone, optimizer, lr_scheduler, ) self.save_preds = save_preds
[docs] def freeze_model(self) -> None: """Freeze model backbone. By default, assumes a timm model with a backbone and head. Alternatively, selected the last layer with parameters to freeze. """ freeze_segmentation_model(self.model, self.freeze_backbone, self.freeze_decoder)
[docs] def setup_task(self) -> None: """Set up task specific attributes.""" self.train_metrics = default_px_regression_metrics("train") self.val_metrics = default_px_regression_metrics("val") self.test_metrics = default_px_regression_metrics("test")
[docs] def adapt_output_for_metrics(self, out: Tensor) -> Tensor: """Adapt model output to be compatible for metric computation..""" assert out.shape[1] <= 2, "Ony support single mean or Gaussian output." return out[:, 0:1, ...].contiguous()
[docs] def on_test_start(self) -> None: """Create logging directory and initialize metrics.""" self.pred_dir = os.path.join(self.trainer.default_root_dir, self.pred_dir_name) if not os.path.exists(self.pred_dir) and self.save_preds: os.makedirs(self.pred_dir)
[docs] def on_test_batch_end( self, outputs: dict[str, Tensor], batch: Any, batch_idx: int, dataloader_idx: int = 0, ) -> None: """Test batch end save predictions. Args: outputs: dictionary of model outputs and aux variables batch: batch from dataloader batch_idx: batch index dataloader_idx: dataloader index """ if self.save_preds: save_image_predictions(outputs, batch_idx, self.pred_dir)