Deep Kernel Learning Classification#
[1]:
%%capture
%pip install git+https://github.com/lightning-uq-box/lightning-uq-box.git
Imports#
[2]:
import os
import tempfile
from collections import defaultdict
from functools import partial
import torch
from lightning import Trainer
from lightning.pytorch import seed_everything
from lightning.pytorch.loggers import CSVLogger
from lightning_uq_box.datamodules import TwoMoonsDataModule
from lightning_uq_box.models import MLP
from lightning_uq_box.uq_methods import DKLClassification
from lightning_uq_box.viz_utils import (
plot_predictions_classification,
plot_training_metrics,
plot_two_moons_data,
)
[3]:
seed_everything(2)
Seed set to 2
[3]:
2
[4]:
# temporary directory for saving
my_temp_dir = tempfile.mkdtemp()
[5]:
dm = TwoMoonsDataModule(batch_size=100)
[6]:
# define data
X_train, Y_train, X_test, Y_test, test_grid_points = (
dm.X_train,
dm.Y_train,
dm.X_test,
dm.Y_test,
dm.test_grid_points,
)
[7]:
fig = plot_two_moons_data(X_train, Y_train, X_test, Y_test)
Feature Extractor#
[8]:
feature_extractor = MLP(
n_inputs=2, n_outputs=13, n_hidden=[50], activation_fn=torch.nn.ELU()
)
Deep Kernel Learning Model#
[9]:
dkl_model = DKLClassification(
feature_extractor,
gp_kernel="RBF",
num_classes=2,
optimizer=partial(torch.optim.Adam, lr=1e-2),
n_inducing_points=20,
)
Trainer#
[10]:
logger = CSVLogger(my_temp_dir)
trainer = Trainer(
accelerator="cpu",
max_epochs=100, # number of epochs we want to train
logger=logger, # log training metrics for later evaluation
log_every_n_steps=1,
enable_checkpointing=False,
enable_progress_bar=False,
default_root_dir=my_temp_dir,
)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
💡 Tip: For seamless cloud logging and experiment tracking, try installing [litlogger](https://pypi.org/project/litlogger/) to enable LitLogger, which logs metrics and artifacts automatically to the Lightning Experiments platform.
[11]:
trainer.fit(dkl_model, dm)
| Name | Type | Params | Mode | FLOPs
------------------------------------------------------------------------
0 | feature_extractor | MLP | 813 | train | 0
1 | train_metrics | MetricCollection | 0 | train | 0
2 | val_metrics | MetricCollection | 0 | train | 0
3 | test_metrics | MetricCollection | 0 | train | 0
4 | gp_layer | DKLGPLayer | 5.8 K | train | 0
5 | scale_to_bounds | ScaleToBounds | 0 | train | 0
6 | likelihood | SoftmaxLikelihood | 26 | train | 0
7 | elbo_fn | VariationalELBO | 5.8 K | train | 0
------------------------------------------------------------------------
6.6 K Trainable params
0 Non-trainable params
6.6 K Total params
0.026 Total estimated model params size (MB)
31 Modules in train mode
0 Modules in eval mode
0 Total Flops
/home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/latest/lib/python3.12/site-packages/lightning/pytorch/utilities/_pytree.py:21: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
`Trainer.fit` stopped: `max_epochs=100` reached.
Training Metrics#
[12]:
fig = plot_training_metrics(
os.path.join(my_temp_dir, "lightning_logs"), ["train_loss", "trainAcc"]
)
Prediction#
[13]:
# save predictions
trainer.test(dkl_model, dm.test_dataloader())
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Test metric DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
testAcc 0.9950000047683716
testCalibration 0.004979097284376621
testEmpirical Coverage 0.9950000047683716
test_loss 0.04922037571668625
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
[13]:
[{'test_loss': 0.04922037571668625,
'testAcc': 0.9950000047683716,
'testCalibration': 0.004979097284376621,
'testEmpirical Coverage': 0.9950000047683716}]
Evaluate Predictions#
[14]:
# due to the GP we need to predict in batches
batch_size = 200
batches = test_grid_points.chunk(
(test_grid_points.size(0) + batch_size - 1) // batch_size
)
preds = defaultdict(list)
for batch in batches:
for key, value in dkl_model.predict_step(batch).items():
if key != "out":
preds[key].append(value)
preds = {key: torch.cat(value, dim=0) for key, value in preds.items()}
[15]:
fig = plot_predictions_classification(
X_test,
Y_test,
preds["pred"].argmax(-1),
test_grid_points,
preds["pred_uct"].cpu().numpy(),
)