Deep Kernel Learning Classification#

[1]:
%%capture
%pip install git+https://github.com/lightning-uq-box/lightning-uq-box.git

Imports#

[2]:
import os
import tempfile
from collections import defaultdict
from functools import partial

import torch
from lightning import Trainer
from lightning.pytorch import seed_everything
from lightning.pytorch.loggers import CSVLogger

from lightning_uq_box.datamodules import TwoMoonsDataModule
from lightning_uq_box.models import MLP
from lightning_uq_box.uq_methods import DKLClassification
from lightning_uq_box.viz_utils import (
    plot_predictions_classification,
    plot_training_metrics,
    plot_two_moons_data,
)
INFO:root:Asdfghjkl backend not available since the old asdfghjkl dependency is not installed. If you want to use it, run: pip install git+https://git@github.com/wiseodd/asdl@asdfghjkl
[3]:
seed_everything(2)
Seed set to 2
[3]:
2
[4]:
# temporary directory for saving
my_temp_dir = tempfile.mkdtemp()
[5]:
dm = TwoMoonsDataModule(batch_size=100)
[6]:
# define data
X_train, Y_train, X_test, Y_test, test_grid_points = (
    dm.X_train,
    dm.Y_train,
    dm.X_test,
    dm.Y_test,
    dm.test_grid_points,
)
[7]:
fig = plot_two_moons_data(X_train, Y_train, X_test, Y_test)
../../_images/tutorials_classification_dkl_8_0.png

Feature Extractor#

[8]:
feature_extractor = MLP(
    n_inputs=2, n_outputs=13, n_hidden=[50], activation_fn=torch.nn.ELU()
)

Deep Kernel Learning Model#

[9]:
dkl_model = DKLClassification(
    feature_extractor,
    gp_kernel="RBF",
    num_classes=2,
    optimizer=partial(torch.optim.Adam, lr=1e-2),
    n_inducing_points=20,
)

Trainer#

[10]:
logger = CSVLogger(my_temp_dir)
trainer = Trainer(
    max_epochs=100,  # number of epochs we want to train
    logger=logger,  # log training metrics for later evaluation
    log_every_n_steps=1,
    enable_checkpointing=False,
    enable_progress_bar=False,
    default_root_dir=my_temp_dir,
)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
[11]:
trainer.fit(dkl_model, dm)

  | Name              | Type              | Params | Mode
----------------------------------------------------------------
0 | feature_extractor | MLP               | 813    | train
1 | train_metrics     | MetricCollection  | 0      | train
2 | val_metrics       | MetricCollection  | 0      | train
3 | test_metrics      | MetricCollection  | 0      | train
4 | gp_layer          | DKLGPLayer        | 5.8 K  | train
5 | scale_to_bounds   | ScaleToBounds     | 0      | train
6 | likelihood        | SoftmaxLikelihood | 26     | train
7 | elbo_fn           | VariationalELBO   | 5.8 K  | train
----------------------------------------------------------------
6.6 K     Trainable params
0         Non-trainable params
6.6 K     Total params
0.026     Total estimated model params size (MB)
31        Modules in train mode
0         Modules in eval mode
`Trainer.fit` stopped: `max_epochs=100` reached.

Training Metrics#

[12]:
fig = plot_training_metrics(
    os.path.join(my_temp_dir, "lightning_logs"), ["train_loss", "trainAcc"]
)
../../_images/tutorials_classification_dkl_17_0.png

Prediction#

[13]:
# save predictions
trainer.test(dkl_model, dm.test_dataloader())
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
         testAcc            0.9950000047683716
     testCalibration       0.007077312096953392
 testEmpirical Coverage     0.9950000047683716
        test_loss          0.052400149405002594
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
[13]:
[{'test_loss': 0.052400149405002594,
  'testAcc': 0.9950000047683716,
  'testCalibration': 0.007077312096953392,
  'testEmpirical Coverage': 0.9950000047683716}]

Evaluate Predictions#

[14]:
# due to the GP we need to predict in batches
batch_size = 200
batches = test_grid_points.chunk(
    (test_grid_points.size(0) + batch_size - 1) // batch_size
)


preds = defaultdict(list)

for batch in batches:
    for key, value in dkl_model.predict_step(batch).items():
        if key != "out":
            preds[key].append(value)

preds = {key: torch.cat(value, dim=0) for key, value in preds.items()}
[15]:
fig = plot_predictions_classification(
    X_test,
    Y_test,
    preds["pred"].argmax(-1),
    test_grid_points,
    preds["pred_uct"].cpu().numpy(),
)
../../_images/tutorials_classification_dkl_22_0.png