Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import sys | |
| import warnings | |
| import logging | |
| import argparse | |
| import json | |
| import random | |
| from datetime import datetime | |
| import torch | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| from tqdm import tqdm | |
| from natsort import natsorted, ns | |
| from einops import rearrange | |
| from omegaconf import OmegaConf | |
| from huggingface_hub import snapshot_download | |
| from transformers import ( | |
| Dinov2Model, CLIPImageProcessor, CLIPVisionModelWithProjection, AutoImageProcessor | |
| ) | |
| from Next3d.training_avatar_texture.camera_utils import LookAtPoseSampler, FOV_to_intrinsics | |
| from data_process.lib.FaceVerse.renderer import Faceverse_manager | |
| import recon.dnnlib as dnnlib | |
| import recon.legacy as legacy | |
| from DiT_VAE.diffusion.utils.misc import read_config | |
| from DiT_VAE.vae.triplane_vae import AutoencoderKL as AutoencoderKLTriplane | |
| from DiT_VAE.diffusion import IDDPM, DPMS | |
| from DiT_VAE.diffusion.model.nets import TriDitCLIPDINO_XL_2 | |
| from DiT_VAE.diffusion.data.datasets import get_chunks | |
| # Get the directory of the current script | |
| father_path = os.path.dirname(os.path.abspath(__file__)) | |
| # Add necessary paths dynamically | |
| sys.path.extend([ | |
| os.path.join(father_path, 'recon'), | |
| os.path.join(father_path, 'Next3d') | |
| ]) | |
| # Suppress warnings (especially for PyTorch) | |
| warnings.filterwarnings("ignore") | |
| # Configure logging settings | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(levelname)s - %(message)s" | |
| ) | |
| def get_args(): | |
| """Parse and return command-line arguments.""" | |
| parser = argparse.ArgumentParser(description="4D Triplane Generation Arguments") | |
| # Configuration and model checkpoints | |
| parser.add_argument("--config", type=str, default="./configs/infer_config.py", | |
| help="Path to the configuration file.") | |
| # Input data paths | |
| parser.add_argument("--target_path", type=str, required=True, default='./demo_data/target_video/data_obama', | |
| help="Base path of the dataset.") | |
| parser.add_argument("--img_file", type=str, required=True, default='./demo_data/source_img/img_generate_different_domain/images512x512/demo_imgs', | |
| help="Directory containing input images.") | |
| parser.add_argument("--input_img_motion", type=str, | |
| default="./demo_data/source_img/img_generate_different_domain/motions/demo_imgs", | |
| help="Directory containing motion features.") | |
| parser.add_argument("--video_name", type=str, required=True, default='Obama', | |
| help="Name of the video.") | |
| parser.add_argument("--input_img_fvid", type=str, | |
| default="./demo_data/source_img/img_generate_different_domain/coeffs/demo_imgs", | |
| help="Path to input image coefficients.") | |
| # Output settings | |
| parser.add_argument("--output_basedir", type=str, default="./output", | |
| help="Base directory for saving output results.") | |
| # Generation parameters | |
| parser.add_argument("--bs", type=int, default=1, | |
| help="Batch size for processing.") | |
| parser.add_argument("--cfg_scale", type=float, default=4.5, | |
| help="CFG scale parameter.") | |
| parser.add_argument("--sampling_algo", type=str, default="dpm-solver", | |
| choices=["iddpm", "dpm-solver"], | |
| help="Sampling algorithm to be used.") | |
| parser.add_argument("--seed", type=int, default=0, | |
| help="Random seed for reproducibility.") | |
| parser.add_argument("--select_img", type=str, default=None, | |
| help="Optional: Select a specific image.") | |
| parser.add_argument('--step', default=-1, type=int) | |
| parser.add_argument('--use_demo_cam', action='store_true', help="Enable predefined camera parameters") | |
| return parser.parse_args() | |
| def set_env(seed=0): | |
| """Set random seed for reproducibility across multiple frameworks.""" | |
| torch.manual_seed(seed) # Set PyTorch seed | |
| torch.cuda.manual_seed_all(seed) # If using multi-GPU | |
| np.random.seed(seed) # Set NumPy seed | |
| random.seed(seed) # Set Python built-in random module seed | |
| torch.set_grad_enabled(False) # Disable gradients for inference | |
| def to_rgb_image(image: Image.Image): | |
| """Convert an image to RGB format if necessary.""" | |
| if image.mode == 'RGB': | |
| return image | |
| elif image.mode == 'RGBA': | |
| img = Image.new("RGB", image.size, (127, 127, 127)) | |
| img.paste(image, mask=image.getchannel('A')) | |
| return img | |
| else: | |
| raise ValueError(f"Unsupported image type: {image.mode}") | |
| def image_process(image_path): | |
| """Preprocess an image for CLIP and DINO models.""" | |
| image = to_rgb_image(Image.open(image_path)) | |
| clip_image = clip_image_processor(images=image, return_tensors="pt").pixel_values.to(device) | |
| dino_image = dino_img_processor(images=image, return_tensors="pt").pixel_values.to(device) | |
| return dino_image, clip_image | |
| def video_gen(frames_dir, output_path, fps=30): | |
| """Generate a video from image frames.""" | |
| frame_files = natsorted(os.listdir(frames_dir), alg=ns.PATH) | |
| frames = [cv2.imread(os.path.join(frames_dir, f)) for f in frame_files] | |
| H, W = frames[0].shape[:2] | |
| video_writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'MP4V'), fps, (W, H)) | |
| for frame in frames: | |
| video_writer.write(frame) | |
| video_writer.release() | |
| def trans(tensor_img): | |
| img = (tensor_img.permute(0, 2, 3, 1) * 0.5 + 0.5).clamp(0, 1) * 255. | |
| img = img.to(torch.uint8) | |
| img = img[0].detach().cpu().numpy() | |
| return img | |
| def get_vert(vert_dir): | |
| uvcoords_image = np.load(os.path.join(vert_dir))[..., :3] | |
| uvcoords_image[..., -1][uvcoords_image[..., -1] < 0.5] = 0 | |
| uvcoords_image[..., -1][uvcoords_image[..., -1] >= 0.5] = 1 | |
| return torch.tensor(uvcoords_image.copy()).float().unsqueeze(0) | |
| def generate_samples(DiT_model, cfg_scale, sample_steps, clip_feature, dino_feature, uncond_clip_feature, | |
| uncond_dino_feature, device, latent_size, sampling_algo): | |
| """ | |
| Generate latent samples using the specified diffusion model. | |
| Args: | |
| DiT_model (torch.nn.Module): The diffusion model. | |
| cfg_scale (float): The classifier-free guidance scale. | |
| sample_steps (int): Number of sampling steps. | |
| clip_feature (torch.Tensor): CLIP feature tensor. | |
| dino_feature (torch.Tensor): DINO feature tensor. | |
| uncond_clip_feature (torch.Tensor): Unconditional CLIP feature tensor. | |
| uncond_dino_feature (torch.Tensor): Unconditional DINO feature tensor. | |
| device (str): Device for computation. | |
| latent_size (tuple): The latent space size. | |
| sampling_algo (str): The sampling algorithm ('iddpm' or 'dpm-solver'). | |
| Returns: | |
| torch.Tensor: The generated samples. | |
| """ | |
| n = 1 # Batch size | |
| z = torch.randn(n, 8, latent_size[0], latent_size[1], device=device) | |
| if sampling_algo == 'iddpm': | |
| z = z.repeat(2, 1, 1, 1) # Duplicate for classifier-free guidance | |
| model_kwargs = dict(y=torch.cat([clip_feature, uncond_clip_feature]), | |
| img_feature=torch.cat([dino_feature, dino_feature]), | |
| cfg_scale=cfg_scale) | |
| diffusion = IDDPM(str(sample_steps)) | |
| samples = diffusion.p_sample_loop(DiT_model.forward_with_cfg, z.shape, z, clip_denoised=False, | |
| model_kwargs=model_kwargs, progress=True, device=device) | |
| samples, _ = samples.chunk(2, dim=0) # Remove unconditional samples | |
| elif sampling_algo == 'dpm-solver': | |
| dpm_solver = DPMS(DiT_model.forward_with_dpmsolver, | |
| condition=[clip_feature, dino_feature], | |
| uncondition=[uncond_clip_feature, dino_feature], | |
| cfg_scale=cfg_scale) | |
| samples = dpm_solver.sample(z, steps=sample_steps, order=2, skip_type="time_uniform", method="multistep") | |
| else: | |
| raise ValueError(f"Invalid sampling_algo '{sampling_algo}'. Choose either 'iddpm' or 'dpm-solver'.") | |
| return samples | |
| def images_to_video(image_folder, output_video, fps=30): | |
| # Get all image files and ensure correct order | |
| images = [img for img in os.listdir(image_folder) if img.endswith((".png", ".jpg", ".jpeg"))] | |
| images = natsorted(images) # Sort filenames naturally to preserve frame order | |
| if not images: | |
| print("β No images found in the directory!") | |
| return | |
| # Get the path to the FFmpeg executable | |
| ffmpeg_exe = ffmpeg.get_ffmpeg_exe() | |
| print(f"Using FFmpeg from: {ffmpeg_exe}") | |
| # Define input image pattern (expects images named like "%04d.png") | |
| image_pattern = os.path.join(image_folder, "%04d.png") | |
| # FFmpeg command to encode video | |
| command = [ | |
| ffmpeg_exe, '-framerate', str(fps), '-i', image_pattern, | |
| '-c:v', 'libx264', '-preset', 'slow', '-crf', '18', # High-quality H.264 encoding | |
| '-pix_fmt', 'yuv420p', '-b:v', '5000k', # Ensure compatibility & increase bitrate | |
| output_video | |
| ] | |
| # Run FFmpeg command | |
| subprocess.run(command, check=True) | |
| print(f"β High-quality MP4 video has been generated: {output_video}") | |
| def avatar_generation(items, bs, sample_steps, cfg_scale, save_path_base, DiT_model, render_model, std, mean, ws_avg, | |
| Faceverse, pitch_range=0.25, yaw_range=0.35, demo_cam=False): | |
| """ | |
| Generate avatars from input images. | |
| Args: | |
| items (list): List of image paths. | |
| bs (int): Batch size. | |
| sample_steps (int): Number of sampling steps. | |
| cfg_scale (float): Classifier-free guidance scale. | |
| save_path_base (str): Base directory for saving results. | |
| DiT_model (torch.nn.Module): The diffusion model. | |
| render_model (torch.nn.Module): The rendering model. | |
| std (torch.Tensor): Standard deviation normalization tensor. | |
| mean (torch.Tensor): Mean normalization tensor. | |
| ws_avg (torch.Tensor): Latent average tensor. | |
| """ | |
| for chunk in tqdm(list(get_chunks(items, bs)), unit='batch'): | |
| if bs != 1: | |
| raise ValueError("Batch size > 1 not implemented") | |
| image_dir = chunk[0] | |
| image_name = os.path.splitext(os.path.basename(image_dir))[0] | |
| dino_img, clip_image = image_process(image_dir) | |
| clip_feature = image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] | |
| uncond_clip_feature = image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[ | |
| -2] | |
| dino_feature = dinov2(dino_img).last_hidden_state | |
| uncond_dino_feature = dinov2(torch.zeros_like(dino_img)).last_hidden_state | |
| samples = generate_samples(DiT_model, cfg_scale, sample_steps, clip_feature, dino_feature, | |
| uncond_clip_feature, uncond_dino_feature, device, latent_size, | |
| args.sampling_algo) | |
| samples = (samples / default_config.scale_factor) | |
| samples = rearrange(samples, "b c (f h) w -> b c f h w", f=4) | |
| samples = vae_triplane.decode(samples) | |
| samples = rearrange(samples, "b c f h w -> b f c h w") | |
| samples = samples * std + mean | |
| torch.cuda.empty_cache() | |
| save_frames_path_combine = os.path.join(save_path_base, image_name, 'combine') | |
| save_frames_path_out = os.path.join(save_path_base, image_name, 'out') | |
| os.makedirs(save_frames_path_combine, exist_ok=True) | |
| os.makedirs(save_frames_path_out, exist_ok=True) | |
| img_ref = np.array(Image.open(image_dir)) | |
| img_ref_out = img_ref.copy() | |
| img_ref = torch.from_numpy(img_ref.astype(np.float32) / 127.5 - 1).permute(2, 0, 1).unsqueeze(0).to(device) | |
| motion_app_dir = os.path.join(args.input_img_motion, image_name + '.npy') | |
| motion_app = torch.tensor(np.load(motion_app_dir), dtype=torch.float32).unsqueeze(0).to(device) | |
| id_motions = os.path.join(args.input_img_fvid, image_name + '.npy') | |
| all_pose = json.loads(open(label_file_test).read())['labels'] | |
| all_pose = dict(all_pose) | |
| if os.path.exists(id_motions): | |
| coeff = np.load(id_motions).astype(np.float32) | |
| coeff = torch.from_numpy(coeff).to(device).float().unsqueeze(0) | |
| Faceverse.id_coeff = Faceverse.recon_model.split_coeffs(coeff)[0] | |
| motion_dir = os.path.join(motion_base_dir, args.video_name) | |
| exp_dir = os.path.join(exp_base_dir, args.video_name) | |
| for frame_index, motion_name in enumerate( | |
| tqdm(natsorted(os.listdir(motion_dir), alg=ns.PATH), desc="Processing Frames")): | |
| exp_each_dir_img = os.path.join(exp_img_base_dir, args.video_name, motion_name.replace('.npy', '.png')) | |
| exp_each_dir = os.path.join(exp_dir, motion_name) | |
| motion_each_dir = os.path.join(motion_dir, motion_name) | |
| # Load pose data | |
| pose_key = os.path.join(args.video_name, motion_name.replace('.npy', '.png')) | |
| if demo_cam: | |
| cam2world_pose = LookAtPoseSampler.sample( | |
| 3.14 / 2 + yaw_range * np.sin(2 * 3.14 * frame_index / len(os.listdir(motion_dir))), | |
| 3.14 / 2 - 0.05 + pitch_range * np.cos(2 * 3.14 * frame_index / len(os.listdir(motion_dir))), | |
| torch.tensor([0, 0, 0], device=device), radius=2.7, device=device) | |
| pose = torch.cat([cam2world_pose.reshape(-1, 16), | |
| FOV_to_intrinsics(fov_degrees=18.837, device=device).reshape(-1, 9)], 1).to(device) | |
| else: | |
| pose = torch.tensor(np.array(all_pose[pose_key]).astype(np.float32)).float().unsqueeze(0).to(device) | |
| # Load and resize expression image | |
| exp_img = np.array(Image.open(exp_each_dir_img).resize((512, 512))) | |
| # Load expression coefficients | |
| exp_coeff = torch.from_numpy(np.load(exp_each_dir).astype(np.float32)).to(device).float().unsqueeze(0) | |
| exp_target = Faceverse.make_driven_rendering(exp_coeff, res=256) | |
| # Load motion data | |
| motion = torch.tensor(np.load(motion_each_dir)).float().unsqueeze(0).to(device) | |
| # Select refine_net processing method | |
| final_out = render_model( | |
| img_ref, None, motion_app, motion, c=pose, mesh=exp_target, triplane_recon=samples, | |
| ws_avg=ws_avg, motion_scale=1. | |
| ) | |
| # Process output image | |
| final_out = trans(final_out['image_sr']) | |
| output_img_combine = np.hstack((img_ref_out, exp_img, final_out)) | |
| # Save output images | |
| frame_name = f'{str(frame_index).zfill(4)}.png' | |
| Image.fromarray(output_img_combine, 'RGB').save(os.path.join(save_frames_path_combine, frame_name)) | |
| Image.fromarray(final_out, 'RGB').save(os.path.join(save_frames_path_out, frame_name)) | |
| # Generate videos | |
| images_to_video(save_frames_path_combine, os.path.join(save_path_base, image_name + '_combine.mp4')) | |
| images_to_video(save_frames_path_out, os.path.join(save_path_base, image_name + '_out.mp4')) | |
| logging.info(f"β Video generation completed successfully!") | |
| logging.info(f"π Combined video saved at: {os.path.join(save_path_base, image_name + '_combine.mp4')}") | |
| logging.info(f"π Output video saved at: {os.path.join(save_path_base, image_name + '_out.mp4')}") | |
| def load_motion_aware_render_model(ckpt_path): | |
| """Load the motion-aware render model from a checkpoint.""" | |
| logging.info("Loading motion-aware render model...") | |
| with dnnlib.util.open_url(ckpt_path, 'rb') as f: | |
| network = legacy.load_network_pkl(f) # type: ignore | |
| logging.info("Motion-aware render model loaded.") | |
| return network['G_ema'].to(device) | |
| def load_diffusion_model(ckpt_path, latent_size): | |
| """Load the diffusion model (DiT).""" | |
| logging.info("Loading diffusion model (DiT)...") | |
| DiT_model = TriDitCLIPDINO_XL_2(input_size=latent_size).to(device) | |
| ckpt = torch.load(ckpt_path, map_location="cpu") | |
| # Remove keys that can cause mismatches | |
| for key in ['pos_embed', 'base_model.pos_embed', 'model.pos_embed']: | |
| ckpt['state_dict'].pop(key, None) | |
| ckpt.get('state_dict_ema', {}).pop(key, None) | |
| state_dict = ckpt.get('state_dict_ema', ckpt) | |
| DiT_model.load_state_dict(state_dict, strict=False) | |
| DiT_model.eval() | |
| logging.info("Diffusion model (DiT) loaded.") | |
| return DiT_model | |
| def load_vae_clip_dino(config, device): | |
| """Load VAE, CLIP, and DINO models.""" | |
| logging.info("Loading VAE, CLIP, and DINO models...") | |
| # Load CLIP image encoder | |
| image_encoder = CLIPVisionModelWithProjection.from_pretrained( | |
| config.image_encoder_path) | |
| image_encoder.requires_grad_(False) | |
| image_encoder.to(device) | |
| # Load VAE | |
| config_vae = OmegaConf.load(config.vae_triplane_config_path) | |
| vae_triplane = AutoencoderKLTriplane(ddconfig=config_vae['ddconfig'], lossconfig=None, embed_dim=8) | |
| vae_triplane.to(device) | |
| vae_ckpt_path = os.path.join(config.vae_pretrained, 'pytorch_model.bin') | |
| if not os.path.isfile(vae_ckpt_path): | |
| raise RuntimeError(f"VAE checkpoint not found at {vae_ckpt_path}") | |
| vae_triplane.load_state_dict(torch.load(vae_ckpt_path, map_location="cpu")) | |
| vae_triplane.requires_grad_(False) | |
| # Load DINO model | |
| dinov2 = Dinov2Model.from_pretrained(config.dino_pretrained) | |
| dinov2.requires_grad_(False) | |
| dinov2.to(device) | |
| # Load image processors | |
| dino_img_processor = AutoImageProcessor.from_pretrained(config.dino_pretrained) | |
| clip_image_processor = CLIPImageProcessor() | |
| logging.info("VAE, CLIP, and DINO models loaded.") | |
| return vae_triplane, image_encoder, dinov2, dino_img_processor, clip_image_processor | |
| def prepare_image_list(img_dir, selected_img): | |
| """Prepare the list of image paths for processing.""" | |
| if selected_img and selected_img in os.listdir(img_dir): | |
| return [os.path.join(img_dir, selected_img)] | |
| return sorted([os.path.join(img_dir, img) for img in os.listdir(img_dir)]) | |
| if __name__ == '__main__': | |
| model_path = "./pretrained_model" | |
| if not os.path.exists(model_path): | |
| logging.info("π₯ Model not found. Downloading from Hugging Face...") | |
| snapshot_download(repo_id="KumaPower/AvatarArtist", local_dir=model_path) | |
| logging.info("β Model downloaded successfully!") | |
| else: | |
| logging.info("π Pretrained model already exists. Skipping download.") | |
| args = get_args() | |
| exp_base_dir = os.path.join(args.target_path, 'coeffs') | |
| exp_img_base_dir = os.path.join(args.target_path, 'images512x512') | |
| motion_base_dir = os.path.join(args.target_path, 'motions') | |
| label_file_test = os.path.join(args.target_path, 'images512x512/dataset_realcam.json') | |
| set_env(args.seed) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| weight_dtype = torch.float32 | |
| logging.info(f"Running inference with {weight_dtype}") | |
| # Load configuration | |
| default_config = read_config(args.config) | |
| # Ensure valid sampling algorithm | |
| assert args.sampling_algo in ['iddpm', 'dpm-solver', 'sa-solver'] | |
| # Prepare image list | |
| items = prepare_image_list(args.img_file, args.select_img) | |
| logging.info(f"Input images: {items}") | |
| # Load motion-aware render model | |
| motion_aware_render_model = load_motion_aware_render_model(default_config.motion_aware_render_model_ckpt) | |
| # Load diffusion model (DiT) | |
| triplane_size = (256 * 4, 256) | |
| latent_size = (triplane_size[0] // 8, triplane_size[1] // 8) | |
| sample_steps = args.step if args.step != -1 else {'iddpm': 100, 'dpm-solver': 20, 'sa-solver': 25}[ | |
| args.sampling_algo] | |
| DiT_model = load_diffusion_model(default_config.DiT_model_ckpt, latent_size) | |
| # Load VAE, CLIP, and DINO | |
| vae_triplane, image_encoder, dinov2, dino_img_processor, clip_image_processor = load_vae_clip_dino(default_config, | |
| device) | |
| # Load normalization parameters | |
| triplane_std = torch.load(default_config.std_dir).to(device).reshape(1, -1, 1, 1, 1) | |
| triplane_mean = torch.load(default_config.mean_dir).to(device).reshape(1, -1, 1, 1, 1) | |
| # Load average latent vector | |
| ws_avg = torch.load(default_config.ws_avg_pkl).to(device)[0] | |
| # Set up save directory | |
| save_root = os.path.join(args.output_basedir, f'{datetime.now().date()}', args.video_name) | |
| os.makedirs(save_root, exist_ok=True) | |
| # Set up face verse for amimation | |
| base_coff = np.load( | |
| 'pretrained_model/temp.npy').astype( | |
| np.float32) | |
| base_coff = torch.from_numpy(base_coff).float() | |
| Faceverse = Faceverse_manager(device=device, base_coeff=base_coff) | |
| # Run avatar generation | |
| avatar_generation(items, args.bs, sample_steps, args.cfg_scale, save_root, DiT_model, motion_aware_render_model, | |
| triplane_std, triplane_mean, ws_avg, Faceverse, demo_cam=args.use_demo_cam) | |