Source code for lightning_uq_box.datamodules.toy_image_regression
# Copyright (c) 2023 lightning-uq-box. All rights reserved.
# Licensed under the Apache License 2.0.
"""Toy Image Regression Datamodule."""
from lightning import LightningDataModule
from torch.utils.data import DataLoader
from lightning_uq_box.datasets import ToyImageRegressionDataset
[docs]
class ToyImageRegressionDatamodule(LightningDataModule):
"""Toy Image Regression Datamodule for Testing."""
[docs]
def __init__(self, batch_size: int = 10) -> None:
"""Initialize a new instance of Toy Image Regression Datamodule.
Args:
batch_size: batch size
"""
super().__init__()
self.batch_size = batch_size
[docs]
def train_dataloader(self) -> DataLoader:
"""Return Train Dataloader."""
return DataLoader(ToyImageRegressionDataset(), batch_size=self.batch_size)
[docs]
def val_dataloader(self) -> DataLoader:
"""Return Val Dataloader."""
return DataLoader(ToyImageRegressionDataset(), batch_size=self.batch_size)
[docs]
def calib_dataloader(self) -> DataLoader:
"""Return Calib Dataloader."""
return DataLoader(ToyImageRegressionDataset(), batch_size=self.batch_size)
[docs]
def test_dataloader(self) -> DataLoader:
"""Return Test Dataloader."""
return DataLoader(ToyImageRegressionDataset(), batch_size=self.batch_size)