# Copyright (c) 2023 lightning-uq-box. All rights reserved.
# Licensed under the Apache License 2.0.
"""Mc-Dropout module."""
import os
from typing import Any
import torch
import torch.nn as nn
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
from torch import Tensor
from .base import DeterministicModel
from .utils import (
_get_num_outputs,
default_classification_metrics,
default_px_regression_metrics,
default_regression_metrics,
default_segmentation_metrics,
freeze_model_backbone,
freeze_segmentation_model,
process_classification_prediction,
process_regression_prediction,
process_segmentation_prediction,
save_classification_predictions,
save_image_predictions,
save_regression_predictions,
)
def find_dropout_layers(model: nn.Module) -> list[str]:
"""Find dropout layers in model."""
dropout_layers = []
for name, module in model.named_modules():
if isinstance(module, nn.Dropout):
dropout_layers.append(name)
# if not dropout_layers:
# raise UserWarning(
# (
# "No dropout layers found in model, maybe dropout "
# "is implemented through nn.fucntional?"
# )
# )
return dropout_layers
[docs]
class MCDropoutBase(DeterministicModel):
"""MC-Dropout Base class.
If you use this model in your research, please cite the following paper:
* https://proceedings.mlr.press/v48/gal16.html
"""
[docs]
def __init__(
self,
model: nn.Module,
num_mc_samples: int,
loss_fn: nn.Module,
dropout_layer_names: list[str] = [],
freeze_backbone: bool = False,
optimizer: OptimizerCallable = torch.optim.Adam,
lr_scheduler: LRSchedulerCallable = None,
) -> None:
"""Initialize a new instance of MCDropoutModel.
Args:
model: pytorch model with dropout layers
num_mc_samples: number of MC samples during prediction
loss_fn: loss function
dropout_layer_names: names of dropout layers to activate during prediction
freeze_backbone: freeze backbone during training
optimizer: optimizer used for training
lr_scheduler: learning rate scheduler
"""
super().__init__(model, loss_fn, freeze_backbone, optimizer, lr_scheduler)
if not dropout_layer_names:
dropout_layer_names = find_dropout_layers(model)
self.dropout_layer_names = dropout_layer_names
[docs]
def setup_task(self) -> None:
"""Set up task specific attributes."""
pass
[docs]
def training_step(
self, batch: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0
) -> Tensor:
"""Compute and return the training loss.
Args:
batch: the output of your DataLoader
batch_idx: the index of this batch
dataloader_idx: the index of the dataloader
Returns:
training loss
"""
out = self.forward(batch[self.input_key])
loss = self.loss_fn(out, batch[self.target_key])
self.log(
"train_loss", loss, batch_size=batch[self.input_key].shape[0]
) # logging to Logger
self.train_metrics(self.adapt_output_for_metrics(out), batch[self.target_key])
return loss
[docs]
def activate_dropout(self) -> None:
"""Activate dropout layers."""
dropout_layers_found = []
self.model.train()
def activate_dropout_recursive(model, prefix=""):
for name, module in model.named_children():
full_name = f"{prefix}.{name}" if prefix else name
if full_name in self.dropout_layer_names and isinstance(
module, nn.Dropout
):
module.train()
dropout_layers_found.append(full_name)
elif isinstance(module, nn.Module):
activate_dropout_recursive(module, full_name)
# set batch norm layers to eval mode
elif isinstance(
module, nn.BatchNorm1d | nn.BatchNorm2d | nn.BatchNorm3d
):
module.eval()
activate_dropout_recursive(self.model)
if not dropout_layers_found:
raise UserWarning(
"No dropout layers found in model, maybe dropout "
"is implemented via specialized layers?"
)
[docs]
class MCDropoutRegression(MCDropoutBase):
"""MC-Dropout Model for Regression.
If you use this model in your research, please cite the following paper:
* https://proceedings.mlr.press/v48/gal16.html
"""
pred_file_name = "preds.csv"
[docs]
def __init__(
self,
model: nn.Module,
num_mc_samples: int,
loss_fn: nn.Module,
burnin_epochs: int = 0,
dropout_layer_names: list[str] = [],
freeze_backbone: bool = False,
optimizer: OptimizerCallable = torch.optim.Adam,
lr_scheduler: LRSchedulerCallable = None,
) -> None:
"""Initialize a new instance of MC-Dropout Model for Regression.
Args:
model: pytorch model with dropout layers
num_mc_samples: number of MC samples during prediction
loss_fn: loss function
burnin_epochs: number of burnin epochs before using the loss_fn
dropout_layer_names: names of dropout layers to activate during prediction
freeze_backbone: freeze backbone during training
optimizer: optimizer used for training
lr_scheduler: learning rate scheduler
from the predictive distribution
"""
super().__init__(
model,
num_mc_samples,
loss_fn,
dropout_layer_names,
freeze_backbone,
optimizer,
lr_scheduler,
)
self.save_hyperparameters(
ignore=["model", "loss_fn", "optimizer", "lr_scheduler"]
)
[docs]
def setup_task(self) -> None:
"""Set up task specific attributes."""
self.train_metrics = default_regression_metrics("train")
self.val_metrics = default_regression_metrics("val")
self.test_metrics = default_regression_metrics("test")
[docs]
def freeze_model(self) -> None:
"""Freeze model backbone.
By default, assumes a timm model with a backbone and head.
Alternatively, selected the last layer with parameters to freeze.
"""
if self.freeze_backbone:
freeze_model_backbone(self.model)
[docs]
def adapt_output_for_metrics(self, out: Tensor) -> Tensor:
"""Adapt model output to be compatible for metric computation.."""
assert out.shape[-1] <= 2, "Ony support single mean or Gaussian output."
return out[:, 0:1]
[docs]
def training_step(
self, batch: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0
) -> Tensor:
"""Compute and return the training loss.
Args:
batch: the output of your DataLoader
batch_idx: the index of this batch
dataloader_idx: the index of the dataloader
Returns:
training loss
"""
out = self.forward(batch[self.input_key])
if self.current_epoch < self.hparams.burnin_epochs:
loss = nn.functional.mse_loss(
self.adapt_output_for_metrics(out), batch[self.target_key]
)
else:
loss = self.loss_fn(out, batch[self.target_key])
self.log(
"train_loss", loss, batch_size=batch[self.input_key].shape[0]
) # logging to Logger
self.train_metrics(self.adapt_output_for_metrics(out), batch[self.target_key])
return loss
[docs]
def predict_step(
self, X: Tensor, batch_idx: int = 0, dataloader_idx: int = 0
) -> dict[str, Tensor]:
"""Predict steps via Monte Carlo Sampling.
Args:
X: prediction batch of shape [batch_size x input_dims]
batch_idx: the index of this batch
dataloader_idx: the index of the dataloader
Returns:
mean and standard deviation of MC predictions
"""
self.activate_dropout()
with torch.no_grad():
preds = torch.stack(
[self.model(X) for _ in range(self.hparams.num_mc_samples)], dim=-1
) # shape [batch_size, num_outputs, num_samples]
return process_regression_prediction(preds)
[docs]
def on_test_batch_end(
self, outputs: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0
) -> None:
"""Test batch end save predictions.
Args:
outputs: dictionary of model outputs and aux variables
batch_idx: batch index
dataloader_idx: dataloader index
"""
save_regression_predictions(
outputs, os.path.join(self.trainer.default_root_dir, self.pred_file_name)
)
[docs]
class MCDropoutClassification(MCDropoutBase):
"""MC-Dropout Model for Classification.
If you use this model in your research, please cite the following paper:
* https://proceedings.mlr.press/v48/gal16.html
"""
pred_file_name = "preds.csv"
valid_tasks = ["binary", "multiclass", "multilable"]
[docs]
def __init__(
self,
model: nn.Module,
num_mc_samples: int,
loss_fn: nn.Module,
task: str = "multiclass",
dropout_layer_names: list[str] = [],
freeze_backbone: bool = False,
optimizer: OptimizerCallable = torch.optim.Adam,
lr_scheduler: LRSchedulerCallable = None,
) -> None:
"""Initialize a new instance of MC-Dropout Model for Classification.
Args:
model: pytorch model with dropout layers
num_mc_samples: number of MC samples during prediction
loss_fn: loss function
task: classification task, one of ['binary', 'multiclass', 'multilabel']
dropout_layer_names: names of dropout layers to activate during prediction
freeze_backbone: freeze backbone during training
optimizer: optimizer used for training
lr_scheduler: learning rate scheduler
"""
assert task in self.valid_tasks
self.task = task
self.num_classes = _get_num_outputs(model)
super().__init__(
model,
num_mc_samples,
loss_fn,
dropout_layer_names,
freeze_backbone,
optimizer,
lr_scheduler,
)
self.save_hyperparameters(
ignore=["model", "loss_fn", "optimizer", "lr_scheduler"]
)
[docs]
def setup_task(self) -> None:
"""Set up task specific attributes."""
self.train_metrics = default_classification_metrics(
"train", self.task, self.num_classes
)
self.val_metrics = default_classification_metrics(
"val", self.task, self.num_classes
)
self.test_metrics = default_classification_metrics(
"test", self.task, self.num_classes
)
[docs]
def adapt_output_for_metrics(self, out: Tensor) -> Tensor:
"""Extract mean output from model."""
return out
[docs]
def predict_step(
self, X: Tensor, batch_idx: int = 0, dataloader_idx: int = 0
) -> dict[str, Tensor]:
"""Predict steps via Monte Carlo Sampling.
Args:
X: prediction batch of shape [batch_size x input_dims]
batch_idx: batch index
dataloader_idx: dataloader index
Returns:
mean and standard deviation of MC predictions
"""
self.activate_dropout() # activate dropout during prediction
with torch.no_grad():
preds = torch.stack(
[self.model(X) for _ in range(self.hparams.num_mc_samples)], dim=-1
) # shape [batch_size, num_outputs, num_samples]
return process_classification_prediction(preds)
[docs]
def on_test_batch_end(
self, outputs: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0
) -> None:
"""Test batch end save predictions.
Args:
outputs: dictionary of model outputs and aux variables
batch_idx: batch index
dataloader_idx: dataloader index
"""
save_classification_predictions(
outputs, os.path.join(self.trainer.default_root_dir, self.pred_file_name)
)
[docs]
class MCDropoutSegmentation(MCDropoutClassification):
"""MC-Dropout Model for Segmentation."""
pred_dir_name = "preds"
[docs]
def __init__(
self,
model: nn.Module,
num_mc_samples: int,
loss_fn: nn.Module,
task: str = "multiclass",
dropout_layer_names: list[str] = [],
freeze_backbone: bool = False,
freeze_decoder: bool = False,
optimizer: OptimizerCallable = torch.optim.Adam,
lr_scheduler: LRSchedulerCallable = None,
save_preds: bool = False,
) -> None:
"""Initialize a new instance of MC-Dropout Model for Segmentation.
Args:
model: pytorch model with dropout layers
num_mc_samples: number of MC samples during prediction
loss_fn: loss function
task: classification task, one of ['binary', 'multiclass', 'multilabel']
dropout_layer_names: names of dropout layers to activate during prediction
freeze_backbone: whether to freeze the model backbone, by default this is
supported for torchseg Unet models
freeze_decoder: whether to freeze the model decoder, by default this is
supported for torchseg Unet models
optimizer: optimizer used for training
lr_scheduler: learning rate scheduler
save_preds: whether to save predictions
"""
self.freeze_backbone = freeze_backbone
self.freeze_decoder = freeze_decoder
super().__init__(
model,
num_mc_samples,
loss_fn,
task,
dropout_layer_names,
freeze_backbone,
optimizer,
lr_scheduler,
)
self.save_preds = save_preds
[docs]
def setup_task(self) -> None:
"""Set up task specific attributes for segmentation."""
self.train_metrics = default_segmentation_metrics(
"train", self.task, self.num_classes
)
self.val_metrics = default_segmentation_metrics(
"val", self.task, self.num_classes
)
self.test_metrics = default_segmentation_metrics(
"test", self.task, self.num_classes
)
[docs]
def freeze_model(self) -> None:
"""Freeze model backbone.
By default, assumes a timm model with a backbone and head.
Alternatively, selected the last layer with parameters to freeze.
"""
freeze_segmentation_model(self.model, self.freeze_backbone, self.freeze_decoder)
[docs]
def predict_step(
self, X: Tensor, batch_idx: int = 0, dataloader_idx: int = 0
) -> dict[str, Tensor]:
"""Predict steps via Monte Carlo Sampling.
Args:
X: prediction batch of shape [batch_size x num_channels x height x width]
batch_idx: batch index
dataloader_idx: dataloader index
Returns:
mean and standard deviation of MC predictions
"""
self.activate_dropout() # activate dropout during prediction
with torch.no_grad():
preds = torch.stack(
[self.model(X) for _ in range(self.hparams.num_mc_samples)], dim=-1
) # shape [batch_size, num_outputs, num_samples]
return process_segmentation_prediction(preds)
[docs]
def on_test_start(self) -> None:
"""Create logging directory and initialize metrics."""
self.pred_dir = os.path.join(self.trainer.default_root_dir, self.pred_dir_name)
if not os.path.exists(self.pred_dir) and self.save_preds:
os.makedirs(self.pred_dir)
[docs]
def on_test_batch_end(
self,
outputs: dict[str, Tensor],
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
"""Test batch end save predictions.
Args:
outputs: dictionary of model outputs and aux variables
batch: batch from dataloader
batch_idx: batch index
dataloader_idx: dataloader index
"""
if self.save_preds:
save_image_predictions(outputs, batch_idx, self.pred_dir)
[docs]
class MCDropoutPxRegression(MCDropoutRegression):
"""MC-Dropout Model for Pixel-wise Regression.
.. versionadded:: 0.2.0
"""
pred_dir_name = "preds"
[docs]
def __init__(
self,
model: nn.Module,
num_mc_samples: int,
loss_fn: nn.Module,
burnin_epochs: int = 0,
dropout_layer_names: list[str] = [],
freeze_backbone: bool = False,
freeze_decoder: bool = False,
optimizer: OptimizerCallable = torch.optim.Adam,
lr_scheduler: LRSchedulerCallable = None,
save_preds: bool = False,
) -> None:
"""Initialize a new instance of MC-Dropout Model for Pixel-wise Regression.
Args:
model: pytorch model with dropout layers
num_mc_samples: number of MC samples during prediction
loss_fn: loss function
burnin_epochs: number of burnin epochs before using the loss_fn
dropout_layer_names: names of dropout layers to activate during prediction
freeze_backbone: freeze backbone during training
freeze_decoder: freeze decoder during training
optimizer: optimizer used for training
lr_scheduler: learning rate scheduler
save_preds: whether to save predictions
"""
self.freeze_decoder = freeze_decoder
super().__init__(
model,
num_mc_samples,
loss_fn,
burnin_epochs,
dropout_layer_names,
freeze_backbone,
optimizer,
lr_scheduler,
)
self.save_preds = save_preds
[docs]
def freeze_model(self) -> None:
"""Freeze model backbone.
By default, assumes a timm model with a backbone and head.
Alternatively, selected the last layer with parameters to freeze.
"""
freeze_segmentation_model(self.model, self.freeze_backbone, self.freeze_decoder)
[docs]
def setup_task(self) -> None:
"""Set up task specific attributes."""
self.train_metrics = default_px_regression_metrics("train")
self.val_metrics = default_px_regression_metrics("val")
self.test_metrics = default_px_regression_metrics("test")
[docs]
def adapt_output_for_metrics(self, out: Tensor) -> Tensor:
"""Adapt model output to be compatible for metric computation.."""
assert out.shape[1] <= 2, "Ony support single mean or Gaussian output."
return out[:, 0:1, ...].contiguous()
[docs]
def on_test_start(self) -> None:
"""Create logging directory and initialize metrics."""
self.pred_dir = os.path.join(self.trainer.default_root_dir, self.pred_dir_name)
if not os.path.exists(self.pred_dir) and self.save_preds:
os.makedirs(self.pred_dir)
[docs]
def on_test_batch_end(
self,
outputs: dict[str, Tensor],
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
"""Test batch end save predictions.
Args:
outputs: dictionary of model outputs and aux variables
batch: batch from dataloader
batch_idx: batch index
dataloader_idx: dataloader index
"""
if self.save_preds:
save_image_predictions(outputs, batch_idx, self.pred_dir)