|
--- |
|
library_name: transformers |
|
tags: |
|
- dhivehi |
|
- thaana |
|
- layout-analysis |
|
license: apache-2.0 |
|
datasets: |
|
- alakxender/dhivehi-layout-syn-b1-paligemma |
|
language: |
|
- dv |
|
base_model: |
|
- facebook/detr-resnet-50-dc5 |
|
--- |
|
|
|
# DETR ResNet-50 DC5 for Dhivehi Layout-Aware Document Parsing |
|
|
|
A fine-tuned DETR (DEtection TRansformer) model based on `facebook/detr-resnet-50-dc5`, trained on a custom COCO-style dataset for layout-aware document understanding in Dhivehi and similar documents. The model can detect key structural elements such as headings, authorship, paragraphs, and text lines — with awareness of document reading direction (LTR/RTL). |
|
|
|
## Model Summary |
|
|
|
- **Base Model:** facebook/detr-resnet-50-dc5 |
|
- **Dataset:** Custom COCO-format document layout dataset (`coco-dv-layout`) |
|
- **Categories:** |
|
- `layout-analysis-QvA6`, `author`, `caption`, `columns`, `date`, `footnote`, `heading`, `paragraph`, `picture`, `textline` |
|
- **Reading Direction Support:** Left-to-Right (LTR) and Right-to-Left (RTL) documents |
|
- **Backbone:** ResNet-50 DC5 |
|
|
|
--- |
|
|
|
## Usage |
|
|
|
### Inference Script |
|
|
|
```python |
|
from transformers import pipeline |
|
from PIL import Image |
|
import torch |
|
|
|
image = Image.open("ocr.png") |
|
|
|
obj_detector = pipeline( |
|
"object-detection", |
|
model="alakxender/detr-resnet-50-dc5-dv-layout-sm1", |
|
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), |
|
use_fast=True |
|
) |
|
|
|
results = obj_detector(image) |
|
print(results) |
|
``` |
|
|
|
### Test Script: |
|
|
|
```python |
|
import requests |
|
from transformers import pipeline |
|
import numpy as np |
|
from PIL import Image, ImageDraw, ImageFont |
|
import torch |
|
import argparse |
|
import json |
|
import re |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--threshold", type=float, default=0.6) |
|
parser.add_argument("--rtl", action="store_true", default=True, help="Process as right-to-left language document") |
|
args = parser.parse_args() |
|
|
|
threshold = args.threshold |
|
is_rtl = args.rtl |
|
|
|
# Set device |
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
print(f"Device set to use {device}") |
|
print(f"Document direction: {'Right-to-Left' if is_rtl else 'Left-to-Right'}") |
|
|
|
image = Image.open("ocr-bill.jpeg") |
|
|
|
obj_detector = pipeline( |
|
"object-detection", |
|
model="alakxender/detr-resnet-50-dc5-dv-layout-sm1", |
|
device=device, |
|
use_fast=True # Set use_fast=True to avoid slow processor warning |
|
) |
|
|
|
results = obj_detector(image) |
|
print(results) |
|
|
|
# Define colors for different labels |
|
category_colors = { |
|
"author": (0, 255, 0), # Green |
|
"caption": (0, 0, 255), # Blue |
|
"columns": (255, 255, 0), # Yellow |
|
"date": (255, 0, 255), # Magenta |
|
"footnote": (0, 255, 255), # Cyan |
|
"heading": (128, 0, 0), # Dark Red |
|
"paragraph": (0, 128, 0), # Dark Green |
|
"picture": (0, 0, 128), # Dark Blue |
|
"textline": (128, 128, 0) # Olive |
|
} |
|
|
|
# Define document element hierarchy (lower value = higher priority) |
|
element_priority = { |
|
"heading": 1, |
|
"author": 2, |
|
"date": 3, |
|
"columns": 4, |
|
"paragraph": 5, |
|
"textline": 6, |
|
"picture": 7, |
|
"caption": 8, |
|
"footnote": 9 |
|
} |
|
|
|
def detect_text_direction(results, threshold=0.6): |
|
""" |
|
Attempt to automatically detect if the document is RTL based on detected text elements. |
|
This is a heuristic approach - for production use, consider using language detection. |
|
""" |
|
# Filter by confidence threshold |
|
filtered_results = [r for r in results if r['score'] > threshold] |
|
|
|
# Focus on text elements (textline, paragraph, heading) |
|
text_elements = [r for r in filtered_results if r['label'] in ['textline', 'paragraph', 'heading']] |
|
|
|
if not text_elements: |
|
return False # Default to LTR if no text elements |
|
|
|
# Get coordinates |
|
coordinates = [] |
|
for r in text_elements: |
|
box = list(r['box'].values()) |
|
if len(box) == 4: |
|
x1, y1, x2, y2 = box |
|
width = x2 - x1 |
|
# Store element with its position info |
|
coordinates.append({ |
|
'xmin': x1, |
|
'xmax': x2, |
|
'width': width, |
|
'x_center': (x1 + x2) / 2 |
|
}) |
|
|
|
if not coordinates: |
|
return False # Default to LTR |
|
|
|
# Analyze the horizontal distribution of elements |
|
image_width = max([c['xmax'] for c in coordinates]) |
|
|
|
# Calculate the average center position relative to image width |
|
avg_center_position = sum([c['x_center'] for c in coordinates]) / len(coordinates) |
|
relative_position = avg_center_position / image_width |
|
|
|
# If elements tend to be more on the right side, it might be RTL |
|
# This is a simple heuristic - a more sophisticated approach would use OCR or language detection |
|
is_rtl_detected = relative_position > 0.55 # Slight bias to right side suggests RTL |
|
|
|
print(f"Auto-detected document direction: {'Right-to-Left' if is_rtl_detected else 'Left-to-Right'}") |
|
print(f"Average element center position: {relative_position:.2f} of document width") |
|
|
|
return is_rtl_detected |
|
|
|
def get_reading_order(results, threshold=0.6, rtl=is_rtl): |
|
""" |
|
Sort detection results in natural reading order for both LTR and RTL documents: |
|
1. First by element priority (headings first) |
|
2. Then by vertical position (top to bottom) |
|
3. For elements with similar y-values, sort by horizontal position based on text direction |
|
""" |
|
# Filter by confidence threshold |
|
filtered_results = [r for r in results if r['score'] > threshold] |
|
|
|
# If no manual RTL flag is set, try to auto-detect |
|
if rtl is None: |
|
rtl = detect_text_direction(results, threshold) |
|
|
|
# Group text lines by their vertical position |
|
# Text lines within ~20 pixels vertically are considered on the same line |
|
y_tolerance = 20 |
|
|
|
# Let's first check the structure of box to understand its keys |
|
if filtered_results and 'box' in filtered_results[0]: |
|
box_keys = filtered_results[0]['box'].keys() |
|
print(f"Box structure keys: {box_keys}") |
|
|
|
# Extract coordinates based on the box format |
|
# Assuming box format is {'xmin', 'ymin', 'xmax', 'ymax'} or similar |
|
if 'ymin' in box_keys: |
|
y_key, height_key = 'ymin', None |
|
x_key = 'xmin' |
|
elif 'top' in box_keys: |
|
y_key, height_key = 'top', 'height' |
|
x_key = 'left' |
|
else: |
|
print("Unknown box format, defaulting to list unpacking") |
|
# Default case using list unpacking method |
|
y_key, x_key, height_key = None, None, None |
|
else: |
|
print("No box format detected, defaulting to list unpacking") |
|
y_key, x_key, height_key = None, None, None |
|
|
|
# Separate heading and non-heading elements |
|
structural_elements = [] |
|
content_elements = [] |
|
|
|
for r in filtered_results: |
|
if r['label'] in ["heading", "author", "date"]: |
|
structural_elements.append(r) |
|
else: |
|
content_elements.append(r) |
|
|
|
# Extract coordinate functions based on the format we have |
|
def get_y(element): |
|
if y_key: |
|
return element['box'][y_key] |
|
else: |
|
# If we don't know the format, assume box values() returns [xmin, ymin, xmax, ymax] |
|
return list(element['box'].values())[1] # ymin is typically the second value |
|
|
|
def get_x(element): |
|
if x_key: |
|
return element['box'][x_key] |
|
else: |
|
# If we don't know the format, assume box values() returns [xmin, ymin, xmax, ymax] |
|
return list(element['box'].values())[0] # xmin is typically the first value |
|
|
|
def get_x_max(element): |
|
box_values = list(element['box'].values()) |
|
if len(box_values) >= 4: |
|
return box_values[2] # xmax is typically the third value |
|
return get_x(element) # fallback |
|
|
|
def get_y_center(element): |
|
if y_key and height_key: |
|
return element['box'][y_key] + (element['box'][height_key] / 2) |
|
else: |
|
# If using list format [xmin, ymin, xmax, ymax] |
|
box_values = list(element['box'].values()) |
|
return (box_values[1] + box_values[3]) / 2 # (ymin + ymax) / 2 |
|
|
|
# Sort structural elements by priority first, then by y position |
|
sorted_structural = sorted( |
|
structural_elements, |
|
key=lambda x: ( |
|
element_priority.get(x['label'], 999), |
|
get_y(x) |
|
) |
|
) |
|
|
|
# Group content elements that may be in the same row (similar y-coordinate) |
|
rows = [] |
|
for element in content_elements: |
|
y_center = get_y_center(element) |
|
|
|
# Check if this element belongs to an existing row |
|
found_row = False |
|
for row in rows: |
|
row_y_centers = [get_y_center(e) for e in row] |
|
row_y_center = sum(row_y_centers) / len(row_y_centers) |
|
if abs(y_center - row_y_center) < y_tolerance: |
|
row.append(element) |
|
found_row = True |
|
break |
|
|
|
# If not found in any existing row, create a new row |
|
if not found_row: |
|
rows.append([element]) |
|
|
|
# Sort elements within each row according to reading direction (left-to-right or right-to-left) |
|
for row in rows: |
|
if rtl: |
|
# For RTL, sort from right to left (descending x values) |
|
row.sort(key=lambda x: get_x(x), reverse=True) |
|
else: |
|
# For LTR, sort from left to right (ascending x values) |
|
row.sort(key=lambda x: get_x(x)) |
|
|
|
# Sort rows by y position (top to bottom) |
|
rows.sort(key=lambda row: sum(get_y_center(e) for e in row) / len(row)) |
|
|
|
# Flatten the rows into a single list |
|
sorted_content = [element for row in rows for element in row] |
|
|
|
# Combine structural and content elements |
|
return sorted_structural + sorted_content |
|
|
|
def plot_results(image, results, threshold=threshold, save_path='output.jpg', rtl=is_rtl): |
|
# Convert image to appropriate format if it's not already a PIL Image |
|
if not isinstance(image, Image.Image): |
|
image = Image.fromarray(np.uint8(image)) |
|
|
|
draw = ImageDraw.Draw(image) |
|
width, height = image.size |
|
|
|
# If rtl is None (not explicitly specified), try to auto-detect |
|
if rtl is None: |
|
rtl = detect_text_direction(results, threshold) |
|
|
|
# Get results in reading order |
|
ordered_results = get_reading_order(results, threshold, rtl) |
|
|
|
# Create a list to store formatted results |
|
formatted_results = [] |
|
|
|
# Add order number to visualize the detection sequence |
|
for i, result in enumerate(ordered_results): |
|
label = result['label'] |
|
box = list(result['box'].values()) |
|
score = result['score'] |
|
|
|
# Make sure box has exactly 4 values |
|
if len(box) == 4: |
|
x1, y1, x2, y2 = tuple(box) |
|
else: |
|
print(f"Warning: Unexpected box format for {label}: {box}") |
|
continue |
|
|
|
color = category_colors.get(label, (255, 255, 255)) # Default to white if label not found |
|
|
|
# Draw bounding box and labels |
|
draw.rectangle((x1, y1, x2, y2), outline=color, width=2) |
|
|
|
# Add order number to visualize the reading sequence |
|
draw.text((x1 + 5, y1 - 20), f'#{i+1}', fill=(255, 255, 255)) |
|
|
|
# For RTL languages, draw indicators differently |
|
if rtl and label in ['textline', 'paragraph', 'heading']: |
|
draw.text((x1 + 5, y1 - 10), f'{label} (RTL)', fill=color) |
|
# Draw arrow showing reading direction (right to left) |
|
arrow_y = y1 - 5 |
|
draw.line([(x2 - 20, arrow_y), (x1 + 20, arrow_y)], fill=color, width=1) |
|
draw.polygon([(x1 + 20, arrow_y - 3), (x1 + 20, arrow_y + 3), (x1 + 15, arrow_y)], fill=color) |
|
else: |
|
draw.text((x1 + 5, y1 - 10), label, fill=color) |
|
|
|
draw.text((x1 + 5, y1 + 10), f'{score:.2f}', fill='green' if score > 0.7 else 'red') |
|
|
|
# Add result to formatted list with order index |
|
formatted_results.append({ |
|
"order_index": i, |
|
"label": label, |
|
"is_rtl": rtl if label in ['textline', 'paragraph', 'heading'] else False, |
|
"score": float(score), |
|
"bbox": { |
|
"x1": float(x1), |
|
"y1": float(y1), |
|
"x2": float(x2), |
|
"y2": float(y2) |
|
} |
|
}) |
|
|
|
image.save(save_path) |
|
|
|
# Save results to JSON file with RTL information |
|
with open('results.json', 'w') as f: |
|
json.dump({ |
|
"document_direction": "rtl" if rtl else "ltr", |
|
"elements": formatted_results |
|
}, f, indent=2) |
|
|
|
return image |
|
|
|
image.save(save_path) |
|
|
|
# Save results to JSON file |
|
with open('results.json', 'w') as f: |
|
json.dump(formatted_results, f, indent=2) |
|
|
|
return image |
|
|
|
if len(results) > 0: # Only plot if there are results |
|
# If RTL flag not set, try to auto-detect |
|
if not hasattr(args, 'rtl') or args.rtl is None: |
|
is_rtl = detect_text_direction(results) |
|
|
|
plot_results(image, results, rtl=is_rtl) |
|
print(f"Processing complete. Document interpreted as {'RTL' if is_rtl else 'LTR'}") |
|
else: |
|
print("No objects detected in the image") |
|
``` |
|
|
|
--- |
|
|
|
## Output Example |
|
|
|
- **Visual Output**: Bounding boxes with labels and order |
|
- **JSON Output:** |
|
```json |
|
{ |
|
"document_direction": "rtl", |
|
"elements": [ |
|
{ |
|
"order_index": 0, |
|
"label": "heading", |
|
"is_rtl": true, |
|
"score": 0.97, |
|
"bbox": { |
|
"x1": 120.5, |
|
"y1": 65.2, |
|
"x2": 620.4, |
|
"y2": 120.7 |
|
} |
|
} |
|
] |
|
} |
|
``` |
|
|
|
--- |
|
|
|
## Training Summary |
|
|
|
- **Training script**: Uses Hugging Face `Trainer` API |
|
- **Eval Strategy**: `steps` with `MeanAveragePrecision` via `torchmetrics` |
|
--- |