gheinrich commited on
Commit
271fef8
·
verified ·
1 Parent(s): 8b56af0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +161 -0
app.py CHANGED
@@ -12,3 +12,164 @@ model = AutoModel.from_pretrained(hf_repo, trust_remote_code=True)
12
  model.eval().cuda()
13
 
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  model.eval().cuda()
13
 
14
 
15
+ title = """RADIO: Reduce All Domains Into One"""
16
+ description = """
17
+ # RADIO
18
+
19
+ AM-RADIO is a framework to distill Large Vision Foundation models into a single one.
20
+ RADIO, a new vision foundation model, excels across visual domains, serving as a superior replacement for vision backbones.
21
+ Integrating CLIP variants, DINOv2, and SAM through distillation, it preserves unique features like text grounding and segmentation correspondence.
22
+ Outperforming teachers in ImageNet zero-shot (+6.8%), kNN (+2.39%), and linear probing segmentation (+3.8%) and vision-language models (LLaVa 1.5 up to 1.5%), it scales to any resolution, supports non-square images.
23
+
24
+ # Instructions
25
+
26
+ Simply paste an image or pick one from the gallery of examples and then click the "Submit" button.
27
+ """
28
+
29
+ inputs = [
30
+ gr.Image(type="pil")
31
+ ]
32
+
33
+ examples = [
34
+ "IMG_0996.jpeg",
35
+ "IMG_1061.jpeg",
36
+ "IMG_1338.jpeg",
37
+ "IMG_4319.jpeg",
38
+ "IMG_5104.jpeg",
39
+ "IMG_5139.jpeg",
40
+ "IMG_6225.jpeg",
41
+ "IMG_6814.jpeg",
42
+ "IMG_7459.jpeg",
43
+ "IMG_7577.jpeg",
44
+ "IMG_7687.jpeg",
45
+ "IMG_9862.jpeg",
46
+ ]
47
+
48
+ outputs = [
49
+ gr.Textbox(label="Feature Shape"),
50
+ gr.Image(),
51
+ ]
52
+
53
+ def get_robust_pca(features: torch.Tensor, m: float = 2, remove_first_component=False):
54
+ # features: (N, C)
55
+ # m: a hyperparam controlling how many std dev outside for outliers
56
+ assert len(features.shape) == 2, "features should be (N, C)"
57
+ reduction_mat = torch.pca_lowrank(features, q=3, niter=20)[2]
58
+ colors = features @ reduction_mat
59
+ if remove_first_component:
60
+ colors_min = colors.min(dim=0).values
61
+ colors_max = colors.max(dim=0).values
62
+ tmp_colors = (colors - colors_min) / (colors_max - colors_min)
63
+ fg_mask = tmp_colors[..., 0] < 0.2
64
+ reduction_mat = torch.pca_lowrank(features[fg_mask], q=3, niter=20)[2]
65
+ colors = features @ reduction_mat
66
+ else:
67
+ fg_mask = torch.ones_like(colors[:, 0]).bool()
68
+ d = torch.abs(colors[fg_mask] - torch.median(colors[fg_mask], dim=0).values)
69
+ mdev = torch.median(d, dim=0).values
70
+ s = d / mdev
71
+ try:
72
+ rins = colors[fg_mask][s[:, 0] < m, 0]
73
+ gins = colors[fg_mask][s[:, 1] < m, 1]
74
+ bins = colors[fg_mask][s[:, 2] < m, 2]
75
+ rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()])
76
+ rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()])
77
+ except:
78
+ rins = colors
79
+ gins = colors
80
+ bins = colors
81
+ rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()])
82
+ rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()])
83
+
84
+ return reduction_mat, rgb_min.to(reduction_mat), rgb_max.to(reduction_mat)
85
+
86
+
87
+ def get_pca_map(
88
+ feature_map: torch.Tensor,
89
+ img_size,
90
+ interpolation="bicubic",
91
+ return_pca_stats=False,
92
+ pca_stats=None,
93
+ ):
94
+ """
95
+ feature_map: (1, h, w, C) is the feature map of a single image.
96
+ """
97
+ if feature_map.shape[0] != 1:
98
+ # make it (1, h, w, C)
99
+ feature_map = feature_map[None]
100
+ if pca_stats is None:
101
+ reduct_mat, color_min, color_max = get_robust_pca(
102
+ feature_map.reshape(-1, feature_map.shape[-1])
103
+ )
104
+ else:
105
+ reduct_mat, color_min, color_max = pca_stats
106
+ pca_color = feature_map @ reduct_mat
107
+ pca_color = (pca_color - color_min) / (color_max - color_min)
108
+ pca_color = pca_color.clamp(0, 1)
109
+ pca_color = F.interpolate(
110
+ pca_color.permute(0, 3, 1, 2),
111
+ size=img_size,
112
+ mode=interpolation,
113
+ ).permute(0, 2, 3, 1)
114
+ pca_color = pca_color.cpu().numpy().squeeze(0)
115
+ if return_pca_stats:
116
+ return pca_color, (reduct_mat, color_min, color_max)
117
+ return pca_color
118
+
119
+
120
+ def pad_image_to_multiple_of_16(image):
121
+ # Calculate the new dimensions to make them multiples of 16
122
+ width, height = image.size
123
+ new_width = (width + 15) // 16 * 16
124
+ new_height = (height + 15) // 16 * 16
125
+
126
+ # Calculate the padding needed on each side
127
+ pad_width = new_width - width
128
+ pad_height = new_height - height
129
+
130
+ left = pad_width // 2
131
+ right = pad_width - left
132
+ top = pad_height // 2
133
+ bottom = pad_height - top
134
+
135
+ # Apply the padding
136
+ padded_image = ImageOps.expand(image, (left, top, right, bottom), fill='black')
137
+
138
+ return padded_image
139
+
140
+
141
+ @spaces.GPU
142
+ def infer_radio(image):
143
+ """Define the function to generate the output."""
144
+ image=pad_image_to_multiple_of_16(image)
145
+ width, height = image.size
146
+ pixel_values = image_processor(images=image, return_tensors='pt').pixel_values
147
+ pixel_values = pixel_values.to(torch.bfloat16).cuda()
148
+
149
+ _, features = model(pixel_values)
150
+
151
+
152
+ num_rows = height // model.patch_size
153
+ num_cols = width // model.patch_size
154
+
155
+ features = features.detach()
156
+ features = rearrange(features, 'b (h w) c -> b h w c', h=num_rows, w=num_cols).float()
157
+
158
+ pca_viz = get_pca_map(features, (height, width), interpolation='bilinear')
159
+
160
+ return f"{features.shape}", pca_viz
161
+
162
+
163
+ # Create the Gradio interface
164
+ demo = gr.Interface(
165
+ fn=infer_radio,
166
+ inputs=inputs,
167
+ examples=examples,
168
+ outputs=outputs,
169
+ title=title,
170
+ description=description
171
+ )
172
+
173
+ if __name__ == "__main__":
174
+ demo.launch()
175
+