| |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import functools |
| import os |
| import pathlib |
| import sys |
| from typing import Callable |
|
|
| if os.environ.get('SYSTEM') == 'spaces': |
| os.system("sed -i '10,17d' DualStyleGAN/model/stylegan/op/fused_act.py") |
| os.system("sed -i '10,17d' DualStyleGAN/model/stylegan/op/upfirdn2d.py") |
|
|
| sys.path.insert(0, 'DualStyleGAN') |
|
|
| import dlib |
| import gradio as gr |
| import huggingface_hub |
| import numpy as np |
| import PIL.Image |
| import torch |
| import torch.nn as nn |
| import torchvision.transforms as T |
| from model.dualstylegan import DualStyleGAN |
| from model.encoder.align_all_parallel import align_face |
| from model.encoder.psp import pSp |
|
|
| ORIGINAL_REPO_URL = 'https://github.com/williamyang1991/DualStyleGAN' |
| TITLE = 'williamyang1991/DualStyleGAN' |
| DESCRIPTION = f"""This is a demo for {ORIGINAL_REPO_URL}. |
| |
|  |
| |
| You can select style images for each style type from the tables below. |
| The style image index should be in the following range: |
| (cartoon: 0-316, caricature: 0-198, anime: 0-173, arcane: 0-99, comic: 0-100, pixar: 0-121, slamdunk: 0-119) |
| """ |
| ARTICLE = """## Style images |
| |
| Note that the style images here for Arcane, comic, Pixar, and Slamdunk are the reconstructed ones, not the original ones due to copyright issues. |
| |
| ### Cartoon |
|  |
| |
| ### Caricature |
|  |
| |
| ### Anime |
|  |
| |
| ### Arcane |
|  |
| |
| ### Comic |
|  |
| |
| ### Pixar |
|  |
| |
| ### Slamdunk |
|  |
| """ |
|
|
| TOKEN = os.environ['TOKEN'] |
| MODEL_REPO = 'hysts/DualStyleGAN' |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--device', type=str, default='cpu') |
| parser.add_argument('--theme', type=str) |
| parser.add_argument('--live', action='store_true') |
| parser.add_argument('--share', action='store_true') |
| parser.add_argument('--port', type=int) |
| parser.add_argument('--disable-queue', |
| dest='enable_queue', |
| action='store_false') |
| parser.add_argument('--allow-flagging', type=str, default='never') |
| parser.add_argument('--allow-screenshot', action='store_true') |
| return parser.parse_args() |
|
|
|
|
| def load_encoder(device: torch.device) -> nn.Module: |
| ckpt_path = huggingface_hub.hf_hub_download(MODEL_REPO, |
| 'models/encoder.pt', |
| use_auth_token=TOKEN) |
| ckpt = torch.load(ckpt_path, map_location='cpu') |
| opts = ckpt['opts'] |
| opts['device'] = device.type |
| opts['checkpoint_path'] = ckpt_path |
| opts = argparse.Namespace(**opts) |
| model = pSp(opts) |
| model.to(device) |
| model.eval() |
| return model |
|
|
|
|
| def load_generator(style_type: str, device: torch.device) -> nn.Module: |
| model = DualStyleGAN(1024, 512, 8, 2, res_index=6) |
| ckpt_path = huggingface_hub.hf_hub_download( |
| MODEL_REPO, f'models/{style_type}/generator.pt', use_auth_token=TOKEN) |
| ckpt = torch.load(ckpt_path, map_location='cpu') |
| model.load_state_dict(ckpt['g_ema']) |
| model.to(device) |
| model.eval() |
| return model |
|
|
|
|
| def load_exstylecode(style_type: str) -> dict[str, np.ndarray]: |
| if style_type in ['cartoon', 'caricature', 'anime']: |
| filename = 'refined_exstyle_code.npy' |
| else: |
| filename = 'exstyle_code.npy' |
| path = huggingface_hub.hf_hub_download(MODEL_REPO, |
| f'models/{style_type}/{filename}', |
| use_auth_token=TOKEN) |
| exstyles = np.load(path, allow_pickle=True).item() |
| return exstyles |
|
|
|
|
| def create_transform() -> Callable: |
| transform = T.Compose([ |
| T.Resize(256), |
| T.CenterCrop(256), |
| T.ToTensor(), |
| T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), |
| ]) |
| return transform |
|
|
|
|
| def create_dlib_landmark_model(): |
| path = huggingface_hub.hf_hub_download( |
| 'hysts/dlib_face_landmark_model', |
| 'shape_predictor_68_face_landmarks.dat', |
| use_auth_token=TOKEN) |
| return dlib.shape_predictor(path) |
|
|
|
|
| def denormalize(tensor: torch.Tensor) -> torch.Tensor: |
| return torch.clamp((tensor + 1) / 2 * 255, 0, 255).to(torch.uint8) |
|
|
|
|
| def postprocess(tensor: torch.Tensor) -> PIL.Image.Image: |
| tensor = denormalize(tensor) |
| image = tensor.cpu().numpy().transpose(1, 2, 0) |
| return PIL.Image.fromarray(image) |
|
|
|
|
| @torch.inference_mode() |
| def run( |
| image, |
| style_type: str, |
| style_id: float, |
| structure_weight: float, |
| color_weight: float, |
| dlib_landmark_model, |
| encoder: nn.Module, |
| generator_dict: dict[str, nn.Module], |
| exstyle_dict: dict[str, dict[str, np.ndarray]], |
| transform: Callable, |
| device: torch.device, |
| ) -> tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, |
| PIL.Image.Image]: |
| generator = generator_dict[style_type] |
| exstyles = exstyle_dict[style_type] |
|
|
| style_id = int(style_id) |
| style_id = min(max(0, style_id), len(exstyles) - 1) |
|
|
| stylename = list(exstyles.keys())[style_id] |
|
|
| image = align_face(filepath=image.name, predictor=dlib_landmark_model) |
| input_data = transform(image).unsqueeze(0).to(device) |
|
|
| img_rec, instyle = encoder(input_data, |
| randomize_noise=False, |
| return_latents=True, |
| z_plus_latent=True, |
| return_z_plus_latent=True, |
| resize=False) |
| img_rec = torch.clamp(img_rec.detach(), -1, 1) |
|
|
| latent = torch.tensor(exstyles[stylename]).repeat(2, 1, 1).to(device) |
| |
| latent[1, 7:18] = instyle[0, 7:18] |
| exstyle = generator.generator.style( |
| latent.reshape(latent.shape[0] * latent.shape[1], |
| latent.shape[2])).reshape(latent.shape) |
|
|
| img_gen, _ = generator([instyle.repeat(2, 1, 1)], |
| exstyle, |
| z_plus_latent=True, |
| truncation=0.7, |
| truncation_latent=0, |
| use_res=True, |
| interp_weights=[structure_weight] * 7 + |
| [color_weight] * 11) |
| img_gen = torch.clamp(img_gen.detach(), -1, 1) |
| |
| img_gen2, _ = generator([instyle], |
| exstyle[0:1], |
| z_plus_latent=True, |
| truncation=0.7, |
| truncation_latent=0, |
| use_res=True, |
| interp_weights=[structure_weight] * 7 + [0] * 11) |
| img_gen2 = torch.clamp(img_gen2.detach(), -1, 1) |
|
|
| img_rec = postprocess(img_rec[0]) |
| img_gen0 = postprocess(img_gen[0]) |
| img_gen1 = postprocess(img_gen[1]) |
| img_gen2 = postprocess(img_gen2[0]) |
|
|
| return image, img_rec, img_gen0, img_gen1, img_gen2 |
|
|
|
|
| def main(): |
| gr.close_all() |
|
|
| args = parse_args() |
| device = torch.device(args.device) |
|
|
| style_types = [ |
| 'cartoon', |
| 'caricature', |
| 'anime', |
| 'arcane', |
| 'comic', |
| 'pixar', |
| 'slamdunk', |
| ] |
| generator_dict = { |
| style_type: load_generator(style_type, device) |
| for style_type in style_types |
| } |
| exstyle_dict = { |
| style_type: load_exstylecode(style_type) |
| for style_type in style_types |
| } |
|
|
| dlib_landmark_model = create_dlib_landmark_model() |
| encoder = load_encoder(device) |
| transform = create_transform() |
|
|
| func = functools.partial(run, |
| dlib_landmark_model=dlib_landmark_model, |
| encoder=encoder, |
| generator_dict=generator_dict, |
| exstyle_dict=exstyle_dict, |
| transform=transform, |
| device=device) |
| func = functools.update_wrapper(func, run) |
|
|
| image_paths = sorted(pathlib.Path('images').glob('*.jpg')) |
| examples = [[path.as_posix(), 'cartoon', 26, 0.6, 1.0] |
| for path in image_paths] |
|
|
| gr.Interface( |
| func, |
| [ |
| gr.inputs.Image(type='file', label='Input Image'), |
| gr.inputs.Radio( |
| style_types, |
| type='value', |
| default='cartoon', |
| label='Style Type', |
| ), |
| gr.inputs.Number(default=26, label='Style Image Index'), |
| gr.inputs.Slider( |
| 0, 1, step=0.1, default=0.6, label='Structure Weight'), |
| gr.inputs.Slider(0, 1, step=0.1, default=1.0, |
| label='Color Weight'), |
| ], |
| [ |
| gr.outputs.Image(type='pil', label='Aligned Face'), |
| gr.outputs.Image(type='pil', label='Reconstructed'), |
| gr.outputs.Image(type='pil', |
| label='Result 1 (Color and structure transfer)'), |
| gr.outputs.Image(type='pil', |
| label='Result 2 (Structure transfer only)'), |
| gr.outputs.Image( |
| type='pil', |
| label='Result 3 (Color-related layers deactivated)'), |
| ], |
| examples=examples, |
| theme=args.theme, |
| title=TITLE, |
| description=DESCRIPTION, |
| article=ARTICLE, |
| allow_screenshot=args.allow_screenshot, |
| allow_flagging=args.allow_flagging, |
| live=args.live, |
| ).launch( |
| enable_queue=args.enable_queue, |
| server_port=args.port, |
| share=args.share, |
| ) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|