Spaces:
Sleeping
Sleeping
File size: 8,456 Bytes
01bb3bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
import torch
import clip
from PIL import Image
from io import BytesIO
import os
import requests
# Model information dictionary containing model paths and language subcategories
model_info = {
"hindi": {
"path": "models/clip_finetuned_hindienglish_real.pth",
"url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglish_real.pth",
"subcategories": ["hindi", "english"]
},
"hinengasm": {
"path": "models/clip_finetuned_hindienglishassamese_real.pth",
"url": "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishassamese_real.pth",
"subcategories": ["hindi", "english", "assamese"]
},
"hinengben": {
"path": "models/clip_finetuned_hindienglishbengali_real.pth",
"url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishbengali_real.pth",
"subcategories": ["hindi", "english", "bengali"]
},
"hinengguj": {
"path": "models/clip_finetuned_hindienglishgujarati_real.pth",
"url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishgujarati_real.pth",
"subcategories": ["hindi", "english", "gujarati"]
},
"hinengkan": {
"path": "models/clip_finetuned_hindienglishkannada_real.pth",
"url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishkannada_real.pth",
"subcategories": ["hindi", "english", "kannada"]
},
"hinengmal": {
"path": "models/clip_finetuned_hindienglishmalayalam_real.pth",
"url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishmalayalam_real.pth",
"subcategories": ["hindi", "english", "malayalam"]
},
"hinengmar": {
"path": "models/clip_finetuned_hindienglishmarathi_real.pth",
"url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishmarathi_real.pth",
"subcategories": ["hindi", "english", "marathi"]
},
"hinengmei": {
"path": "models/clip_finetuned_hindienglishmeitei_real.pth",
"url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishmeitei_real.pth",
"subcategories": ["hindi", "english", "meitei"]
},
"hinengodi": {
"path": "models/clip_finetuned_hindienglishodia_real.pth",
"url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishodia_real.pth",
"subcategories": ["hindi", "english", "odia"]
},
"hinengpun": {
"path": "models/clip_finetuned_hindienglishpunjabi_real.pth",
"url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishpunjabi_real.pth",
"subcategories": ["hindi", "english", "punjabi"]
},
"hinengtam": {
"path": "models/clip_finetuned_hindienglishtamil_real.pth",
"url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishtamil_real.pth",
"subcategories": ["hindi", "english", "tamil"]
},
"hinengtel": {
"path": "models/clip_finetuned_hindienglishtelugu_real.pth",
"url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishtelugu_real.pth",
"subcategories": ["hindi", "english", "telugu"]
},
"hinengurd": {
"path": "models/clip_finetuned_hindienglishurdu_real.pth",
"url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishurdu_real.pth",
"subcategories": ["hindi", "english", "urdu"]
},
}
# Set device to CUDA if available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
clip_model, preprocess = clip.load("ViT-B/32", device=device)
class CLIPFineTuner(torch.nn.Module):
"""
Fine-tuning class for the CLIP model to adapt to specific tasks.
Attributes:
model (torch.nn.Module): The CLIP model to be fine-tuned.
classifier (torch.nn.Linear): A linear classifier to map features to the desired number of classes.
"""
def __init__(self, model, num_classes):
"""
Initializes the fine-tuner with the CLIP model and classifier.
Args:
model (torch.nn.Module): The base CLIP model.
num_classes (int): The number of target classes for classification.
"""
super(CLIPFineTuner, self).__init__()
self.model = model
self.classifier = torch.nn.Linear(model.visual.output_dim, num_classes)
def forward(self, x):
"""
Forward pass for image classification.
Args:
x (torch.Tensor): Preprocessed input tensor for an image.
Returns:
torch.Tensor: Logits for each class.
"""
with torch.no_grad():
features = self.model.encode_image(x).float() # Extract image features from CLIP model
return self.classifier(features) # Return class logits
class CLIPidentifier:
def __init__(self):
pass
# Ensure model file exists; download directly if not
def ensure_model(self, model_name):
model_path = model_info[model_name]["path"]
url = model_info[model_name]["url"]
root_model_dir = "IndicPhotoOCR/script_identification/"
model_path = os.path.join(root_model_dir, model_path)
if not os.path.exists(model_path):
print(f"Model not found locally. Downloading {model_name} from {url}...")
response = requests.get(url, stream=True)
os.makedirs(f"{root_model_dir}/models", exist_ok=True)
with open(f"{model_path}", "wb") as f:
f.write(response.content)
print(f"Downloaded model for {model_name}.")
return model_path
# Prediction function to verify and load the model
def identify(self, image_path, model_name):
"""
Predicts the class of an input image using a fine-tuned CLIP model.
Args:
image_path (str): Path to the input image file.
model_name (str): Name of the model (e.g., hineng, hinengpun, hinengguj) as specified in `model_info`.
Returns:
dict: Contains either `predicted_class` if successful or `error` if an exception occurs.
Example usage:
result = predict("sample_image.jpg", "hinengguj")
print(result) # Output might be {'predicted_class': 'hindi'}
"""
try:
# Validate model name and retrieve associated subcategories
if model_name not in model_info:
return {"error": "Invalid model name"}
# Ensure the model file is downloaded and accessible
model_path = self.ensure_model(model_name)
subcategories = model_info[model_name]["subcategories"]
num_classes = len(subcategories)
# Load the fine-tuned model with the specified number of classes
model_ft = CLIPFineTuner(clip_model, num_classes)
model_ft.load_state_dict(torch.load(model_path, map_location=device))
model_ft = model_ft.to(device)
model_ft.eval()
# Load and preprocess the image
image = Image.open(image_path).convert("RGB")
input_tensor = preprocess(image).unsqueeze(0).to(device)
# Run the model and get the prediction
outputs = model_ft(input_tensor)
_, predicted_idx = torch.max(outputs, 1)
predicted_class = subcategories[predicted_idx.item()]
return predicted_class
except Exception as e:
return {"error": str(e)}
# if __name__ == "__main__":
# import argparse
# # Argument parser for command line usage
# parser = argparse.ArgumentParser(description="Image classification using CLIP fine-tuned model")
# parser.add_argument("image_path", type=str, help="Path to the input image")
# parser.add_argument("model_name", type=str, choices=model_info.keys(), help="Name of the model (e.g., hineng, hinengpun, hinengguj)")
# args = parser.parse_args()
# # Execute prediction with command line inputs
# result = predict(args.image_path, args.model_name)
# print(result) |