|
|
import torch |
|
|
|
|
|
from diffusers.pipelines import FluxPipeline |
|
|
from omini.pipeline.flux_omini import Condition, generate, seed_everything, convert_to_condition |
|
|
from omini.rotation import RotationConfig, RotationTuner |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
def load_rotation(transformer, path: str, adapter_name: str = "default", strict: bool = False): |
|
|
""" |
|
|
Load rotation adapter weights. |
|
|
|
|
|
Args: |
|
|
path: Directory containing the saved adapter weights |
|
|
adapter_name: Name of the adapter to load |
|
|
strict: Whether to strictly match all keys |
|
|
""" |
|
|
from safetensors.torch import load_file |
|
|
import os |
|
|
import yaml |
|
|
|
|
|
device = transformer.device |
|
|
print(f"device for loading: {device}") |
|
|
|
|
|
|
|
|
safetensors_path = os.path.join(path, f"{adapter_name}.safetensors") |
|
|
pth_path = os.path.join(path, f"{adapter_name}.pth") |
|
|
|
|
|
if os.path.exists(safetensors_path): |
|
|
state_dict = load_file(safetensors_path) |
|
|
print(f"Loaded rotation adapter from {safetensors_path}") |
|
|
elif os.path.exists(pth_path): |
|
|
state_dict = torch.load(pth_path, map_location=device) |
|
|
print(f"Loaded rotation adapter from {pth_path}") |
|
|
else: |
|
|
raise FileNotFoundError( |
|
|
f"No adapter weights found for '{adapter_name}' in {path}\n" |
|
|
f"Looking for: {safetensors_path} or {pth_path}" |
|
|
) |
|
|
|
|
|
|
|
|
transformer_device = next(transformer.parameters()).device |
|
|
transformer_dtype = next(transformer.parameters()).dtype |
|
|
|
|
|
|
|
|
|
|
|
state_dict_with_adapter = {} |
|
|
for k, v in state_dict.items(): |
|
|
|
|
|
new_key = k.replace(".rotation.", f".rotation.{adapter_name}.") |
|
|
if "_adapter_config" in new_key: |
|
|
print(f"adapter_config key: {new_key}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if v.dtype in [torch.long, torch.int, torch.int32, torch.int64, torch.bool]: |
|
|
|
|
|
state_dict_with_adapter[new_key] = v.to(device=transformer_device) |
|
|
else: |
|
|
|
|
|
state_dict_with_adapter[new_key] = v.to(device=transformer_device, dtype=transformer_dtype) |
|
|
|
|
|
|
|
|
state_dict_with_adapter = { |
|
|
k.replace(".rotation.", f".rotation.{adapter_name}."): v |
|
|
for k, v in state_dict.items() |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
missing, unexpected = transformer.load_state_dict( |
|
|
state_dict_with_adapter, |
|
|
strict=strict |
|
|
) |
|
|
|
|
|
if missing: |
|
|
print(f"Missing keys: {missing[:5]}{'...' if len(missing) > 5 else ''}") |
|
|
if unexpected: |
|
|
print(f"Unexpected keys: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}") |
|
|
|
|
|
|
|
|
config_path = os.path.join(path, f"{adapter_name}_config.yaml") |
|
|
if os.path.exists(config_path): |
|
|
with open(config_path, 'r') as f: |
|
|
config = yaml.safe_load(f) |
|
|
print(f"Loaded config: {config}") |
|
|
|
|
|
total_params = sum(p.numel() for p in state_dict.values()) |
|
|
print(f"Loaded {len(state_dict)} tensors ({total_params:,} parameters)") |
|
|
|
|
|
return state_dict |
|
|
|
|
|
|
|
|
|
|
|
image = Image.open("assets/coffee.png").convert("RGB") |
|
|
|
|
|
w, h, min_dim = image.size + (min(image.size),) |
|
|
image = image.crop( |
|
|
((w - min_dim) // 2, (h - min_dim) // 2, (w + min_dim) // 2, (h + min_dim) // 2) |
|
|
).resize((512, 512)) |
|
|
|
|
|
prompt = "In a bright room. A cup of a coffee with some beans on the side. They are placed on a dark wooden table." |
|
|
|
|
|
canny_image = convert_to_condition("canny", image) |
|
|
condition = Condition(canny_image, "canny") |
|
|
|
|
|
seed_everything() |
|
|
|
|
|
|
|
|
|
|
|
for i in range(40, 60): |
|
|
pipe = FluxPipeline.from_pretrained( |
|
|
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
transformer = pipe.transformer |
|
|
|
|
|
adapter_name = "default" |
|
|
transformer._hf_peft_config_loaded = True |
|
|
|
|
|
rotation_adapter_config = { |
|
|
"r": 4, |
|
|
"num_rotations": 4, |
|
|
"target_modules": "(.*x_embedder|.*(?<!single_)transformer_blocks\\.[0-9]+\\.norm1\\.linear|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_k|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_q|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_v|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_out\\.0|.*(?<!single_)transformer_blocks\\.[0-9]+\\.ff\\.net\\.2|.*single_transformer_blocks\\.[0-9]+\\.norm\\.linear|.*single_transformer_blocks\\.[0-9]+\\.proj_mlp|.*single_transformer_blocks\\.[0-9]+\\.proj_out|.*single_transformer_blocks\\.[0-9]+\\.attn.to_k|.*single_transformer_blocks\\.[0-9]+\\.attn.to_q|.*single_transformer_blocks\\.[0-9]+\\.attn.to_v|.*single_transformer_blocks\\.[0-9]+\\.attn.to_out)", |
|
|
} |
|
|
|
|
|
config = RotationConfig(**rotation_adapter_config) |
|
|
config.T = float(i + 1) / 20 |
|
|
rotation_tuner = RotationTuner( |
|
|
transformer, |
|
|
config, |
|
|
adapter_name=adapter_name, |
|
|
) |
|
|
|
|
|
transformer = transformer.to(torch.bfloat16) |
|
|
transformer.set_adapter(adapter_name) |
|
|
|
|
|
|
|
|
load_rotation( |
|
|
transformer, |
|
|
path="runs/20251110-191859/ckpt/4000", |
|
|
adapter_name=adapter_name, |
|
|
strict=False, |
|
|
) |
|
|
|
|
|
pipe = pipe.to("cuda") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
result_img = generate( |
|
|
pipe, |
|
|
prompt=prompt, |
|
|
conditions=[condition], |
|
|
).images[0] |
|
|
|
|
|
concat_image = Image.new("RGB", (1536, 512)) |
|
|
concat_image.paste(image, (0, 0)) |
|
|
concat_image.paste(condition.condition, (512, 0)) |
|
|
concat_image.paste(result_img, (1024, 0)) |
|
|
|
|
|
|
|
|
result_img.save(f"result_{i+1}.png") |
|
|
concat_image.save(f"result_concat_{i+1}.png") |
|
|
print(f"Saved result_{i+1}.png and result_concat_{i+1}.png") |