# Copyright 2019 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2023 lightning-uq-box. All rights reserved.
# Licensed under the Apache License 2.0.
# Changes
# - Removed all references to tensorflow
# - adapt to lightning training framework
# - https://arxiv.org/pdf/1905.13077.pdf paper
"""Hierarchical Probabilistic U-Net."""
import os
from collections.abc import Callable
from typing import Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
from torch import Tensor
from lightning_uq_box.uq_methods import BaseModule
from ..models.hierarchical_prob_unet import (
LagrangeMultiplier,
MovingAverage,
_HierarchicalCore,
_StitchingDecoder,
)
from .utils import (
default_segmentation_metrics,
process_segmentation_prediction,
save_image_predictions,
)
[docs]
class HierarchicalProbUNet(BaseModule):
"""Hierarchical Probabilistic U-Net.
If you use this model, please cite the following paper:
- https://arxiv.org/pdf/1905.13077.pdf
"""
valid_loss_types = ["elbo", "geco"]
valid_tasks = ["multiclass", "binary"]
pred_dir_name = "preds"
[docs]
def __init__(
self,
latent_dims: tuple[int, ...] = (1, 1, 1, 1),
channels_per_block: tuple[int, ...] | None = None,
num_in_channels: int = 3,
num_classes: int = 2,
down_channels_per_block: tuple[int, ...] | None = None,
activation_fn: Callable[[Tensor], Tensor] = F.relu,
convs_per_block: int = 3,
blocks_per_level: int = 3,
loss_type: str = "geco",
kappa: int = 10,
beta: float | None = 100,
top_k_percentage: float | None = None,
deterministic_top_k: int | None = None,
num_samples: int = 5,
task: str = "multiclass",
optimizer: OptimizerCallable = torch.optim.Adam,
lr_scheduler: LRSchedulerCallable | None = None,
save_preds: bool = False,
) -> None:
"""Initialize a new HierarchicalProbUNET.
Args:
latent_dims: A tuple of integers specifying the dimensions of the latents
at each scale. The length of the tuple indicates the number of U-Net
decoder scales that have latents
channels_per_block: A tuple of integers specifying the number of output
channels for each block or None. If None, the default values are used
num_in_channels: the number of input channels
num_classes: An integer specifying the number of output classes.
down_channels_per_block: A tuple of integers specifying the number of
intermediate channels for each block or None. If None, the
intermediate channels are chosen equal to channels_per_block
activation_fn: A callable activation function
convs_per_block: An integer specifying the number of convolutional layers
blocks_per_level: An integer specifying the number of residual blocks per
level
model: the hierarchical probabilistic u-net model
loss_type: the type of loss to use, either "elbo" or "geco"
kappa: the kappa parameter for the geco loss
beta: the beta parameter for the elbo loss
top_k_percentage: the percentage of top loss_mask pixels to use for the
loss
deterministic_top_k: An optional percentage for the top-k loss_mask.
num_samples: the number of samples to use during prediction
task: task type, either "multiclass" or "binary"
optimizer: optimizer
lr_scheduler: learning rate scheduler
save_preds: whether to save predictions
"""
super().__init__()
self.latent_dims = latent_dims
self.channels_per_block = channels_per_block
self.num_classes = num_classes
self.num_in_channels = num_in_channels
self.down_channels_per_block = down_channels_per_block
self.activation_fn = activation_fn
self.convs_per_block = convs_per_block
self.blocks_per_level = blocks_per_level
self.save_preds = save_preds
self._build_model()
assert (
loss_type in self.valid_loss_types
), f"Loss type {loss_type} not valid, please choose from {self.valid_loss_types}" # noqa: E501
self.loss_type = loss_type
assert (
task in self.valid_tasks
), f"Task {task} not valid, please choose from {self.valid_tasks}."
self.task = task
# TODO check that beta can only be used with elbo loss
self.beta = beta
# TODO check that kappa can only be used with geco loss
self.kappa = kappa
self.top_k_percentage = top_k_percentage
self.deterministic_top_k = deterministic_top_k
self.num_samples = num_samples
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self._cache = ()
if self.loss_type == "geco":
self._lagmul = LagrangeMultiplier()
self._moving_average = MovingAverage()
self.setup_task()
def _build_model(self) -> None:
"""Build the HierarchicalProbUnet model."""
# these are default arguments from the original code
# TODO need to add input channels
base_channels = 24
default_channels_per_block = (
self.num_in_channels + self.latent_dims[0],
base_channels,
2 * base_channels,
4 * base_channels,
8 * base_channels,
8 * base_channels,
)
if self.channels_per_block is None:
self.channels_per_block = default_channels_per_block
if self.down_channels_per_block is None:
self.all_gatherdown_channels_per_block = tuple(
[i / 2 for i in default_channels_per_block]
)
self._prior = _HierarchicalCore(
latent_dims=self.latent_dims,
channels_per_block=self.channels_per_block,
down_channels_per_block=self.down_channels_per_block,
activation_fn=self.activation_fn,
convs_per_block=self.convs_per_block,
blocks_per_level=self.blocks_per_level,
)
self._posterior = _HierarchicalCore(
latent_dims=self.latent_dims,
channels_per_block=self.channels_per_block,
down_channels_per_block=self.down_channels_per_block,
activation_fn=self.activation_fn,
convs_per_block=self.convs_per_block,
blocks_per_level=self.blocks_per_level,
)
self._f_comb = _StitchingDecoder(
latent_dims=self.latent_dims,
channels_per_block=self.channels_per_block,
num_classes=self.num_classes,
down_channels_per_block=self.down_channels_per_block,
activation_fn=self.activation_fn,
convs_per_block=self.convs_per_block,
blocks_per_level=self.blocks_per_level,
)
[docs]
def setup_task(self) -> None:
"""Set up the task."""
self.train_metrics = default_segmentation_metrics(
prefix="train", num_classes=self.num_classes, task=self.task
)
self.val_metrics = default_segmentation_metrics(
prefix="val", num_classes=self.num_classes, task=self.task
)
self.test_metrics = default_segmentation_metrics(
prefix="test", num_classes=self.num_classes, task=self.task
)
[docs]
def forward(self, X: Tensor) -> Tensor:
"""Forward pass through the model.
Args:
X: the input data
Returns:
the output of the model
"""
return self.model(X)
[docs]
def construct_latent_space(self, img: Tensor, seg_mask: Tensor) -> None:
"""Construct the latent space.
Args:
img: the input image
seg_mask: the segmentation mask
"""
inputs = (seg_mask, img)
if (
len(self._cache) == 2
and torch.all(self._cache[0] == seg_mask)
and torch.all(self._cache[1] == img)
):
return
else:
self._q_sample = self._posterior(
torch.cat([seg_mask, img], dim=1), mean=False
)
self._q_sample_mean = self._posterior(
torch.cat([seg_mask, img], dim=1), mean=True
)
self._p_sample = self._prior(img, mean=False, z_q=None)
self._p_sample_z_q = self._prior(img, z_q=self._q_sample["used_latents"])
self._p_sample_z_q_mean = self._prior(
img, z_q=self._q_sample_mean["used_latents"]
)
self._cache = inputs
return
[docs]
def compute_loss(
self, batch: dict[str, Tensor], loss_mask: Tensor | None = None
) -> Tensor:
"""Compute the loss from the output of the model.
Args:
batch: the batch of data containing image and segmentation mask
loss_mask: Optional tensor that masks some pixels from being
included in the loss
Returns:
the loss
"""
img = batch[self.input_key]
seg_mask = batch[self.target_key]
if len(seg_mask.shape) == 3:
seg_mask_target = seg_mask.long()
seg_mask_target = F.one_hot(seg_mask_target, num_classes=self.num_classes)
seg_mask_target = seg_mask_target.permute(
0, 3, 1, 2
).float() # move class dim to the channel dim
# channel dimension for concatenation
seg_mask = seg_mask.unsqueeze(1)
# forward pass through model that computes prior and posterior
# but does not return anything
self.construct_latent_space(img, seg_mask)
# Compute reconstruction loss
reconstruction = self.reconstruct(mean=False)
rec_loss = ce_loss(
reconstruction,
seg_mask_target,
loss_mask,
self.top_k_percentage,
self.deterministic_top_k,
)
# Compute KL divergence loss
kl_dict = self.kl()
kl_sum = torch.sum(torch.stack([kl for _, kl in kl_dict.items()], dim=-1))
# TODO summaries should be logged
# summaries["rec_loss_mean"] = rec_loss["mean"]
# summaries["rec_loss_sum"] = rec_loss["sum"]
# summaries["kl_sum"] = kl_sum
# for level, kl in kl_dict.items():
# summaries[f"kl_{level}"] = kl
if self.loss_type == "elbo":
loss = rec_loss["sum"] + self.beta * kl_sum
# summaries["elbo_loss"] = loss
elif self.loss_type == "geco":
# TODO need to keep a record of the moving average
ma_rec_loss = self._moving_average(rec_loss["sum"])
mask_sum_per_instance = torch.sum(rec_loss["mask"], dim=-1)
num_valid_pixels = torch.mean(mask_sum_per_instance)
reconstruction_threshold = self.kappa * num_valid_pixels
rec_constraint = ma_rec_loss - reconstruction_threshold
lagmul = self._lagmul(rec_constraint)
loss = lagmul * rec_constraint + kl_sum
# TODO should be logged
# summaries["geco_loss"] = loss
# summaries["ma_rec_loss_mean"] = ma_rec_loss / num_valid_pixels
# summaries["num_valid_pixels"] = num_valid_pixels
# summaries["lagmul"] = lagmul
return {
"loss": loss,
"kl_loss": kl_sum,
"rec_loss": rec_loss["sum"],
"reconstruction": reconstruction,
}
[docs]
def reconstruct(self, mean: bool = False) -> dict[str, Any]:
"""Reconstruct the input.
Args:
mean (bool, optional): Whether to use the mean. Defaults to False.
Returns:
dict[str, Any]: A dictionary containing encoder and decoder features.
"""
if mean:
prior_out = self._p_sample_z_q_mean
else:
prior_out = self._p_sample_z_q
encoder_features = prior_out["encoder_features"]
decoder_features = prior_out["decoder_features"]
return self._f_comb(
encoder_features=encoder_features, decoder_features=decoder_features
)
[docs]
def kl(self) -> dict[int, torch.Tensor]:
"""Compute the KL divergence.
Returns:
A dictionary containing the KL divergence for each level
"""
posterior_out = self._q_sample
prior_out = self._p_sample_z_q
q_dists = posterior_out["distributions"]
p_dists = prior_out["distributions"]
kl = {}
for level, (q, p) in enumerate(zip(q_dists, p_dists)):
kl_per_pixel = torch.distributions.kl_divergence(q, p)
kl_per_instance = torch.sum(kl_per_pixel, dim=[1, 2])
kl[level] = torch.mean(kl_per_instance)
return kl
[docs]
def sample(
self, img: torch.Tensor, mean: bool = False, z_q: torch.Tensor | None = None
) -> dict[str, Any]:
"""Sample from the model.
Args:
img: The image tensor.
mean: Whether to use the mean
z_q: Latent tensor
Returns:
A dictionary containing encoder and decoder features
"""
prior_out = self._prior(img, mean, z_q)
encoder_features = prior_out["encoder_features"]
decoder_features = prior_out["decoder_features"]
return self._f_comb(
encoder_features=encoder_features, decoder_features=decoder_features
)
[docs]
def training_step(
self, batch: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0
) -> Tensor:
"""Compute and return the training loss.
Args:
batch: the output of your DataLoader
batch_idx: the index of this batch
dataloader_idx: the index of the dataloader
Returns:
training loss
"""
loss_dict = self.compute_loss(batch)
self.log("train_loss", loss_dict["loss"])
self.log("train_kl_loss", loss_dict["kl_loss"])
self.log("train_rec_loss", loss_dict["rec_loss"])
# compute metrics with reconstruction
self.train_metrics(
loss_dict["reconstruction"],
batch[self.target_key],
batch_size=batch[self.input_key].shape[0],
)
return loss_dict["loss"]
[docs]
def validation_step(
self, batch: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0
) -> Tensor:
"""Compute and return the validation loss.
Args:
batch: the output of your DataLoader
batch_idx: the index of the batch
dataloader_idx: the index of the dataloader
Returns:
validation loss
"""
loss_dict = self.compute_loss(batch)
self.log("val_loss", loss_dict["loss"])
self.log("val_kl_loss", loss_dict["kl_loss"])
self.log("val_rec_loss", loss_dict["rec_loss"])
# compute metrics with reconstruction
self.val_metrics(
loss_dict["reconstruction"],
batch[self.target_key],
batch_size=batch[self.input_key].shape[0],
)
return loss_dict["loss"]
[docs]
def test_step(
self, batch: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0
) -> Tensor:
"""Compute and return the test loss.
Args:
batch: the output of your DataLoader
batch_idx: the index of the batch
dataloader_idx: the index of the dataloader
Returns:
test prediction dict
"""
preds = self.predict_step(batch[self.input_key])
# compute metrics with sampled reconstruction
self.test_metrics(
preds["pred"],
batch[self.target_key],
batch_size=batch[self.input_key].shape[0],
)
preds = self.add_aux_data_to_dict(preds, batch)
preds[self.target_key] = batch[self.target_key]
return preds
[docs]
def on_test_start(self) -> None:
"""Create logging directory and initialize metrics."""
self.pred_dir = os.path.join(self.trainer.default_root_dir, self.pred_dir_name)
if not os.path.exists(self.pred_dir) and self.save_preds:
os.makedirs(self.pred_dir)
[docs]
def on_test_batch_end(
self,
outputs: dict[str, Tensor],
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
"""Test batch end save predictions.
Args:
outputs: dictionary of model outputs and aux variables
batch: batch from dataloader
batch_idx: batch index
dataloader_idx: dataloader index
"""
if self.save_preds:
save_image_predictions(outputs, batch_idx, self.pred_dir)
[docs]
def predict_step(
self, X: Tensor, batch_idx: int = 0, dataloader_idx: int = 0
) -> Tensor:
"""Compute and return the prediction.
Args:
X: the input data
batch_idx: the index of the batch
dataloader_idx: the index of the dataloader
Returns:
prediction dict
"""
samples = torch.stack(
[self.sample(X) for _ in range(self.num_samples)], dim=-1
) # shape: (batch_size, num_classes, height, width, num_samples)
return process_segmentation_prediction(samples)
def _sample_gumbel(shape: torch.Size, eps: float = 1e-20) -> Tensor:
"""Transforms a uniform random variable to be standard Gumbel distributed.
Args:
shape: The shape of the output tensor.
eps: A small constant for numerical stability.
Returns:
A tensor of the specified shape sampled from the Gumbel distribution.
"""
U = torch.rand(shape)
return -torch.log(-torch.log(U + eps) + eps)
def _topk_mask(score: Tensor, k: int) -> Tensor:
"""Create a loss_mask for the top-k elements in a score tensor.
Args:
score: The score tensor.
k: The number of top elements to loss_mask.
Returns:
A binary loss_mask tensor of the same shape as the score tensor.
"""
_, indices = torch.topk(score, k=k)
loss_mask = torch.zeros_like(score)
loss_mask.scatter_(1, indices, 1)
return loss_mask
def ce_loss(
logits: Tensor,
labels: Tensor,
mask: Tensor | None = None,
top_k_percentage: float | None = None,
deterministic: bool = False,
) -> dict[str, Tensor]:
"""Compute the cross-entropy loss between logits and labels.
Args:
logits: The logits tensor of shape (batch_size, num_classes, height, width)
labels: The labels tensor of shape (batch_size, num_classes, height, width)
mask: An optional mask tensor
top_k_percentage: An optional percentage for the top-k loss_mask
deterministic: A boolean indicating whether to use deterministic top-k masking
Returns:
A dict containing the mean and sum of the cross-entropy loss,
and the mask
"""
num_classes = logits.shape[1]
y_flat = logits.view(-1, num_classes)
t_flat = torch.argmax(labels, dim=1).view(-1)
if mask is None:
mask = torch.ones(y_flat.shape[0])
else:
mask = mask.view(-1)
n_pixels_in_batch = y_flat.shape[0]
xe = nn.CrossEntropyLoss(reduction="none")(y_flat, t_flat)
if top_k_percentage is not None:
assert 0.0 < top_k_percentage <= 1.0
k_pixels = int(n_pixels_in_batch * top_k_percentage)
stopgrad_xe = xe.detach()
norm_xe = stopgrad_xe / torch.sum(stopgrad_xe)
if deterministic:
score = torch.log(norm_xe)
else:
score = torch.log(norm_xe) + _sample_gumbel(norm_xe.shape)
score = score + torch.log(mask)
top_k_mask = _topk_mask(score, k_pixels)
mask = mask * top_k_mask
xe = xe.view(logits.shape[0], -1)
mask = mask.view(logits.shape[0], -1)
ce_sum_per_instance = torch.sum(mask * xe, dim=1)
ce_sum = torch.mean(ce_sum_per_instance)
ce_mean = torch.sum(mask * xe) / torch.sum(mask)
return {"mean": ce_mean, "sum": ce_sum, "mask": mask}