File size: 6,835 Bytes
01bb3bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import gradio as gr
from PIL import Image
import os
from IndicPhotoOCR.ocr import OCR  # Ensure OCR class is saved in a file named ocr.py
from IndicPhotoOCR.theme import Seafoam
from IndicPhotoOCR.utils.helper import detect_para
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,

)
import numpy as np
import torch

from IndicTransToolkit import IndicProcessor

# Initialize the OCR object for text detection and recognition
ocr = OCR(device='cpu',verbose=False)

def translate(given_str,lang='hindi'):
    DEVICE = 'cpu'
    model_name = "ai4bharat/indictrans2-en-indic-1B" if lang=="english" else "ai4bharat/indictrans2-indic-en-1B"
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

    model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)

    ip = IndicProcessor(inference=True)

    model = model.to(DEVICE)
    model.eval()
    src_lang, tgt_lang = ("eng_Latn", "hin_Deva") if lang=="english" else ("hin_Deva", "eng_Latn" )
    
    batch = ip.preprocess_batch(
        [given_str],
        src_lang=src_lang,
        tgt_lang=tgt_lang,
    )
    inputs = tokenizer(
            batch,
            truncation=True,
            padding="longest",
            return_tensors="pt",
            return_attention_mask=True,
        ).to(DEVICE)
    with torch.no_grad():
        generated_tokens = model.generate(
            **inputs,
            use_cache=True,
            min_length=0,
            max_length=256,
            num_beams=5,
            num_return_sequences=1,
        )

    # Decode the generated tokens into text
    with tokenizer.as_target_tokenizer():
        generated_tokens = tokenizer.batch_decode(
            generated_tokens.detach().cpu().tolist(),
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True,
        )
    translation = ip.postprocess_batch(generated_tokens, lang=tgt_lang)[0]
    return translation



        

    
def process_image(image):
    """
    Processes the uploaded image for text detection and recognition. 
    - Detects bounding boxes in the image
    - Draws bounding boxes on the image and identifies script in each detected area
    - Recognizes text in each cropped region and returns the annotated image and recognized text

    Parameters:
    image (PIL.Image): The input image to be processed.

    Returns:
    tuple: A PIL.Image with bounding boxes and a string of recognized text.
    """
    
    # Save the input image temporarily
    image_path = "input_image.jpg"
    image.save(image_path)
    
    # Detect bounding boxes on the image using OCR
    detections = ocr.detect(image_path)
    
    # Draw bounding boxes on the image and save it as output
    ocr.visualize_detection(image_path, detections, save_path="output_image.png")
    
    # Load the annotated image with bounding boxes drawn
    output_image = Image.open("output_image.png")
    
    # Initialize list to hold recognized text from each detected area
    recognized_texts = {}
    pil_image = Image.open(image_path)
    
    # # Process each detected bounding box for script identification and text recognition
    # for bbox in detections:
    #     # Identify the script and crop the image to this region
    #     script_lang, cropped_path = ocr.crop_and_identify_script(pil_image, bbox)
        
    #     if script_lang:  # Only proceed if a script language is identified
    #         # Recognize text in the cropped area
    #         recognized_text = ocr.recognise(cropped_path, script_lang)
    #         recognized_texts.append(recognized_text)
    for id, bbox in enumerate(detections):
        # Identify the script and crop the image to this region
        script_lang, cropped_path = ocr.crop_and_identify_script(pil_image, bbox)

        # Calculate bounding box coordinates
        x1 = min([bbox[i][0] for i in range(len(bbox))])
        y1 = min([bbox[i][1] for i in range(len(bbox))])
        x2 = max([bbox[i][0] for i in range(len(bbox))])
        y2 = max([bbox[i][1] for i in range(len(bbox))])

        if script_lang:
            recognized_text = ocr.recognise(cropped_path, script_lang)
            recognized_texts[f"img_{id}"] = {"txt": recognized_text, "bbox": [x1, y1, x2, y2]}

    # Combine recognized texts into a single string for display
    # recognized_texts_combined = " ".join(recognized_texts)
    string = detect_para(recognized_texts)

    recognized_texts_combined = '\n'.join([' '.join(line) for line in string])
    recognized_texts_combined = translate(recognized_texts_combined,script_lang)

    return output_image, recognized_texts_combined

# Custom HTML for interface header with logos and alignment
interface_html = """
<div style="text-align: left; padding: 10px;">
    <div style="background-color: white; padding: 10px; display: inline-block;">
        <img src="https://iitj.ac.in/images/logo/Design-of-New-Logo-of-IITJ-2.png" alt="IITJ Logo" style="width: 100px; height: 100px;">
    </div>
    <img src="https://play-lh.googleusercontent.com/_FXSr4xmhPfBykmNJvKvC0GIAVJmOLhFl6RA5fobCjV-8zVSypxX8yb8ka6zu6-4TEft=w240-h480-rw" alt="Bhashini Logo" style="width: 100px; height: 100px; float: right;">
</div>
"""



# Links to GitHub and Dataset repositories with GitHub icon
links_html = """
<div style="text-align: center; padding-top: 20px;">
    <a href="https://github.com/Bhashini-IITJ/IndicPhotoOCR" target="_blank" style="margin-right: 20px; font-size: 18px; text-decoration: none;">
        GitHub Repository
    </a>
    <a href="https://github.com/Bhashini-IITJ/BharatSceneTextDataset" target="_blank" style="font-size: 18px; text-decoration: none;">
        Dataset Repository
    </a>
</div>
"""

# Custom CSS to style the text box font size
custom_css = """
.custom-textbox textarea {
    font-size: 20px !important;
}
"""

# Create an instance of the Seafoam theme for a consistent visual style
seafoam = Seafoam()

# Define examples for users to try out
examples = [
    ["test_images/208.jpg"],
    ["test_images/1310.jpg"]
]
title = "<h1 style='text-align: center;'>Developed by IITJ</h1>"

# Set up the Gradio Interface with the defined function and customizations
demo = gr.Interface(
    fn=process_image,
    inputs=gr.Image(type="pil", image_mode="RGB"),
    outputs=[
        gr.Image(type="pil", label="Detected Bounding Boxes"),
        gr.Textbox(label="Translated Text", elem_classes="custom-textbox")
    ],
    title="Scene Text Translator",
    description=title+interface_html+links_html,
    theme=seafoam,
    css=custom_css,
    examples=examples
)

# # Server setup and launch configuration
if __name__ == "__main__":
    server = "0.0.0.0"  # IP address for server
    port = 7867  # Port to run the server on
    demo.launch(server_name=server, server_port=port)

# demo.launch()