Spaces:
Configuration error
Configuration error
import copy | |
import torch | |
from safetensors import safe_open | |
from garment_seg.process import load_seg_model, generate_mask | |
from utils.utils import is_torch2_available, prepare_image, prepare_mask | |
from diffusers import UNet2DConditionModel | |
if is_torch2_available(): | |
from .attention_processor import REFAttnProcessor2_0 as REFAttnProcessor | |
from .attention_processor import AttnProcessor2_0 as AttnProcessor | |
from .attention_processor import REFAnimateDiffAttnProcessor2_0 as REFAnimateDiffAttnProcessor | |
else: | |
from .attention_processor import REFAttnProcessor, AttnProcessor | |
class ClothAdapter: | |
def __init__(self, sd_pipe, ref_path, device, enable_cloth_guidance, set_seg_model=True): | |
self.enable_cloth_guidance = enable_cloth_guidance | |
self.device = device | |
self.pipe = sd_pipe.to(self.device) | |
self.set_adapter(self.pipe.unet, "write") | |
print(ref_path) | |
ref_unet = copy.deepcopy(sd_pipe.unet) | |
if ref_unet.config.in_channels == 9: | |
ref_unet.conv_in = torch.nn.Conv2d(4, 320, ref_unet.conv_in.kernel_size, ref_unet.conv_in.stride, ref_unet.conv_in.padding) | |
ref_unet.register_to_config(in_channels=4) | |
state_dict = {} | |
with safe_open(ref_path, framework="pt", device="cpu") as f: | |
for key in f.keys(): | |
state_dict[key] = f.get_tensor(key) | |
ref_unet.load_state_dict(state_dict, strict=False) | |
self.ref_unet = ref_unet.to(self.device, dtype=self.pipe.dtype) | |
self.set_adapter(self.ref_unet, "read") | |
if set_seg_model: | |
self.set_seg_model() | |
self.attn_store = {} | |
def set_seg_model(self, ): | |
checkpoint_path = 'checkpoints/cloth_segm.pth' | |
self.seg_net = load_seg_model(checkpoint_path, device=self.device) | |
def set_adapter(self, unet, type): | |
attn_procs = {} | |
for name in unet.attn_processors.keys(): | |
if "attn1" in name: | |
attn_procs[name] = REFAttnProcessor(name=name, type=type) | |
else: | |
attn_procs[name] = AttnProcessor() | |
unet.set_attn_processor(attn_procs) | |
def generate( | |
self, | |
cloth_image, | |
cloth_mask_image=None, | |
prompt=None, | |
a_prompt="best quality, high quality", | |
num_images_per_prompt=4, | |
negative_prompt=None, | |
seed=-1, | |
guidance_scale=7.5, | |
cloth_guidance_scale=2.5, | |
num_inference_steps=20, | |
height=512, | |
width=384, | |
**kwargs, | |
): | |
if cloth_mask_image is None: | |
cloth_mask_image = generate_mask(cloth_image, net=self.seg_net, device=self.device) | |
cloth = prepare_image(cloth_image, height, width) | |
cloth_mask = prepare_mask(cloth_mask_image, height, width) | |
cloth = (cloth * cloth_mask).to(self.device, dtype=torch.float16) | |
if prompt is None: | |
prompt = "a photography of a model" | |
prompt = prompt + ", " + a_prompt | |
if negative_prompt is None: | |
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" | |
with torch.inference_mode(): | |
prompt_embeds, negative_prompt_embeds = self.pipe.encode_prompt( | |
prompt, | |
device=self.device, | |
num_images_per_prompt=num_images_per_prompt, | |
do_classifier_free_guidance=True, | |
negative_prompt=negative_prompt, | |
) | |
prompt_embeds_null = self.pipe.encode_prompt([""], device=self.device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=False)[0] | |
cloth_embeds = self.pipe.vae.encode(cloth).latent_dist.mode() * self.pipe.vae.config.scaling_factor | |
self.ref_unet(torch.cat([cloth_embeds] * num_images_per_prompt), 0, prompt_embeds_null, cross_attention_kwargs={"attn_store": self.attn_store}) | |
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None | |
if self.enable_cloth_guidance: | |
images = self.pipe( | |
prompt_embeds=prompt_embeds, | |
negative_prompt_embeds=negative_prompt_embeds, | |
guidance_scale=guidance_scale, | |
cloth_guidance_scale=cloth_guidance_scale, | |
num_inference_steps=num_inference_steps, | |
generator=generator, | |
height=height, | |
width=width, | |
cross_attention_kwargs={"attn_store": self.attn_store, "do_classifier_free_guidance": guidance_scale > 1.0, "enable_cloth_guidance": self.enable_cloth_guidance}, | |
**kwargs, | |
).images | |
else: | |
images = self.pipe( | |
prompt_embeds=prompt_embeds, | |
negative_prompt_embeds=negative_prompt_embeds, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
generator=generator, | |
height=height, | |
width=width, | |
cross_attention_kwargs={"attn_store": self.attn_store, "do_classifier_free_guidance": guidance_scale > 1.0, "enable_cloth_guidance": self.enable_cloth_guidance}, | |
**kwargs, | |
).images | |
return images, cloth_mask_image | |
def generate_inpainting( | |
self, | |
cloth_image, | |
cloth_mask_image=None, | |
num_images_per_prompt=4, | |
seed=-1, | |
cloth_guidance_scale=2.5, | |
num_inference_steps=20, | |
height=512, | |
width=384, | |
**kwargs, | |
): | |
if cloth_mask_image is None: | |
cloth_mask_image = generate_mask(cloth_image, net=self.seg_net, device=self.device) | |
cloth = prepare_image(cloth_image, height, width) | |
cloth_mask = prepare_mask(cloth_mask_image, height, width) | |
cloth = (cloth * cloth_mask).to(self.device, dtype=torch.float16) | |
with torch.inference_mode(): | |
prompt_embeds_null = self.pipe.encode_prompt([""], device=self.device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=False)[0] | |
cloth_embeds = self.pipe.vae.encode(cloth).latent_dist.mode() * self.pipe.vae.config.scaling_factor | |
self.ref_unet(torch.cat([cloth_embeds] * num_images_per_prompt), 0, prompt_embeds_null, cross_attention_kwargs={"attn_store": self.attn_store}) | |
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None | |
images = self.pipe( | |
prompt_embeds=prompt_embeds_null, | |
cloth_guidance_scale=cloth_guidance_scale, | |
num_inference_steps=num_inference_steps, | |
generator=generator, | |
height=height, | |
width=width, | |
cross_attention_kwargs={"attn_store": self.attn_store, "do_classifier_free_guidance": cloth_guidance_scale > 1.0, "enable_cloth_guidance": False}, | |
**kwargs, | |
).images | |
return images, cloth_mask_image | |
class ClothAdapter_AnimateDiff: | |
def __init__(self, sd_pipe, pipe_path, ref_path, device, set_seg_model=True): | |
self.device = device | |
self.pipe = sd_pipe.to(self.device) | |
self.set_adapter(self.pipe.unet, "write") | |
ref_unet = UNet2DConditionModel.from_pretrained(pipe_path, subfolder='unet', torch_dtype=sd_pipe.dtype) | |
state_dict = {} | |
with safe_open(ref_path, framework="pt", device="cpu") as f: | |
for key in f.keys(): | |
state_dict[key] = f.get_tensor(key) | |
ref_unet.load_state_dict(state_dict, strict=False) | |
self.ref_unet = ref_unet.to(self.device) | |
self.set_adapter(self.ref_unet, "read") | |
if set_seg_model: | |
self.set_seg_model() | |
self.attn_store = {} | |
def set_seg_model(self, ): | |
checkpoint_path = 'checkpoints/cloth_segm.pth' | |
self.seg_net = load_seg_model(checkpoint_path, device=self.device) | |
def set_adapter(self, unet, type): | |
attn_procs = {} | |
for name in unet.attn_processors.keys(): | |
if "attn1" in name and "motion_modules" not in name: | |
attn_procs[name] = REFAnimateDiffAttnProcessor(name=name, type=type) | |
else: | |
attn_procs[name] = AttnProcessor() | |
unet.set_attn_processor(attn_procs) | |
def generate( | |
self, | |
cloth_image, | |
cloth_mask_image=None, | |
prompt=None, | |
a_prompt="best quality, high quality", | |
num_images_per_prompt=4, | |
negative_prompt=None, | |
seed=-1, | |
guidance_scale=7.5, | |
cloth_guidance_scale=3., | |
num_inference_steps=20, | |
height=512, | |
width=384, | |
**kwargs, | |
): | |
if cloth_mask_image is None: | |
cloth_mask_image = generate_mask(cloth_image, net=self.seg_net, device=self.device) | |
cloth = prepare_image(cloth_image, height, width) | |
cloth_mask = prepare_mask(cloth_mask_image, height, width) | |
cloth = (cloth * cloth_mask).to(self.device, dtype=torch.float16) | |
if prompt is None: | |
prompt = "a photography of a model" | |
prompt = prompt + ", " + a_prompt | |
if negative_prompt is None: | |
negative_prompt = "bare, naked, nude, undressed, monochrome, lowres, bad anatomy, worst quality, low quality" | |
with torch.inference_mode(): | |
prompt_embeds, negative_prompt_embeds = self.pipe.encode_prompt( | |
prompt, | |
device=self.device, | |
num_images_per_prompt=num_images_per_prompt, | |
do_classifier_free_guidance=True, | |
negative_prompt=negative_prompt, | |
) | |
prompt_embeds_null = self.pipe.encode_prompt([""], device=self.device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=False)[0] | |
cloth_embeds = self.pipe.vae.encode(cloth).latent_dist.mode() * self.pipe.vae.config.scaling_factor | |
self.ref_unet(torch.cat([cloth_embeds] * num_images_per_prompt), 0, prompt_embeds_null, cross_attention_kwargs={"attn_store": self.attn_store}) | |
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None | |
frames = self.pipe( | |
prompt_embeds=prompt_embeds, | |
negative_prompt_embeds=negative_prompt_embeds, | |
guidance_scale=guidance_scale, | |
cloth_guidance_scale=cloth_guidance_scale, | |
num_inference_steps=num_inference_steps, | |
generator=generator, | |
height=height, | |
width=width, | |
cross_attention_kwargs={"attn_store": self.attn_store, "do_classifier_free_guidance": guidance_scale > 1.0}, | |
**kwargs, | |
).frames | |
return frames, cloth_mask_image | |