Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -4,6 +4,9 @@ from PIL import Image
|
|
4 |
from deepseek_vl2.serve.inference import load_model, deepseek_generate, convert_conversation_to_prompts
|
5 |
from deepseek_vl2.serve.app_modules.utils import configure_logger, strip_stop_words, pil_to_base64
|
6 |
|
|
|
|
|
|
|
7 |
# Set up logging
|
8 |
logger = configure_logger()
|
9 |
|
@@ -20,7 +23,7 @@ def fetch_model(model_name: str, dtype=torch.bfloat16):
|
|
20 |
model_info = load_model(model_name, dtype=dtype)
|
21 |
tokenizer, model, vl_chat_processor = model_info
|
22 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
23 |
-
model = model.to(device).eval() # Move to appropriate device
|
24 |
DEPLOY_MODELS[model_name] = (tokenizer, model, vl_chat_processor)
|
25 |
logger.info(f"Loaded {model_name} on {device}")
|
26 |
return DEPLOY_MODELS[model_name]
|
@@ -52,22 +55,21 @@ def to_gradio_chatbot(conv):
|
|
52 |
ret[-1][-1] = msg
|
53 |
return ret
|
54 |
|
55 |
-
# Predict function
|
56 |
-
def predict(text, images,
|
57 |
logger.info("Starting predict function...")
|
58 |
tokenizer, vl_gpt, vl_chat_processor = fetch_model(model_name)
|
59 |
if not text:
|
60 |
logger.warning("Empty text input detected.")
|
61 |
-
return
|
62 |
|
63 |
logger.info("Processing images...")
|
64 |
pil_images = [Image.open(img).convert("RGB") for img in images] if images else []
|
65 |
conversation = generate_prompt_with_history(
|
66 |
-
text, pil_images,
|
67 |
)
|
68 |
all_conv, _ = convert_conversation_to_prompts(conversation)
|
69 |
stop_words = conversation.stop_str
|
70 |
-
gradio_chatbot_output = to_gradio_chatbot(conversation)
|
71 |
|
72 |
full_response = ""
|
73 |
logger.info("Generating response...")
|
@@ -85,53 +87,96 @@ def predict(text, images, chatbot, history, model_name="deepseek-ai/deepseek-vl2
|
|
85 |
repetition_penalty=1.1
|
86 |
):
|
87 |
full_response += x
|
88 |
-
|
89 |
-
conversation.update_last_message(response)
|
90 |
-
gradio_chatbot_output[-1][1] = response
|
91 |
-
logger.info(f"Yielding partial response: {response[:50]}...")
|
92 |
-
yield gradio_chatbot_output, conversation.messages, "Generating..."
|
93 |
-
|
94 |
logger.info("Generation complete.")
|
95 |
-
|
96 |
-
torch.cuda.empty_cache() # Clear cache only if CUDA is available
|
97 |
-
yield gradio_chatbot_output, conversation.messages, "Success"
|
98 |
except Exception as e:
|
99 |
logger.error(f"Error in generation: {str(e)}")
|
100 |
-
|
101 |
|
102 |
-
#
|
103 |
-
def
|
104 |
if image is None:
|
105 |
-
return "Please upload an image."
|
106 |
prompt = "Extract all text from this image exactly as it appears, ensuring the output is in English only. Preserve spaces, bullets, numbers, and all formatting. Do not translate, generate, or include text in any other language. Stop at the last character of the image text."
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
from deepseek_vl2.serve.inference import load_model, deepseek_generate, convert_conversation_to_prompts
|
5 |
from deepseek_vl2.serve.app_modules.utils import configure_logger, strip_stop_words, pil_to_base64
|
6 |
|
7 |
+
# Set Page Config (Must be first!)
|
8 |
+
st.set_page_config(layout="wide")
|
9 |
+
|
10 |
# Set up logging
|
11 |
logger = configure_logger()
|
12 |
|
|
|
23 |
model_info = load_model(model_name, dtype=dtype)
|
24 |
tokenizer, model, vl_chat_processor = model_info
|
25 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
26 |
+
model = model.to(device).eval() # Move to appropriate device
|
27 |
DEPLOY_MODELS[model_name] = (tokenizer, model, vl_chat_processor)
|
28 |
logger.info(f"Loaded {model_name} on {device}")
|
29 |
return DEPLOY_MODELS[model_name]
|
|
|
55 |
ret[-1][-1] = msg
|
56 |
return ret
|
57 |
|
58 |
+
# Predict function (simplified for OCR)
|
59 |
+
def predict(text, images, model_name="deepseek-ai/deepseek-vl2-tiny"):
|
60 |
logger.info("Starting predict function...")
|
61 |
tokenizer, vl_gpt, vl_chat_processor = fetch_model(model_name)
|
62 |
if not text:
|
63 |
logger.warning("Empty text input detected.")
|
64 |
+
return "Empty context."
|
65 |
|
66 |
logger.info("Processing images...")
|
67 |
pil_images = [Image.open(img).convert("RGB") for img in images] if images else []
|
68 |
conversation = generate_prompt_with_history(
|
69 |
+
text, pil_images, [], vl_chat_processor, tokenizer
|
70 |
)
|
71 |
all_conv, _ = convert_conversation_to_prompts(conversation)
|
72 |
stop_words = conversation.stop_str
|
|
|
73 |
|
74 |
full_response = ""
|
75 |
logger.info("Generating response...")
|
|
|
87 |
repetition_penalty=1.1
|
88 |
):
|
89 |
full_response += x
|
90 |
+
response = strip_stop_words(full_response, stop_words)
|
|
|
|
|
|
|
|
|
|
|
91 |
logger.info("Generation complete.")
|
92 |
+
return response
|
|
|
|
|
93 |
except Exception as e:
|
94 |
logger.error(f"Error in generation: {str(e)}")
|
95 |
+
return f"Error: {str(e)}"
|
96 |
|
97 |
+
# OCR extraction function
|
98 |
+
def extract_text(image):
|
99 |
if image is None:
|
100 |
+
return "Please upload an image."
|
101 |
prompt = "Extract all text from this image exactly as it appears, ensuring the output is in English only. Preserve spaces, bullets, numbers, and all formatting. Do not translate, generate, or include text in any other language. Stop at the last character of the image text."
|
102 |
+
logger.info("Starting text extraction...")
|
103 |
+
extracted_text = predict(prompt, [image])
|
104 |
+
return extracted_text
|
105 |
+
|
106 |
+
# User Selection State
|
107 |
+
if "user_selected" not in st.session_state:
|
108 |
+
st.session_state["user_selected"] = False
|
109 |
+
if "selected_user" not in st.session_state:
|
110 |
+
st.session_state["selected_user"] = None
|
111 |
+
|
112 |
+
# User Selection Page
|
113 |
+
@st.cache_data
|
114 |
+
def get_user_names():
|
115 |
+
return ["DS", "NW", "RB", "IG", "AR", "NU"]
|
116 |
+
|
117 |
+
user_names = get_user_names()
|
118 |
+
|
119 |
+
# Retrieve cached user from query params
|
120 |
+
query_params = st.query_params
|
121 |
+
cached_user = query_params.get("user", None)
|
122 |
+
|
123 |
+
if cached_user and "user_selected" not in st.session_state:
|
124 |
+
st.session_state["user_selected"] = True
|
125 |
+
st.session_state["selected_user"] = cached_user
|
126 |
+
|
127 |
+
if not st.session_state["user_selected"]:
|
128 |
+
st.title("👋 Let’s Get Started! Identify Yourself to Begin")
|
129 |
+
selected_user = st.selectbox("Choose your name:", user_names, index=user_names.index(cached_user) if cached_user in user_names else None)
|
130 |
+
continue_button = st.button("Continue", disabled=not selected_user)
|
131 |
+
|
132 |
+
if continue_button:
|
133 |
+
st.session_state["user_selected"] = True
|
134 |
+
st.session_state["selected_user"] = selected_user
|
135 |
+
st.query_params["user"] = selected_user
|
136 |
+
st.rerun()
|
137 |
+
|
138 |
+
if not selected_user:
|
139 |
+
st.warning("⚠ Please select a user to continue.")
|
140 |
+
else:
|
141 |
+
st.write(f"✅ Welcome Back, {st.session_state['selected_user']}!")
|
142 |
+
|
143 |
+
# Main UI (Only loads after user selection)
|
144 |
+
if st.session_state["user_selected"]:
|
145 |
+
st.markdown("<h1 style='text-align: center;'>🔍 Extract Job Info in One Click</h1>", unsafe_allow_html=True)
|
146 |
+
|
147 |
+
uploaded_file = st.file_uploader("Upload an Image (PNG, JPG, JPEG)", type=["png", "jpg", "jpeg"])
|
148 |
+
|
149 |
+
if uploaded_file:
|
150 |
+
img = Image.open(uploaded_file)
|
151 |
+
extracted_text = extract_text(uploaded_file)
|
152 |
+
|
153 |
+
col1, col2 = st.columns(2)
|
154 |
+
with col1:
|
155 |
+
st.image(uploaded_file, caption="Uploaded Image", use_container_width=True)
|
156 |
+
with col2:
|
157 |
+
st.markdown(
|
158 |
+
f"""
|
159 |
+
<div style="border: 1px solid #ccc; padding: 10px; width: 100%; white-space: pre-wrap; overflow: hidden; text-align: left;">
|
160 |
+
{extracted_text}
|
161 |
+
</div>
|
162 |
+
""",
|
163 |
+
unsafe_allow_html=True
|
164 |
+
)
|
165 |
+
|
166 |
+
st.session_state["extracted_text"] = extracted_text
|
167 |
+
|
168 |
+
errors_text = st.text_area("Paste Any Errors Here", height=200)
|
169 |
+
rating = st.radio("Rate the OCR Extraction (5-1)", options=[5, 4, 3, 2, 1], index=None, horizontal=True)
|
170 |
+
ref_number = st.text_input("Enter Reference Number", max_chars=10)
|
171 |
+
|
172 |
+
if ref_number and not ref_number.isdigit():
|
173 |
+
st.warning("⚠ Reference Number must be a number.")
|
174 |
+
ref_number = ""
|
175 |
+
|
176 |
+
if not ref_number or rating is None:
|
177 |
+
st.warning("⚠ Please enter a Reference Number and select a Rating to proceed.")
|
178 |
+
st.button("Submit", disabled=True)
|
179 |
+
else:
|
180 |
+
if st.button("Submit"):
|
181 |
+
st.success("✅ Submitted successfully!")
|
182 |
+
# No Google Drive or Sheets upload; just a confirmation
|