File size: 5,454 Bytes
fc4df88
 
 
 
f08cdaf
fc4df88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f08cdaf
fc4df88
 
 
 
 
 
 
 
 
f08cdaf
fc4df88
 
 
f08cdaf
fc4df88
 
 
 
f08cdaf
 
fc4df88
 
f08cdaf
 
fc4df88
f08cdaf
fc4df88
f08cdaf
fc4df88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f08cdaf
 
 
 
fc4df88
 
f08cdaf
fc4df88
 
 
 
 
 
 
 
 
 
f08cdaf
fc4df88
f08cdaf
fc4df88
 
f08cdaf
fc4df88
f08cdaf
fc4df88
 
 
 
 
 
f08cdaf
fc4df88
 
 
 
 
 
 
 
f08cdaf
fc4df88
f08cdaf
fc4df88
 
 
 
f08cdaf
 
fc4df88
 
 
 
75f50ff
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import gradio as gr
import torch
from transformers import (
    Idefics2Processor, Idefics2ForConditionalGeneration,
    Blip2Processor, Blip2ForConditionalGeneration
)
from PIL import Image
import time
import pandas as pd
import nltk
from nltk.translate.bleu_score import sentence_bleu

# Descargar 'punkt' si no está disponible
try:
    nltk.data.find("tokenizers/punkt")
except LookupError:
    nltk.download("punkt")

# Configuración del dispositivo
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Usando dispositivo: {device}")

# Definición de modelos
models = {
    "IDEFICS2": {
        "model_id": "HuggingFaceM4/idefics2-8b",
        "processor_class": Idefics2Processor,
        "model_class": Idefics2ForConditionalGeneration,
        "caption_prompt": "<image>Describe the image in detail"
    },
    "BLIP2": {
        "model_id": "Salesforce/blip2-opt-2.7b",
        "processor_class": Blip2Processor,
        "model_class": Blip2ForConditionalGeneration,
        "caption_prompt": ""  # Prompt vacío para BLIP2
    }
}

# Cargar modelos (pre-cargados para evitar retrasos)
model_instances = {}
for model_name, config in models.items():
    processor = config["processor_class"].from_pretrained(config["model_id"])
    model = config["model_class"].from_pretrained(config["model_id"]).to(device)
    model_instances[model_name] = (processor, model)

# Preguntas VQA predefinidas
vqa_questions = [
    "Are there people in the image?",
    "Which color predominates in the image?"
]

# Referencia genérica para BLEU (puedes ajustar según necesidades)
reference_caption = ["An image with people and various objects"]

def infer(image, model_name, task, question=None):
    if image is None:
        return "Por favor, sube una imagen.", None, None, None, None, None

    # Abrir y preparar la imagen
    image = Image.open(image).convert("RGB")
    if "BLIP2" in model_name:
        image = image.resize((224, 224))

    processor, model = model_instances[model_name]
    
    start_time = time.time()
    vram = torch.cuda.memory_allocated() / 1024**3 if torch.cuda.is_available() else 0

    if task == "captioning":
        caption_prompt = models[model_name]["caption_prompt"]
        caption_text = "" if "BLIP2" in model_name else caption_prompt
        inputs = processor(images=image, text=caption_text, return_tensors="pt").to(device)
        output_ids = model.generate(
            **inputs,
            max_new_tokens=50,
            num_beams=5 if "BLIP2" in model_name else 1,
            no_repeat_ngram_size=2 if "BLIP2" in model_name else 0
        )
        caption = processor.decode(output_ids[0], skip_special_tokens=True)
        inference_time = time.time() - start_time

        # Calcular BLEU (simplificado, usando referencia genérica)
        bleu_score = sentence_bleu([reference_caption[0].split()], caption.split()) if caption else 0.0

        return (caption, inference_time, None, None, vram, bleu_score)

    elif task == "vqa" and question:
        vqa_text = question if "BLIP2" in model_name else f"<image>Q: {question}"
        inputs = processor(images=image, text=vqa_text, return_tensors="pt").to(device)
        output_ids = model.generate(
            **inputs,
            max_new_tokens=10,
            num_beams=5 if "BLIP2" in model_name else 1,
            no_repeat_ngram_size=2 if "BLIP2" in model_name else 0
        )
        vqa_answer = processor.decode(output_ids[0], skip_special_tokens=True)
        inference_time = time.time() - start_time

        return (None, None, vqa_answer, inference_time, vram, None)

    return "Selecciona una tarea válida y, para VQA, una pregunta.", None, None, None, None, None

# Interfaz Gradio
with gr.Blocks(title="MLLM Benchmark Demo") as demo:
    gr.Markdown("# Benchmark para Modelos Multimodales (MLLMs)")
    gr.Markdown("Sube una imagen, selecciona un modelo y una tarea, y obtén resultados de captioning o VQA.")

    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="filepath", label="Subir Imagen")
            model_dropdown = gr.Dropdown(choices=["IDEFICS2", "BLIP2"], label="Seleccionar Modelo", value="IDEFICS2")
            task_dropdown = gr.Dropdown(choices=["captioning", "vqa"], label="Seleccionar Tarea", value="captioning")
            question_input = gr.Textbox(label="Pregunta VQA (opcional, solo para VQA)", placeholder="Ej: Are there people in the image?")
            submit_btn = gr.Button("Generar")

        with gr.Column():
            caption_output = gr.Textbox(label="Subtítulo Generado")
            vqa_output = gr.Textbox(label="Respuesta VQA")
            metrics_output = gr.Textbox(label="Métricas (Tiempo, VRAM, BLEU)")

    submit_btn.click(
        fn=infer,
        inputs=[image_input, model_dropdown, task_dropdown, question_input],
        outputs=[caption_output, gr.Number(label="Tiempo Captioning (s)"), vqa_output, gr.Number(label="Tiempo VQA (s)"), gr.Number(label="VRAM (GB)"), gr.Number(label="BLEU Score")]
    )

    gr.Markdown("### Notas")
    gr.Markdown("""
    - para mejroar la velocidad de inferencia, descarga en local y usar GPU avanzada.
    - La métrica BLEU usa una referencia genérica y puede no reflejar la calidad real.
    - Para más detalles, consulta el [repositorio del paper](https://huggingface.co/spaces/Pdro-ruiz/MLLM_Estado_del_Arte_Feb25/tree/main).
    """)

if __name__ == "__main__":
    demo.launch()