Classification and Regression Diffusion (CARD)

Classification and Regression Diffusion (CARD)#

[ ]:
%%capture
%pip install git+https://github.com/lightning-uq-box/lightning-uq-box.git

Imports#

[1]:
import os
import tempfile
from functools import partial

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from lightning import Trainer
from lightning.pytorch import seed_everything
from lightning.pytorch.loggers import CSVLogger

from lightning_uq_box.datamodules import TwoMoonsDataModule
from lightning_uq_box.models import MLP, ConditionalGuidedLinearModel
from lightning_uq_box.uq_methods import CARDClassification, DeterministicClassification
from lightning_uq_box.viz_utils import (
    plot_predictions_classification,
    plot_training_metrics,
    plot_two_moons_data,
)

plt.rcParams["figure.figsize"] = [14, 5]

%load_ext autoreload
%autoreload 2
[2]:
seed_everything(0)  # seed everything for reproducibility
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # set the GPU device
torch.set_float32_matmul_precision("medium")
Seed set to 0
[3]:
my_temp_dir = tempfile.mkdtemp()

Datamodule#

[4]:
dm = TwoMoonsDataModule(batch_size=256, n_samples=8000)

X_train, Y_train, X_test, Y_test, test_grid_points = (
    dm.X_train,
    dm.Y_train,
    dm.X_test,
    dm.Y_test,
    dm.test_grid_points,
)
[5]:
fig = plot_two_moons_data(X_train, Y_train, X_test, Y_test)
../../_images/tutorials_classification_card_8_0.png

Model#

Pretrain the conditional mean estimation model before moving on to the diffusion training stage. We will use an MLP with the DeterministicClassification module.

[6]:
network = MLP(n_inputs=2, n_hidden=[50, 50], n_outputs=2)

cond_mean_model = DeterministicClassification(
    model=network,
    optimizer=partial(torch.optim.Adam, lr=1e-2),
    loss_fn=nn.CrossEntropyLoss(),
)

trainer = Trainer(
    max_epochs=100,  # number of epochs we want to train
    devices=[0],
    accelerator="gpu",
    enable_checkpointing=False,
    enable_progress_bar=False,
)
trainer.fit(cond_mean_model, dm)
/home/nils/.conda/envs/uqboxEnv/lib/python3.9/site-packages/lightning/fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/nils/.conda/envs/uqboxEnv/lib/python3.9/site-p ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name          | Type             | Params
---------------------------------------------------
0 | model         | MLP              | 2.8 K
1 | loss_fn       | CrossEntropyLoss | 0
2 | train_metrics | MetricCollection | 0
3 | val_metrics   | MetricCollection | 0
4 | test_metrics  | MetricCollection | 0
---------------------------------------------------
2.8 K     Trainable params
0         Non-trainable params
2.8 K     Total params
0.011     Total estimated model params size (MB)
/home/nils/.conda/envs/uqboxEnv/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.
/home/nils/.conda/envs/uqboxEnv/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.
/home/nils/.conda/envs/uqboxEnv/lib/python3.9/site-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (19) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
`Trainer.fit` stopped: `max_epochs=100` reached.

Diffusion Model#

[7]:
# configuration parameters

n_steps = 1000
cat_x = True  # condition on input x through concatenation
cat_y_pred = True  # condition on y_0_hat

x_dim = 2
y_dim = 2

n_hidden = [128, 128]

beta_schedule = "linear"
beta_start = 0.0001
beta_end = 0.02
[8]:
guidance_model = ConditionalGuidedLinearModel(
    n_steps=n_steps,
    x_dim=x_dim,
    y_dim=y_dim,
    n_hidden=n_hidden,
    cat_x=cat_x,
    cat_y_pred=cat_y_pred,
)
[9]:
card_model = CARDClassification(
    cond_mean_model=cond_mean_model.model,
    guidance_model=guidance_model,
    guidance_optim=partial(torch.optim.Adam, lr=1e-2),
    beta_schedule=beta_schedule,
    beta_start=beta_start,
    beta_end=beta_end,
    n_steps=n_steps,
)
[10]:
logger = CSVLogger(my_temp_dir)
diff_trainer = Trainer(
    max_epochs=1000,  # number of epochs we want to train
    accelerator="gpu",
    devices=[0],
    logger=logger,
    log_every_n_steps=10,
    enable_checkpointing=False,
    enable_progress_bar=False,
    default_root_dir=my_temp_dir,
)
diff_trainer.fit(card_model, dm)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: /tmp/tmpc86dexzc/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name            | Type                         | Params
-----------------------------------------------------------------
0 | cond_mean_model | MLP                          | 2.8 K
1 | guidance_model  | ConditionalGuidedLinearModel | 273 K
2 | train_metrics   | MetricCollection             | 0
3 | val_metrics     | MetricCollection             | 0
4 | test_metrics    | MetricCollection             | 0
-----------------------------------------------------------------
276 K     Trainable params
0         Non-trainable params
276 K     Total params
1.106     Total estimated model params size (MB)
`Trainer.fit` stopped: `max_epochs=1000` reached.

Training Metrics#

[11]:
fig = plot_training_metrics(os.path.join(my_temp_dir, "lightning_logs"), ["train_loss"])
../../_images/tutorials_classification_card_17_0.png

Prediction#

[12]:
card_model = card_model.to("cuda")
preds = card_model.predict_step(test_grid_points.cuda())
[13]:
fig = plot_predictions_classification(
    X_test, Y_test, preds["pred"].argmax(-1), test_grid_points, preds["pred_uct"]
)
../../_images/tutorials_classification_card_20_0.png
[ ]: