EasyRef / ip_adapter /ip_adapter.py
TempleX
Update README
f6ad3fd
import os
from typing import List
import torch
from diffusers import StableDiffusionPipeline
from diffusers.pipelines.controlnet import MultiControlNetModel
from PIL import Image
from safetensors import safe_open
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
from .utils import is_torch2_available, get_generator
if is_torch2_available():
from .attention_processor import (
AttnProcessor2_0 as AttnProcessor,
)
from .attention_processor import (
CNAttnProcessor2_0 as CNAttnProcessor,
)
from .attention_processor import (
IPAttnProcessor2_0 as IPAttnProcessor,
)
from .attention_processor_faceid import (
LoRAAttnProcessor2_0 as LoRAAttnProcessor,
)
from .attention_processor_faceid import (
LoRAIPAttnProcessor2_0 as LoRAIPAttnProcessor,
)
else:
from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor
from .attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor
from .resampler import Resampler
class ImageProjModel(torch.nn.Module):
"""Projection Model"""
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
super().__init__()
self.generator = None
self.cross_attention_dim = cross_attention_dim
self.clip_extra_context_tokens = clip_extra_context_tokens
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
self.norm = torch.nn.LayerNorm(cross_attention_dim)
def forward(self, image_embeds):
embeds = image_embeds
clip_extra_context_tokens = self.proj(embeds).reshape(
-1, self.clip_extra_context_tokens, self.cross_attention_dim
)
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
return clip_extra_context_tokens
class MLPProjModel(torch.nn.Module):
"""SD model with image prompt"""
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
super().__init__()
self.proj = torch.nn.Sequential(
torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
torch.nn.GELU(),
torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
torch.nn.LayerNorm(cross_attention_dim)
)
def forward(self, image_embeds):
clip_extra_context_tokens = self.proj(image_embeds)
return clip_extra_context_tokens
class IPAdapter:
def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4):
self.device = device
self.image_encoder_path = image_encoder_path
self.ip_ckpt = ip_ckpt
self.num_tokens = num_tokens
self.pipe = sd_pipe.to(self.device)
self.set_ip_adapter()
# load image encoder
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
self.device, dtype=torch.float16
)
self.clip_image_processor = CLIPImageProcessor()
# image proj model
self.image_proj_model = self.init_proj()
self.load_ip_adapter()
def init_proj(self):
image_proj_model = ImageProjModel(
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
clip_embeddings_dim=self.image_encoder.config.projection_dim,
clip_extra_context_tokens=self.num_tokens,
).to(self.device, dtype=torch.float16)
return image_proj_model
def set_ip_adapter(self):
unet = self.pipe.unet
attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
attn_procs[name] = AttnProcessor()
else:
attn_procs[name] = IPAttnProcessor(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
num_tokens=self.num_tokens,
).to(self.device, dtype=torch.float16)
unet.set_attn_processor(attn_procs)
if hasattr(self.pipe, "controlnet"):
if isinstance(self.pipe.controlnet, MultiControlNetModel):
for controlnet in self.pipe.controlnet.nets:
controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
else:
self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
def load_ip_adapter(self):
if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
state_dict = {"image_proj": {}, "ip_adapter": {}}
with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
for key in f.keys():
if key.startswith("image_proj."):
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
elif key.startswith("ip_adapter."):
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
else:
state_dict = torch.load(self.ip_ckpt, map_location="cpu")
self.image_proj_model.load_state_dict(state_dict["image_proj"])
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
ip_layers.load_state_dict(state_dict["ip_adapter"])
@torch.inference_mode()
def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
if pil_image is not None:
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
else:
clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
return image_prompt_embeds, uncond_image_prompt_embeds
def set_scale(self, scale):
for attn_processor in self.pipe.unet.attn_processors.values():
if isinstance(attn_processor, IPAttnProcessor) or isinstance(attn_processor, LoRAIPAttnProcessor):
attn_processor.scale = scale
def generate(
self,
pil_image=None,
clip_image_embeds=None,
prompt=None,
negative_prompt=None,
scale=1.0,
num_samples=4,
seed=None,
guidance_scale=7.5,
num_inference_steps=30,
**kwargs,
):
self.set_scale(scale)
if pil_image is not None:
num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
else:
num_prompts = clip_image_embeds.size(0)
if prompt is None:
prompt = "best quality, high quality"
if negative_prompt is None:
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
if not isinstance(prompt, List):
prompt = [prompt] * num_prompts
if not isinstance(negative_prompt, List):
negative_prompt = [negative_prompt] * num_prompts
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
pil_image=pil_image, clip_image_embeds=clip_image_embeds
)
bs_embed, seq_len, _ = image_prompt_embeds.shape
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
with torch.inference_mode():
prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
prompt,
device=self.device,
num_images_per_prompt=num_samples,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
generator = get_generator(seed, self.device)
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
**kwargs,
).images
return images
class IPAdapterXL(IPAdapter):
"""SDXL"""
@torch.inference_mode()
def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
if pil_image is not None:
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
else:
clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
clip_image_embeds = clip_image_embeds.mean(0, keepdim=True)
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
# if image_prompt_embeds.shape[0] > 1:
# image_prompt_embeds = image_prompt_embeds.mean(0, keepdim=True)
# uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds[:1]))
# else:
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
return image_prompt_embeds, uncond_image_prompt_embeds
def generate(
self,
pil_image,
prompt=None,
negative_prompt=None,
scale=1.0,
num_samples=4,
seed=None,
num_inference_steps=30,
**kwargs,
):
self.set_scale(scale)
# num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
num_prompts = 1
if prompt is None:
prompt = "best quality, high quality"
if negative_prompt is None:
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
if not isinstance(prompt, List):
prompt = [prompt] * num_prompts
if not isinstance(negative_prompt, List):
negative_prompt = [negative_prompt] * num_prompts
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
bs_embed, seq_len, _ = image_prompt_embeds.shape
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
with torch.inference_mode():
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.pipe.encode_prompt(
prompt,
num_images_per_prompt=num_samples,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
self.generator = get_generator(seed, self.device)
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
num_inference_steps=num_inference_steps,
generator=self.generator,
**kwargs,
).images
return images
class IPAdapterPlus(IPAdapter):
"""IP-Adapter with fine-grained features"""
def init_proj(self):
image_proj_model = Resampler(
dim=self.pipe.unet.config.cross_attention_dim,
depth=4,
dim_head=64,
heads=12,
num_queries=self.num_tokens,
embedding_dim=self.image_encoder.config.hidden_size,
output_dim=self.pipe.unet.config.cross_attention_dim,
ff_mult=4,
).to(self.device, dtype=torch.float16)
return image_proj_model
@torch.inference_mode()
def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image = clip_image.to(self.device, dtype=torch.float16)
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
uncond_clip_image_embeds = self.image_encoder(
torch.zeros_like(clip_image), output_hidden_states=True
).hidden_states[-2]
uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
return image_prompt_embeds, uncond_image_prompt_embeds
class IPAdapterFull(IPAdapterPlus):
"""IP-Adapter with full features"""
def init_proj(self):
image_proj_model = MLPProjModel(
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
clip_embeddings_dim=self.image_encoder.config.hidden_size,
).to(self.device, dtype=torch.float16)
return image_proj_model
class IPAdapterPlusXL(IPAdapter):
"""SDXL"""
def init_proj(self):
image_proj_model = Resampler(
dim=1280,
depth=4,
dim_head=64,
heads=20,
num_queries=self.num_tokens,
embedding_dim=self.image_encoder.config.hidden_size,
output_dim=self.pipe.unet.config.cross_attention_dim,
ff_mult=4,
).to(self.device, dtype=torch.float16)
return image_proj_model
@torch.inference_mode()
def get_image_embeds(self, pil_image):
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image = clip_image.to(self.device, dtype=torch.float16)
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
uncond_clip_image_embeds = self.image_encoder(
torch.zeros_like(clip_image), output_hidden_states=True
).hidden_states[-2]
uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
return image_prompt_embeds, uncond_image_prompt_embeds
def generate(
self,
pil_image,
prompt=None,
negative_prompt=None,
scale=1.0,
num_samples=4,
seed=None,
num_inference_steps=30,
**kwargs,
):
self.set_scale(scale)
num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
if prompt is None:
prompt = "best quality, high quality"
if negative_prompt is None:
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
if not isinstance(prompt, List):
prompt = [prompt] * num_prompts
if not isinstance(negative_prompt, List):
negative_prompt = [negative_prompt] * num_prompts
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
bs_embed, seq_len, _ = image_prompt_embeds.shape
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
with torch.inference_mode():
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.pipe.encode_prompt(
prompt,
num_images_per_prompt=num_samples,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
generator = get_generator(seed, self.device)
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
num_inference_steps=num_inference_steps,
generator=generator,
**kwargs,
).images
return images
class EasyRef(IPAdapter):
"""EasyRef-SDXL"""
def __init__(self, sd_pipe, multimodal_llm_path, ip_ckpt, device, num_tokens=64, use_lora=False, lora_rank=128, cond_image_size=336):
self.device = device
self.multimodal_llm_path = multimodal_llm_path
self.ip_ckpt = ip_ckpt
self.num_tokens = num_tokens
self.use_lora = use_lora
self.lora_rank = lora_rank
self.pipe = sd_pipe.to(self.device)
self.set_ip_adapter()
# load image encoder
mllm_final_layer = Qwen2VLForConditionalGeneration.from_pretrained(
multimodal_llm_path,
torch_dtype=torch.bfloat16,
attn_implementation="sdpa",
device_map="cuda"
)
mllm_final_layer = mllm_final_layer.model
mllm_final_layer.layers = mllm_final_layer.layers[-1:]
mllm_final_layer.embed_tokens = torch.nn.Identity()
mllm_final_layer.visual = torch.nn.Identity()
mllm_final_layer.lm_head = torch.nn.Identity()
mllm_final_layer.reference_tokens = torch.nn.Parameter(0.1 * torch.randn(num_tokens, mllm_final_layer.config.hidden_size))
self.mllm_final_layer = mllm_final_layer.to(self.device)
for i in range(len(self.mllm_final_layer.layers)):
self.mllm_final_layer.layers[i].self_attn.is_causal = False
multimodal_llm = Qwen2VLForConditionalGeneration.from_pretrained(
multimodal_llm_path,
torch_dtype=torch.bfloat16,
attn_implementation="sdpa",
device_map="cuda"
)
multimodal_llm.model.layers = multimodal_llm.model.layers[:-1]
multimodal_llm.norm = torch.nn.Identity()
self.multimodal_llm = multimodal_llm.to(self.device)
min_pixels = ((cond_image_size // 28 - 1)**2) * 28 * 28
max_pixels = ((cond_image_size // 28)**2 + 1) * 28 * 28
self.image_processor = AutoProcessor.from_pretrained(
multimodal_llm_path, min_pixels=min_pixels, max_pixels=max_pixels)
# image proj model
self.image_proj_model = self.init_proj()
self.load_ip_adapter()
def load_ip_adapter(self):
if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
state_dict = {"image_proj_model": {}, "mllm_final_layer": {}, "unet": {}, "multimodal_llm": {}}
with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
for key in f.keys():
if key.startswith("image_proj_model."):
state_dict["image_proj_model"][key.replace("image_proj_model.", "")] = f.get_tensor(key)
elif key.startswith("mllm_final_layer."):
state_dict["mllm_final_layer"][key.replace("mllm_final_layer.", "")] = f.get_tensor(key)
elif key.startswith("multimodal_llm."):
state_dict["multimodal_llm"][key.replace("multimodal_llm.", "")] = f.get_tensor(key)
elif key.startswith("unet."):
state_dict["unet"][key.replace("unet.", "")] = f.get_tensor(key)
else:
state_dict = {"image_proj_model": {}, "mllm_final_layer": {}, "unet": {}, "multimodal_llm": {}}
f = torch.load(self.ip_ckpt, map_location="cpu")["module"]
for key in f.keys():
if key.startswith("image_proj_model."):
state_dict["image_proj_model"][key.replace("image_proj_model.", "")] = f[key]
elif key.startswith("mllm_final_layer."):
state_dict["mllm_final_layer"][key.replace("mllm_final_layer.", "")] = f[key]
elif key.startswith("multimodal_llm."):
state_dict["multimodal_llm"][key.replace("multimodal_llm.", "")] = f[key]
elif key.startswith("unet."):
state_dict["unet"][key.replace("unet.", "")] = f[key]
if len(list(state_dict["multimodal_llm"].keys())) > 0:
self.multimodal_llm.load_state_dict(state_dict["multimodal_llm"], strict=False)
self.image_proj_model.load_state_dict(state_dict["image_proj_model"])
self.mllm_final_layer.load_state_dict(state_dict["mllm_final_layer"])
unet_state_dict = self.pipe.unet.state_dict()
unet_state_dict.update(state_dict["unet"])
self.pipe.unet.load_state_dict(unet_state_dict, strict=False)
# ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
# ip_layers.load_state_dict(state_dict["ip_adapter"])
def set_ip_adapter(self):
unet = self.pipe.unet
attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
if self.use_lora:
attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=self.lora_rank).to(self.device, dtype=torch.float16)
else:
attn_procs[name] = AttnProcessor().to(self.device, dtype=torch.float16)
else:
if self.use_lora:
attn_procs[name] = LoRAIPAttnProcessor(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
num_tokens=self.num_tokens,
rank=self.lora_rank,
).to(self.device, dtype=torch.float16)
else:
attn_procs[name] = IPAttnProcessor(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
num_tokens=self.num_tokens,
).to(self.device, dtype=torch.float16)
unet.set_attn_processor(attn_procs)
if hasattr(self.pipe, "controlnet"):
if isinstance(self.pipe.controlnet, MultiControlNetModel):
for controlnet in self.pipe.controlnet.nets:
controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
else:
self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
def init_proj(self):
image_proj_model = MLPProjModel(
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
clip_embeddings_dim=self.multimodal_llm.config.hidden_size,
).to(self.device, dtype=torch.bfloat16)
return image_proj_model
@torch.inference_mode()
def get_image_embeds(self, pil_image, system_prompt):
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
data = []
messages = [
{
"role": "user",
"content": [],
}
]
for image in pil_image:
messages[0]["content"].append({"type": "image", "image": image})
messages[0]["content"].append({"type": "text", "text": system_prompt})
prompt = self.image_processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.image_processor(
text=[prompt],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
data.append(inputs)
input_ids = torch.stack([example["input_ids"] for example in data], dim=0).to(self.device)
attention_mask = torch.cat([example["attention_mask"] for example in data], dim=0).to(self.device)
pixel_values = [example["pixel_values"] for example in data]
image_grid_thw = torch.stack([example["image_grid_thw"] for example in data], dim=0).to(self.device)
with torch.no_grad():
inputs_embeds = self.multimodal_llm.model.embed_tokens(input_ids)
new_inputs_embeds = []
for i in range(len(pixel_values)):
pixel_value = pixel_values[i].type(self.multimodal_llm.visual.get_dtype()).to(inputs_embeds.device)
grid_thw = image_grid_thw[i]
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
dim=0, dtype=torch.int32
)
cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0)
image_embeds = []
for j in range(1, len(cu_seqlens)):
image_embed = self.multimodal_llm.visual(pixel_value[cu_seqlens[j - 1] : cu_seqlens[j]], grid_thw=grid_thw[(j - 1) : j]).to(inputs_embeds.device)
image_embeds.append(image_embed)
image_embeds = torch.cat(image_embeds, dim=0)
image_mask = input_ids[i] == self.multimodal_llm.config.image_token_id
inputs_embed = inputs_embeds[i].clone()
inputs_embed[image_mask] = image_embeds
new_inputs_embeds.append(inputs_embed)
inputs_embeds = torch.cat(new_inputs_embeds, dim=0)
image_embeds = self.multimodal_llm(
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
output_hidden_states=True
).hidden_states[-1]
reference_tokens = self.mllm_final_layer.reference_tokens.to(self.device)
image_embeds = torch.cat([image_embeds, reference_tokens.unsqueeze(0).repeat(image_embeds.shape[0], 1, 1)], dim=1).to(dtype=torch.bfloat16)
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, :reference_tokens.shape[0]])], dim=1)
outputs = self.mllm_final_layer(
attention_mask=attention_mask.to(self.device),
inputs_embeds=image_embeds.to(self.device),
output_hidden_states=True,
)
image_embeds = outputs.hidden_states[-1]
image_embeds_ = []
for image_embed in image_embeds:
new_image_embed = image_embed[-reference_tokens.shape[0]:]
image_embeds_.append(new_image_embed)
image_prompt_embeds = self.image_proj_model(torch.stack(image_embeds_)).to(dtype=torch.float16)
return image_prompt_embeds
def generate(
self,
pil_image,
system_prompt,
prompt=None,
negative_prompt=None,
scale=1.0,
num_samples=4,
seed=None,
num_inference_steps=30,
**kwargs,
):
self.set_scale(scale)
# num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
num_prompts = 1
if prompt is None:
prompt = "best quality, high quality"
if negative_prompt is None:
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
if not isinstance(prompt, List):
prompt = [prompt] * num_prompts
if not isinstance(negative_prompt, List):
negative_prompt = [negative_prompt] * num_prompts
# image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
image_prompt_embeds = self.get_image_embeds(pil_image, system_prompt[0])
uncond_image_prompt_embeds = self.get_image_embeds(Image.new(mode="RGB", size=(int(512), int(512))), system_prompt[1])
bs_embed, seq_len, _ = image_prompt_embeds.shape
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
with torch.inference_mode():
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.pipe.encode_prompt(
prompt,
num_images_per_prompt=num_samples,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
generator = get_generator(seed, self.device)
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
num_inference_steps=num_inference_steps,
generator=generator,
**kwargs,
).images
return images