| | from math import atan, cos, pi, sin, sqrt |
| | from typing import Any, Callable, List, Optional, Tuple, Type |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from einops import rearrange, reduce |
| | from torch import Tensor |
| |
|
| | from .utils import * |
| |
|
| | """ |
| | Diffusion Training |
| | """ |
| |
|
| | """ Distributions """ |
| |
|
| |
|
| | class Distribution: |
| | def __call__(self, num_samples: int, device: torch.device): |
| | raise NotImplementedError() |
| |
|
| |
|
| | class LogNormalDistribution(Distribution): |
| | def __init__(self, mean: float, std: float): |
| | self.mean = mean |
| | self.std = std |
| |
|
| | def __call__( |
| | self, num_samples: int, device: torch.device = torch.device("cpu") |
| | ) -> Tensor: |
| | normal = self.mean + self.std * torch.randn((num_samples,), device=device) |
| | return normal.exp() |
| |
|
| |
|
| | class UniformDistribution(Distribution): |
| | def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")): |
| | return torch.rand(num_samples, device=device) |
| |
|
| |
|
| | class VKDistribution(Distribution): |
| | def __init__( |
| | self, |
| | min_value: float = 0.0, |
| | max_value: float = float("inf"), |
| | sigma_data: float = 1.0, |
| | ): |
| | self.min_value = min_value |
| | self.max_value = max_value |
| | self.sigma_data = sigma_data |
| |
|
| | def __call__( |
| | self, num_samples: int, device: torch.device = torch.device("cpu") |
| | ) -> Tensor: |
| | sigma_data = self.sigma_data |
| | min_cdf = atan(self.min_value / sigma_data) * 2 / pi |
| | max_cdf = atan(self.max_value / sigma_data) * 2 / pi |
| | u = (max_cdf - min_cdf) * torch.randn((num_samples,), device=device) + min_cdf |
| | return torch.tan(u * pi / 2) * sigma_data |
| |
|
| |
|
| | """ Diffusion Classes """ |
| |
|
| |
|
| | def pad_dims(x: Tensor, ndim: int) -> Tensor: |
| | |
| | return x.view(*x.shape, *((1,) * ndim)) |
| |
|
| |
|
| | def clip(x: Tensor, dynamic_threshold: float = 0.0): |
| | if dynamic_threshold == 0.0: |
| | return x.clamp(-1.0, 1.0) |
| | else: |
| | |
| | |
| | x_flat = rearrange(x, "b ... -> b (...)") |
| | scale = torch.quantile(x_flat.abs(), dynamic_threshold, dim=-1) |
| | |
| | scale.clamp_(min=1.0) |
| | |
| | scale = pad_dims(scale, ndim=x.ndim - scale.ndim) |
| | x = x.clamp(-scale, scale) / scale |
| | return x |
| |
|
| |
|
| | def to_batch( |
| | batch_size: int, |
| | device: torch.device, |
| | x: Optional[float] = None, |
| | xs: Optional[Tensor] = None, |
| | ) -> Tensor: |
| | assert exists(x) ^ exists(xs), "Either x or xs must be provided" |
| | |
| | if exists(x): |
| | xs = torch.full(size=(batch_size,), fill_value=x).to(device) |
| | assert exists(xs) |
| | return xs |
| |
|
| |
|
| | class Diffusion(nn.Module): |
| |
|
| | alias: str = "" |
| |
|
| | """Base diffusion class""" |
| |
|
| | def denoise_fn( |
| | self, |
| | x_noisy: Tensor, |
| | sigmas: Optional[Tensor] = None, |
| | sigma: Optional[float] = None, |
| | **kwargs, |
| | ) -> Tensor: |
| | raise NotImplementedError("Diffusion class missing denoise_fn") |
| |
|
| | def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor: |
| | raise NotImplementedError("Diffusion class missing forward function") |
| |
|
| |
|
| | class VDiffusion(Diffusion): |
| |
|
| | alias = "v" |
| |
|
| | def __init__(self, net: nn.Module, *, sigma_distribution: Distribution): |
| | super().__init__() |
| | self.net = net |
| | self.sigma_distribution = sigma_distribution |
| |
|
| | def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]: |
| | angle = sigmas * pi / 2 |
| | alpha = torch.cos(angle) |
| | beta = torch.sin(angle) |
| | return alpha, beta |
| |
|
| | def denoise_fn( |
| | self, |
| | x_noisy: Tensor, |
| | sigmas: Optional[Tensor] = None, |
| | sigma: Optional[float] = None, |
| | **kwargs, |
| | ) -> Tensor: |
| | batch_size, device = x_noisy.shape[0], x_noisy.device |
| | sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device) |
| | return self.net(x_noisy, sigmas, **kwargs) |
| |
|
| | def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor: |
| | batch_size, device = x.shape[0], x.device |
| |
|
| | |
| | sigmas = self.sigma_distribution(num_samples=batch_size, device=device) |
| | sigmas_padded = rearrange(sigmas, "b -> b 1 1") |
| |
|
| | |
| | noise = default(noise, lambda: torch.randn_like(x)) |
| |
|
| | |
| | alpha, beta = self.get_alpha_beta(sigmas_padded) |
| | x_noisy = x * alpha + noise * beta |
| | x_target = noise * alpha - x * beta |
| |
|
| | |
| | x_denoised = self.denoise_fn(x_noisy, sigmas, **kwargs) |
| | return F.mse_loss(x_denoised, x_target) |
| |
|
| |
|
| | class KDiffusion(Diffusion): |
| | """Elucidated Diffusion (Karras et al. 2022): https://arxiv.org/abs/2206.00364""" |
| |
|
| | alias = "k" |
| |
|
| | def __init__( |
| | self, |
| | net: nn.Module, |
| | *, |
| | sigma_distribution: Distribution, |
| | sigma_data: float, |
| | dynamic_threshold: float = 0.0, |
| | ): |
| | super().__init__() |
| | self.net = net |
| | self.sigma_data = sigma_data |
| | self.sigma_distribution = sigma_distribution |
| | self.dynamic_threshold = dynamic_threshold |
| |
|
| | def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]: |
| | sigma_data = self.sigma_data |
| | c_noise = torch.log(sigmas) * 0.25 |
| | sigmas = rearrange(sigmas, "b -> b 1 1") |
| | c_skip = (sigma_data ** 2) / (sigmas ** 2 + sigma_data ** 2) |
| | c_out = sigmas * sigma_data * (sigma_data ** 2 + sigmas ** 2) ** -0.5 |
| | c_in = (sigmas ** 2 + sigma_data ** 2) ** -0.5 |
| | return c_skip, c_out, c_in, c_noise |
| |
|
| | def denoise_fn( |
| | self, |
| | x_noisy: Tensor, |
| | sigmas: Optional[Tensor] = None, |
| | sigma: Optional[float] = None, |
| | **kwargs, |
| | ) -> Tensor: |
| | batch_size, device = x_noisy.shape[0], x_noisy.device |
| | sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device) |
| |
|
| | |
| | c_skip, c_out, c_in, c_noise = self.get_scale_weights(sigmas) |
| | x_pred = self.net(c_in * x_noisy, c_noise, **kwargs) |
| | x_denoised = c_skip * x_noisy + c_out * x_pred |
| |
|
| | return x_denoised |
| |
|
| | def loss_weight(self, sigmas: Tensor) -> Tensor: |
| | |
| | return (sigmas ** 2 + self.sigma_data ** 2) * (sigmas * self.sigma_data) ** -2 |
| |
|
| | def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor: |
| | batch_size, device = x.shape[0], x.device |
| | from einops import rearrange, reduce |
| |
|
| | |
| | sigmas = self.sigma_distribution(num_samples=batch_size, device=device) |
| | sigmas_padded = rearrange(sigmas, "b -> b 1 1") |
| |
|
| | |
| | noise = default(noise, lambda: torch.randn_like(x)) |
| | x_noisy = x + sigmas_padded * noise |
| | |
| | |
| | x_denoised = self.denoise_fn(x_noisy, sigmas=sigmas, **kwargs) |
| |
|
| | |
| | losses = F.mse_loss(x_denoised, x, reduction="none") |
| | losses = reduce(losses, "b ... -> b", "mean") |
| | losses = losses * self.loss_weight(sigmas) |
| | loss = losses.mean() |
| | return loss |
| |
|
| |
|
| | class VKDiffusion(Diffusion): |
| |
|
| | alias = "vk" |
| |
|
| | def __init__(self, net: nn.Module, *, sigma_distribution: Distribution): |
| | super().__init__() |
| | self.net = net |
| | self.sigma_distribution = sigma_distribution |
| |
|
| | def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]: |
| | sigma_data = 1.0 |
| | sigmas = rearrange(sigmas, "b -> b 1 1") |
| | c_skip = (sigma_data ** 2) / (sigmas ** 2 + sigma_data ** 2) |
| | c_out = -sigmas * sigma_data * (sigma_data ** 2 + sigmas ** 2) ** -0.5 |
| | c_in = (sigmas ** 2 + sigma_data ** 2) ** -0.5 |
| | return c_skip, c_out, c_in |
| |
|
| | def sigma_to_t(self, sigmas: Tensor) -> Tensor: |
| | return sigmas.atan() / pi * 2 |
| |
|
| | def t_to_sigma(self, t: Tensor) -> Tensor: |
| | return (t * pi / 2).tan() |
| |
|
| | def denoise_fn( |
| | self, |
| | x_noisy: Tensor, |
| | sigmas: Optional[Tensor] = None, |
| | sigma: Optional[float] = None, |
| | **kwargs, |
| | ) -> Tensor: |
| | batch_size, device = x_noisy.shape[0], x_noisy.device |
| | sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device) |
| |
|
| | |
| | c_skip, c_out, c_in = self.get_scale_weights(sigmas) |
| | x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs) |
| | x_denoised = c_skip * x_noisy + c_out * x_pred |
| | return x_denoised |
| |
|
| | def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor: |
| | batch_size, device = x.shape[0], x.device |
| |
|
| | |
| | sigmas = self.sigma_distribution(num_samples=batch_size, device=device) |
| | sigmas_padded = rearrange(sigmas, "b -> b 1 1") |
| |
|
| | |
| | noise = default(noise, lambda: torch.randn_like(x)) |
| | x_noisy = x + sigmas_padded * noise |
| |
|
| | |
| | c_skip, c_out, c_in = self.get_scale_weights(sigmas) |
| | x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs) |
| |
|
| | |
| | v_target = (x - c_skip * x_noisy) / (c_out + 1e-7) |
| |
|
| | |
| | loss = F.mse_loss(x_pred, v_target) |
| | return loss |
| |
|
| |
|
| | """ |
| | Diffusion Sampling |
| | """ |
| |
|
| | """ Schedules """ |
| |
|
| |
|
| | class Schedule(nn.Module): |
| | """Interface used by different sampling schedules""" |
| |
|
| | def forward(self, num_steps: int, device: torch.device) -> Tensor: |
| | raise NotImplementedError() |
| |
|
| |
|
| | class LinearSchedule(Schedule): |
| | def forward(self, num_steps: int, device: Any) -> Tensor: |
| | sigmas = torch.linspace(1, 0, num_steps + 1)[:-1] |
| | return sigmas |
| |
|
| |
|
| | class KarrasSchedule(Schedule): |
| | """https://arxiv.org/abs/2206.00364 equation 5""" |
| |
|
| | def __init__(self, sigma_min: float, sigma_max: float, rho: float = 7.0): |
| | super().__init__() |
| | self.sigma_min = sigma_min |
| | self.sigma_max = sigma_max |
| | self.rho = rho |
| |
|
| | def forward(self, num_steps: int, device: Any) -> Tensor: |
| | rho_inv = 1.0 / self.rho |
| | steps = torch.arange(num_steps, device=device, dtype=torch.float32) |
| | sigmas = ( |
| | self.sigma_max ** rho_inv |
| | + (steps / (num_steps - 1)) |
| | * (self.sigma_min ** rho_inv - self.sigma_max ** rho_inv) |
| | ) ** self.rho |
| | sigmas = F.pad(sigmas, pad=(0, 1), value=0.0) |
| | return sigmas |
| |
|
| |
|
| | """ Samplers """ |
| |
|
| |
|
| | class Sampler(nn.Module): |
| |
|
| | diffusion_types: List[Type[Diffusion]] = [] |
| |
|
| | def forward( |
| | self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int |
| | ) -> Tensor: |
| | raise NotImplementedError() |
| |
|
| | def inpaint( |
| | self, |
| | source: Tensor, |
| | mask: Tensor, |
| | fn: Callable, |
| | sigmas: Tensor, |
| | num_steps: int, |
| | num_resamples: int, |
| | ) -> Tensor: |
| | raise NotImplementedError("Inpainting not available with current sampler") |
| |
|
| |
|
| | class VSampler(Sampler): |
| |
|
| | diffusion_types = [VDiffusion] |
| |
|
| | def get_alpha_beta(self, sigma: float) -> Tuple[float, float]: |
| | angle = sigma * pi / 2 |
| | alpha = cos(angle) |
| | beta = sin(angle) |
| | return alpha, beta |
| |
|
| | def forward( |
| | self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int |
| | ) -> Tensor: |
| | x = sigmas[0] * noise |
| | alpha, beta = self.get_alpha_beta(sigmas[0].item()) |
| |
|
| | for i in range(num_steps - 1): |
| | is_last = i == num_steps - 1 |
| |
|
| | x_denoised = fn(x, sigma=sigmas[i]) |
| | x_pred = x * alpha - x_denoised * beta |
| | x_eps = x * beta + x_denoised * alpha |
| |
|
| | if not is_last: |
| | alpha, beta = self.get_alpha_beta(sigmas[i + 1].item()) |
| | x = x_pred * alpha + x_eps * beta |
| |
|
| | return x_pred |
| |
|
| |
|
| | class KarrasSampler(Sampler): |
| | """https://arxiv.org/abs/2206.00364 algorithm 1""" |
| |
|
| | diffusion_types = [KDiffusion, VKDiffusion] |
| |
|
| | def __init__( |
| | self, |
| | s_tmin: float = 0, |
| | s_tmax: float = float("inf"), |
| | s_churn: float = 0.0, |
| | s_noise: float = 1.0, |
| | ): |
| | super().__init__() |
| | self.s_tmin = s_tmin |
| | self.s_tmax = s_tmax |
| | self.s_noise = s_noise |
| | self.s_churn = s_churn |
| |
|
| | def step( |
| | self, x: Tensor, fn: Callable, sigma: float, sigma_next: float, gamma: float |
| | ) -> Tensor: |
| | """Algorithm 2 (step)""" |
| | |
| | sigma_hat = sigma + gamma * sigma |
| | |
| | epsilon = self.s_noise * torch.randn_like(x) |
| | x_hat = x + sqrt(sigma_hat ** 2 - sigma ** 2) * epsilon |
| | |
| | d = (x_hat - fn(x_hat, sigma=sigma_hat)) / sigma_hat |
| | |
| | x_next = x_hat + (sigma_next - sigma_hat) * d |
| | |
| | if sigma_next != 0: |
| | model_out_next = fn(x_next, sigma=sigma_next) |
| | d_prime = (x_next - model_out_next) / sigma_next |
| | x_next = x_hat + 0.5 * (sigma - sigma_hat) * (d + d_prime) |
| | return x_next |
| |
|
| | def forward( |
| | self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int |
| | ) -> Tensor: |
| | x = sigmas[0] * noise |
| | |
| | gammas = torch.where( |
| | (sigmas >= self.s_tmin) & (sigmas <= self.s_tmax), |
| | min(self.s_churn / num_steps, sqrt(2) - 1), |
| | 0.0, |
| | ) |
| | |
| | for i in range(num_steps - 1): |
| | x = self.step( |
| | x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1], gamma=gammas[i] |
| | ) |
| |
|
| | return x |
| |
|
| |
|
| | class AEulerSampler(Sampler): |
| |
|
| | diffusion_types = [KDiffusion, VKDiffusion] |
| |
|
| | def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float]: |
| | sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2) |
| | sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2) |
| | return sigma_up, sigma_down |
| |
|
| | def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor: |
| | |
| | sigma_up, sigma_down = self.get_sigmas(sigma, sigma_next) |
| | |
| | d = (x - fn(x, sigma=sigma)) / sigma |
| | |
| | x_next = x + d * (sigma_down - sigma) |
| | |
| | x_next = x_next + torch.randn_like(x) * sigma_up |
| | return x_next |
| |
|
| | def forward( |
| | self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int |
| | ) -> Tensor: |
| | x = sigmas[0] * noise |
| | |
| | for i in range(num_steps - 1): |
| | x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) |
| | return x |
| |
|
| |
|
| | class ADPM2Sampler(Sampler): |
| | """https://www.desmos.com/calculator/jbxjlqd9mb""" |
| |
|
| | diffusion_types = [KDiffusion, VKDiffusion] |
| |
|
| | def __init__(self, rho: float = 1.0): |
| | super().__init__() |
| | self.rho = rho |
| |
|
| | def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float, float]: |
| | r = self.rho |
| | sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2) |
| | sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2) |
| | sigma_mid = ((sigma ** (1 / r) + sigma_down ** (1 / r)) / 2) ** r |
| | return sigma_up, sigma_down, sigma_mid |
| |
|
| | def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor: |
| | |
| | sigma_up, sigma_down, sigma_mid = self.get_sigmas(sigma, sigma_next) |
| | |
| | d = (x - fn(x, sigma=sigma)) / sigma |
| | |
| | x_mid = x + d * (sigma_mid - sigma) |
| | |
| | d_mid = (x_mid - fn(x_mid, sigma=sigma_mid)) / sigma_mid |
| | |
| | x = x + d_mid * (sigma_down - sigma) |
| | |
| | x_next = x + torch.randn_like(x) * sigma_up |
| | return x_next |
| |
|
| | def forward( |
| | self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int |
| | ) -> Tensor: |
| | x = sigmas[0] * noise |
| | |
| | for i in range(num_steps - 1): |
| | x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) |
| | return x |
| |
|
| | def inpaint( |
| | self, |
| | source: Tensor, |
| | mask: Tensor, |
| | fn: Callable, |
| | sigmas: Tensor, |
| | num_steps: int, |
| | num_resamples: int, |
| | ) -> Tensor: |
| | x = sigmas[0] * torch.randn_like(source) |
| |
|
| | for i in range(num_steps - 1): |
| | |
| | source_noisy = source + sigmas[i] * torch.randn_like(source) |
| | for r in range(num_resamples): |
| | |
| | x = source_noisy * mask + x * ~mask |
| | x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) |
| | |
| | if r < num_resamples - 1: |
| | sigma = sqrt(sigmas[i] ** 2 - sigmas[i + 1] ** 2) |
| | x = x + sigma * torch.randn_like(x) |
| |
|
| | return source * mask + x * ~mask |
| |
|
| |
|
| | """ Main Classes """ |
| |
|
| |
|
| | class DiffusionSampler(nn.Module): |
| | def __init__( |
| | self, |
| | diffusion: Diffusion, |
| | *, |
| | sampler: Sampler, |
| | sigma_schedule: Schedule, |
| | num_steps: Optional[int] = None, |
| | clamp: bool = True, |
| | ): |
| | super().__init__() |
| | self.denoise_fn = diffusion.denoise_fn |
| | self.sampler = sampler |
| | self.sigma_schedule = sigma_schedule |
| | self.num_steps = num_steps |
| | self.clamp = clamp |
| |
|
| | |
| | sampler_class = sampler.__class__.__name__ |
| | diffusion_class = diffusion.__class__.__name__ |
| | message = f"{sampler_class} incompatible with {diffusion_class}" |
| | assert diffusion.alias in [t.alias for t in sampler.diffusion_types], message |
| |
|
| | def forward( |
| | self, noise: Tensor, num_steps: Optional[int] = None, **kwargs |
| | ) -> Tensor: |
| | device = noise.device |
| | num_steps = default(num_steps, self.num_steps) |
| | assert exists(num_steps), "Parameter `num_steps` must be provided" |
| | |
| | sigmas = self.sigma_schedule(num_steps, device) |
| | |
| | fn = lambda *a, **ka: self.denoise_fn(*a, **{**ka, **kwargs}) |
| | |
| | x = self.sampler(noise, fn=fn, sigmas=sigmas, num_steps=num_steps) |
| | x = x.clamp(-1.0, 1.0) if self.clamp else x |
| | return x |
| |
|
| |
|
| | class DiffusionInpainter(nn.Module): |
| | def __init__( |
| | self, |
| | diffusion: Diffusion, |
| | *, |
| | num_steps: int, |
| | num_resamples: int, |
| | sampler: Sampler, |
| | sigma_schedule: Schedule, |
| | ): |
| | super().__init__() |
| | self.denoise_fn = diffusion.denoise_fn |
| | self.num_steps = num_steps |
| | self.num_resamples = num_resamples |
| | self.inpaint_fn = sampler.inpaint |
| | self.sigma_schedule = sigma_schedule |
| |
|
| | @torch.no_grad() |
| | def forward(self, inpaint: Tensor, inpaint_mask: Tensor) -> Tensor: |
| | x = self.inpaint_fn( |
| | source=inpaint, |
| | mask=inpaint_mask, |
| | fn=self.denoise_fn, |
| | sigmas=self.sigma_schedule(self.num_steps, inpaint.device), |
| | num_steps=self.num_steps, |
| | num_resamples=self.num_resamples, |
| | ) |
| | return x |
| |
|
| |
|
| | def sequential_mask(like: Tensor, start: int) -> Tensor: |
| | length, device = like.shape[2], like.device |
| | mask = torch.ones_like(like, dtype=torch.bool) |
| | mask[:, :, start:] = torch.zeros((length - start,), device=device) |
| | return mask |
| |
|
| |
|
| | class SpanBySpanComposer(nn.Module): |
| | def __init__( |
| | self, |
| | inpainter: DiffusionInpainter, |
| | *, |
| | num_spans: int, |
| | ): |
| | super().__init__() |
| | self.inpainter = inpainter |
| | self.num_spans = num_spans |
| |
|
| | def forward(self, start: Tensor, keep_start: bool = False) -> Tensor: |
| | half_length = start.shape[2] // 2 |
| |
|
| | spans = list(start.chunk(chunks=2, dim=-1)) if keep_start else [] |
| | |
| | inpaint = torch.zeros_like(start) |
| | inpaint[:, :, :half_length] = start[:, :, half_length:] |
| | inpaint_mask = sequential_mask(like=start, start=half_length) |
| |
|
| | for i in range(self.num_spans): |
| | |
| | span = self.inpainter(inpaint=inpaint, inpaint_mask=inpaint_mask) |
| | |
| | second_half = span[:, :, half_length:] |
| | inpaint[:, :, :half_length] = second_half |
| | |
| | spans.append(second_half) |
| |
|
| | return torch.cat(spans, dim=2) |
| |
|
| |
|
| | class XDiffusion(nn.Module): |
| | def __init__(self, type: str, net: nn.Module, **kwargs): |
| | super().__init__() |
| |
|
| | diffusion_classes = [VDiffusion, KDiffusion, VKDiffusion] |
| | aliases = [t.alias for t in diffusion_classes] |
| | message = f"type='{type}' must be one of {*aliases,}" |
| | assert type in aliases, message |
| | self.net = net |
| |
|
| | for XDiffusion in diffusion_classes: |
| | if XDiffusion.alias == type: |
| | self.diffusion = XDiffusion(net=net, **kwargs) |
| |
|
| | def forward(self, *args, **kwargs) -> Tensor: |
| | return self.diffusion(*args, **kwargs) |
| |
|
| | def sample( |
| | self, |
| | noise: Tensor, |
| | num_steps: int, |
| | sigma_schedule: Schedule, |
| | sampler: Sampler, |
| | clamp: bool, |
| | **kwargs, |
| | ) -> Tensor: |
| | diffusion_sampler = DiffusionSampler( |
| | diffusion=self.diffusion, |
| | sampler=sampler, |
| | sigma_schedule=sigma_schedule, |
| | num_steps=num_steps, |
| | clamp=clamp, |
| | ) |
| | return diffusion_sampler(noise, **kwargs) |
| |
|