Deterministic Classification#
This notebook shows how to train a standard classification network and utilize the entropy of the softmax outputs as a measuere of uncertainty.
[1]:
%%capture
%pip install git+https://github.com/lightning-uq-box/lightning-uq-box.git
Imports#
[2]:
import os
import tempfile
from functools import partial
import matplotlib.pyplot as plt
import torch.nn as nn
from lightning import Trainer
from lightning.pytorch import seed_everything
from lightning.pytorch.loggers import CSVLogger
from torch.optim import Adam
from lightning_uq_box.datamodules import TwoMoonsDataModule
from lightning_uq_box.models import MLP
from lightning_uq_box.uq_methods import DeterministicClassification
from lightning_uq_box.viz_utils import (
plot_predictions_classification,
plot_training_metrics,
plot_two_moons_data,
)
plt.rcParams["figure.figsize"] = [14, 5]
%load_ext autoreload
%autoreload 2
[3]:
seed_everything(0)
Seed set to 0
[3]:
0
[4]:
my_temp_dir = tempfile.mkdtemp()
Datamodule#
[5]:
dm = TwoMoonsDataModule(batch_size=128)
X_train, Y_train, X_test, Y_test, test_grid_points = (
dm.X_train,
dm.Y_train,
dm.X_test,
dm.Y_test,
dm.test_grid_points,
)
[6]:
fig = plot_two_moons_data(X_train, Y_train, X_test, Y_test)
Model#
[7]:
network = MLP(
n_inputs=2,
n_hidden=[50, 50, 50],
n_outputs=2,
dropout_p=0.2,
activation_fn=nn.ReLU(),
)
network
[7]:
MLP(
(model): Sequential(
(0): Linear(in_features=2, out_features=50, bias=True)
(1): ReLU()
(2): Dropout(p=0.2, inplace=False)
(3): Linear(in_features=50, out_features=50, bias=True)
(4): ReLU()
(5): Dropout(p=0.2, inplace=False)
(6): Linear(in_features=50, out_features=50, bias=True)
(7): ReLU()
(8): Dropout(p=0.2, inplace=False)
(9): Linear(in_features=50, out_features=2, bias=True)
)
)
[8]:
mc_dropout_module = DeterministicClassification(
model=network, optimizer=partial(Adam, lr=1e-2), loss_fn=nn.CrossEntropyLoss()
)
Trainer#
[9]:
logger = CSVLogger(my_temp_dir)
trainer = Trainer(
accelerator="cpu",
max_epochs=100, # number of epochs we want to train
logger=logger, # log training metrics for later evaluation
log_every_n_steps=1,
enable_checkpointing=False,
enable_progress_bar=False,
default_root_dir=my_temp_dir,
)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
💡 Tip: For seamless cloud logging and experiment tracking, try installing [litlogger](https://pypi.org/project/litlogger/) to enable LitLogger, which logs metrics and artifacts automatically to the Lightning Experiments platform.
[10]:
trainer.fit(mc_dropout_module, dm)
| Name | Type | Params | Mode | FLOPs
-------------------------------------------------------------------
0 | model | MLP | 5.4 K | train | 0
1 | loss_fn | CrossEntropyLoss | 0 | train | 0
2 | train_metrics | MetricCollection | 0 | train | 0
3 | val_metrics | MetricCollection | 0 | train | 0
4 | test_metrics | MetricCollection | 0 | train | 0
-------------------------------------------------------------------
5.4 K Trainable params
0 Non-trainable params
5.4 K Total params
0.021 Total estimated model params size (MB)
23 Modules in train mode
0 Modules in eval mode
0 Total Flops
/home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/latest/lib/python3.12/site-packages/lightning/pytorch/utilities/_pytree.py:21: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
`Trainer.fit` stopped: `max_epochs=100` reached.
Training Metrics#
[11]:
fig = plot_training_metrics(
os.path.join(my_temp_dir, "lightning_logs"), ["train_loss", "trainAcc"]
)
Prediction#
[12]:
# save predictions
trainer.test(mc_dropout_module, dm.test_dataloader())
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Test metric DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
testAcc 1.0
testCalibration 0.00033162086037918925
testEmpirical Coverage 1.0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
[12]:
[{'testAcc': 1.0,
'testCalibration': 0.00033162086037918925,
'testEmpirical Coverage': 1.0}]
[13]:
preds = mc_dropout_module.predict_step(test_grid_points)
Evaluate Predictions#
[14]:
fig = plot_predictions_classification(
X_test, Y_test, preds["pred"].argmax(-1), test_grid_points, preds["pred_uct"]
)