import os import argparse import numpy as np from PIL import Image import torch import torchvision.transforms as T from transformers import AutoTokenizer import gradio as gr from resnet50 import build_model from utils import generate_similiarity_map, post_process, load_tokenizer, build_transform_R50 from utils import IMAGENET_MEAN, IMAGENET_STD from internvl.train.dataset import dynamic_preprocess from internvl.model.internvl_chat import InternVLChatModel import spaces # 模型配置 CHECKPOINTS = { "TokenFD_4096_English_seg": "TongkunGuan/TokenFD_4096_English_seg", "TokenFD_2048_Bilingual_seg": "TongkunGuan/TokenFD_2048_Bilingual_seg", } # 全局变量 HF_TOKEN = os.getenv("HF_TOKEN") current_vis = [] current_bpe = [] current_index = 0 def load_model(check_type): # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda") if check_type == 'R50': tokenizer = load_tokenizer('tokenizer_path') model = build_model(argparse.Namespace()).eval() model.load_state_dict(torch.load(CHECKPOINTS['R50'], map_location='cpu')['model']) transform = build_transform_R50(normalize_type='imagenet') elif check_type == 'R50_siglip': tokenizer = load_tokenizer('tokenizer_path') model = build_model(argparse.Namespace()).eval() model.load_state_dict(torch.load(CHECKPOINTS['R50_siglip'], map_location='cpu')['model']) transform = build_transform_R50(normalize_type='imagenet') elif 'TokenFD' in check_type: model_path = CHECKPOINTS[check_type] tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False, use_auth_token=HF_TOKEN) model = InternVLChatModel.from_pretrained(model_path, torch_dtype=torch.bfloat16).eval() transform = T.Compose([ T.Lambda(lambda img: img.convert('RGB')), T.Resize((224, 224)), T.ToTensor(), T.Normalize(IMAGENET_MEAN, IMAGENET_STD) ]) return model.to(device), tokenizer, transform, device def process_image(model, tokenizer, transform, device, check_type, image, text): global current_vis, current_bpe src_size = image.size if 'TokenOCR' in check_type: images, target_ratio = dynamic_preprocess(image, min_num=1, max_num=12, image_size=model.config.force_image_size, use_thumbnail=model.config.use_thumbnail, return_ratio=True) pixel_values = torch.stack([transform(img) for img in images]).to(device) else: pixel_values = torch.stack([transform(image)]).to(device) target_ratio = (1, 1) # 文本处理 text += ' ' input_ids = tokenizer(text)['input_ids'][1:] input_ids = torch.tensor(input_ids, device=device) # 获取嵌入 with torch.no_grad(): if 'R50' in check_type: text_embeds = model.language_embedding(input_ids) else: text_embeds = model.tok_embeddings(input_ids) vit_embeds, size1 = model.forward_tokenocr(pixel_values.to(device)) vit_embeds, size2 = post_process(vit_embeds, target_ratio, check_type) # 计算相似度 text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True) vit_embeds = vit_embeds / vit_embeds.norm(dim=-1, keepdim=True) similarity = text_embeds @ vit_embeds.T resized_size = size1 if size1 is not None else size2 # print(f"text_embeds shape: {text_embeds.shape}, numel: {text_embeds.numel()}") # text_embeds shape: torch.Size([4, 2048]), numel: 8192 # print(f"vit_embeds shape: {vit_embeds.shape}, numel: {vit_embeds.numel()}") # vit_embeds shape: torch.Size([9728, 2048]), numel: 19922944 # print(f"similarity shape: {similarity.shape}, numel: {similarity.numel()}")# similarity shape: torch.Size([4, 9728]), numel: 38912 # 生成可视化 attn_map = similarity.reshape(len(text_embeds), resized_size[0], resized_size[1]) # attn_map = similarity.reshape(len(text_embeds), *target_ratio) all_bpe_strings = [tokenizer.decode(input_id) for input_id in input_ids] current_vis = generate_similiarity_map([image], attn_map, [tokenizer.decode([i]) for i in input_ids], [], target_ratio, src_size) current_bpe = [tokenizer.decode([i]) for i in input_ids] # current_bpe[-1] = 'Input text' current_bpe[-1] = text return image, current_vis[0], current_bpe[0] # 事件处理函数 def update_index(change): global current_index current_index = max(0, min(len(current_vis) - 1, current_index + change)) return current_vis[current_index], format_bpe_display(current_bpe[current_index]) def format_bpe_display(bpe): # 使用HTML标签来设置字体大小、颜色,加粗,并居中 return f"
If the input text is not included in the image, the attention map will show a lot of noise (the actual response value is very low), since we normalize the attention map according to the relative value.
") with gr.Row(): orig_img = gr.Image(label="Original picture", interactive=False) heatmap = gr.Image(label="BPE visualization", interactive=False) with gr.Row() as controls: prev_btn = gr.Button("⬅ Last", visible=False) index_slider = gr.Slider(0, 1, value=0, step=1, label="BPE index", visible=False) next_btn = gr.Button("⮕ Next", visible=False) bpe_display = gr.Markdown("Current BPE: ", visible=False) # 事件处理 @spaces.GPU def on_run_clicked(model_type, image, text): global current_vis, current_bpe, current_index current_index = 0 # Reset index when new image is processed image, vis, bpe = process_image(*load_model(model_type), model_type, image, text) # Update the slider range and set value to 0 slider_max_val = len(current_bpe) - 1 bpe_text = format_bpe_display(bpe) return image, vis, bpe_text, slider_max_val run_btn.click( on_run_clicked, inputs=[model_type, image_input, text_input], outputs=[orig_img, heatmap, bpe_display, index_slider], ).then( lambda max_val: (gr.update(visible=True), gr.update(visible=True, maximum=max_val, value=0), gr.update(visible=True), gr.update(visible=True)), inputs=index_slider, outputs=[prev_btn, index_slider, next_btn, bpe_display], ) prev_btn.click( lambda: (*update_index(-1), current_index), outputs=[heatmap, bpe_display, index_slider] ) next_btn.click( lambda: (*update_index(1), current_index), outputs=[heatmap, bpe_display, index_slider] ) index_slider.change( lambda x: (current_vis[x], format_bpe_display(current_bpe[x])), inputs=index_slider, outputs=[heatmap, bpe_display] ) if __name__ == "__main__": demo.launch()