|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | import math | 
					
						
						|  | import os | 
					
						
						|  | import os.path as osp | 
					
						
						|  | import warnings | 
					
						
						|  | from dataclasses import asdict | 
					
						
						|  | from typing import Any, Dict, List, Optional, Sequence, Tuple | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | import transformers | 
					
						
						|  | from huggingface_hub import file_exists, repo_exists | 
					
						
						|  | from huggingface_hub.utils import HFValidationError | 
					
						
						|  | from transformers import ( | 
					
						
						|  | AutoConfig, | 
					
						
						|  | AutoModelForCausalLM, | 
					
						
						|  | AutoTokenizer, | 
					
						
						|  | PretrainedConfig, | 
					
						
						|  | PreTrainedModel, | 
					
						
						|  | PreTrainedTokenizer, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | from .conversation import SeparatorStyle, default_conversation | 
					
						
						|  |  | 
					
						
						|  | SENTINEL_TOKEN = "<vila/sentinel>" | 
					
						
						|  | MEDIA_TOKENS = { | 
					
						
						|  | "image": "<image>", | 
					
						
						|  | "video": "<vila/video>", | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | DUMMY_CONVERSATION = [ | 
					
						
						|  | {"from": "human", "value": "question"}, | 
					
						
						|  | {"from": "gpt", "value": "answer"}, | 
					
						
						|  | ] * 10 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def tokenizer_image_token(prompt, tokenizer, return_tensors=None): | 
					
						
						|  | return tokenizer(prompt, return_tensors=return_tensors).input_ids[0] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def has_tokenizer(repo_id_or_path: str) -> bool: | 
					
						
						|  |  | 
					
						
						|  | if osp.exists(osp.join(repo_id_or_path, "tokenizer_config.json")): | 
					
						
						|  | return True | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | return repo_exists(repo_id_or_path) and file_exists(repo_id_or_path, "tokenizer_config.json") | 
					
						
						|  | except HFValidationError: | 
					
						
						|  | return False | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def _maybe_add_sentinel_token(tokenizer: transformers.PreTrainedTokenizer) -> None: | 
					
						
						|  | if not hasattr(tokenizer, "sentinel_token"): | 
					
						
						|  | tokenizer.add_tokens([SENTINEL_TOKEN], special_tokens=True) | 
					
						
						|  | tokenizer.sentinel_token = SENTINEL_TOKEN | 
					
						
						|  | tokenizer.sentinel_token_id = tokenizer.convert_tokens_to_ids(SENTINEL_TOKEN) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def tokenize_conversation_legacy( | 
					
						
						|  | messages: Sequence[Dict[str, str]], | 
					
						
						|  | tokenizer: transformers.PreTrainedTokenizer, | 
					
						
						|  | add_generation_prompt: bool = False, | 
					
						
						|  | overrides: Optional[Dict[str, str]] = None, | 
					
						
						|  | no_system_prompt: bool = False, | 
					
						
						|  | ) -> torch.Tensor: | 
					
						
						|  | conv = default_conversation.copy() | 
					
						
						|  | roles = {"human": conv.roles[0], "gpt": conv.roles[1]} | 
					
						
						|  |  | 
					
						
						|  | if no_system_prompt: | 
					
						
						|  | conv.system = "" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if messages[0]["from"] != "human": | 
					
						
						|  | messages = messages[1:] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if add_generation_prompt: | 
					
						
						|  | messages.append({"from": "gpt", "value": None}) | 
					
						
						|  |  | 
					
						
						|  | conv.messages = [] | 
					
						
						|  | for turn, message in enumerate(messages): | 
					
						
						|  | role = roles[message["from"]] | 
					
						
						|  | assert role == conv.roles[turn % 2] | 
					
						
						|  | if overrides is not None and message["from"] in overrides: | 
					
						
						|  | conv.append_message(role, overrides[message["from"]]) | 
					
						
						|  | else: | 
					
						
						|  | conv.append_message(role, message["value"]) | 
					
						
						|  |  | 
					
						
						|  | return tokenizer_image_token(conv.get_prompt(), tokenizer, return_tensors="pt") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def tokenize_conversation( | 
					
						
						|  | messages: Sequence[Dict[str, str]], | 
					
						
						|  | tokenizer: transformers.PreTrainedTokenizer, | 
					
						
						|  | add_generation_prompt: bool = False, | 
					
						
						|  | overrides: Optional[Dict[str, str]] = None, | 
					
						
						|  | no_system_prompt: bool = False, | 
					
						
						|  | ) -> torch.Tensor: | 
					
						
						|  |  | 
					
						
						|  | for message in messages: | 
					
						
						|  | message["value"] = message["value"].strip() | 
					
						
						|  |  | 
					
						
						|  | if default_conversation.sep_style != SeparatorStyle.AUTO: | 
					
						
						|  | return tokenize_conversation_legacy( | 
					
						
						|  | messages, | 
					
						
						|  | tokenizer, | 
					
						
						|  | add_generation_prompt=add_generation_prompt, | 
					
						
						|  | overrides=overrides, | 
					
						
						|  | no_system_prompt=no_system_prompt, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | conversation = [] | 
					
						
						|  | for m in messages: | 
					
						
						|  | message = {} | 
					
						
						|  | if m["from"] == "human": | 
					
						
						|  | message["role"] = "user" | 
					
						
						|  | elif m["from"] == "gpt": | 
					
						
						|  | message["role"] = "assistant" | 
					
						
						|  | else: | 
					
						
						|  | raise ValueError(f"Unexpected sender '{m['from']}' in conversation entry.") | 
					
						
						|  |  | 
					
						
						|  | message["content"] = m["value"] | 
					
						
						|  | if overrides is not None and m["from"] in overrides: | 
					
						
						|  | message["content"] = overrides[m["from"]] | 
					
						
						|  | conversation.append(message) | 
					
						
						|  |  | 
					
						
						|  | if no_system_prompt: | 
					
						
						|  | conversation = [{"role": "system", "content": ""}] + conversation | 
					
						
						|  |  | 
					
						
						|  | text = tokenizer.apply_chat_template( | 
					
						
						|  | conversation, | 
					
						
						|  | add_generation_prompt=add_generation_prompt, | 
					
						
						|  | tokenize=False, | 
					
						
						|  | ) | 
					
						
						|  | return tokenizer_image_token(text, tokenizer, return_tensors="pt") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def infer_stop_tokens(tokenizer: transformers.PreTrainedTokenizer) -> List[str]: | 
					
						
						|  | _maybe_add_sentinel_token(tokenizer) | 
					
						
						|  | template = tokenize_conversation(DUMMY_CONVERSATION, tokenizer, overrides={"gpt": SENTINEL_TOKEN}) | 
					
						
						|  |  | 
					
						
						|  | stop_tokens = {tokenizer.eos_token} | 
					
						
						|  | for k in range(template.size(0) - 1): | 
					
						
						|  | if template[k] == tokenizer.sentinel_token_id: | 
					
						
						|  | stop_token = tokenizer.decode(template[k + 1]) | 
					
						
						|  | stop_tokens.add(stop_token) | 
					
						
						|  | return list(stop_tokens) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def context_length_extension(config): | 
					
						
						|  | orig_ctx_len = getattr(config, "max_position_embeddings", None) | 
					
						
						|  | model_max_length = getattr(config, "model_max_length", None) | 
					
						
						|  | if orig_ctx_len and model_max_length > orig_ctx_len: | 
					
						
						|  | print(f"Scaling RoPE from {orig_ctx_len} to {model_max_length}") | 
					
						
						|  | scaling_factor = float(math.ceil(model_max_length / orig_ctx_len)) | 
					
						
						|  | config.rope_scaling = {"type": "linear", "factor": scaling_factor} | 
					
						
						|  | return config | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def build_llm_and_tokenizer( | 
					
						
						|  | model_name_or_path: str, | 
					
						
						|  | config: PretrainedConfig, | 
					
						
						|  | attn_implementation=None, | 
					
						
						|  | model_max_length=None, | 
					
						
						|  | *args, | 
					
						
						|  | **kwargs, | 
					
						
						|  | ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: | 
					
						
						|  |  | 
					
						
						|  | llm_cfg = AutoConfig.from_pretrained(model_name_or_path) | 
					
						
						|  | llm_cfg._attn_implementation = attn_implementation | 
					
						
						|  | llm_cfg.model_max_length = model_max_length | 
					
						
						|  | if model_max_length is not None: | 
					
						
						|  | context_length_extension(llm_cfg) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | quantization_restore_from_checkpoint = False | 
					
						
						|  |  | 
					
						
						|  | if quantization_restore_from_checkpoint: | 
					
						
						|  | fp8_model_name_or_path = kwargs.pop("fp8_llm_cfg", None) | 
					
						
						|  |  | 
					
						
						|  | llm = AutoModelForCausalLM.from_pretrained( | 
					
						
						|  | fp8_model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | llm = AutoModelForCausalLM.from_pretrained( | 
					
						
						|  | model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | llm_path = model_name_or_path | 
					
						
						|  | if not has_tokenizer(llm_path): | 
					
						
						|  | llm_path = osp.join(llm_path, "llm") | 
					
						
						|  | if not has_tokenizer(llm_path): | 
					
						
						|  | raise ValueError(f"Cannot find tokenizer in {llm_path}.") | 
					
						
						|  |  | 
					
						
						|  | tokenizer = AutoTokenizer.from_pretrained(llm_path, padding_side="right", use_fast=True, legacy=False) | 
					
						
						|  | if model_max_length is not None: | 
					
						
						|  | tokenizer.model_max_length = model_max_length | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if getattr(config, "chat_template", None) is not None: | 
					
						
						|  | print(f"Using chat template: {config.chat_template}") | 
					
						
						|  | fpath = os.path.join(os.path.dirname(__file__), "chat_templates", f"{config.chat_template}.jinja") | 
					
						
						|  | if not os.path.exists(fpath): | 
					
						
						|  | fpath = os.path.join(os.path.dirname(model_name_or_path), f"{config.chat_template}.jinja") | 
					
						
						|  | with open(fpath) as fd: | 
					
						
						|  | chat_template = fd.read() | 
					
						
						|  | tokenizer.chat_template = chat_template.replace("    ", "").replace("\n", "") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | tokenizer.stop_tokens = infer_stop_tokens(tokenizer) | 
					
						
						|  | tokenizer.stop_token_ids = tokenizer.convert_tokens_to_ids(tokenizer.stop_tokens) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | tokenizer.media_tokens = MEDIA_TOKENS | 
					
						
						|  | tokenizer.media_token_ids = {} | 
					
						
						|  | for name, token in MEDIA_TOKENS.items(): | 
					
						
						|  | tokenizer.add_tokens([token], special_tokens=True) | 
					
						
						|  | tokenizer.media_token_ids[name] = tokenizer.convert_tokens_to_ids(token) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | config.hidden_size = llm.config.hidden_size | 
					
						
						|  | return llm, tokenizer | 
					
						
						|  |  |