| |
|
|
| from typing import Dict, Any |
| import torch |
| from PIL import Image |
| import base64 |
| from io import BytesIO |
| import numpy as np |
| from diffusers import AutoencoderKL, DDIMScheduler |
| from einops import repeat |
| from omegaconf import OmegaConf |
| from transformers import CLIPVisionModelWithProjection |
| import cv2 |
| import os |
| import sys |
| import skvideo.io |
| from src.models.pose_guider import PoseGuider |
| from src.models.unet_2d_condition import UNet2DConditionModel |
| from src.models.unet_3d import UNet3DConditionModel |
| from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline |
| from src.utils.util import read_frames, get_fps, save_videos_grid |
|
|
| |
| import gc |
| import subprocess |
|
|
| import requests |
| import tempfile |
|
|
| from rembg import remove |
| import onnxruntime as ort |
| import shutil |
|
|
| import firebase_admin |
| from firebase_admin import credentials, storage, firestore |
| import json |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| if device.type != 'cuda': |
| raise ValueError("The model requires a GPU for inference.") |
|
|
| class EndpointHandler(): |
| def __init__(self, path=""): |
| base_dir = os.path.dirname(os.path.abspath(__file__)) |
| config_path = os.path.join(base_dir, 'configs', 'prompts', 'animation.yaml') |
|
|
| if not os.path.exists(config_path): |
| raise FileNotFoundError(f"The configuration file was not found at: {config_path}") |
|
|
| service_account_info = os.getenv("FIREBASE_ACCOUNT_INFO") |
|
|
| if not service_account_info: |
| raise ValueError("The FIREBASE_SERVICE_ACCOUNT environment variable is not set.") |
| service_account_info = service_account_info.replace('/\\n/g', '\n') |
|
|
| service_account_info_dict = json.loads(service_account_info) |
|
|
| cred = credentials.Certificate(service_account_info_dict) |
| firebase_admin.initialize_app(cred, { |
| 'storageBucket': 'quiz-app-edffe.appspot.com' |
| }) |
|
|
| self.config = OmegaConf.load(config_path) |
| self.weight_dtype = torch.float16 |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| self.pipeline = None |
| self._initialize_pipeline() |
|
|
| def _initialize_pipeline(self): |
| base_dir = os.path.dirname(os.path.abspath(__file__)) |
| config_path = os.path.join(base_dir, 'pretrained_weights', 'sd-vae-ft-mse') |
|
|
| if not os.path.exists(config_path): |
| raise FileNotFoundError(f"The sd-vae-ft-mse folder was not found at: {config_path}") |
|
|
| vae = AutoencoderKL.from_pretrained(config_path).to(self.device, dtype=self.weight_dtype) |
|
|
| pretrained_base_model_path_unet = os.path.join(base_dir, 'pretrained_weights', 'stable-diffusion-v1-5', 'unet') |
| print("model path is " + pretrained_base_model_path_unet) |
| reference_unet = UNet2DConditionModel.from_pretrained( |
| pretrained_base_model_path_unet |
| ).to(dtype=self.weight_dtype, device=self.device) |
|
|
| inference_config_path = os.path.join(base_dir, 'configs', 'inference', 'inference_v2.yaml') |
| motion_module_path = os.path.join(base_dir, 'pretrained_weights', 'motion_module.pth') |
| denoising_unet_path = os.path.join(base_dir, 'pretrained_weights', 'denoising_unet.pth') |
| reference_unet_path = os.path.join(base_dir, 'pretrained_weights', 'reference_unet.pth') |
| pose_guider_path = os.path.join(base_dir, 'pretrained_weights', 'pose_guider.pth') |
| image_encoder_path = os.path.join(base_dir, 'pretrained_weights', 'image_encoder') |
|
|
| infer_config = OmegaConf.load(inference_config_path) |
| denoising_unet = UNet3DConditionModel.from_pretrained_2d( |
| pretrained_base_model_path_unet, |
| motion_module_path, |
| unet_additional_kwargs=infer_config.unet_additional_kwargs, |
| ).to(self.device, dtype=self.weight_dtype) |
|
|
| pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to(self.device, dtype=self.weight_dtype) |
| image_enc = CLIPVisionModelWithProjection.from_pretrained(image_encoder_path).to(self.device, dtype=self.weight_dtype) |
| sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs) |
| scheduler = DDIMScheduler(**sched_kwargs) |
|
|
| denoising_unet.load_state_dict(torch.load(denoising_unet_path, map_location="cpu"), strict=False) |
| reference_unet.load_state_dict(torch.load(reference_unet_path, map_location="cpu")) |
| pose_guider.load_state_dict(torch.load(pose_guider_path, map_location="cpu")) |
|
|
| self.pipeline = Pose2VideoPipeline( |
| vae=vae, |
| image_encoder=image_enc, |
| reference_unet=reference_unet, |
| denoising_unet=denoising_unet, |
| pose_guider=pose_guider, |
| scheduler=scheduler |
| ).to(self.device, dtype=self.weight_dtype) |
|
|
| def _crop_face(self, image, save_path="cropped_face.jpg", margin=0.5): |
| |
| cv_image = np.array(image) |
| cv_image = cv_image[:, :, ::-1].copy() |
|
|
| |
| face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml') |
|
|
| |
| gray = cv2.cvtColor(cv_image, cv2.COLOR_BGR2GRAY) |
| faces = face_cascade.detectMultiScale(gray, 1.1, 4) |
|
|
| if len(faces) == 0: |
| raise ValueError("No faces detected in the reference image.") |
|
|
| |
| x, y, w, h = faces[0] |
| x_margin = int(margin * w) |
| y_margin = int(margin * h) |
|
|
| x1 = max(0, x - x_margin) |
| y1 = max(0, y - y_margin // 2) |
| x2 = min(cv_image.shape[1], x + w + x_margin) |
| y2 = min(cv_image.shape[0], y + h + y_margin) |
|
|
| cropped_face = cv_image[y1:y2, x1:x2] |
|
|
| |
| cropped_face = Image.fromarray(cropped_face[:, :, ::-1]).convert("RGB") |
|
|
| |
| cropped_face.save(save_path, format="JPEG", quality=95) |
|
|
| return cropped_face |
|
|
| def _swap_face(self, source_path, target_video_path, output_path): |
| |
| |
|
|
| roop.globals.source_path = source_path |
| roop.globals.target_path = target_video_path |
| roop.globals.output_path = output_path |
| roop.globals.frame_processors = ["face_swapper", "face_enhancer"] |
| roop.globals.headless = True |
| roop.globals.keep_fps = True |
| roop.globals.keep_audio = True |
| roop.globals.keep_frames = False |
| roop.globals.many_faces = False |
| |
| roop.globals.video_quality = 50 |
| roop.globals.max_memory = suggest_max_memory() |
|
|
| |
| roop.globals.execution_providers = decode_execution_providers(["CUDAExecutionProvider"]) |
| roop.globals.execution_threads = suggest_execution_threads() |
|
|
| |
| ort.set_default_logger_severity(3) |
| providers = ['CUDAExecutionProvider'] |
| options = ort.SessionOptions() |
| options.intra_op_num_threads = 1 |
|
|
| for frame_processor in get_frame_processors_modules(roop.globals.frame_processors): |
| if hasattr(frame_processor, 'onnx_session'): |
| frame_processor.onnx_session.set_providers(providers, options) |
|
|
| |
| torch.cuda.empty_cache() |
|
|
| start() |
|
|
| |
| for frame_processor in roop.globals.frame_processors: |
| del frame_processor |
| torch.cuda.empty_cache() |
|
|
| return os.path.join(os.getcwd(), output_path) |
|
|
| def print_memory_stat_for_stuff(self, phase, log_file="memory_stats.log"): |
| with open(log_file, "a") as f: |
| f.write(f"Memory Stats - {phase}:\n") |
| f.write(f"Allocated memory: {torch.cuda.memory_allocated() / 1024**2:.2f} MB\n") |
| f.write(f"Reserved memory: {torch.cuda.memory_reserved() / 1024**2:.2f} MB\n") |
| f.write(f"Max allocated memory: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB\n") |
| f.write(f"Max reserved memory: {torch.cuda.max_memory_reserved() / 1024**2:.2f} MB\n") |
| f.write("="*30 + "\n") |
|
|
| def convert_to_playable_format(self, input_path, output_path): |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_file: |
| temp_output_path = tmp_file.name |
|
|
| command = f"ffmpeg -i {input_path} -c:v libx264 -preset fast -crf 18 -y {temp_output_path}" |
|
|
| |
| result = subprocess.run(command, shell=True, capture_output=True, text=True) |
| print("Conversion STDOUT:", result.stdout) |
| print("Conversion STDERR:", result.stderr) |
| |
| if result.returncode != 0: |
| raise RuntimeError(f"FFmpeg conversion failed with exit code {result.returncode}") |
|
|
| shutil.move(temp_output_path, output_path) |
|
|
| def run_rife_interpolation(self, video_path, output_path, multi=2, scale=1.0): |
| base_dir = os.path.dirname(os.path.abspath(__file__)) |
| directory = os.path.join(base_dir, "Practical-RIFE", "inference_video.py") |
| model_directory = os.path.join(base_dir, "Practical-RIFE", "train_log") |
| command = f"python3 {directory} --video={video_path} --output={output_path} --multi={multi} --scale={scale} --model={model_directory}" |
|
|
| |
| result = subprocess.run(command, shell=True, capture_output=True, text=True) |
| print(result) |
| print(result.stdout) |
| print(result.stderr) |
| |
| if result.returncode != 0: |
| raise RuntimeError(f"RIFE interpolation failed with exit code {result.returncode}") |
| |
| |
| self.convert_to_playable_format(output_path, output_path) |
|
|
| def speed_up_video(self, input_path, output_path, factor=4): |
| command = f"ffmpeg -i {input_path} -filter:v setpts=PTS/{factor} -an {output_path}" |
|
|
| |
| result = subprocess.run(command, shell=True, capture_output=True, text=True) |
| print("Speed Up Video STDOUT:", result.stdout) |
| print("Speed Up Video STDERR:", result.stderr) |
| |
| if result.returncode != 0: |
| raise RuntimeError(f"FFmpeg speed up failed with exit code {result.returncode}") |
|
|
| def slow_down_video(self, input_path, output_path, factor=4): |
| command = f"ffmpeg -i {input_path} -filter:v setpts={factor}*PTS -an {output_path}" |
|
|
| |
| result = subprocess.run(command, shell=True, capture_output=True, text=True) |
| print("Slow Down Video STDOUT:", result.stdout) |
| print("Slow Down Video STDERR:", result.stderr) |
| |
| if result.returncode != 0: |
| raise RuntimeError(f"FFmpeg slow down failed with exit code {result.returncode}") |
|
|
| def download_file(self, url: str, save_path: str): |
| response = requests.get(url, stream=True) |
| if response.status_code == 200: |
| with open(save_path, 'wb') as f: |
| for chunk in response.iter_content(chunk_size=8192): |
| f.write(chunk) |
| else: |
| raise ValueError(f"Failed to download file from {url}") |
|
|
| def print_directory_contents(self, directory): |
| for root, dirs, files in os.walk(directory): |
| level = root.replace(directory, '').count(os.sep) |
| indent = ' ' * 4 * (level) |
| print(f"{indent}{os.path.basename(root)}/") |
| subindent = ' ' * 4 * (level + 1) |
| for f in files: |
| print(f"{subindent}{f}") |
|
|
| def print_directory_contents(self, path='.'): |
| for root, dirs, files in os.walk(path): |
| level = root.replace(path, '').count(os.sep) |
| indent = ' ' * 4 * level |
| print(f'{indent}{os.path.basename(root)}/') |
| sub_indent = ' ' * 4 * (level + 1) |
| for f in files: |
| print(f'{sub_indent}{f}') |
|
|
| def __call__(self, data: Any) -> Dict[str, str]: |
| inputs = data.get("inputs", {}) |
| ref_image_url = inputs.get("ref_image_url", "") |
| video_url = inputs.get("video_url", "") |
| width = inputs.get("width", 512) |
| height = inputs.get("height", 768) |
| length = inputs.get("length", 96) |
| num_inference_steps = inputs.get("num_inference_steps", 15) |
| cfg = inputs.get("cfg", 3.5) |
| seed = inputs.get("seed", -1) |
| firebase_doc_id = inputs.get("firebase_doc_id", "") |
|
|
| base_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
| with tempfile.TemporaryDirectory() as temp_dir: |
| print(f"Temporary directory created at {temp_dir}") |
| video_root = os.path.join(temp_dir, "dw_poses_videos") |
| os.makedirs(video_root, exist_ok=True) |
| downloaded_video_path = os.path.join(video_root, "downloaded_video.mp4") |
| downloaded_image_path = os.path.join(video_root, "downloaded_image.jpg") |
|
|
| self.download_file(video_url, downloaded_video_path) |
| self.download_file(ref_image_url, downloaded_image_path) |
| ref_image = Image.open(downloaded_image_path) |
|
|
| original_width, original_height = ref_image.size |
| max_dimension = max(original_width, original_height) |
| if max_dimension > 600: |
| ratio = max_dimension / 600 |
| width = int(original_width / ratio) |
| height = int(original_height / ratio) |
| else: |
| width = original_width |
| height = original_height |
|
|
| ref_image_no_bg = remove(ref_image) |
| ref_image_no_bg_path = os.path.join(video_root, "ref_image_no_bg.png") |
| ref_image_no_bg.save(ref_image_no_bg_path) |
|
|
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| torch.manual_seed(seed) |
|
|
| pose_images = read_frames(downloaded_video_path) |
| src_fps = get_fps(downloaded_video_path) |
| |
| pose_list = [] |
| total_length = min(length, len(pose_images)) |
| for pose_image_pil in pose_images[:total_length]: |
| pose_list.append(pose_image_pil) |
|
|
| video = self.pipeline( |
| ref_image_no_bg, |
| pose_list, |
| width=width, |
| height=height, |
| video_length=total_length, |
| num_inference_steps=num_inference_steps, |
| guidance_scale=cfg |
| ).videos |
|
|
| save_dir = os.path.join(temp_dir, "output") |
| if not os.path.exists(save_dir): |
| os.makedirs(save_dir, exist_ok=True) |
| animation_path = os.path.join(save_dir, "animation_output.mp4") |
| save_videos_grid(video, animation_path, n_rows=1, fps=src_fps) |
|
|
| cropped_face_path = os.path.join(save_dir, "cropped_face.jpg") |
| cropped_face = self._crop_face(ref_image_no_bg, save_path=cropped_face_path) |
|
|
| torch.cuda.empty_cache() |
|
|
| swapped_face_video_path = os.path.join(save_dir, "swapped_face_output.mp4") |
| facefusion_script_path = os.path.join(base_dir, 'facefusion', 'core.py') |
| swap_command = f'python3 {facefusion_script_path} --source {cropped_face_path} --target {animation_path} --output {swapped_face_video_path}' |
| swap_result = subprocess.run(swap_command, shell=True, capture_output=True, text=True) |
| if swap_result.returncode != 0: |
| raise RuntimeError(f"Error running face swap: {swap_result.stderr}") |
|
|
| |
| |
| |
| |
|
|
|
|
| torch.cuda.empty_cache() |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| |
| |
| |
|
|
|
|
| with open(swapped_face_video_path, "rb") as video_file: |
| video_base64 = base64.b64encode(video_file.read()).decode("utf-8") |
|
|
| |
| bucket = storage.bucket() |
| blob = bucket.blob(f"videos/{firebase_doc_id}/swapped_face_output.mp4") |
| blob.upload_from_filename(swapped_face_video_path) |
| |
| |
| blob.make_public() |
|
|
| video_url = blob.public_url |
|
|
| |
| db = firestore.client() |
| doc_ref = db.collection('danceResults').document(firebase_doc_id) |
| doc_ref.update({"videoResultUrl": video_url}) |
|
|
| return {"video": video_base64} |
|
|