prithivMLmods commited on
Commit
49baf5e
·
verified ·
1 Parent(s): 22ba041
Files changed (1) hide show
  1. app.py +179 -179
app.py CHANGED
@@ -1,180 +1,180 @@
1
- import gradio as gr
2
- import torch
3
- from transformers import AutoModel, AutoProcessor
4
-
5
- from gender_classification import gender_classification
6
- from emotion_classification import emotion_classification
7
- from dog_breed import dog_breed_classification
8
- from deepfake_quality import deepfake_classification
9
- from gym_workout_classification import workout_classification
10
- from augmented_waste_classifier import waste_classification
11
- from age_classification import age_classification
12
- from mnist_digits import classify_digit
13
- from fashion_mnist_cloth import fashion_mnist_classification
14
- from indian_western_food_classify import food_classification
15
- from bird_species import bird_classification
16
- from alphabet_sign_language_detection import sign_language_classification
17
- from rice_leaf_disease import classify_leaf_disease
18
- from traffic_density import traffic_density_classification
19
- from clip_art import clipart_classification
20
- from multisource_121 import multisource_classification
21
- from painting_126 import painting_classification
22
- from sketch_126 import sketch_classification # New import
23
-
24
- # Main classification function for multi-model classification.
25
- def classify(image, model_name):
26
- if model_name == "gender":
27
- return gender_classification(image)
28
- elif model_name == "emotion":
29
- return emotion_classification(image)
30
- elif model_name == "dog breed":
31
- return dog_breed_classification(image)
32
- elif model_name == "deepfake":
33
- return deepfake_classification(image)
34
- elif model_name == "gym workout":
35
- return workout_classification(image)
36
- elif model_name == "waste":
37
- return waste_classification(image)
38
- elif model_name == "age":
39
- return age_classification(image)
40
- elif model_name == "mnist":
41
- return classify_digit(image)
42
- elif model_name == "fashion_mnist":
43
- return fashion_mnist_classification(image)
44
- elif model_name == "food":
45
- return food_classification(image)
46
- elif model_name == "bird":
47
- return bird_classification(image)
48
- elif model_name == "leaf disease":
49
- return classify_leaf_disease(image)
50
- elif model_name == "sign language":
51
- return sign_language_classification(image)
52
- elif model_name == "traffic density":
53
- return traffic_density_classification(image)
54
- elif model_name == "clip art":
55
- return clipart_classification(image)
56
- elif model_name == "multisource":
57
- return multisource_classification(image)
58
- elif model_name == "painting":
59
- return painting_classification(image)
60
- elif model_name == "sketch": # New option
61
- return sketch_classification(image)
62
- else:
63
- return {"Error": "No model selected"}
64
-
65
- # Function to update the selected model and button styles.
66
- def select_model(model_name):
67
- model_variants = {
68
- "gender": "secondary", "emotion": "secondary", "dog breed": "secondary", "deepfake": "secondary",
69
- "gym workout": "secondary", "waste": "secondary", "age": "secondary", "mnist": "secondary",
70
- "fashion_mnist": "secondary", "food": "secondary", "bird": "secondary", "leaf disease": "secondary",
71
- "sign language": "secondary", "traffic density": "secondary", "clip art": "secondary",
72
- "multisource": "secondary", "painting": "secondary", "sketch": "secondary" # New model variant
73
- }
74
- model_variants[model_name] = "primary"
75
- return (model_name, *(gr.update(variant=model_variants[key]) for key in model_variants))
76
-
77
- # Zero-Shot Classification Setup (SigLIP models)
78
- sg1_ckpt = "google/siglip-so400m-patch14-384"
79
- siglip1_model = AutoModel.from_pretrained(sg1_ckpt, device_map="cpu").eval()
80
- siglip1_processor = AutoProcessor.from_pretrained(sg1_ckpt)
81
-
82
- sg2_ckpt = "google/siglip2-so400m-patch14-384"
83
- siglip2_model = AutoModel.from_pretrained(sg2_ckpt, device_map="cpu").eval()
84
- siglip2_processor = AutoProcessor.from_pretrained(sg2_ckpt)
85
-
86
- def postprocess_siglip(sg1_probs, sg2_probs, labels):
87
- sg1_output = {labels[i]: sg1_probs[0][i].item() for i in range(len(labels))}
88
- sg2_output = {labels[i]: sg2_probs[0][i].item() for i in range(len(labels))}
89
- return sg1_output, sg2_output
90
-
91
- def siglip_detector(image, texts):
92
- sg1_inputs = siglip1_processor(
93
- text=texts, images=image, return_tensors="pt", padding="max_length", max_length=64
94
- ).to("cpu")
95
- sg2_inputs = siglip2_processor(
96
- text=texts, images=image, return_tensors="pt", padding="max_length", max_length=64
97
- ).to("cpu")
98
- with torch.no_grad():
99
- sg1_outputs = siglip1_model(**sg1_inputs)
100
- sg2_outputs = siglip2_model(**sg2_inputs)
101
- sg1_logits_per_image = sg1_outputs.logits_per_image
102
- sg2_logits_per_image = sg2_outputs.logits_per_image
103
- sg1_probs = torch.sigmoid(sg1_logits_per_image)
104
- sg2_probs = torch.sigmoid(sg2_logits_per_image)
105
- return sg1_probs, sg2_probs
106
-
107
- def infer(image, candidate_labels):
108
- candidate_labels = [label.lstrip(" ") for label in candidate_labels.split(",")]
109
- sg1_probs, sg2_probs = siglip_detector(image, candidate_labels)
110
- return postprocess_siglip(sg1_probs, sg2_probs, labels=candidate_labels)
111
-
112
- # Build the Gradio Interface with two tabs.
113
- with gr.Blocks(theme="YTheme/Minecraft") as demo:
114
- gr.Markdown("# Multi-Domain & Zero-Shot Image Classification")
115
-
116
- with gr.Tabs():
117
- # Tab 1: Multi-Model Classification
118
- with gr.Tab("Multi-Domain Classification"):
119
- with gr.Sidebar():
120
- gr.Markdown("# Choose Domain")
121
- with gr.Row():
122
- age_btn = gr.Button("Age Classification", variant="primary")
123
- gender_btn = gr.Button("Gender Classification", variant="secondary")
124
- emotion_btn = gr.Button("Emotion Classification", variant="secondary")
125
- gym_workout_btn = gr.Button("Gym Workout Classification", variant="secondary")
126
- dog_breed_btn = gr.Button("Dog Breed Classification", variant="secondary")
127
- bird_btn = gr.Button("Bird Species Classification", variant="secondary")
128
- waste_btn = gr.Button("Waste Classification", variant="secondary")
129
- deepfake_btn = gr.Button("Deepfake Quality Test", variant="secondary")
130
- traffic_density_btn = gr.Button("Traffic Density", variant="secondary")
131
- sign_language_btn = gr.Button("Alphabet Sign Language", variant="secondary")
132
- clip_art_btn = gr.Button("Clip Art 126", variant="secondary")
133
- mnist_btn = gr.Button("Digit Classify (0-9)", variant="secondary")
134
- fashion_mnist_btn = gr.Button("Fashion MNIST (only cloth)", variant="secondary")
135
- food_btn = gr.Button("Indian/Western Food Type", variant="secondary")
136
- leaf_disease_btn = gr.Button("Rice Leaf Disease", variant="secondary")
137
- multisource_btn = gr.Button("Multi Source 121", variant="secondary")
138
- painting_btn = gr.Button("Painting 126", variant="secondary")
139
- sketch_btn = gr.Button("Sketch 126", variant="secondary")
140
-
141
- selected_model = gr.State("age")
142
- gr.Markdown("### Current Model:")
143
- model_display = gr.Textbox(value="age", interactive=False)
144
- selected_model.change(lambda m: m, selected_model, model_display)
145
-
146
- buttons = [
147
- gender_btn, emotion_btn, dog_breed_btn, deepfake_btn, gym_workout_btn, waste_btn,
148
- age_btn, mnist_btn, fashion_mnist_btn, food_btn, bird_btn, leaf_disease_btn,
149
- sign_language_btn, traffic_density_btn, clip_art_btn, multisource_btn, painting_btn, sketch_btn # Include new button
150
- ]
151
- model_names = [
152
- "gender", "emotion", "dog breed", "deepfake", "gym workout", "waste",
153
- "age", "mnist", "fashion_mnist", "food", "bird", "leaf disease",
154
- "sign language", "traffic density", "clip art", "multisource", "painting", "sketch" # New model name
155
- ]
156
-
157
- for btn, name in zip(buttons, model_names):
158
- btn.click(fn=lambda n=name: select_model(n), inputs=[], outputs=[selected_model] + buttons)
159
-
160
- with gr.Row():
161
- with gr.Column():
162
- image_input = gr.Image(type="numpy", label="Upload Image")
163
- analyze_btn = gr.Button("Classify / Predict")
164
- output_label = gr.Label(label="Prediction Scores")
165
- analyze_btn.click(fn=classify, inputs=[image_input, selected_model], outputs=output_label)
166
-
167
- # Tab 2: Zero-Shot Classification (SigLIP)
168
- with gr.Tab("Zero-Shot Classification"):
169
- gr.Markdown("## Compare SigLIP 1 and SigLIP 2 on Zero-Shot Classification")
170
- with gr.Row():
171
- with gr.Column():
172
- zs_image_input = gr.Image(type="pil", label="Upload Image")
173
- zs_text_input = gr.Textbox(label="Input a list of labels (comma separated)")
174
- zs_run_button = gr.Button("Run")
175
- with gr.Column():
176
- siglip1_output = gr.Label(label="SigLIP 1 Output", num_top_classes=3)
177
- siglip2_output = gr.Label(label="SigLIP 2 Output", num_top_classes=3)
178
- zs_run_button.click(fn=infer, inputs=[zs_image_input, zs_text_input], outputs=[siglip1_output, siglip2_output])
179
-
180
  demo.launch()
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModel, AutoProcessor
4
+
5
+ from gender_classification import gender_classification
6
+ from emotion_classification import emotion_classification
7
+ from dog_breed import dog_breed_classification
8
+ from deepfake_quality import deepfake_classification
9
+ from gym_workout_classification import workout_classification
10
+ from augmented_waste_classifier import waste_classification
11
+ from age_classification import age_classification
12
+ from mnist_digits import classify_digit
13
+ from fashion_mnist_cloth import fashion_mnist_classification
14
+ from indian_western_food_classify import food_classification
15
+ from bird_species import bird_classification
16
+ from alphabet_sign_language_detection import sign_language_classification
17
+ from rice_leaf_disease import classify_leaf_disease
18
+ from traffic_density import traffic_density_classification
19
+ from clip_art import clipart_classification
20
+ from multisource_121 import multisource_classification
21
+ from painting_126 import painting_classification
22
+ from sketch_126 import sketch_classification # New import
23
+
24
+ # Main classification function for multi-model classification.
25
+ def classify(image, model_name):
26
+ if model_name == "gender":
27
+ return gender_classification(image)
28
+ elif model_name == "emotion":
29
+ return emotion_classification(image)
30
+ elif model_name == "dog breed":
31
+ return dog_breed_classification(image)
32
+ elif model_name == "deepfake":
33
+ return deepfake_classification(image)
34
+ elif model_name == "gym workout":
35
+ return workout_classification(image)
36
+ elif model_name == "waste":
37
+ return waste_classification(image)
38
+ elif model_name == "age":
39
+ return age_classification(image)
40
+ elif model_name == "mnist":
41
+ return classify_digit(image)
42
+ elif model_name == "fashion_mnist":
43
+ return fashion_mnist_classification(image)
44
+ elif model_name == "food":
45
+ return food_classification(image)
46
+ elif model_name == "bird":
47
+ return bird_classification(image)
48
+ elif model_name == "leaf disease":
49
+ return classify_leaf_disease(image)
50
+ elif model_name == "sign language":
51
+ return sign_language_classification(image)
52
+ elif model_name == "traffic density":
53
+ return traffic_density_classification(image)
54
+ elif model_name == "clip art":
55
+ return clipart_classification(image)
56
+ elif model_name == "multisource":
57
+ return multisource_classification(image)
58
+ elif model_name == "painting":
59
+ return painting_classification(image)
60
+ elif model_name == "sketch": # New option
61
+ return sketch_classification(image)
62
+ else:
63
+ return {"Error": "No model selected"}
64
+
65
+ # Function to update the selected model and button styles.
66
+ def select_model(model_name):
67
+ model_variants = {
68
+ "gender": "secondary", "emotion": "secondary", "dog breed": "secondary", "deepfake": "secondary",
69
+ "gym workout": "secondary", "waste": "secondary", "age": "secondary", "mnist": "secondary",
70
+ "fashion_mnist": "secondary", "food": "secondary", "bird": "secondary", "leaf disease": "secondary",
71
+ "sign language": "secondary", "traffic density": "secondary", "clip art": "secondary",
72
+ "multisource": "secondary", "painting": "secondary", "sketch": "secondary" # New model variant
73
+ }
74
+ model_variants[model_name] = "primary"
75
+ return (model_name, *(gr.update(variant=model_variants[key]) for key in model_variants))
76
+
77
+ # Zero-Shot Classification Setup (SigLIP models)
78
+ sg1_ckpt = "google/siglip-so400m-patch14-384"
79
+ siglip1_model = AutoModel.from_pretrained(sg1_ckpt, device_map="cpu").eval()
80
+ siglip1_processor = AutoProcessor.from_pretrained(sg1_ckpt)
81
+
82
+ sg2_ckpt = "google/siglip2-so400m-patch14-384"
83
+ siglip2_model = AutoModel.from_pretrained(sg2_ckpt, device_map="cpu").eval()
84
+ siglip2_processor = AutoProcessor.from_pretrained(sg2_ckpt)
85
+
86
+ def postprocess_siglip(sg1_probs, sg2_probs, labels):
87
+ sg1_output = {labels[i]: sg1_probs[0][i].item() for i in range(len(labels))}
88
+ sg2_output = {labels[i]: sg2_probs[0][i].item() for i in range(len(labels))}
89
+ return sg1_output, sg2_output
90
+
91
+ def siglip_detector(image, texts):
92
+ sg1_inputs = siglip1_processor(
93
+ text=texts, images=image, return_tensors="pt", padding="max_length", max_length=64
94
+ ).to("cpu")
95
+ sg2_inputs = siglip2_processor(
96
+ text=texts, images=image, return_tensors="pt", padding="max_length", max_length=64
97
+ ).to("cpu")
98
+ with torch.no_grad():
99
+ sg1_outputs = siglip1_model(**sg1_inputs)
100
+ sg2_outputs = siglip2_model(**sg2_inputs)
101
+ sg1_logits_per_image = sg1_outputs.logits_per_image
102
+ sg2_logits_per_image = sg2_outputs.logits_per_image
103
+ sg1_probs = torch.sigmoid(sg1_logits_per_image)
104
+ sg2_probs = torch.sigmoid(sg2_logits_per_image)
105
+ return sg1_probs, sg2_probs
106
+
107
+ def infer(image, candidate_labels):
108
+ candidate_labels = [label.lstrip(" ") for label in candidate_labels.split(",")]
109
+ sg1_probs, sg2_probs = siglip_detector(image, candidate_labels)
110
+ return postprocess_siglip(sg1_probs, sg2_probs, labels=candidate_labels)
111
+
112
+ # Build the Gradio Interface with two tabs.
113
+ with gr.Blocks() as demo:
114
+ gr.Markdown("# Multi-Domain & Zero-Shot Image Classification")
115
+
116
+ with gr.Tabs():
117
+ # Tab 1: Multi-Model Classification
118
+ with gr.Tab("Multi-Domain Classification"):
119
+ with gr.Sidebar():
120
+ gr.Markdown("# Choose Domain")
121
+ with gr.Row():
122
+ age_btn = gr.Button("Age Classification", variant="primary")
123
+ gender_btn = gr.Button("Gender Classification", variant="secondary")
124
+ emotion_btn = gr.Button("Emotion Classification", variant="secondary")
125
+ gym_workout_btn = gr.Button("Gym Workout Classification", variant="secondary")
126
+ dog_breed_btn = gr.Button("Dog Breed Classification", variant="secondary")
127
+ bird_btn = gr.Button("Bird Species Classification", variant="secondary")
128
+ waste_btn = gr.Button("Waste Classification", variant="secondary")
129
+ deepfake_btn = gr.Button("Deepfake Quality Test", variant="secondary")
130
+ traffic_density_btn = gr.Button("Traffic Density", variant="secondary")
131
+ sign_language_btn = gr.Button("Alphabet Sign Language", variant="secondary")
132
+ clip_art_btn = gr.Button("Clip Art 126", variant="secondary")
133
+ mnist_btn = gr.Button("Digit Classify (0-9)", variant="secondary")
134
+ fashion_mnist_btn = gr.Button("Fashion MNIST (only cloth)", variant="secondary")
135
+ food_btn = gr.Button("Indian/Western Food Type", variant="secondary")
136
+ leaf_disease_btn = gr.Button("Rice Leaf Disease", variant="secondary")
137
+ multisource_btn = gr.Button("Multi Source 121", variant="secondary")
138
+ painting_btn = gr.Button("Painting 126", variant="secondary")
139
+ sketch_btn = gr.Button("Sketch 126", variant="secondary")
140
+
141
+ selected_model = gr.State("age")
142
+ gr.Markdown("### Current Model:")
143
+ model_display = gr.Textbox(value="age", interactive=False)
144
+ selected_model.change(lambda m: m, selected_model, model_display)
145
+
146
+ buttons = [
147
+ gender_btn, emotion_btn, dog_breed_btn, deepfake_btn, gym_workout_btn, waste_btn,
148
+ age_btn, mnist_btn, fashion_mnist_btn, food_btn, bird_btn, leaf_disease_btn,
149
+ sign_language_btn, traffic_density_btn, clip_art_btn, multisource_btn, painting_btn, sketch_btn # Include new button
150
+ ]
151
+ model_names = [
152
+ "gender", "emotion", "dog breed", "deepfake", "gym workout", "waste",
153
+ "age", "mnist", "fashion_mnist", "food", "bird", "leaf disease",
154
+ "sign language", "traffic density", "clip art", "multisource", "painting", "sketch" # New model name
155
+ ]
156
+
157
+ for btn, name in zip(buttons, model_names):
158
+ btn.click(fn=lambda n=name: select_model(n), inputs=[], outputs=[selected_model] + buttons)
159
+
160
+ with gr.Row():
161
+ with gr.Column():
162
+ image_input = gr.Image(type="numpy", label="Upload Image")
163
+ analyze_btn = gr.Button("Classify / Predict")
164
+ output_label = gr.Label(label="Prediction Scores")
165
+ analyze_btn.click(fn=classify, inputs=[image_input, selected_model], outputs=output_label)
166
+
167
+ # Tab 2: Zero-Shot Classification (SigLIP)
168
+ with gr.Tab("Zero-Shot Classification"):
169
+ gr.Markdown("## Compare SigLIP 1 and SigLIP 2 on Zero-Shot Classification")
170
+ with gr.Row():
171
+ with gr.Column():
172
+ zs_image_input = gr.Image(type="pil", label="Upload Image")
173
+ zs_text_input = gr.Textbox(label="Input a list of labels (comma separated)")
174
+ zs_run_button = gr.Button("Run")
175
+ with gr.Column():
176
+ siglip1_output = gr.Label(label="SigLIP 1 Output", num_top_classes=3)
177
+ siglip2_output = gr.Label(label="SigLIP 2 Output", num_top_classes=3)
178
+ zs_run_button.click(fn=infer, inputs=[zs_image_input, zs_text_input], outputs=[siglip1_output, siglip2_output])
179
+
180
  demo.launch()