Mixture Density Network 1D Regression#

This tutorial will showcase the application of a fairly “old” (in ML age) modeling technique called Mixture Density Networks that can be used for supervised regression problems, as well as inverse problems. For a very nice tutorial, with more mathematical details, see this tutorial.

[1]:
%%capture
%pip install git+https://github.com/lightning-uq-box/lightning-uq-box.git

Imports#

[26]:
import os
import tempfile
from functools import partial

import matplotlib.pyplot as plt
import torch.nn as nn
from lightning import Trainer
from lightning.pytorch import seed_everything
from lightning.pytorch.loggers import CSVLogger
from torch.optim import Adam

from lightning_uq_box.datamodules import ToyHeteroscedasticDatamodule
from lightning_uq_box.models import MLP
from lightning_uq_box.uq_methods import DeterministicRegression, MDNRegression
from lightning_uq_box.viz_utils import (
    plot_predictions_regression,
    plot_toy_regression_data,
    plot_training_metrics,
)

plt.rcParams["figure.figsize"] = [14, 5]

%load_ext autoreload
%autoreload 2
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
[3]:
seed_everything(0)  # seed everything for reproducibility
Seed set to 0
[3]:
0

We define a temporary directory to look at some training metrics and results.

[4]:
my_temp_dir = tempfile.mkdtemp()

Datamodule#

[5]:
dm = ToyHeteroscedasticDatamodule(n_points=500)

X_train, Y_train, train_loader, X_test, Y_test, test_loader, X_gtext, Y_gtext = (
    dm.X_train,
    dm.Y_train,
    dm.train_dataloader(),
    dm.X_test,
    dm.Y_test,
    dm.test_dataloader(),
    dm.X_gtext,
    dm.Y_gtext,
)
[6]:
fig = plot_toy_regression_data(X_train, Y_train, X_test, Y_test)
../../_images/tutorials_regression_mixture_density_9_0.png

Deterministic Model#

For this toy regression problem, we will use a simple Mulit-layer Perceptron (MLP).

[7]:
network = MLP(
    n_inputs=1,
    n_hidden=[50, 50, 50],
    n_outputs=1,
    dropout_p=0.1,
    activation_fn=nn.Tanh(),
)
network
[7]:
MLP(
  (model): Sequential(
    (0): Linear(in_features=1, out_features=50, bias=True)
    (1): Tanh()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=50, out_features=50, bias=True)
    (4): Tanh()
    (5): Dropout(p=0.1, inplace=False)
    (6): Linear(in_features=50, out_features=50, bias=True)
    (7): Tanh()
    (8): Dropout(p=0.1, inplace=False)
    (9): Linear(in_features=50, out_features=1, bias=True)
  )
)

Fit deterministic model#

[8]:
det_model = DeterministicRegression(
    model=network, loss_fn=nn.MSELoss(), optimizer=partial(Adam, lr=0.003)
)
[9]:
logger = CSVLogger(my_temp_dir)
trainer = Trainer(
    max_epochs=250,  # number of epochs we want to train
    logger=logger,  # log training metrics for later evaluation
    log_every_n_steps=20,
    enable_checkpointing=False,
    enable_progress_bar=False,
    default_root_dir=my_temp_dir,
)

trainer.fit(det_model, dm)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: /tmp/tmp2t51mro4/lightning_logs

  | Name          | Type             | Params
---------------------------------------------------
0 | model         | MLP              | 5.3 K
1 | loss_fn       | MSELoss          | 0
2 | train_metrics | MetricCollection | 0
3 | val_metrics   | MetricCollection | 0
4 | test_metrics  | MetricCollection | 0
---------------------------------------------------
5.3 K     Trainable params
0         Non-trainable params
5.3 K     Total params
0.021     Total estimated model params size (MB)
/home/nils/miniconda3/envs/py311uqbox/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/home/nils/miniconda3/envs/py311uqbox/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/home/nils/miniconda3/envs/py311uqbox/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (4) is smaller than the logging interval Trainer(log_every_n_steps=20). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
`Trainer.fit` stopped: `max_epochs=250` reached.
[10]:
fig = plot_training_metrics(
    os.path.join(my_temp_dir, "lightning_logs"), ["train_loss", "trainRMSE"]
)
../../_images/tutorials_regression_mixture_density_15_0.png
[11]:
deterministic_preds = det_model.predict_step(X_gtext)
fig = plot_predictions_regression(
    X_train,
    Y_train,
    X_gtext,
    Y_gtext,
    deterministic_preds["pred"].squeeze(-1),
    title="Supervised Regression Deterministic Model",
)
../../_images/tutorials_regression_mixture_density_16_0.png

Fit deterministic model on inverse problem#

Okay, this is standard stuff, but what happens if we consider an inverse problem, where we swap X and Y and then solve for Y.

[28]:
dm_inverse = ToyHeteroscedasticDatamodule(n_points=500, invert=True)

X_train_inv, Y_train_inv, X_test_inv, Y_test_inv, X_gtext_inv, Y_gtext_inv = (
    dm_inverse.X_train,
    dm_inverse.Y_train,
    dm_inverse.X_test,
    dm_inverse.Y_test,
    dm_inverse.X_gtext,
    dm_inverse.Y_gtext,
)
[29]:
fig = plot_toy_regression_data(X_train_inv, Y_train_inv, X_test_inv, Y_test_inv)
../../_images/tutorials_regression_mixture_density_19_0.png
[30]:
network = MLP(
    n_inputs=1,
    n_hidden=[50, 50, 50],
    n_outputs=1,
    dropout_p=0.1,
    activation_fn=nn.Tanh(),
)
det_inv_model = DeterministicRegression(
    model=network, loss_fn=nn.MSELoss(), optimizer=partial(Adam, lr=0.003)
)
[31]:
logger = CSVLogger(my_temp_dir)
trainer = Trainer(
    max_epochs=250,  # number of epochs we want to train
    logger=logger,  # log training metrics for later evaluation
    log_every_n_steps=20,
    enable_checkpointing=False,
    enable_progress_bar=False,
    default_root_dir=my_temp_dir,
)

trainer.fit(det_inv_model, dm_inverse)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name          | Type             | Params
---------------------------------------------------
0 | model         | MLP              | 5.3 K
1 | loss_fn       | MSELoss          | 0
2 | train_metrics | MetricCollection | 0
3 | val_metrics   | MetricCollection | 0
4 | test_metrics  | MetricCollection | 0
---------------------------------------------------
5.3 K     Trainable params
0         Non-trainable params
5.3 K     Total params
0.021     Total estimated model params size (MB)
/home/nils/miniconda3/envs/py311uqbox/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/home/nils/miniconda3/envs/py311uqbox/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/home/nils/miniconda3/envs/py311uqbox/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (4) is smaller than the logging interval Trainer(log_every_n_steps=20). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
`Trainer.fit` stopped: `max_epochs=250` reached.
[34]:
det_inv_preds = det_inv_model.predict_step(X_gtext_inv)
fig = plot_predictions_regression(
    X_train_inv,
    Y_train_inv,
    X_gtext_inv,
    # Y_gtext_inv,
    Y_gtext_inv,
    det_inv_preds["pred"].squeeze(-1),
    title="Inverse Problem Deterministic Model",
)
../../_images/tutorials_regression_mixture_density_22_0.png

We can see that the model has difficulties fitting the data, as it tries to find a deterministic mapping for this multi-modal dataset.

Mixture Density Network Model#

The Mixture Density Network tackles the inverse problem by instead modeling the data as a mixture of multiple distributions, and effectively train a neural network to parameterize a Gaussian Mixture Model.

For the inverse problem, we will again use a simple Mulit-layer Perceptron (MLP), but adapt it afterwards for the Mixture Density Network idea.

[35]:
mdn_temp_dir = tempfile.mkdtemp()
[36]:
mdn_base_network = MLP(
    n_inputs=1,
    n_hidden=[50, 50, 50],
    n_outputs=1,
    dropout_p=0.1,
    activation_fn=nn.Tanh(),
)
mdn_base_network
[36]:
MLP(
  (model): Sequential(
    (0): Linear(in_features=1, out_features=50, bias=True)
    (1): Tanh()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=50, out_features=50, bias=True)
    (4): Tanh()
    (5): Dropout(p=0.1, inplace=False)
    (6): Linear(in_features=50, out_features=50, bias=True)
    (7): Tanh()
    (8): Dropout(p=0.1, inplace=False)
    (9): Linear(in_features=50, out_features=1, bias=True)
  )
)

We will create a model with a Mixture Density Layer as the last layer. This last layer will replace the last layer of the specified MLP architecture above, such that the method remains applicable for “off-the-shelve” models like found in timm as well as custom architectures of course. You can also only train the last layer and keep all other parameters frozen (via the freeze_backbone flag), which might be interesting for finetuning pretrained models to have a probabilistic output.

[37]:
mdn_model = MDNRegression(
    model=mdn_base_network, n_components=3, optimizer=partial(Adam, lr=0.003)
)
[38]:
logger = CSVLogger(mdn_temp_dir)
trainer = Trainer(
    max_epochs=250,  # number of epochs we want to train
    logger=logger,  # log training metrics for later evaluation
    log_every_n_steps=20,
    enable_checkpointing=False,
    enable_progress_bar=False,
    default_root_dir=my_temp_dir,
)

trainer.fit(mdn_model, dm_inverse)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: /tmp/tmp6_g7pdxb/lightning_logs

  | Name          | Type               | Params
-----------------------------------------------------
0 | model         | MLP                | 10.8 K
1 | loss_fn       | MixtureDensityLoss | 0
2 | train_metrics | MetricCollection   | 0
3 | val_metrics   | MetricCollection   | 0
4 | test_metrics  | MetricCollection   | 0
-----------------------------------------------------
10.8 K    Trainable params
0         Non-trainable params
10.8 K    Total params
0.043     Total estimated model params size (MB)
`Trainer.fit` stopped: `max_epochs=250` reached.

We can take a brief look at the loss to check whether the model has been trained to reasonable convergence.

[39]:
fig = plot_training_metrics(
    os.path.join(mdn_temp_dir, "lightning_logs"), ["train_loss", "val_loss"]
)
../../_images/tutorials_regression_mixture_density_31_0.png

Plot Predictions#

Now we are ready for the “big reveal”, namely whether we can fit the data more accurately with the Mixture Density approach.

[40]:
mixture_density_pred = mdn_model.predict_step(X_gtext_inv)

fig = plot_predictions_regression(
    X_train_inv,
    Y_train_inv,
    X_gtext_inv,
    Y_gtext_inv,
    mixture_density_pred["pred"],
    aleatoric=mixture_density_pred["pred_uct"],
    title="Mixture Density Mean Predictions",
    show_bands=False,
)
../../_images/tutorials_regression_mixture_density_33_0.png

Sampling#

We can also draw conditional samples from the mixture distributions.

[55]:
mixture_density_samples = mdn_model.sample(X_gtext_inv)

fig, axs = plt.subplots(1, figsize=(12, 7))

axs.scatter(
    X_gtext_inv,
    mixture_density_samples.squeeze(-1).T,
    color="orange",
    alpha=0.6,
    label="Mixture Density Samples",
)
axs.scatter(X_train_inv, Y_train_inv, color="blue", label="Training Data", alpha=0.6)

axs.set_title("Conditional Mixture Density Samples over input range")
plt.legend()
[55]:
<matplotlib.legend.Legend at 0x7f50c44c1190>
../../_images/tutorials_regression_mixture_density_35_1.png

It looks like the Mixture Density Model was able to approximate \(p(y|x)\) of the inverse problem reasonably well, much better than a default deterministic network.