Knightmovies commited on
Commit
99975be
Β·
verified Β·
1 Parent(s): c726dc7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +301 -120
app.py CHANGED
@@ -1,145 +1,326 @@
1
  import gradio as gr
2
  import numpy as np
3
- import torch
4
- from doctr.models import ocr_predictor
5
- from doctr.utils.visualization import visualize_page
6
- import tempfile
7
  import cv2
 
8
  from PIL import Image
 
 
9
 
10
- print("Initializing Doctr model... This will download ~170MB of files on the first startup and can be slow.")
11
- # Load the pre-trained Doctr AI model.
12
- model = ocr_predictor(
13
- det_arch='db_resnet50',
14
- reco_arch='crnn_vgg16_bn',
15
- pretrained=True,
16
- detect_orientation=True,
17
- assume_straight_pages=False
18
- )
19
- print("βœ… Doctr model is ready.")
 
 
 
 
20
 
21
- def process_image_with_doctr(input_image_pil):
22
- """
23
- Processes an image using Doctr library to extract text and visualize detections.
24
- """
25
- if input_image_pil is None:
26
- return None, None
27
 
28
- # Convert PIL Image to RGB NumPy array
29
- input_image_numpy = np.array(input_image_pil)
 
 
30
 
31
- # Process the document with the AI model
32
- result = model([input_image_numpy])
 
 
33
 
34
- # Get the first page results
35
- page = result.pages[0]
 
36
 
37
- # Method 1: Create visualization with detected text boxes
38
- try:
39
- # Use doctr's built-in visualization
40
- visualized_image = visualize_page(page.export(), input_image_numpy)
41
- final_image_rgb = visualized_image
42
- except Exception as e:
43
- print(f"Visualization error: {e}")
44
- # Fallback: return original image
45
- final_image_rgb = input_image_numpy
46
 
47
- # Convert to BGR for OpenCV saving
48
- final_image_bgr = cv2.cvtColor(final_image_rgb, cv2.COLOR_RGB_BGR)
 
49
 
50
- # Save to temporary file
51
- with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file:
52
- cv2.imwrite(tmp_file.name, final_image_bgr)
53
- return final_image_rgb, tmp_file.name
54
 
55
- # Alternative version that extracts text and draws boxes manually
56
- def process_image_with_manual_boxes(input_image_pil):
57
- """
58
- Alternative approach: manually draw bounding boxes on detected text.
59
- """
60
- if input_image_pil is None:
61
- return None, None
62
-
63
- input_image_numpy = np.array(input_image_pil)
64
- result = model([input_image_numpy])
65
-
66
- # Create a copy of the original image to draw on
67
- output_image = input_image_numpy.copy()
68
- h, w = output_image.shape[:2]
69
-
70
- # Extract text and draw bounding boxes
71
- page = result.pages[0]
72
-
73
- for block in page.blocks:
74
- for line in block.lines:
75
- for word in line.words:
76
- # Get word geometry (normalized coordinates)
77
- geometry = word.geometry
78
- # Convert normalized coordinates to pixel coordinates
79
- x1, y1 = int(geometry[0][0] * w), int(geometry[0][1] * h)
80
- x2, y2 = int(geometry[1][0] * w), int(geometry[1][1] * h)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
- # Draw rectangle around detected word
83
- cv2.rectangle(output_image, (x1, y1), (x2, y2), (0, 255, 0), 2)
 
 
 
84
 
85
- # Optionally add text
86
- cv2.putText(output_image, word.value, (x1, y1-5),
87
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
88
-
89
- # Save to temporary file
90
- final_image_bgr = cv2.cvtColor(output_image, cv2.COLOR_RGB_BGR)
91
- with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file:
92
- cv2.imwrite(tmp_file.name, final_image_bgr)
93
- return output_image, tmp_file.name
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
- # Text extraction function
96
- def extract_text_from_image(input_image_pil):
97
  """
98
- Extract text from image and return both visualization and plain text.
 
 
 
 
99
  """
100
  if input_image_pil is None:
101
- return None, None, ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
- input_image_numpy = np.array(input_image_pil)
104
- result = model([input_image_numpy])
 
 
 
 
105
 
106
- # Extract all text
107
- extracted_text = ""
108
- for page in result.pages:
109
- for block in page.blocks:
110
- for line in block.lines:
111
- line_text = " ".join([word.value for word in line.words])
112
- extracted_text += line_text + "\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
- # Create visualization
115
- page = result.pages[0]
116
- try:
117
- visualized_image = visualize_page(page.export(), input_image_numpy)
118
- except:
119
- visualized_image = process_image_with_manual_boxes(input_image_pil)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
- # Save visualization
122
- final_image_bgr = cv2.cvtColor(visualized_image, cv2.COLOR_RGB_BGR)
123
- with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file:
124
- cv2.imwrite(tmp_file.name, final_image_bgr)
125
- return visualized_image, tmp_file.name, extracted_text.strip()
126
-
127
- # ==============================================================================
128
- # Gradio Interface
129
- # ==============================================================================
130
- demo = gr.Interface(
131
- fn=extract_text_from_image,
132
- inputs=gr.Image(type="pil", label="Upload Document Photo"),
133
- outputs=[
134
- gr.Image(type="numpy", label="Text Detection Visualization"),
135
- gr.File(label="Download Visualization"),
136
- gr.Textbox(label="Extracted Text", lines=10)
137
- ],
138
- title="πŸ“„ AI-Powered Document Scanner & OCR",
139
- description="Upload a document image to detect and extract text using the Doctr deep learning library. The tool will show detected text regions and provide the extracted text.",
140
- flagging_options=None
141
- )
142
 
 
143
  if __name__ == "__main__":
144
- demo.launch()
145
-
 
1
  import gradio as gr
2
  import numpy as np
 
 
 
 
3
  import cv2
4
+ import tempfile
5
  from PIL import Image
6
+ import math
7
+ import os
8
 
9
+ def order_points(pts):
10
+ """Order points in top-left, top-right, bottom-right, bottom-left order"""
11
+ rect = np.zeros((4, 2), dtype="float32")
12
+
13
+ # Sum and difference to find corners
14
+ s = pts.sum(axis=1)
15
+ diff = np.diff(pts, axis=1)
16
+
17
+ rect[0] = pts[np.argmin(s)] # top-left
18
+ rect[2] = pts[np.argmax(s)] # bottom-right
19
+ rect[1] = pts[np.argmin(diff)] # top-right
20
+ rect[3] = pts[np.argmax(diff)] # bottom-left
21
+
22
+ return rect
23
 
24
+ def four_point_transform(image, pts):
25
+ """Apply perspective transformation to get bird's eye view"""
26
+ rect = order_points(pts)
27
+ (tl, tr, br, bl) = rect
 
 
28
 
29
+ # Compute width of new image
30
+ widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
31
+ widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
32
+ maxWidth = max(int(widthA), int(widthB))
33
 
34
+ # Compute height of new image
35
+ heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
36
+ heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
37
+ maxHeight = max(int(heightA), int(heightB))
38
 
39
+ # Ensure minimum dimensions
40
+ maxWidth = max(maxWidth, 100)
41
+ maxHeight = max(maxHeight, 100)
42
 
43
+ # Destination points for perspective transform
44
+ dst = np.array([
45
+ [0, 0],
46
+ [maxWidth - 1, 0],
47
+ [maxWidth - 1, maxHeight - 1],
48
+ [0, maxHeight - 1]], dtype="float32")
 
 
 
49
 
50
+ # Perspective transformation
51
+ M = cv2.getPerspectiveTransform(rect, dst)
52
+ warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight))
53
 
54
+ return warped
 
 
 
55
 
56
+ def detect_document_edges(image):
57
+ """Detect document edges using contour detection"""
58
+ # Convert to grayscale
59
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
60
+
61
+ # Apply Gaussian blur
62
+ blurred = cv2.GaussianBlur(gray, (5, 5), 0)
63
+
64
+ # Edge detection
65
+ edged = cv2.Canny(blurred, 75, 200)
66
+
67
+ # Morphological operations to close gaps
68
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
69
+ edged = cv2.morphologyEx(edged, cv2.MORPH_CLOSE, kernel)
70
+
71
+ # Find contours
72
+ contours, _ = cv2.findContours(edged, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
73
+
74
+ if not contours:
75
+ # Fallback to image corners
76
+ h, w = image.shape[:2]
77
+ return np.array([[0, 0], [w-1, 0], [w-1, h-1], [0, h-1]], dtype="float32")
78
+
79
+ # Sort contours by area (largest first)
80
+ contours = sorted(contours, key=cv2.contourArea, reverse=True)
81
+
82
+ # Find the largest rectangular contour
83
+ for contour in contours:
84
+ # Skip very small contours
85
+ if cv2.contourArea(contour) < 1000:
86
+ continue
87
+
88
+ # Approximate contour
89
+ epsilon = 0.02 * cv2.arcLength(contour, True)
90
+ approx = cv2.approxPolyDP(contour, epsilon, True)
91
+
92
+ # If we found a 4-sided contour, it's likely our document
93
+ if len(approx) == 4:
94
+ return approx.reshape(4, 2).astype("float32")
95
+
96
+ # If no rectangular contour found, use image corners
97
+ h, w = image.shape[:2]
98
+ return np.array([[0, 0], [w-1, 0], [w-1, h-1], [0, h-1]], dtype="float32")
99
+
100
+ def enhance_document(image):
101
+ """Enhance the document image for better readability"""
102
+ try:
103
+ # Convert to LAB color space
104
+ lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
105
+ l, a, b = cv2.split(lab)
106
+
107
+ # Apply CLAHE to L channel
108
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
109
+ l = clahe.apply(l)
110
+
111
+ # Merge channels and convert back to RGB
112
+ enhanced = cv2.merge([l, a, b])
113
+ enhanced = cv2.cvtColor(enhanced, cv2.COLOR_LAB2RGB)
114
+
115
+ return enhanced
116
+ except:
117
+ # Fallback: simple contrast enhancement
118
+ return cv2.convertScaleAbs(image, alpha=1.2, beta=10)
119
+
120
+ def auto_rotate_image(image):
121
+ """Auto-rotate image to correct orientation using text line detection"""
122
+ try:
123
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
124
+
125
+ # Detect lines using HoughLinesP
126
+ edges = cv2.Canny(gray, 50, 150, apertureSize=3)
127
+ lines = cv2.HoughLinesP(edges, 1, np.pi/180, threshold=100, minLineLength=100, maxLineGap=10)
128
+
129
+ if lines is not None and len(lines) > 0:
130
+ angles = []
131
+ for line in lines:
132
+ x1, y1, x2, y2 = line[0]
133
+ angle = math.atan2(y2 - y1, x2 - x1)
134
+ angles.append(angle)
135
+
136
+ # Get median angle
137
+ if angles:
138
+ median_angle = np.median(angles)
139
+ angle_deg = np.degrees(median_angle)
140
 
141
+ # Correct angle to nearest 90-degree orientation
142
+ if angle_deg > 45:
143
+ angle_deg -= 90
144
+ elif angle_deg < -45:
145
+ angle_deg += 90
146
 
147
+ # Rotate image if significant rotation detected
148
+ if abs(angle_deg) > 1: # Only rotate if angle > 1 degree
149
+ h, w = image.shape[:2]
150
+ center = (w // 2, h // 2)
151
+ M = cv2.getRotationMatrix2D(center, angle_deg, 1.0)
152
+
153
+ # Calculate new image dimensions
154
+ cos = np.abs(M[0, 0])
155
+ sin = np.abs(M[0, 1])
156
+ new_w = int((h * sin) + (w * cos))
157
+ new_h = int((h * cos) + (w * sin))
158
+
159
+ # Adjust rotation matrix for new center
160
+ M[0, 2] += (new_w / 2) - center[0]
161
+ M[1, 2] += (new_h / 2) - center[1]
162
+
163
+ rotated = cv2.warpAffine(image, M, (new_w, new_h), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REFLECT)
164
+ return rotated
165
+ except Exception as e:
166
+ print(f"Auto-rotation failed: {e}")
167
+
168
+ return image
169
 
170
+ def scan_document(input_image_pil):
 
171
  """
172
+ Complete document scanning pipeline:
173
+ 1. Detect document edges
174
+ 2. Apply perspective correction
175
+ 3. Auto-rotate for correct orientation
176
+ 4. Enhance image quality
177
  """
178
  if input_image_pil is None:
179
+ return None, None, "❌ No image uploaded"
180
+
181
+ try:
182
+ # Convert PIL to numpy array
183
+ image = np.array(input_image_pil)
184
+ original_image = image.copy()
185
+
186
+ # Validate image
187
+ if image.size == 0:
188
+ return original_image, None, "❌ Invalid image"
189
+
190
+ # Step 1: Auto-rotate to correct orientation
191
+ print("πŸ”„ Auto-rotating image...")
192
+ rotated_image = auto_rotate_image(image)
193
+
194
+ # Step 2: Detect document edges
195
+ print("πŸ“ Detecting document edges...")
196
+ edges = detect_document_edges(rotated_image)
197
+
198
+ # Step 3: Apply perspective transformation
199
+ print("βœ‚οΈ Applying perspective correction...")
200
+ scanned = four_point_transform(rotated_image, edges)
201
+
202
+ # Step 4: Enhance the scanned document
203
+ print("✨ Enhancing document...")
204
+ enhanced = enhance_document(scanned)
205
+
206
+ # Save to temporary file
207
+ enhanced_bgr = cv2.cvtColor(enhanced, cv2.COLOR_RGB2BGR)
208
+ temp_path = tempfile.mktemp(suffix=".jpg")
209
+ cv2.imwrite(temp_path, enhanced_bgr)
210
+
211
+ return enhanced, temp_path, "βœ… Document scanned successfully!"
212
+
213
+ except Exception as e:
214
+ print(f"Error in scan_document: {e}")
215
+ return original_image if 'original_image' in locals() else None, None, f"❌ Error: {str(e)}"
216
+
217
+ # Custom CSS for better UI
218
+ custom_css = """
219
+ #image_upload {
220
+ max-height: 400px !important;
221
+ }
222
+
223
+ .gradio-container {
224
+ max-width: 1200px !important;
225
+ margin: auto !important;
226
+ }
227
+
228
+ #output_image {
229
+ max-height: 500px !important;
230
+ }
231
+
232
+ .primary {
233
+ background: linear-gradient(45deg, #4CAF50, #45a049) !important;
234
+ border: none !important;
235
+ }
236
+
237
+ .primary:hover {
238
+ background: linear-gradient(45deg, #45a049, #4CAF50) !important;
239
+ transform: translateY(-2px) !important;
240
+ }
241
+ """
242
+
243
+ # Create Gradio interface
244
+ with gr.Blocks(css=custom_css, title="πŸ“„ AI Document Scanner", theme=gr.themes.Soft()) as demo:
245
 
246
+ gr.HTML("""
247
+ <div style="text-align: center; margin-bottom: 30px;">
248
+ <h1 style="color: #2E7D32; font-size: 2.5em; margin-bottom: 10px;">πŸ“„ AI Document Scanner</h1>
249
+ <p style="color: #666; font-size: 1.2em;">Professional document scanning with automatic perspective correction, rotation, and enhancement</p>
250
+ </div>
251
+ """)
252
 
253
+ with gr.Row():
254
+ with gr.Column(scale=1):
255
+ gr.HTML("<h3 style='color: #1976D2; text-align: center;'>πŸ“€ Upload Document</h3>")
256
+
257
+ input_image = gr.Image(
258
+ type="pil",
259
+ label="Upload your document photo",
260
+ elem_id="image_upload",
261
+ height=400,
262
+ sources=["upload", "webcam"]
263
+ )
264
+
265
+ scan_btn = gr.Button(
266
+ "πŸ” Scan Document",
267
+ variant="primary",
268
+ size="lg"
269
+ )
270
+
271
+ status_text = gr.Textbox(
272
+ label="πŸ“Š Status",
273
+ value="Ready to scan documents",
274
+ interactive=False,
275
+ lines=2
276
+ )
277
+
278
+ with gr.Column(scale=1):
279
+ gr.HTML("<h3 style='color: #1976D2; text-align: center;'>πŸ“‹ Scanned Result</h3>")
280
+
281
+ output_image = gr.Image(
282
+ type="numpy",
283
+ label="Scanned Document",
284
+ elem_id="output_image",
285
+ height=400
286
+ )
287
+
288
+ download_file = gr.File(
289
+ label="πŸ“₯ Download Scanned Document"
290
+ )
291
 
292
+ # Features section
293
+ gr.HTML("""
294
+ <div style="margin-top: 30px; padding: 20px; background: linear-gradient(135deg, #E8F5E8, #F0F8FF); border-radius: 15px;">
295
+ <h3 style="color: #2E7D32; text-align: center; margin-bottom: 15px;">✨ Key Features</h3>
296
+ <div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(250px, 1fr)); gap: 15px;">
297
+ <div style="text-align: center;">
298
+ <span style="font-size: 2em;">πŸ”„</span>
299
+ <p><strong>Auto Rotation</strong><br>Automatically detects and corrects orientation</p>
300
+ </div>
301
+ <div style="text-align: center;">
302
+ <span style="font-size: 2em;">πŸ“</span>
303
+ <p><strong>Perspective Correction</strong><br>Straightens tilted and skewed documents</p>
304
+ </div>
305
+ <div style="text-align: center;">
306
+ <span style="font-size: 2em;">βœ‚οΈ</span>
307
+ <p><strong>Smart Cropping</strong><br>Automatically crops to document boundaries</p>
308
+ </div>
309
+ <div style="text-align: center;">
310
+ <span style="font-size: 2em;">✨</span>
311
+ <p><strong>Enhancement</strong><br>Improves contrast and readability</p>
312
+ </div>
313
+ </div>
314
+ </div>
315
+ """)
316
 
317
+ # Set up the scanning function
318
+ scan_btn.click(
319
+ fn=scan_document,
320
+ inputs=[input_image],
321
+ outputs=[output_image, download_file, status_text]
322
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
 
324
+ # Launch the app
325
  if __name__ == "__main__":
326
+ demo.launch()