Anuji commited on
Commit
f5ef54e
·
verified ·
1 Parent(s): 65e53fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -50
app.py CHANGED
@@ -4,6 +4,9 @@ from PIL import Image
4
  from deepseek_vl2.serve.inference import load_model, deepseek_generate, convert_conversation_to_prompts
5
  from deepseek_vl2.serve.app_modules.utils import configure_logger, strip_stop_words, pil_to_base64
6
 
 
 
 
7
  # Set up logging
8
  logger = configure_logger()
9
 
@@ -20,7 +23,7 @@ def fetch_model(model_name: str, dtype=torch.bfloat16):
20
  model_info = load_model(model_name, dtype=dtype)
21
  tokenizer, model, vl_chat_processor = model_info
22
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23
- model = model.to(device).eval() # Move to appropriate device and set eval mode
24
  DEPLOY_MODELS[model_name] = (tokenizer, model, vl_chat_processor)
25
  logger.info(f"Loaded {model_name} on {device}")
26
  return DEPLOY_MODELS[model_name]
@@ -52,22 +55,21 @@ def to_gradio_chatbot(conv):
52
  ret[-1][-1] = msg
53
  return ret
54
 
55
- # Predict function
56
- def predict(text, images, chatbot, history, model_name="deepseek-ai/deepseek-vl2-tiny"):
57
  logger.info("Starting predict function...")
58
  tokenizer, vl_gpt, vl_chat_processor = fetch_model(model_name)
59
  if not text:
60
  logger.warning("Empty text input detected.")
61
- return chatbot, history, "Empty context."
62
 
63
  logger.info("Processing images...")
64
  pil_images = [Image.open(img).convert("RGB") for img in images] if images else []
65
  conversation = generate_prompt_with_history(
66
- text, pil_images, history, vl_chat_processor, tokenizer
67
  )
68
  all_conv, _ = convert_conversation_to_prompts(conversation)
69
  stop_words = conversation.stop_str
70
- gradio_chatbot_output = to_gradio_chatbot(conversation)
71
 
72
  full_response = ""
73
  logger.info("Generating response...")
@@ -85,53 +87,96 @@ def predict(text, images, chatbot, history, model_name="deepseek-ai/deepseek-vl2
85
  repetition_penalty=1.1
86
  ):
87
  full_response += x
88
- response = strip_stop_words(full_response, stop_words)
89
- conversation.update_last_message(response)
90
- gradio_chatbot_output[-1][1] = response
91
- logger.info(f"Yielding partial response: {response[:50]}...")
92
- yield gradio_chatbot_output, conversation.messages, "Generating..."
93
-
94
  logger.info("Generation complete.")
95
- if torch.cuda.is_available():
96
- torch.cuda.empty_cache() # Clear cache only if CUDA is available
97
- yield gradio_chatbot_output, conversation.messages, "Success"
98
  except Exception as e:
99
  logger.error(f"Error in generation: {str(e)}")
100
- yield gradio_chatbot_output, conversation.messages, f"Error: {str(e)}"
101
 
102
- # Streamlit OCR app interface
103
- def upload_and_process(image):
104
  if image is None:
105
- return "Please upload an image.", []
106
  prompt = "Extract all text from this image exactly as it appears, ensuring the output is in English only. Preserve spaces, bullets, numbers, and all formatting. Do not translate, generate, or include text in any other language. Stop at the last character of the image text."
107
- chatbot = []
108
- history = []
109
- logger.info("Starting upload_and_process...")
110
- for chatbot_output, history_output, status in predict(prompt, [image], chatbot, history):
111
- logger.info(f"Status: {status}")
112
- if status == "Success":
113
- return chatbot_output[-1][1], history_output
114
- elif status.startswith("Error"):
115
- return f"Error: {status}", []
116
- return "Processing failed.", []
117
-
118
- # Streamlit UI
119
- st.markdown("<h1 style='text-align: center;'>🔍 Extract Job Info in One Click</h1>", unsafe_allow_html=True)
120
-
121
- uploaded_file = st.file_uploader("Upload an Image (PNG, JPG, JPEG)", type=["png", "jpg", "jpeg"])
122
-
123
- if uploaded_file:
124
- extracted_text, _ = upload_and_process(uploaded_file)
125
-
126
- col1, col2 = st.columns(2)
127
- with col1:
128
- st.image(uploaded_file, caption="Uploaded Image", use_container_width=True)
129
- with col2:
130
- st.markdown(
131
- f"""
132
- <div style="border: 1px solid #ccc; padding: 10px; width: 100%; white-space: pre-wrap; overflow: hidden; text-align: left;">
133
- {extracted_text}
134
- </div>
135
- """,
136
- unsafe_allow_html=True
137
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from deepseek_vl2.serve.inference import load_model, deepseek_generate, convert_conversation_to_prompts
5
  from deepseek_vl2.serve.app_modules.utils import configure_logger, strip_stop_words, pil_to_base64
6
 
7
+ # Set Page Config (Must be first!)
8
+ st.set_page_config(layout="wide")
9
+
10
  # Set up logging
11
  logger = configure_logger()
12
 
 
23
  model_info = load_model(model_name, dtype=dtype)
24
  tokenizer, model, vl_chat_processor = model_info
25
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
26
+ model = model.to(device).eval() # Move to appropriate device
27
  DEPLOY_MODELS[model_name] = (tokenizer, model, vl_chat_processor)
28
  logger.info(f"Loaded {model_name} on {device}")
29
  return DEPLOY_MODELS[model_name]
 
55
  ret[-1][-1] = msg
56
  return ret
57
 
58
+ # Predict function (simplified for OCR)
59
+ def predict(text, images, model_name="deepseek-ai/deepseek-vl2-tiny"):
60
  logger.info("Starting predict function...")
61
  tokenizer, vl_gpt, vl_chat_processor = fetch_model(model_name)
62
  if not text:
63
  logger.warning("Empty text input detected.")
64
+ return "Empty context."
65
 
66
  logger.info("Processing images...")
67
  pil_images = [Image.open(img).convert("RGB") for img in images] if images else []
68
  conversation = generate_prompt_with_history(
69
+ text, pil_images, [], vl_chat_processor, tokenizer
70
  )
71
  all_conv, _ = convert_conversation_to_prompts(conversation)
72
  stop_words = conversation.stop_str
 
73
 
74
  full_response = ""
75
  logger.info("Generating response...")
 
87
  repetition_penalty=1.1
88
  ):
89
  full_response += x
90
+ response = strip_stop_words(full_response, stop_words)
 
 
 
 
 
91
  logger.info("Generation complete.")
92
+ return response
 
 
93
  except Exception as e:
94
  logger.error(f"Error in generation: {str(e)}")
95
+ return f"Error: {str(e)}"
96
 
97
+ # OCR extraction function
98
+ def extract_text(image):
99
  if image is None:
100
+ return "Please upload an image."
101
  prompt = "Extract all text from this image exactly as it appears, ensuring the output is in English only. Preserve spaces, bullets, numbers, and all formatting. Do not translate, generate, or include text in any other language. Stop at the last character of the image text."
102
+ logger.info("Starting text extraction...")
103
+ extracted_text = predict(prompt, [image])
104
+ return extracted_text
105
+
106
+ # User Selection State
107
+ if "user_selected" not in st.session_state:
108
+ st.session_state["user_selected"] = False
109
+ if "selected_user" not in st.session_state:
110
+ st.session_state["selected_user"] = None
111
+
112
+ # User Selection Page
113
+ @st.cache_data
114
+ def get_user_names():
115
+ return ["DS", "NW", "RB", "IG", "AR", "NU"]
116
+
117
+ user_names = get_user_names()
118
+
119
+ # Retrieve cached user from query params
120
+ query_params = st.query_params
121
+ cached_user = query_params.get("user", None)
122
+
123
+ if cached_user and "user_selected" not in st.session_state:
124
+ st.session_state["user_selected"] = True
125
+ st.session_state["selected_user"] = cached_user
126
+
127
+ if not st.session_state["user_selected"]:
128
+ st.title("👋 Let’s Get Started! Identify Yourself to Begin")
129
+ selected_user = st.selectbox("Choose your name:", user_names, index=user_names.index(cached_user) if cached_user in user_names else None)
130
+ continue_button = st.button("Continue", disabled=not selected_user)
131
+
132
+ if continue_button:
133
+ st.session_state["user_selected"] = True
134
+ st.session_state["selected_user"] = selected_user
135
+ st.query_params["user"] = selected_user
136
+ st.rerun()
137
+
138
+ if not selected_user:
139
+ st.warning("⚠ Please select a user to continue.")
140
+ else:
141
+ st.write(f"✅ Welcome Back, {st.session_state['selected_user']}!")
142
+
143
+ # Main UI (Only loads after user selection)
144
+ if st.session_state["user_selected"]:
145
+ st.markdown("<h1 style='text-align: center;'>🔍 Extract Job Info in One Click</h1>", unsafe_allow_html=True)
146
+
147
+ uploaded_file = st.file_uploader("Upload an Image (PNG, JPG, JPEG)", type=["png", "jpg", "jpeg"])
148
+
149
+ if uploaded_file:
150
+ img = Image.open(uploaded_file)
151
+ extracted_text = extract_text(uploaded_file)
152
+
153
+ col1, col2 = st.columns(2)
154
+ with col1:
155
+ st.image(uploaded_file, caption="Uploaded Image", use_container_width=True)
156
+ with col2:
157
+ st.markdown(
158
+ f"""
159
+ <div style="border: 1px solid #ccc; padding: 10px; width: 100%; white-space: pre-wrap; overflow: hidden; text-align: left;">
160
+ {extracted_text}
161
+ </div>
162
+ """,
163
+ unsafe_allow_html=True
164
+ )
165
+
166
+ st.session_state["extracted_text"] = extracted_text
167
+
168
+ errors_text = st.text_area("Paste Any Errors Here", height=200)
169
+ rating = st.radio("Rate the OCR Extraction (5-1)", options=[5, 4, 3, 2, 1], index=None, horizontal=True)
170
+ ref_number = st.text_input("Enter Reference Number", max_chars=10)
171
+
172
+ if ref_number and not ref_number.isdigit():
173
+ st.warning("⚠ Reference Number must be a number.")
174
+ ref_number = ""
175
+
176
+ if not ref_number or rating is None:
177
+ st.warning("⚠ Please enter a Reference Number and select a Rating to proceed.")
178
+ st.button("Submit", disabled=True)
179
+ else:
180
+ if st.button("Submit"):
181
+ st.success("✅ Submitted successfully!")
182
+ # No Google Drive or Sheets upload; just a confirmation