from PIL import ImageOps from PIL.Image import Image import torch from typing import Union, List from tqdm import tqdm from transformers.image_utils import ImageInput from transformers.tokenization_utils_base import TextInput from transformers import CLIPImageProcessor from transformers.processing_utils import ( ProcessorMixin, ) from transformers import AutoTokenizer, PreTrainedTokenizer from .image_processing_instellavl import InstellaVLImageProcessor from .mm_utils import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, KeywordsStoppingCriteria from .conversation import conv_templates def tokenizer_image_token(prompt: str, tokenizer: PreTrainedTokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None)->Union[torch.Tensor, List[torch.Tensor]]: r""" Tokenizes a prompt containing image tokens and inserts the specified image token index at the appropriate positions. Args: - prompt (str): The input prompt string containing text and DEFAULT_IMAGE_TOKEN="" placeholders. - tokenizer (PreTrainedTokenizer): The tokenizer to use for tokenizing the text chunks. - image_token_index (int): The token index to use for the image placeholders. Default is IMAGE_TOKEN_INDEX. - return_tensors (str, optional): The type of tensor to return. If "pt", returns a PyTorch tensor. Default is None. Returns: list or torch.Tensor: The tokenized input IDs as a list or a PyTorch tensor if return_tensors is specified. """ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split(DEFAULT_IMAGE_TOKEN)] def insert_separator(X, sep): return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1] input_ids = [] offset = 0 if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: offset = 1 input_ids.append(prompt_chunks[0][0]) for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): input_ids.extend(x[offset:]) if return_tensors is not None: if return_tensors == "pt": return torch.tensor(input_ids, dtype=torch.long) raise ValueError(f"Unsupported tensor type: {return_tensors}") return input_ids class InstellaVLProcessor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] image_processor_class = "AutoImageProcessor" tokenizer_class = ("GPTNeoXTokenizerFast") def __init__(self, image_processor: InstellaVLImageProcessor = None, tokenizer: AutoTokenizer = None, **kwargs): super().__init__(image_processor, tokenizer, **kwargs) def pad_sequence(self, input_ids: Union[List[torch.Tensor], List[List[torch.Tensor]]], batch_first: bool, padding_value: int, tokenizer: AutoTokenizer): if tokenizer.padding_side == "left": input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids] input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=batch_first, padding_value=padding_value) if tokenizer.padding_side == "left": input_ids = torch.flip(input_ids, [1]) return input_ids def encode(self, text: TextInput = None, images: ImageInput = None, image_processor: CLIPImageProcessor = None, tokenizer: AutoTokenizer = None, model_cfg: dict = None, ) -> dict: if images is not None: if isinstance(images, Image): # Handle images with EXIF orientation tags, which PIL will ignore by default # https://github.com/python-pillow/Pillow/issues/4703 ImageOps.exif_transpose(images, in_place=True) image_sizes = [images.size] images = [images] elif isinstance(images, list): image_sizes = [] for i in images: ImageOps.exif_transpose(i, in_place=True) image_sizes.append(i.size) image_tensor = self.image_processor.process(images, image_processor, model_cfg)['pixel_values'] text = text.replace(DEFAULT_IMAGE_TOKEN, "").strip() if images is not None and len(image_tensor) != 0 and DEFAULT_IMAGE_TOKEN not in text: question = DEFAULT_IMAGE_TOKEN + "\n" + text else: question = text conv = conv_templates["instella"].copy() conv.append_message(conv.roles[0], question) conv.append_message(conv.roles[1], None) prompt_question = conv.get_prompt() input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0) keywords = [conv.sep] stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) terminators = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("|||IP_ADDRESS|||")] out = { "input_ids": input_ids, "stopping_criteria": [stopping_criteria], "eos_token_id": terminators, } if images is not None: out = { "image_tensor": image_tensor, "image_sizes": image_sizes, **out, } self.tokenizer = tokenizer return out def batch_encode(self, texts: List[TextInput] = None, images: List[ImageInput] = None, image_processor: CLIPImageProcessor = None, tokenizer: AutoTokenizer = None, model_cfg: dict = None, ): if texts is None: raise ValueError("Text must be provided for batch encoding.") if images is None: images = [None] * len(text) assert isinstance(texts, list), "Since batch encoding happening, provide batch of texts in a list." assert len(texts) == len(images), "The number of texts and images must be equal." batch_outs = [] for txt, img in tqdm(zip(texts, images), total=len(texts), desc="Total Samples to encode"): batch_outs.append(self.encode(txt, img, image_processor, tokenizer, model_cfg)) return batch_outs # batched_image_tensors = [] # batched_text_tokens = [] # stopping_criterias = [] # image_sizes = [] # for t, img in tqdm(zip(text, images), desc="Total Samples to encode"): # if img is not None: # if isinstance(img, Image): # ImageOps.exif_transpose(img, in_place=True) # image_sizes.append(img.size) # img = [img] # elif isinstance(img, list): # tmp_img_sizes = [] # for i in img: # ImageOps.exif_transpose(i, in_place=True) # tmp_img_sizes.append(i.size) # image_sizes.append(tmp_img_sizes) # batched_image_tensors.append(self.image_processor.process(img, image_processor, model_cfg)['pixel_values'].squeeze(0)) # t = t.replace(DEFAULT_IMAGE_TOKEN, "").strip() # if img is not None and len(batched_image_tensors[-1]) != 0 and DEFAULT_IMAGE_TOKEN not in t: # question = DEFAULT_IMAGE_TOKEN + "\n" + t # else: # question = t # conv = conv_templates["instella"].copy() # conv.append_message(conv.roles[0], question) # conv.append_message(conv.roles[1], None) # prompt_question = conv.get_prompt() # input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") # stopping_criterias.append(KeywordsStoppingCriteria([conv.sep], tokenizer, input_ids.unsqueeze(0))) # batched_text_tokens.append(input_ids) # terminators = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("|||IP_ADDRESS|||")] # # Pad the text tokens. # pad_token_ids = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id # input_ids = self.pad_sequence(batched_text_tokens, batch_first=True, padding_value=pad_token_ids, tokenizer=tokenizer) # attention_masks = input_ids.ne(pad_token_ids) # batch_outs = { # "input_ids": input_ids, # "attention_mask": attention_masks, # "pad_token_id": pad_token_ids, # "stopping_criteria": stopping_criterias, # "eos_token_id": terminators, # } # if images is not None: # batch_outs = { # "image_tensor": batched_image_tensors, # "image_sizes": image_sizes, # **batch_outs # } # self.tokenizer = tokenizer # return batch_outs def decode(self, output_ids: torch.Tensor)->str: return self.tokenizer.decode(output_ids[0, :], skip_special_tokens=True).strip() def batch_decode(self, output_ids_lst: List[torch.Tensor])->List[str]: raise NotImplementedError("Batch decode is not implemented for InstellaVLProcessor") # text_decoded_outs = [] # for out_ids in output_ids_lst: # text_decoded_outs.append(self.decode(out_ids)) # return text_decoded_outs InstellaVLProcessor.register_for_auto_class()