Masksemble#

[1]:
%pip install git+https://github.com/lightning-uq-box/lightning-uq-box.git
Collecting git+https://github.com/lightning-uq-box/lightning-uq-box.git
  Cloning https://github.com/lightning-uq-box/lightning-uq-box.git to /tmp/pip-req-build-ezwvqb3r
  Running command git clone --filter=blob:none --quiet https://github.com/lightning-uq-box/lightning-uq-box.git /tmp/pip-req-build-ezwvqb3r
  Resolved https://github.com/lightning-uq-box/lightning-uq-box.git to commit bee3172f68a5a21cd63c7996f4dcbce30c40adf6
  Installing build dependencies ... - \ | done
  Getting requirements to build wheel ... - done
  Preparing metadata (pyproject.toml) ... - done
Requirement already satisfied: einops>=0.3 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from lightning-uq-box==0.2.0) (0.8.0)
Requirement already satisfied: lightning>=2.4.0 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from lightning-uq-box==0.2.0) (2.4.0)
Requirement already satisfied: matplotlib>=3.5 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from lightning-uq-box==0.2.0) (3.9.3)
Requirement already satisfied: numpy>=1.21.1 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from lightning-uq-box==0.2.0) (1.26.4)
Requirement already satisfied: pandas>=1.1.3 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from lightning-uq-box==0.2.0) (2.2.3)
Requirement already satisfied: torch>=2.0 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from lightning-uq-box==0.2.0) (2.5.1)
Requirement already satisfied: torchmetrics>=1.2 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from lightning-uq-box==0.2.0) (1.6.0)
Requirement already satisfied: torchvision>=0.16.1 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from lightning-uq-box==0.2.0) (0.20.1)
Requirement already satisfied: scikit-learn>=1.3 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from lightning-uq-box==0.2.0) (1.5.2)
Requirement already satisfied: gpytorch>=1.11 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from lightning-uq-box==0.2.0) (1.13)
Requirement already satisfied: laplace-torch>=0.2.1 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from lightning-uq-box==0.2.0) (0.2.2.2)
Requirement already satisfied: uncertainty-toolbox>=0.1.1 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from lightning-uq-box==0.2.0) (0.1.1)
Requirement already satisfied: kornia>=0.6.9 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from lightning-uq-box==0.2.0) (0.7.4)
Requirement already satisfied: timm>=0.9.2 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from lightning-uq-box==0.2.0) (1.0.11)
Requirement already satisfied: torchseg>=0.0.1a1 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from lightning-uq-box==0.2.0) (0.0.1a4)
Requirement already satisfied: h5py>=3.12.1 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from lightning-uq-box==0.2.0) (3.12.1)
Requirement already satisfied: ema-pytorch>=0.7.0 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from lightning-uq-box==0.2.0) (0.7.6)
Requirement already satisfied: jaxtyping==0.2.19 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from gpytorch>=1.11->lightning-uq-box==0.2.0) (0.2.19)
Requirement already satisfied: mpmath<=1.3,>=0.19 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from gpytorch>=1.11->lightning-uq-box==0.2.0) (1.3.0)
Requirement already satisfied: scipy>=1.6.0 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from gpytorch>=1.11->lightning-uq-box==0.2.0) (1.14.1)
Requirement already satisfied: linear-operator>=0.5.3 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from gpytorch>=1.11->lightning-uq-box==0.2.0) (0.5.3)
Requirement already satisfied: typeguard>=2.13.3 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from jaxtyping==0.2.19->gpytorch>=1.11->lightning-uq-box==0.2.0) (4.4.1)
Requirement already satisfied: typing-extensions>=3.7.4.1 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from jaxtyping==0.2.19->gpytorch>=1.11->lightning-uq-box==0.2.0) (4.12.2)
Requirement already satisfied: kornia-rs>=0.1.0 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from kornia>=0.6.9->lightning-uq-box==0.2.0) (0.1.7)
Requirement already satisfied: packaging in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from kornia>=0.6.9->lightning-uq-box==0.2.0) (24.2)
Requirement already satisfied: asdfghjkl==0.1a4 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from laplace-torch>=0.2.1->lightning-uq-box==0.2.0) (0.1a4)
Requirement already satisfied: backpack-for-pytorch in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from laplace-torch>=0.2.1->lightning-uq-box==0.2.0) (1.7.1)
Requirement already satisfied: curvlinops-for-pytorch>=2.0 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from laplace-torch>=0.2.1->lightning-uq-box==0.2.0) (2.0.1)
Requirement already satisfied: opt_einsum in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from laplace-torch>=0.2.1->lightning-uq-box==0.2.0) (3.4.0)
Requirement already satisfied: PyYAML<8.0,>=5.4 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from lightning>=2.4.0->lightning-uq-box==0.2.0) (6.0.2)
Requirement already satisfied: fsspec<2026.0,>=2022.5.0 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from fsspec[http]<2026.0,>=2022.5.0->lightning>=2.4.0->lightning-uq-box==0.2.0) (2024.10.0)
Requirement already satisfied: lightning-utilities<2.0,>=0.10.0 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from lightning>=2.4.0->lightning-uq-box==0.2.0) (0.11.9)
Requirement already satisfied: tqdm<6.0,>=4.57.0 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from lightning>=2.4.0->lightning-uq-box==0.2.0) (4.67.1)
Requirement already satisfied: pytorch-lightning in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from lightning>=2.4.0->lightning-uq-box==0.2.0) (2.4.0)
Requirement already satisfied: contourpy>=1.0.1 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from matplotlib>=3.5->lightning-uq-box==0.2.0) (1.3.1)
Requirement already satisfied: cycler>=0.10 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from matplotlib>=3.5->lightning-uq-box==0.2.0) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from matplotlib>=3.5->lightning-uq-box==0.2.0) (4.55.1)
Requirement already satisfied: kiwisolver>=1.3.1 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from matplotlib>=3.5->lightning-uq-box==0.2.0) (1.4.7)
Requirement already satisfied: pillow>=8 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from matplotlib>=3.5->lightning-uq-box==0.2.0) (11.0.0)
Requirement already satisfied: pyparsing>=2.3.1 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from matplotlib>=3.5->lightning-uq-box==0.2.0) (3.2.0)
Requirement already satisfied: python-dateutil>=2.7 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from matplotlib>=3.5->lightning-uq-box==0.2.0) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from pandas>=1.1.3->lightning-uq-box==0.2.0) (2024.2)
Requirement already satisfied: tzdata>=2022.7 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from pandas>=1.1.3->lightning-uq-box==0.2.0) (2024.2)
Requirement already satisfied: joblib>=1.2.0 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from scikit-learn>=1.3->lightning-uq-box==0.2.0) (1.4.2)
Requirement already satisfied: threadpoolctl>=3.1.0 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from scikit-learn>=1.3->lightning-uq-box==0.2.0) (3.5.0)
Requirement already satisfied: huggingface_hub in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from timm>=0.9.2->lightning-uq-box==0.2.0) (0.26.3)
Requirement already satisfied: safetensors in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from timm>=0.9.2->lightning-uq-box==0.2.0) (0.4.5)
Requirement already satisfied: filelock in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from torch>=2.0->lightning-uq-box==0.2.0) (3.16.1)
Requirement already satisfied: networkx in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from torch>=2.0->lightning-uq-box==0.2.0) (3.4.2)
Requirement already satisfied: jinja2 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from torch>=2.0->lightning-uq-box==0.2.0) (3.1.4)
Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from torch>=2.0->lightning-uq-box==0.2.0) (12.4.127)
Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from torch>=2.0->lightning-uq-box==0.2.0) (12.4.127)
Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from torch>=2.0->lightning-uq-box==0.2.0) (12.4.127)
Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from torch>=2.0->lightning-uq-box==0.2.0) (9.1.0.70)
Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from torch>=2.0->lightning-uq-box==0.2.0) (12.4.5.8)
Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from torch>=2.0->lightning-uq-box==0.2.0) (11.2.1.3)
Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from torch>=2.0->lightning-uq-box==0.2.0) (10.3.5.147)
Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from torch>=2.0->lightning-uq-box==0.2.0) (11.6.1.9)
Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from torch>=2.0->lightning-uq-box==0.2.0) (12.3.1.170)
Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from torch>=2.0->lightning-uq-box==0.2.0) (2.21.5)
Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from torch>=2.0->lightning-uq-box==0.2.0) (12.4.127)
Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from torch>=2.0->lightning-uq-box==0.2.0) (12.4.127)
Requirement already satisfied: triton==3.1.0 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from torch>=2.0->lightning-uq-box==0.2.0) (3.1.0)
Requirement already satisfied: setuptools in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from torch>=2.0->lightning-uq-box==0.2.0) (75.6.0)
Requirement already satisfied: sympy==1.13.1 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from torch>=2.0->lightning-uq-box==0.2.0) (1.13.1)
Requirement already satisfied: einconv in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from curvlinops-for-pytorch>=2.0->laplace-torch>=0.2.1->lightning-uq-box==0.2.0) (0.1.0)
Requirement already satisfied: unfoldNd<1.0.0,>=0.2.0 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from backpack-for-pytorch->laplace-torch>=0.2.1->lightning-uq-box==0.2.0) (0.2.2)
Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from fsspec[http]<2026.0,>=2022.5.0->lightning>=2.4.0->lightning-uq-box==0.2.0) (3.11.9)
Requirement already satisfied: six>=1.5 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from python-dateutil>=2.7->matplotlib>=3.5->lightning-uq-box==0.2.0) (1.16.0)
Requirement already satisfied: requests in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from huggingface_hub->timm>=0.9.2->lightning-uq-box==0.2.0) (2.32.3)
Requirement already satisfied: MarkupSafe>=2.0 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from jinja2->torch>=2.0->lightning-uq-box==0.2.0) (3.0.2)
Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning>=2.4.0->lightning-uq-box==0.2.0) (2.4.4)
Requirement already satisfied: aiosignal>=1.1.2 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning>=2.4.0->lightning-uq-box==0.2.0) (1.3.1)
Requirement already satisfied: attrs>=17.3.0 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning>=2.4.0->lightning-uq-box==0.2.0) (24.2.0)
Requirement already satisfied: frozenlist>=1.1.1 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning>=2.4.0->lightning-uq-box==0.2.0) (1.5.0)
Requirement already satisfied: multidict<7.0,>=4.5 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning>=2.4.0->lightning-uq-box==0.2.0) (6.1.0)
Requirement already satisfied: propcache>=0.2.0 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning>=2.4.0->lightning-uq-box==0.2.0) (0.2.1)
Requirement already satisfied: yarl<2.0,>=1.17.0 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning>=2.4.0->lightning-uq-box==0.2.0) (1.18.3)
Requirement already satisfied: charset-normalizer<4,>=2 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from requests->huggingface_hub->timm>=0.9.2->lightning-uq-box==0.2.0) (3.4.0)
Requirement already satisfied: idna<4,>=2.5 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from requests->huggingface_hub->timm>=0.9.2->lightning-uq-box==0.2.0) (3.10)
Requirement already satisfied: urllib3<3,>=1.21.1 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from requests->huggingface_hub->timm>=0.9.2->lightning-uq-box==0.2.0) (2.2.3)
Requirement already satisfied: certifi>=2017.4.17 in /home/docs/checkouts/readthedocs.org/user_builds/lightning-uq-box/envs/stable/lib/python3.12/site-packages (from requests->huggingface_hub->timm>=0.9.2->lightning-uq-box==0.2.0) (2024.8.30)
Note: you may need to restart the kernel to use updated packages.

Imports#

[2]:
import os
import tempfile
from functools import partial

import matplotlib.pyplot as plt
import torch.nn as nn
from lightning import Trainer
from lightning.pytorch import seed_everything
from lightning.pytorch.loggers import CSVLogger
from torch.optim import Adam

from lightning_uq_box.datamodules import TwoMoonsDataModule
from lightning_uq_box.models import MLP
from lightning_uq_box.uq_methods import MasksemblesClassification
from lightning_uq_box.viz_utils import (
    plot_predictions_classification,
    plot_training_metrics,
    plot_two_moons_data,
)

plt.rcParams["figure.figsize"] = [14, 5]

%load_ext autoreload
%autoreload 2
INFO:root:Asdfghjkl backend not available since the old asdfghjkl dependency is not installed. If you want to use it, run: pip install git+https://git@github.com/wiseodd/asdl@asdfghjkl
[3]:
seed_everything(0)
Seed set to 0
[3]:
0
[4]:
my_temp_dir = tempfile.mkdtemp()

Datamodule#

[5]:
dm = TwoMoonsDataModule(batch_size=128)

X_train, Y_train, X_test, Y_test, test_grid_points = (
    dm.X_train,
    dm.Y_train,
    dm.X_test,
    dm.Y_test,
    dm.test_grid_points,
)
[6]:
fig = plot_two_moons_data(X_train, Y_train, X_test, Y_test)
../../_images/tutorials_classification_masksembles_8_0.png

Model#

[7]:
network = MLP(
    n_inputs=2,
    n_hidden=[50, 50, 50],
    n_outputs=2,
    dropout_p=0.2,
    activation_fn=nn.ReLU(),
)
network
[7]:
MLP(
  (model): Sequential(
    (0): Linear(in_features=2, out_features=50, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=50, out_features=50, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.2, inplace=False)
    (6): Linear(in_features=50, out_features=50, bias=True)
    (7): ReLU()
    (8): Dropout(p=0.2, inplace=False)
    (9): Linear(in_features=50, out_features=2, bias=True)
  )
)
[8]:
masksemble = MasksemblesClassification(
    model=network,
    optimizer=partial(Adam, lr=1e-2),
    loss_fn=nn.CrossEntropyLoss(),
    num_estimators=5,
    scale=3.5,
)

Trainer#

[9]:
logger = CSVLogger(my_temp_dir)
trainer = Trainer(
    max_epochs=100,  # number of epochs we want to train
    logger=logger,  # log training metrics for later evaluation
    log_every_n_steps=1,
    enable_checkpointing=False,
    enable_progress_bar=False,
    default_root_dir=my_temp_dir,
)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
[10]:
trainer.fit(masksemble, dm)

  | Name          | Type             | Params | Mode
-----------------------------------------------------------
0 | loss_fn       | CrossEntropyLoss | 0      | train
1 | model         | MLP              | 6.1 K  | train
2 | train_metrics | MetricCollection | 0      | train
3 | val_metrics   | MetricCollection | 0      | train
4 | test_metrics  | MetricCollection | 0      | train
-----------------------------------------------------------
5.4 K     Trainable params
750       Non-trainable params
6.1 K     Total params
0.024     Total estimated model params size (MB)
29        Modules in train mode
0         Modules in eval mode
`Trainer.fit` stopped: `max_epochs=100` reached.

Training Metrics#

[11]:
fig = plot_training_metrics(
    os.path.join(my_temp_dir, "lightning_logs"), ["train_loss", "trainAcc"]
)
../../_images/tutorials_classification_masksembles_16_0.png

Prediction#

[12]:
# save predictions
trainer.test(masksemble, dm.test_dataloader())
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
         testAcc                    1.0
     testCalibration       0.0029539649840444326
 testEmpirical Coverage             1.0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
[12]:
[{'testAcc': 1.0,
  'testCalibration': 0.0029539649840444326,
  'testEmpirical Coverage': 1.0}]
[13]:
preds = masksemble.predict_step(test_grid_points)

Evaluate Predictions#

[14]:
fig = plot_predictions_classification(
    X_test, Y_test, preds["pred"].argmax(-1), test_grid_points, preds["pred_uct"]
)
../../_images/tutorials_classification_masksembles_21_0.png