Spaces:
Running
on
Zero
Running
on
Zero
| import sys | |
| import os | |
| import json | |
| import zipfile | |
| import argparse | |
| sys.path.append(os.getcwd()) | |
| from os.path import join as opj | |
| import numpy as np | |
| import torch | |
| from PIL import Image, ImageFile | |
| from torchvision.transforms import ToPILImage | |
| ImageFile.LOAD_TRUNCATED_IMAGES = True | |
| import transformers | |
| from diffusers import ( | |
| StableDiffusionControlNetImg2ImgPipeline, | |
| ControlNetModel, | |
| DPMSolverMultistepScheduler, | |
| AutoencoderKL, | |
| ) | |
| from landmark_generation import generate_annotation | |
| from natsort import ns, natsorted | |
| class DomainImageGeneration: | |
| def __init__( | |
| self, src_path, args, model_name, | |
| device="cuda:0", clip_skip=2, | |
| use_anime_vae=False, save_path='', | |
| only_lora=False, model_base_path='./diffusion_model' | |
| ): | |
| """Initialize DomainImageGeneration class""" | |
| self.model_base_path = model_base_path | |
| self.device = device | |
| self.args = args | |
| self.src_path = src_path | |
| self.use_model = model_name | |
| self.only_lora = only_lora | |
| self.out_path_base = os.path.join(save_path, model_name) | |
| os.makedirs(self.out_path_base, exist_ok=True) | |
| self.diffusion_checkpoint_path = self._get_model_path(model_name) | |
| self.pipe = self._load_pipeline(model_name, use_anime_vae, clip_skip) | |
| print("All models loaded successfully") | |
| def _get_model_path(self, model_name): | |
| """Retrieve the model checkpoint path based on the model name""" | |
| base_path = self.model_base_path | |
| if model_name == "stable-diffusion-2-1-base": | |
| return os.path.join(base_path, "stable-diffusion-2-1-base") | |
| elif self.only_lora: | |
| return os.path.join(base_path, "stable-diffusion-v1-5") | |
| else: | |
| return os.path.join(base_path, model_name) | |
| def _load_controlnet(self, model_name): | |
| """Load the ControlNet model""" | |
| controlnet_path = os.path.join(self.model_base_path, 'ControlNetMediaPipeFace') | |
| if model_name == "stable-diffusion-2-1-base": | |
| controlnet_path += "old" | |
| return ControlNetModel.from_pretrained( | |
| controlnet_path, torch_dtype=torch.float16 | |
| ) | |
| def _load_pipeline(self, model_name, use_anime_vae, clip_skip): | |
| """Load the Stable Diffusion ControlNet Img2Img Pipeline""" | |
| controlnet = self._load_controlnet(model_name) | |
| if use_anime_vae: | |
| print("Using Anime VAE") | |
| anime_vae = AutoencoderKL.from_pretrained( | |
| "/nas8/liuhongyu/model/kl-f8-anime2", torch_dtype=torch.float16 | |
| ) | |
| pipeline = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( | |
| self.diffusion_checkpoint_path, torch_dtype=torch.float16, safety_checker=None, | |
| vae=anime_vae, controlnet=controlnet | |
| ).to(self.device) | |
| self._load_lora(pipeline, "detail-tweaker-lora/add_detail.safetensors") | |
| elif model_name == "stable-diffusion-2-1-base": | |
| pipeline = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( | |
| self.diffusion_checkpoint_path, torch_dtype=torch.float16, | |
| use_safetensors=True, controlnet=controlnet, variant="fp16" | |
| ).to(self.device) | |
| else: | |
| text_encoder = transformers.CLIPTextModel.from_pretrained( | |
| self.diffusion_checkpoint_path, | |
| subfolder="text_encoder", | |
| num_hidden_layers=12 - (clip_skip - 1), | |
| torch_dtype=torch.float16 | |
| ) | |
| pipeline = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( | |
| self.diffusion_checkpoint_path, torch_dtype=torch.float16, | |
| use_safetensors=True, text_encoder=text_encoder, | |
| controlnet=controlnet, variant="fp16" | |
| ).to(self.device) | |
| self._apply_negative_embedding(pipeline, model_name) | |
| pipeline.scheduler = DPMSolverMultistepScheduler.from_config( | |
| pipeline.scheduler.config, use_karras_sigmas=True | |
| ) | |
| print("Target diffusion model loaded") | |
| return pipeline | |
| def _load_lora(self, pipeline, lora_name): | |
| """Load LoRA weights into the model""" | |
| lora_path = f"/nas8/liuhongyu/model/{lora_name}" | |
| state_dict, network_alphas = pipeline.lora_state_dict(lora_path) | |
| pipeline.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=pipeline.unet) | |
| def _apply_negative_embedding(self, pipeline, model_name): | |
| """Apply negative embedding (textual inversion)""" | |
| if model_name not in ["stable-diffusion-xl-base-1.0", "stable-diffusion-2-1-base"]: | |
| if self.only_lora: | |
| self._load_lora(pipeline, model_name) | |
| pipeline.safety_checker = lambda images, clip_input: (images, None) | |
| else: | |
| pipeline.load_textual_inversion( | |
| "/nas8/liuhongyu/lora_model", | |
| weight_name="EasyNegativeV2.safetensors", | |
| token="EasyNegative" | |
| ) | |
| def image_generation(self, prompt, strength=0.7, | |
| guidance_scale=7.5, num_inference_steps=30): | |
| """Generate images using the diffusion model""" | |
| out_path = os.path.join(self.out_path_base, prompt.replace(" ", "_")) | |
| os.makedirs(out_path, exist_ok=True) | |
| src_img_list = natsorted( | |
| [f for f in os.listdir(self.src_path) if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.gif'))], | |
| alg=ns.PATH | |
| ) | |
| all_gen_nums = 0 # Counter for generated images | |
| for img_name in src_img_list: | |
| src_img_pil = Image.open(os.path.join(self.src_path, img_name)) | |
| control_image = generate_annotation(src_img_pil, max_faces=1) | |
| if control_image is not None: | |
| prompt_input = prompt | |
| # Apply different generation methods based on the model type | |
| if self.use_model in ['stable-diffusion-xl-base-1.0', 'stable-diffusion-2-1-base']: | |
| trg_img_pil = self.pipe( | |
| prompt=prompt_input, | |
| image=src_img_pil, | |
| strength=strength, | |
| control_image=Image.fromarray(control_image), | |
| guidance_scale=guidance_scale, | |
| negative_prompt='worst quality, normal quality, low quality, low res, blurry', | |
| num_inference_steps=num_inference_steps, | |
| controlnet_conditioning_scale=1.5 | |
| )['images'][0] | |
| else: | |
| trg_img_pil = self.pipe( | |
| prompt=prompt_input, | |
| image=src_img_pil, | |
| control_image=Image.fromarray(control_image), | |
| strength=strength, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| controlnet_conditioning_scale=1.5, | |
| negative_prompt='EasyNegative, worst quality, normal quality, low quality, low res, blurry' | |
| )['images'][0] | |
| # Save the generated image if valid | |
| if np.array(trg_img_pil).max() > 0: | |
| trg_img_pil.save(opj(out_path, img_name)) | |
| all_gen_nums += 1 | |
| def parse_args(): | |
| """Parse command-line arguments.""" | |
| parser = argparse.ArgumentParser(description="Domain Image Generation") | |
| parser.add_argument( | |
| "--src_img_path", | |
| type=str, | |
| default="demo_input", | |
| help="Path to the source image directory" | |
| ) | |
| parser.add_argument("--strength", type=float, default=0.6, help="Strength of the sdeedit") | |
| parser.add_argument("--prompt", type=str, default=None, help="Text prompt for image generation") | |
| parser.add_argument("--guidance_scale", type=float, default=7.5, help="Guidance scale for Stable Diffusion") | |
| parser.add_argument("--sd_model_id", type=str, default="stable-diffusion-2-1-base", help="Stable Diffusion model ID") | |
| parser.add_argument("--num_inference_steps", type=int, default=30, help="Number of inference steps") | |
| parser.add_argument("--save_base", type=str, default="./output", help="Output directory for generated images") | |
| parser.add_argument("--device", type=str, default="cuda:0", help="Device to run inference on (e.g., 'cuda:0')") | |
| parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility") | |
| parser.add_argument("--use_anime_vae", action="store_true", help="Enable Anime VAE for image generation") | |
| parser.add_argument("--model_base_path", type=str, default="./diffusion_model", help="Output directory for generated images") | |
| return parser.parse_args() | |
| def set_random_seed(seed: int): | |
| """Set random seed for reproducibility.""" | |
| torch.manual_seed(seed) | |
| np.random.seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| def main(): | |
| """Main function to execute the image generation process.""" | |
| args = parse_args() | |
| # Set random seed to ensure reproducibility | |
| set_random_seed(args.seed) | |
| # Check if the source image path exists | |
| if not os.path.exists(args.src_img_path): | |
| raise FileNotFoundError(f"❌ Source image path does not exist: {args.src_img_path}") | |
| # Ensure the output directory exists | |
| os.makedirs(args.save_base, exist_ok=True) | |
| # Initialize the DomainImageGeneration class and generate images | |
| data_generation = DomainImageGeneration( | |
| src_path=args.src_img_path, | |
| args=args, | |
| model_name=args.sd_model_id, | |
| save_path=args.save_base, | |
| device=args.device, | |
| use_anime_vae=args.use_anime_vae, | |
| model_base_path = args.model_base_path | |
| ) | |
| # Start image generation | |
| data_generation.image_generation(prompt=args.prompt, strength=args.strength, guidance_scale=args.guidance_scale,num_inference_steps=args.num_inference_steps ) | |
| if __name__ == "__main__": | |
| main() |