Source code for lightning_uq_box.uq_methods.swag

#    Copyright 2021 GlaxoSmithKline

#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at

#        http://www.apache.org/licenses/LICENSE-2.0

#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.

# Changes include:
# - integrating the functions into pytorch lightning Lightning Module framework
# - enable selections of stochastic modules

# Copyright (c) 2023 lightning-uq-box. All rights reserved.
# Licensed under the Apache License 2.0.

"""Stochastic Weight Averaging - Gaussian.

Adapted from https://github.com/GSK-AI/afterglow/blob/master/afterglow/trackers/trackers.py (Apache License 2.0) # noqa: E501
for support of partial stochasticity and integration to lightning.
"""

import math
import os
from collections import OrderedDict
from copy import deepcopy
from typing import Any

import torch
import torch.nn as nn
from torch import Tensor
from torch.distributions import Normal

from .base import DeterministicModel
from .utils import (
    _get_num_outputs,
    default_classification_metrics,
    default_px_regression_metrics,
    default_regression_metrics,
    default_segmentation_metrics,
    map_stochastic_modules,
    process_classification_prediction,
    process_regression_prediction,
    process_segmentation_prediction,
    save_classification_predictions,
    save_image_predictions,
    save_regression_predictions,
)


[docs] class SWAGBase(DeterministicModel): """Stochastic Weight Averaging - Gaussian (SWAG). If you use this model in your research, please cite the following paper: * https://proceedings.neurips.cc/paper_files/paper/2019/hash/118921efba23fc329e6560b27861f0c2-Abstract.html # noqa: E501 """
[docs] def __init__( self, model: nn.Module, max_swag_snapshots: int, snapshot_freq: int, num_mc_samples: int, swag_lr: float, loss_fn: nn.Module, stochastic_module_names: list[int | str] | None = None, ) -> None: """Initialize a new instance of SWAG Model Wrapper. Args: model: pytorch model max_swag_snapshots: maximum number of snapshots to store snapshot_freq: frequency of snapshots num_mc_samples: number of MC samples during prediction swag_lr: learning rate for swag loss_fn: loss function stochastic_module_names: list of module names or indices that should be converted to variational layer """ super().__init__(model, loss_fn, None, None) self.stochastic_module_names = map_stochastic_modules( self.model, stochastic_module_names ) self.swag_fitted = False self.current_iteration = 0 self.num_tracked = 0 self.model_w_and_b_module_names = self._find_weights_and_bias_modules( self.model ) self.max_swag_snapshots = max_swag_snapshots self._create_swag_buffers(self.model) # manual optimization with SWAG optimization process self.automatic_optimization = False
[docs] def setup_task(self) -> None: """Set up task specific attributes.""" pass
[docs] def training_step( self, batch: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0 ) -> None: """Compute SWAG optimization step. Args: batch: the output of your DataLoader batch_idx: the index of this batch dataloader_idx: the index of the dataloader """ swag_opt = self.optimizers() swag_opt.zero_grad() if self.trainer.global_step % self.hparams.snapshot_freq == 0: self.update_uncertainty_buffers() loss = self.loss_fn(self.model(batch[self.input_key]), batch[self.target_key]) self.manual_backward(loss) swag_opt.step()
[docs] def on_train_epoch_end(self): """Do not Log epoch-level training metrics.""" pass
[docs] def on_train_end(self) -> None: """After training stage is completed, swag is fitted.""" self.swag_fitted = True
[docs] def validation_step(self, *args: Any, **kwargs: Any) -> Tensor: """Not intended to be used.""" pass
[docs] def on_validation_epoch_end(self) -> None: """Do not log any validation metrics.""" pass
def _find_weights_and_bias_modules(self, instance: nn.Module) -> list[str]: """Find weights and bias modules corresponding to part stochastic modules.""" model_w_and_b_module_names: list[str] = [] for name, _ in instance.named_parameters(): if ( name.removesuffix(".weight").removesuffix(".bias") in self.stochastic_module_names ): # noqa: E501 model_w_and_b_module_names.append(name) return model_w_and_b_module_names def _create_swag_buffers(self, instance: nn.Module) -> None: """Create swawg buffers for an underlying module. Args: instance: underlying model instance for which to create buffers """ for name, parameter in instance.named_parameters(): # check for partial stochasticity modules if name in self.model_w_and_b_module_names: name = name.replace(".", "_") instance.register_buffer(f"{name}_mean", deepcopy(parameter)) instance.register_buffer( f"{name}_squared_mean", torch.zeros_like(parameter) ) instance.register_buffer( f"{name}_D_block", torch.zeros( (self.max_swag_snapshots, *parameter.shape), device=parameter.device, ), ) else: continue instance.register_buffer("num_snapshots_tracked", torch.tensor(0, dtype=int)) def _get_buffer_for_param(self, param_name: str, buffer_name: str): """Get buffer for parameter name. Args: param_name: parameter name buffer_name: buffer_name """ safe_name = param_name.replace(".", "_") # TODO be able to access and retrieve nested # param names in custom models return getattr(self.model, f"{safe_name}_{buffer_name}") def _set_buffer_for_param(self, param_name, buffer_name, value): safe_name = param_name.replace(".", "_") setattr(self.model, f"{safe_name}_{buffer_name}", value) def _update_tracked_state_dict(self, state_dict: dict[str, nn.Parameter]) -> None: """Update tracked state_dict. Args: state_dict: model state_dict Returns: state_dict """ full_state_dict = OrderedDict({**state_dict, **self._untracked_state_dict()}) full_state_dict._metadata = getattr(self.model.state_dict(), "_metadata", None) self.model.load_state_dict(full_state_dict) def _untracked_state_dict(self) -> dict[str, nn.Parameter]: """Return filtered untracked state dict.""" filtered_state_dict = {} for k, v in self.model.state_dict().items(): if k not in self.model_w_and_b_module_names: filtered_state_dict[k] = v return filtered_state_dict def _sample_state_dict(self) -> dict: """Sample the underlying model state dict.""" if self.num_tracked == 0: raise RuntimeError( "Attempted to sample weights using a tracker that has " "recorded no snapshots" ) sampled = {} # find first param for name, param in self.model.named_parameters(): if name in self.model_w_and_b_module_names: K_sample = ( Normal( torch.zeros(self.hparams.max_swag_snapshots), torch.ones(self.hparams.max_swag_snapshots), ) .sample() .to(param.device) # should have lightning device ) break else: continue for name in self.model_w_and_b_module_names: mean = self._get_buffer_for_param(name, "mean") squared_mean = self._get_buffer_for_param(name, "squared_mean") d_block = self._get_buffer_for_param(name, "D_block") p1 = mean p2 = Normal( torch.zeros_like(mean), (0.5 * (squared_mean - mean.pow(2)).clamp(1e-30)).sqrt(), ).sample() shape = d_block.shape[1:] aux = d_block.reshape(self.hparams.max_swag_snapshots, -1) p3 = torch.matmul(K_sample, aux).reshape(shape) / math.sqrt( 2 * (self.hparams.max_swag_snapshots - 1) ) sampled[name] = p1 + p2 + p3 return sampled
[docs] def update_uncertainty_buffers(self): """Update the running average over weights.""" if self.num_tracked == 0: with torch.no_grad(): for name, parameter in self.model.named_parameters(): if name in self.model_w_and_b_module_names: mean = self._get_buffer_for_param(name, "mean") squared_mean = self._get_buffer_for_param(name, "squared_mean") self._set_buffer_for_param(name, "mean", mean + parameter) self._set_buffer_for_param( name, "squared_mean", squared_mean + parameter.pow(2) ) else: continue else: with torch.no_grad(): for name, parameter in self.model.named_parameters(): if name in self.model_w_and_b_module_names: mean = self._get_buffer_for_param(name, "mean") squared_mean = self._get_buffer_for_param(name, "squared_mean") d_block = self._get_buffer_for_param(name, "D_block") self._set_buffer_for_param( name, "mean", (self.num_tracked * mean + parameter) / (self.num_tracked + 1), ) self._set_buffer_for_param( name, "squared_mean", (self.num_tracked * squared_mean + parameter.pow(2)) / (self.num_tracked + 1), ) d_block = d_block.roll(1, dims=0) d_block[0] = parameter - mean self._set_buffer_for_param(name, "D_block", d_block) else: continue self.num_tracked += 1
[docs] def sample_state(self): """Update the state with a sample.""" sampled_state_dict = self._sample_state_dict() self._update_tracked_state_dict(sampled_state_dict)
[docs] def sample_predictions(self, X: Tensor) -> Tensor: """Sample predictions. Args: X: input batch of shape [batch_size x input_dims] Returns: predictions of shape [batch_size x num_outputs x num_mc_samples] """ preds = [] for i in range(self.hparams.num_mc_samples): # sample weights self.sample_state() with torch.no_grad(): pred = self.model(X) preds.append(pred) preds = torch.stack(preds, dim=-1) return preds
[docs] def configure_optimizers(self) -> dict[str, Any]: """Manually implemented SWAG optimization.""" swag_params: list[nn.Parameter] = [ param for name, param in self.model.named_parameters() if name in self.model_w_and_b_module_names ] swag_optimizer = torch.optim.SGD(swag_params, lr=self.hparams.swag_lr) return swag_optimizer
[docs] class SWAGRegression(SWAGBase): """SWAG Model for Regression. If you use this model in your research, please cite the following paper: * https://proceedings.neurips.cc/paper_files/paper/2019/hash/118921efba23fc329e6560b27861f0c2-Abstract.html """ # noqa: E501 pred_file_name = "preds.csv"
[docs] def __init__( self, model: nn.Module, max_swag_snapshots: int, snapshot_freq: int, num_mc_samples: int, swag_lr: float, loss_fn: nn.Module, stochastic_module_names: list[int] | list[str] | None = None, ) -> None: """Initialize a new instance of SWAG Model for Regression. Args: model: pytorch model num_swag_epochs: number of epochs to train swag max_swag_snapshots: maximum number of snapshots to store snapshot_freq: frequency of snapshots num_mc_samples: number of MC samples during prediction swag_lr: learning rate for swag loss_fn: loss function stochastic_module_names: names of modules that are partially stochastic """ super().__init__( model, max_swag_snapshots, snapshot_freq, num_mc_samples, swag_lr, loss_fn, stochastic_module_names, ) self.save_hyperparameters(ignore=["model", "loss_fn"])
[docs] def setup_task(self) -> None: """Set up task specific attributes.""" self.test_metrics = default_regression_metrics("test")
[docs] def on_test_batch_end( self, outputs: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0 ) -> None: """Test batch end save predictions. Args: outputs: dictionary of model outputs and aux variables batch_idx: batch index dataloader_idx: dataloader index """ save_regression_predictions( outputs, os.path.join(self.trainer.default_root_dir, self.pred_file_name) )
[docs] def predict_step( self, X: Tensor, batch_idx: int = 0, dataloader_idx: int = 0 ) -> dict[str, Tensor]: """Prediction step that with SWAG uncertainty. Args: X: prediction batch of shape [batch_size x input_dims] batch_idx: batch index dataloader_idx: dataloader index Returns: prediction dictionary """ if not self.swag_fitted: raise RuntimeError( "SWAG is not fitted yet, please call trainer.fit() first." ) preds = self.sample_predictions(X) return process_regression_prediction(preds)
[docs] class SWAGClassification(SWAGBase): """SWAG Model for Classification. If you use this model in your research, please cite the following paper: * https://proceedings.neurips.cc/paper_files/paper/2019/hash/118921efba23fc329e6560b27861f0c2-Abstract.html """ # noqa: E501 pred_file_name = "preds.csv" valid_tasks = ["binary", "multiclass", "multilable"]
[docs] def __init__( self, model: nn.Module, max_swag_snapshots: int, snapshot_freq: int, num_mc_samples: int, swag_lr: float, loss_fn: nn.Module, task: str = "multiclass", stochastic_module_names: list[int] | list[str] | None = None, ) -> None: """Initialize a new instance of SWAG Model for Classification. Args: model: pytorch model num_swag_epochs: number of epochs to train swag max_swag_snapshots: maximum number of snapshots to store snapshot_freq: frequency of snapshots num_mc_samples: number of MC samples during prediction swag_lr: learning rate for swag loss_fn: loss function task: classification task, one of ['binary', 'multiclass', 'multilabel'] stochastic_module_names: names of modules that are partially stochastic """ assert task in self.valid_tasks self.task = task self.num_classes = _get_num_outputs(model) super().__init__( model, max_swag_snapshots, snapshot_freq, num_mc_samples, swag_lr, loss_fn, stochastic_module_names, ) self.save_hyperparameters(ignore=["model", "loss_fn"])
[docs] def adapt_output_for_metrics(self, out: Tensor) -> Tensor: """Adapt model output to be compatible for metric computation. Args: out: output from the model Returns: mean output """ return out
[docs] def setup_task(self) -> None: """Set up task specific attributes.""" self.test_metrics = default_classification_metrics( "test", self.task, self.num_classes )
[docs] def predict_step( self, X: Tensor, batch_idx: int = 0, dataloader_idx: int = 0 ) -> dict[str, Tensor]: """Prediction step with SWAG uncertainty. Args: X: prediction batch of shape [batch_size x input_dims] batch_idx: the index of this batch dataloader_idx: the index of the dataloader Returns: prediction dictionary """ if not self.swag_fitted: raise RuntimeError( "SWAG is not fitted yet, please call trainer.fit() first." ) preds = self.sample_predictions(X) return process_classification_prediction(preds)
[docs] def on_test_batch_end( self, outputs: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0 ) -> None: """Test batch end save predictions. Args: outputs: dictionary of model outputs and aux variables batch_idx: batch index dataloader_idx: dataloader index """ save_classification_predictions( outputs, os.path.join(self.trainer.default_root_dir, self.pred_file_name) )
[docs] class SWAGSegmentation(SWAGClassification): """SWAG Model for Segmentation.""" pred_dir_name = "preds"
[docs] def __init__( self, model: nn.Module, max_swag_snapshots: int, snapshot_freq: int, num_mc_samples: int, swag_lr: float, loss_fn: nn.Module, task: str = "multiclass", stochastic_module_names: list[int] | list[str] | None = None, save_preds: bool = False, ) -> None: """Initialize a new instance of SWAG Model for Segmentation. Args: model: pytorch model num_swag_epochs: number of epochs to train swag max_swag_snapshots: maximum number of snapshots to store snapshot_freq: frequency of snapshots num_mc_samples: number of MC samples during prediction swag_lr: learning rate for swag loss_fn: loss function task: segmentation task, one of ['binary', 'multiclass'] stochastic_module_names: names of modules that are partially stochastic save_preds: save predictions """ super().__init__( model, max_swag_snapshots, snapshot_freq, num_mc_samples, swag_lr, loss_fn, task, stochastic_module_names, ) self.save_preds = save_preds
[docs] def setup_task(self) -> None: """Set up task specific attributes.""" self.test_metrics = default_segmentation_metrics( "test", self.task, self.num_classes )
[docs] def predict_step( self, X: Tensor, batch_idx: int = 0, dataloader_idx: int = 0 ) -> dict[str, Tensor]: """Prediction step with SWAG uncertainty. Args: X: prediction batch of shape [batch_size x num_channels x height x width] batch_idx: batch index dataloader_idx: dataloader index Returns: prediction dictionary """ if not self.swag_fitted: raise RuntimeError( "SWAG is not fitted yet, please call trainer.fit() first." ) preds = self.sample_predictions(X) return process_segmentation_prediction(preds)
[docs] def on_test_start(self) -> None: """Create logging directory and initialize metrics.""" self.pred_dir = os.path.join(self.trainer.default_root_dir, self.pred_dir_name) if not os.path.exists(self.pred_dir) and self.save_preds: os.makedirs(self.pred_dir)
[docs] def on_test_batch_end( self, outputs: dict[str, Tensor], batch: Any, batch_idx: int, dataloader_idx: int = 0, ) -> None: """Test batch end save predictions. Args: outputs: dictionary of model outputs and aux variables batch: batch from dataloader batch_idx: batch index dataloader_idx: dataloader index """ if self.save_preds: save_image_predictions(outputs, batch_idx, self.pred_dir)
[docs] class SWAGPxRegression(SWAGRegression): """SWAG Model for Pixelwise Regression. .. versionadded:: 0.2.0 """ pred_dir_name = "preds"
[docs] def __init__( self, model: nn.Module, max_swag_snapshots: int, snapshot_freq: int, num_mc_samples: int, swag_lr: float, loss_fn: nn.Module, stochastic_module_names: list[int] | list[str] | None = None, save_preds: bool = False, ) -> None: """Initialize a new instance of SWAG Model for Pixelwise Regression. Args: model: pytorch model num_swag_epochs: number of epochs to train swag max_swag_snapshots: maximum number of snapshots to store snapshot_freq: frequency of snapshots num_mc_samples: number of MC samples during prediction swag_lr: learning rate for swag loss_fn: loss function stochastic_module_names: names of modules that are partially stochastic save_preds: save predictions """ super().__init__( model, max_swag_snapshots, snapshot_freq, num_mc_samples, swag_lr, loss_fn, stochastic_module_names, ) self.save_preds = save_preds
[docs] def setup_task(self) -> None: """Set up task specific attributes.""" self.test_metrics = default_px_regression_metrics("test")
[docs] def on_test_start(self) -> None: """Create logging directory and initialize metrics.""" self.pred_dir = os.path.join(self.trainer.default_root_dir, self.pred_dir_name) if not os.path.exists(self.pred_dir) and self.save_preds: os.makedirs(self.pred_dir)
[docs] def on_test_batch_end( self, outputs: dict[str, Tensor], batch: Any, batch_idx: int, dataloader_idx: int = 0, ) -> None: """Test batch end save predictions. Args: outputs: dictionary of model outputs and aux variables batch: batch from dataloader batch_idx: batch index dataloader_idx: dataloader index """ if self.save_preds: save_image_predictions(outputs, batch_idx, self.pred_dir)