#!/usr/bin/env python3 import os import json import logging import gc import fitz import requests import torch import boto3 from io import BytesIO from typing import Dict, List, Any from magic_pdf.data.dataset import PymuDocDataset from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s - %(message)s", handlers=[ logging.StreamHandler(), logging.FileHandler('topic_processor.log') ] ) logger = logging.getLogger(__name__) class s3Writer: def __init__(self, ak: str, sk: str, bucket: str, endpoint_url: str): self.bucket = bucket self.client = boto3.client( 's3', aws_access_key_id=ak, aws_secret_access_key=sk, endpoint_url=endpoint_url ) def write(self, path: str, data: bytes) -> None: try: file_obj = BytesIO(data) self.client.upload_fileobj(file_obj, self.bucket, path) logger.info(f"Uploaded to S3: {path}") except Exception as e: logger.error(f"Failed to upload to S3: {str(e)}") raise class S3ImageWriter: def __init__(self, s3_writer: s3Writer, base_path: str, gemini_api_key: str): self.s3_writer = s3_writer self.base_path = base_path if base_path.endswith("/") else base_path + "/" self.gemini_api_key = gemini_api_key self.descriptions = {} def write(self, path: str, data: bytes) -> None: full_path = f"{self.base_path}{os.path.basename(path)}" self.s3_writer.write(full_path, data) self.descriptions[path] = { "data": data, "s3_path": full_path } def post_process(self, key: str, md_content: str) -> str: for path, info in self.descriptions.items(): s3_path = info.get("s3_path") md_content = md_content.replace(f"![]({key}{path})", f"![]({s3_path})") return md_content def delete_non_heading_text(md_content: str) -> str: filtered_lines = [] for line in md_content.splitlines(): stripped = line.lstrip() if stripped.startswith('#') or stripped.startswith('![]('): filtered_lines.append(line) return "\n".join(filtered_lines) class TopicExtractionProcessor: def __init__(self, gemini_api_key: str = None): try: self.s3_writer = s3Writer( ak=os.getenv("S3_ACCESS_KEY"), sk=os.getenv("S3_SECRET_KEY"), bucket="quextro-resources", endpoint_url=os.getenv("S3_ENDPOINT") ) config_path = "/home/user/magic-pdf.json" if os.path.exists(config_path): with open(config_path, "r") as f: config = json.load(f) self.layout_model = config.get("layout-config", {}).get("model", "doclayout_yolo") self.formula_enable = config.get("formula-config", {}).get("enable", True) else: self.layout_model = "doclayout_yolo" self.formula_enable = True self.table_enable = False self.language = "en" self.gemini_api_key = gemini_api_key or os.getenv("GEMINI_API_KEY", "AIzaSyDtoakpXa2pjJwcQB6TJ5QaXHNSA5JxcrU") logger.info("TopicExtractionProcessor initialized successfully") except Exception as e: logger.error("Failed to initialize TopicExtractionProcessor: %s", str(e)) raise def cleanup_gpu(self): try: gc.collect() torch.cuda.empty_cache() logger.info("GPU memory cleaned up.") except Exception as e: logger.error("Error during GPU cleanup: %s", e) def process(self, input_file: Dict[str, Any]) -> str: try: key = input_file.get("key", "") url = input_file.get("url", "") pages = input_file.get("page", []) if not url or not pages: raise ValueError("Missing required 'url' or 'page' in input file") if url.startswith(("http://", "https://")): response = requests.get(url) response.raise_for_status() pdf_bytes = response.content else: with open(url, "rb") as f: pdf_bytes = f.read() pages = self.parse_page_range(pages) logger.info("Processing %s with pages %s", key, pages) subset_pdf = self.create_subset_pdf(pdf_bytes, pages) logger.info(f"Created subset PDF with size: {len(subset_pdf)} bytes") dataset = PymuDocDataset(subset_pdf) inference = doc_analyze( dataset, ocr=True, lang=self.language, layout_model=self.layout_model, formula_enable=self.formula_enable, table_enable=self.table_enable ) base_path = f"/topic-extraction/{key}/" writer = S3ImageWriter(self.s3_writer, "/topic-extraction/", self.gemini_api_key) md_prefix = "/topic-extraction/" pipe_result = inference.pipe_ocr_mode(writer, lang=self.language) md_content = pipe_result.get_markdown(md_prefix) post_processed = writer.post_process(md_prefix, md_content) #remove non-heading text from the markdown output final_markdown = delete_non_heading_text(post_processed) return final_markdown except Exception as e: logger.error("Processing failed for %s: %s", key, str(e)) raise finally: self.cleanup_gpu() def create_subset_pdf(self, pdf_bytes: bytes, page_indices: List[int]) -> bytes: """Create a PDF subset from specified pages""" doc = fitz.open(stream=pdf_bytes, filetype="pdf") new_doc = fitz.open() try: for p in sorted(set(page_indices)): if 0 <= p < doc.page_count: new_doc.insert_pdf(doc, from_page=p, to_page=p) else: raise ValueError(f"Page index {p} out of range (0-{doc.page_count-1})") return new_doc.tobytes() finally: new_doc.close() doc.close() def parse_page_range(self, page_field) -> List[int]: """Parse page range from input (1-indexed to 0-indexed)""" if isinstance(page_field, list): return [int(p) - 1 for p in page_field] if isinstance(page_field, str): parts = [p.strip() for p in page_field.split(',')] return [int(p) - 1 for p in parts] raise ValueError("Invalid page field type") def main(): """Local test execution without RabbitMQ""" test_input = { "key": "local_test", "url": "/home/user/app/input_output/a-level-pearson-mathematics-specification.pdf", # Local PDF path "page":[15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 34, 35, 36, 37, 38, 39, 40, 41, 42] } processor = TopicExtractionProcessor() try: logger.info("Starting test processing.") result = processor.process(test_input) logger.info("Processing completed successfully") print("Markdown:\n", result) except Exception as e: logger.error("Test failed: %s", str(e)) if __name__ == "__main__": main()