import os import torch import cv2 import numpy as np import time import warnings import IndicPhotoOCR.detection.east_config as cfg from IndicPhotoOCR.detection.east_utils import ModelManager from IndicPhotoOCR.detection.east_model import East import IndicPhotoOCR.detection.east_utils as utils # Suppress warnings warnings.filterwarnings("ignore") class EASTdetector: def __init__(self, model_name= "east", model_path=None): self.model_path = model_path # self.model_manager = ModelManager() # self.model_manager.ensure_model(model_name) # self.ensure_model(self.model_name) # self.root_model_dir = "BharatSTR/bharatOCR/detection/East/tmp" def detect(self, image_path, model_checkpoint, device): # Load image im = cv2.imread(image_path) # im = cv2.imread(image_path)[:, :, ::-1] # Initialize the EAST model and load checkpoint model = East() model = torch.nn.DataParallel(model, device_ids=cfg.gpu_ids) # Load the model checkpoint with weights_only=True checkpoint = torch.load(model_checkpoint, map_location=torch.device(device), weights_only=True) model.load_state_dict(checkpoint['state_dict']) model.eval() # Resize image and convert to tensor format im_resized, (ratio_h, ratio_w) = utils.resize_image(im) im_resized = im_resized.astype(np.float32).transpose(2, 0, 1) im_tensor = torch.from_numpy(im_resized).unsqueeze(0).cpu() # Inference timer = {'net': 0, 'restore': 0, 'nms': 0} start = time.time() score, geometry = model(im_tensor) timer['net'] = time.time() - start # Process output score = score.permute(0, 2, 3, 1).data.cpu().numpy() geometry = geometry.permute(0, 2, 3, 1).data.cpu().numpy() # Detect boxes boxes, timer = utils.detect( score_map=score, geo_map=geometry, timer=timer, score_map_thresh=cfg.score_map_thresh, box_thresh=cfg.box_thresh, nms_thres=cfg.box_thresh ) bbox_result_dict = {'detections': []} # Parse detected boxes and adjust coordinates if boxes is not None: boxes = boxes[:, :8].reshape((-1, 4, 2)) boxes[:, :, 0] /= ratio_w boxes[:, :, 1] /= ratio_h for box in boxes: box = utils.sort_poly(box.astype(np.int32)) if np.linalg.norm(box[0] - box[1]) < 5 or np.linalg.norm(box[3] - box[0]) < 5: continue bbox_result_dict['detections'].append([ [int(coord[0]), int(coord[1])] for coord in box ]) return bbox_result_dict if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description='Text detection using EAST model') parser.add_argument('--image_path', type=str, required=True, help='Path to the input image') parser.add_argument('--device', type=str, default='cpu', help='Device to run the model on, e.g., "cpu" or "cuda"') parser.add_argument('--model_checkpoint', type=str, required=True, help='Path to the model checkpoint file') args = parser.parse_args() # Run prediction and get results as dictionary east = EASTdetector(model_path = args.model_checkpoint) detection_result = east.detect(args.image_path, args.model_checkpoint, args.device) # print(detection_result)