Spaces:
Running
Running
vit model added
Browse files- .gitignore +1 -0
- IndicPhotoOCR/detection/east_detector.py +11 -10
- IndicPhotoOCR/ocr.py +35 -14
- IndicPhotoOCR/script_identification/vit/__init__.py +0 -0
- IndicPhotoOCR/script_identification/vit/config.py +58 -0
- IndicPhotoOCR/script_identification/vit/vit_infer.py +213 -0
- IndicPhotoOCR/utils/__init__.py +0 -0
- IndicPhotoOCR/utils/helper.py +220 -0
- app.py +33 -17
- requirements.txt +2 -1
.gitignore
CHANGED
@@ -163,6 +163,7 @@ IndicPhotoOCR/recognition/models
|
|
163 |
|
164 |
IndicPhotoOCR/script_identification/images
|
165 |
IndicPhotoOCR/script_identification/models
|
|
|
166 |
|
167 |
|
168 |
build/
|
|
|
163 |
|
164 |
IndicPhotoOCR/script_identification/images
|
165 |
IndicPhotoOCR/script_identification/models
|
166 |
+
IndicPhotoOCR/script_identification/vit/models
|
167 |
|
168 |
|
169 |
build/
|
IndicPhotoOCR/detection/east_detector.py
CHANGED
@@ -74,14 +74,15 @@ class EASTdetector:
|
|
74 |
|
75 |
return bbox_result_dict
|
76 |
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
|
85 |
-
#
|
86 |
-
|
87 |
-
|
|
|
|
74 |
|
75 |
return bbox_result_dict
|
76 |
|
77 |
+
if __name__ == "__main__":
|
78 |
+
import argparse
|
79 |
+
parser = argparse.ArgumentParser(description='Text detection using EAST model')
|
80 |
+
parser.add_argument('--image_path', type=str, required=True, help='Path to the input image')
|
81 |
+
parser.add_argument('--device', type=str, default='cpu', help='Device to run the model on, e.g., "cpu" or "cuda"')
|
82 |
+
parser.add_argument('--model_checkpoint', type=str, required=True, help='Path to the model checkpoint file')
|
83 |
+
args = parser.parse_args()
|
84 |
|
85 |
+
# Run prediction and get results as dictionary
|
86 |
+
east = EASTdetector(model_path = args.model_checkpoint)
|
87 |
+
detection_result = east.detect(args.image_path, args.model_checkpoint, args.device)
|
88 |
+
# print(detection_result)
|
IndicPhotoOCR/ocr.py
CHANGED
@@ -7,11 +7,14 @@ import numpy as np
|
|
7 |
|
8 |
|
9 |
# from IndicPhotoOCR.detection.east_detector import EASTdetector
|
10 |
-
from IndicPhotoOCR.script_identification.CLIP_identifier import CLIPidentifier
|
|
|
11 |
from IndicPhotoOCR.recognition.parseq_recogniser import PARseqrecogniser
|
12 |
import IndicPhotoOCR.detection.east_config as cfg
|
13 |
from IndicPhotoOCR.detection.textbpn.textbpnpp_detector import TextBPNpp_detector
|
14 |
|
|
|
|
|
15 |
|
16 |
class OCR:
|
17 |
def __init__(self, device='cuda:0', verbose=False):
|
@@ -22,7 +25,8 @@ class OCR:
|
|
22 |
# self.detector = EASTdetector()
|
23 |
self.detector = TextBPNpp_detector(device=self.device)
|
24 |
self.recogniser = PARseqrecogniser()
|
25 |
-
self.identifier = CLIPidentifier()
|
|
|
26 |
|
27 |
# def detect(self, image_path, detect_model_checkpoint=cfg.checkpoint):
|
28 |
# """Run the detection model to get bounding boxes of text areas."""
|
@@ -123,6 +127,7 @@ class OCR:
|
|
123 |
|
124 |
def ocr(self, image_path):
|
125 |
"""Process the image by detecting text areas, identifying script, and recognizing text."""
|
|
|
126 |
recognized_words = []
|
127 |
image = Image.open(image_path)
|
128 |
|
@@ -130,25 +135,41 @@ class OCR:
|
|
130 |
detections = self.detect(image_path)
|
131 |
|
132 |
# Process each detected text area
|
133 |
-
for bbox in detections:
|
134 |
-
# Crop and identify script language
|
135 |
-
script_lang, cropped_path = self.crop_and_identify_script(image, bbox)
|
136 |
|
137 |
-
# Check if the script language is valid
|
138 |
-
if script_lang:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
|
140 |
-
# Recognize text
|
141 |
-
recognized_word = self.recognise(cropped_path, script_lang)
|
142 |
-
recognized_words.append(recognized_word)
|
143 |
|
144 |
-
|
145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
|
147 |
-
return
|
|
|
148 |
|
149 |
if __name__ == '__main__':
|
150 |
# detect_model_checkpoint = 'bharatSTR/East/tmp/epoch_990_checkpoint.pth.tar'
|
151 |
-
sample_image_path = 'test_images/
|
152 |
cropped_image_path = 'test_images/cropped_image/image_141_0.jpg'
|
153 |
|
154 |
ocr = OCR(device="cuda", verbose=False)
|
|
|
7 |
|
8 |
|
9 |
# from IndicPhotoOCR.detection.east_detector import EASTdetector
|
10 |
+
# from IndicPhotoOCR.script_identification.CLIP_identifier import CLIPidentifier
|
11 |
+
from IndicPhotoOCR.script_identification.vit.vit_infer import VIT_identifier
|
12 |
from IndicPhotoOCR.recognition.parseq_recogniser import PARseqrecogniser
|
13 |
import IndicPhotoOCR.detection.east_config as cfg
|
14 |
from IndicPhotoOCR.detection.textbpn.textbpnpp_detector import TextBPNpp_detector
|
15 |
|
16 |
+
from IndicPhotoOCR.utils.helper import detect_para
|
17 |
+
|
18 |
|
19 |
class OCR:
|
20 |
def __init__(self, device='cuda:0', verbose=False):
|
|
|
25 |
# self.detector = EASTdetector()
|
26 |
self.detector = TextBPNpp_detector(device=self.device)
|
27 |
self.recogniser = PARseqrecogniser()
|
28 |
+
# self.identifier = CLIPidentifier()
|
29 |
+
self.identifier = VIT_identifier()
|
30 |
|
31 |
# def detect(self, image_path, detect_model_checkpoint=cfg.checkpoint):
|
32 |
# """Run the detection model to get bounding boxes of text areas."""
|
|
|
127 |
|
128 |
def ocr(self, image_path):
|
129 |
"""Process the image by detecting text areas, identifying script, and recognizing text."""
|
130 |
+
recognized_texts = {}
|
131 |
recognized_words = []
|
132 |
image = Image.open(image_path)
|
133 |
|
|
|
135 |
detections = self.detect(image_path)
|
136 |
|
137 |
# Process each detected text area
|
138 |
+
# for bbox in detections:
|
139 |
+
# # Crop and identify script language
|
140 |
+
# script_lang, cropped_path = self.crop_and_identify_script(image, bbox)
|
141 |
|
142 |
+
# # Check if the script language is valid
|
143 |
+
# if script_lang:
|
144 |
+
|
145 |
+
# # Recognize text
|
146 |
+
# recognized_word = self.recognise(cropped_path, script_lang)
|
147 |
+
# recognized_words.append(recognized_word)
|
148 |
+
|
149 |
+
# if self.verbose:
|
150 |
+
# print(f"Recognized word: {recognized_word}")
|
151 |
|
|
|
|
|
|
|
152 |
|
153 |
+
for id, bbox in enumerate(detections):
|
154 |
+
# Identify the script and crop the image to this region
|
155 |
+
script_lang, cropped_path = self.crop_and_identify_script(image, bbox)
|
156 |
+
|
157 |
+
# Calculate bounding box coordinates
|
158 |
+
x1 = min([bbox[i][0] for i in range(len(bbox))])
|
159 |
+
y1 = min([bbox[i][1] for i in range(len(bbox))])
|
160 |
+
x2 = max([bbox[i][0] for i in range(len(bbox))])
|
161 |
+
y2 = max([bbox[i][1] for i in range(len(bbox))])
|
162 |
+
|
163 |
+
if script_lang:
|
164 |
+
recognized_text = self.recognise(cropped_path, script_lang)
|
165 |
+
recognized_texts[f"img_{id}"] = {"txt": recognized_text, "bbox": [x1, y1, x2, y2]}
|
166 |
|
167 |
+
return detect_para(recognized_texts)
|
168 |
+
# return recognized_words
|
169 |
|
170 |
if __name__ == '__main__':
|
171 |
# detect_model_checkpoint = 'bharatSTR/East/tmp/epoch_990_checkpoint.pth.tar'
|
172 |
+
sample_image_path = 'test_images/image_88.jpg'
|
173 |
cropped_image_path = 'test_images/cropped_image/image_141_0.jpg'
|
174 |
|
175 |
ocr = OCR(device="cuda", verbose=False)
|
IndicPhotoOCR/script_identification/vit/__init__.py
ADDED
File without changes
|
IndicPhotoOCR/script_identification/vit/config.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
common_config={
|
2 |
+
'pretrained_vit_model': 'google/vit-base-patch16-224-in21k'
|
3 |
+
}
|
4 |
+
|
5 |
+
train_config = {
|
6 |
+
'epochs': 20,
|
7 |
+
'max_images_real':1900,
|
8 |
+
'classes':12,
|
9 |
+
'hindi_path_real': '<path_for_hindi_dataset>',
|
10 |
+
'english_path_real':'<path_for_eng_dataset>',
|
11 |
+
'gujarati_path_real':'<path_for_gujarati_dataset>',
|
12 |
+
'punjabi_path_real':'<path_for_punjabi_dataset>',
|
13 |
+
'assamese_path_real':'<path_for_assamese_dataset>',
|
14 |
+
'bengali_path_real':'<path_for_bengali_dataset>',
|
15 |
+
'kannada_path_real':'<path_for_kannada_dataset>',
|
16 |
+
'malayalam_path_real':'<path_for_malayalam_dataset>',
|
17 |
+
'marathi_path_real':'<path_for_marathi_dataset>',
|
18 |
+
'odia_path_real':'<path_for_odia_dataset>',
|
19 |
+
'tamil_path_real':'<path_for_tamil_dataset>',
|
20 |
+
'telugu_path_real':'<path_for_telegu_dataset>',
|
21 |
+
'checkpoints_dir': '<path_for_model>'
|
22 |
+
|
23 |
+
}
|
24 |
+
train_config.update(common_config)
|
25 |
+
|
26 |
+
test_config = {
|
27 |
+
'reload_model': '<path_for_model>',
|
28 |
+
'max_images':2000,
|
29 |
+
'classes':12,
|
30 |
+
'hindi_path_real': '<path_for_hindi_dataset>',
|
31 |
+
'english_path_real':'<path_for_eng_dataset>',
|
32 |
+
'gujarati_path_real':'<path_for_gujarati_dataset>',
|
33 |
+
'punjabi_path_real':'<path_for_punjabi_dataset>',
|
34 |
+
'assamese_path_real':'<path_for_assamese_dataset>',
|
35 |
+
'bengali_path_real':'<path_for_bengali_dataset>',
|
36 |
+
'kannada_path_real':'<path_for_kannada_dataset>',
|
37 |
+
'malayalam_path_real':'<path_for_malayalam_dataset>',
|
38 |
+
'marathi_path_real':'<path_for_marathi_dataset>',
|
39 |
+
'odia_path_real':'<path_for_odia_dataset>',
|
40 |
+
'tamil_path_real':'<path_for_tamil_dataset>',
|
41 |
+
'telugu_path_real':'<path_for_telegu_dataset>',
|
42 |
+
|
43 |
+
|
44 |
+
}
|
45 |
+
test_config.update(common_config)
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
infer_config = {
|
50 |
+
'model_path':'<path_for_model>',
|
51 |
+
'img_path': 'image_path',
|
52 |
+
'folder_path':'<path_dataset_folder>',
|
53 |
+
'csv_path':'<csv_path>',
|
54 |
+
}
|
55 |
+
|
56 |
+
|
57 |
+
infer_config.update(common_config)
|
58 |
+
|
IndicPhotoOCR/script_identification/vit/vit_infer.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoImageProcessor,ViTForImageClassification,pipeline
|
2 |
+
from PIL import Image
|
3 |
+
from datasets import DatasetDict,Dataset,ClassLabel
|
4 |
+
import torchvision.transforms as transforms
|
5 |
+
import numpy as np
|
6 |
+
import csv
|
7 |
+
import os
|
8 |
+
import argparse
|
9 |
+
import requests
|
10 |
+
from tqdm import tqdm
|
11 |
+
import zipfile
|
12 |
+
import time
|
13 |
+
import glob
|
14 |
+
from IndicPhotoOCR.script_identification.vit.config import infer_config as config
|
15 |
+
|
16 |
+
model_info = {
|
17 |
+
"hindi": {
|
18 |
+
"path": "models/hindienglish",
|
19 |
+
"url" : "https://github.com/Bhashini-IITJ/ScriptIdentification/releases/download/Vit_Models/hindienglish.zip",
|
20 |
+
"subcategories": ["hindi", "english"]
|
21 |
+
},
|
22 |
+
"assamese": {
|
23 |
+
"path": "models/hindienglishassamese",
|
24 |
+
"url": "https://github.com/Bhashini-IITJ/ScriptIdentification/releases/download/Vit_Models/hindienglishassamese.zip",
|
25 |
+
"subcategories": ["hindi", "english", "assamese"]
|
26 |
+
},
|
27 |
+
"bengali": {
|
28 |
+
"path": "models/hindienglishbengali",
|
29 |
+
"url" : "https://github.com/Bhashini-IITJ/ScriptIdentification/releases/download/Vit_Models/hindienglishbengali.zip",
|
30 |
+
"subcategories": ["hindi", "english", "bengali"]
|
31 |
+
},
|
32 |
+
"gujarati": {
|
33 |
+
"path": "models/hindienglishgujarati",
|
34 |
+
"url" : "https://github.com/Bhashini-IITJ/ScriptIdentification/releases/download/Vit_Models/hindienglishgujarati.zip",
|
35 |
+
"subcategories": ["hindi", "english", "gujarati"]
|
36 |
+
},
|
37 |
+
"kannada": {
|
38 |
+
"path": "models/hindienglishkannada",
|
39 |
+
"url" : "https://github.com/Bhashini-IITJ/ScriptIdentification/releases/download/Vit_Models/hindienglishkannada.zip",
|
40 |
+
"subcategories": ["hindi", "english", "kannada"]
|
41 |
+
},
|
42 |
+
"malayalam": {
|
43 |
+
"path": "models/hindienglishmalayalam",
|
44 |
+
"url" : "https://github.com/Bhashini-IITJ/ScriptIdentification/releases/download/Vit_Models/hindienglishmalayalam.zip",
|
45 |
+
"subcategories": ["hindi", "english", "malayalam"]
|
46 |
+
},
|
47 |
+
"marathi": {
|
48 |
+
"path": "models/hindienglishmarathi",
|
49 |
+
"url" : "https://github.com/Bhashini-IITJ/ScriptIdentification/releases/download/Vit_Models/hindienglishmarathi.zip",
|
50 |
+
"subcategories": ["hindi", "english", "marathi"]
|
51 |
+
},
|
52 |
+
"meitei": {
|
53 |
+
"path": "models/hindienglishmeitei",
|
54 |
+
"url" : "https://github.com/Bhashini-IITJ/ScriptIdentification/releases/download/Vit_Models/hindienglishmeitei.zip",
|
55 |
+
"subcategories": ["hindi", "english", "meitei"]
|
56 |
+
},
|
57 |
+
"odia": {
|
58 |
+
"path": "models/hindienglishodia",
|
59 |
+
"url" : "https://github.com/Bhashini-IITJ/ScriptIdentification/releases/download/Vit_Models/hindienglishodia.zip",
|
60 |
+
"subcategories": ["hindi", "english", "odia"]
|
61 |
+
},
|
62 |
+
"punjabi": {
|
63 |
+
"path": "models/hindienglishpunjabi",
|
64 |
+
"url" : "https://github.com/Bhashini-IITJ/ScriptIdentification/releases/download/Vit_Models/hindienglishpunjabi.zip",
|
65 |
+
"subcategories": ["hindi", "english", "punjabi"]
|
66 |
+
},
|
67 |
+
"tamil": {
|
68 |
+
"path": "models/hindienglishtamil",
|
69 |
+
"url" : "https://github.com/Bhashini-IITJ/ScriptIdentification/releases/download/Vit_Models/hindienglishtamil.zip",
|
70 |
+
"subcategories": ["hindi", "english", "tamil"]
|
71 |
+
},
|
72 |
+
"telugu": {
|
73 |
+
"path": "models/hindienglishtelugu",
|
74 |
+
"url" : "https://github.com/Bhashini-IITJ/ScriptIdentification/releases/download/Vit_Models/hindienglishtelugu.zip",
|
75 |
+
"subcategories": ["hindi", "english", "telugu"]
|
76 |
+
},
|
77 |
+
"12C": {
|
78 |
+
"path": "models/12_classes",
|
79 |
+
"url" : "https://github.com/Bhashini-IITJ/ScriptIdentification/releases/download/Vit_Models/12_classes.zip",
|
80 |
+
"subcategories": ["hindi", "english", "assamese","bengali","gujarati","kannada","malayalam","marathi","odia","punjabi","tamil","telegu"]
|
81 |
+
},
|
82 |
+
|
83 |
+
|
84 |
+
}
|
85 |
+
|
86 |
+
pretrained_vit_model = config['pretrained_vit_model']
|
87 |
+
processor = AutoImageProcessor.from_pretrained(pretrained_vit_model,use_fast=True)
|
88 |
+
|
89 |
+
|
90 |
+
class VIT_identifier:
|
91 |
+
def __init__(self):
|
92 |
+
pass
|
93 |
+
|
94 |
+
def unzip_file(self, zip_path, extract_to):
|
95 |
+
|
96 |
+
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
97 |
+
zip_ref.extractall(extract_to)
|
98 |
+
print(f"Extracted files to {extract_to}")
|
99 |
+
|
100 |
+
|
101 |
+
|
102 |
+
|
103 |
+
def ensure_model(self, model_name):
|
104 |
+
model_path = model_info[model_name]["path"]
|
105 |
+
url = model_info[model_name]["url"]
|
106 |
+
root_model_dir = "IndicPhotoOCR/script_identification/vit"
|
107 |
+
model_path = os.path.join(root_model_dir, model_path)
|
108 |
+
|
109 |
+
if not os.path.exists(model_path):
|
110 |
+
print(f"Model not found locally. Downloading {model_name} from {url}...")
|
111 |
+
|
112 |
+
response = requests.get(url, stream=True)
|
113 |
+
zip_path = os.path.join(model_path, "temp_download.zip")
|
114 |
+
|
115 |
+
os.makedirs(model_path, exist_ok=True)
|
116 |
+
|
117 |
+
with open(zip_path, "wb") as file:
|
118 |
+
for chunk in response.iter_content(chunk_size=8192):
|
119 |
+
file.write(chunk)
|
120 |
+
|
121 |
+
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
122 |
+
zip_ref.extractall(model_path)
|
123 |
+
|
124 |
+
os.remove(zip_path)
|
125 |
+
|
126 |
+
print(f"Downloaded and extracted to {model_path}")
|
127 |
+
|
128 |
+
else:
|
129 |
+
# print(f"Model folder already exists: {model_path}")
|
130 |
+
pass
|
131 |
+
|
132 |
+
return model_path
|
133 |
+
|
134 |
+
|
135 |
+
|
136 |
+
|
137 |
+
|
138 |
+
def identify(self, image_path,model_name):
|
139 |
+
model_path = self.ensure_model(model_name)
|
140 |
+
|
141 |
+
vit = ViTForImageClassification.from_pretrained(model_path)
|
142 |
+
model= pipeline('image-classification', model=vit, feature_extractor=processor,device=0)
|
143 |
+
|
144 |
+
if image_path.endswith((".png", ".jpg", ".jpeg")):
|
145 |
+
|
146 |
+
image = Image.open(image_path)
|
147 |
+
output = model(image)
|
148 |
+
predicted_label = max(output, key=lambda x: x['score'])['label']
|
149 |
+
|
150 |
+
# print(f"image_path: {image_path}, predicted_label: {predicted_label}\n")
|
151 |
+
|
152 |
+
return predicted_label
|
153 |
+
|
154 |
+
|
155 |
+
def predict_batch(self, image_dir,model_name,time_show,output_csv="prediction.csv"):
|
156 |
+
model_path = self.ensure_model(model_name)
|
157 |
+
vit = ViTForImageClassification.from_pretrained(model_path)
|
158 |
+
model= pipeline('image-classification', model=vit, feature_extractor=processor,device=0)
|
159 |
+
|
160 |
+
start_time = time.time()
|
161 |
+
results=[]
|
162 |
+
image_count=0
|
163 |
+
for filename in os.listdir(image_dir):
|
164 |
+
|
165 |
+
if filename.endswith((".png", ".jpg", ".jpeg")):
|
166 |
+
img_path = os.path.join(image_dir, filename)
|
167 |
+
image = Image.open(img_path)
|
168 |
+
|
169 |
+
|
170 |
+
output = model(image)
|
171 |
+
predicted_label = max(output, key=lambda x: x['score'])['label'].capitalize()
|
172 |
+
|
173 |
+
results.append({"Filepath": filename, "Language": predicted_label})
|
174 |
+
image_count+=1
|
175 |
+
|
176 |
+
elapsed_time = time.time() - start_time
|
177 |
+
|
178 |
+
if time_show:
|
179 |
+
print(f"Time taken to process {image_count} images: {elapsed_time:.2f} seconds")
|
180 |
+
|
181 |
+
with open(output_csv, mode="w", newline="", encoding="utf-8") as csvfile:
|
182 |
+
writer = csv.DictWriter(csvfile, fieldnames=["Filepath", "Language"])
|
183 |
+
writer.writeheader()
|
184 |
+
writer.writerows(results)
|
185 |
+
|
186 |
+
return output_csv
|
187 |
+
|
188 |
+
|
189 |
+
# if __name__ == "__main__":
|
190 |
+
# # Argument parser for command line usage
|
191 |
+
# parser = argparse.ArgumentParser(description="Image classification using CLIP fine-tuned model")
|
192 |
+
# parser.add_argument("--image_path", type=str, help="Path to the input image")
|
193 |
+
# parser.add_argument("--image_dir", type=str, help="Path to the input image directory")
|
194 |
+
# parser.add_argument("--model_name", type=str, choices=model_info.keys(), help="Name of the model (e.g., hineng, hinengpun, hinengguj)")
|
195 |
+
# parser.add_argument("--batch", action="store_true", help="Process images in batch mode if specified")
|
196 |
+
# parser.add_argument("--time",type=bool, nargs="?", const=True, default=False, help="Prints the time required to process a batch of images")
|
197 |
+
|
198 |
+
# args = parser.parse_args()
|
199 |
+
|
200 |
+
|
201 |
+
# # Choose function based on the batch parameter
|
202 |
+
# if args.batch:
|
203 |
+
# if not args.image_dir:
|
204 |
+
# print("Error: image_dir is required when batch is set to True.")
|
205 |
+
# else:
|
206 |
+
# result = predict_batch(args.image_dir, args.model_name, args.time)
|
207 |
+
# print(result)
|
208 |
+
# else:
|
209 |
+
# if not args.image_path:
|
210 |
+
# print("Error: image_path is required when batch is not set.")
|
211 |
+
# else:
|
212 |
+
# result = predict(args.image_path, args.model_name)
|
213 |
+
# print(result)
|
IndicPhotoOCR/utils/__init__.py
ADDED
File without changes
|
IndicPhotoOCR/utils/helper.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
# def detect_para(bbox_dict):
|
4 |
+
# alpha1 = 0.2
|
5 |
+
# alpha2 = 0.7
|
6 |
+
# beta1 = 0.4
|
7 |
+
# data = bbox_dict
|
8 |
+
# word_crops = list(data.keys())
|
9 |
+
# for i in word_crops:
|
10 |
+
# data[i]["x1"], data[i]["y1"], data[i]["x2"], data[i]["y2"] = data[i]["bbox"]
|
11 |
+
# data[i]["xc"] = (data[i]["x1"] + data[i]["x2"]) / 2
|
12 |
+
# data[i]["yc"] = (data[i]["y1"] + data[i]["y2"]) / 2
|
13 |
+
# data[i]["w"] = data[i]["x2"] - data[i]["x1"]
|
14 |
+
# data[i]["h"] = data[i]["y2"] - data[i]["y1"]
|
15 |
+
|
16 |
+
# patch_info = {}
|
17 |
+
# while word_crops:
|
18 |
+
# img_name = word_crops[0].split("_")[0]
|
19 |
+
# word_crop_collection = [
|
20 |
+
# word_crop for word_crop in word_crops if word_crop.startswith(img_name)
|
21 |
+
# ]
|
22 |
+
# centroids = {}
|
23 |
+
# lines = []
|
24 |
+
# img_word_crops = word_crop_collection.copy()
|
25 |
+
# para = []
|
26 |
+
# while img_word_crops:
|
27 |
+
# clusters = []
|
28 |
+
# para_words_group = [
|
29 |
+
# img_word_crops[0],
|
30 |
+
# ]
|
31 |
+
# added = [
|
32 |
+
# img_word_crops[0],
|
33 |
+
# ]
|
34 |
+
# img_word_crops.remove(img_word_crops[0])
|
35 |
+
# ## determining the paragraph
|
36 |
+
# while added:
|
37 |
+
# word_crop = added.pop()
|
38 |
+
# for i in range(len(img_word_crops)):
|
39 |
+
# word_crop_ = img_word_crops[i]
|
40 |
+
# if (
|
41 |
+
# abs(data[word_crop_]["yc"] - data[word_crop]["yc"])
|
42 |
+
# < data[word_crop]["h"] * alpha1
|
43 |
+
# ):
|
44 |
+
# if data[word_crop]["xc"] > data[word_crop_]["xc"]:
|
45 |
+
# if (data[word_crop]["x1"] - data[word_crop_]["x2"]) < data[
|
46 |
+
# word_crop
|
47 |
+
# ]["h"] * alpha2:
|
48 |
+
# para_words_group.append(word_crop_)
|
49 |
+
# added.append(word_crop_)
|
50 |
+
# else:
|
51 |
+
# if (data[word_crop_]["x1"] - data[word_crop]["x2"]) < data[
|
52 |
+
# word_crop
|
53 |
+
# ]["h"] * alpha2:
|
54 |
+
# para_words_group.append(word_crop_)
|
55 |
+
# added.append(word_crop_)
|
56 |
+
# else:
|
57 |
+
# if data[word_crop]["yc"] > data[word_crop_]["yc"]:
|
58 |
+
# if (data[word_crop]["y1"] - data[word_crop_]["y2"]) < data[
|
59 |
+
# word_crop
|
60 |
+
# ]["h"] * beta1 and (
|
61 |
+
# (
|
62 |
+
# (data[word_crop_]["x1"] < data[word_crop]["x2"])
|
63 |
+
# and (data[word_crop_]["x1"] > data[word_crop]["x1"])
|
64 |
+
# )
|
65 |
+
# or (
|
66 |
+
# (data[word_crop_]["x2"] < data[word_crop]["x2"])
|
67 |
+
# and (data[word_crop_]["x2"] > data[word_crop]["x1"])
|
68 |
+
# )
|
69 |
+
# or (
|
70 |
+
# (data[word_crop]["x1"] > data[word_crop_]["x1"])
|
71 |
+
# and (data[word_crop]["x2"] < data[word_crop_]["x2"])
|
72 |
+
# )
|
73 |
+
# ):
|
74 |
+
# para_words_group.append(word_crop_)
|
75 |
+
# added.append(word_crop_)
|
76 |
+
# else:
|
77 |
+
# if (data[word_crop_]["y1"] - data[word_crop]["y2"]) < data[
|
78 |
+
# word_crop
|
79 |
+
# ]["h"] * beta1 and (
|
80 |
+
# (
|
81 |
+
# (data[word_crop_]["x1"] < data[word_crop]["x2"])
|
82 |
+
# and (data[word_crop_]["x1"] > data[word_crop]["x1"])
|
83 |
+
# )
|
84 |
+
# or (
|
85 |
+
# (data[word_crop_]["x2"] < data[word_crop]["x2"])
|
86 |
+
# and (data[word_crop_]["x2"] > data[word_crop]["x1"])
|
87 |
+
# )
|
88 |
+
# or (
|
89 |
+
# (data[word_crop]["x1"] > data[word_crop_]["x1"])
|
90 |
+
# and (data[word_crop]["x2"] < data[word_crop_]["x2"])
|
91 |
+
# )
|
92 |
+
# ):
|
93 |
+
# para_words_group.append(word_crop_)
|
94 |
+
# added.append(word_crop_)
|
95 |
+
# img_word_crops = [p for p in img_word_crops if p not in para_words_group]
|
96 |
+
# ## processing for the line
|
97 |
+
# while para_words_group:
|
98 |
+
# line_words_group = [
|
99 |
+
# para_words_group[0],
|
100 |
+
# ]
|
101 |
+
# added = [
|
102 |
+
# para_words_group[0],
|
103 |
+
# ]
|
104 |
+
# para_words_group.remove(para_words_group[0])
|
105 |
+
# ## determining the line
|
106 |
+
# while added:
|
107 |
+
# word_crop = added.pop()
|
108 |
+
# for i in range(len(para_words_group)):
|
109 |
+
# word_crop_ = para_words_group[i]
|
110 |
+
# if (
|
111 |
+
# abs(data[word_crop_]["yc"] - data[word_crop]["yc"])
|
112 |
+
# < data[word_crop]["h"] * alpha1
|
113 |
+
# ):
|
114 |
+
# if data[word_crop]["xc"] > data[word_crop_]["xc"]:
|
115 |
+
# if (data[word_crop]["x1"] - data[word_crop_]["x2"]) < data[
|
116 |
+
# word_crop
|
117 |
+
# ]["h"] * alpha2:
|
118 |
+
# line_words_group.append(word_crop_)
|
119 |
+
# added.append(word_crop_)
|
120 |
+
# else:
|
121 |
+
# if (data[word_crop_]["x1"] - data[word_crop]["x2"]) < data[
|
122 |
+
# word_crop
|
123 |
+
# ]["h"] * alpha2:
|
124 |
+
# line_words_group.append(word_crop_)
|
125 |
+
# added.append(word_crop_)
|
126 |
+
# para_words_group = [
|
127 |
+
# p for p in para_words_group if p not in line_words_group
|
128 |
+
# ]
|
129 |
+
# xc = [data[word_crop]["xc"] for word_crop in line_words_group]
|
130 |
+
# idxs = np.argsort(xc)
|
131 |
+
# patch_cluster_ = [line_words_group[i] for i in idxs]
|
132 |
+
# line_words_group = patch_cluster_
|
133 |
+
# x1 = [data[word_crop]["x1"] for word_crop in line_words_group]
|
134 |
+
# x2 = [data[word_crop]["x2"] for word_crop in line_words_group]
|
135 |
+
# y1 = [data[word_crop]["y1"] for word_crop in line_words_group]
|
136 |
+
# y2 = [data[word_crop]["y2"] for word_crop in line_words_group]
|
137 |
+
# txt_line = [data[word_crop]["txt"] for word_crop in line_words_group]
|
138 |
+
# txt = " ".join(txt_line)
|
139 |
+
# x = [x1[0]]
|
140 |
+
# y1_ = [y1[0]]
|
141 |
+
# y2_ = [y2[0]]
|
142 |
+
# l = [len(txt_l) for txt_l in txt_line]
|
143 |
+
# for i in range(1, len(x1)):
|
144 |
+
# x.append((x1[i] + x2[i - 1]) / 2)
|
145 |
+
# y1_.append((y1[i] + y1[i - 1]) / 2)
|
146 |
+
# y2_.append((y2[i] + y2[i - 1]) / 2)
|
147 |
+
# x.append(x2[-1])
|
148 |
+
# y1_.append(y1[-1])
|
149 |
+
# y2_.append(y2[-1])
|
150 |
+
# line_info = {
|
151 |
+
# "x": x,
|
152 |
+
# "y1": y1_,
|
153 |
+
# "y2": y2_,
|
154 |
+
# "l": l,
|
155 |
+
# "txt": txt,
|
156 |
+
# "word_crops": line_words_group,
|
157 |
+
# }
|
158 |
+
# clusters.append(line_info)
|
159 |
+
# y_ = [clusters[i]["y1"][0] for i in range(len(clusters))]
|
160 |
+
# idxs = np.argsort(y_)
|
161 |
+
# clusters_ = [clusters[i] for i in idxs]
|
162 |
+
# txt = [clusters[i]["txt"] for i in idxs]
|
163 |
+
# l = [len(t) for t in txt]
|
164 |
+
# txt = " ".join(txt)
|
165 |
+
# para_info = {"lines": clusters_, "l": l, "txt": txt}
|
166 |
+
# para.append(para_info)
|
167 |
+
|
168 |
+
# for word_crop in word_crop_collection:
|
169 |
+
# word_crops.remove(word_crop)
|
170 |
+
# return "\n".join([para[i]["txt"] for i in range(len(para))])
|
171 |
+
|
172 |
+
|
173 |
+
def detect_para(recognized_texts):
|
174 |
+
"""
|
175 |
+
Sort words into lines based on horizontal overlap of bounding boxes.
|
176 |
+
|
177 |
+
Args:
|
178 |
+
recognized_texts (dict): A dictionary with recognized texts as keys and bounding boxes as values.
|
179 |
+
Each bounding box is a list of points [x1, y1, x2, y2].
|
180 |
+
|
181 |
+
Returns:
|
182 |
+
list: A list of lists where each sublist contains words sorted by x-coordinate for a single line.
|
183 |
+
"""
|
184 |
+
def calculate_overlap(bbox1, bbox2):
|
185 |
+
"""Calculate the vertical overlap between two bounding boxes."""
|
186 |
+
# Extract bounding box coordinates
|
187 |
+
x1_1, y1_1, x2_1, y2_1 = bbox1
|
188 |
+
x1_2, y1_2, x2_2, y2_2 = bbox2
|
189 |
+
|
190 |
+
overlap = max(0, min(y2_1, y2_2) - max(y1_1, y1_2))
|
191 |
+
height = min(y2_1 - y1_1, y2_2 - y1_2)
|
192 |
+
return overlap / height if height > 0 else 0
|
193 |
+
|
194 |
+
# Convert recognized_texts dictionary to a list of tuples for processing
|
195 |
+
items = list(recognized_texts.items())
|
196 |
+
lines = []
|
197 |
+
|
198 |
+
while items:
|
199 |
+
current_image, current_data = items.pop(0)
|
200 |
+
current_text, current_bbox = current_data['txt'], current_data['bbox']
|
201 |
+
current_line = [(current_text, current_bbox)]
|
202 |
+
|
203 |
+
remaining_items = []
|
204 |
+
for image, data in items:
|
205 |
+
text, bbox = data['txt'], data['bbox']
|
206 |
+
if calculate_overlap(current_bbox, bbox) > 0.4:
|
207 |
+
current_line.append((text, bbox))
|
208 |
+
else:
|
209 |
+
remaining_items.append((image, data))
|
210 |
+
|
211 |
+
items = remaining_items
|
212 |
+
lines.append(current_line)
|
213 |
+
|
214 |
+
# Sort words within each line based on x1 (horizontal position)
|
215 |
+
sorted_lines = [
|
216 |
+
[text for text, bbox in sorted(line, key=lambda x: x[1][0])] for line in lines
|
217 |
+
]
|
218 |
+
return sorted_lines
|
219 |
+
|
220 |
+
|
app.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
-
# This is a working demo - textbpn++ - CLIP - Parseq
|
2 |
import gradio as gr
|
3 |
from PIL import Image
|
4 |
import os
|
5 |
from IndicPhotoOCR.ocr import OCR # Ensure OCR class is saved in a file named ocr.py
|
6 |
from IndicPhotoOCR.theme import Seafoam
|
|
|
7 |
|
8 |
# Initialize the OCR object for text detection and recognition
|
9 |
-
ocr = OCR(
|
10 |
|
11 |
def process_image(image):
|
12 |
"""
|
@@ -36,21 +36,37 @@ def process_image(image):
|
|
36 |
output_image = Image.open("output_image.png")
|
37 |
|
38 |
# Initialize list to hold recognized text from each detected area
|
39 |
-
recognized_texts =
|
40 |
pil_image = Image.open(image_path)
|
41 |
|
42 |
-
# Process each detected bounding box for script identification and text recognition
|
43 |
-
for bbox in detections:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
# Identify the script and crop the image to this region
|
45 |
script_lang, cropped_path = ocr.crop_and_identify_script(pil_image, bbox)
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
49 |
recognized_text = ocr.recognise(cropped_path, script_lang)
|
50 |
-
recognized_texts
|
51 |
-
|
52 |
# Combine recognized texts into a single string for display
|
53 |
-
recognized_texts_combined = " ".join(recognized_texts)
|
|
|
|
|
|
|
54 |
return output_image, recognized_texts_combined
|
55 |
|
56 |
# Custom HTML for interface header with logos and alignment
|
@@ -110,10 +126,10 @@ demo = gr.Interface(
|
|
110 |
examples=examples
|
111 |
)
|
112 |
|
113 |
-
#
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
|
119 |
-
demo.launch()
|
|
|
|
|
1 |
import gradio as gr
|
2 |
from PIL import Image
|
3 |
import os
|
4 |
from IndicPhotoOCR.ocr import OCR # Ensure OCR class is saved in a file named ocr.py
|
5 |
from IndicPhotoOCR.theme import Seafoam
|
6 |
+
from IndicPhotoOCR.utils.helper import detect_para
|
7 |
|
8 |
# Initialize the OCR object for text detection and recognition
|
9 |
+
ocr = OCR(verbose=False)
|
10 |
|
11 |
def process_image(image):
|
12 |
"""
|
|
|
36 |
output_image = Image.open("output_image.png")
|
37 |
|
38 |
# Initialize list to hold recognized text from each detected area
|
39 |
+
recognized_texts = {}
|
40 |
pil_image = Image.open(image_path)
|
41 |
|
42 |
+
# # Process each detected bounding box for script identification and text recognition
|
43 |
+
# for bbox in detections:
|
44 |
+
# # Identify the script and crop the image to this region
|
45 |
+
# script_lang, cropped_path = ocr.crop_and_identify_script(pil_image, bbox)
|
46 |
+
|
47 |
+
# if script_lang: # Only proceed if a script language is identified
|
48 |
+
# # Recognize text in the cropped area
|
49 |
+
# recognized_text = ocr.recognise(cropped_path, script_lang)
|
50 |
+
# recognized_texts.append(recognized_text)
|
51 |
+
for id, bbox in enumerate(detections):
|
52 |
# Identify the script and crop the image to this region
|
53 |
script_lang, cropped_path = ocr.crop_and_identify_script(pil_image, bbox)
|
54 |
+
|
55 |
+
# Calculate bounding box coordinates
|
56 |
+
x1 = min([bbox[i][0] for i in range(len(bbox))])
|
57 |
+
y1 = min([bbox[i][1] for i in range(len(bbox))])
|
58 |
+
x2 = max([bbox[i][0] for i in range(len(bbox))])
|
59 |
+
y2 = max([bbox[i][1] for i in range(len(bbox))])
|
60 |
+
|
61 |
+
if script_lang:
|
62 |
recognized_text = ocr.recognise(cropped_path, script_lang)
|
63 |
+
recognized_texts[f"img_{id}"] = {"txt": recognized_text, "bbox": [x1, y1, x2, y2]}
|
64 |
+
|
65 |
# Combine recognized texts into a single string for display
|
66 |
+
# recognized_texts_combined = " ".join(recognized_texts)
|
67 |
+
string = detect_para(recognized_texts)
|
68 |
+
recognized_texts_combined = '\n'.join([' '.join(line) for line in string])
|
69 |
+
|
70 |
return output_image, recognized_texts_combined
|
71 |
|
72 |
# Custom HTML for interface header with logos and alignment
|
|
|
126 |
examples=examples
|
127 |
)
|
128 |
|
129 |
+
# Server setup and launch configuration
|
130 |
+
if __name__ == "__main__":
|
131 |
+
server = "0.0.0.0" # IP address for server
|
132 |
+
port = 7865 # Port to run the server on
|
133 |
+
demo.launch(server_name=server, server_port=port, share=True)
|
134 |
|
135 |
+
# demo.launch()
|
requirements.txt
CHANGED
@@ -43,4 +43,5 @@ torch==2.5.0
|
|
43 |
torchvision==0.20.0
|
44 |
easydict==1.13
|
45 |
scipy==1.13.1
|
46 |
-
|
|
|
|
43 |
torchvision==0.20.0
|
44 |
easydict==1.13
|
45 |
scipy==1.13.1
|
46 |
+
transformers==4.45.1
|
47 |
+
datasets==3.1.0
|