|
import os |
|
import torch |
|
import streamlit as st |
|
from PIL import Image |
|
from deepseek_vl2.serve.inference import load_model, deepseek_generate, convert_conversation_to_prompts |
|
from deepseek_vl2.models.conversation import SeparatorStyle |
|
from deepseek_vl2.serve.app_modules.utils import configure_logger, strip_stop_words, pil_to_base64 |
|
|
|
|
|
logger = configure_logger() |
|
|
|
|
|
MODELS = ["deepseek-ai/deepseek-vl2-tiny"] |
|
DEPLOY_MODELS = {} |
|
IMAGE_TOKEN = "<image>" |
|
|
|
|
|
def fetch_model(model_name: str, dtype=torch.bfloat16): |
|
global DEPLOY_MODELS |
|
if model_name not in DEPLOY_MODELS: |
|
st.write(f"Loading {model_name}...") |
|
model_info = load_model(model_name, dtype=dtype) |
|
tokenizer, model, vl_chat_processor = model_info |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
model = model.to(device) |
|
DEPLOY_MODELS[model_name] = (tokenizer, model, vl_chat_processor) |
|
st.write(f"Loaded {model_name} on {device}") |
|
return DEPLOY_MODELS[model_name] |
|
|
|
|
|
def generate_prompt_with_history(text, images, history, vl_chat_processor, tokenizer, max_length=2048): |
|
conversation = vl_chat_processor.new_chat_template() |
|
if history: |
|
conversation.messages = history |
|
if images: |
|
text = f"{IMAGE_TOKEN}\n{text}" |
|
text = (text, images) |
|
conversation.append_message(conversation.roles[0], text) |
|
conversation.append_message(conversation.roles[1], "") |
|
return conversation |
|
|
|
|
|
def to_gradio_chatbot(conv): |
|
ret = [] |
|
for i, (role, msg) in enumerate(conv.messages[conv.offset:]): |
|
if i % 2 == 0: |
|
if isinstance(msg, tuple): |
|
msg, images = msg |
|
for image in images: |
|
img_b64 = pil_to_base64(image, "user upload", max_size=800, min_size=400) |
|
msg = msg.replace(IMAGE_TOKEN, img_b64, 1) |
|
ret.append([msg, None]) |
|
else: |
|
ret[-1][-1] = msg |
|
return ret |
|
|
|
|
|
def predict(text, images, chatbot, history, model_name="deepseek-ai/deepseek-vl2-tiny"): |
|
tokenizer, vl_gpt, vl_chat_processor = fetch_model(model_name) |
|
if not text: |
|
return chatbot, history, "Empty context." |
|
|
|
pil_images = [Image.open(img).convert("RGB") for img in images] if images else [] |
|
conversation = generate_prompt_with_history( |
|
text, pil_images, history, vl_chat_processor, tokenizer |
|
) |
|
all_conv, _ = convert_conversation_to_prompts(conversation) |
|
stop_words = conversation.stop_str |
|
gradio_chatbot_output = to_gradio_chatbot(conversation) |
|
|
|
full_response = "" |
|
try: |
|
with torch.no_grad(): |
|
for x in deepseek_generate( |
|
conversations=all_conv, |
|
vl_gpt=vl_gpt, |
|
vl_chat_processor=vl_chat_processor, |
|
tokenizer=tokenizer, |
|
stop_words=stop_words, |
|
max_length=2048, |
|
temperature=0.1, |
|
top_p=0.9, |
|
repetition_penalty=1.1 |
|
): |
|
full_response += x |
|
response = strip_stop_words(full_response, stop_words) |
|
conversation.update_last_message(response) |
|
gradio_chatbot_output[-1][1] = response |
|
yield gradio_chatbot_output, conversation.messages, "Generating..." |
|
|
|
torch.cuda.empty_cache() |
|
yield gradio_chatbot_output, conversation.messages, "Success" |
|
except Exception as e: |
|
yield gradio_chatbot_output, conversation.messages, f"Error: {str(e)}" |
|
|
|
|
|
st.title("DeepSeek-VL2 OCR in Colab") |
|
st.write("Upload an image and get the extracted text.") |
|
|
|
|
|
image_input = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg"]) |
|
|
|
|
|
output_text = st.text_area("Extracted Text", "") |
|
|
|
|
|
if image_input: |
|
prompt = "Extract all text from this image exactly as it appears, ensuring the output is in English only." |
|
chatbot = [] |
|
history = [] |
|
for chatbot_output, history_output, status in predict(prompt, [image_input], chatbot, history): |
|
if status == "Success": |
|
output_text = chatbot_output[-1][1] |
|
st.write("Extracted Text:", output_text) |
|
else: |
|
st.error(f"Error: {status}") |
|
|