ClearSpend / app.py
MonilM's picture
Improved NLP Logic
07b50c0
import atexit
import functools
import os
import re
import tempfile
from queue import Queue
from threading import Event, Thread
import threading # Import threading
from flask import Flask, request, jsonify
from paddleocr import PaddleOCR
from PIL import Image
# --- NEW: Import the NLP analysis function ---
from nlp_service import analyze_text # Corrected import
# --- Configuration ---
LANG = 'en' # Default language, can be overridden if needed
NUM_WORKERS = 2 # Number of OCR worker threads
# --- PaddleOCR Model Manager ---
class PaddleOCRModelManager(object):
def __init__(self,
num_workers,
model_factory):
super().__init__()
self._model_factory = model_factory
self._queue = Queue()
self._workers = []
self._model_initialized_event = Event()
print(f"Initializing {num_workers} OCR worker(s)...")
for i in range(num_workers):
print(f"Starting worker {i+1}...")
worker = Thread(target=self._worker, daemon=True)
worker.start()
self._model_initialized_event.wait() # Wait for this worker's model
self._model_initialized_event.clear()
self._workers.append(worker)
print("All OCR workers initialized.")
def infer(self, *args, **kwargs):
result_queue = Queue(maxsize=1)
self._queue.put((args, kwargs, result_queue))
success, payload = result_queue.get()
if success:
return payload
else:
print(f"Error during OCR inference: {payload}")
raise payload
def close(self):
print("Shutting down OCR workers...")
for _ in self._workers:
self._queue.put(None)
print("OCR worker shutdown signaled.")
def _worker(self):
print(f"Worker thread {threading.current_thread().name}: Loading PaddleOCR model ({LANG})...")
try:
model = self._model_factory()
print(f"Worker thread {threading.current_thread().name}: Model loaded.")
self._model_initialized_event.set()
except Exception as e:
print(f"FATAL: Worker thread {threading.current_thread().name} failed to load model: {e}")
self._model_initialized_event.set()
return
while True:
item = self._queue.get()
if item is None:
print(f"Worker thread {threading.current_thread().name}: Exiting.")
break
args, kwargs, result_queue = item
try:
result = model.ocr(*args, **kwargs)
if result and result[0]:
result_queue.put((True, result[0]))
else:
result_queue.put((True, []))
except Exception as e:
print(f"Worker thread {threading.current_thread().name}: Error processing request: {e}")
result_queue.put((False, e))
finally:
self._queue.task_done()
# --- Amount Extraction Logic ---
def find_main_amount(ocr_results):
if not ocr_results:
return None
amount_regex = re.compile(r'(?<!%)\b\d{1,3}(?:,?\d{3})*(?:\.\d{2})\b|\b\d+\.\d{2}\b|\b\d+\b(?!\.\d{1})')
# Prioritized keywords
priority_keywords = ['grand total', 'total amount', 'amount due', 'to pay', 'bill total', 'total payable']
secondary_keywords = ['total', 'balance', 'net amount', 'paid', 'charge', 'net total'] # Added 'net total'
lower_priority_keywords = ['subtotal', 'sub total'] # Added 'sub total'
parsed_lines = []
for i, line_info in enumerate(ocr_results):
if not line_info or len(line_info) < 2 or len(line_info[1]) < 1:
continue
text = line_info[1][0].lower().strip()
confidence = line_info[1][1]
numbers_in_line = amount_regex.findall(text)
float_numbers = []
for num_str in numbers_in_line:
try:
# Avoid converting year-like numbers if they stand alone on short lines
if len(text) < 7 and '.' not in num_str and 1900 < int(num_str.replace(',', '')) < 2100:
# More robust check: avoid if it's the only thing and looks like a year
if len(numbers_in_line) == 1 and len(num_str) == 4:
continue
float_numbers.append(float(num_str.replace(',', '')))
except ValueError:
continue
# Check for keywords
has_priority_keyword = any(re.search(r'\b' + re.escape(kw) + r'\b', text) for kw in priority_keywords)
has_secondary_keyword = any(re.search(r'\b' + re.escape(kw) + r'\b', text) for kw in secondary_keywords)
has_lower_priority_keyword = any(re.search(r'\b' + re.escape(kw) + r'\b', text) for kw in lower_priority_keywords)
parsed_lines.append({
"index": i,
"text": text,
"numbers": float_numbers,
"has_priority_keyword": has_priority_keyword,
"has_secondary_keyword": has_secondary_keyword,
"has_lower_priority_keyword": has_lower_priority_keyword,
"confidence": confidence
})
# --- Strategy to find the best candidate ---
# 1. Look for numbers on the SAME line as PRIORITY keywords OR the line IMMEDIATELY AFTER
priority_candidates = []
for i, line in enumerate(parsed_lines):
if line["has_priority_keyword"]:
if line["numbers"]:
priority_candidates.extend(line["numbers"])
# Check next line if current line has no numbers and next line exists
elif i + 1 < len(parsed_lines) and parsed_lines[i+1]["numbers"]:
priority_candidates.extend(parsed_lines[i+1]["numbers"])
if priority_candidates:
# Often the largest number on/near these lines is the final total
return max(priority_candidates)
# 2. Look for numbers on the SAME line as SECONDARY keywords OR the line IMMEDIATELY AFTER
secondary_candidates = []
for i, line in enumerate(parsed_lines):
if line["has_secondary_keyword"]:
if line["numbers"]:
secondary_candidates.extend(line["numbers"])
# Check next line if current line has no numbers and next line exists
elif i + 1 < len(parsed_lines) and parsed_lines[i+1]["numbers"]:
secondary_candidates.extend(parsed_lines[i+1]["numbers"])
if secondary_candidates:
# If we only found secondary keywords, return the largest number found on/near those lines
return max(secondary_candidates)
# 3. Look near priority/secondary keywords (REMOVED - less reliable, covered by step 1 & 2)
# 4. Look for numbers on the SAME line as LOWER PRIORITY keywords (Subtotal) OR the line IMMEDIATELY AFTER
lower_priority_candidates = []
for i, line in enumerate(parsed_lines):
if line["has_lower_priority_keyword"]:
if line["numbers"]:
lower_priority_candidates.extend(line["numbers"])
# Check next line if current line has no numbers and next line exists
elif i + 1 < len(parsed_lines) and parsed_lines[i+1]["numbers"]:
lower_priority_candidates.extend(parsed_lines[i+1]["numbers"])
# Don't return subtotal directly unless it's the only thing found later
# 5. Fallback: Largest plausible number overall (excluding subtotals if other numbers exist)
print("Warning: No numbers found on/near priority/secondary keyword lines. Using fallback.")
all_numbers = []
# Use set comprehension for efficiency
subtotal_numbers = {num for line in parsed_lines if line["has_lower_priority_keyword"] for num in line["numbers"]}
# Also add numbers from the line after lower priority keywords to subtotals
for i, line in enumerate(parsed_lines):
if line["has_lower_priority_keyword"] and not line["numbers"] and i + 1 < len(parsed_lines):
subtotal_numbers.update(parsed_lines[i+1]["numbers"])
for line in parsed_lines:
all_numbers.extend(line["numbers"])
if all_numbers:
unique_numbers = list(set(all_numbers))
# Filter out potential quantities/years/small irrelevant numbers
plausible_numbers = [n for n in unique_numbers if n >= 0.01] # Keep small decimals too
# Stricter filter for large numbers: exclude large integers (likely IDs, phone numbers)
# Keep numbers < 50000 OR numbers that have a non-zero decimal part
plausible_numbers = [n for n in plausible_numbers if n < 50000 or (n != int(n))]
# If we have plausible numbers other than subtotals, prefer them
non_subtotal_plausible = [n for n in plausible_numbers if n not in subtotal_numbers]
if non_subtotal_plausible:
return max(non_subtotal_plausible)
elif plausible_numbers: # Only subtotals (or nothing else plausible) were found
return max(plausible_numbers) # Return the largest subtotal/plausible as last resort
# 6. If still nothing, return None
print("Warning: Could not determine main amount.")
return None
# --- Flask App Setup ---
app = Flask(__name__)
# --- REMOVED: Register the NLP Blueprint ---
# app.register_blueprint(nlp_bp) # No longer needed as we call the function directly
# --- Initialize OCR Manager ---
ocr_model_factory = functools.partial(PaddleOCR, lang=LANG, use_angle_cls=True, use_gpu=False, show_log=False)
ocr_manager = PaddleOCRModelManager(num_workers=NUM_WORKERS, model_factory=ocr_model_factory)
# Register cleanup function
atexit.register(ocr_manager.close)
# --- API Endpoint ---
@app.route('/extract_expense', methods=['POST'])
def extract_expense():
if 'file' not in request.files:
return jsonify({"error": "No file part in the request"}), 400
file = request.files['file']
if file.filename == '':
return jsonify({"error": "No selected file"}), 400
if file:
temp_file_path = None # Initialize variable
try:
# Save to a temporary file
_, file_extension = os.path.splitext(file.filename)
with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file:
file.save(temp_file.name)
temp_file_path = temp_file.name
# Perform OCR
ocr_result = ocr_manager.infer(temp_file_path, cls=True)
# Process OCR results
extracted_text = ""
main_amount_ocr = None
if ocr_result:
extracted_lines = [line[1][0] for line in ocr_result if line and len(line) > 1 and len(line[1]) > 0]
extracted_text = "\n".join(extracted_lines)
main_amount_ocr = find_main_amount(ocr_result) # Keep OCR amount extraction
# --- REMOVED: NLP Call ---
# nlp_analysis_result = None
# nlp_error = None
# ... (removed NLP call logic) ...
# --- End Removed NLP Call ---
# Construct the response (only OCR results)
response_data = {
"type": "photo",
"extracted_text": extracted_text,
"main_amount_ocr": main_amount_ocr, # Amount found by OCR regex logic
}
return jsonify(response_data)
except Exception as e:
print(f"Error processing file: {e}")
import traceback
traceback.print_exc()
return jsonify({"error": f"An internal error occurred: {str(e)}"}), 500
finally:
if temp_file_path and os.path.exists(temp_file_path):
os.remove(temp_file_path)
return jsonify({"error": "File processing failed"}), 500
# --- NEW: NLP Message Endpoint ---
@app.route('/message', methods=['POST'])
def process_message():
data = request.get_json()
if not data or 'text' not in data:
return jsonify({"error": "Missing 'text' field in JSON payload"}), 400
text_message = data['text']
if not text_message:
return jsonify({"error": "'text' field cannot be empty"}), 400
nlp_analysis_result = None
nlp_error = None
try:
# Call the imported analysis function
nlp_analysis_result = analyze_text(text_message) # Corrected function call
print(f"NLP Service Analysis Result: {nlp_analysis_result}")
# Check if the NLP analysis itself reported an error/failure or requires fallback
status = nlp_analysis_result.get("status")
if status == "failed":
nlp_error = nlp_analysis_result.get("message", "NLP processing failed")
# Return the failure result from NLP service
return jsonify(nlp_analysis_result), 400 # Use 400 for client-side errors like empty text
elif status == "fallback_required":
# Return the fallback result (e.g., for queries)
return jsonify(nlp_analysis_result), 200 # Return 200, but indicate fallback needed
# Return the successful analysis result
return jsonify(nlp_analysis_result)
except Exception as nlp_e:
nlp_error = f"Error calling NLP analysis function: {nlp_e}"
print(f"Error calling NLP function: {nlp_error}")
return jsonify({"error": "An internal error occurred during NLP processing", "details": nlp_error}), 500
# --- NEW: Health Check Endpoint ---
@app.route('/health', methods=['GET'])
def health_check():
# You could add more checks here (e.g., if OCR workers are alive)
return jsonify({"status": "ok"}), 200
# --- Run the App ---
if __name__ == '__main__':
# Use port 7860 as expected by Hugging Face Spaces
# Use host='0.0.0.0' for accessibility within Docker/Spaces
app.run(host='0.0.0.0', port=7860, debug=False)