File size: 6,841 Bytes
01bb3bb 63da76a 01bb3bb 63da76a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
import gradio as gr
from PIL import Image
import os
from IndicPhotoOCR.ocr import OCR # Ensure OCR class is saved in a file named ocr.py
from IndicPhotoOCR.theme import Seafoam
from IndicPhotoOCR.utils.helper import detect_para
from transformers import (
AutoModelForSeq2SeqLM,
AutoTokenizer,
)
import numpy as np
import torch
from IndicTransToolkit import IndicProcessor
# Initialize the OCR object for text detection and recognition
ocr = OCR(device='cpu',verbose=False)
def translate(given_str,lang='hindi'):
DEVICE = 'cpu'
model_name = "ai4bharat/indictrans2-en-indic-1B" if lang=="english" else "ai4bharat/indictrans2-indic-en-1B"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
ip = IndicProcessor(inference=True)
model = model.to(DEVICE)
model.eval()
src_lang, tgt_lang = ("eng_Latn", "hin_Deva") if lang=="english" else ("hin_Deva", "eng_Latn" )
batch = ip.preprocess_batch(
[given_str],
src_lang=src_lang,
tgt_lang=tgt_lang,
)
inputs = tokenizer(
batch,
truncation=True,
padding="longest",
return_tensors="pt",
return_attention_mask=True,
).to(DEVICE)
with torch.no_grad():
generated_tokens = model.generate(
**inputs,
use_cache=True,
min_length=0,
max_length=256,
num_beams=5,
num_return_sequences=1,
)
# Decode the generated tokens into text
with tokenizer.as_target_tokenizer():
generated_tokens = tokenizer.batch_decode(
generated_tokens.detach().cpu().tolist(),
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
translation = ip.postprocess_batch(generated_tokens, lang=tgt_lang)[0]
return translation
def process_image(image):
"""
Processes the uploaded image for text detection and recognition.
- Detects bounding boxes in the image
- Draws bounding boxes on the image and identifies script in each detected area
- Recognizes text in each cropped region and returns the annotated image and recognized text
Parameters:
image (PIL.Image): The input image to be processed.
Returns:
tuple: A PIL.Image with bounding boxes and a string of recognized text.
"""
# Save the input image temporarily
image_path = "input_image.jpg"
image.save(image_path)
# Detect bounding boxes on the image using OCR
detections = ocr.detect(image_path)
# Draw bounding boxes on the image and save it as output
ocr.visualize_detection(image_path, detections, save_path="output_image.png")
# Load the annotated image with bounding boxes drawn
output_image = Image.open("output_image.png")
# Initialize list to hold recognized text from each detected area
recognized_texts = {}
pil_image = Image.open(image_path)
# # Process each detected bounding box for script identification and text recognition
# for bbox in detections:
# # Identify the script and crop the image to this region
# script_lang, cropped_path = ocr.crop_and_identify_script(pil_image, bbox)
# if script_lang: # Only proceed if a script language is identified
# # Recognize text in the cropped area
# recognized_text = ocr.recognise(cropped_path, script_lang)
# recognized_texts.append(recognized_text)
for id, bbox in enumerate(detections):
# Identify the script and crop the image to this region
script_lang, cropped_path = ocr.crop_and_identify_script(pil_image, bbox)
# Calculate bounding box coordinates
x1 = min([bbox[i][0] for i in range(len(bbox))])
y1 = min([bbox[i][1] for i in range(len(bbox))])
x2 = max([bbox[i][0] for i in range(len(bbox))])
y2 = max([bbox[i][1] for i in range(len(bbox))])
if script_lang:
recognized_text = ocr.recognise(cropped_path, script_lang)
recognized_texts[f"img_{id}"] = {"txt": recognized_text, "bbox": [x1, y1, x2, y2]}
# Combine recognized texts into a single string for display
# recognized_texts_combined = " ".join(recognized_texts)
string = detect_para(recognized_texts)
recognized_texts_combined = '\n'.join([' '.join(line) for line in string])
recognized_texts_combined = translate(recognized_texts_combined,script_lang)
return output_image, recognized_texts_combined
# Custom HTML for interface header with logos and alignment
interface_html = """
<div style="text-align: left; padding: 10px;">
<div style="background-color: white; padding: 10px; display: inline-block;">
<img src="https://iitj.ac.in/images/logo/Design-of-New-Logo-of-IITJ-2.png" alt="IITJ Logo" style="width: 100px; height: 100px;">
</div>
<img src="https://play-lh.googleusercontent.com/_FXSr4xmhPfBykmNJvKvC0GIAVJmOLhFl6RA5fobCjV-8zVSypxX8yb8ka6zu6-4TEft=w240-h480-rw" alt="Bhashini Logo" style="width: 100px; height: 100px; float: right;">
</div>
"""
# Links to GitHub and Dataset repositories with GitHub icon
links_html = """
<div style="text-align: center; padding-top: 20px;">
<a href="https://github.com/Bhashini-IITJ/IndicPhotoOCR" target="_blank" style="margin-right: 20px; font-size: 18px; text-decoration: none;">
GitHub Repository
</a>
<a href="https://github.com/Bhashini-IITJ/BharatSceneTextDataset" target="_blank" style="font-size: 18px; text-decoration: none;">
Dataset Repository
</a>
</div>
"""
# Custom CSS to style the text box font size
custom_css = """
.custom-textbox textarea {
font-size: 20px !important;
}
"""
# Create an instance of the Seafoam theme for a consistent visual style
seafoam = Seafoam()
# Define examples for users to try out
examples = [
["test_images/208.jpg"],
["test_images/1310.jpg"]
]
title = "<h1 style='text-align: center;'>Developed by IITJ</h1>"
# Set up the Gradio Interface with the defined function and customizations
demo = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil", image_mode="RGB"),
outputs=[
gr.Image(type="pil", label="Detected Bounding Boxes"),
gr.Textbox(label="Translated Text", elem_classes="custom-textbox")
],
title="Scene Text Translator",
description=title+interface_html+links_html,
theme=seafoam,
css=custom_css,
examples=examples
)
# # Server setup and launch configuration
# if __name__ == "__main__":
# server = "0.0.0.0" # IP address for server
# port = 7867 # Port to run the server on
# demo.launch(server_name=server, server_port=port)
demo.launch()
|