Spaces:
Configuration error
Configuration error
| import copy | |
| import torch | |
| from safetensors import safe_open | |
| from garment_seg.process import load_seg_model, generate_mask | |
| from utils.utils import is_torch2_available, prepare_image, prepare_mask | |
| from diffusers import UNet2DConditionModel | |
| if is_torch2_available(): | |
| from .attention_processor import REFAttnProcessor2_0 as REFAttnProcessor | |
| from .attention_processor import AttnProcessor2_0 as AttnProcessor | |
| from .attention_processor import REFAnimateDiffAttnProcessor2_0 as REFAnimateDiffAttnProcessor | |
| else: | |
| from .attention_processor import REFAttnProcessor, AttnProcessor | |
| class ClothAdapter: | |
| def __init__(self, sd_pipe, ref_path, device, enable_cloth_guidance, set_seg_model=True): | |
| self.enable_cloth_guidance = enable_cloth_guidance | |
| self.device = device | |
| self.pipe = sd_pipe.to(self.device) | |
| self.set_adapter(self.pipe.unet, "write") | |
| print(ref_path) | |
| ref_unet = copy.deepcopy(sd_pipe.unet) | |
| if ref_unet.config.in_channels == 9: | |
| ref_unet.conv_in = torch.nn.Conv2d(4, 320, ref_unet.conv_in.kernel_size, ref_unet.conv_in.stride, ref_unet.conv_in.padding) | |
| ref_unet.register_to_config(in_channels=4) | |
| state_dict = {} | |
| with safe_open(ref_path, framework="pt", device="cpu") as f: | |
| for key in f.keys(): | |
| state_dict[key] = f.get_tensor(key) | |
| ref_unet.load_state_dict(state_dict, strict=False) | |
| self.ref_unet = ref_unet.to(self.device, dtype=self.pipe.dtype) | |
| self.set_adapter(self.ref_unet, "read") | |
| if set_seg_model: | |
| self.set_seg_model() | |
| self.attn_store = {} | |
| def set_seg_model(self, ): | |
| checkpoint_path = 'checkpoints/cloth_segm.pth' | |
| self.seg_net = load_seg_model(checkpoint_path, device=self.device) | |
| def set_adapter(self, unet, type): | |
| attn_procs = {} | |
| for name in unet.attn_processors.keys(): | |
| if "attn1" in name: | |
| attn_procs[name] = REFAttnProcessor(name=name, type=type) | |
| else: | |
| attn_procs[name] = AttnProcessor() | |
| unet.set_attn_processor(attn_procs) | |
| def generate( | |
| self, | |
| cloth_image, | |
| cloth_mask_image=None, | |
| prompt=None, | |
| a_prompt="best quality, high quality", | |
| num_images_per_prompt=4, | |
| negative_prompt=None, | |
| seed=-1, | |
| guidance_scale=7.5, | |
| cloth_guidance_scale=2.5, | |
| num_inference_steps=20, | |
| height=512, | |
| width=384, | |
| **kwargs, | |
| ): | |
| if cloth_mask_image is None: | |
| cloth_mask_image = generate_mask(cloth_image, net=self.seg_net, device=self.device) | |
| cloth = prepare_image(cloth_image, height, width) | |
| cloth_mask = prepare_mask(cloth_mask_image, height, width) | |
| cloth = (cloth * cloth_mask).to(self.device, dtype=torch.float16) | |
| if prompt is None: | |
| prompt = "a photography of a model" | |
| prompt = prompt + ", " + a_prompt | |
| if negative_prompt is None: | |
| negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" | |
| with torch.inference_mode(): | |
| prompt_embeds, negative_prompt_embeds = self.pipe.encode_prompt( | |
| prompt, | |
| device=self.device, | |
| num_images_per_prompt=num_images_per_prompt, | |
| do_classifier_free_guidance=True, | |
| negative_prompt=negative_prompt, | |
| ) | |
| prompt_embeds_null = self.pipe.encode_prompt([""], device=self.device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=False)[0] | |
| cloth_embeds = self.pipe.vae.encode(cloth).latent_dist.mode() * self.pipe.vae.config.scaling_factor | |
| self.ref_unet(torch.cat([cloth_embeds] * num_images_per_prompt), 0, prompt_embeds_null, cross_attention_kwargs={"attn_store": self.attn_store}) | |
| generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None | |
| if self.enable_cloth_guidance: | |
| images = self.pipe( | |
| prompt_embeds=prompt_embeds, | |
| negative_prompt_embeds=negative_prompt_embeds, | |
| guidance_scale=guidance_scale, | |
| cloth_guidance_scale=cloth_guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| generator=generator, | |
| height=height, | |
| width=width, | |
| cross_attention_kwargs={"attn_store": self.attn_store, "do_classifier_free_guidance": guidance_scale > 1.0, "enable_cloth_guidance": self.enable_cloth_guidance}, | |
| **kwargs, | |
| ).images | |
| else: | |
| images = self.pipe( | |
| prompt_embeds=prompt_embeds, | |
| negative_prompt_embeds=negative_prompt_embeds, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| generator=generator, | |
| height=height, | |
| width=width, | |
| cross_attention_kwargs={"attn_store": self.attn_store, "do_classifier_free_guidance": guidance_scale > 1.0, "enable_cloth_guidance": self.enable_cloth_guidance}, | |
| **kwargs, | |
| ).images | |
| return images, cloth_mask_image | |
| def generate_inpainting( | |
| self, | |
| cloth_image, | |
| cloth_mask_image=None, | |
| num_images_per_prompt=4, | |
| seed=-1, | |
| cloth_guidance_scale=2.5, | |
| num_inference_steps=20, | |
| height=512, | |
| width=384, | |
| **kwargs, | |
| ): | |
| if cloth_mask_image is None: | |
| cloth_mask_image = generate_mask(cloth_image, net=self.seg_net, device=self.device) | |
| cloth = prepare_image(cloth_image, height, width) | |
| cloth_mask = prepare_mask(cloth_mask_image, height, width) | |
| cloth = (cloth * cloth_mask).to(self.device, dtype=torch.float16) | |
| with torch.inference_mode(): | |
| prompt_embeds_null = self.pipe.encode_prompt([""], device=self.device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=False)[0] | |
| cloth_embeds = self.pipe.vae.encode(cloth).latent_dist.mode() * self.pipe.vae.config.scaling_factor | |
| self.ref_unet(torch.cat([cloth_embeds] * num_images_per_prompt), 0, prompt_embeds_null, cross_attention_kwargs={"attn_store": self.attn_store}) | |
| generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None | |
| images = self.pipe( | |
| prompt_embeds=prompt_embeds_null, | |
| cloth_guidance_scale=cloth_guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| generator=generator, | |
| height=height, | |
| width=width, | |
| cross_attention_kwargs={"attn_store": self.attn_store, "do_classifier_free_guidance": cloth_guidance_scale > 1.0, "enable_cloth_guidance": False}, | |
| **kwargs, | |
| ).images | |
| return images, cloth_mask_image | |
| class ClothAdapter_AnimateDiff: | |
| def __init__(self, sd_pipe, pipe_path, ref_path, device, set_seg_model=True): | |
| self.device = device | |
| self.pipe = sd_pipe.to(self.device) | |
| self.set_adapter(self.pipe.unet, "write") | |
| ref_unet = UNet2DConditionModel.from_pretrained(pipe_path, subfolder='unet', torch_dtype=sd_pipe.dtype) | |
| state_dict = {} | |
| with safe_open(ref_path, framework="pt", device="cpu") as f: | |
| for key in f.keys(): | |
| state_dict[key] = f.get_tensor(key) | |
| ref_unet.load_state_dict(state_dict, strict=False) | |
| self.ref_unet = ref_unet.to(self.device) | |
| self.set_adapter(self.ref_unet, "read") | |
| if set_seg_model: | |
| self.set_seg_model() | |
| self.attn_store = {} | |
| def set_seg_model(self, ): | |
| checkpoint_path = 'checkpoints/cloth_segm.pth' | |
| self.seg_net = load_seg_model(checkpoint_path, device=self.device) | |
| def set_adapter(self, unet, type): | |
| attn_procs = {} | |
| for name in unet.attn_processors.keys(): | |
| if "attn1" in name and "motion_modules" not in name: | |
| attn_procs[name] = REFAnimateDiffAttnProcessor(name=name, type=type) | |
| else: | |
| attn_procs[name] = AttnProcessor() | |
| unet.set_attn_processor(attn_procs) | |
| def generate( | |
| self, | |
| cloth_image, | |
| cloth_mask_image=None, | |
| prompt=None, | |
| a_prompt="best quality, high quality", | |
| num_images_per_prompt=4, | |
| negative_prompt=None, | |
| seed=-1, | |
| guidance_scale=7.5, | |
| cloth_guidance_scale=3., | |
| num_inference_steps=20, | |
| height=512, | |
| width=384, | |
| **kwargs, | |
| ): | |
| if cloth_mask_image is None: | |
| cloth_mask_image = generate_mask(cloth_image, net=self.seg_net, device=self.device) | |
| cloth = prepare_image(cloth_image, height, width) | |
| cloth_mask = prepare_mask(cloth_mask_image, height, width) | |
| cloth = (cloth * cloth_mask).to(self.device, dtype=torch.float16) | |
| if prompt is None: | |
| prompt = "a photography of a model" | |
| prompt = prompt + ", " + a_prompt | |
| if negative_prompt is None: | |
| negative_prompt = "bare, naked, nude, undressed, monochrome, lowres, bad anatomy, worst quality, low quality" | |
| with torch.inference_mode(): | |
| prompt_embeds, negative_prompt_embeds = self.pipe.encode_prompt( | |
| prompt, | |
| device=self.device, | |
| num_images_per_prompt=num_images_per_prompt, | |
| do_classifier_free_guidance=True, | |
| negative_prompt=negative_prompt, | |
| ) | |
| prompt_embeds_null = self.pipe.encode_prompt([""], device=self.device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=False)[0] | |
| cloth_embeds = self.pipe.vae.encode(cloth).latent_dist.mode() * self.pipe.vae.config.scaling_factor | |
| self.ref_unet(torch.cat([cloth_embeds] * num_images_per_prompt), 0, prompt_embeds_null, cross_attention_kwargs={"attn_store": self.attn_store}) | |
| generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None | |
| frames = self.pipe( | |
| prompt_embeds=prompt_embeds, | |
| negative_prompt_embeds=negative_prompt_embeds, | |
| guidance_scale=guidance_scale, | |
| cloth_guidance_scale=cloth_guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| generator=generator, | |
| height=height, | |
| width=width, | |
| cross_attention_kwargs={"attn_store": self.attn_store, "do_classifier_free_guidance": guidance_scale > 1.0}, | |
| **kwargs, | |
| ).frames | |
| return frames, cloth_mask_image | |