Spectral Normalized Gaussian Process (SNGP) Classification

Spectral Normalized Gaussian Process (SNGP) Classification#

[1]:
%%capture
%pip install git+https://github.com/lightning-uq-box/lightning-uq-box.git
[2]:
import os
import tempfile

import torch
from lightning import Trainer
from lightning.pytorch import seed_everything
from lightning.pytorch.loggers import CSVLogger

from lightning_uq_box.datamodules import TwoMoonsDataModule
from lightning_uq_box.models.fc_resnet import FCResNet
from lightning_uq_box.uq_methods import SNGPClassification
from lightning_uq_box.viz_utils import (
    plot_predictions_classification,
    plot_training_metrics,
    plot_two_moons_data,
)

%load_ext autoreload
%autoreload 2
[3]:
seed_everything(2)
# temporary directory for saving
my_temp_dir = tempfile.mkdtemp()
Seed set to 2

Datamodule#

[4]:
dm = TwoMoonsDataModule(batch_size=128)
[5]:
# define data
X_train, Y_train, X_test, Y_test, test_grid_points = (
    dm.X_train,
    dm.Y_train,
    dm.X_test,
    dm.Y_test,
    dm.test_grid_points,
)
[6]:
fig = plot_two_moons_data(X_train, Y_train, X_test, Y_test)
../../_images/tutorials_classification_sngp_7_0.png

Model#

[7]:
feature_extractor = FCResNet(input_dim=2, features=64, depth=4)
[8]:
sngp = SNGPClassification(
    feature_extractor=feature_extractor,
    loss_fn=torch.nn.CrossEntropyLoss(),
    num_targets=2,
)

Trainer#

[9]:
logger = CSVLogger(my_temp_dir)
trainer = Trainer(
    accelerator="cpu",
    max_epochs=100,  # number of epochs we want to train
    logger=logger,  # log training metrics for later evaluation
    log_every_n_steps=1,
    enable_checkpointing=False,
    enable_progress_bar=False,
    default_root_dir=my_temp_dir,
)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
💡 Tip: For seamless cloud logging and experiment tracking, try installing [litlogger](https://pypi.org/project/litlogger/) to enable LitLogger, which logs metrics and artifacts automatically to the Lightning Experiments platform.
[10]:
trainer.fit(sngp, dm)

  | Name              | Type                  | Params | Mode  | FLOPs
----------------------------------------------------------------------------
0 | feature_extractor | FCResNet              | 16.8 K | train | 0
1 | loss_fn           | CrossEntropyLoss      | 0      | train | 0
2 | normalize         | LayerNorm             | 256    | train | 0
3 | rff               | RandomFourierFeatures | 0      | train | 0
4 | beta              | Linear                | 2.0 K  | train | 0
5 | train_metrics     | MetricCollection      | 0      | train | 0
6 | val_metrics       | MetricCollection      | 0      | train | 0
7 | test_metrics      | MetricCollection      | 0      | train | 0
----------------------------------------------------------------------------
19.1 K    Trainable params
0         Non-trainable params
19.1 K    Total params
0.077     Total estimated model params size (MB)
24        Modules in train mode
0         Modules in eval mode
0         Total Flops
/home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/latest/lib/python3.12/site-packages/lightning/pytorch/utilities/_pytree.py:21: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
`Trainer.fit` stopped: `max_epochs=100` reached.
[11]:
fig = plot_training_metrics(
    os.path.join(my_temp_dir, "lightning_logs"), ["train_loss", "trainAcc"]
)
../../_images/tutorials_classification_sngp_14_0.png

Predictions#

We can plot the predictions for a grid of test points spanning the extent of the input data and visualize the decision boundaries and corresponding uncertainty.

[12]:
preds = sngp.predict_step(test_grid_points.to(sngp.device))
[13]:
fig = plot_predictions_classification(
    X_test,
    Y_test,
    preds["pred"].argmax(-1),
    test_grid_points,
    preds["pred_uct"].cpu().numpy(),
)
../../_images/tutorials_classification_sngp_17_0.png