Anuji commited on
Commit
3db54ae
·
verified ·
1 Parent(s): 92affd7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -28
app.py CHANGED
@@ -1,33 +1,31 @@
1
- import os
2
  import torch
3
  import streamlit as st
4
  from PIL import Image
5
  from deepseek_vl2.serve.inference import load_model, deepseek_generate, convert_conversation_to_prompts
6
- from deepseek_vl2.models.conversation import SeparatorStyle
7
  from deepseek_vl2.serve.app_modules.utils import configure_logger, strip_stop_words, pil_to_base64
8
 
9
- # Initialize logger
10
  logger = configure_logger()
11
 
12
- # Global variables for model loading
13
  MODELS = ["deepseek-ai/deepseek-vl2-tiny"]
14
  DEPLOY_MODELS = {}
15
  IMAGE_TOKEN = "<image>"
16
 
17
- # Load model function
18
  def fetch_model(model_name: str, dtype=torch.bfloat16):
19
  global DEPLOY_MODELS
20
  if model_name not in DEPLOY_MODELS:
21
- st.write(f"Loading {model_name}...")
22
  model_info = load_model(model_name, dtype=dtype)
23
  tokenizer, model, vl_chat_processor = model_info
24
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25
  model = model.to(device)
26
  DEPLOY_MODELS[model_name] = (tokenizer, model, vl_chat_processor)
27
- st.write(f"Loaded {model_name} on {device}")
28
  return DEPLOY_MODELS[model_name]
29
 
30
- # Generate prompt with conversation history
31
  def generate_prompt_with_history(text, images, history, vl_chat_processor, tokenizer, max_length=2048):
32
  conversation = vl_chat_processor.new_chat_template()
33
  if history:
@@ -39,7 +37,7 @@ def generate_prompt_with_history(text, images, history, vl_chat_processor, token
39
  conversation.append_message(conversation.roles[1], "")
40
  return conversation
41
 
42
- # Convert conversation to Gradio-compatible format
43
  def to_gradio_chatbot(conv):
44
  ret = []
45
  for i, (role, msg) in enumerate(conv.messages[conv.offset:]):
@@ -54,12 +52,15 @@ def to_gradio_chatbot(conv):
54
  ret[-1][-1] = msg
55
  return ret
56
 
57
- # Prediction function
58
  def predict(text, images, chatbot, history, model_name="deepseek-ai/deepseek-vl2-tiny"):
 
59
  tokenizer, vl_gpt, vl_chat_processor = fetch_model(model_name)
60
  if not text:
 
61
  return chatbot, history, "Empty context."
62
 
 
63
  pil_images = [Image.open(img).convert("RGB") for img in images] if images else []
64
  conversation = generate_prompt_with_history(
65
  text, pil_images, history, vl_chat_processor, tokenizer
@@ -69,6 +70,7 @@ def predict(text, images, chatbot, history, model_name="deepseek-ai/deepseek-vl2
69
  gradio_chatbot_output = to_gradio_chatbot(conversation)
70
 
71
  full_response = ""
 
72
  try:
73
  with torch.no_grad():
74
  for x in deepseek_generate(
@@ -86,31 +88,35 @@ def predict(text, images, chatbot, history, model_name="deepseek-ai/deepseek-vl2
86
  response = strip_stop_words(full_response, stop_words)
87
  conversation.update_last_message(response)
88
  gradio_chatbot_output[-1][1] = response
 
89
  yield gradio_chatbot_output, conversation.messages, "Generating..."
90
 
 
91
  torch.cuda.empty_cache()
92
  yield gradio_chatbot_output, conversation.messages, "Success"
93
  except Exception as e:
 
94
  yield gradio_chatbot_output, conversation.messages, f"Error: {str(e)}"
95
 
96
- # Streamlit UI setup
97
- st.title("DeepSeek-VL2 OCR in Colab")
98
- st.write("Upload an image and get the extracted text.")
99
-
100
- # Image upload
101
- image_input = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg"])
102
-
103
- # Output text
104
- output_text = st.text_area("Extracted Text", "")
105
-
106
- # Handle the image upload and processing
107
- if image_input:
108
- prompt = "Extract all text from this image exactly as it appears, ensuring the output is in English only."
109
  chatbot = []
110
  history = []
111
- for chatbot_output, history_output, status in predict(prompt, [image_input], chatbot, history):
 
 
112
  if status == "Success":
113
- output_text = chatbot_output[-1][1]
114
- st.write("Extracted Text:", output_text)
115
- else:
116
- st.error(f"Error: {status}")
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import streamlit as st
3
  from PIL import Image
4
  from deepseek_vl2.serve.inference import load_model, deepseek_generate, convert_conversation_to_prompts
 
5
  from deepseek_vl2.serve.app_modules.utils import configure_logger, strip_stop_words, pil_to_base64
6
 
7
+ # Set up logging
8
  logger = configure_logger()
9
 
10
+ # Models and deployment
11
  MODELS = ["deepseek-ai/deepseek-vl2-tiny"]
12
  DEPLOY_MODELS = {}
13
  IMAGE_TOKEN = "<image>"
14
 
15
+ # Fetch model
16
  def fetch_model(model_name: str, dtype=torch.bfloat16):
17
  global DEPLOY_MODELS
18
  if model_name not in DEPLOY_MODELS:
19
+ logger.info(f"Loading {model_name}...")
20
  model_info = load_model(model_name, dtype=dtype)
21
  tokenizer, model, vl_chat_processor = model_info
22
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23
  model = model.to(device)
24
  DEPLOY_MODELS[model_name] = (tokenizer, model, vl_chat_processor)
25
+ logger.info(f"Loaded {model_name} on {device}")
26
  return DEPLOY_MODELS[model_name]
27
 
28
+ # Generate prompt with history
29
  def generate_prompt_with_history(text, images, history, vl_chat_processor, tokenizer, max_length=2048):
30
  conversation = vl_chat_processor.new_chat_template()
31
  if history:
 
37
  conversation.append_message(conversation.roles[1], "")
38
  return conversation
39
 
40
+ # Convert conversation to gradio format
41
  def to_gradio_chatbot(conv):
42
  ret = []
43
  for i, (role, msg) in enumerate(conv.messages[conv.offset:]):
 
52
  ret[-1][-1] = msg
53
  return ret
54
 
55
+ # Predict function
56
  def predict(text, images, chatbot, history, model_name="deepseek-ai/deepseek-vl2-tiny"):
57
+ logger.info("Starting predict function...")
58
  tokenizer, vl_gpt, vl_chat_processor = fetch_model(model_name)
59
  if not text:
60
+ logger.warning("Empty text input detected.")
61
  return chatbot, history, "Empty context."
62
 
63
+ logger.info("Processing images...")
64
  pil_images = [Image.open(img).convert("RGB") for img in images] if images else []
65
  conversation = generate_prompt_with_history(
66
  text, pil_images, history, vl_chat_processor, tokenizer
 
70
  gradio_chatbot_output = to_gradio_chatbot(conversation)
71
 
72
  full_response = ""
73
+ logger.info("Generating response...")
74
  try:
75
  with torch.no_grad():
76
  for x in deepseek_generate(
 
88
  response = strip_stop_words(full_response, stop_words)
89
  conversation.update_last_message(response)
90
  gradio_chatbot_output[-1][1] = response
91
+ logger.info(f"Yielding partial response: {response[:50]}...")
92
  yield gradio_chatbot_output, conversation.messages, "Generating..."
93
 
94
+ logger.info("Generation complete.")
95
  torch.cuda.empty_cache()
96
  yield gradio_chatbot_output, conversation.messages, "Success"
97
  except Exception as e:
98
+ logger.error(f"Error in generation: {str(e)}")
99
  yield gradio_chatbot_output, conversation.messages, f"Error: {str(e)}"
100
 
101
+ # Streamlit OCR app interface
102
+ def upload_and_process(image):
103
+ if image is None:
104
+ return "Please upload an image.", []
105
+ prompt = "Extract all text from this image exactly as it appears, ensuring the output is in English only. Preserve spaces, bullets, numbers, and all formatting. Do not translate, generate, or include text in any other language. Stop at the last character of the image text."
 
 
 
 
 
 
 
 
106
  chatbot = []
107
  history = []
108
+ logger.info("Starting upload_and_process...")
109
+ for chatbot_output, history_output, status in predict(prompt, [image], chatbot, history):
110
+ logger.info(f"Status: {status}")
111
  if status == "Success":
112
+ return chatbot_output[-1][1], history_output
113
+ return "Processing failed.", []
114
+
115
+ # Streamlit UI
116
+ st.title("DeepSeek-VL2 OCR with Streamlit")
117
+ image_input = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg"])
118
+ output_text = st.text_area("Extracted Text", height=300)
119
+ if image_input:
120
+ output, _ = upload_and_process(image_input)
121
+ output_text.write(output)
122
+