Knightmovies commited on
Commit
28c6b80
Β·
verified Β·
1 Parent(s): 65c6a60

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -105
app.py CHANGED
@@ -8,34 +8,30 @@ import pytesseract
8
  from scipy.spatial import distance as dist
9
 
10
  # ==============================================================================
11
- # App Configuration & Model Loading
12
  # ==============================================================================
13
-
14
  st.set_page_config(
15
  page_title="Document AI Toolkit",
16
  page_icon="πŸ€–",
17
  layout="wide"
18
  )
19
 
20
- # Use Streamlit's caching to load the model only once.
 
 
21
  @st.cache_resource
22
  def load_model():
23
  """Loads the Table Transformer model and processor."""
24
- st.write("Cache miss: Loading Table Transformer model...")
25
- processor = DetrImageProcessor.from_pretrained("microsoft/table-transformer-structure-recognition")
26
- model = TableTransformerForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition")
27
- return processor, model
28
 
29
- processor, model = load_model()
30
 
31
  # ==============================================================================
32
- # Core Image Processing Functions (Unchanged)
33
  # ==============================================================================
34
-
35
  def order_points(pts):
36
  xSorted = pts[np.argsort(pts[:, 0]), :]
37
- leftMost = xSorted[:2, :]
38
- rightMost = xSorted[2:, :]
39
  leftMost = leftMost[np.argsort(leftMost[:, 1]), :]
40
  (tl, bl) = leftMost
41
  D = dist.cdist(tl[np.newaxis], rightMost, "euclidean")[0]
@@ -53,8 +49,7 @@ def perspective_transform(image, pts):
53
  maxHeight = max(int(heightA), int(heightB))
54
  dst = np.array([[0, 0], [maxWidth - 1, 0], [maxWidth - 1, maxHeight - 1], [0, maxHeight - 1]], dtype="float32")
55
  M = cv2.getPerspectiveTransform(rect, dst)
56
- warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight))
57
- return warped
58
 
59
  def find_and_straighten_document(image):
60
  gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
@@ -63,127 +58,124 @@ def find_and_straighten_document(image):
63
  if not contours: return None
64
  page_contour = max(contours, key=cv2.contourArea)
65
  if cv2.contourArea(page_contour) < (image.shape[0] * image.shape[1] * 0.1): return None
66
- rect = cv2.minAreaRect(page_contour)
67
- box = cv2.boxPoints(rect)
68
  return perspective_transform(image, box)
69
 
70
  def correct_orientation(image):
 
71
  try:
72
- osd = pytesseract.image_to_osd(image, output_type=pytesseract.Output.DICT)
73
  rotation = osd['rotate']
74
- if rotation in [90, 180, 270]:
75
- if rotation == 90:
76
- rotated_image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
77
- elif rotation == 180:
78
- rotated_image = cv2.rotate(image, cv2.ROTATE_180)
79
- else: # 270
80
- rotated_image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
81
- return rotated_image
82
- except Exception as e:
83
- st.warning(f"OSD check failed: {e}. Using original orientation.")
84
- return image
 
 
 
 
 
 
85
 
86
  def extract_and_draw_table_structure(image_bgr):
 
87
  image_pil = Image.fromarray(cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB))
88
  inputs = processor(images=image_pil, return_tensors="pt")
89
  with torch.no_grad():
90
  outputs = model(**inputs)
91
  target_sizes = torch.tensor([image_pil.size[::-1]])
92
- results = processor.post_process_object_detection(outputs, threshold=0.7, target_sizes=target_sizes)[0]
 
93
  img_with_boxes = image_bgr.copy()
94
- colors = {"table row": (0, 255, 0), "table column": (0, 0, 255), "table": (255, 0, 255)}
95
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
96
  class_name = model.config.id2label[label.item()]
97
  if class_name in colors:
98
  xmin, ymin, xmax, ymax = [int(val) for val in box.tolist()]
99
- color = colors[class_name]
100
- cv2.rectangle(img_with_boxes, (xmin, ymin), (xmax, ymax), color, 2)
101
  return img_with_boxes
102
 
103
  # ==============================================================================
104
- # UI Functions for Each Step
105
  # ==============================================================================
106
 
107
- def initialize_state():
108
- """Initializes the session state."""
109
- if "stage" not in st.session_state:
110
- st.session_state.stage = "upload"
111
- st.session_state.original_image = None
112
- st.session_state.processed_image = None
113
-
114
- def reset_app():
115
- """Resets the app to the initial upload stage."""
116
- for key in st.session_state.keys():
117
- del st.session_state[key]
118
- initialize_state()
119
-
120
- # --- Main App UI ---
121
- initialize_state()
122
-
123
- st.title("πŸ€– Document AI Toolkit")
124
- st.markdown("---")
125
-
126
- # Use columns for a centered and constrained layout
127
- left_col, main_col, right_col = st.columns([1, 4, 1])
128
-
129
- with main_col:
130
- # --- STAGE 1: UPLOAD ---
131
  if st.session_state.stage == "upload":
132
- st.header("Step 1: Upload Your Document")
133
- uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
134
-
135
  if uploaded_file:
136
  file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
137
  st.session_state.original_image = cv2.imdecode(file_bytes, 1)
138
- st.image(cv2.cvtColor(st.session_state.original_image, cv2.COLOR_BGR2RGB), caption="Original Upload", use_container_width=True)
139
-
140
- if st.button("▢️ Start Pre-processing"):
141
- st.session_state.stage = "process"
142
- st.rerun()
143
-
144
- # --- STAGE 2: PRE-PROCESSING ---
145
- elif st.session_state.stage == "process":
146
- st.header("Step 2: Pre-processing Result")
147
- with st.spinner("Straightening and correcting orientation..."):
148
- original_image = st.session_state.original_image
149
- straightened = find_and_straighten_document(original_image)
150
- image_to_orient = straightened if straightened is not None and straightened.size > 0 else original_image
151
- st.session_state.processed_image = correct_orientation(image_to_orient)
152
-
153
- st.image(cv2.cvtColor(st.session_state.processed_image, cv2.COLOR_BGR2RGB), caption="Corrected Document", use_container_width=True)
154
- st.info("The document has been straightened and oriented.")
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  if st.button("πŸ“Š Find Table Structure"):
157
- st.session_state.stage = "analyze"
 
 
158
  st.rerun()
159
-
160
- if st.button("↩️ Upload New Image"):
161
- reset_app()
 
 
162
  st.rerun()
163
 
164
- # --- STAGE 3: ANALYSIS ---
165
- elif st.session_state.stage == "analyze":
166
- st.header("Step 3: Table Structure Analysis")
167
- processed_image = st.session_state.processed_image
168
- with st.spinner("Running Table Transformer model... This can take a moment."):
169
- annotated_image = extract_and_draw_table_structure(processed_image)
170
 
171
- st.subheader("Final Results")
172
-
173
- # Display results side-by-side
174
- res_col1, res_col2 = st.columns(2)
175
- with res_col1:
176
- st.image(cv2.cvtColor(processed_image, cv2.COLOR_BGR2RGB), caption="Cleaned Document", use_container_width=True)
177
- _, buf = cv2.imencode(".jpg", processed_image)
178
- st.download_button(
179
- label="πŸ“₯ Download Clean Image",
180
- data=buf.tobytes(),
181
- file_name="corrected_document.jpg",
182
- mime="image/jpeg",
183
- )
184
- with res_col2:
185
- st.image(cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB), caption="Detected Table Structure", use_container_width=True)
 
 
 
 
186
 
187
- if st.button("πŸ”„ Start Over"):
188
- reset_app()
189
- st.rerun()
 
8
  from scipy.spatial import distance as dist
9
 
10
  # ==============================================================================
11
+ # App Configuration
12
  # ==============================================================================
 
13
  st.set_page_config(
14
  page_title="Document AI Toolkit",
15
  page_icon="πŸ€–",
16
  layout="wide"
17
  )
18
 
19
+ # ==============================================================================
20
+ # Model Loading (Cached)
21
+ # ==============================================================================
22
  @st.cache_resource
23
  def load_model():
24
  """Loads the Table Transformer model and processor."""
25
+ return TableTransformerForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition"), DetrImageProcessor.from_pretrained("microsoft/table-transformer-structure-recognition")
 
 
 
26
 
27
+ model, processor = load_model()
28
 
29
  # ==============================================================================
30
+ # Core Image Processing Functions
31
  # ==============================================================================
 
32
  def order_points(pts):
33
  xSorted = pts[np.argsort(pts[:, 0]), :]
34
+ leftMost, rightMost = xSorted[:2, :], xSorted[2:, :]
 
35
  leftMost = leftMost[np.argsort(leftMost[:, 1]), :]
36
  (tl, bl) = leftMost
37
  D = dist.cdist(tl[np.newaxis], rightMost, "euclidean")[0]
 
49
  maxHeight = max(int(heightA), int(heightB))
50
  dst = np.array([[0, 0], [maxWidth - 1, 0], [maxWidth - 1, maxHeight - 1], [0, maxHeight - 1]], dtype="float32")
51
  M = cv2.getPerspectiveTransform(rect, dst)
52
+ return cv2.warpPerspective(image, M, (maxWidth, maxHeight))
 
53
 
54
  def find_and_straighten_document(image):
55
  gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
 
58
  if not contours: return None
59
  page_contour = max(contours, key=cv2.contourArea)
60
  if cv2.contourArea(page_contour) < (image.shape[0] * image.shape[1] * 0.1): return None
61
+ box = cv2.boxPoints(cv2.minAreaRect(page_contour))
 
62
  return perspective_transform(image, box)
63
 
64
  def correct_orientation(image):
65
+ """Robust orientation correction using a cascade approach."""
66
  try:
67
+ osd = pytesseract.image_to_osd(image, output_type=pytesseract.Output.DICT, timeout=5)
68
  rotation = osd['rotate']
69
+ if rotation > 0:
70
+ angle_map = {90: cv2.ROTATE_90_COUNTERCLOCKWISE, 180: cv2.ROTATE_180, 270: cv2.ROTATE_90_CLOCKWISE}
71
+ return cv2.rotate(image, angle_map[rotation])
72
+ return image
73
+ except Exception:
74
+ # Fallback to bounding box method if OSD fails
75
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
76
+ thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
77
+ orientations = {0: thresh, 90: cv2.rotate(thresh, cv2.ROTATE_90_CLOCKWISE), 180: cv2.rotate(thresh, cv2.ROTATE_180), 270: cv2.rotate(thresh, cv2.ROTATE_90_COUNTERCLOCKWISE)}
78
+ best_rotation, max_horizontal_boxes = 0, -1
79
+ for angle, rotated_img in orientations.items():
80
+ data = pytesseract.image_to_data(rotated_img, output_type=pytesseract.Output.DICT, timeout=5)
81
+ horizontal_boxes = sum(1 for i, conf in enumerate(data['conf']) if int(conf) > 10 and data['width'][i] > data['height'][i])
82
+ if horizontal_boxes > max_horizontal_boxes:
83
+ max_horizontal_boxes, best_rotation = horizontal_boxes, angle
84
+ angle_map = {90: cv2.ROTATE_90_CLOCKWISE, 180: cv2.ROTATE_180, 270: cv2.ROTATE_90_COUNTERCLOCKWISE}
85
+ return cv2.rotate(image, angle_map[best_rotation]) if best_rotation > 0 else image
86
 
87
  def extract_and_draw_table_structure(image_bgr):
88
+ """Finds and draws table structure using OpenCV."""
89
  image_pil = Image.fromarray(cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB))
90
  inputs = processor(images=image_pil, return_tensors="pt")
91
  with torch.no_grad():
92
  outputs = model(**inputs)
93
  target_sizes = torch.tensor([image_pil.size[::-1]])
94
+ results = processor.post_process_object_detection(outputs, threshold=0.6, target_sizes=target_sizes)[0]
95
+
96
  img_with_boxes = image_bgr.copy()
97
+ colors = {"table row": (0, 255, 0), "table column": (255, 0, 0), "table": (255, 0, 255)} # Red for columns
98
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
99
  class_name = model.config.id2label[label.item()]
100
  if class_name in colors:
101
  xmin, ymin, xmax, ymax = [int(val) for val in box.tolist()]
102
+ cv2.rectangle(img_with_boxes, (xmin, ymin), (xmax, ymax), colors[class_name], 2)
 
103
  return img_with_boxes
104
 
105
  # ==============================================================================
106
+ # Streamlit UI
107
  # ==============================================================================
108
 
109
+ # --- Session State Management ---
110
+ if "stage" not in st.session_state:
111
+ st.session_state.stage = "upload"
112
+ st.session_state.original_image = None
113
+ st.session_state.processed_image = None
114
+ st.session_state.annotated_image = None
115
+
116
+ # --- Sidebar Controls ---
117
+ with st.sidebar:
118
+ st.title("πŸ€– Document AI Toolkit")
119
+ st.markdown("---")
120
+
 
 
 
 
 
 
 
 
 
 
 
 
121
  if st.session_state.stage == "upload":
122
+ st.header("Step 1: Upload Image")
123
+ uploaded_file = st.file_uploader("Upload your document image", type=["jpg", "jpeg", "png"], label_visibility="collapsed")
 
124
  if uploaded_file:
125
  file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
126
  st.session_state.original_image = cv2.imdecode(file_bytes, 1)
127
+ st.session_state.stage = "processing"
128
+ st.rerun()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
+ elif st.session_state.stage == "processing":
131
+ st.header("Step 2: Pre-process")
132
+ st.info("Straightening and correcting orientation...")
133
+ if st.button("▢️ Start Pre-processing"):
134
+ with st.spinner("Working..."):
135
+ original_image = st.session_state.original_image
136
+ straightened = find_and_straighten_document(original_image)
137
+ image_to_orient = straightened if straightened is not None and straightened.size > 0 else original_image
138
+ st.session_state.processed_image = correct_orientation(image_to_orient)
139
+ st.session_state.stage = "analysis"
140
+ st.rerun()
141
+
142
+ elif st.session_state.stage == "analysis":
143
+ st.header("Step 3: Analyze Table")
144
+ st.info("Detecting table structure...")
145
  if st.button("πŸ“Š Find Table Structure"):
146
+ with st.spinner("Running Table Transformer model..."):
147
+ st.session_state.annotated_image = extract_and_draw_table_structure(st.session_state.processed_image)
148
+ st.session_state.stage = "done"
149
  st.rerun()
150
+
151
+ if st.session_state.stage != "upload":
152
+ if st.button("πŸ”„ Start Over"):
153
+ for key in list(st.session_state.keys()):
154
+ del st.session_state[key]
155
  st.rerun()
156
 
157
+ # --- Main Panel Display ---
158
+ st.header("Document Processing Stages")
 
 
 
 
159
 
160
+ if st.session_state.stage == "upload":
161
+ st.info("Please upload a document image using the sidebar to begin.")
162
+
163
+ if st.session_state.original_image is not None:
164
+ st.subheader("1. Original Image")
165
+ st.image(cv2.cvtColor(st.session_state.original_image, cv2.COLOR_BGR2RGB), use_container_width=True)
166
+
167
+ if st.session_state.processed_image is not None:
168
+ st.subheader("2. Pre-processed Image")
169
+ st.image(cv2.cvtColor(st.session_state.processed_image, cv2.COLOR_BGR2RGB), caption="Straightened & Oriented", use_container_width=True)
170
+
171
+ if st.session_state.annotated_image is not None:
172
+ st.subheader("3. Final Analysis")
173
+ tab1, tab2 = st.tabs(["βœ… Corrected Document", "πŸ“Š Table Structure"])
174
+
175
+ with tab1:
176
+ st.image(cv2.cvtColor(st.session_state.processed_image, cv2.COLOR_BGR2RGB), use_container_width=True)
177
+ _, buf = cv2.imencode(".jpg", st.session_state.processed_image)
178
+ st.download_button("πŸ“₯ Download Clean Image", data=buf.tobytes(), file_name="corrected.jpg", mime="image/jpeg")
179
 
180
+ with tab2:
181
+ st.image(cv2.cvtColor(st.session_state.annotated_image, cv2.COLOR_BGR2RGB), use_container_width=True)