import torch from diffusers.pipelines import FluxPipeline from omini.pipeline.flux_omini import Condition, generate, seed_everything, convert_to_condition from omini.rotation import RotationConfig, RotationTuner from PIL import Image def load_rotation(transformer, path: str, adapter_name: str = "default", strict: bool = False): """ Load rotation adapter weights. Args: path: Directory containing the saved adapter weights adapter_name: Name of the adapter to load strict: Whether to strictly match all keys """ from safetensors.torch import load_file import os import yaml device = transformer.device print(f"device for loading: {device}") # Try to load safetensors first, then fallback to .pth safetensors_path = os.path.join(path, f"{adapter_name}.safetensors") pth_path = os.path.join(path, f"{adapter_name}.pth") if os.path.exists(safetensors_path): state_dict = load_file(safetensors_path) print(f"Loaded rotation adapter from {safetensors_path}") elif os.path.exists(pth_path): state_dict = torch.load(pth_path, map_location=device) print(f"Loaded rotation adapter from {pth_path}") else: raise FileNotFoundError( f"No adapter weights found for '{adapter_name}' in {path}\n" f"Looking for: {safetensors_path} or {pth_path}" ) # # Get the device and dtype of the transformer transformer_device = next(transformer.parameters()).device transformer_dtype = next(transformer.parameters()).dtype state_dict_with_adapter = {} for k, v in state_dict.items(): # Reconstruct the full key with adapter name new_key = k.replace(".rotation.", f".rotation.{adapter_name}.") if "_adapter_config" in new_key: print(f"adapter_config key: {new_key}") # Move to target device and dtype # Check if this parameter should keep its original dtype (e.g., indices, masks) if v.dtype in [torch.long, torch.int, torch.int32, torch.int64, torch.bool]: # Keep integer/boolean dtypes, only move device state_dict_with_adapter[new_key] = v.to(device=transformer_device) else: # Convert floating point tensors to target dtype and device state_dict_with_adapter[new_key] = v.to(device=transformer_device, dtype=transformer_dtype) # Add adapter name back to keys (reverse of what we did in save) state_dict_with_adapter = { k.replace(".rotation.", f".rotation.{adapter_name}."): v for k, v in state_dict.items() } # Load into the model missing, unexpected = transformer.load_state_dict( state_dict_with_adapter, strict=strict ) if missing: print(f"Missing keys: {missing[:5]}{'...' if len(missing) > 5 else ''}") if unexpected: print(f"Unexpected keys: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}") # Load config if available config_path = os.path.join(path, f"{adapter_name}_config.yaml") if os.path.exists(config_path): with open(config_path, 'r') as f: config = yaml.safe_load(f) print(f"Loaded config: {config}") total_params = sum(p.numel() for p in state_dict.values()) print(f"Loaded {len(state_dict)} tensors ({total_params:,} parameters)") return state_dict # prepare input image and prompt image = Image.open("assets/coffee.png").convert("RGB") w, h, min_dim = image.size + (min(image.size),) image = image.crop( ((w - min_dim) // 2, (h - min_dim) // 2, (w + min_dim) // 2, (h + min_dim) // 2) ).resize((512, 512)) prompt = "In a bright room. A cup of a coffee with some beans on the side. They are placed on a dark wooden table." canny_image = convert_to_condition("canny", image) condition = Condition(canny_image, "canny") seed_everything() for i in range(40, 60): pipe = FluxPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16 ) # add adapter to the transformer transformer = pipe.transformer adapter_name = "default" transformer._hf_peft_config_loaded = True rotation_adapter_config = { "r": 4, "num_rotations": 4, "target_modules": "(.*x_embedder|.*(?