Knightmovies commited on
Commit
5cac9c5
Β·
verified Β·
1 Parent(s): 28c6b80

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -45
app.py CHANGED
@@ -8,7 +8,7 @@ import pytesseract
8
  from scipy.spatial import distance as dist
9
 
10
  # ==============================================================================
11
- # App Configuration
12
  # ==============================================================================
13
  st.set_page_config(
14
  page_title="Document AI Toolkit",
@@ -16,6 +16,19 @@ st.set_page_config(
16
  layout="wide"
17
  )
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  # ==============================================================================
20
  # Model Loading (Cached)
21
  # ==============================================================================
@@ -27,7 +40,7 @@ def load_model():
27
  model, processor = load_model()
28
 
29
  # ==============================================================================
30
- # Core Image Processing Functions
31
  # ==============================================================================
32
  def order_points(pts):
33
  xSorted = pts[np.argsort(pts[:, 0]), :]
@@ -71,16 +84,18 @@ def correct_orientation(image):
71
  return cv2.rotate(image, angle_map[rotation])
72
  return image
73
  except Exception:
74
- # Fallback to bounding box method if OSD fails
75
  gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
76
  thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
77
  orientations = {0: thresh, 90: cv2.rotate(thresh, cv2.ROTATE_90_CLOCKWISE), 180: cv2.rotate(thresh, cv2.ROTATE_180), 270: cv2.rotate(thresh, cv2.ROTATE_90_COUNTERCLOCKWISE)}
78
  best_rotation, max_horizontal_boxes = 0, -1
79
  for angle, rotated_img in orientations.items():
80
- data = pytesseract.image_to_data(rotated_img, output_type=pytesseract.Output.DICT, timeout=5)
81
- horizontal_boxes = sum(1 for i, conf in enumerate(data['conf']) if int(conf) > 10 and data['width'][i] > data['height'][i])
82
- if horizontal_boxes > max_horizontal_boxes:
83
- max_horizontal_boxes, best_rotation = horizontal_boxes, angle
 
 
 
84
  angle_map = {90: cv2.ROTATE_90_CLOCKWISE, 180: cv2.ROTATE_180, 270: cv2.ROTATE_90_COUNTERCLOCKWISE}
85
  return cv2.rotate(image, angle_map[best_rotation]) if best_rotation > 0 else image
86
 
@@ -92,9 +107,8 @@ def extract_and_draw_table_structure(image_bgr):
92
  outputs = model(**inputs)
93
  target_sizes = torch.tensor([image_pil.size[::-1]])
94
  results = processor.post_process_object_detection(outputs, threshold=0.6, target_sizes=target_sizes)[0]
95
-
96
  img_with_boxes = image_bgr.copy()
97
- colors = {"table row": (0, 255, 0), "table column": (255, 0, 0), "table": (255, 0, 255)} # Red for columns
98
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
99
  class_name = model.config.id2label[label.item()]
100
  if class_name in colors:
@@ -118,64 +132,71 @@ with st.sidebar:
118
  st.title("πŸ€– Document AI Toolkit")
119
  st.markdown("---")
120
 
 
 
 
 
 
121
  if st.session_state.stage == "upload":
122
  st.header("Step 1: Upload Image")
123
- uploaded_file = st.file_uploader("Upload your document image", type=["jpg", "jpeg", "png"], label_visibility="collapsed")
124
  if uploaded_file:
125
  file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
126
  st.session_state.original_image = cv2.imdecode(file_bytes, 1)
127
  st.session_state.stage = "processing"
128
  st.rerun()
129
-
130
  elif st.session_state.stage == "processing":
131
  st.header("Step 2: Pre-process")
132
- st.info("Straightening and correcting orientation...")
133
- if st.button("▢️ Start Pre-processing"):
134
- with st.spinner("Working..."):
135
  original_image = st.session_state.original_image
136
  straightened = find_and_straighten_document(original_image)
137
  image_to_orient = straightened if straightened is not None and straightened.size > 0 else original_image
138
  st.session_state.processed_image = correct_orientation(image_to_orient)
139
  st.session_state.stage = "analysis"
140
  st.rerun()
141
-
142
  elif st.session_state.stage == "analysis":
143
  st.header("Step 3: Analyze Table")
144
- st.info("Detecting table structure...")
145
- if st.button("πŸ“Š Find Table Structure"):
146
  with st.spinner("Running Table Transformer model..."):
147
  st.session_state.annotated_image = extract_and_draw_table_structure(st.session_state.processed_image)
148
  st.session_state.stage = "done"
149
  st.rerun()
150
 
151
- if st.session_state.stage != "upload":
152
- if st.button("πŸ”„ Start Over"):
153
- for key in list(st.session_state.keys()):
154
- del st.session_state[key]
155
- st.rerun()
156
-
157
  # --- Main Panel Display ---
158
- st.header("Document Processing Stages")
159
-
160
- if st.session_state.stage == "upload":
161
- st.info("Please upload a document image using the sidebar to begin.")
162
-
 
 
 
 
 
 
 
163
  if st.session_state.original_image is not None:
164
- st.subheader("1. Original Image")
165
- st.image(cv2.cvtColor(st.session_state.original_image, cv2.COLOR_BGR2RGB), use_container_width=True)
166
-
 
 
 
 
 
 
167
  if st.session_state.processed_image is not None:
168
- st.subheader("2. Pre-processed Image")
169
- st.image(cv2.cvtColor(st.session_state.processed_image, cv2.COLOR_BGR2RGB), caption="Straightened & Oriented", use_container_width=True)
170
-
171
- if st.session_state.annotated_image is not None:
172
- st.subheader("3. Final Analysis")
173
- tab1, tab2 = st.tabs(["βœ… Corrected Document", "πŸ“Š Table Structure"])
174
-
175
- with tab1:
176
- st.image(cv2.cvtColor(st.session_state.processed_image, cv2.COLOR_BGR2RGB), use_container_width=True)
177
- _, buf = cv2.imencode(".jpg", st.session_state.processed_image)
178
- st.download_button("πŸ“₯ Download Clean Image", data=buf.tobytes(), file_name="corrected.jpg", mime="image/jpeg")
179
-
180
- with tab2:
181
- st.image(cv2.cvtColor(st.session_state.annotated_image, cv2.COLOR_BGR2RGB), use_container_width=True)
 
8
  from scipy.spatial import distance as dist
9
 
10
  # ==============================================================================
11
+ # App Configuration & Styling
12
  # ==============================================================================
13
  st.set_page_config(
14
  page_title="Document AI Toolkit",
 
16
  layout="wide"
17
  )
18
 
19
+ # Inject CSS for a centered, fixed-width layout
20
+ st.markdown("""
21
+ <style>
22
+ .main .block-container {
23
+ max-width: 900px;
24
+ padding-top: 2rem;
25
+ padding-right: 2rem;
26
+ padding-left: 2rem;
27
+ padding-bottom: 2rem;
28
+ }
29
+ </style>
30
+ """, unsafe_allow_html=True)
31
+
32
  # ==============================================================================
33
  # Model Loading (Cached)
34
  # ==============================================================================
 
40
  model, processor = load_model()
41
 
42
  # ==============================================================================
43
+ # Core Image Processing Functions (Unchanged)
44
  # ==============================================================================
45
  def order_points(pts):
46
  xSorted = pts[np.argsort(pts[:, 0]), :]
 
84
  return cv2.rotate(image, angle_map[rotation])
85
  return image
86
  except Exception:
 
87
  gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
88
  thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
89
  orientations = {0: thresh, 90: cv2.rotate(thresh, cv2.ROTATE_90_CLOCKWISE), 180: cv2.rotate(thresh, cv2.ROTATE_180), 270: cv2.rotate(thresh, cv2.ROTATE_90_COUNTERCLOCKWISE)}
90
  best_rotation, max_horizontal_boxes = 0, -1
91
  for angle, rotated_img in orientations.items():
92
+ try:
93
+ data = pytesseract.image_to_data(rotated_img, output_type=pytesseract.Output.DICT, timeout=5)
94
+ horizontal_boxes = sum(1 for i, conf in enumerate(data['conf']) if int(conf) > 10 and data['width'][i] > data['height'][i])
95
+ if horizontal_boxes > max_horizontal_boxes:
96
+ max_horizontal_boxes, best_rotation = horizontal_boxes, angle
97
+ except Exception:
98
+ continue
99
  angle_map = {90: cv2.ROTATE_90_CLOCKWISE, 180: cv2.ROTATE_180, 270: cv2.ROTATE_90_COUNTERCLOCKWISE}
100
  return cv2.rotate(image, angle_map[best_rotation]) if best_rotation > 0 else image
101
 
 
107
  outputs = model(**inputs)
108
  target_sizes = torch.tensor([image_pil.size[::-1]])
109
  results = processor.post_process_object_detection(outputs, threshold=0.6, target_sizes=target_sizes)[0]
 
110
  img_with_boxes = image_bgr.copy()
111
+ colors = {"table row": (0, 255, 0), "table column": (255, 0, 0), "table": (255, 0, 255)}
112
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
113
  class_name = model.config.id2label[label.item()]
114
  if class_name in colors:
 
132
  st.title("πŸ€– Document AI Toolkit")
133
  st.markdown("---")
134
 
135
+ if st.button("πŸ”„ Start Over", use_container_width=True):
136
+ for key in list(st.session_state.keys()):
137
+ del st.session_state[key]
138
+ st.rerun()
139
+
140
  if st.session_state.stage == "upload":
141
  st.header("Step 1: Upload Image")
142
+ uploaded_file = st.file_uploader("Upload your document", type=["jpg", "jpeg", "png"], label_visibility="collapsed")
143
  if uploaded_file:
144
  file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
145
  st.session_state.original_image = cv2.imdecode(file_bytes, 1)
146
  st.session_state.stage = "processing"
147
  st.rerun()
 
148
  elif st.session_state.stage == "processing":
149
  st.header("Step 2: Pre-process")
150
+ if st.button("▢️ Start Pre-processing", use_container_width=True, type="primary"):
151
+ with st.spinner("Straightening & correcting orientation..."):
 
152
  original_image = st.session_state.original_image
153
  straightened = find_and_straighten_document(original_image)
154
  image_to_orient = straightened if straightened is not None and straightened.size > 0 else original_image
155
  st.session_state.processed_image = correct_orientation(image_to_orient)
156
  st.session_state.stage = "analysis"
157
  st.rerun()
 
158
  elif st.session_state.stage == "analysis":
159
  st.header("Step 3: Analyze Table")
160
+ if st.button("πŸ“Š Find Table Structure", use_container_width=True, type="primary"):
 
161
  with st.spinner("Running Table Transformer model..."):
162
  st.session_state.annotated_image = extract_and_draw_table_structure(st.session_state.processed_image)
163
  st.session_state.stage = "done"
164
  st.rerun()
165
 
 
 
 
 
 
 
166
  # --- Main Panel Display ---
167
+ st.title("Document Processing Workflow")
168
+
169
+ # Step 1: Upload
170
+ expander1 = st.expander("Step 1: Upload Original Image", expanded=(st.session_state.stage == "upload"))
171
+ with expander1:
172
+ if st.session_state.original_image is None:
173
+ st.info("Please upload a document image using the sidebar to begin.")
174
+ else:
175
+ st.image(cv2.cvtColor(st.session_state.original_image, cv2.COLOR_BGR2RGB), use_container_width=True)
176
+ st.success("Image uploaded successfully.")
177
+
178
+ # Step 2: Pre-process
179
  if st.session_state.original_image is not None:
180
+ expander2 = st.expander("Step 2: Pre-process Document", expanded=(st.session_state.stage == "processing" or st.session_state.stage == "analysis"))
181
+ with expander2:
182
+ if st.session_state.processed_image is None:
183
+ st.info("Click 'Start Pre-processing' in the sidebar.")
184
+ else:
185
+ st.image(cv2.cvtColor(st.session_state.processed_image, cv2.COLOR_BGR2RGB), caption="Straightened & Oriented", use_container_width=True)
186
+ st.success("Pre-processing complete.")
187
+
188
+ # Step 3: Analysis
189
  if st.session_state.processed_image is not None:
190
+ expander3 = st.expander("Step 3: Analyze Table Structure", expanded=(st.session_state.stage == "done"))
191
+ with expander3:
192
+ if st.session_state.annotated_image is None:
193
+ st.info("Click 'Find Table Structure' in the sidebar to run the analysis.")
194
+ else:
195
+ tab1, tab2 = st.tabs(["βœ… Corrected Document", "πŸ“Š Table Structure"])
196
+ with tab1:
197
+ st.image(cv2.cvtColor(st.session_state.processed_image, cv2.COLOR_BGR2RGB), use_container_width=True)
198
+ _, buf = cv2.imencode(".jpg", st.session_state.processed_image)
199
+ st.download_button("πŸ“₯ Download Clean Image", data=buf.tobytes(), file_name="corrected.jpg", mime="image/jpeg", use_container_width=True)
200
+ with tab2:
201
+ st.image(cv2.cvtColor(st.session_state.annotated_image, cv2.COLOR_BGR2RGB), use_container_width=True)
202
+ st.success("Analysis complete.")