# Copyright (c) 2023 lightning-uq-box. All rights reserved.
# Licensed under the Apache License 2.0.
"""CARD Regression Diffusion Model.
Based on official PyTorch implementation from https://github.com/XzwHan/CARD # noqa: E501
"""
import math
import os
from typing import Any
import numpy as np
import torch
import torch.nn as nn
from ema_pytorch import EMA
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
from torch import Tensor
from .base import BaseModule
from .utils import (
_get_num_outputs,
default_classification_metrics,
default_regression_metrics,
process_classification_prediction,
save_classification_predictions,
save_regression_predictions,
)
[docs]
class CARDBase(BaseModule):
"""CARD Model.
Diffusion Model based on CARD paper.
If you use this in your research, please cite the following paper:
* https://arxiv.org/abs/2206.07275
"""
pred_file_name = "predictions.csv"
[docs]
def __init__(
self,
cond_mean_model: nn.Module,
guidance_model: nn.Module,
n_steps: int = 1000,
beta_schedule: str = "linear",
beta_start: float = 1e-5,
beta_end: float = 1e-2,
n_z_samples: int = 100,
ema_decay: float = 0.995,
ema_update_every: float = 10,
ema_update_after_step: int = 0,
guidance_optim: OptimizerCallable = torch.optim.Adam,
lr_scheduler: LRSchedulerCallable = None,
) -> None:
"""Initialize a new instance of the CARD Model.
Args:
cond_mean_model: conditional mean model, should be
pretrained model that estimates $E[y|x]$
guidance_model: guidance diffusion model
n_steps: number of diffusion steps
beta_schedule: what type of noise scheduling to conduct
beta_start: start value of beta scheduling
beta_end: end value of beta scheduling
n_z_samples: number of samples during prediction
ema_decay: exponential moving average decay
ema_update_every: How often to update the EMA model,
in terms of every n gradient steps.
ema_update_after_step: after which step to start updating the EMA model
guidance_optim: optimizer for the guidance model
lr_scheduler: learning rate scheduler
.. versionchanged:: 0.2.0
Added arguments `ema_decay`, `ema_update_every`, `ema_update_after_step` for EMA support.
"""
super().__init__()
self.cond_mean_model = cond_mean_model
self.guidance_model = guidance_model
self.n_steps = n_steps
self.n_z_samples = n_z_samples
self.noise_scheduler = NoiseScheduler(
beta_schedule, n_steps, beta_start, beta_end
)
self.guidance_optim = guidance_optim
self.lr_scheduler = lr_scheduler
self.ema_decay = ema_decay
self.ema_update_every = ema_update_every
self.ema_update_after_step = ema_update_after_step
self.ema = EMA(
self.guidance_model,
beta=self.ema_decay,
update_after_step=self.ema_update_after_step,
update_every=self.ema_update_every,
)
self.use_ema_model = False
self.setup_task()
[docs]
def setup_task(self) -> None:
"""Setup task specific attributes."""
pass
[docs]
def diffusion_process(self, batch: dict[str, Tensor]) -> Tensor:
"""Diffusion process during training.
Args:
batch: the output of your DataLoader
Returns:
loss from diffusion process
"""
x, y = batch[self.input_key], batch[self.target_key].float()
batch_size = x.shape[0]
# antithetic sampling
ant_samples_t = torch.randint(
low=0, high=self.n_steps, size=(batch_size // 2 + 1,)
).to(x.device)
ant_samples_t = torch.cat(
[ant_samples_t, self.n_steps - 1 - ant_samples_t], dim=0
)[:batch_size]
# noise estimation loss
y_0_hat = self.cond_mean_model(x)
e = torch.randn_like(y)
y_t_sample = self.q_sample(
y,
y_0_hat,
self.noise_scheduler.alphas_bar_sqrt.to(self.device),
self.noise_scheduler.one_minus_alphas_bar_sqrt.to(self.device),
ant_samples_t,
noise=e,
)
if self.use_ema_model:
guidance_output = self.ema.ema_model(x, y_t_sample, y_0_hat, ant_samples_t)
else:
guidance_output = self.guidance_model(x, y_t_sample, y_0_hat, ant_samples_t)
# in classification y usually don't have target dimension
# but in regression they do so for broadcasting align them
if e.dim() == 1:
e = e.unsqueeze(-1)
# TODO does this change?
# use the same noise sample e during training to compute loss
loss = (e - guidance_output).square().mean()
return loss, y_t_sample
[docs]
def on_after_backward(self):
"""Update EMA after each backward pass."""
self.ema.update()
[docs]
def training_step(
self, batch: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0
) -> Tensor:
"""Compute and return the training loss.
Args:
batch: the output of your DataLoader
batch_idx: the index of this batch
dataloader_idx: the index of the dataloader
Returns:
training loss
"""
self.use_ema_model = False
loss, y_t_sample = self.diffusion_process(batch)
self.log("train_loss", loss, batch_size=batch[self.input_key].shape[0])
return loss
# TODO what metrics should be logged?
# def on_train_epoch_end(self):
# """Log epoch-level training metrics."""
# self.log_dict(self.train_metrics.compute())
# self.train_metrics.reset()
[docs]
def validation_step(
self, batch: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0
) -> Tensor:
"""Compute and return the validation loss.
Args:
batch: the output of your DataLoader
batch_idx: the index of this batch
dataloader_idx: the index of the dataloader
Returns:
validation loss
"""
self.use_ema_model = True
loss, y_t_sample = self.diffusion_process(batch)
self.log("val_loss", loss, batch_size=batch[self.input_key].shape[0])
self.use_ema_model = False
return loss
# def on_validation_epoch_end(self) -> None:
# """Log epoch level validation metrics."""
# self.log_dict(self.val_metrics.compute())
# self.val_metrics.reset()
# def on_test_epoch_end(self):
# """Log epoch-level test metrics."""
# self.log_dict(self.test_metrics.compute())
# self.test_metrics.reset()
[docs]
def predict_step(
self, X: Tensor, batch_idx: int = 0, dataloader_idx: int = 0
) -> dict[str, np.ndarray]:
"""Prediction step.
Args:
X: prediction batch of shape [batch_size x input_dims]
batch_idx: the index of this batch
dataloader_idx: the index of the dataloader
Returns:
diffusion samples for each time step
"""
self.use_ema_model = True
# compute y_0_hat only once as the initial prediction
with torch.no_grad():
y_0_hat = self.cond_mean_model(X)
if X.dim() == 2:
# TODO: This works for Vector 1D Regression with the tiling
# y_0_tile = torch.tile(y, (n_z_samples, 1))
y_0_hat_tile = torch.tile(y_0_hat, (self.n_z_samples, 1)).to(
self.device
)
test_x_tile = torch.tile(X, (self.n_z_samples, 1)).to(self.device)
z = torch.randn_like(y_0_hat_tile).to(self.device)
# TODO check what happens, here and why y_0_hat_tile is passed twice
y_t = y_0_hat_tile + z
# generate samples from all time steps for the mini-batch
y_tile_seq: list[Tensor] = self.p_sample_loop(
test_x_tile,
y_0_hat_tile,
y_0_hat_tile,
self.n_steps,
self.noise_scheduler.alphas.to(self.device),
self.noise_scheduler.one_minus_alphas_bar_sqrt.to(self.device),
)
# put in shape [n_z_samples, batch_size, output_dimension]
y_tile_seq = [
arr.reshape(self.n_z_samples, X.shape[0], y_t.shape[-1])
for arr in y_tile_seq
]
final_recoverd = y_tile_seq[-1]
else:
# TODO make this more efficient
y_tile_seq: list[Tensor] = [
self.p_sample_loop(
X,
y_0_hat,
y_0_hat,
self.n_steps,
self.noise_scheduler.alphas.to(self.device),
self.noise_scheduler.one_minus_alphas_bar_sqrt.to(self.device),
)[-1]
for i in range(self.n_z_samples)
]
final_recoverd = torch.stack(y_tile_seq, dim=0)
self.use_ema_model = False
return final_recoverd, y_tile_seq
[docs]
def p_sample(
self,
x: Tensor,
y: Tensor,
y_0_hat: Tensor,
y_T_mean: Tensor,
t: int,
alphas: Tensor,
one_minus_alphas_bar_sqrt: Tensor,
) -> Tensor:
"""Reverse diffusion process sampling, one time step.
This is the process of generating a sample from the model's prior distribution
and then evolving it through the diffusion process. It starts from the final
time step and goes backwards to the initial time step. At each time step,
a noise variable is sampled and the state is updated according to the
reverse diffusion process.
Args:
x: input features
y: sampled y at time step t, y_t.
y_0_hat: prediction of pre-trained guidance model.
y_T_mean: mean of prior distribution at timestep T.
t: time step
alphas: noise schedule alpha
one_minus_alphas_bar_sqrt: noise schedule one minus alpha sqrt
Returns:
reverse process sample
"""
z = torch.randn_like(y) # if t > 1 else torch.zeros_like(y)
t = torch.tensor([t]).to(self.device)
alpha_t = self.extract(alphas, t, y)
sqrt_one_minus_alpha_bar_t = self.extract(one_minus_alphas_bar_sqrt, t, y)
sqrt_one_minus_alpha_bar_t_m_1 = self.extract(
one_minus_alphas_bar_sqrt, t - 1, y
)
sqrt_alpha_bar_t = (1 - sqrt_one_minus_alpha_bar_t.square()).sqrt()
sqrt_alpha_bar_t_m_1 = (1 - sqrt_one_minus_alpha_bar_t_m_1.square()).sqrt()
# y_t_m_1 posterior mean component coefficients
gamma_0 = (
(1 - alpha_t) * sqrt_alpha_bar_t_m_1 / (sqrt_one_minus_alpha_bar_t.square())
)
gamma_1 = (
(sqrt_one_minus_alpha_bar_t_m_1.square())
* (alpha_t.sqrt())
/ (sqrt_one_minus_alpha_bar_t.square())
)
gamma_2 = 1 + (sqrt_alpha_bar_t - 1) * (
alpha_t.sqrt() + sqrt_alpha_bar_t_m_1
) / (sqrt_one_minus_alpha_bar_t.square())
if self.use_ema_model:
eps_theta = self.ema.ema_model(x, y, y_0_hat, t).detach()
else:
eps_theta = self.guidance_model(x, y, y_0_hat, t).detach()
# y_0 reparameterization
y_0_reparam = (
1
/ sqrt_alpha_bar_t
* (
y
- (1 - sqrt_alpha_bar_t) * y_T_mean
- eps_theta * sqrt_one_minus_alpha_bar_t
)
)
# posterior mean
y_t_m_1_hat = gamma_0 * y_0_reparam + gamma_1 * y + gamma_2 * y_T_mean
# posterior variance
beta_t_hat = (
(sqrt_one_minus_alpha_bar_t_m_1.square())
/ (sqrt_one_minus_alpha_bar_t.square())
* (1 - alpha_t)
)
y_t_m_1 = y_t_m_1_hat.to(self.device) + beta_t_hat.sqrt().to(
self.device
) * z.to(self.device)
return y_t_m_1
# Reverse function -- sample y_0 given y_1
[docs]
def p_sample_t_1to0(
self,
x: Tensor,
y: Tensor,
y_0_hat: Tensor,
y_T_mean: Tensor,
one_minus_alphas_bar_sqrt: Tensor,
) -> Tensor:
"""Reverse sample function, sample y_0 given y_1.
Args:
x: input
y: sampled y at time step t, y_t.
y_0_hat: prediction of pre-trained guidance model.
y_T_mean: mean of prior distribution at timestep T.
one_minus_alphas_bar_sqrt: noise schedule one minus alpha bar sqrt
Returns:
y_0 sample
"""
# corresponding to timestep 1 (i.e., t=1 in diffusion models)
t = torch.tensor([0]).to(self.device)
sqrt_one_minus_alpha_bar_t = self.extract(one_minus_alphas_bar_sqrt, t, y)
sqrt_alpha_bar_t = (1 - sqrt_one_minus_alpha_bar_t.square()).sqrt()
if self.use_ema_model:
eps_theta = self.ema.ema_model(x, y, y_0_hat, t).detach()
else:
eps_theta = self.guidance_model(x, y, y_0_hat, t).detach()
# y_0 reparameterization
y_0_reparam = (
1
/ sqrt_alpha_bar_t
* (
y
- (1 - sqrt_alpha_bar_t) * y_T_mean
- eps_theta * sqrt_one_minus_alpha_bar_t
)
)
y_t_m_1 = y_0_reparam.to(self.device)
return y_t_m_1
[docs]
def p_sample_loop(
self,
x: Tensor,
y_0_hat: Tensor,
y_T_mean: Tensor,
n_steps: int,
alphas: Tensor,
one_minus_alphas_bar_sqrt: Tensor,
only_last_sample: bool = False,
) -> list[Tensor]:
"""P sample loop for the entire chain.
Args:
x: input
y_0_hat: prediction of pre-trained guidance model.
y_T_mean: mean of prior distribution at timestep T.
n_steps: number of diffusion steps
alphas: noise schedule alpha
one_minus_alphas_bar_sqrt: noise schedule one minus alpha
only_last_sample: whether to only return the last sample
Returns:
list of samples for each diffusion time step
"""
num_t, y_p_seq = None, None
z = torch.randn_like(y_T_mean).to(self.device)
cur_y = z + y_T_mean # sampled y_T
if only_last_sample:
num_t = 1
else:
y_p_seq = [cur_y]
for t in reversed(range(1, n_steps)):
y_t = cur_y
cur_y = self.p_sample(
x, y_t, y_0_hat, y_T_mean, t, alphas, one_minus_alphas_bar_sqrt
) # y_{t-1}
if only_last_sample:
num_t += 1
else:
y_p_seq.append(cur_y)
if only_last_sample:
assert num_t == n_steps
y_0 = self.p_sample_t_1to0(
x, cur_y, y_0_hat, y_T_mean, one_minus_alphas_bar_sqrt
)
return y_0
else:
assert len(y_p_seq) == n_steps
y_0 = self.p_sample_t_1to0(
x, y_p_seq[-1], y_0_hat, y_T_mean, one_minus_alphas_bar_sqrt
)
y_p_seq.append(y_0)
return y_p_seq
[docs]
def q_sample(
self,
y: Tensor,
y_0_hat: Tensor,
alphas_bar_sqrt: Tensor,
one_minus_alphas_bar_sqrt: Tensor,
t: int,
noise: Tensor | None = None,
) -> Tensor:
"""Q sampling process.
This is the process of approximating the posterior distribution of the
latent variables given the observed data. It starts from the initial
time step and goes forward to the final time step. At each time step,
a noise variable is sampled and the state is updated according to the
forward diffusion process.
Args:
y: sampled y at time step t, y_t.
y_0_hat: prediction of pre-trained guidance model.
alphas_bar_sqrt: noise schedule alpha bar
one_minus_alphas_bar_sqrt: noise schedule one minus alpha bar
t: time step
noise: optional noise tensor
Returns:
q sample at time step t
"""
if y.dim() == 1:
y = y.unsqueeze(-1)
if noise is None:
noise = torch.randn_like(y)
elif noise.shape != y.shape:
noise = noise.unsqueeze(-1)
sqrt_alpha_bar_t = self.extract(alphas_bar_sqrt, t, y)
sqrt_one_minus_alpha_bar_t = self.extract(one_minus_alphas_bar_sqrt, t, y)
# q(y_t | y_0, x)
# add feature dimension for proper broadcasting
y_t = (
sqrt_alpha_bar_t * y
+ (1 - sqrt_alpha_bar_t) * y_0_hat
+ sqrt_one_minus_alpha_bar_t * noise
)
return y_t
[docs]
def test_step(
self, batch: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0
) -> Tensor:
"""Compute and return the test loss.
Args:
batch: the output of your DataLoader
batch_idx: the index of this batch
dataloader_idx: the index of the dataloader
Returns:
test loss
"""
out_dict = self.predict_step(batch[self.input_key])
out_dict[self.target_key] = batch[self.target_key].detach().squeeze(-1).cpu()
# turn mean to np array
out_dict["pred"] = out_dict["pred"].detach().cpu().squeeze(-1)
out_dict["pred_uct"] = out_dict["pred_uct"].detach().cpu().squeeze(-1)
if "aleatoric_uct" in out_dict:
out_dict["aleatoric_uct"] = (
out_dict["aleatoric_uct"].detach().cpu().squeeze(-1)
)
# save metadata
out_dict = self.add_aux_data_to_dict(out_dict, batch)
return out_dict
[docs]
class CARDRegression(CARDBase):
"""CARD Regression Model.
If you use this in your research, please cite the following paper:
* https://arxiv.org/abs/2206.07275
"""
[docs]
def setup_task(self) -> None:
"""Setup task specific attributes."""
self.train_metrics = default_regression_metrics("train")
self.val_metrics = default_regression_metrics("val")
self.test_metrics = default_regression_metrics("test")
[docs]
def predict_step(
self, X: Tensor, batch_idx: int = 0, dataloader_idx: int = 0
) -> dict[str, Tensor]:
"""Prediction step.
Args:
X: prediction batch of shape [batch_size x input_dims]
batch_idx: batch index
dataloader_idx: dataloader index
Returns:
prediction dictionary with uncertainty estimates and samples
"""
final_recoverd, y_tile_seq = super().predict_step(X, batch_idx, dataloader_idx)
# momenet matching
mean_pred = final_recoverd.mean(dim=0).detach().cpu().squeeze()
std_pred = final_recoverd.std(dim=0).detach().cpu().squeeze()
return {
"pred": mean_pred,
"pred_uct": std_pred,
"aleatoric_uct": std_pred,
"samples": y_tile_seq,
}
[docs]
def on_test_batch_end(
self,
outputs: dict[str, np.ndarray],
batch: Any,
batch_idx: int,
dataloader_idx=0,
):
"""Test batch end save predictions."""
del outputs["samples"]
save_regression_predictions(
outputs, os.path.join(self.trainer.default_root_dir, self.pred_file_name)
)
[docs]
class CARDClassification(CARDBase):
"""CARD Classification Model.
If you use this in your research, please cite the following paper:
* https://arxiv.org/abs/2206.07275
"""
valid_tasks = ["binary", "multiclass"]
[docs]
def __init__(
self,
cond_mean_model: nn.Module,
guidance_model: nn.Module,
n_steps: int = 1000,
beta_schedule: str = "linear",
beta_start: float = 0.00001,
beta_end: float = 0.01,
n_z_samples: int = 100,
task: str = "multiclass",
ema_decay: float = 0.995,
ema_update_every: float = 10,
ema_update_after_step: int = 0,
guidance_optim: OptimizerCallable = torch.optim.Adam,
lr_scheduler: LRSchedulerCallable = None,
) -> None:
"""Initialize a new instance of the CARD Classification.
Args:
cond_mean_model: conditional mean model, should be
pretrained model that estimates $E[y|x]$
guidance_model: guidance diffusion model
n_steps: number of diffusion steps
beta_schedule: what type of noise scheduling to conduct
beta_start: start value of beta scheduling
beta_end: end value of beta scheduling
n_z_samples: number of samples during prediction
task: classification task, either `binary` or `multiclass`
ema_decay: exponential moving average decay
ema_update_every: How often to update the EMA model,
in terms of every n gradient steps.
ema_update_after_step: after which step to start updating the EMA model
guidance_optim: optimizer for the guidance model
lr_scheduler: learning rate scheduler
.. versionchanged:: 0.2.0
Added arguments `ema_decay`, `ema_update_every`, `ema_update_after_step` for EMA support.
"""
assert task in self.valid_tasks
self.task = task
self.num_classes = _get_num_outputs(cond_mean_model)
super().__init__(
cond_mean_model,
guidance_model,
n_steps,
beta_schedule,
beta_start,
beta_end,
n_z_samples,
ema_decay,
ema_update_every,
ema_update_after_step,
guidance_optim,
lr_scheduler,
)
self.save_hyperparameters(
ignore=[
"cond_mean_model",
"guidance_model",
"guidance_optim",
"lr_scheduler",
]
)
[docs]
def setup_task(self) -> None:
"""Setup task specific attributes."""
self.train_metrics = default_classification_metrics(
"train", self.task, self.num_classes
)
self.val_metrics = default_classification_metrics(
"val", self.task, self.num_classes
)
self.test_metrics = default_classification_metrics(
"test", self.task, self.num_classes
)
[docs]
def predict_step(
self, X: Tensor, batch_idx: int = 0, dataloader_idx: int = 0
) -> dict[str, Tensor]:
"""Prediction step.
Args:
X: prediction batch of shape [batch_size x input_dims]
batch_idx: the index of this batch
dataloader_idx: the index of the dataloader
Returns:
predictions
"""
final_recoverd, y_tile_seq = super().predict_step(X, batch_idx, dataloader_idx)
# change from [num_samples, ...] to shape [batch_size, num_classes, num_samples]
final_recoverd = final_recoverd.permute(1, 2, 0).cpu()
# momenet matching
pred_dict = process_classification_prediction(final_recoverd)
pred_dict["samples"] = y_tile_seq
return pred_dict
[docs]
def on_test_batch_end(
self, outputs: dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0
) -> None:
"""Test batch end save predictions.
Args:
outputs: dictionary of model outputs and aux variables
batch_idx: batch index
dataloader_idx: dataloader index
"""
del outputs["samples"]
save_classification_predictions(
outputs, os.path.join(self.trainer.default_root_dir, self.pred_file_name)
)
[docs]
class NoiseScheduler:
"""Noise Scheduler for Diffusion Training."""
valid_schedules = [
"linear",
"const",
"quad",
"jsd",
"sigmoid",
"cosine",
"cosine_anneal",
]
[docs]
def __init__(
self,
schedule: str = "linear",
n_steps: int = 1000,
beta_start: float = 1e-5,
beta_end: float = 1e-2,
) -> None:
"""Initialize a new instance of the noise scheduler.
Args:
schedule: type of noise schedule
n_steps: number of diffusion time steps
beta_start: beta noise start value
beta_end: beta noise end value
Raises:
AssertionError if schedule is invalid
"""
assert schedule in self.valid_schedules, (
f"Invalid schedule, please choose one of {self.valid_schedules}."
)
self.schedule = schedule
self.n_steps = n_steps
self.beta_start = beta_start
self.beta_end = beta_end
self.betas = {
"linear": self.linear_schedule(),
"const": self.constant_schedule(),
"quad": self.quadratic_schedule(),
"sigmoid": self.sigmoid_schedule(),
"cosine": self.cosine_schedule(),
"cosine_anneal": self.cosine_anneal_schedule(),
}[schedule]
self.betas_sqrt = torch.sqrt(self.betas)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = self.alphas.cumprod(dim=0)
self.alphas_bar_sqrt = torch.sqrt(self.alphas_cumprod)
self.one_minus_alphas_bar_sqrt = torch.sqrt(1 - self.alphas_cumprod)
[docs]
def linear_schedule(self) -> Tensor:
"""Linear Schedule."""
return torch.linspace(self.beta_start, self.beta_end, self.n_steps)
[docs]
def constant_schedule(self) -> Tensor:
"""Constant Schedule."""
return self.beta_end * torch.ones(self.n_steps)
[docs]
def quadratic_schedule(self) -> Tensor:
"""Quadratic Schedule."""
return (
torch.linspace(self.beta_start**0.5, self.beta_end**0.5, self.n_steps) ** 2
)
[docs]
def sigmoid_schedule(self) -> Tensor:
"""Sigmoid Schedule."""
betas = (
torch.sigmoid(torch.linspace(-6, 6, self.n_steps))
* (self.beta_end - self.beta_start)
+ self.beta_start
)
return torch.sigmoid(betas)
[docs]
def cosine_schedule(self) -> Tensor:
"""Cosine Schedule."""
max_beta = 0.999
cosine_s = 0.008
return torch.tensor(
[
min(
1
- (
math.cos(
((i + 1) / self.n_steps + cosine_s)
/ (1 + cosine_s)
* math.pi
/ 2
)
** 2
)
/ (
math.cos(
(i / self.n_steps + cosine_s) / (1 + cosine_s) * math.pi / 2
)
** 2
),
max_beta,
)
for i in range(self.n_steps)
]
)
[docs]
def cosine_anneal_schedule(self) -> Tensor:
"""Cosine Annealing Schedule."""
return torch.tensor(
[
self.beta_start
+ 0.5
* (self.beta_end - self.beta_start)
* (1 - math.cos(t / (self.n_steps - 1) * math.pi))
for t in range(self.n_steps)
]
)
[docs]
def get_noisy_x_at_t(input, t, x) -> Tensor:
"""Retrieve a noisy representation at time step t.
Args:
input: schedule version
t: time step
x: tensor ot make noisy version of
Returns:
A noisy
"""
shape = x.shape
out = torch.gather(input, 0, t.to(input.device))
reshape = [t.shape[0]] + [1] * (len(shape) - 1)
return out.reshape(*reshape)