| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| |
|
| | from .tools.t5 import T5EncoderModel |
| | from .tools.wan_model import WanModel |
| |
|
| |
|
| | class DiffForcingWanModel(nn.Module): |
| | def __init__( |
| | self, |
| | checkpoint_path="deps/t5_umt5-xxl-enc-bf16/models_t5_umt5-xxl-enc-bf16.pth", |
| | tokenizer_path="deps/t5_umt5-xxl-enc-bf16/google/umt5-xxl", |
| | input_dim=256, |
| | hidden_dim=1024, |
| | ffn_dim=2048, |
| | freq_dim=256, |
| | num_heads=8, |
| | num_layers=8, |
| | time_embedding_scale=1.0, |
| | chunk_size=5, |
| | noise_steps=10, |
| | use_text_cond=True, |
| | text_len=128, |
| | drop_out=0.1, |
| | cfg_scale=5.0, |
| | prediction_type="vel", |
| | causal=False, |
| | ): |
| | super().__init__() |
| |
|
| | self.input_dim = input_dim |
| | self.hidden_dim = hidden_dim |
| | self.ffn_dim = ffn_dim |
| | self.freq_dim = freq_dim |
| | self.num_heads = num_heads |
| | self.num_layers = num_layers |
| | self.time_embedding_scale = time_embedding_scale |
| | self.chunk_size = chunk_size |
| | self.noise_steps = noise_steps |
| | self.use_text_cond = use_text_cond |
| | self.drop_out = drop_out |
| | self.cfg_scale = cfg_scale |
| | self.prediction_type = prediction_type |
| | self.causal = causal |
| |
|
| | self.text_dim = 4096 |
| | self.text_len = text_len |
| | self.text_encoder = T5EncoderModel( |
| | text_len=self.text_len, |
| | dtype=torch.bfloat16, |
| | device=torch.device("cpu"), |
| | checkpoint_path=checkpoint_path, |
| | tokenizer_path=tokenizer_path, |
| | shard_fn=None, |
| | ) |
| |
|
| | |
| | self.text_cache = {} |
| | self.model = WanModel( |
| | model_type="t2v", |
| | patch_size=(1, 1, 1), |
| | text_len=self.text_len, |
| | in_dim=self.input_dim, |
| | dim=self.hidden_dim, |
| | ffn_dim=self.ffn_dim, |
| | freq_dim=self.freq_dim, |
| | text_dim=self.text_dim, |
| | out_dim=self.input_dim, |
| | num_heads=self.num_heads, |
| | num_layers=self.num_layers, |
| | window_size=(-1, -1), |
| | qk_norm=True, |
| | cross_attn_norm=True, |
| | eps=1e-6, |
| | causal=self.causal, |
| | ) |
| | self.param_dtype = torch.float32 |
| |
|
| | def encode_text_with_cache(self, text_list, device): |
| | """Encode text using cache |
| | Args: |
| | text_list: List[str], list of texts |
| | device: torch.device |
| | Returns: |
| | List[Tensor]: List of encoded text features |
| | """ |
| | text_features = [] |
| | indices_to_encode = [] |
| | texts_to_encode = [] |
| |
|
| | |
| | for i, text in enumerate(text_list): |
| | if text in self.text_cache: |
| | |
| | cached_feature = self.text_cache[text].to(device) |
| | text_features.append(cached_feature) |
| | else: |
| | |
| | text_features.append(None) |
| | indices_to_encode.append(i) |
| | texts_to_encode.append(text) |
| |
|
| | |
| | if texts_to_encode: |
| | self.text_encoder.model.to(device) |
| | encoded = self.text_encoder(texts_to_encode, device) |
| |
|
| | |
| | for idx, text, feature in zip(indices_to_encode, texts_to_encode, encoded): |
| | |
| | self.text_cache[text] = feature.cpu() |
| | text_features[idx] = feature |
| |
|
| | return text_features |
| |
|
| | def preprocess(self, x): |
| | |
| | x = x.permute(0, 2, 1)[:, :, :, None, None] |
| | return x |
| |
|
| | def postprocess(self, x): |
| | |
| | x = x.permute(0, 2, 1, 3, 4).contiguous().view(x.size(0), x.size(2), -1) |
| | return x |
| |
|
| | def _get_noise_levels(self, device, seq_len, time_steps): |
| | """Get noise levels""" |
| | |
| | noise_level = torch.clamp( |
| | 1 |
| | + torch.arange(seq_len, device=device) / self.chunk_size |
| | - time_steps.unsqueeze(1), |
| | min=0.0, |
| | max=1.0, |
| | ) |
| | return noise_level |
| |
|
| | def add_noise(self, x, noise_level): |
| | """Add noise |
| | Args: |
| | x: (B, T, D) |
| | noise_level: (B, T) |
| | """ |
| | noise = torch.randn_like(x) |
| | |
| | noise_level = noise_level.unsqueeze(-1) |
| | noisy_x = x * (1 - noise_level) + noise_level * noise |
| | return noisy_x, noise |
| |
|
| | def forward(self, x): |
| | feature = x["feature"] |
| | feature_length = x["feature_length"] |
| | batch_size, seq_len, _ = feature.shape |
| | device = feature.device |
| |
|
| | |
| | time_steps = [] |
| | for i in range(batch_size): |
| | valid_len = feature_length[i].item() |
| | |
| | max_time = valid_len / self.chunk_size |
| | |
| | time_steps.append(torch.FloatTensor(1).uniform_(0, max_time).item()) |
| | time_steps = torch.tensor(time_steps, device=device) |
| | noise_level = self._get_noise_levels(device, seq_len, time_steps) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | noisy_feature, noise = self.add_noise(feature, noise_level) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | feature = self.preprocess(feature) |
| | noisy_feature = self.preprocess(noisy_feature) |
| | noise = self.preprocess(noise) |
| |
|
| | feature_ref = [] |
| | noise_ref = [] |
| | noisy_feature_input = [] |
| | for i in range(batch_size): |
| | t = time_steps[i].item() |
| | end_index = int(self.chunk_size * t) + 1 |
| | valid_len = feature_length[i].item() |
| | end_index = min(valid_len, end_index) |
| | feature_ref.append(feature[i, :, :end_index, ...]) |
| | noise_ref.append(noise[i, :, :end_index, ...]) |
| | noisy_feature_input.append(noisy_feature[i, :, :end_index, ...]) |
| |
|
| | |
| | if self.use_text_cond and "text" in x: |
| | text_list = x["text"] |
| | if isinstance(text_list[0], list): |
| | text_end_list = x["feature_text_end"] |
| | all_text_context = [] |
| | for single_text_list, single_text_end_list in zip( |
| | text_list, text_end_list |
| | ): |
| | if np.random.rand() > self.drop_out: |
| | single_text_end_list = [0] + [ |
| | min(t, seq_len) for t in single_text_end_list |
| | ] |
| | else: |
| | single_text_list = [""] |
| | single_text_end_list = [0, seq_len] |
| | single_text_length_list = [ |
| | t - b |
| | for t, b in zip( |
| | single_text_end_list[1:], single_text_end_list[:-1] |
| | ) |
| | ] |
| | single_text_context = self.encode_text_with_cache( |
| | single_text_list, device |
| | ) |
| | single_text_context = [ |
| | u.to(self.param_dtype) for u in single_text_context |
| | ] |
| | for u, duration in zip( |
| | single_text_context, single_text_length_list |
| | ): |
| | all_text_context.extend([u for _ in range(duration)]) |
| | all_text_context.extend( |
| | [ |
| | single_text_context[-1] |
| | for _ in range(seq_len - single_text_end_list[-1]) |
| | ] |
| | ) |
| | else: |
| | all_text_context = [ |
| | (u if np.random.rand() > self.drop_out else "") for u in text_list |
| | ] |
| | all_text_context = self.encode_text_with_cache(all_text_context, device) |
| | all_text_context = [u.to(self.param_dtype) for u in all_text_context] |
| | else: |
| | all_text_context = [""] * batch_size |
| | all_text_context = self.encode_text_with_cache(all_text_context, device) |
| | all_text_context = [u.to(self.param_dtype) for u in all_text_context] |
| |
|
| | |
| | predicted_result = self.model( |
| | noisy_feature_input, |
| | noise_level * self.time_embedding_scale, |
| | all_text_context, |
| | seq_len, |
| | y=None, |
| | ) |
| |
|
| | loss = 0.0 |
| | for b in range(batch_size): |
| | if self.prediction_type == "vel": |
| | vel = feature_ref[b] - noise_ref[b] |
| | squared_error = ( |
| | predicted_result[b][:, -self.chunk_size :, ...] |
| | - vel[:, -self.chunk_size :, ...] |
| | ) ** 2 |
| | elif self.prediction_type == "x0": |
| | squared_error = ( |
| | predicted_result[b][:, -self.chunk_size :, ...] |
| | - feature_ref[b][:, -self.chunk_size :, ...] |
| | ) ** 2 |
| | elif self.prediction_type == "noise": |
| | squared_error = ( |
| | predicted_result[b][:, -self.chunk_size :, ...] |
| | - noise_ref[b][:, -self.chunk_size :, ...] |
| | ) ** 2 |
| | sample_loss = squared_error.sum().mean() |
| | loss += sample_loss |
| | loss = loss / batch_size |
| |
|
| | loss_dict = {"total": loss, "mse": loss} |
| | return loss_dict |
| |
|
| | def generate(self, x, num_denoise_steps=None): |
| | """ |
| | Generation - Diffusion Forcing inference |
| | Uses triangular noise schedule, progressively generating from left to right |
| | |
| | Generation process: |
| | 1. Start from t=0, gradually increase t |
| | 2. Each t corresponds to a noise schedule: clean on left, noisy on right, gradient in middle |
| | 3. After each denoising step, t increases slightly and continues |
| | """ |
| | feature_length = x["feature_length"] |
| | batch_size = len(feature_length) |
| | seq_len = max(feature_length).item() |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | if num_denoise_steps is None: |
| | num_denoise_steps = self.noise_steps |
| | assert num_denoise_steps % self.chunk_size == 0 |
| |
|
| | device = next(self.parameters()).device |
| |
|
| | |
| | generated = torch.randn( |
| | batch_size, seq_len + self.chunk_size, self.input_dim, device=device |
| | ) |
| | generated = self.preprocess(generated) |
| |
|
| | |
| | max_t = 1 + (seq_len - 1) / self.chunk_size |
| |
|
| | |
| | dt = 1 / num_denoise_steps |
| | total_steps = int(max_t / dt) |
| |
|
| | |
| | if self.use_text_cond and "text" in x: |
| | text_list = x["text"] |
| | if isinstance(text_list[0], list): |
| | generated_length = [] |
| | text_end_list = x["feature_text_end"] |
| | full_text = [] |
| | all_text_context = [] |
| | for single_text_list, single_text_end_list in zip( |
| | text_list, text_end_list |
| | ): |
| | single_text_end_list = [0] + [ |
| | min(t, seq_len) for t in single_text_end_list |
| | ] |
| | generated_length.append(single_text_end_list[-1]) |
| | single_text_length_list = [ |
| | t - b |
| | for t, b in zip( |
| | single_text_end_list[1:], single_text_end_list[:-1] |
| | ) |
| | ] |
| | full_text.append( |
| | " ////////// ".join( |
| | [ |
| | f"{u} //dur:{t}" |
| | for u, t in zip( |
| | single_text_list, single_text_length_list |
| | ) |
| | ] |
| | ) |
| | ) |
| | single_text_context = self.encode_text_with_cache( |
| | single_text_list, device |
| | ) |
| | single_text_context = [ |
| | u.to(self.param_dtype) for u in single_text_context |
| | ] |
| | for u, duration in zip( |
| | single_text_context, single_text_length_list |
| | ): |
| | all_text_context.extend([u for _ in range(duration)]) |
| | all_text_context.extend( |
| | [ |
| | single_text_context[-1] |
| | for _ in range( |
| | seq_len + self.chunk_size - single_text_end_list[-1] |
| | ) |
| | ] |
| | ) |
| | else: |
| | generated_length = feature_length |
| | full_text = text_list |
| | all_text_context = self.encode_text_with_cache(text_list, device) |
| | all_text_context = [u.to(self.param_dtype) for u in all_text_context] |
| | else: |
| | generated_length = feature_length |
| | full_text = [""] * batch_size |
| | all_text_context = [""] * batch_size |
| | all_text_context = self.encode_text_with_cache(all_text_context, device) |
| | all_text_context = [u.to(self.param_dtype) for u in all_text_context] |
| |
|
| | |
| | text_null_list = [""] * batch_size |
| | text_null_context = self.encode_text_with_cache(text_null_list, device) |
| | text_null_context = [u.to(self.param_dtype) for u in text_null_context] |
| |
|
| | |
| |
|
| | |
| | for step in range(total_steps): |
| | |
| | t = step * dt |
| | start_index = max(0, int(self.chunk_size * (t - 1)) + 1) |
| | end_index = int(self.chunk_size * t) + 1 |
| | time_steps = torch.full((batch_size,), t, device=device) |
| |
|
| | |
| | noise_level = self._get_noise_levels( |
| | device, seq_len + self.chunk_size, time_steps |
| | ) |
| |
|
| | |
| | noisy_input = [] |
| | for i in range(batch_size): |
| | noisy_input.append(generated[i, :, :end_index, ...]) |
| |
|
| | predicted_result = self.model( |
| | noisy_input, |
| | noise_level * self.time_embedding_scale, |
| | all_text_context, |
| | seq_len + self.chunk_size, |
| | y=None, |
| | ) |
| |
|
| | |
| | if self.cfg_scale != 1.0: |
| | predicted_result_null = self.model( |
| | noisy_input, |
| | noise_level * self.time_embedding_scale, |
| | text_null_context, |
| | seq_len + self.chunk_size, |
| | y=None, |
| | ) |
| | predicted_result = [ |
| | self.cfg_scale * pv - (self.cfg_scale - 1) * pvn |
| | for pv, pvn in zip(predicted_result, predicted_result_null) |
| | ] |
| |
|
| | for i in range(batch_size): |
| | predicted_result_i = predicted_result[i] |
| | if self.prediction_type == "vel": |
| | predicted_vel = predicted_result_i[:, start_index:end_index, ...] |
| | generated[i, :, start_index:end_index, ...] += predicted_vel * dt |
| | elif self.prediction_type == "x0": |
| | predicted_vel = ( |
| | predicted_result_i[:, start_index:end_index, ...] |
| | - generated[i, :, start_index:end_index, ...] |
| | ) / ( |
| | noise_level[i, start_index:end_index] |
| | .unsqueeze(0) |
| | .unsqueeze(-1) |
| | .unsqueeze(-1) |
| | ) |
| | generated[i, :, start_index:end_index, ...] += predicted_vel * dt |
| | elif self.prediction_type == "noise": |
| | predicted_vel = ( |
| | generated[i, :, start_index:end_index, ...] |
| | - predicted_result_i[:, start_index:end_index, ...] |
| | ) / ( |
| | 1 |
| | + dt |
| | - noise_level[i, start_index:end_index] |
| | .unsqueeze(0) |
| | .unsqueeze(-1) |
| | .unsqueeze(-1) |
| | ) |
| | generated[i, :, start_index:end_index, ...] += predicted_vel * dt |
| |
|
| | generated = self.postprocess(generated) |
| | y_hat_out = [] |
| | for i in range(batch_size): |
| | |
| | single_generated = generated[i, : generated_length[i], :] |
| | y_hat_out.append(single_generated) |
| | out = {} |
| | out["generated"] = y_hat_out |
| | out["text"] = full_text |
| |
|
| | return out |
| |
|
| | @torch.no_grad() |
| | def stream_generate(self, x, num_denoise_steps=None): |
| | """ |
| | Streaming generation - Diffusion Forcing inference |
| | Uses triangular noise schedule, progressively generating from left to right |
| | |
| | Generation process: |
| | 1. Start from t=0, gradually increase t |
| | 2. Each t corresponds to a noise schedule: clean on left, noisy on right, gradient in middle |
| | 3. After each denoising step, t increases slightly and continues |
| | """ |
| | feature_length = x["feature_length"] |
| | batch_size = len(feature_length) |
| | seq_len = max(feature_length).item() |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | if num_denoise_steps is None: |
| | num_denoise_steps = self.noise_steps |
| | assert num_denoise_steps % self.chunk_size == 0 |
| |
|
| | device = next(self.parameters()).device |
| |
|
| | |
| | generated = torch.randn( |
| | batch_size, seq_len + self.chunk_size, self.input_dim, device=device |
| | ) |
| | generated = self.preprocess(generated) |
| |
|
| | |
| | max_t = 1 + (seq_len - 1) / self.chunk_size |
| |
|
| | |
| | dt = 1 / num_denoise_steps |
| | total_steps = int(max_t / dt) |
| |
|
| | |
| | if self.use_text_cond and "text" in x: |
| | text_list = x["text"] |
| | if isinstance(text_list[0], list): |
| | generated_length = [] |
| | text_end_list = x["feature_text_end"] |
| | full_text = [] |
| | all_text_context = [] |
| | for single_text_list, single_text_end_list in zip( |
| | text_list, text_end_list |
| | ): |
| | single_text_end_list = [0] + [ |
| | min(t, seq_len) for t in single_text_end_list |
| | ] |
| | generated_length.append(single_text_end_list[-1]) |
| | single_text_length_list = [ |
| | t - b |
| | for t, b in zip( |
| | single_text_end_list[1:], single_text_end_list[:-1] |
| | ) |
| | ] |
| | full_text.append( |
| | " ////////// ".join( |
| | [ |
| | f"{u} //dur:{t}" |
| | for u, t in zip( |
| | single_text_list, single_text_length_list |
| | ) |
| | ] |
| | ) |
| | ) |
| | single_text_context = self.encode_text_with_cache( |
| | single_text_list, device |
| | ) |
| | single_text_context = [ |
| | u.to(self.param_dtype) for u in single_text_context |
| | ] |
| | for u, duration in zip( |
| | single_text_context, single_text_length_list |
| | ): |
| | all_text_context.extend([u for _ in range(duration)]) |
| | all_text_context.extend( |
| | [ |
| | single_text_context[-1] |
| | for _ in range( |
| | seq_len + self.chunk_size - single_text_end_list[-1] |
| | ) |
| | ] |
| | ) |
| | else: |
| | generated_length = feature_length |
| | full_text = text_list |
| | all_text_context = self.encode_text_with_cache(text_list, device) |
| | all_text_context = [u.to(self.param_dtype) for u in all_text_context] |
| | else: |
| | generated_length = feature_length |
| | full_text = [""] * batch_size |
| | all_text_context = [""] * batch_size |
| | all_text_context = self.encode_text_with_cache(all_text_context, device) |
| | all_text_context = [u.to(self.param_dtype) for u in all_text_context] |
| |
|
| | |
| | text_null_list = [""] * batch_size |
| | text_null_context = self.encode_text_with_cache(text_null_list, device) |
| | text_null_context = [u.to(self.param_dtype) for u in text_null_context] |
| |
|
| | |
| |
|
| | commit_index = 0 |
| | |
| | for step in range(total_steps): |
| | |
| | t = step * dt |
| | start_index = max(0, int(self.chunk_size * (t - 1)) + 1) |
| | end_index = int(self.chunk_size * t) + 1 |
| | time_steps = torch.full((batch_size,), t, device=device) |
| |
|
| | |
| | noise_level = self._get_noise_levels( |
| | device, seq_len + self.chunk_size, time_steps |
| | ) |
| |
|
| | |
| | noisy_input = [] |
| | for i in range(batch_size): |
| | noisy_input.append(generated[i, :, :end_index, ...]) |
| |
|
| | predicted_result = self.model( |
| | noisy_input, |
| | noise_level * self.time_embedding_scale, |
| | all_text_context, |
| | seq_len + self.chunk_size, |
| | y=None, |
| | ) |
| |
|
| | |
| | if self.cfg_scale != 1.0: |
| | predicted_result_null = self.model( |
| | noisy_input, |
| | noise_level * self.time_embedding_scale, |
| | text_null_context, |
| | seq_len + self.chunk_size, |
| | y=None, |
| | ) |
| | predicted_result = [ |
| | self.cfg_scale * pv - (self.cfg_scale - 1) * pvn |
| | for pv, pvn in zip(predicted_result, predicted_result_null) |
| | ] |
| |
|
| | for i in range(batch_size): |
| | predicted_result_i = predicted_result[i] |
| | if self.prediction_type == "vel": |
| | predicted_vel = predicted_result_i[:, start_index:end_index, ...] |
| | generated[i, :, start_index:end_index, ...] += predicted_vel * dt |
| | elif self.prediction_type == "x0": |
| | predicted_vel = ( |
| | predicted_result_i[:, start_index:end_index, ...] |
| | - generated[i, :, start_index:end_index, ...] |
| | ) / ( |
| | noise_level[i, start_index:end_index] |
| | .unsqueeze(0) |
| | .unsqueeze(-1) |
| | .unsqueeze(-1) |
| | ) |
| | generated[i, :, start_index:end_index, ...] += predicted_vel * dt |
| | elif self.prediction_type == "noise": |
| | predicted_vel = ( |
| | generated[i, :, start_index:end_index, ...] |
| | - predicted_result_i[:, start_index:end_index, ...] |
| | ) / ( |
| | 1 |
| | + dt |
| | - noise_level[i, start_index:end_index] |
| | .unsqueeze(0) |
| | .unsqueeze(-1) |
| | .unsqueeze(-1) |
| | ) |
| | generated[i, :, start_index:end_index, ...] += predicted_vel * dt |
| |
|
| | if commit_index < start_index: |
| | output = generated[:, :, commit_index:start_index, ...] |
| | output = self.postprocess(output) |
| | y_hat_out = [] |
| | for i in range(batch_size): |
| | if commit_index < generated_length[i]: |
| | y_hat_out.append( |
| | output[i, : generated_length[i] - commit_index, ...] |
| | ) |
| | else: |
| | y_hat_out.append(None) |
| |
|
| | out = {} |
| | out["generated"] = y_hat_out |
| | yield out |
| | commit_index = start_index |
| |
|
| | output = generated[:, :, commit_index:, ...] |
| | output = self.postprocess(output) |
| | y_hat_out = [] |
| | for i in range(batch_size): |
| | if commit_index < generated_length[i]: |
| | y_hat_out.append(output[i, : generated_length[i] - commit_index, ...]) |
| | else: |
| | y_hat_out.append(None) |
| | out = {} |
| | out["generated"] = y_hat_out |
| | yield out |
| |
|
| | def init_generated(self, seq_len, batch_size=1, num_denoise_steps=None): |
| | self.seq_len = seq_len |
| | self.batch_size = batch_size |
| | if num_denoise_steps is None: |
| | self.num_denoise_steps = self.noise_steps |
| | else: |
| | self.num_denoise_steps = num_denoise_steps |
| | assert self.num_denoise_steps % self.chunk_size == 0 |
| | self.dt = 1 / self.num_denoise_steps |
| | self.current_step = 0 |
| | self.text_condition_list = [[] for _ in range(self.batch_size)] |
| | self.generated = torch.randn( |
| | self.batch_size, self.seq_len * 2 + self.chunk_size, self.input_dim |
| | ) |
| | self.generated = self.preprocess(self.generated) |
| | self.commit_index = 0 |
| |
|
| | @torch.no_grad() |
| | def stream_generate_step(self, x, first_chunk=True): |
| | """ |
| | Streaming generation step - Diffusion Forcing inference |
| | Uses triangular noise schedule, progressively generating from left to right |
| | |
| | Generation process: |
| | 1. Start from t=0, gradually increase t |
| | 2. Each t corresponds to a noise schedule: clean on left, noisy on right, gradient in middle |
| | 3. After each denoising step, t increases slightly and continues |
| | """ |
| |
|
| | device = next(self.parameters()).device |
| | if first_chunk: |
| | self.generated = self.generated.to(device) |
| |
|
| | |
| | if self.use_text_cond and "text" in x: |
| | text_list = x["text"] |
| | new_text_context = self.encode_text_with_cache(text_list, device) |
| | new_text_context = [u.to(self.param_dtype) for u in new_text_context] |
| | else: |
| | new_text_context = [""] * self.batch_size |
| | new_text_context = self.encode_text_with_cache(new_text_context, device) |
| | new_text_context = [u.to(self.param_dtype) for u in new_text_context] |
| |
|
| | |
| | text_null_list = [""] * self.batch_size |
| | text_null_context = self.encode_text_with_cache(text_null_list, device) |
| | text_null_context = [u.to(self.param_dtype) for u in text_null_context] |
| |
|
| | for i in range(self.batch_size): |
| | if first_chunk: |
| | self.text_condition_list[i].extend( |
| | [new_text_context[i]] * self.chunk_size |
| | ) |
| | else: |
| | self.text_condition_list[i].extend([new_text_context[i]]) |
| |
|
| | end_step = ( |
| | (self.commit_index + self.chunk_size) |
| | * self.num_denoise_steps |
| | / self.chunk_size |
| | ) |
| | while self.current_step < end_step: |
| | current_time = self.current_step * self.dt |
| | start_index = max(0, int(self.chunk_size * (current_time - 1)) + 1) |
| | end_index = int(self.chunk_size * current_time) + 1 |
| | time_steps = torch.full((self.batch_size,), current_time, device=device) |
| |
|
| | noise_level = self._get_noise_levels(device, end_index, time_steps)[ |
| | :, -self.seq_len : |
| | ] |
| |
|
| | |
| | noisy_input = [] |
| | for i in range(self.batch_size): |
| | noisy_input.append( |
| | self.generated[i, :, :end_index, ...][:, -self.seq_len :] |
| | ) |
| |
|
| | text_condition = [] |
| | for i in range(self.batch_size): |
| | text_condition.extend( |
| | self.text_condition_list[i][:end_index][-self.seq_len :] |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | predicted_result = self.model( |
| | noisy_input, |
| | noise_level * self.time_embedding_scale, |
| | text_condition, |
| | min(end_index, self.seq_len), |
| | y=None, |
| | ) |
| |
|
| | |
| | if self.cfg_scale != 1.0: |
| | predicted_result_null = self.model( |
| | noisy_input, |
| | noise_level * self.time_embedding_scale, |
| | text_null_context, |
| | min(end_index, self.seq_len), |
| | y=None, |
| | ) |
| | predicted_result = [ |
| | self.cfg_scale * pv - (self.cfg_scale - 1) * pvn |
| | for pv, pvn in zip(predicted_result, predicted_result_null) |
| | ] |
| |
|
| | for i in range(self.batch_size): |
| | predicted_result_i = predicted_result[i] |
| | if end_index > self.seq_len: |
| | predicted_result_i = torch.cat( |
| | [ |
| | torch.zeros( |
| | predicted_result_i.shape[0], |
| | end_index - self.seq_len, |
| | predicted_result_i.shape[2], |
| | predicted_result_i.shape[3], |
| | device=device, |
| | ), |
| | predicted_result_i, |
| | ], |
| | dim=1, |
| | ) |
| | if self.prediction_type == "vel": |
| | predicted_vel = predicted_result_i[:, start_index:end_index, ...] |
| | self.generated[i, :, start_index:end_index, ...] += ( |
| | predicted_vel * self.dt |
| | ) |
| | elif self.prediction_type == "x0": |
| | predicted_vel = ( |
| | predicted_result_i[:, start_index:end_index, ...] |
| | - self.generated[i, :, start_index:end_index, ...] |
| | ) / ( |
| | noise_level[i, start_index:end_index] |
| | .unsqueeze(0) |
| | .unsqueeze(-1) |
| | .unsqueeze(-1) |
| | ) |
| | self.generated[i, :, start_index:end_index, ...] += ( |
| | predicted_vel * self.dt |
| | ) |
| | elif self.prediction_type == "noise": |
| | predicted_vel = ( |
| | self.generated[i, :, start_index:end_index, ...] |
| | - predicted_result_i[:, start_index:end_index, ...] |
| | ) / ( |
| | 1 |
| | + self.dt |
| | - noise_level[i, start_index:end_index] |
| | .unsqueeze(0) |
| | .unsqueeze(-1) |
| | .unsqueeze(-1) |
| | ) |
| | self.generated[i, :, start_index:end_index, ...] += ( |
| | predicted_vel * self.dt |
| | ) |
| | self.current_step += 1 |
| | output = self.generated[:, :, self.commit_index : self.commit_index + 1, ...] |
| | output = self.postprocess(output) |
| | out = {} |
| | out["generated"] = output |
| | self.commit_index += 1 |
| |
|
| | if self.commit_index == self.seq_len * 2: |
| | self.generated = torch.cat( |
| | [ |
| | self.generated[:, :, self.seq_len :, ...], |
| | torch.randn( |
| | self.batch_size, |
| | self.input_dim, |
| | self.seq_len, |
| | 1, |
| | 1, |
| | device=device, |
| | ), |
| | ], |
| | dim=2, |
| | ) |
| | self.current_step -= self.seq_len * self.num_denoise_steps / self.chunk_size |
| | self.commit_index -= self.seq_len |
| | for i in range(self.batch_size): |
| | self.text_condition_list[i] = self.text_condition_list[i][ |
| | self.seq_len : |
| | ] |
| | return out |
| |
|