Source code for lightning_uq_box.uq_methods.deep_ensemble

# Copyright (c) 2023 lightning-uq-box. All rights reserved.
# Licensed under the Apache License 2.0.

"""Implement a Deep Ensemble Model for prediction."""

import os
from typing import Any

import torch
from lightning import LightningModule
from torch import Tensor

from .base import BaseModule
from .utils import (
    default_classification_metrics,
    default_px_regression_metrics,
    default_regression_metrics,
    default_segmentation_metrics,
    process_classification_prediction,
    process_regression_prediction,
    process_segmentation_prediction,
    save_classification_predictions,
    save_image_predictions,
    save_regression_predictions,
)


[docs] class DeepEnsemble(BaseModule): """Base Class for different Ensemble Models. If you use this model in your work, please cite: * https://proceedings.neurips.cc/paper_files/paper/2017/hash/9ef2ed4b7fd2c810847ffa5fa85bce38-Abstract.html # noqa: E501 """
[docs] def __init__( self, ensemble_members: list[dict[str, type[LightningModule] | str]] ) -> None: """Initialize a new instance of DeepEnsembleModel Wrapper. Args: ensemble_members: List of dicts where each element specifies the LightningModule class and a path to a checkpoint save_dir: path to directory where to store prediction quantiles: quantile values to compute for prediction """ super().__init__() self.n_ensemble_members = len(ensemble_members) self.ensemble_members = ensemble_members self.setup_task()
[docs] def setup_task(self) -> None: """Set up task.""" pass
[docs] def forward(self, X: Tensor) -> Tensor: """Forward step of Deep Ensemble. Args: X: input tensor of shape [batch_size, input_di] Returns: Ensemble member outputs stacked over last dimension for output of [batch_size, num_outputs, num_ensemble_members] """ out: list[torch.Tensor] = [] for model_config in self.ensemble_members: # load the weights into the network model_config["base_model"].load_state_dict( torch.load(model_config["ckpt_path"], weights_only=True)["state_dict"] ) model_config["base_model"].to(X.device).eval() out.append(model_config["base_model"](X)) return torch.stack(out, dim=-1)
[docs] def test_step( self, batch: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0 ) -> None: """Test step.""" """Compute test step for deep ensemble and log test metrics. Args: batch: prediction batch of shape [batch_size x input_dims] Returns: dictionary of uncertainty outputs """ out_dict = self.predict_step(batch[self.input_key]) out_dict[self.target_key] = batch[self.target_key].detach().squeeze(-1).cpu() if batch[self.input_key].shape[0] > 1: self.test_metrics(out_dict["pred"], batch[self.target_key]) # turn mean to np array out_dict["pred"] = out_dict["pred"].detach().cpu().squeeze(-1) # save metadata out_dict = self.add_aux_data_to_dict(out_dict, batch) return out_dict
[docs] def generate_ensemble_predictions(self, X: Tensor) -> Tensor: """Generate DeepEnsemble Predictions. Args: X: input tensor of shape [batch_size, input_di] Returns: the ensemble predictions """ return self.forward(X) # [batch_size, num_outputs, num_ensemble_members]
[docs] class DeepEnsembleRegression(DeepEnsemble): """Deep Ensemble Model for regression. If you use this model in your work, please cite: * https://proceedings.neurips.cc/paper_files/paper/2017/hash/9ef2ed4b7fd2c810847ffa5fa85bce38-Abstract.html """ # noqa: E501 pred_file_name = "preds.csv"
[docs] def setup_task(self) -> None: """Set up task for regression.""" self.test_metrics = default_regression_metrics("test")
[docs] def predict_step( self, X: Tensor, batch_idx: int = 0, dataloader_idx: int = 0 ) -> Any: """Compute prediction step for a deep ensemble. Args: X: input tensor of shape [batch_size, input_dims] batch_idx: the index of this batch dataloader_idx: the index of the dataloader Returns: mean and standard deviation of MC predictions """ with torch.no_grad(): preds = self.generate_ensemble_predictions(X) pred_dict = process_regression_prediction(preds) pred_dict["samples"] = preds return pred_dict
[docs] def on_test_batch_end( self, outputs: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0 ) -> None: """Test batch end save predictions. Args: outputs: dictionary of model outputs and aux variables batch_idx: batch index dataloader_idx: dataloader index """ save_regression_predictions( outputs, os.path.join(self.trainer.default_root_dir, self.pred_file_name) )
[docs] class DeepEnsembleClassification(DeepEnsemble): """Deep Ensemble Model for classification. If you use this model in your work, please cite: * https://proceedings.neurips.cc/paper_files/paper/2017/hash/9ef2ed4b7fd2c810847ffa5fa85bce38-Abstract.html """ # noqa: E501 valid_tasks = ["multiclass", "binary", "multilabel"] pred_file_name = "preds.csv"
[docs] def __init__( self, ensemble_members: list[dict[str, type[LightningModule] | str]], num_classes: int, task: str = "multiclass", ) -> None: """Initialize a new instance of DeepEnsemble for Classification. Args: ensemble_members: List of dicts where each element specifies the LightningModule class and a path to a checkpoint num_classes: number of classes task: classification task, one of "multiclass", "binary" or "multilabel" """ assert task in self.valid_tasks self.task = task self.num_classes = num_classes super().__init__(ensemble_members)
[docs] def setup_task(self) -> None: """Set up task for classification.""" self.test_metrics = default_classification_metrics( "test", self.task, self.num_classes )
[docs] def predict_step( self, X: Tensor, batch_idx: int = 0, dataloader_idx: int = 0 ) -> Any: """Compute prediction step for a deep ensemble. Args: X: input tensor of shape [batch_size, input_dims] batch_idx: the index of this batch dataloader_idx: the index of the dataloader Returns: mean and standard deviation of MC predictions """ with torch.no_grad(): preds = self.generate_ensemble_predictions(X) return process_classification_prediction(preds)
[docs] def on_test_batch_end( self, outputs: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0 ) -> None: """Test batch end save predictions. Args: outputs: dictionary of model outputs and aux variables batch_idx: batch index dataloader_idx: dataloader index """ save_classification_predictions( outputs, os.path.join(self.trainer.default_root_dir, self.pred_file_name) )
[docs] class DeepEnsembleSegmentation(DeepEnsembleClassification): """Deep Ensemble Model for segmentation. If you use this model in your work, please cite: * https://proceedings.neurips.cc/paper_files/paper/2017/hash/9ef2ed4b7fd2c810847ffa5fa85bce38-Abstract.html """ # noqa: E501 pred_dir_name = "preds"
[docs] def __init__( self, ensemble_members: list[dict[str, type[LightningModule] | str]], num_classes: int, task: str = "multiclass", save_preds: bool = False, ) -> None: """Initialize a new instance of DeepEnsemble for Segmentation. Args: ensemble_members: List of dicts where each element specifies the LightningModule class and a path to a checkpoint num_classes: number of classes task: classification task, one of "multiclass", "binary" or "multilabel" save_preds: whether to save predictions """ super().__init__(ensemble_members, num_classes, task) self.save_preds = save_preds
[docs] def setup_task(self) -> None: """Set up task for segmentation.""" self.test_metrics = default_segmentation_metrics( "test", self.task, self.num_classes )
[docs] def predict_step( self, X: Tensor, batch_idx: int = 0, dataloader_idx: int = 0 ) -> Any: """Compute prediction step for a deep ensemble. Args: X: input tensor of shape [batch_size, input_di] batch_idx: the index of this batch dataloader_idx: the index of the dataloader Returns: mean and standard deviation of MC predictions """ with torch.no_grad(): preds = self.generate_ensemble_predictions(X) return process_segmentation_prediction(preds)
[docs] def on_test_start(self) -> None: """Create logging directory and initialize metrics.""" self.pred_dir = os.path.join(self.trainer.default_root_dir, self.pred_dir_name) if not os.path.exists(self.pred_dir) and self.save_preds: os.makedirs(self.pred_dir)
[docs] def on_test_batch_end( self, outputs: dict[str, Tensor], batch: Any, batch_idx: int, dataloader_idx: int = 0, ) -> None: """Test batch end save predictions. Args: outputs: dictionary of model outputs and aux variables batch: batch from dataloader batch_idx: batch index dataloader_idx: dataloader index """ if self.save_preds: save_image_predictions(outputs, batch_idx, self.pred_dir)
[docs] class DeepEnsemblePxRegression(DeepEnsembleRegression): """Deep Ensemble Model for pixelwise regression. If you use this model in your work, please cite: * https://proceedings.neurips.cc/paper_files/paper/2017/hash/9ef2ed4b7fd2c810847ffa5fa85bce38-Abstract.html .. versionadded:: 0.2.0 """ # noqa: E501
[docs] def __init__( self, ensemble_members: list[dict[str, type[LightningModule] | str]], save_preds: bool = False, ) -> None: """Initialize a new instance of DeepEnsemble for Pixelwise Regression. Args: ensemble_members: List of dicts where each element specifies the LightningModule class and a path to a checkpoint save_preds: whether to save predictions """ super().__init__(ensemble_members) self.save_preds = save_preds
pred_dir_name = "preds"
[docs] def setup_task(self) -> None: """Set up task specific attributes.""" self.test_metrics = default_px_regression_metrics("test")
[docs] def on_test_start(self) -> None: """Create logging directory and initialize metrics.""" self.pred_dir = os.path.join(self.trainer.default_root_dir, self.pred_dir_name) if not os.path.exists(self.pred_dir) and self.save_preds: os.makedirs(self.pred_dir)
[docs] def on_test_batch_end( self, outputs: dict[str, Tensor], batch: Any, batch_idx: int, dataloader_idx: int = 0, ) -> None: """Test batch end save predictions. Args: outputs: dictionary of model outputs and aux variables batch: batch from dataloader batch_idx: the index of this batch dataloader_idx: the index of the dataloader """ if self.save_preds: save_image_predictions(outputs, batch_idx, self.pred_dir)