Diffusers documentation
AnyFlow
AnyFlow
AnyFlow: Any-Step Video Diffusion Model with On-Policy Flow Map Distillation by Yuchao Gu, Guian Fang and collaborators at NUS ShowLab in collaboration with NVIDIA.
Few-step video generation has been significantly advanced by consistency models. However, their performance often degrades in any-step video diffusion models due to the fixed-point formulation. To address this limitation, we present AnyFlow, the first any-step video diffusion distillation framework built on flow maps. Instead of learning only the mapping z_t → z_0, AnyFlow learns transitions z_t → z_r over arbitrary time intervals, enabling a single model to adapt to different inference budgets. We design an improved forward flow map training recipe that fine-tunes pretrained video diffusion models into flow map models, and introduce Flow Map Backward Simulation to enable on-policy distillation for flow map models. Extensive experiments across both bidirectional and causal architectures, at scales ranging from 1.3B to 14B, on text-to-video and image-to-video tasks demonstrate that AnyFlow outperforms consistency-based baselines while preserving high fidelity and flexible sampling under varying step budgets.
The original training code is at NVlabs/AnyFlow. The project page is at nvlabs.github.io/AnyFlow.
The following AnyFlow checkpoints are supported:
| Checkpoint | Backbone | Description |
|---|---|---|
nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers | Wan2.1 1.3B | Bidirectional T2V, lightweight |
nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers | Wan2.1 14B | Bidirectional T2V, full quality |
nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers | FAR + Wan2.1 1.3B | Causal T2V / I2V / V2V |
nvidia/AnyFlow-FAR-Wan2.1-14B-Diffusers | FAR + Wan2.1 14B | Causal T2V / I2V / V2V |
All four are grouped under the nvidia/anyflow Hugging Face collection.
Choose
AnyFlowPipelinefor traditional bidirectional text-to-video generation. ChooseAnyFlowFARPipelinefor streaming I2V, video continuation (V2V), or any setup that benefits from frame-by-frame autoregressive sampling.
AnyFlow supports any-step sampling: a single distilled checkpoint can be evaluated at 1, 2, 4, 8, 16… NFE without retraining. Quality scales monotonically with steps in our benchmarks.
Optimizing Memory and Inference Speed
import torch
from diffusers import AnyFlowPipeline
from diffusers.hooks import apply_group_offloading
pipe = AnyFlowPipeline.from_pretrained(
"nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers", torch_dtype=torch.bfloat16
)
apply_group_offloading(pipe.transformer, onload_device="cuda", offload_type="leaf_level")
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()Generation with AnyFlow (Bidirectional T2V)
import torch
from diffusers import AnyFlowPipeline
from diffusers.utils import export_to_video
pipe = AnyFlowPipeline.from_pretrained(
"nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers", torch_dtype=torch.bfloat16
).to("cuda")
prompt = "A red panda eating bamboo in a forest, cinematic lighting"
video = pipe(prompt, num_inference_steps=4, num_frames=33).frames[0]
export_to_video(video, "out.mp4", fps=16)Generation with AnyFlow (FAR Causal)
The causal pipeline selects between T2V / I2V / V2V via the video (or video_latents) argument:
omit both for plain text-to-video, or pass video=<tensor> of shape (B, T, C, H, W) in [0, 1]
with T = 4n + 1 to condition on existing frames. Use a single conditioning frame for I2V and a longer
clip for V2V continuation. If you already have pre-encoded latents in the model layout, pass them via
video_latents=<tensor> to skip VAE encoding. video and video_latents are mutually exclusive.
AnyFlowFARPipeline.default_chunk_partition = [1, 3, 3, 3, 3, 3, 3, 2](sum 21) is matched to the released checkpoints’ canonical 81 raw frames (21 latent frames at the VAE temporal stride of 4). When you changenum_frames, you must also pass a matchingchunk_partitionsumming to(num_frames - 1) // 4 + 1, otherwise the pipeline raises anAssertionError.
import torch
from diffusers import AnyFlowFARPipeline
from diffusers.utils import export_to_video
pipe = AnyFlowFARPipeline.from_pretrained(
"nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", torch_dtype=torch.bfloat16
).to("cuda")
video = pipe(
prompt="A cat surfing a wave, sunset",
num_inference_steps=4,
num_frames=81,
).frames[0]
export_to_video(video, "out.mp4", fps=16)Notes
- Classifier-free guidance is fused into the released checkpoints, so inference does not run a second guided forward pass. Keep the default
guidance_scale=1.0unless your own checkpoint requires otherwise. FlowMapEulerDiscreteScheduleris general-purpose. You can attach it to any flow-map-distilled checkpoint viafrom_pretrained(..., scheduler=FlowMapEulerDiscreteScheduler.from_config(...)).AnyFlowPipelineusesAnyFlowTransformer3DModel(bidirectional).AnyFlowFARPipelineusesAnyFlowFARTransformer3DModel, which adds a compressed-frame patch embedding and the FAR causal block-mask.- LoRA loading is supported via
WanLoraLoaderMixin, the same mixin used by the upstream Wan pipelines. - For training recipes (forward flow-map training and on-policy distillation), refer to the original AnyFlow training framework at
NVlabs/AnyFlow; training is out of scope for diffusers.
AnyFlowPipeline
class diffusers.AnyFlowPipeline
< source >( tokenizer: AutoTokenizer text_encoder: UMT5EncoderModel transformer: AnyFlowTransformer3DModel vae: AutoencoderKLWan scheduler: FlowMapEulerDiscreteScheduler )
Parameters
- tokenizer ([AutoTokenizer]) — Tokenizer from google/umt5-xxl.
- text_encoder ([UMT5EncoderModel]) — google/umt5-xxl text encoder.
- transformer ([AnyFlowTransformer3DModel]) — Bidirectional flow-map 3D Transformer.
- vae ([AutoencoderKLWan]) — VAE that encodes/decodes videos to and from latent representations.
- scheduler ([FlowMapEulerDiscreteScheduler]) —
Flow-map sampler. The pipeline drives
scheduler.step(..., timestep, sample, r_timestep)per inference step.
Bidirectional text-to-video generation pipeline for AnyFlow flow-map-distilled checkpoints, introduced in AnyFlow by Yuchao Gu, Guian Fang et al.
AnyFlow learns arbitrary-interval transitions rather than the fixed mapping
of consistency models, so a single distilled checkpoint can be evaluated at 1, 2, 4, 8, 16… NFE without
retraining. This pipeline operates over the full video tensor in one bidirectional pass; for frame-level
autoregressive (causal) generation use AnyFlowFARPipeline.
Sampling is plain Euler in mean-velocity form (z_r = z_t - (t - r) * u) with no re-noising. The released NVIDIA
checkpoints fold classifier-free guidance into the model weights, so the default guidance_scale=1.0 is the
recommended setting.
This model inherits from [DiffusionPipeline]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.).
__call__
< source >( prompt: typing.Union[str, typing.List[str]] = None video: typing.Optional[torch.Tensor] = None video_latents: typing.Optional[torch.Tensor] = None negative_prompt: typing.Union[str, typing.List[str]] = None height: int = 480 width: int = 832 num_frames: int = 81 num_inference_steps: int = 50 sigmas: typing.Optional[typing.List[float]] = None timesteps: typing.Optional[typing.List[float]] = None guidance_scale: float = 1.0 num_videos_per_prompt: typing.Optional[int] = 1 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None output_type: typing.Optional[str] = 'np' return_dict: bool = True attention_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None callback_on_step_end: typing.Union[typing.Callable[[int, int, typing.Dict], NoneType], diffusers.callbacks.PipelineCallback, diffusers.callbacks.MultiPipelineCallbacks, NoneType] = None callback_on_step_end_tensor_inputs: typing.List[str] = ['latents'] max_sequence_length: int = 512 use_mean_velocity: bool = True ) → ~AnyFlowPipelineOutput or tuple
Parameters
- prompt (
strorList[str], optional) — The prompt or prompts to guide the video generation. If not defined, passprompt_embedsinstead. - video (
torch.Tensor, optional) — Pre-VAE conditioning frames of shape(B, T, C, H, W)in[0, 1]. When provided, the pipeline VAE-encodes them and keeps the corresponding latent prefix fixed during sampling. Mutually exclusive withvideo_latents. - video_latents (
torch.Tensor, optional) — Pre-encoded VAE latents in the AnyFlow layout(B, T_latent, C, H_latent, W_latent). Skips VAE encoding on the pipeline side. Mutually exclusive withvideo. - negative_prompt (
strorList[str], optional) — The prompt or prompts to avoid during video generation. Ignored when not using guidance (guidance_scale < 1). - height (
int, defaults to480) — The height in pixels of the generated video. - width (
int, defaults to832) — The width in pixels of the generated video. - num_frames (
int, defaults to81) — The number of frames in the generated video. Must satisfy(num_frames - 1) % vae_scale_factor_temporal == 0. - num_inference_steps (
int, defaults to50) — The number of denoising steps. Distilled AnyFlow checkpoints support any-step sampling, so values as low as1,2,4, or8are typical. Ignored whensigmasortimestepsis provided. - sigmas (
List[float], optional) — Custom sigma schedule for any-step sampling, in[0, 1]and ordered from noisy to clean. Length determines the effectivenum_inference_steps; the scheduler appends the terminal0sigma. - timesteps (
List[float], optional) — Custom timestep schedule for any-step sampling, in the same units asself.scheduler.timesteps(i.e. scaled bynum_train_timesteps). Mutually exclusive withsigmas. - guidance_scale (
float, defaults to1.0) — Classifier-free guidance scale. The released AnyFlow checkpoints fuse CFG into the weights during training; keep at1.0unless you know your checkpoint expects otherwise. - num_videos_per_prompt (
int, optional, defaults to1) — The number of videos to generate per prompt. - generator (
torch.GeneratororList[torch.Generator], optional) — Atorch.Generatorto make generation deterministic. - latents (
torch.Tensor, optional) — Pre-generated noisy latents to use as inputs. If not provided, latents are sampled from the suppliedgenerator. - prompt_embeds (
torch.Tensor, optional) — Pre-generated text embeddings. Can be used to tweak text inputs (e.g., prompt weighting). If not provided, embeddings are generated fromprompt. - negative_prompt_embeds (
torch.Tensor, optional) — Pre-generated negative text embeddings. - output_type (
str, optional, defaults to"np") — The output format. One of"pil","np","pt", or"latent". - return_dict (
bool, optional, defaults toTrue) — Whether to return anAnyFlowPipelineOutputinstead of a plain tuple. - attention_kwargs (
dict, optional) — A kwargs dictionary that if specified is passed along to theAttentionProcessoras defined underself.processorin diffusers.models.attention_processor. - callback_on_step_end (
Callable, optional) — A function orPipelineCallbackcalled at the end of each inference step. Seecallbacksfor details. - callback_on_step_end_tensor_inputs (
List[str], optional, defaults to["latents"]) — The tensor inputs forwarded to the callback. Must be a subset ofself._callback_tensor_inputs. - max_sequence_length (
int, defaults to512) — The maximum text-encoder sequence length. Longer prompts are truncated. - use_mean_velocity (
bool, defaults toTrue) — WhenTrue, the flow-map model is conditioned on both the source timesteptand the target timesteprto predict a mean velocity, matching the training-time behavior. Disable to mirror raw Euler stepping (r = t).
Returns
~AnyFlowPipelineOutput or tuple
If return_dict is True, AnyFlowPipelineOutput is returned, otherwise a tuple whose first
element is the generated video.
The call function to the pipeline for generation.
Examples:
>>> import torch
>>> from diffusers import AnyFlowPipeline
>>> from diffusers.utils import export_to_video
>>> pipe = AnyFlowPipeline.from_pretrained(
... "nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers", torch_dtype=torch.bfloat16
... ).to("cuda")
>>> prompt = "A red panda eating bamboo in a forest, cinematic lighting"
>>> video = pipe(prompt, num_inference_steps=4, num_frames=33).frames[0]
>>> export_to_video(video, "anyflow_t2v.mp4", fps=16)encode_prompt
< source >( prompt: str | list[str] negative_prompt: str | list[str] | None = None do_classifier_free_guidance: bool = True num_videos_per_prompt: int = 1 prompt_embeds: torch.Tensor | None = None negative_prompt_embeds: torch.Tensor | None = None max_sequence_length: int = 226 device: torch.device | None = None dtype: torch.dtype | None = None )
Parameters
- prompt (
strorlist[str], optional) — prompt to be encoded - negative_prompt (
strorlist[str], optional) — The prompt or prompts not to guide the image generation. If not defined, one has to passnegative_prompt_embedsinstead. Ignored when not using guidance (i.e., ignored ifguidance_scaleis less than1). - do_classifier_free_guidance (
bool, optional, defaults toTrue) — Whether to use classifier free guidance or not. - num_videos_per_prompt (
int, optional, defaults to 1) — Number of videos that should be generated per prompt. torch device to place the resulting embeddings on - prompt_embeds (
torch.Tensor, optional) — Pre-generated text embeddings. Can be used to easily tweak text inputs, e.g. prompt weighting. If not provided, text embeddings will be generated frompromptinput argument. - negative_prompt_embeds (
torch.Tensor, optional) — Pre-generated negative text embeddings. Can be used to easily tweak text inputs, e.g. prompt weighting. If not provided, negative_prompt_embeds will be generated fromnegative_promptinput argument. - device — (
torch.device, optional): torch device - dtype — (
torch.dtype, optional): torch dtype
Encodes the prompt into text encoder hidden states.
Encode a pixel-space video into AnyFlow’s latent layout.
Mirrors the single-helper convention of other diffusers pipelines (cf.
WanImageToVideoPipeline.encode_image): wraps preprocessing, VAE encoding, and latent normalization into one
call. Output layout is (B, T_latent, C, H, W), which is what the AnyFlow transformer expects for
conditioning frames.
AnyFlowFARPipeline
class diffusers.AnyFlowFARPipeline
< source >( tokenizer: AutoTokenizer text_encoder: UMT5EncoderModel transformer: AnyFlowFARTransformer3DModel vae: AutoencoderKLWan scheduler: FlowMapEulerDiscreteScheduler )
Parameters
- tokenizer ([AutoTokenizer]) — Tokenizer from google/umt5-xxl.
- text_encoder ([UMT5EncoderModel]) — google/umt5-xxl text encoder.
- transformer ([AnyFlowFARTransformer3DModel]) — FAR causal flow-map 3D Transformer.
- vae ([AutoencoderKLWan]) — VAE that encodes/decodes videos to and from latent representations.
- scheduler ([FlowMapEulerDiscreteScheduler]) — Flow-map sampler.
Causal (FAR-based) text-to-video / image-to-video / video-to-video pipeline for AnyFlow checkpoints, introduced in AnyFlow by Yuchao Gu, Guian Fang et al.
The pipeline drives a frame-level autoregressive sampling loop over chunks: each chunk is denoised with flow-map steps while attending only to past chunks via block-sparse causal attention, and intermediate KV cache is reused across chunks.
The task mode (T2V / I2V / V2V) is selected by which conditioning argument is passed to __call__:
- both
video=Noneandvideo_latents=None— pure text-to-video. video=<tensor of shape (B, T, C, H, W) in [0, 1] with T = 4n + 1>— pre-VAE conditioning frames; the pipeline VAE-encodes them. Pass a single-frame video for I2V or a multi-frame clip for V2V.video_latents=<latent tensor of shape (B, T_latent, C, H_latent, W_latent)>— already-encoded latents in the FAR layout (skips the VAE encode step).
The FAR backbone is the causal Wan2.1 variant introduced by FAR (Gu et al., 2025; arXiv:2503.19325). Inference is plain Euler in mean-velocity form per chunk with no re-noising. Joint T2V / I2V / V2V is supported by a single distilled model.
This model inherits from [DiffusionPipeline]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.).
__call__
< source >( prompt: typing.Union[str, typing.List[str]] = None video: typing.Optional[torch.Tensor] = None video_latents: typing.Optional[torch.Tensor] = None negative_prompt: typing.Union[str, typing.List[str]] = None height: int = 480 width: int = 832 num_frames: int = 81 num_inference_steps: int = 50 sigmas: typing.Optional[typing.List[float]] = None timesteps: typing.Optional[typing.List[float]] = None guidance_scale: float = 1.0 num_videos_per_prompt: typing.Optional[int] = 1 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None output_type: typing.Optional[str] = 'np' return_dict: bool = True attention_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None callback_on_step_end: typing.Union[typing.Callable[[int, int, typing.Dict], NoneType], diffusers.callbacks.PipelineCallback, diffusers.callbacks.MultiPipelineCallbacks, NoneType] = None callback_on_step_end_tensor_inputs: typing.List[str] = ['latents'] max_sequence_length: int = 512 use_mean_velocity: bool = True use_kv_cache: bool = True chunk_partition: typing.Optional[typing.List[int]] = None ) → ~AnyFlowPipelineOutput or tuple
Parameters
- prompt (
strorList[str], optional) — The prompt or prompts to guide the video generation. If not defined, passprompt_embedsinstead. - video (
torch.Tensor, optional) — Pre-VAE conditioning frames of shape(B, T, C, H, W)in[0, 1](T = 4n + 1). When provided, the pipeline VAE-encodes them and keeps the corresponding latent prefix fixed during sampling. Mutually exclusive withvideo_latents. - video_latents (
torch.Tensor, optional) — Pre-encoded VAE latents in the FAR layout(B, T_latent, C, H_latent, W_latent). Skips VAE encoding on the pipeline side. Mutually exclusive withvideo. - negative_prompt (
strorList[str], optional) — The prompt or prompts to avoid during video generation. Ignored when not using guidance (guidance_scale < 1). - height (
int, defaults to480) — The height in pixels of the generated video. - width (
int, defaults to832) — The width in pixels of the generated video. - num_frames (
int, defaults to81) — The number of frames in the generated video. Must satisfy(num_frames - 1) % vae_scale_factor_temporal == 0. - num_inference_steps (
int, defaults to50) — The number of denoising steps per chunk. Distilled AnyFlow-FAR checkpoints support any-step sampling (1, 2, 4, 8, …). Ignored whensigmasortimestepsis provided. - sigmas (
List[float], optional) — Custom sigma schedule for any-step sampling, in[0, 1]and ordered from noisy to clean. Length determines the effectivenum_inference_steps; the scheduler appends the terminal0sigma. - timesteps (
List[float], optional) — Custom timestep schedule for any-step sampling, in the same units asself.scheduler.timesteps(i.e. scaled bynum_train_timesteps). Mutually exclusive withsigmas. - guidance_scale (
float, defaults to1.0) — Classifier-free guidance scale. The released AnyFlow checkpoints fuse CFG into the weights during training; keep at1.0unless the checkpoint requires otherwise. - num_videos_per_prompt (
int, optional, defaults to1) — The number of videos to generate per prompt. - generator (
torch.GeneratororList[torch.Generator], optional) — Generator used to seed sampling. - latents (
torch.Tensor, optional) — Pre-generated noisy latents. If not provided, latents are sampled from the suppliedgenerator. - prompt_embeds (
torch.Tensor, optional) — Pre-generated text embeddings. If not provided, embeddings are generated fromprompt. - negative_prompt_embeds (
torch.Tensor, optional) — Pre-generated negative text embeddings. - output_type (
str, optional, defaults to"np") — Output format. One of"pil","np","pt", or"latent". - return_dict (
bool, optional, defaults toTrue) — Whether to return anAnyFlowPipelineOutputinstead of a plain tuple. - attention_kwargs (
dict, optional) — A kwargs dictionary that if specified is passed along to theAttentionProcessoras defined underself.processorin diffusers.models.attention_processor. - callback_on_step_end (
Callable, optional) — A function orPipelineCallbackcalled at the end of each inference step. - callback_on_step_end_tensor_inputs (
List[str], optional, defaults to["latents"]) — Tensor inputs forwarded to the callback. Must be a subset ofself._callback_tensor_inputs. - max_sequence_length (
int, defaults to512) — The maximum text-encoder sequence length. - use_mean_velocity (
bool, defaults toTrue) — WhenTrue, condition the flow-map model on both the source timesteptand the target timesteprto predict a mean velocity. Disable to mirror raw Euler stepping. - use_kv_cache (
bool, defaults toTrue) — Reuse the FAR attention KV cache across causal chunks. Disable only for debugging. - chunk_partition (
List[int], optional) — Per-chunk frame counts. Defaults todefault_chunk_partition(matched to the released 81-frame checkpoints). When you changenum_frames, supply achunk_partitionthat sums to(num_frames - 1) // vae_scale_factor_temporal + 1.
Returns
~AnyFlowPipelineOutput or tuple
If return_dict is True, an AnyFlowPipelineOutput is returned, otherwise a tuple whose first
element is the generated video.
The call function to the pipeline for generation.
Examples:
>>> import numpy as np
>>> import torch
>>> from diffusers import AnyFlowFARPipeline
>>> from diffusers.utils import export_to_video, load_image
>>> pipe = AnyFlowFARPipeline.from_pretrained(
... "nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", torch_dtype=torch.bfloat16
... ).to("cuda")
>>> # Single-frame I2V: wrap the conditioning image as a (1, 1, 3, H, W) tensor in [0, 1].
>>> first_frame = load_image("path/to/first_frame.png").resize((832, 480))
>>> arr = np.asarray(first_frame).astype("float32") / 255.0
>>> context = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).unsqueeze(1).to("cuda")
>>> video = pipe(
... prompt="a cat walks across a sunlit lawn",
... video=context,
... num_inference_steps=4,
... num_frames=81,
... ).frames[0]
>>> export_to_video(video, "anyflow_far.mp4", fps=16)encode_prompt
< source >( prompt: str | list[str] negative_prompt: str | list[str] | None = None do_classifier_free_guidance: bool = True num_videos_per_prompt: int = 1 prompt_embeds: torch.Tensor | None = None negative_prompt_embeds: torch.Tensor | None = None max_sequence_length: int = 226 device: torch.device | None = None dtype: torch.dtype | None = None )
Parameters
- prompt (
strorlist[str], optional) — prompt to be encoded - negative_prompt (
strorlist[str], optional) — The prompt or prompts not to guide the image generation. If not defined, one has to passnegative_prompt_embedsinstead. Ignored when not using guidance (i.e., ignored ifguidance_scaleis less than1). - do_classifier_free_guidance (
bool, optional, defaults toTrue) — Whether to use classifier free guidance or not. - num_videos_per_prompt (
int, optional, defaults to 1) — Number of videos that should be generated per prompt. torch device to place the resulting embeddings on - prompt_embeds (
torch.Tensor, optional) — Pre-generated text embeddings. Can be used to easily tweak text inputs, e.g. prompt weighting. If not provided, text embeddings will be generated frompromptinput argument. - negative_prompt_embeds (
torch.Tensor, optional) — Pre-generated negative text embeddings. Can be used to easily tweak text inputs, e.g. prompt weighting. If not provided, negative_prompt_embeds will be generated fromnegative_promptinput argument. - device — (
torch.device, optional): torch device - dtype — (
torch.dtype, optional): torch dtype
Encodes the prompt into text encoder hidden states.
Encode a pixel-space video into AnyFlow’s latent layout.
Mirrors the single-helper convention of other diffusers pipelines (cf.
WanImageToVideoPipeline.encode_image): wraps preprocessing, VAE encoding, and latent normalization into one
call. Output layout is (B, T_latent, C, H, W), which is what the AnyFlow transformer expects for
conditioning frames.
AnyFlowPipelineOutput
class diffusers.pipelines.anyflow.pipeline_output.AnyFlowPipelineOutput
< source >( frames: Tensor )
Parameters
- frames (
torch.Tensor,np.ndarray, or list[list[PIL.Image.Image]]) — list of video outputs - It can be a nested list of lengthbatch_size,with each sub-list containing denoised PIL image sequences of lengthnum_frames.It can also be a NumPy array or Torch tensor of shape(batch_size, num_frames, channels, height, width).
Output class for AnyFlow pipelines.