OCR-app / app.py
Anuji's picture
updated
09a5bc7 verified
raw
history blame
4.46 kB
import os
import torch
import streamlit as st
from PIL import Image
from deepseek_vl2.serve.inference import load_model, deepseek_generate, convert_conversation_to_prompts
from deepseek_vl2.models.conversation import SeparatorStyle
from deepseek_vl2.serve.app_modules.utils import configure_logger, strip_stop_words, pil_to_base64
# Initialize logger
logger = configure_logger()
# Global variables for model loading
MODELS = ["deepseek-ai/deepseek-vl2-tiny"]
DEPLOY_MODELS = {}
IMAGE_TOKEN = "<image>"
# Load model function
def fetch_model(model_name: str, dtype=torch.bfloat16):
global DEPLOY_MODELS
if model_name not in DEPLOY_MODELS:
st.write(f"Loading {model_name}...")
model_info = load_model(model_name, dtype=dtype)
tokenizer, model, vl_chat_processor = model_info
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
DEPLOY_MODELS[model_name] = (tokenizer, model, vl_chat_processor)
st.write(f"Loaded {model_name} on {device}")
return DEPLOY_MODELS[model_name]
# Generate prompt with conversation history
def generate_prompt_with_history(text, images, history, vl_chat_processor, tokenizer, max_length=2048):
conversation = vl_chat_processor.new_chat_template()
if history:
conversation.messages = history
if images:
text = f"{IMAGE_TOKEN}\n{text}"
text = (text, images)
conversation.append_message(conversation.roles[0], text)
conversation.append_message(conversation.roles[1], "")
return conversation
# Convert conversation to Gradio-compatible format
def to_gradio_chatbot(conv):
ret = []
for i, (role, msg) in enumerate(conv.messages[conv.offset:]):
if i % 2 == 0:
if isinstance(msg, tuple):
msg, images = msg
for image in images:
img_b64 = pil_to_base64(image, "user upload", max_size=800, min_size=400)
msg = msg.replace(IMAGE_TOKEN, img_b64, 1)
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret
# Prediction function
def predict(text, images, chatbot, history, model_name="deepseek-ai/deepseek-vl2-tiny"):
tokenizer, vl_gpt, vl_chat_processor = fetch_model(model_name)
if not text:
return chatbot, history, "Empty context."
pil_images = [Image.open(img).convert("RGB") for img in images] if images else []
conversation = generate_prompt_with_history(
text, pil_images, history, vl_chat_processor, tokenizer
)
all_conv, _ = convert_conversation_to_prompts(conversation)
stop_words = conversation.stop_str
gradio_chatbot_output = to_gradio_chatbot(conversation)
full_response = ""
try:
with torch.no_grad():
for x in deepseek_generate(
conversations=all_conv,
vl_gpt=vl_gpt,
vl_chat_processor=vl_chat_processor,
tokenizer=tokenizer,
stop_words=stop_words,
max_length=2048,
temperature=0.1,
top_p=0.9,
repetition_penalty=1.1
):
full_response += x
response = strip_stop_words(full_response, stop_words)
conversation.update_last_message(response)
gradio_chatbot_output[-1][1] = response
yield gradio_chatbot_output, conversation.messages, "Generating..."
torch.cuda.empty_cache()
yield gradio_chatbot_output, conversation.messages, "Success"
except Exception as e:
yield gradio_chatbot_output, conversation.messages, f"Error: {str(e)}"
# Streamlit UI setup
st.title("DeepSeek-VL2 OCR in Colab")
st.write("Upload an image and get the extracted text.")
# Image upload
image_input = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg"])
# Output text
output_text = st.text_area("Extracted Text", "")
# Handle the image upload and processing
if image_input:
prompt = "Extract all text from this image exactly as it appears, ensuring the output is in English only."
chatbot = []
history = []
for chatbot_output, history_output, status in predict(prompt, [image_input], chatbot, history):
if status == "Success":
output_text = chatbot_output[-1][1]
st.write("Extracted Text:", output_text)
else:
st.error(f"Error: {status}")