vincentb25 commited on
Commit
408074d
·
1 Parent(s): c24e460

Added n class support

Browse files
Files changed (1) hide show
  1. app.py +98 -32
app.py CHANGED
@@ -25,7 +25,7 @@ args.cx = '06d75168141bc47f1'
25
 
26
 
27
  # model
28
- device = 'cpu' #torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
  model = get_model(args)
30
  model.to(device)
31
  checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu')
@@ -57,25 +57,29 @@ def denormalize(x, mean, std):
57
 
58
 
59
  # Gradio UI
60
- def inference(query, class1_name="class1", support_imgs=None, class2_name="class2", support_imgs2=None):
61
  '''
62
  query: PIL image
63
- labels: list of class names
64
  '''
65
 
66
 
67
- #first, open the images
68
- support_imgs = [Image.open(img) for img in support_imgs]
69
- support_imgs2 = [Image.open(img) for img in support_imgs2]
70
 
 
 
 
 
 
 
 
71
 
72
- labels = [class1_name, class2_name]
73
-
74
  supp_x = []
75
  supp_y = []
76
-
77
- for i, (class_name, support_img) in enumerate(zip([class1_name, class2_name], [support_imgs, support_imgs2])):
78
- for img in support_img:
 
 
79
  x_im = preprocess(img)
80
  supp_x.append(x_im)
81
  supp_y.append(i)
@@ -93,11 +97,14 @@ def inference(query, class1_name="class1", support_imgs=None, class2_name="class
93
 
94
 
95
  with torch.cuda.amp.autocast(True):
 
96
  output = model(supp_x, supp_y, query) # (1, 1, n_labels)
 
97
 
98
  probs = output.softmax(dim=-1).detach().cpu().numpy()
99
 
100
- return {k: float(v) for k, v in zip(labels, probs[0, 0])}
 
101
 
102
 
103
  # DEBUG
@@ -109,25 +116,84 @@ def inference(query, class1_name="class1", support_imgs=None, class2_name="class
109
  #print(output)
110
 
111
 
112
- title = "P>M>F few-shot learning pipeline"
113
  description = "Short description: We take a ViT-small backbone, which is pre-trained with DINO, and meta-trained on Meta-Dataset; for few-shot classification, we use a ProtoNet classifier. The demo can be viewed as zero-shot since the support set is built by searching images from Google. Note that you may need to play with GIS parameters to get good support examples. Besides, GIS is not very stable as search requests may fail for many reasons (e.g., number of requests reaches the limit of the day). This code is heavely inspired from the original HF space <a href='https://huggingface.co/spaces/hushell/pmf_with_gis' target='_blank'>here</a>"
114
  article = "<p style='text-align: center'><a href='http://arxiv.org/abs/2204.07305' target='_blank'>Arxiv</a></p>"
115
- gr.Interface(fn=inference,
116
- inputs=[
117
- gr.Image(label="Image to classify", type="pil"),
118
- gr.Textbox(lines=1, label="First class name :", placeholder="Enter first class name",),
119
- gr.File(label="First class example images", file_types=["image"], file_count="multiple"),
120
- gr.Textbox(lines=1, label="Second class name :", placeholder="Enter second class name",),
121
- gr.File(label="Second class example iamges", file_types=["image"], file_count="multiple"),
122
- ],
123
- theme="grass",
124
- outputs=[
125
- gr.Label(label="Predicted class probabilities"),
126
- ],
127
- title=title,
128
- description=description,
129
- article=article,
130
- examples=[
131
- ["./example_images/2007_000033.jpg", "plane", ["./example_images/2007_000738.jpg", "./example_images/2007_000256.jpg"], "cat", ["./example_images/2007_000528.jpg", "./example_images/2007_000549.jpg"]]
132
- ]
133
- ).launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
 
27
  # model
28
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
  model = get_model(args)
30
  model.to(device)
31
  checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu')
 
57
 
58
 
59
  # Gradio UI
60
+ def inference(query, *support_text_box_and_files):
61
  '''
62
  query: PIL image
63
+ class_names: list of class names
64
  '''
65
 
66
 
 
 
 
67
 
68
+ labels = support_text_box_and_files[0::2]
69
+ support_images = support_text_box_and_files[1::2]
70
+
71
+ print(f"Support images: {support_images}")
72
+
73
+ #first, open the images
74
+ support_images = [[Image.open(img) for img in imgs] for imgs in support_images if imgs != None]
75
 
 
 
76
  supp_x = []
77
  supp_y = []
78
+ for i, support_imgs in enumerate(support_images):
79
+ #for i, (class_name, support_imgs) in enumerate(zip(class_names, support_images)):
80
+ if len(support_imgs) == 0:
81
+ continue
82
+ for img in support_imgs:
83
  x_im = preprocess(img)
84
  supp_x.append(x_im)
85
  supp_y.append(i)
 
97
 
98
 
99
  with torch.cuda.amp.autocast(True):
100
+ start_time = time.time()
101
  output = model(supp_x, supp_y, query) # (1, 1, n_labels)
102
+ exec_time = time.time() - start_time
103
 
104
  probs = output.softmax(dim=-1).detach().cpu().numpy()
105
 
106
+
107
+ return {k: float(v) for k, v in zip(labels, probs[0, 0])}, exec_time
108
 
109
 
110
  # DEBUG
 
116
  #print(output)
117
 
118
 
119
+ title = "# P>M>F few-shot learning pipeline"
120
  description = "Short description: We take a ViT-small backbone, which is pre-trained with DINO, and meta-trained on Meta-Dataset; for few-shot classification, we use a ProtoNet classifier. The demo can be viewed as zero-shot since the support set is built by searching images from Google. Note that you may need to play with GIS parameters to get good support examples. Besides, GIS is not very stable as search requests may fail for many reasons (e.g., number of requests reaches the limit of the day). This code is heavely inspired from the original HF space <a href='https://huggingface.co/spaces/hushell/pmf_with_gis' target='_blank'>here</a>"
121
  article = "<p style='text-align: center'><a href='http://arxiv.org/abs/2204.07305' target='_blank'>Arxiv</a></p>"
122
+
123
+ min_classes = 2
124
+ max_classes = 10
125
+
126
+
127
+ def variable_outputs(k):
128
+ k = int(k)
129
+ inputs = []
130
+ for _ in range(k):
131
+ inputs.append(gr.Textbox(visible=True))
132
+ inputs.append(gr.File(visible=True))
133
+
134
+ for _ in range(max_classes-k):
135
+ inputs.append(gr.Textbox(visible=False))
136
+ inputs.append(gr.File(visible=False))
137
+
138
+ return inputs
139
+
140
+ with gr.Blocks() as demo:
141
+
142
+
143
+ with gr.Row():
144
+ gr.Markdown(title)
145
+ with gr.Row():
146
+ gr.Markdown(description)
147
+ with gr.Row():
148
+ gr.Markdown(article)
149
+ with gr.Row():
150
+ with gr.Column():
151
+
152
+
153
+ query = gr.Image(label="Image to classify", type="pil")
154
+ num_classes_slider = gr.Slider(minimum=min_classes, maximum=10, value=2, label="Number of classes", step=1)
155
+
156
+ #set_number_classes_btn = gr.Button("Set number of classes")
157
+
158
+ textboxes_and_files = []
159
+ for i in range(max_classes):
160
+ is_visible = (i < 2)
161
+ t = gr.Textbox(label=f"Class {i+1} name", placeholder=f"Enter class {i+1} name", visible=is_visible)
162
+ textboxes_and_files.append(t)
163
+ f = gr.File(label=f"Support image for class {i+1}", type="filepath", visible=is_visible, file_count="multiple")
164
+ textboxes_and_files.append(f)
165
+
166
+
167
+
168
+
169
+ num_classes_slider.change(variable_outputs, inputs=[num_classes_slider], outputs=textboxes_and_files)
170
+
171
+
172
+ run_button = gr.Button("Run Inference")
173
+
174
+ with gr.Column():
175
+ output = gr.Label(label="Predicted class probabilities")
176
+ exec_time = gr.Textbox(label="Execution time (s)")
177
+
178
+
179
+ # def run_inference(query, *example_inputs):
180
+ #
181
+ # print("len(example_inputs) : ")
182
+ # print(len(example_inputs))
183
+ #
184
+ # class_names = [example_inputs[i].value for i in range(0, len(example_inputs), 2)]
185
+ # support_images = [example_inputs[i].value for i in range(1, len(example_inputs), 2)]
186
+ # return inference(query, class_names, support_images)
187
+
188
+ run_button.click(
189
+ fn=inference,
190
+ inputs=[query] + textboxes_and_files,
191
+ outputs=[output, exec_time]
192
+ )
193
+
194
+
195
+ # this does nothing it seems
196
+ demo.examples = [
197
+ ["./example_images/2007_000033.jpg", "plane", ["./example_images/2007_000738.jpg", "./example_images/2007_000256.jpg"], "cat", ["./example_images/2007_000528.jpg", "./example_images/2007_000549.jpg"]]
198
+ ]
199
+ demo.launch()