anikde commited on
Commit
75b1563
·
1 Parent(s): 1b11f07

vit model added

Browse files
.gitignore CHANGED
@@ -163,6 +163,7 @@ IndicPhotoOCR/recognition/models
163
 
164
  IndicPhotoOCR/script_identification/images
165
  IndicPhotoOCR/script_identification/models
 
166
 
167
 
168
  build/
 
163
 
164
  IndicPhotoOCR/script_identification/images
165
  IndicPhotoOCR/script_identification/models
166
+ IndicPhotoOCR/script_identification/vit/models
167
 
168
 
169
  build/
IndicPhotoOCR/detection/east_detector.py CHANGED
@@ -74,14 +74,15 @@ class EASTdetector:
74
 
75
  return bbox_result_dict
76
 
77
- # if __name__ == "__main__":
78
- # import argparse
79
- # parser = argparse.ArgumentParser(description='Text detection using EAST model')
80
- # parser.add_argument('--image_path', type=str, required=True, help='Path to the input image')
81
- # parser.add_argument('--device', type=str, default='cpu', help='Device to run the model on, e.g., "cpu" or "cuda"')
82
- # parser.add_argument('--model_checkpoint', type=str, required=True, help='Path to the model checkpoint file')
83
- # args = parser.parse_args()
84
 
85
- # # Run prediction and get results as dictionary
86
- # detection_result = predict(args.image_path, args.device, args.model_checkpoint)
87
- # print(detection_result)
 
 
74
 
75
  return bbox_result_dict
76
 
77
+ if __name__ == "__main__":
78
+ import argparse
79
+ parser = argparse.ArgumentParser(description='Text detection using EAST model')
80
+ parser.add_argument('--image_path', type=str, required=True, help='Path to the input image')
81
+ parser.add_argument('--device', type=str, default='cpu', help='Device to run the model on, e.g., "cpu" or "cuda"')
82
+ parser.add_argument('--model_checkpoint', type=str, required=True, help='Path to the model checkpoint file')
83
+ args = parser.parse_args()
84
 
85
+ # Run prediction and get results as dictionary
86
+ east = EASTdetector(model_path = args.model_checkpoint)
87
+ detection_result = east.detect(args.image_path, args.model_checkpoint, args.device)
88
+ # print(detection_result)
IndicPhotoOCR/ocr.py CHANGED
@@ -7,11 +7,14 @@ import numpy as np
7
 
8
 
9
  # from IndicPhotoOCR.detection.east_detector import EASTdetector
10
- from IndicPhotoOCR.script_identification.CLIP_identifier import CLIPidentifier
 
11
  from IndicPhotoOCR.recognition.parseq_recogniser import PARseqrecogniser
12
  import IndicPhotoOCR.detection.east_config as cfg
13
  from IndicPhotoOCR.detection.textbpn.textbpnpp_detector import TextBPNpp_detector
14
 
 
 
15
 
16
  class OCR:
17
  def __init__(self, device='cuda:0', verbose=False):
@@ -22,7 +25,8 @@ class OCR:
22
  # self.detector = EASTdetector()
23
  self.detector = TextBPNpp_detector(device=self.device)
24
  self.recogniser = PARseqrecogniser()
25
- self.identifier = CLIPidentifier()
 
26
 
27
  # def detect(self, image_path, detect_model_checkpoint=cfg.checkpoint):
28
  # """Run the detection model to get bounding boxes of text areas."""
@@ -123,6 +127,7 @@ class OCR:
123
 
124
  def ocr(self, image_path):
125
  """Process the image by detecting text areas, identifying script, and recognizing text."""
 
126
  recognized_words = []
127
  image = Image.open(image_path)
128
 
@@ -130,25 +135,41 @@ class OCR:
130
  detections = self.detect(image_path)
131
 
132
  # Process each detected text area
133
- for bbox in detections:
134
- # Crop and identify script language
135
- script_lang, cropped_path = self.crop_and_identify_script(image, bbox)
136
 
137
- # Check if the script language is valid
138
- if script_lang:
 
 
 
 
 
 
 
139
 
140
- # Recognize text
141
- recognized_word = self.recognise(cropped_path, script_lang)
142
- recognized_words.append(recognized_word)
143
 
144
- if self.verbose:
145
- print(f"Recognized word: {recognized_word}")
 
 
 
 
 
 
 
 
 
 
 
146
 
147
- return recognized_words
 
148
 
149
  if __name__ == '__main__':
150
  # detect_model_checkpoint = 'bharatSTR/East/tmp/epoch_990_checkpoint.pth.tar'
151
- sample_image_path = 'test_images/image_141.jpg'
152
  cropped_image_path = 'test_images/cropped_image/image_141_0.jpg'
153
 
154
  ocr = OCR(device="cuda", verbose=False)
 
7
 
8
 
9
  # from IndicPhotoOCR.detection.east_detector import EASTdetector
10
+ # from IndicPhotoOCR.script_identification.CLIP_identifier import CLIPidentifier
11
+ from IndicPhotoOCR.script_identification.vit.vit_infer import VIT_identifier
12
  from IndicPhotoOCR.recognition.parseq_recogniser import PARseqrecogniser
13
  import IndicPhotoOCR.detection.east_config as cfg
14
  from IndicPhotoOCR.detection.textbpn.textbpnpp_detector import TextBPNpp_detector
15
 
16
+ from IndicPhotoOCR.utils.helper import detect_para
17
+
18
 
19
  class OCR:
20
  def __init__(self, device='cuda:0', verbose=False):
 
25
  # self.detector = EASTdetector()
26
  self.detector = TextBPNpp_detector(device=self.device)
27
  self.recogniser = PARseqrecogniser()
28
+ # self.identifier = CLIPidentifier()
29
+ self.identifier = VIT_identifier()
30
 
31
  # def detect(self, image_path, detect_model_checkpoint=cfg.checkpoint):
32
  # """Run the detection model to get bounding boxes of text areas."""
 
127
 
128
  def ocr(self, image_path):
129
  """Process the image by detecting text areas, identifying script, and recognizing text."""
130
+ recognized_texts = {}
131
  recognized_words = []
132
  image = Image.open(image_path)
133
 
 
135
  detections = self.detect(image_path)
136
 
137
  # Process each detected text area
138
+ # for bbox in detections:
139
+ # # Crop and identify script language
140
+ # script_lang, cropped_path = self.crop_and_identify_script(image, bbox)
141
 
142
+ # # Check if the script language is valid
143
+ # if script_lang:
144
+
145
+ # # Recognize text
146
+ # recognized_word = self.recognise(cropped_path, script_lang)
147
+ # recognized_words.append(recognized_word)
148
+
149
+ # if self.verbose:
150
+ # print(f"Recognized word: {recognized_word}")
151
 
 
 
 
152
 
153
+ for id, bbox in enumerate(detections):
154
+ # Identify the script and crop the image to this region
155
+ script_lang, cropped_path = self.crop_and_identify_script(image, bbox)
156
+
157
+ # Calculate bounding box coordinates
158
+ x1 = min([bbox[i][0] for i in range(len(bbox))])
159
+ y1 = min([bbox[i][1] for i in range(len(bbox))])
160
+ x2 = max([bbox[i][0] for i in range(len(bbox))])
161
+ y2 = max([bbox[i][1] for i in range(len(bbox))])
162
+
163
+ if script_lang:
164
+ recognized_text = self.recognise(cropped_path, script_lang)
165
+ recognized_texts[f"img_{id}"] = {"txt": recognized_text, "bbox": [x1, y1, x2, y2]}
166
 
167
+ return detect_para(recognized_texts)
168
+ # return recognized_words
169
 
170
  if __name__ == '__main__':
171
  # detect_model_checkpoint = 'bharatSTR/East/tmp/epoch_990_checkpoint.pth.tar'
172
+ sample_image_path = 'test_images/image_88.jpg'
173
  cropped_image_path = 'test_images/cropped_image/image_141_0.jpg'
174
 
175
  ocr = OCR(device="cuda", verbose=False)
IndicPhotoOCR/script_identification/vit/__init__.py ADDED
File without changes
IndicPhotoOCR/script_identification/vit/config.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ common_config={
2
+ 'pretrained_vit_model': 'google/vit-base-patch16-224-in21k'
3
+ }
4
+
5
+ train_config = {
6
+ 'epochs': 20,
7
+ 'max_images_real':1900,
8
+ 'classes':12,
9
+ 'hindi_path_real': '<path_for_hindi_dataset>',
10
+ 'english_path_real':'<path_for_eng_dataset>',
11
+ 'gujarati_path_real':'<path_for_gujarati_dataset>',
12
+ 'punjabi_path_real':'<path_for_punjabi_dataset>',
13
+ 'assamese_path_real':'<path_for_assamese_dataset>',
14
+ 'bengali_path_real':'<path_for_bengali_dataset>',
15
+ 'kannada_path_real':'<path_for_kannada_dataset>',
16
+ 'malayalam_path_real':'<path_for_malayalam_dataset>',
17
+ 'marathi_path_real':'<path_for_marathi_dataset>',
18
+ 'odia_path_real':'<path_for_odia_dataset>',
19
+ 'tamil_path_real':'<path_for_tamil_dataset>',
20
+ 'telugu_path_real':'<path_for_telegu_dataset>',
21
+ 'checkpoints_dir': '<path_for_model>'
22
+
23
+ }
24
+ train_config.update(common_config)
25
+
26
+ test_config = {
27
+ 'reload_model': '<path_for_model>',
28
+ 'max_images':2000,
29
+ 'classes':12,
30
+ 'hindi_path_real': '<path_for_hindi_dataset>',
31
+ 'english_path_real':'<path_for_eng_dataset>',
32
+ 'gujarati_path_real':'<path_for_gujarati_dataset>',
33
+ 'punjabi_path_real':'<path_for_punjabi_dataset>',
34
+ 'assamese_path_real':'<path_for_assamese_dataset>',
35
+ 'bengali_path_real':'<path_for_bengali_dataset>',
36
+ 'kannada_path_real':'<path_for_kannada_dataset>',
37
+ 'malayalam_path_real':'<path_for_malayalam_dataset>',
38
+ 'marathi_path_real':'<path_for_marathi_dataset>',
39
+ 'odia_path_real':'<path_for_odia_dataset>',
40
+ 'tamil_path_real':'<path_for_tamil_dataset>',
41
+ 'telugu_path_real':'<path_for_telegu_dataset>',
42
+
43
+
44
+ }
45
+ test_config.update(common_config)
46
+
47
+
48
+
49
+ infer_config = {
50
+ 'model_path':'<path_for_model>',
51
+ 'img_path': 'image_path',
52
+ 'folder_path':'<path_dataset_folder>',
53
+ 'csv_path':'<csv_path>',
54
+ }
55
+
56
+
57
+ infer_config.update(common_config)
58
+
IndicPhotoOCR/script_identification/vit/vit_infer.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoImageProcessor,ViTForImageClassification,pipeline
2
+ from PIL import Image
3
+ from datasets import DatasetDict,Dataset,ClassLabel
4
+ import torchvision.transforms as transforms
5
+ import numpy as np
6
+ import csv
7
+ import os
8
+ import argparse
9
+ import requests
10
+ from tqdm import tqdm
11
+ import zipfile
12
+ import time
13
+ import glob
14
+ from IndicPhotoOCR.script_identification.vit.config import infer_config as config
15
+
16
+ model_info = {
17
+ "hindi": {
18
+ "path": "models/hindienglish",
19
+ "url" : "https://github.com/Bhashini-IITJ/ScriptIdentification/releases/download/Vit_Models/hindienglish.zip",
20
+ "subcategories": ["hindi", "english"]
21
+ },
22
+ "assamese": {
23
+ "path": "models/hindienglishassamese",
24
+ "url": "https://github.com/Bhashini-IITJ/ScriptIdentification/releases/download/Vit_Models/hindienglishassamese.zip",
25
+ "subcategories": ["hindi", "english", "assamese"]
26
+ },
27
+ "bengali": {
28
+ "path": "models/hindienglishbengali",
29
+ "url" : "https://github.com/Bhashini-IITJ/ScriptIdentification/releases/download/Vit_Models/hindienglishbengali.zip",
30
+ "subcategories": ["hindi", "english", "bengali"]
31
+ },
32
+ "gujarati": {
33
+ "path": "models/hindienglishgujarati",
34
+ "url" : "https://github.com/Bhashini-IITJ/ScriptIdentification/releases/download/Vit_Models/hindienglishgujarati.zip",
35
+ "subcategories": ["hindi", "english", "gujarati"]
36
+ },
37
+ "kannada": {
38
+ "path": "models/hindienglishkannada",
39
+ "url" : "https://github.com/Bhashini-IITJ/ScriptIdentification/releases/download/Vit_Models/hindienglishkannada.zip",
40
+ "subcategories": ["hindi", "english", "kannada"]
41
+ },
42
+ "malayalam": {
43
+ "path": "models/hindienglishmalayalam",
44
+ "url" : "https://github.com/Bhashini-IITJ/ScriptIdentification/releases/download/Vit_Models/hindienglishmalayalam.zip",
45
+ "subcategories": ["hindi", "english", "malayalam"]
46
+ },
47
+ "marathi": {
48
+ "path": "models/hindienglishmarathi",
49
+ "url" : "https://github.com/Bhashini-IITJ/ScriptIdentification/releases/download/Vit_Models/hindienglishmarathi.zip",
50
+ "subcategories": ["hindi", "english", "marathi"]
51
+ },
52
+ "meitei": {
53
+ "path": "models/hindienglishmeitei",
54
+ "url" : "https://github.com/Bhashini-IITJ/ScriptIdentification/releases/download/Vit_Models/hindienglishmeitei.zip",
55
+ "subcategories": ["hindi", "english", "meitei"]
56
+ },
57
+ "odia": {
58
+ "path": "models/hindienglishodia",
59
+ "url" : "https://github.com/Bhashini-IITJ/ScriptIdentification/releases/download/Vit_Models/hindienglishodia.zip",
60
+ "subcategories": ["hindi", "english", "odia"]
61
+ },
62
+ "punjabi": {
63
+ "path": "models/hindienglishpunjabi",
64
+ "url" : "https://github.com/Bhashini-IITJ/ScriptIdentification/releases/download/Vit_Models/hindienglishpunjabi.zip",
65
+ "subcategories": ["hindi", "english", "punjabi"]
66
+ },
67
+ "tamil": {
68
+ "path": "models/hindienglishtamil",
69
+ "url" : "https://github.com/Bhashini-IITJ/ScriptIdentification/releases/download/Vit_Models/hindienglishtamil.zip",
70
+ "subcategories": ["hindi", "english", "tamil"]
71
+ },
72
+ "telugu": {
73
+ "path": "models/hindienglishtelugu",
74
+ "url" : "https://github.com/Bhashini-IITJ/ScriptIdentification/releases/download/Vit_Models/hindienglishtelugu.zip",
75
+ "subcategories": ["hindi", "english", "telugu"]
76
+ },
77
+ "12C": {
78
+ "path": "models/12_classes",
79
+ "url" : "https://github.com/Bhashini-IITJ/ScriptIdentification/releases/download/Vit_Models/12_classes.zip",
80
+ "subcategories": ["hindi", "english", "assamese","bengali","gujarati","kannada","malayalam","marathi","odia","punjabi","tamil","telegu"]
81
+ },
82
+
83
+
84
+ }
85
+
86
+ pretrained_vit_model = config['pretrained_vit_model']
87
+ processor = AutoImageProcessor.from_pretrained(pretrained_vit_model,use_fast=True)
88
+
89
+
90
+ class VIT_identifier:
91
+ def __init__(self):
92
+ pass
93
+
94
+ def unzip_file(self, zip_path, extract_to):
95
+
96
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
97
+ zip_ref.extractall(extract_to)
98
+ print(f"Extracted files to {extract_to}")
99
+
100
+
101
+
102
+
103
+ def ensure_model(self, model_name):
104
+ model_path = model_info[model_name]["path"]
105
+ url = model_info[model_name]["url"]
106
+ root_model_dir = "IndicPhotoOCR/script_identification/vit"
107
+ model_path = os.path.join(root_model_dir, model_path)
108
+
109
+ if not os.path.exists(model_path):
110
+ print(f"Model not found locally. Downloading {model_name} from {url}...")
111
+
112
+ response = requests.get(url, stream=True)
113
+ zip_path = os.path.join(model_path, "temp_download.zip")
114
+
115
+ os.makedirs(model_path, exist_ok=True)
116
+
117
+ with open(zip_path, "wb") as file:
118
+ for chunk in response.iter_content(chunk_size=8192):
119
+ file.write(chunk)
120
+
121
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
122
+ zip_ref.extractall(model_path)
123
+
124
+ os.remove(zip_path)
125
+
126
+ print(f"Downloaded and extracted to {model_path}")
127
+
128
+ else:
129
+ # print(f"Model folder already exists: {model_path}")
130
+ pass
131
+
132
+ return model_path
133
+
134
+
135
+
136
+
137
+
138
+ def identify(self, image_path,model_name):
139
+ model_path = self.ensure_model(model_name)
140
+
141
+ vit = ViTForImageClassification.from_pretrained(model_path)
142
+ model= pipeline('image-classification', model=vit, feature_extractor=processor,device=0)
143
+
144
+ if image_path.endswith((".png", ".jpg", ".jpeg")):
145
+
146
+ image = Image.open(image_path)
147
+ output = model(image)
148
+ predicted_label = max(output, key=lambda x: x['score'])['label']
149
+
150
+ # print(f"image_path: {image_path}, predicted_label: {predicted_label}\n")
151
+
152
+ return predicted_label
153
+
154
+
155
+ def predict_batch(self, image_dir,model_name,time_show,output_csv="prediction.csv"):
156
+ model_path = self.ensure_model(model_name)
157
+ vit = ViTForImageClassification.from_pretrained(model_path)
158
+ model= pipeline('image-classification', model=vit, feature_extractor=processor,device=0)
159
+
160
+ start_time = time.time()
161
+ results=[]
162
+ image_count=0
163
+ for filename in os.listdir(image_dir):
164
+
165
+ if filename.endswith((".png", ".jpg", ".jpeg")):
166
+ img_path = os.path.join(image_dir, filename)
167
+ image = Image.open(img_path)
168
+
169
+
170
+ output = model(image)
171
+ predicted_label = max(output, key=lambda x: x['score'])['label'].capitalize()
172
+
173
+ results.append({"Filepath": filename, "Language": predicted_label})
174
+ image_count+=1
175
+
176
+ elapsed_time = time.time() - start_time
177
+
178
+ if time_show:
179
+ print(f"Time taken to process {image_count} images: {elapsed_time:.2f} seconds")
180
+
181
+ with open(output_csv, mode="w", newline="", encoding="utf-8") as csvfile:
182
+ writer = csv.DictWriter(csvfile, fieldnames=["Filepath", "Language"])
183
+ writer.writeheader()
184
+ writer.writerows(results)
185
+
186
+ return output_csv
187
+
188
+
189
+ # if __name__ == "__main__":
190
+ # # Argument parser for command line usage
191
+ # parser = argparse.ArgumentParser(description="Image classification using CLIP fine-tuned model")
192
+ # parser.add_argument("--image_path", type=str, help="Path to the input image")
193
+ # parser.add_argument("--image_dir", type=str, help="Path to the input image directory")
194
+ # parser.add_argument("--model_name", type=str, choices=model_info.keys(), help="Name of the model (e.g., hineng, hinengpun, hinengguj)")
195
+ # parser.add_argument("--batch", action="store_true", help="Process images in batch mode if specified")
196
+ # parser.add_argument("--time",type=bool, nargs="?", const=True, default=False, help="Prints the time required to process a batch of images")
197
+
198
+ # args = parser.parse_args()
199
+
200
+
201
+ # # Choose function based on the batch parameter
202
+ # if args.batch:
203
+ # if not args.image_dir:
204
+ # print("Error: image_dir is required when batch is set to True.")
205
+ # else:
206
+ # result = predict_batch(args.image_dir, args.model_name, args.time)
207
+ # print(result)
208
+ # else:
209
+ # if not args.image_path:
210
+ # print("Error: image_path is required when batch is not set.")
211
+ # else:
212
+ # result = predict(args.image_path, args.model_name)
213
+ # print(result)
IndicPhotoOCR/utils/__init__.py ADDED
File without changes
IndicPhotoOCR/utils/helper.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ # def detect_para(bbox_dict):
4
+ # alpha1 = 0.2
5
+ # alpha2 = 0.7
6
+ # beta1 = 0.4
7
+ # data = bbox_dict
8
+ # word_crops = list(data.keys())
9
+ # for i in word_crops:
10
+ # data[i]["x1"], data[i]["y1"], data[i]["x2"], data[i]["y2"] = data[i]["bbox"]
11
+ # data[i]["xc"] = (data[i]["x1"] + data[i]["x2"]) / 2
12
+ # data[i]["yc"] = (data[i]["y1"] + data[i]["y2"]) / 2
13
+ # data[i]["w"] = data[i]["x2"] - data[i]["x1"]
14
+ # data[i]["h"] = data[i]["y2"] - data[i]["y1"]
15
+
16
+ # patch_info = {}
17
+ # while word_crops:
18
+ # img_name = word_crops[0].split("_")[0]
19
+ # word_crop_collection = [
20
+ # word_crop for word_crop in word_crops if word_crop.startswith(img_name)
21
+ # ]
22
+ # centroids = {}
23
+ # lines = []
24
+ # img_word_crops = word_crop_collection.copy()
25
+ # para = []
26
+ # while img_word_crops:
27
+ # clusters = []
28
+ # para_words_group = [
29
+ # img_word_crops[0],
30
+ # ]
31
+ # added = [
32
+ # img_word_crops[0],
33
+ # ]
34
+ # img_word_crops.remove(img_word_crops[0])
35
+ # ## determining the paragraph
36
+ # while added:
37
+ # word_crop = added.pop()
38
+ # for i in range(len(img_word_crops)):
39
+ # word_crop_ = img_word_crops[i]
40
+ # if (
41
+ # abs(data[word_crop_]["yc"] - data[word_crop]["yc"])
42
+ # < data[word_crop]["h"] * alpha1
43
+ # ):
44
+ # if data[word_crop]["xc"] > data[word_crop_]["xc"]:
45
+ # if (data[word_crop]["x1"] - data[word_crop_]["x2"]) < data[
46
+ # word_crop
47
+ # ]["h"] * alpha2:
48
+ # para_words_group.append(word_crop_)
49
+ # added.append(word_crop_)
50
+ # else:
51
+ # if (data[word_crop_]["x1"] - data[word_crop]["x2"]) < data[
52
+ # word_crop
53
+ # ]["h"] * alpha2:
54
+ # para_words_group.append(word_crop_)
55
+ # added.append(word_crop_)
56
+ # else:
57
+ # if data[word_crop]["yc"] > data[word_crop_]["yc"]:
58
+ # if (data[word_crop]["y1"] - data[word_crop_]["y2"]) < data[
59
+ # word_crop
60
+ # ]["h"] * beta1 and (
61
+ # (
62
+ # (data[word_crop_]["x1"] < data[word_crop]["x2"])
63
+ # and (data[word_crop_]["x1"] > data[word_crop]["x1"])
64
+ # )
65
+ # or (
66
+ # (data[word_crop_]["x2"] < data[word_crop]["x2"])
67
+ # and (data[word_crop_]["x2"] > data[word_crop]["x1"])
68
+ # )
69
+ # or (
70
+ # (data[word_crop]["x1"] > data[word_crop_]["x1"])
71
+ # and (data[word_crop]["x2"] < data[word_crop_]["x2"])
72
+ # )
73
+ # ):
74
+ # para_words_group.append(word_crop_)
75
+ # added.append(word_crop_)
76
+ # else:
77
+ # if (data[word_crop_]["y1"] - data[word_crop]["y2"]) < data[
78
+ # word_crop
79
+ # ]["h"] * beta1 and (
80
+ # (
81
+ # (data[word_crop_]["x1"] < data[word_crop]["x2"])
82
+ # and (data[word_crop_]["x1"] > data[word_crop]["x1"])
83
+ # )
84
+ # or (
85
+ # (data[word_crop_]["x2"] < data[word_crop]["x2"])
86
+ # and (data[word_crop_]["x2"] > data[word_crop]["x1"])
87
+ # )
88
+ # or (
89
+ # (data[word_crop]["x1"] > data[word_crop_]["x1"])
90
+ # and (data[word_crop]["x2"] < data[word_crop_]["x2"])
91
+ # )
92
+ # ):
93
+ # para_words_group.append(word_crop_)
94
+ # added.append(word_crop_)
95
+ # img_word_crops = [p for p in img_word_crops if p not in para_words_group]
96
+ # ## processing for the line
97
+ # while para_words_group:
98
+ # line_words_group = [
99
+ # para_words_group[0],
100
+ # ]
101
+ # added = [
102
+ # para_words_group[0],
103
+ # ]
104
+ # para_words_group.remove(para_words_group[0])
105
+ # ## determining the line
106
+ # while added:
107
+ # word_crop = added.pop()
108
+ # for i in range(len(para_words_group)):
109
+ # word_crop_ = para_words_group[i]
110
+ # if (
111
+ # abs(data[word_crop_]["yc"] - data[word_crop]["yc"])
112
+ # < data[word_crop]["h"] * alpha1
113
+ # ):
114
+ # if data[word_crop]["xc"] > data[word_crop_]["xc"]:
115
+ # if (data[word_crop]["x1"] - data[word_crop_]["x2"]) < data[
116
+ # word_crop
117
+ # ]["h"] * alpha2:
118
+ # line_words_group.append(word_crop_)
119
+ # added.append(word_crop_)
120
+ # else:
121
+ # if (data[word_crop_]["x1"] - data[word_crop]["x2"]) < data[
122
+ # word_crop
123
+ # ]["h"] * alpha2:
124
+ # line_words_group.append(word_crop_)
125
+ # added.append(word_crop_)
126
+ # para_words_group = [
127
+ # p for p in para_words_group if p not in line_words_group
128
+ # ]
129
+ # xc = [data[word_crop]["xc"] for word_crop in line_words_group]
130
+ # idxs = np.argsort(xc)
131
+ # patch_cluster_ = [line_words_group[i] for i in idxs]
132
+ # line_words_group = patch_cluster_
133
+ # x1 = [data[word_crop]["x1"] for word_crop in line_words_group]
134
+ # x2 = [data[word_crop]["x2"] for word_crop in line_words_group]
135
+ # y1 = [data[word_crop]["y1"] for word_crop in line_words_group]
136
+ # y2 = [data[word_crop]["y2"] for word_crop in line_words_group]
137
+ # txt_line = [data[word_crop]["txt"] for word_crop in line_words_group]
138
+ # txt = " ".join(txt_line)
139
+ # x = [x1[0]]
140
+ # y1_ = [y1[0]]
141
+ # y2_ = [y2[0]]
142
+ # l = [len(txt_l) for txt_l in txt_line]
143
+ # for i in range(1, len(x1)):
144
+ # x.append((x1[i] + x2[i - 1]) / 2)
145
+ # y1_.append((y1[i] + y1[i - 1]) / 2)
146
+ # y2_.append((y2[i] + y2[i - 1]) / 2)
147
+ # x.append(x2[-1])
148
+ # y1_.append(y1[-1])
149
+ # y2_.append(y2[-1])
150
+ # line_info = {
151
+ # "x": x,
152
+ # "y1": y1_,
153
+ # "y2": y2_,
154
+ # "l": l,
155
+ # "txt": txt,
156
+ # "word_crops": line_words_group,
157
+ # }
158
+ # clusters.append(line_info)
159
+ # y_ = [clusters[i]["y1"][0] for i in range(len(clusters))]
160
+ # idxs = np.argsort(y_)
161
+ # clusters_ = [clusters[i] for i in idxs]
162
+ # txt = [clusters[i]["txt"] for i in idxs]
163
+ # l = [len(t) for t in txt]
164
+ # txt = " ".join(txt)
165
+ # para_info = {"lines": clusters_, "l": l, "txt": txt}
166
+ # para.append(para_info)
167
+
168
+ # for word_crop in word_crop_collection:
169
+ # word_crops.remove(word_crop)
170
+ # return "\n".join([para[i]["txt"] for i in range(len(para))])
171
+
172
+
173
+ def detect_para(recognized_texts):
174
+ """
175
+ Sort words into lines based on horizontal overlap of bounding boxes.
176
+
177
+ Args:
178
+ recognized_texts (dict): A dictionary with recognized texts as keys and bounding boxes as values.
179
+ Each bounding box is a list of points [x1, y1, x2, y2].
180
+
181
+ Returns:
182
+ list: A list of lists where each sublist contains words sorted by x-coordinate for a single line.
183
+ """
184
+ def calculate_overlap(bbox1, bbox2):
185
+ """Calculate the vertical overlap between two bounding boxes."""
186
+ # Extract bounding box coordinates
187
+ x1_1, y1_1, x2_1, y2_1 = bbox1
188
+ x1_2, y1_2, x2_2, y2_2 = bbox2
189
+
190
+ overlap = max(0, min(y2_1, y2_2) - max(y1_1, y1_2))
191
+ height = min(y2_1 - y1_1, y2_2 - y1_2)
192
+ return overlap / height if height > 0 else 0
193
+
194
+ # Convert recognized_texts dictionary to a list of tuples for processing
195
+ items = list(recognized_texts.items())
196
+ lines = []
197
+
198
+ while items:
199
+ current_image, current_data = items.pop(0)
200
+ current_text, current_bbox = current_data['txt'], current_data['bbox']
201
+ current_line = [(current_text, current_bbox)]
202
+
203
+ remaining_items = []
204
+ for image, data in items:
205
+ text, bbox = data['txt'], data['bbox']
206
+ if calculate_overlap(current_bbox, bbox) > 0.4:
207
+ current_line.append((text, bbox))
208
+ else:
209
+ remaining_items.append((image, data))
210
+
211
+ items = remaining_items
212
+ lines.append(current_line)
213
+
214
+ # Sort words within each line based on x1 (horizontal position)
215
+ sorted_lines = [
216
+ [text for text, bbox in sorted(line, key=lambda x: x[1][0])] for line in lines
217
+ ]
218
+ return sorted_lines
219
+
220
+
app.py CHANGED
@@ -1,12 +1,12 @@
1
- # This is a working demo - textbpn++ - CLIP - Parseq
2
  import gradio as gr
3
  from PIL import Image
4
  import os
5
  from IndicPhotoOCR.ocr import OCR # Ensure OCR class is saved in a file named ocr.py
6
  from IndicPhotoOCR.theme import Seafoam
 
7
 
8
  # Initialize the OCR object for text detection and recognition
9
- ocr = OCR(device="cpu", verbose=False)
10
 
11
  def process_image(image):
12
  """
@@ -36,21 +36,37 @@ def process_image(image):
36
  output_image = Image.open("output_image.png")
37
 
38
  # Initialize list to hold recognized text from each detected area
39
- recognized_texts = []
40
  pil_image = Image.open(image_path)
41
 
42
- # Process each detected bounding box for script identification and text recognition
43
- for bbox in detections:
 
 
 
 
 
 
 
 
44
  # Identify the script and crop the image to this region
45
  script_lang, cropped_path = ocr.crop_and_identify_script(pil_image, bbox)
46
-
47
- if script_lang: # Only proceed if a script language is identified
48
- # Recognize text in the cropped area
 
 
 
 
 
49
  recognized_text = ocr.recognise(cropped_path, script_lang)
50
- recognized_texts.append(recognized_text)
51
-
52
  # Combine recognized texts into a single string for display
53
- recognized_texts_combined = " ".join(recognized_texts)
 
 
 
54
  return output_image, recognized_texts_combined
55
 
56
  # Custom HTML for interface header with logos and alignment
@@ -110,10 +126,10 @@ demo = gr.Interface(
110
  examples=examples
111
  )
112
 
113
- # # Server setup and launch configuration
114
- # if __name__ == "__main__":
115
- # server = "0.0.0.0" # IP address for server
116
- # port = 7865 # Port to run the server on
117
- # demo.launch(server_name=server, server_port=port)
118
 
119
- demo.launch()
 
 
1
  import gradio as gr
2
  from PIL import Image
3
  import os
4
  from IndicPhotoOCR.ocr import OCR # Ensure OCR class is saved in a file named ocr.py
5
  from IndicPhotoOCR.theme import Seafoam
6
+ from IndicPhotoOCR.utils.helper import detect_para
7
 
8
  # Initialize the OCR object for text detection and recognition
9
+ ocr = OCR(verbose=False)
10
 
11
  def process_image(image):
12
  """
 
36
  output_image = Image.open("output_image.png")
37
 
38
  # Initialize list to hold recognized text from each detected area
39
+ recognized_texts = {}
40
  pil_image = Image.open(image_path)
41
 
42
+ # # Process each detected bounding box for script identification and text recognition
43
+ # for bbox in detections:
44
+ # # Identify the script and crop the image to this region
45
+ # script_lang, cropped_path = ocr.crop_and_identify_script(pil_image, bbox)
46
+
47
+ # if script_lang: # Only proceed if a script language is identified
48
+ # # Recognize text in the cropped area
49
+ # recognized_text = ocr.recognise(cropped_path, script_lang)
50
+ # recognized_texts.append(recognized_text)
51
+ for id, bbox in enumerate(detections):
52
  # Identify the script and crop the image to this region
53
  script_lang, cropped_path = ocr.crop_and_identify_script(pil_image, bbox)
54
+
55
+ # Calculate bounding box coordinates
56
+ x1 = min([bbox[i][0] for i in range(len(bbox))])
57
+ y1 = min([bbox[i][1] for i in range(len(bbox))])
58
+ x2 = max([bbox[i][0] for i in range(len(bbox))])
59
+ y2 = max([bbox[i][1] for i in range(len(bbox))])
60
+
61
+ if script_lang:
62
  recognized_text = ocr.recognise(cropped_path, script_lang)
63
+ recognized_texts[f"img_{id}"] = {"txt": recognized_text, "bbox": [x1, y1, x2, y2]}
64
+
65
  # Combine recognized texts into a single string for display
66
+ # recognized_texts_combined = " ".join(recognized_texts)
67
+ string = detect_para(recognized_texts)
68
+ recognized_texts_combined = '\n'.join([' '.join(line) for line in string])
69
+
70
  return output_image, recognized_texts_combined
71
 
72
  # Custom HTML for interface header with logos and alignment
 
126
  examples=examples
127
  )
128
 
129
+ # Server setup and launch configuration
130
+ if __name__ == "__main__":
131
+ server = "0.0.0.0" # IP address for server
132
+ port = 7865 # Port to run the server on
133
+ demo.launch(server_name=server, server_port=port, share=True)
134
 
135
+ # demo.launch()
requirements.txt CHANGED
@@ -43,4 +43,5 @@ torch==2.5.0
43
  torchvision==0.20.0
44
  easydict==1.13
45
  scipy==1.13.1
46
-
 
 
43
  torchvision==0.20.0
44
  easydict==1.13
45
  scipy==1.13.1
46
+ transformers==4.45.1
47
+ datasets==3.1.0