rodrigomasini commited on
Commit
8fccd9c
·
verified ·
1 Parent(s): 763903d

Update mdr_pdf_parser.py

Browse files
Files changed (1) hide show
  1. mdr_pdf_parser.py +179 -23
mdr_pdf_parser.py CHANGED
@@ -2589,8 +2589,8 @@ class MDRExtractionEngine:
2589
  def _get_yolo_model(self) -> YOLOv10 | None:
2590
  """Loads the YOLOv10 layout detection model using hf_hub_download."""
2591
  if self._yolo is None and YOLOv10 is not None:
2592
- repo_id = "juliozhao/DocLayout-YOLO-DocStructBench"
2593
- filename = "doclayout_yolo_docstructbench_imgsz1024.pt"
2594
  # Use a subdirectory within the main model dir for YOLO cache via HF Hub
2595
  yolo_cache_dir = Path(self._model_dir) / "yolo_hf_cache"
2596
  mdr_ensure_directory(str(yolo_cache_dir)) # Ensure cache dir exists
@@ -2684,37 +2684,193 @@ class MDRExtractionEngine:
2684
  return MDRExtractionResult(rotation=optimizer.rotation, layouts=layouts, extracted_image=image,
2685
  adjusted_image=optimizer.adjusted_image)
2686
 
2687
- def _run_yolo_detection(self, img: Image, yolo: YOLOv10):
 
 
2688
  img_rgb = img.convert("RGB")
2689
- res = yolo.predict(source=img_rgb, imgsz=1024, conf=0.20,
2690
- device=self._device, verbose=False)
2691
 
2692
- if not res or not res[0].boxes:
 
 
 
 
 
 
 
 
2693
  return
2694
 
2695
- plain_classes: set[MDRLayoutClass] = {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2696
  MDRLayoutClass.TITLE,
2697
  MDRLayoutClass.PLAIN_TEXT,
2698
- MDRLayoutClass.ABANDON,
2699
- MDRLayoutClass.FIGURE_CAPTION,
2700
- MDRLayoutClass.TABLE_CAPTION,
2701
- MDRLayoutClass.TABLE_FOOTNOTE,
2702
- MDRLayoutClass.FORMULA_CAPTION,
2703
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2704
 
2705
- for cls_id_t, xyxy_t in zip(res[0].boxes.cls, res[0].boxes.xyxy):
2706
- cls = MDRLayoutClass(int(cls_id_t))
2707
- x1, y1, x2, y2 = map(float, xyxy_t)
2708
- rect = MDRRectangle((x1, y1), (x2, y1), (x1, y2), (x2, y2))
2709
- if rect.area < 10:
2710
  continue
2711
 
2712
- if cls == MDRLayoutClass.TABLE:
2713
- yield MDRTableLayoutElement(rect=rect, fragments=[], parsed=None)
2714
- elif cls == MDRLayoutClass.ISOLATE_FORMULA:
2715
- yield MDRFormulaLayoutElement(rect=rect, fragments=[], latex=None)
2716
- elif cls in plain_classes:
2717
- yield MDRPlainLayoutElement(cls=cls, rect=rect, fragments=[])
 
 
 
 
 
 
2718
 
2719
  def _match_fragments_to_layouts(self, frags: list[MDROcrFragment], layouts: list[MDRLayoutElement]) -> list[
2720
  MDRLayoutElement]:
 
2589
  def _get_yolo_model(self) -> YOLOv10 | None:
2590
  """Loads the YOLOv10 layout detection model using hf_hub_download."""
2591
  if self._yolo is None and YOLOv10 is not None:
2592
+ repo_id = "hantian/yolo-doclaynet"
2593
+ filename = "yolov10b-doclaynet.pt"
2594
  # Use a subdirectory within the main model dir for YOLO cache via HF Hub
2595
  yolo_cache_dir = Path(self._model_dir) / "yolo_hf_cache"
2596
  mdr_ensure_directory(str(yolo_cache_dir)) # Ensure cache dir exists
 
2684
  return MDRExtractionResult(rotation=optimizer.rotation, layouts=layouts, extracted_image=image,
2685
  adjusted_image=optimizer.adjusted_image)
2686
 
2687
+ # In class MDRExtractionEngine:
2688
+
2689
+ def _run_yolo_detection(self, img: Image, yolo: Any): # yolo can be doclayout_yolo.YOLOv10 or ultralytics.YOLO
2690
  img_rgb = img.convert("RGB")
 
 
2691
 
2692
+ # Standard predict call
2693
+ # The conf threshold might need adjustment based on the new model's performance
2694
+ # For DocLayNet, 'Text' is often a high-confidence class.
2695
+ res_list = yolo.predict(source=img_rgb, imgsz=1024, conf=0.25,
2696
+ # Slightly higher conf for potentially better precision
2697
+ device=self._device, verbose=False)
2698
+
2699
+ if not res_list or not hasattr(res_list[0], 'boxes') or res_list[0].boxes is None:
2700
+ print(" Engine: YOLO detection returned no results or no boxes.")
2701
  return
2702
 
2703
+ results = res_list[0] # Get the first (and usually only) result object
2704
+
2705
+ # --- Determine Class Mapping ---
2706
+ # This mapping needs to be verified against the actual model's output.
2707
+ # The hantian/yolo-doclaynet model card or its files might specify this.
2708
+ # Common DocLayNet class order (example, VERIFY THIS):
2709
+ # 0: Caption, 1: Footnote, 2: Formula, 3: List-item, 4: Page-footer,
2710
+ # 5: Page-header, 6: Picture, 7: Section-header, 8: Table, 9: Text, 10: Title
2711
+
2712
+ # Let's try to get names from the model directly if possible
2713
+ model_class_names = {}
2714
+ if hasattr(results, 'names') and isinstance(results.names, dict):
2715
+ model_class_names = results.names # results.names is usually {id: name}
2716
+ print(f" Engine: YOLO model class names: {model_class_names}")
2717
+ else:
2718
+ print(
2719
+ " Engine: Warning - Could not automatically get class names from YOLO model. Using predefined fallback mapping.")
2720
+ # Fallback predefined mapping (MUST BE VERIFIED FOR hantian/yolo-doclaynet)
2721
+ # This is a GUESS based on common DocLayNet order.
2722
+ # You MUST verify this by inspecting the model's config or output.
2723
+ _doclaynet_names_fallback = ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header',
2724
+ 'Picture', 'Section-header', 'Table', 'Text', 'Title']
2725
+ model_class_names = {i: name for i, name in enumerate(_doclaynet_names_fallback)}
2726
+
2727
+ # Define your mapping from DocLayNet names (or indices if names are not available) to MDRLayoutClass
2728
+ # This is crucial and specific to the new model's output classes.
2729
+ doclaynet_to_mdr_map = {
2730
+ model_class_names.get(k): v for k, v in {
2731
+ # Map by string name if model_class_names is populated correctly
2732
+ 'Text': MDRLayoutClass.PLAIN_TEXT,
2733
+ 'Title': MDRLayoutClass.TITLE,
2734
+ 'Section-header': MDRLayoutClass.TITLE, # Or a new MDRLayoutClass if needed
2735
+ 'List-item': MDRLayoutClass.PLAIN_TEXT, # Treat list items as plain text
2736
+ 'Table': MDRLayoutClass.TABLE,
2737
+ 'Picture': MDRLayoutClass.FIGURE,
2738
+ 'Formula': MDRLayoutClass.ISOLATE_FORMULA,
2739
+ 'Caption': MDRLayoutClass.FIGURE_CAPTION, # Or TABLE_CAPTION, needs context
2740
+ 'Footnote': MDRLayoutClass.TABLE_FOOTNOTE, # Or a general footnote class
2741
+ 'Page-header': MDRLayoutClass.ABANDON, # Often headers/footers are ignored
2742
+ 'Page-footer': MDRLayoutClass.ABANDON,
2743
+ }.items() if k in model_class_names.values() # Ensure key exists
2744
+ }
2745
+
2746
+ # If mapping by string name failed (e.g. model_class_names was not populated as expected),
2747
+ # try mapping by assumed index if you know the class ID for 'Text'.
2748
+ # The hantian/yolo-doclaynet example uses `classes=[1]` for Text. This implies ID 1 is Text.
2749
+ # This is risky if the order changes.
2750
+ if 'Text' not in [name for name in model_class_names.values() if name in doclaynet_to_mdr_map]:
2751
+ print(
2752
+ " Engine: Warning - 'Text' class not found in model_class_names via string mapping. Attempting index-based mapping for critical classes.")
2753
+ # Example: If you know from model card that class ID 9 is 'Text' and 10 is 'Title' for hantian/yolo-doclaynet
2754
+ # This is a COMMON order for DocLayNet, but VERIFY for hantian's model.
2755
+ # From some sources, for DocLayNet, 'Text' is often ID 9, 'Title' is ID 10.
2756
+ # The example `classes=[1]` from the HF page for hantian/yolo-doclaynet is confusing if 'Text' is ID 9.
2757
+ # Let's assume the example `classes=[1]` meant "the class at index 1 in some list", not necessarily ID 1.
2758
+ # We MUST get the correct ID for 'Text'.
2759
+ # For now, let's try to find 'Text' and 'Title' by string in model_class_names and get their IDs.
2760
+
2761
+ text_id = None
2762
+ title_id = None
2763
+ table_id = None
2764
+ figure_id = None
2765
+ formula_id = None
2766
+ caption_id = None # Generic caption
2767
+
2768
+ for id_val, name_val in model_class_names.items():
2769
+ if name_val == 'Text':
2770
+ text_id = id_val
2771
+ elif name_val == 'Title':
2772
+ title_id = id_val
2773
+ elif name_val == 'Table':
2774
+ table_id = id_val
2775
+ elif name_val == 'Picture':
2776
+ figure_id = id_val
2777
+ elif name_val == 'Formula':
2778
+ formula_id = id_val
2779
+ elif name_val == 'Caption':
2780
+ caption_id = id_val
2781
+ # Add other mappings as needed
2782
+
2783
+ temp_map_by_id = {}
2784
+ if text_id is not None: temp_map_by_id[text_id] = MDRLayoutClass.PLAIN_TEXT
2785
+ if title_id is not None: temp_map_by_id[title_id] = MDRLayoutClass.TITLE
2786
+ if table_id is not None: temp_map_by_id[table_id] = MDRLayoutClass.TABLE
2787
+ if figure_id is not None: temp_map_by_id[figure_id] = MDRLayoutClass.FIGURE
2788
+ if formula_id is not None: temp_map_by_id[formula_id] = MDRLayoutClass.ISOLATE_FORMULA
2789
+ if caption_id is not None: temp_map_by_id[
2790
+ caption_id] = MDRLayoutClass.FIGURE_CAPTION # Default, refine later
2791
+
2792
+ # Override doclaynet_to_mdr_map if direct ID mapping is more reliable
2793
+ if temp_map_by_id:
2794
+ print(f" Engine: Using direct ID mapping for some classes: {temp_map_by_id}")
2795
+ # This isn't quite right, the map should be from YOLO ID to MDR Class
2796
+ # The previous doclaynet_to_mdr_map was from string name to MDR Class.
2797
+ # We need a single, consistent map from YOLO's predicted class ID to MDRLayoutClass.
2798
+
2799
+ # Let's rebuild the map: yolo_class_id -> MDRLayoutClass
2800
+ final_yolo_id_to_mdr_class_map = {}
2801
+ if text_id is not None: final_yolo_id_to_mdr_class_map[text_id] = MDRLayoutClass.PLAIN_TEXT
2802
+ if title_id is not None: final_yolo_id_to_mdr_class_map[title_id] = MDRLayoutClass.TITLE
2803
+ # ... map others based on their found IDs ...
2804
+ # For simplicity, let's assume the string-based map from above is preferred if names are available.
2805
+ # The most important thing is to get the ID for 'Text'.
2806
+ # If `model_class_names` is `{0: 'Caption', 1: 'Footnote', ..., 9: 'Text', 10: 'Title'}`
2807
+ # then `doclaynet_to_mdr_map` should correctly map 'Text' to `MDRLayoutClass.PLAIN_TEXT`.
2808
+
2809
+ # Define which MDRLayoutClasses are considered "plain" for fragment merging later (if needed)
2810
+ # This set should use your MDRLayoutClass enum members.
2811
+ plain_mdr_classes: set[MDRLayoutClass] = {
2812
  MDRLayoutClass.TITLE,
2813
  MDRLayoutClass.PLAIN_TEXT,
2814
+ # MDRLayoutClass.ABANDON, # ABANDON layouts usually shouldn't get general text fragments
2815
+ MDRLayoutClass.FIGURE_CAPTION, # Captions are text
2816
+ MDRLayoutClass.TABLE_CAPTION, # Captions are text
2817
+ MDRLayoutClass.TABLE_FOOTNOTE, # Footnotes are text
2818
+ MDRLayoutClass.FORMULA_CAPTION, # Captions are text
2819
  }
2820
+ print(f" Engine: Mapping YOLO classes to MDR classes. Effective map used for generation:")
2821
+
2822
+ for cls_id_tensor, xyxy_tensor in zip(results.boxes.cls, results.boxes.xyxy):
2823
+ yolo_cls_id = int(cls_id_tensor.item()) # Get integer class ID from tensor
2824
+
2825
+ # Get the string name for logging/mapping
2826
+ yolo_cls_name = model_class_names.get(yolo_cls_id, f"UnknownID-{yolo_cls_id}")
2827
+
2828
+ # Map YOLO class name to your MDRLayoutClass
2829
+ mdr_cls = None
2830
+ if yolo_cls_name == 'Text':
2831
+ mdr_cls = MDRLayoutClass.PLAIN_TEXT
2832
+ elif yolo_cls_name == 'Title':
2833
+ mdr_cls = MDRLayoutClass.TITLE
2834
+ elif yolo_cls_name == 'Section-header':
2835
+ mdr_cls = MDRLayoutClass.TITLE # Or a specific header class
2836
+ elif yolo_cls_name == 'List-item':
2837
+ mdr_cls = MDRLayoutClass.PLAIN_TEXT
2838
+ elif yolo_cls_name == 'Table':
2839
+ mdr_cls = MDRLayoutClass.TABLE
2840
+ elif yolo_cls_name == 'Picture':
2841
+ mdr_cls = MDRLayoutClass.FIGURE
2842
+ elif yolo_cls_name == 'Formula':
2843
+ mdr_cls = MDRLayoutClass.ISOLATE_FORMULA
2844
+ elif yolo_cls_name == 'Caption':
2845
+ mdr_cls = MDRLayoutClass.FIGURE_CAPTION # Default, could be table too
2846
+ elif yolo_cls_name == 'Footnote':
2847
+ mdr_cls = MDRLayoutClass.TABLE_FOOTNOTE # Or general footnote
2848
+ elif yolo_cls_name in ['Page-header', 'Page-footer']:
2849
+ mdr_cls = MDRLayoutClass.ABANDON
2850
+
2851
+ if mdr_cls is None:
2852
+ # print(f" Engine: Skipping YOLO box with class '{yolo_cls_name}' (ID {yolo_cls_id}) as it's not mapped to an MDRLayoutClass.")
2853
+ continue
2854
+
2855
+ # print(f" Engine: Detected YOLO class '{yolo_cls_name}' (ID {yolo_cls_id}), mapped to MDR class '{mdr_cls.name}'")
2856
 
2857
+ x1, y1, x2, y2 = map(float, xyxy_tensor)
2858
+ rect = MDRRectangle(lt=(x1, y1), rt=(x2, y1), lb=(x1, y2), rb=(x2, y2))
2859
+ if rect.area < 10: # Filter tiny boxes
 
 
2860
  continue
2861
 
2862
+ if mdr_cls == MDRLayoutClass.TABLE:
2863
+ yield MDRTableLayoutElement(rect=rect, fragments=[], parsed=None, cls=mdr_cls) # Explicitly pass cls
2864
+ elif mdr_cls == MDRLayoutClass.ISOLATE_FORMULA:
2865
+ yield MDRFormulaLayoutElement(rect=rect, fragments=[], latex=None, cls=mdr_cls) # Explicitly pass cls
2866
+ elif mdr_cls == MDRLayoutClass.FIGURE: # Figure is not in plain_mdr_classes for default fragment assignment
2867
+ yield MDRPlainLayoutElement(cls=mdr_cls, rect=rect, fragments=[])
2868
+ elif mdr_cls in plain_mdr_classes: # For TITLE, PLAIN_TEXT, CAPTION, etc.
2869
+ yield MDRPlainLayoutElement(cls=mdr_cls, rect=rect, fragments=[])
2870
+ elif mdr_cls == MDRLayoutClass.ABANDON: # ABANDON class if you want to track but not assign frags by default
2871
+ yield MDRPlainLayoutElement(cls=mdr_cls, rect=rect, fragments=[])
2872
+ # else:
2873
+ # print(f" Engine: MDR class '{mdr_cls.name}' not explicitly handled for yielding, but was mapped.")
2874
 
2875
  def _match_fragments_to_layouts(self, frags: list[MDROcrFragment], layouts: list[MDRLayoutElement]) -> list[
2876
  MDRLayoutElement]: