Spaces:
Sleeping
Sleeping
| import os | |
| import contextlib | |
| import functools | |
| from datetime import datetime | |
| import cv2 | |
| import gradio as gr | |
| import kiui | |
| import numpy as np | |
| import rembg | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import trimesh | |
| from huggingface_hub import hf_hub_download | |
| try: | |
| import spaces | |
| except ImportError: | |
| class spaces: | |
| class GPU: | |
| def __init__(self, duration=60): | |
| self.duration = duration | |
| def __call__(self, func): | |
| return func | |
| from flow.model import Model | |
| from flow.configs.schema import ModelConfig | |
| from flow.utils import get_random_color, recenter_foreground | |
| from vae.utils import postprocess_mesh | |
| # ========================================================= | |
| # CPU / dtype 基础设置 | |
| # ========================================================= | |
| DEVICE = torch.device("cpu") | |
| DTYPE = torch.float32 | |
| # 线程数可按 HF CPU Space 机器情况调整 | |
| CPU_THREADS = int(os.environ.get("CPU_THREADS", "2")) | |
| torch.set_num_threads(CPU_THREADS) | |
| torch.set_num_interop_threads(max(1, min(2, CPU_THREADS))) | |
| # 显式设默认浮点 dtype 为 float32 | |
| torch.set_default_dtype(torch.float32) | |
| # 对 CPU 推理更稳妥 | |
| try: | |
| torch.set_grad_enabled(False) | |
| except Exception: | |
| pass | |
| TRIMESH_GLB_EXPORT = np.array( | |
| [[0, 1, 0], [0, 0, 1], [1, 0, 0]], | |
| dtype=np.float32 | |
| ) | |
| MAX_SEED = np.iinfo(np.int32).max | |
| bg_remover = rembg.new_session() | |
| # ========================================================= | |
| # 工具函数:递归转换任意对象中的浮点 Tensor 为 float32 | |
| # ========================================================= | |
| def to_cpu_fp32(obj): | |
| """ | |
| 递归把对象中的浮点 Tensor 转成 CPU + float32。 | |
| 支持 Tensor / dict / list / tuple。 | |
| """ | |
| if torch.is_tensor(obj): | |
| if obj.is_floating_point(): | |
| return obj.to(device=DEVICE, dtype=torch.float32, non_blocking=False) | |
| return obj.to(device=DEVICE, non_blocking=False) | |
| if isinstance(obj, dict): | |
| return {k: to_cpu_fp32(v) for k, v in obj.items()} | |
| if isinstance(obj, list): | |
| return [to_cpu_fp32(v) for v in obj] | |
| if isinstance(obj, tuple): | |
| return tuple(to_cpu_fp32(v) for v in obj) | |
| return obj | |
| # ========================================================= | |
| # 工具函数:强制整个模块转 float32 | |
| # ========================================================= | |
| def force_module_fp32(module: torch.nn.Module): | |
| """ | |
| 递归把模块参数和 buffer 都转到 CPU + float32。 | |
| """ | |
| module.to(device=DEVICE) | |
| module.float() | |
| for child in module.children(): | |
| force_module_fp32(child) | |
| # 处理 buffer | |
| for name, buf in module.named_buffers(recurse=False): | |
| if torch.is_tensor(buf) and buf.is_floating_point(): | |
| setattr(module, name, buf.to(device=DEVICE, dtype=torch.float32)) | |
| return module | |
| # ========================================================= | |
| # 工具函数:禁用 CPU autocast | |
| # ========================================================= | |
| def disable_cpu_autocast(): | |
| """ | |
| 显式关闭 CPU autocast,防止内部偷偷切到 bfloat16。 | |
| """ | |
| try: | |
| with torch.autocast(device_type="cpu", enabled=False): | |
| yield | |
| except Exception: | |
| # 某些环境/版本可能不支持该写法,直接退化为普通上下文 | |
| yield | |
| # ========================================================= | |
| # 兜底补丁 1:全局修补 F.linear | |
| # ========================================================= | |
| def patch_functional_linear(): | |
| """ | |
| 给 torch.nn.functional.linear 打补丁: | |
| 如果 input 和 weight dtype 不一致,自动把 input 转成 weight.dtype。 | |
| 这是最后一道保险。 | |
| """ | |
| if getattr(F.linear, "_fp32_safe_patched", False): | |
| return | |
| original_linear = F.linear | |
| def linear_fp32_safe(input, weight, bias=None): | |
| if ( | |
| torch.is_tensor(input) | |
| and torch.is_tensor(weight) | |
| and input.device.type == "cpu" | |
| and input.is_floating_point() | |
| and weight.is_floating_point() | |
| and input.dtype != weight.dtype | |
| ): | |
| input = input.to(dtype=weight.dtype) | |
| if ( | |
| bias is not None | |
| and torch.is_tensor(bias) | |
| and bias.is_floating_point() | |
| and torch.is_tensor(weight) | |
| and weight.is_floating_point() | |
| and bias.dtype != weight.dtype | |
| ): | |
| bias = bias.to(dtype=weight.dtype) | |
| return original_linear(input, weight, bias) | |
| linear_fp32_safe._fp32_safe_patched = True | |
| F.linear = linear_fp32_safe | |
| # ========================================================= | |
| # 兜底补丁 2:给常见模块加 forward pre-hook | |
| # ========================================================= | |
| def register_dtype_guard_hooks(root_module: nn.Module): | |
| """ | |
| 给常见算子模块注册前置 hook,在 forward 入口把输入对齐到参数 dtype。 | |
| """ | |
| hooks = [] | |
| guarded_types = ( | |
| nn.Linear, | |
| nn.Conv1d, | |
| nn.Conv2d, | |
| nn.Conv3d, | |
| nn.LayerNorm, | |
| nn.GroupNorm, | |
| nn.BatchNorm1d, | |
| nn.BatchNorm2d, | |
| nn.BatchNorm3d, | |
| nn.MultiheadAttention, | |
| ) | |
| def cast_obj_to_dtype(obj, dtype, device): | |
| if torch.is_tensor(obj): | |
| if obj.is_floating_point(): | |
| return obj.to(device=device, dtype=dtype) | |
| return obj.to(device=device) | |
| if isinstance(obj, dict): | |
| return {k: cast_obj_to_dtype(v, dtype, device) for k, v in obj.items()} | |
| if isinstance(obj, list): | |
| return [cast_obj_to_dtype(v, dtype, device) for v in obj] | |
| if isinstance(obj, tuple): | |
| return tuple(cast_obj_to_dtype(v, dtype, device) for v in obj) | |
| return obj | |
| def pre_hook(module, inputs): | |
| ref_tensor = None | |
| # 先从参数里找参考 dtype | |
| for p in module.parameters(recurse=False): | |
| if torch.is_tensor(p) and p.is_floating_point(): | |
| ref_tensor = p | |
| break | |
| # 参数没有,再从 buffer 里找 | |
| if ref_tensor is None: | |
| for b in module.buffers(recurse=False): | |
| if torch.is_tensor(b) and b.is_floating_point(): | |
| ref_tensor = b | |
| break | |
| if ref_tensor is None: | |
| return inputs | |
| return cast_obj_to_dtype(inputs, ref_tensor.dtype, ref_tensor.device) | |
| for submodule in root_module.modules(): | |
| if isinstance(submodule, guarded_types): | |
| hooks.append(submodule.register_forward_pre_hook(pre_hook)) | |
| return hooks | |
| # ========================================================= | |
| # 兜底补丁 3:包装 forward,统一禁用 autocast + 输入转 fp32 | |
| # ========================================================= | |
| def wrap_forward_fp32(module: nn.Module): | |
| """ | |
| 包装模块的 forward: | |
| 1. 进入 forward 前先把输入递归转为 float32 | |
| 2. forward 期间禁用 CPU autocast | |
| """ | |
| if getattr(module, "_forward_fp32_wrapped", False): | |
| return | |
| original_forward = module.forward | |
| def forward_fp32_safe(*args, **kwargs): | |
| args = to_cpu_fp32(args) | |
| kwargs = to_cpu_fp32(kwargs) | |
| with disable_cpu_autocast(): | |
| out = original_forward(*args, **kwargs) | |
| return to_cpu_fp32(out) | |
| module.forward = forward_fp32_safe | |
| module._forward_fp32_wrapped = True | |
| # ========================================================= | |
| # 下载模型 | |
| # ========================================================= | |
| flow_ckpt_path = hf_hub_download( | |
| repo_id="nvidia/PartPacker", | |
| filename="flow.pt" | |
| ) | |
| vae_ckpt_path = hf_hub_download( | |
| repo_id="nvidia/PartPacker", | |
| filename="vae.pt" | |
| ) | |
| # ========================================================= | |
| # 模型配置 | |
| # ========================================================= | |
| model_config = ModelConfig( | |
| vae_conf="vae.configs.part_woenc", | |
| vae_ckpt_path=vae_ckpt_path, | |
| qknorm=True, | |
| qknorm_type="RMSNorm", | |
| use_pos_embed=False, | |
| dino_model="dinov2_vitg14", | |
| hidden_dim=1536, | |
| flow_shift=3.0, | |
| logitnorm_mean=1.0, | |
| logitnorm_std=1.0, | |
| latent_size=4096, | |
| use_parts=True, | |
| ) | |
| # ========================================================= | |
| # 初始化模型(CPU + float32) | |
| # ========================================================= | |
| print("正在加载模型到 CPU ...") | |
| patch_functional_linear() | |
| model = Model(model_config) | |
| model.eval() | |
| model.to(DEVICE) | |
| # 显式按 CPU 加载权重 | |
| # 某些环境下 weights_only=True 不兼容时,可退回普通 torch.load | |
| try: | |
| ckpt_dict = torch.load(flow_ckpt_path, map_location=DEVICE, weights_only=True) | |
| except TypeError: | |
| ckpt_dict = torch.load(flow_ckpt_path, map_location=DEVICE) | |
| model.load_state_dict(ckpt_dict, strict=True) | |
| # 强制全模型转 float32 | |
| force_module_fp32(model) | |
| model.eval() | |
| # 包装 forward,彻底关闭 CPU autocast | |
| wrap_forward_fp32(model) | |
| if hasattr(model, "dit"): | |
| wrap_forward_fp32(model.dit) | |
| if hasattr(model, "vae"): | |
| wrap_forward_fp32(model.vae) | |
| # 给模型注册 dtype 保护 hook | |
| _DTYPE_GUARD_HOOKS = [] | |
| _DTYPE_GUARD_HOOKS.extend(register_dtype_guard_hooks(model)) | |
| if hasattr(model, "vae"): | |
| _DTYPE_GUARD_HOOKS.extend(register_dtype_guard_hooks(model.vae)) | |
| print("模型加载完成。") | |
| try: | |
| print("主模型 dtype:", next(model.parameters()).dtype) | |
| except StopIteration: | |
| print("主模型没有可见参数。") | |
| def get_random_seed(randomize_seed, seed): | |
| if randomize_seed: | |
| seed = np.random.randint(0, MAX_SEED) | |
| return int(seed) | |
| def process_image(image_path): | |
| """ | |
| 处理输入图片: | |
| 1. 读图 | |
| 2. 没有 alpha 就自动去背景 | |
| 3. 主体居中 | |
| 4. 缩放到模型输入尺寸 | |
| """ | |
| if image_path is None: | |
| raise gr.Error("请先上传图片。") | |
| image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) | |
| if image is None: | |
| raise gr.Error("图片读取失败,请上传有效图片。") | |
| if image.ndim == 2: | |
| image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGBA) | |
| if image.shape[-1] == 4: | |
| image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) | |
| else: | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| image = rembg.remove(image, session=bg_remover) | |
| mask = image[..., -1] > 0 | |
| image = recenter_foreground(image, mask, border_ratio=0.1) | |
| image = cv2.resize(image, (518, 518), interpolation=cv2.INTER_AREA) | |
| return image | |
| def process_3d( | |
| input_image, | |
| num_steps=10, | |
| cfg_scale=7.0, | |
| grid_res=128, | |
| seed=42, | |
| simplify_mesh=True, | |
| target_num_faces=20000 | |
| ): | |
| """ | |
| CPU 版 3D 生成 | |
| """ | |
| if input_image is None: | |
| raise gr.Error("请先上传并处理图片。") | |
| try: | |
| kiui.seed_everything(int(seed)) | |
| os.makedirs("output", exist_ok=True) | |
| output_glb_path = f"output/partpacker_{datetime.now().strftime('%Y%m%d_%H%M%S')}.glb" | |
| # ------------------------------------------------- | |
| # 1) RGBA -> RGB 白底合成 -> float32 | |
| # ------------------------------------------------- | |
| image = input_image.astype(np.float32) / 255.0 | |
| image = image[..., :3] * image[..., 3:4] + (1.0 - image[..., 3:4]) | |
| image_tensor = ( | |
| torch.from_numpy(image) | |
| .permute(2, 0, 1) | |
| .contiguous() | |
| .unsqueeze(0) | |
| .to(device=DEVICE, dtype=torch.float32) | |
| ) | |
| data = { | |
| "cond_images": image_tensor | |
| } | |
| data = to_cpu_fp32(data) | |
| # ------------------------------------------------- | |
| # 2) 推理前再次强制模型为 float32 | |
| # ------------------------------------------------- | |
| force_module_fp32(model) | |
| model.eval() | |
| if hasattr(model, "vae"): | |
| force_module_fp32(model.vae) | |
| model.vae.eval() | |
| # ------------------------------------------------- | |
| # 3) 主模型推理:显式禁用 CPU autocast | |
| # ------------------------------------------------- | |
| with torch.inference_mode(): | |
| with disable_cpu_autocast(): | |
| results = model( | |
| data, | |
| num_steps=int(num_steps), | |
| cfg_scale=float(cfg_scale) | |
| ) | |
| results = to_cpu_fp32(results) | |
| latent = results.get("latent", None) | |
| if not isinstance(latent, torch.Tensor): | |
| raise gr.Error("模型输出 latent 异常。") | |
| latent = latent.to(device=DEVICE, dtype=torch.float32).contiguous() | |
| # ------------------------------------------------- | |
| # 4) VAE 解码:再次显式禁用 CPU autocast | |
| # ------------------------------------------------- | |
| data_part0 = { | |
| "latent": latent[:, : model.config.latent_size, :].contiguous() | |
| } | |
| data_part1 = { | |
| "latent": latent[:, model.config.latent_size:, :].contiguous() | |
| } | |
| data_part0 = to_cpu_fp32(data_part0) | |
| data_part1 = to_cpu_fp32(data_part1) | |
| with torch.inference_mode(): | |
| with disable_cpu_autocast(): | |
| results_part0 = model.vae(data_part0, resolution=int(grid_res)) | |
| results_part1 = model.vae(data_part1, resolution=int(grid_res)) | |
| results_part0 = to_cpu_fp32(results_part0) | |
| results_part1 = to_cpu_fp32(results_part1) | |
| if not simplify_mesh: | |
| target_num_faces = -1 | |
| parts = [] | |
| # ------------------------------------------------- | |
| # 5) part 0 mesh | |
| # ------------------------------------------------- | |
| vertices, faces = results_part0["meshes"][0] | |
| vertices = np.asarray(vertices, dtype=np.float32) | |
| faces = np.asarray(faces, dtype=np.int64) | |
| mesh_part0 = trimesh.Trimesh(vertices, faces, process=False) | |
| mesh_part0.vertices = mesh_part0.vertices @ TRIMESH_GLB_EXPORT.T | |
| mesh_part0 = postprocess_mesh(mesh_part0, int(target_num_faces)) | |
| parts.extend(mesh_part0.split(only_watertight=False)) | |
| # ------------------------------------------------- | |
| # 6) part 1 mesh | |
| # ------------------------------------------------- | |
| vertices, faces = results_part1["meshes"][0] | |
| vertices = np.asarray(vertices, dtype=np.float32) | |
| faces = np.asarray(faces, dtype=np.int64) | |
| mesh_part1 = trimesh.Trimesh(vertices, faces, process=False) | |
| mesh_part1.vertices = mesh_part1.vertices @ TRIMESH_GLB_EXPORT.T | |
| mesh_part1 = postprocess_mesh(mesh_part1, int(target_num_faces)) | |
| parts.extend(mesh_part1.split(only_watertight=False)) | |
| if len(parts) == 0: | |
| raise gr.Error("没有生成有效网格,请换一张更清晰、背景更简单的图片。") | |
| for j, part in enumerate(parts): | |
| part.visual.vertex_colors = get_random_color(j, use_float=True) | |
| scene = trimesh.Scene(parts) | |
| scene.export(output_glb_path) | |
| return output_glb_path | |
| except gr.Error: | |
| raise | |
| except Exception as e: | |
| raise gr.Error( | |
| "CPU 生成失败:" | |
| + str(e) | |
| + "\n\n建议:\n" | |
| "1. Inference Steps 先设为 10\n" | |
| "2. Grid Resolution 先设为 128\n" | |
| "3. 勾选 Simplify Mesh\n" | |
| "4. Target Face Count 设为 20000\n" | |
| "5. 使用主体清晰、背景简单的 PNG 图片" | |
| ) | |
| _TITLE = "🎨 Image to 3D Model - CPU Version" | |
| _DESCRIPTION = """ | |
| ### CPU 版说明 | |
| 这是适配 Hugging Face CPU Space 的版本。 | |
| ### 建议参数 | |
| - Inference Steps:10 | |
| - CFG Scale:7.0 | |
| - Grid Resolution:128 | |
| - Simplify Mesh:开启 | |
| - Target Face Count:20000 | |
| ### 注意 | |
| 该模型原本更适合 GPU,CPU 下会比较慢。 | |
| """ | |
| block = gr.Blocks(title=_TITLE).queue(max_size=2) | |
| with block: | |
| gr.Markdown("# " + _TITLE) | |
| gr.Markdown(_DESCRIPTION) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image( | |
| label="上传图片", | |
| type="filepath" | |
| ) | |
| seg_image = gr.Image( | |
| label="处理后图片", | |
| type="numpy", | |
| interactive=False, | |
| image_mode="RGBA" | |
| ) | |
| with gr.Accordion("高级设置", open=False): | |
| num_steps = gr.Slider( | |
| label="Inference Steps", | |
| minimum=1, | |
| maximum=30, | |
| step=1, | |
| value=10 | |
| ) | |
| cfg_scale = gr.Slider( | |
| label="CFG Scale", | |
| minimum=2.0, | |
| maximum=10.0, | |
| step=0.1, | |
| value=7.0 | |
| ) | |
| input_grid_res = gr.Slider( | |
| label="Grid Resolution", | |
| minimum=64, | |
| maximum=256, | |
| step=1, | |
| value=128 | |
| ) | |
| with gr.Row(): | |
| randomize_seed = gr.Checkbox(label="随机种子", value=True) | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=MAX_SEED, | |
| step=1, | |
| value=0 | |
| ) | |
| with gr.Row(): | |
| simplify_mesh = gr.Checkbox(label="简化网格", value=True) | |
| target_num_faces = gr.Slider( | |
| label="目标面数", | |
| minimum=5000, | |
| maximum=50000, | |
| step=1000, | |
| value=20000 | |
| ) | |
| button_gen = gr.Button("生成 3D 模型", variant="primary") | |
| with gr.Column(): | |
| output_model = gr.Model3D(label="3D 预览", height=512) | |
| with gr.Row(): | |
| gr.Examples( | |
| examples=[ | |
| ["examples/rabbit.png"], | |
| ["examples/robot.png"], | |
| ["examples/teapot.png"], | |
| ], | |
| fn=process_image, | |
| inputs=[input_image], | |
| outputs=[seg_image], | |
| cache_examples=False | |
| ) | |
| button_gen.click( | |
| fn=process_image, | |
| inputs=[input_image], | |
| outputs=[seg_image] | |
| ).then( | |
| fn=get_random_seed, | |
| inputs=[randomize_seed, seed], | |
| outputs=[seed] | |
| ).then( | |
| fn=process_3d, | |
| inputs=[ | |
| seg_image, | |
| num_steps, | |
| cfg_scale, | |
| input_grid_res, | |
| seed, | |
| simplify_mesh, | |
| target_num_faces | |
| ], | |
| outputs=[output_model] | |
| ) | |
| block.launch( | |
| server_name="0.0.0.0", | |
| server_port=int(os.environ.get("PORT", 7860)) | |
| ) |