shreyasvaidya's picture
Upload folder using huggingface_hub
01bb3bb verified
raw
history blame
8.46 kB
import torch
import clip
from PIL import Image
from io import BytesIO
import os
import requests
# Model information dictionary containing model paths and language subcategories
model_info = {
"hindi": {
"path": "models/clip_finetuned_hindienglish_real.pth",
"url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglish_real.pth",
"subcategories": ["hindi", "english"]
},
"hinengasm": {
"path": "models/clip_finetuned_hindienglishassamese_real.pth",
"url": "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishassamese_real.pth",
"subcategories": ["hindi", "english", "assamese"]
},
"hinengben": {
"path": "models/clip_finetuned_hindienglishbengali_real.pth",
"url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishbengali_real.pth",
"subcategories": ["hindi", "english", "bengali"]
},
"hinengguj": {
"path": "models/clip_finetuned_hindienglishgujarati_real.pth",
"url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishgujarati_real.pth",
"subcategories": ["hindi", "english", "gujarati"]
},
"hinengkan": {
"path": "models/clip_finetuned_hindienglishkannada_real.pth",
"url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishkannada_real.pth",
"subcategories": ["hindi", "english", "kannada"]
},
"hinengmal": {
"path": "models/clip_finetuned_hindienglishmalayalam_real.pth",
"url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishmalayalam_real.pth",
"subcategories": ["hindi", "english", "malayalam"]
},
"hinengmar": {
"path": "models/clip_finetuned_hindienglishmarathi_real.pth",
"url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishmarathi_real.pth",
"subcategories": ["hindi", "english", "marathi"]
},
"hinengmei": {
"path": "models/clip_finetuned_hindienglishmeitei_real.pth",
"url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishmeitei_real.pth",
"subcategories": ["hindi", "english", "meitei"]
},
"hinengodi": {
"path": "models/clip_finetuned_hindienglishodia_real.pth",
"url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishodia_real.pth",
"subcategories": ["hindi", "english", "odia"]
},
"hinengpun": {
"path": "models/clip_finetuned_hindienglishpunjabi_real.pth",
"url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishpunjabi_real.pth",
"subcategories": ["hindi", "english", "punjabi"]
},
"hinengtam": {
"path": "models/clip_finetuned_hindienglishtamil_real.pth",
"url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishtamil_real.pth",
"subcategories": ["hindi", "english", "tamil"]
},
"hinengtel": {
"path": "models/clip_finetuned_hindienglishtelugu_real.pth",
"url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishtelugu_real.pth",
"subcategories": ["hindi", "english", "telugu"]
},
"hinengurd": {
"path": "models/clip_finetuned_hindienglishurdu_real.pth",
"url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishurdu_real.pth",
"subcategories": ["hindi", "english", "urdu"]
},
}
# Set device to CUDA if available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
clip_model, preprocess = clip.load("ViT-B/32", device=device)
class CLIPFineTuner(torch.nn.Module):
"""
Fine-tuning class for the CLIP model to adapt to specific tasks.
Attributes:
model (torch.nn.Module): The CLIP model to be fine-tuned.
classifier (torch.nn.Linear): A linear classifier to map features to the desired number of classes.
"""
def __init__(self, model, num_classes):
"""
Initializes the fine-tuner with the CLIP model and classifier.
Args:
model (torch.nn.Module): The base CLIP model.
num_classes (int): The number of target classes for classification.
"""
super(CLIPFineTuner, self).__init__()
self.model = model
self.classifier = torch.nn.Linear(model.visual.output_dim, num_classes)
def forward(self, x):
"""
Forward pass for image classification.
Args:
x (torch.Tensor): Preprocessed input tensor for an image.
Returns:
torch.Tensor: Logits for each class.
"""
with torch.no_grad():
features = self.model.encode_image(x).float() # Extract image features from CLIP model
return self.classifier(features) # Return class logits
class CLIPidentifier:
def __init__(self):
pass
# Ensure model file exists; download directly if not
def ensure_model(self, model_name):
model_path = model_info[model_name]["path"]
url = model_info[model_name]["url"]
root_model_dir = "IndicPhotoOCR/script_identification/"
model_path = os.path.join(root_model_dir, model_path)
if not os.path.exists(model_path):
print(f"Model not found locally. Downloading {model_name} from {url}...")
response = requests.get(url, stream=True)
os.makedirs(f"{root_model_dir}/models", exist_ok=True)
with open(f"{model_path}", "wb") as f:
f.write(response.content)
print(f"Downloaded model for {model_name}.")
return model_path
# Prediction function to verify and load the model
def identify(self, image_path, model_name):
"""
Predicts the class of an input image using a fine-tuned CLIP model.
Args:
image_path (str): Path to the input image file.
model_name (str): Name of the model (e.g., hineng, hinengpun, hinengguj) as specified in `model_info`.
Returns:
dict: Contains either `predicted_class` if successful or `error` if an exception occurs.
Example usage:
result = predict("sample_image.jpg", "hinengguj")
print(result) # Output might be {'predicted_class': 'hindi'}
"""
try:
# Validate model name and retrieve associated subcategories
if model_name not in model_info:
return {"error": "Invalid model name"}
# Ensure the model file is downloaded and accessible
model_path = self.ensure_model(model_name)
subcategories = model_info[model_name]["subcategories"]
num_classes = len(subcategories)
# Load the fine-tuned model with the specified number of classes
model_ft = CLIPFineTuner(clip_model, num_classes)
model_ft.load_state_dict(torch.load(model_path, map_location=device))
model_ft = model_ft.to(device)
model_ft.eval()
# Load and preprocess the image
image = Image.open(image_path).convert("RGB")
input_tensor = preprocess(image).unsqueeze(0).to(device)
# Run the model and get the prediction
outputs = model_ft(input_tensor)
_, predicted_idx = torch.max(outputs, 1)
predicted_class = subcategories[predicted_idx.item()]
return predicted_class
except Exception as e:
return {"error": str(e)}
# if __name__ == "__main__":
# import argparse
# # Argument parser for command line usage
# parser = argparse.ArgumentParser(description="Image classification using CLIP fine-tuned model")
# parser.add_argument("image_path", type=str, help="Path to the input image")
# parser.add_argument("model_name", type=str, choices=model_info.keys(), help="Name of the model (e.g., hineng, hinengpun, hinengguj)")
# args = parser.parse_args()
# # Execute prediction with command line inputs
# result = predict(args.image_path, args.model_name)
# print(result)