| | """ |
| | Adapted from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/nn.py#L124 |
| | """ |
| |
|
| | from typing import Callable, Iterable, Sequence, Union |
| |
|
| | import torch |
| |
|
| |
|
| | def checkpoint( |
| | func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]], |
| | inputs: Sequence[torch.Tensor], |
| | params: Iterable[torch.Tensor], |
| | flag: bool, |
| | ): |
| | """ |
| | Evaluate a function without caching intermediate activations, allowing for |
| | reduced memory at the expense of extra compute in the backward pass. |
| | :param func: the function to evaluate. |
| | :param inputs: the argument sequence to pass to `func`. |
| | :param params: a sequence of parameters `func` depends on but does not |
| | explicitly take as arguments. |
| | :param flag: if False, disable gradient checkpointing. |
| | """ |
| | if flag: |
| | args = tuple(inputs) + tuple(params) |
| | return CheckpointFunction.apply(func, len(inputs), *args) |
| | else: |
| | return func(*inputs) |
| |
|
| |
|
| | class CheckpointFunction(torch.autograd.Function): |
| | @staticmethod |
| | def forward(ctx, run_function, length, *args): |
| | ctx.run_function = run_function |
| | ctx.input_tensors = list(args[:length]) |
| | ctx.input_params = list(args[length:]) |
| | with torch.no_grad(): |
| | output_tensors = ctx.run_function(*ctx.input_tensors) |
| | return output_tensors |
| |
|
| | @staticmethod |
| | def backward(ctx, *output_grads): |
| | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] |
| | with torch.enable_grad(): |
| | |
| | |
| | |
| | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] |
| | output_tensors = ctx.run_function(*shallow_copies) |
| | input_grads = torch.autograd.grad( |
| | output_tensors, |
| | ctx.input_tensors + ctx.input_params, |
| | output_grads, |
| | allow_unused=True, |
| | ) |
| | del ctx.input_tensors |
| | del ctx.input_params |
| | del output_tensors |
| | return (None, None) + input_grads |
| |
|