Spaces:
Running
Running
import atexit | |
import functools | |
import os | |
import re | |
import tempfile | |
from queue import Queue | |
from threading import Event, Thread | |
import threading # Import threading | |
from flask import Flask, request, jsonify | |
from paddleocr import PaddleOCR | |
from PIL import Image | |
# --- NEW: Import the NLP analysis function --- | |
from nlp_service import analyze_text # Corrected import | |
# --- Configuration --- | |
LANG = 'en' # Default language, can be overridden if needed | |
NUM_WORKERS = 2 # Number of OCR worker threads | |
# --- PaddleOCR Model Manager --- | |
class PaddleOCRModelManager(object): | |
def __init__(self, | |
num_workers, | |
model_factory): | |
super().__init__() | |
self._model_factory = model_factory | |
self._queue = Queue() | |
self._workers = [] | |
self._model_initialized_event = Event() | |
print(f"Initializing {num_workers} OCR worker(s)...") | |
for i in range(num_workers): | |
print(f"Starting worker {i+1}...") | |
worker = Thread(target=self._worker, daemon=True) | |
worker.start() | |
self._model_initialized_event.wait() # Wait for this worker's model | |
self._model_initialized_event.clear() | |
self._workers.append(worker) | |
print("All OCR workers initialized.") | |
def infer(self, *args, **kwargs): | |
result_queue = Queue(maxsize=1) | |
self._queue.put((args, kwargs, result_queue)) | |
success, payload = result_queue.get() | |
if success: | |
return payload | |
else: | |
print(f"Error during OCR inference: {payload}") | |
raise payload | |
def close(self): | |
print("Shutting down OCR workers...") | |
for _ in self._workers: | |
self._queue.put(None) | |
print("OCR worker shutdown signaled.") | |
def _worker(self): | |
print(f"Worker thread {threading.current_thread().name}: Loading PaddleOCR model ({LANG})...") | |
try: | |
model = self._model_factory() | |
print(f"Worker thread {threading.current_thread().name}: Model loaded.") | |
self._model_initialized_event.set() | |
except Exception as e: | |
print(f"FATAL: Worker thread {threading.current_thread().name} failed to load model: {e}") | |
self._model_initialized_event.set() | |
return | |
while True: | |
item = self._queue.get() | |
if item is None: | |
print(f"Worker thread {threading.current_thread().name}: Exiting.") | |
break | |
args, kwargs, result_queue = item | |
try: | |
result = model.ocr(*args, **kwargs) | |
if result and result[0]: | |
result_queue.put((True, result[0])) | |
else: | |
result_queue.put((True, [])) | |
except Exception as e: | |
print(f"Worker thread {threading.current_thread().name}: Error processing request: {e}") | |
result_queue.put((False, e)) | |
finally: | |
self._queue.task_done() | |
# --- Amount Extraction Logic --- | |
def find_main_amount(ocr_results): | |
if not ocr_results: | |
return None | |
amount_regex = re.compile(r'(?<!%)\b\d{1,3}(?:,?\d{3})*(?:\.\d{2})\b|\b\d+\.\d{2}\b|\b\d+\b(?!\.\d{1})') | |
# Prioritized keywords | |
priority_keywords = ['grand total', 'total amount', 'amount due', 'to pay', 'bill total', 'total payable'] | |
secondary_keywords = ['total', 'balance', 'net amount', 'paid', 'charge', 'net total'] # Added 'net total' | |
lower_priority_keywords = ['subtotal', 'sub total'] # Added 'sub total' | |
parsed_lines = [] | |
for i, line_info in enumerate(ocr_results): | |
if not line_info or len(line_info) < 2 or len(line_info[1]) < 1: | |
continue | |
text = line_info[1][0].lower().strip() | |
confidence = line_info[1][1] | |
numbers_in_line = amount_regex.findall(text) | |
float_numbers = [] | |
for num_str in numbers_in_line: | |
try: | |
# Avoid converting year-like numbers if they stand alone on short lines | |
if len(text) < 7 and '.' not in num_str and 1900 < int(num_str.replace(',', '')) < 2100: | |
# More robust check: avoid if it's the only thing and looks like a year | |
if len(numbers_in_line) == 1 and len(num_str) == 4: | |
continue | |
float_numbers.append(float(num_str.replace(',', ''))) | |
except ValueError: | |
continue | |
# Check for keywords | |
has_priority_keyword = any(re.search(r'\b' + re.escape(kw) + r'\b', text) for kw in priority_keywords) | |
has_secondary_keyword = any(re.search(r'\b' + re.escape(kw) + r'\b', text) for kw in secondary_keywords) | |
has_lower_priority_keyword = any(re.search(r'\b' + re.escape(kw) + r'\b', text) for kw in lower_priority_keywords) | |
parsed_lines.append({ | |
"index": i, | |
"text": text, | |
"numbers": float_numbers, | |
"has_priority_keyword": has_priority_keyword, | |
"has_secondary_keyword": has_secondary_keyword, | |
"has_lower_priority_keyword": has_lower_priority_keyword, | |
"confidence": confidence | |
}) | |
# --- Strategy to find the best candidate --- | |
# 1. Look for numbers on the SAME line as PRIORITY keywords OR the line IMMEDIATELY AFTER | |
priority_candidates = [] | |
for i, line in enumerate(parsed_lines): | |
if line["has_priority_keyword"]: | |
if line["numbers"]: | |
priority_candidates.extend(line["numbers"]) | |
# Check next line if current line has no numbers and next line exists | |
elif i + 1 < len(parsed_lines) and parsed_lines[i+1]["numbers"]: | |
priority_candidates.extend(parsed_lines[i+1]["numbers"]) | |
if priority_candidates: | |
# Often the largest number on/near these lines is the final total | |
return max(priority_candidates) | |
# 2. Look for numbers on the SAME line as SECONDARY keywords OR the line IMMEDIATELY AFTER | |
secondary_candidates = [] | |
for i, line in enumerate(parsed_lines): | |
if line["has_secondary_keyword"]: | |
if line["numbers"]: | |
secondary_candidates.extend(line["numbers"]) | |
# Check next line if current line has no numbers and next line exists | |
elif i + 1 < len(parsed_lines) and parsed_lines[i+1]["numbers"]: | |
secondary_candidates.extend(parsed_lines[i+1]["numbers"]) | |
if secondary_candidates: | |
# If we only found secondary keywords, return the largest number found on/near those lines | |
return max(secondary_candidates) | |
# 3. Look near priority/secondary keywords (REMOVED - less reliable, covered by step 1 & 2) | |
# 4. Look for numbers on the SAME line as LOWER PRIORITY keywords (Subtotal) OR the line IMMEDIATELY AFTER | |
lower_priority_candidates = [] | |
for i, line in enumerate(parsed_lines): | |
if line["has_lower_priority_keyword"]: | |
if line["numbers"]: | |
lower_priority_candidates.extend(line["numbers"]) | |
# Check next line if current line has no numbers and next line exists | |
elif i + 1 < len(parsed_lines) and parsed_lines[i+1]["numbers"]: | |
lower_priority_candidates.extend(parsed_lines[i+1]["numbers"]) | |
# Don't return subtotal directly unless it's the only thing found later | |
# 5. Fallback: Largest plausible number overall (excluding subtotals if other numbers exist) | |
print("Warning: No numbers found on/near priority/secondary keyword lines. Using fallback.") | |
all_numbers = [] | |
# Use set comprehension for efficiency | |
subtotal_numbers = {num for line in parsed_lines if line["has_lower_priority_keyword"] for num in line["numbers"]} | |
# Also add numbers from the line after lower priority keywords to subtotals | |
for i, line in enumerate(parsed_lines): | |
if line["has_lower_priority_keyword"] and not line["numbers"] and i + 1 < len(parsed_lines): | |
subtotal_numbers.update(parsed_lines[i+1]["numbers"]) | |
for line in parsed_lines: | |
all_numbers.extend(line["numbers"]) | |
if all_numbers: | |
unique_numbers = list(set(all_numbers)) | |
# Filter out potential quantities/years/small irrelevant numbers | |
plausible_numbers = [n for n in unique_numbers if n >= 0.01] # Keep small decimals too | |
# Stricter filter for large numbers: exclude large integers (likely IDs, phone numbers) | |
# Keep numbers < 50000 OR numbers that have a non-zero decimal part | |
plausible_numbers = [n for n in plausible_numbers if n < 50000 or (n != int(n))] | |
# If we have plausible numbers other than subtotals, prefer them | |
non_subtotal_plausible = [n for n in plausible_numbers if n not in subtotal_numbers] | |
if non_subtotal_plausible: | |
return max(non_subtotal_plausible) | |
elif plausible_numbers: # Only subtotals (or nothing else plausible) were found | |
return max(plausible_numbers) # Return the largest subtotal/plausible as last resort | |
# 6. If still nothing, return None | |
print("Warning: Could not determine main amount.") | |
return None | |
# --- Flask App Setup --- | |
app = Flask(__name__) | |
# --- REMOVED: Register the NLP Blueprint --- | |
# app.register_blueprint(nlp_bp) # No longer needed as we call the function directly | |
# --- Initialize OCR Manager --- | |
ocr_model_factory = functools.partial(PaddleOCR, lang=LANG, use_angle_cls=True, use_gpu=False, show_log=False) | |
ocr_manager = PaddleOCRModelManager(num_workers=NUM_WORKERS, model_factory=ocr_model_factory) | |
# Register cleanup function | |
atexit.register(ocr_manager.close) | |
# --- API Endpoint --- | |
def extract_expense(): | |
if 'file' not in request.files: | |
return jsonify({"error": "No file part in the request"}), 400 | |
file = request.files['file'] | |
if file.filename == '': | |
return jsonify({"error": "No selected file"}), 400 | |
if file: | |
temp_file_path = None # Initialize variable | |
try: | |
# Save to a temporary file | |
_, file_extension = os.path.splitext(file.filename) | |
with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file: | |
file.save(temp_file.name) | |
temp_file_path = temp_file.name | |
# Perform OCR | |
ocr_result = ocr_manager.infer(temp_file_path, cls=True) | |
# Process OCR results | |
extracted_text = "" | |
main_amount_ocr = None | |
if ocr_result: | |
extracted_lines = [line[1][0] for line in ocr_result if line and len(line) > 1 and len(line[1]) > 0] | |
extracted_text = "\n".join(extracted_lines) | |
main_amount_ocr = find_main_amount(ocr_result) # Keep OCR amount extraction | |
# --- REMOVED: NLP Call --- | |
# nlp_analysis_result = None | |
# nlp_error = None | |
# ... (removed NLP call logic) ... | |
# --- End Removed NLP Call --- | |
# Construct the response (only OCR results) | |
response_data = { | |
"type": "photo", | |
"extracted_text": extracted_text, | |
"main_amount_ocr": main_amount_ocr, # Amount found by OCR regex logic | |
} | |
return jsonify(response_data) | |
except Exception as e: | |
print(f"Error processing file: {e}") | |
import traceback | |
traceback.print_exc() | |
return jsonify({"error": f"An internal error occurred: {str(e)}"}), 500 | |
finally: | |
if temp_file_path and os.path.exists(temp_file_path): | |
os.remove(temp_file_path) | |
return jsonify({"error": "File processing failed"}), 500 | |
# --- NEW: NLP Message Endpoint --- | |
def process_message(): | |
data = request.get_json() | |
if not data or 'text' not in data: | |
return jsonify({"error": "Missing 'text' field in JSON payload"}), 400 | |
text_message = data['text'] | |
if not text_message: | |
return jsonify({"error": "'text' field cannot be empty"}), 400 | |
nlp_analysis_result = None | |
nlp_error = None | |
try: | |
# Call the imported analysis function | |
nlp_analysis_result = analyze_text(text_message) # Corrected function call | |
print(f"NLP Service Analysis Result: {nlp_analysis_result}") | |
# Check if the NLP analysis itself reported an error/failure or requires fallback | |
status = nlp_analysis_result.get("status") | |
if status == "failed": | |
nlp_error = nlp_analysis_result.get("message", "NLP processing failed") | |
# Return the failure result from NLP service | |
return jsonify(nlp_analysis_result), 400 # Use 400 for client-side errors like empty text | |
elif status == "fallback_required": | |
# Return the fallback result (e.g., for queries) | |
return jsonify(nlp_analysis_result), 200 # Return 200, but indicate fallback needed | |
# Return the successful analysis result | |
return jsonify(nlp_analysis_result) | |
except Exception as nlp_e: | |
nlp_error = f"Error calling NLP analysis function: {nlp_e}" | |
print(f"Error calling NLP function: {nlp_error}") | |
return jsonify({"error": "An internal error occurred during NLP processing", "details": nlp_error}), 500 | |
# --- NEW: Health Check Endpoint --- | |
def health_check(): | |
# You could add more checks here (e.g., if OCR workers are alive) | |
return jsonify({"status": "ok"}), 200 | |
# --- Run the App --- | |
if __name__ == '__main__': | |
# Use port 7860 as expected by Hugging Face Spaces | |
# Use host='0.0.0.0' for accessibility within Docker/Spaces | |
app.run(host='0.0.0.0', port=7860, debug=False) |