File size: 8,456 Bytes
01bb3bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201

import torch
import clip
from PIL import Image
from io import BytesIO
import os
import requests

# Model information dictionary containing model paths and language subcategories
model_info = {
    "hindi": {
        "path": "models/clip_finetuned_hindienglish_real.pth",
        "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglish_real.pth",
        "subcategories": ["hindi", "english"]
    },
    "hinengasm": {
        "path": "models/clip_finetuned_hindienglishassamese_real.pth",
        "url": "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishassamese_real.pth",
        "subcategories": ["hindi", "english", "assamese"]
    },
    "hinengben": {
        "path": "models/clip_finetuned_hindienglishbengali_real.pth",
        "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishbengali_real.pth",
        "subcategories": ["hindi", "english", "bengali"]
    },
    "hinengguj": {
        "path": "models/clip_finetuned_hindienglishgujarati_real.pth",
        "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishgujarati_real.pth",
        "subcategories": ["hindi", "english", "gujarati"]
    },
    "hinengkan": {
        "path": "models/clip_finetuned_hindienglishkannada_real.pth",
        "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishkannada_real.pth",
        "subcategories": ["hindi", "english", "kannada"]
    },
    "hinengmal": {
        "path": "models/clip_finetuned_hindienglishmalayalam_real.pth",
        "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishmalayalam_real.pth",
        "subcategories": ["hindi", "english", "malayalam"]
    },
    "hinengmar": {
        "path": "models/clip_finetuned_hindienglishmarathi_real.pth",
        "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishmarathi_real.pth",
        "subcategories": ["hindi", "english", "marathi"]
    },
    "hinengmei": {
        "path": "models/clip_finetuned_hindienglishmeitei_real.pth",
        "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishmeitei_real.pth",
        "subcategories": ["hindi", "english", "meitei"]
    },
    "hinengodi": {
        "path": "models/clip_finetuned_hindienglishodia_real.pth",
        "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishodia_real.pth",
        "subcategories": ["hindi", "english", "odia"]
    },
    "hinengpun": {
        "path": "models/clip_finetuned_hindienglishpunjabi_real.pth",
        "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishpunjabi_real.pth",
        "subcategories": ["hindi", "english", "punjabi"]
    },
    "hinengtam": {
        "path": "models/clip_finetuned_hindienglishtamil_real.pth",
        "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishtamil_real.pth",
        "subcategories": ["hindi", "english", "tamil"]
    },
    "hinengtel": {
        "path": "models/clip_finetuned_hindienglishtelugu_real.pth",
        "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishtelugu_real.pth",
        "subcategories": ["hindi", "english", "telugu"]
    },
    "hinengurd": {
        "path": "models/clip_finetuned_hindienglishurdu_real.pth",
        "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishurdu_real.pth",
        "subcategories": ["hindi", "english", "urdu"]
    },
    

}


# Set device to CUDA if available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
clip_model, preprocess = clip.load("ViT-B/32", device=device)

class CLIPFineTuner(torch.nn.Module):
    """
    Fine-tuning class for the CLIP model to adapt to specific tasks.
    
    Attributes:
        model (torch.nn.Module): The CLIP model to be fine-tuned.
        classifier (torch.nn.Linear): A linear classifier to map features to the desired number of classes.
    """
    def __init__(self, model, num_classes):
        """
        Initializes the fine-tuner with the CLIP model and classifier.

        Args:
            model (torch.nn.Module): The base CLIP model.
            num_classes (int): The number of target classes for classification.
        """
        super(CLIPFineTuner, self).__init__()
        self.model = model
        self.classifier = torch.nn.Linear(model.visual.output_dim, num_classes)

    def forward(self, x):
        """
        Forward pass for image classification.

        Args:
            x (torch.Tensor): Preprocessed input tensor for an image.

        Returns:
            torch.Tensor: Logits for each class.
        """
        with torch.no_grad():
            features = self.model.encode_image(x).float()  # Extract image features from CLIP model
        return self.classifier(features)  # Return class logits

class CLIPidentifier:
    def __init__(self):
        pass

    # Ensure model file exists; download directly if not
    def ensure_model(self, model_name):
        model_path = model_info[model_name]["path"]
        url = model_info[model_name]["url"]
        root_model_dir = "IndicPhotoOCR/script_identification/"
        model_path = os.path.join(root_model_dir, model_path)
        
        if not os.path.exists(model_path):
            print(f"Model not found locally. Downloading {model_name} from {url}...")
            response = requests.get(url, stream=True)
            os.makedirs(f"{root_model_dir}/models", exist_ok=True)
            with open(f"{model_path}", "wb") as f:
                f.write(response.content)
            print(f"Downloaded model for {model_name}.")
        
        return model_path

    # Prediction function to verify and load the model
    def identify(self, image_path, model_name):
        """
        Predicts the class of an input image using a fine-tuned CLIP model.

        Args:
            image_path (str): Path to the input image file.
            model_name (str): Name of the model (e.g., hineng, hinengpun, hinengguj) as specified in `model_info`.

        Returns:
            dict: Contains either `predicted_class` if successful or `error` if an exception occurs.

        Example usage:
            result = predict("sample_image.jpg", "hinengguj")
            print(result)  # Output might be {'predicted_class': 'hindi'}
        """
        try:
            # Validate model name and retrieve associated subcategories
            if model_name not in model_info:
                return {"error": "Invalid model name"}

            # Ensure the model file is downloaded and accessible
            model_path = self.ensure_model(model_name)


            subcategories = model_info[model_name]["subcategories"]
            num_classes = len(subcategories)

            # Load the fine-tuned model with the specified number of classes
            model_ft = CLIPFineTuner(clip_model, num_classes)
            model_ft.load_state_dict(torch.load(model_path, map_location=device))
            model_ft = model_ft.to(device)
            model_ft.eval()

            # Load and preprocess the image
            image = Image.open(image_path).convert("RGB")
            input_tensor = preprocess(image).unsqueeze(0).to(device)

            # Run the model and get the prediction
            outputs = model_ft(input_tensor)
            _, predicted_idx = torch.max(outputs, 1)
            predicted_class = subcategories[predicted_idx.item()]

            return predicted_class

        except Exception as e:
            return {"error": str(e)}


# if __name__ == "__main__":
#     import argparse

#     # Argument parser for command line usage
#     parser = argparse.ArgumentParser(description="Image classification using CLIP fine-tuned model")
#     parser.add_argument("image_path", type=str, help="Path to the input image")
#     parser.add_argument("model_name", type=str, choices=model_info.keys(), help="Name of the model (e.g., hineng, hinengpun, hinengguj)")

#     args = parser.parse_args()

#     # Execute prediction with command line inputs
#     result = predict(args.image_path, args.model_name)
#     print(result)