Varu96's picture
Update app.py
435eee3 verified
raw
history blame contribute delete
3.58 kB
import json
import os
from transformers import AutoProcessor, AutoModelForVision2Seq
import torch
from PIL import Image
import gradio as gr
import subprocess
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path
from llava.eval.run_llava import eval_model
os.chdir("My_new_LLaVA/LLaVA") # Update this if needed
# Verify the current working directory
print("Current Working Directory:", os.getcwd())
# Load the LLaVA model and processor
llava_model_path = "/My_new_LLaVA/llava-fine_tune_model"
# Load the LLaVA-Med model and processor
llava_med_model_path = "/My_new_LLaVA/llava-fine_tune_model"
# Args class to store arguments for LLaVA models
class Args:
def __init__(self, model_path, model_base, model_name, query, image_path, conv_mode, image_file, sep, temperature, top_p, num_beams, max_new_tokens):
self.model_path = model_path
self.model_base = model_base
self.model_name = model_name
self.query = query
self.image_path = image_path
self.conv_mode = conv_mode
self.image_file = image_file
self.sep = sep
self.temperature = temperature
self.top_p = top_p
self.num_beams = num_beams
self.max_new_tokens = max_new_tokens
# Function to predict using LLaVA
def predict_llava(image, question, temperature, max_tokens):
# Save the image temporarily
image.save("temp_image.jpg")
# Setup evaluation arguments
args = Args(
model_path=llava_model_path,
model_base=None,
model_name=get_model_name_from_path(llava_model_path),
query=question,
image_path="temp_image.jpg",
conv_mode=None,
image_file="temp_image.jpg",
sep=",",
temperature=temperature,
top_p=None,
num_beams=1,
max_new_tokens=max_tokens
)
# Generate the answer using the selected model
output = eval_model(args)
return output
# Function to predict using LLaVA-Med
def predict_llava_med(image, question, temperature, max_tokens):
# Save the image temporarily
image_path = "temp_image_med.jpg"
image.save(image_path)
# Command to run the LLaVA-Med model
command = [
"python", "-m", "llava.eval.run_llava",
"--model-name", llava_med_model_path,
"--image-file", image_path,
"--query", question,
"--temperature", str(temperature),
"--max-new-tokens", str(max_tokens)
]
# Execute the command and capture the output
result = subprocess.run(command, capture_output=True, text=True)
return result.stdout.strip() # Return the output as text
# Main prediction function
def predict(model_name, image, text, temperature, max_tokens):
if model_name == "LLaVA":
return predict_llava(image, text, temperature, max_tokens)
elif model_name == "LLaVA-Med":
return predict_llava_med(image, text, temperature, max_tokens)
# Define the Gradio interface
interface = gr.Interface(
fn=predict,
inputs=[
gr.Radio(choices=["LLaVA", "LLaVA-Med"], label="Select Model"),
gr.Image(type="pil", label="Input Image"),
gr.Textbox(label="Input Text"),
gr.Slider(minimum=0.1, maximum=1.0, default=0.7, label="Temperature"),
gr.Slider(minimum=1, maximum=512, default=256, label="Max Tokens"),
],
outputs=gr.Textbox(label="Output Text"),
title="Multimodal LLM Interface",
description="Switch between models and adjust parameters.",
)
# Launch the Gradio interface
interface.launch()