Image Feature Extraction
Transformers
Safetensors
dinov2
fepegar commited on
Commit
201611f
·
1 Parent(s): b22cb58

Add package code update docs

Browse files
.gitignore ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv
11
+
12
+ .vscode/
13
+
14
+ *.ipynb
15
+ uv.lock
16
+ *.txt
17
+ .python-version
README.md CHANGED
@@ -56,113 +56,43 @@ Fine-tuning RAD-DINO is typically not necessary to obtain good performance in do
56
  RAD-DINO was trained with data from three countries, therefore it might be biased towards population in the training data.
57
  Underlying biases of the training datasets may not be well characterized.
58
 
59
- ## Getting started
60
 
61
- ### Get some data
62
-
63
- Let us first write an auxiliary function to download a chest X-ray.
64
-
65
- ```python
66
- >>> import requests
67
- >>> from PIL import Image
68
- >>> def download_sample_image() -> Image.Image:
69
- ... """Download chest X-ray with CC license."""
70
- ... base_url = "https://upload.wikimedia.org/wikipedia/commons"
71
- ... image_url = f"{base_url}/2/20/Chest_X-ray_in_influenza_and_Haemophilus_influenzae.jpg"
72
- ... headers = {"User-Agent": "RAD-DINO"}
73
- ... response = requests.get(image_url, headers=headers, stream=True)
74
- ... return Image.open(response.raw)
75
- ...
76
  ```
77
 
78
- ### Load the model
79
-
80
- Now let us download the model and encode an image.
81
-
82
- ```python
83
- >>> import torch
84
- >>> from transformers import AutoModel
85
- >>> from transformers import AutoImageProcessor
86
- >>>
87
- >>> # Download the model
88
- >>> repo = "microsoft/rad-dino"
89
- >>> rad_dino = AutoModel.from_pretrained(repo)
90
- >>>
91
- >>> # The processor takes a PIL image, performs resizing, center-cropping, and
92
- >>> # intensity normalization using stats from MIMIC-CXR, and returns a
93
- >>> # dictionary with a PyTorch tensor ready for the encoder
94
- >>> processor = AutoImageProcessor.from_pretrained(repo)
95
- ```
96
 
97
  ### Encode an image
98
 
99
  ```python
100
- >>> # Download and preprocess a chest X-ray
 
 
101
  >>> image = download_sample_image()
102
- >>> image.size # (width, height)
103
- (2765, 2505)
104
- >>> inputs = processor(images=image, return_tensors="pt")
105
- >>>
106
- >>> # Encode the image!
107
- >>> with torch.inference_mode():
108
- >>> outputs = rad_dino(**inputs)
109
- >>>
110
- >>> # Look at the CLS embeddings
111
- >>> cls_embeddings = outputs.pooler_output
112
- >>> cls_embeddings.shape # (batch_size, num_channels)
113
- torch.Size([1, 768])
114
- ```
115
-
116
- If we are interested in the feature maps, we can reshape the patch embeddings into a grid.
117
- We will use [`einops`](https://einops.rocks/) (install with `pip install einops`) for this.
118
-
119
- ```python
120
- >>> def reshape_patch_embeddings(flat_tokens: torch.Tensor) -> torch.Tensor:
121
- ... """Reshape flat list of patch tokens into a nice grid."""
122
- ... from einops import rearrange
123
- ... image_size = processor.crop_size["height"]
124
- ... patch_size = model.config.patch_size
125
- ... embeddings_size = image_size // patch_size
126
- ... patches_grid = rearrange(flat_tokens, "b (h w) c -> b c h w", h=embeddings_size)
127
- ... return patches_grid
128
- ...
129
- >>> flat_patch_embeddings = outputs.last_hidden_state[:, 1:] # first token is CLS
130
- >>> reshaped_patch_embeddings = reshape_patch_embeddings(flat_patch_embeddings)
131
- >>> reshaped_patch_embeddings.shape # (batch_size, num_channels, height, width)
132
- torch.Size([1, 768, 37, 37])
133
  ```
134
 
135
  ### Weights for fine-tuning
136
 
137
- We have released a checkpoint compatible with
138
- [the original DINOv2 code](https://github.com/facebookresearch/dinov2) to help
139
- researchers fine-tune our model.
140
-
141
- First, let us write code to load a
142
- [`safetensors` checkpoint](https://huggingface.co/docs/safetensors).
143
-
144
- ```python
145
- >>> import safetensors
146
- >>> def safetensors_to_state_dict(checkpoint_path: str) -> dict[str, torch.Tensor]:
147
- ... state_dict = {}
148
- ... with safe_open(checkpoint_path, framework="pt") as ckpt_file:
149
- ... for key in ckpt_file.keys():
150
- ... state_dict[key] = ckpt_file.get_tensor(key)
151
- ... return state_dict
152
- ...
153
- ```
154
 
155
- We can now use the hub model and load the RAD-DINO weights.
156
  Let's clone the DINOv2 repository so we can import the code for the head.
157
 
158
  ```shell
159
  git clone https://github.com/facebookresearch/dinov2.git
160
- cd dinov2
161
  ```
162
 
163
  ```python
164
  >>> import torch
165
- >>> rad_dino_gh = torch.hub.load(".", "dinov2_vitb14")
 
166
  >>> backbone_state_dict = safetensors_to_state_dict("backbone_compatible.safetensors")
167
  >>> rad_dino_gh.load_state_dict(backbone_state_dict, strict=True)
168
  <All keys matched successfully>
@@ -272,20 +202,20 @@ We used [SimpleITK](https://simpleitk.org/) and [Pydicom](https://pydicom.github
272
 
273
  ```bibtex
274
  @article{perez-garcia_exploring_2025,
275
- title = {Exploring scalable medical image encoders beyond text supervision},
276
- issn = {2522-5839},
277
- url = {https://doi.org/10.1038/s42256-024-00965-w},
278
- doi = {10.1038/s42256-024-00965-w},
279
- journal = {Nature Machine Intelligence},
280
- author = {P{\'e}rez-Garc{\'i}a, Fernando and Sharma, Harshita and Bond-Taylor, Sam and Bouzid, Kenza and Salvatelli, Valentina and Ilse, Maximilian and Bannur, Shruthi and Castro, Daniel C. and Schwaighofer, Anton and Lungren, Matthew P. and Wetscherek, Maria Teodora and Codella, Noel and Hyland, Stephanie L. and Alvarez-Valle, Javier and Oktay, Ozan},
281
- month = jan,
282
- year = {2025},
283
  }
284
  ```
285
 
286
  **APA:**
287
 
288
- > Pérez-García, F., Sharma, H., Bond-Taylor, S., Bouzid, K., Salvatelli, V., Ilse, M., Bannur, S., Castro, D. C., Schwaighofer, A., Lungren, M. P., Wetscherek, M. T., Codella, N., Hyland, S. L., Alvarez-Valle, J., & Oktay, O. (2025). *Exploring scalable medical image encoders beyond text supervision*. In Nature Machine Intelligence. Springer Science and Business Media LLC. https://doi.org/10.1038/s42256-024-00965-w
289
 
290
  ## Model card contact
291
 
 
56
  RAD-DINO was trained with data from three countries, therefore it might be biased towards population in the training data.
57
  Underlying biases of the training datasets may not be well characterized.
58
 
59
+ ## Installation
60
 
61
+ ```shell
62
+ pip install rad-dino
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  ```
64
 
65
+ ## Usage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  ### Encode an image
68
 
69
  ```python
70
+ >>> from rad_dino import RadDino
71
+ >>> from rad_dino.utils import download_sample_image
72
+ >>> encoder = RadDino()
73
  >>> image = download_sample_image()
74
+ >>> image
75
+ <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=2765x2505 at 0x7CCD5C014050>
76
+ >>> cls_token, patch_tokens = encoder.extract_features(image)
77
+ >>> cls_embeddings.shape, patch_embeddings.shape
78
+ (torch.Size([1, 768]), torch.Size([1, 768, 37, 37]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  ```
80
 
81
  ### Weights for fine-tuning
82
 
83
+ We have released a checkpoint compatible with [the original DINOv2 code](https://github.com/facebookresearch/dinov2) to help researchers fine-tune our model.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ We can use the hub model and load the RAD-DINO weights.
86
  Let's clone the DINOv2 repository so we can import the code for the head.
87
 
88
  ```shell
89
  git clone https://github.com/facebookresearch/dinov2.git
 
90
  ```
91
 
92
  ```python
93
  >>> import torch
94
+ >>> from rad_dino.utils import safetensors_to_state_dict
95
+ >>> rad_dino_gh = torch.hub.load("./dinov2", "dinov2_vitb14")
96
  >>> backbone_state_dict = safetensors_to_state_dict("backbone_compatible.safetensors")
97
  >>> rad_dino_gh.load_state_dict(backbone_state_dict, strict=True)
98
  <All keys matched successfully>
 
202
 
203
  ```bibtex
204
  @article{perez-garcia_exploring_2025,
205
+ title = {Exploring scalable medical image encoders beyond text supervision},
206
+ issn = {2522-5839},
207
+ url = {https://doi.org/10.1038/s42256-024-00965-w},
208
+ doi = {10.1038/s42256-024-00965-w},
209
+ journal = {Nature Machine Intelligence},
210
+ author = {P{\'e}rez-Garc{\'i}a, Fernando and Sharma, Harshita and Bond-Taylor, Sam and Bouzid, Kenza and Salvatelli, Valentina and Ilse, Maximilian and Bannur, Shruthi and Castro, Daniel C. and Schwaighofer, Anton and Lungren, Matthew P. and Wetscherek, Maria Teodora and Codella, Noel and Hyland, Stephanie L. and Alvarez-Valle, Javier and Oktay, Ozan},
211
+ month = jan,
212
+ year = {2025},
213
  }
214
  ```
215
 
216
  **APA:**
217
 
218
+ > Pérez-García, F., Sharma, H., Bond-Taylor, S., Bouzid, K., Salvatelli, V., Ilse, M., Bannur, S., Castro, D. C., Schwaighofer, A., Lungren, M. P., Wetscherek, M. T., Codella, N., Hyland, S. L., Alvarez-Valle, J., & Oktay, O. (2025). *Exploring scalable medical image encoders beyond text supervision*. In Nature Machine Intelligence. Springer Science and Business Media LLC. <https://doi.org/10.1038/s42256-024-00965-w>
219
 
220
  ## Model card contact
221
 
pyproject.toml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "rad-dino"
3
+ version = "0.1.0"
4
+ description = "Vision encoder for chest X-rays."
5
+ readme = "README.md"
6
+ authors = [{ name = "Microsoft Health Futures" }]
7
+ requires-python = ">=3.10"
8
+ dependencies = [
9
+ "einops",
10
+ "jaxtyping",
11
+ "pillow",
12
+ "requests",
13
+ "safetensors",
14
+ "transformers[torch]",
15
+ "typer>=0.16.0",
16
+ ]
17
+
18
+ [project.scripts]
19
+ rad-dino = "rad_dino.__main__:main"
20
+
21
+ [build-system]
22
+ requires = ["uv_build>=0.8.3,<0.9.0"]
23
+ build-backend = "uv_build"
24
+
25
+ [dependency-groups]
26
+ dev = ["ipykernel", "ipywidgets"]
src/rad_dino/__init__.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+ from jaxtyping import Float
4
+ from PIL import Image
5
+ from torch import Tensor
6
+ from torch import nn
7
+ from transformers import AutoImageProcessor
8
+ from transformers import AutoModel
9
+ from transformers.image_processing_base import BatchFeature
10
+
11
+
12
+ __version__ = "0.1.0"
13
+
14
+ TypeClsToken = Float[Tensor, "batch_size embed_dim"]
15
+ TypePatchTokensFlat = Float[Tensor, "batch_size (height width) embed_dim"]
16
+ TypePatchTokens = Float[Tensor, "batch_size embed_dim height width"]
17
+ TypeInputImages = Image.Image | list[Image.Image]
18
+
19
+
20
+ class RadDino(nn.Module):
21
+ _REPO = "microsoft/rad-dino"
22
+
23
+ def __init__(self):
24
+ super().__init__()
25
+ self.model = AutoModel.from_pretrained(self._REPO).eval()
26
+ self.processor = AutoImageProcessor.from_pretrained(self._REPO, use_fast=False)
27
+
28
+ def preprocess(self, image_or_images: TypeInputImages) -> BatchFeature:
29
+ return self.processor(image_or_images, return_tensors="pt")
30
+
31
+ def encode(self, inputs: BatchFeature) -> tuple[TypeClsToken, TypePatchTokensFlat]:
32
+ outputs = self.model(**inputs)
33
+ cls_token = outputs.last_hidden_state[:, 0]
34
+ patch_tokens = outputs.last_hidden_state[:, 1:]
35
+ return cls_token, patch_tokens
36
+
37
+ def reshape_patch_tokens(
38
+ self,
39
+ patch_tokens_flat: TypePatchTokensFlat,
40
+ ) -> TypePatchTokens:
41
+ input_size = self.processor.crop_size["height"]
42
+ patch_size = self.model.config.patch_size
43
+ embeddings_size = input_size // patch_size
44
+ patches_grid = rearrange(
45
+ patch_tokens_flat,
46
+ "batch (height width) embed_dim -> batch embed_dim height width",
47
+ height=embeddings_size,
48
+ )
49
+ return patches_grid
50
+
51
+ @torch.inference_mode()
52
+ def extract_features(
53
+ self,
54
+ image_or_images: TypeInputImages,
55
+ ) -> tuple[TypeClsToken, TypePatchTokens]:
56
+ inputs = self.preprocess(image_or_images)
57
+ cls_token, patch_tokens_flat = self.encode(inputs)
58
+ patch_tokens = self.reshape_patch_tokens(patch_tokens_flat)
59
+ return cls_token, patch_tokens
60
+
61
+ def extract_cls_token(self, image_or_images: TypeInputImages) -> TypeClsToken:
62
+ cls_token, _ = self.extract_features(image_or_images)
63
+ return cls_token
64
+
65
+ def extract_patch_tokens(self, image_or_images: TypeInputImages) -> TypePatchTokens:
66
+ _, patch_tokens = self.extract_features(image_or_images)
67
+ return patch_tokens
68
+
69
+ def forward(self, *args) -> tuple[TypeClsToken, TypePatchTokens]:
70
+ return self.extract_features(*args)
src/rad_dino/__main__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ def main():
2
+ print("Hello from rad-dino!")
3
+
4
+
5
+ if __name__ == "__main__":
6
+ main()
src/rad_dino/utils.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import torch
3
+ from PIL import Image
4
+ from safetensors import safe_open
5
+
6
+
7
+ def download_sample_image() -> Image.Image:
8
+ """Download chest X-ray with CC license."""
9
+ base_url = "https://upload.wikimedia.org/wikipedia/commons"
10
+ path = "2/20/Chest_X-ray_in_influenza_and_Haemophilus_influenzae.jpg"
11
+ image_url = f"{base_url}/{path}"
12
+ headers = {"User-Agent": "RAD-DINO"}
13
+ response = requests.get(image_url, headers=headers, stream=True)
14
+ return Image.open(response.raw)
15
+
16
+
17
+ def safetensors_to_state_dict(checkpoint_path: str) -> dict[str, torch.Tensor]:
18
+ state_dict = {}
19
+ with safe_open(checkpoint_path, framework="pt") as ckpt_file:
20
+ for key in ckpt_file.keys():
21
+ state_dict[key] = ckpt_file.get_tensor(key)
22
+ return state_dict