Add package code update docs
Browse files- .gitignore +17 -0
- README.md +25 -95
- pyproject.toml +26 -0
- src/rad_dino/__init__.py +70 -0
- src/rad_dino/__main__.py +6 -0
- src/rad_dino/utils.py +22 -0
.gitignore
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Python-generated files
|
2 |
+
__pycache__/
|
3 |
+
*.py[oc]
|
4 |
+
build/
|
5 |
+
dist/
|
6 |
+
wheels/
|
7 |
+
*.egg-info
|
8 |
+
|
9 |
+
# Virtual environments
|
10 |
+
.venv
|
11 |
+
|
12 |
+
.vscode/
|
13 |
+
|
14 |
+
*.ipynb
|
15 |
+
uv.lock
|
16 |
+
*.txt
|
17 |
+
.python-version
|
README.md
CHANGED
@@ -56,113 +56,43 @@ Fine-tuning RAD-DINO is typically not necessary to obtain good performance in do
|
|
56 |
RAD-DINO was trained with data from three countries, therefore it might be biased towards population in the training data.
|
57 |
Underlying biases of the training datasets may not be well characterized.
|
58 |
|
59 |
-
##
|
60 |
|
61 |
-
|
62 |
-
|
63 |
-
Let us first write an auxiliary function to download a chest X-ray.
|
64 |
-
|
65 |
-
```python
|
66 |
-
>>> import requests
|
67 |
-
>>> from PIL import Image
|
68 |
-
>>> def download_sample_image() -> Image.Image:
|
69 |
-
... """Download chest X-ray with CC license."""
|
70 |
-
... base_url = "https://upload.wikimedia.org/wikipedia/commons"
|
71 |
-
... image_url = f"{base_url}/2/20/Chest_X-ray_in_influenza_and_Haemophilus_influenzae.jpg"
|
72 |
-
... headers = {"User-Agent": "RAD-DINO"}
|
73 |
-
... response = requests.get(image_url, headers=headers, stream=True)
|
74 |
-
... return Image.open(response.raw)
|
75 |
-
...
|
76 |
```
|
77 |
|
78 |
-
|
79 |
-
|
80 |
-
Now let us download the model and encode an image.
|
81 |
-
|
82 |
-
```python
|
83 |
-
>>> import torch
|
84 |
-
>>> from transformers import AutoModel
|
85 |
-
>>> from transformers import AutoImageProcessor
|
86 |
-
>>>
|
87 |
-
>>> # Download the model
|
88 |
-
>>> repo = "microsoft/rad-dino"
|
89 |
-
>>> rad_dino = AutoModel.from_pretrained(repo)
|
90 |
-
>>>
|
91 |
-
>>> # The processor takes a PIL image, performs resizing, center-cropping, and
|
92 |
-
>>> # intensity normalization using stats from MIMIC-CXR, and returns a
|
93 |
-
>>> # dictionary with a PyTorch tensor ready for the encoder
|
94 |
-
>>> processor = AutoImageProcessor.from_pretrained(repo)
|
95 |
-
```
|
96 |
|
97 |
### Encode an image
|
98 |
|
99 |
```python
|
100 |
-
>>>
|
|
|
|
|
101 |
>>> image = download_sample_image()
|
102 |
-
>>> image
|
103 |
-
|
104 |
-
>>>
|
105 |
-
>>>
|
106 |
-
|
107 |
-
>>> with torch.inference_mode():
|
108 |
-
>>> outputs = rad_dino(**inputs)
|
109 |
-
>>>
|
110 |
-
>>> # Look at the CLS embeddings
|
111 |
-
>>> cls_embeddings = outputs.pooler_output
|
112 |
-
>>> cls_embeddings.shape # (batch_size, num_channels)
|
113 |
-
torch.Size([1, 768])
|
114 |
-
```
|
115 |
-
|
116 |
-
If we are interested in the feature maps, we can reshape the patch embeddings into a grid.
|
117 |
-
We will use [`einops`](https://einops.rocks/) (install with `pip install einops`) for this.
|
118 |
-
|
119 |
-
```python
|
120 |
-
>>> def reshape_patch_embeddings(flat_tokens: torch.Tensor) -> torch.Tensor:
|
121 |
-
... """Reshape flat list of patch tokens into a nice grid."""
|
122 |
-
... from einops import rearrange
|
123 |
-
... image_size = processor.crop_size["height"]
|
124 |
-
... patch_size = model.config.patch_size
|
125 |
-
... embeddings_size = image_size // patch_size
|
126 |
-
... patches_grid = rearrange(flat_tokens, "b (h w) c -> b c h w", h=embeddings_size)
|
127 |
-
... return patches_grid
|
128 |
-
...
|
129 |
-
>>> flat_patch_embeddings = outputs.last_hidden_state[:, 1:] # first token is CLS
|
130 |
-
>>> reshaped_patch_embeddings = reshape_patch_embeddings(flat_patch_embeddings)
|
131 |
-
>>> reshaped_patch_embeddings.shape # (batch_size, num_channels, height, width)
|
132 |
-
torch.Size([1, 768, 37, 37])
|
133 |
```
|
134 |
|
135 |
### Weights for fine-tuning
|
136 |
|
137 |
-
We have released a checkpoint compatible with
|
138 |
-
[the original DINOv2 code](https://github.com/facebookresearch/dinov2) to help
|
139 |
-
researchers fine-tune our model.
|
140 |
-
|
141 |
-
First, let us write code to load a
|
142 |
-
[`safetensors` checkpoint](https://huggingface.co/docs/safetensors).
|
143 |
-
|
144 |
-
```python
|
145 |
-
>>> import safetensors
|
146 |
-
>>> def safetensors_to_state_dict(checkpoint_path: str) -> dict[str, torch.Tensor]:
|
147 |
-
... state_dict = {}
|
148 |
-
... with safe_open(checkpoint_path, framework="pt") as ckpt_file:
|
149 |
-
... for key in ckpt_file.keys():
|
150 |
-
... state_dict[key] = ckpt_file.get_tensor(key)
|
151 |
-
... return state_dict
|
152 |
-
...
|
153 |
-
```
|
154 |
|
155 |
-
We can
|
156 |
Let's clone the DINOv2 repository so we can import the code for the head.
|
157 |
|
158 |
```shell
|
159 |
git clone https://github.com/facebookresearch/dinov2.git
|
160 |
-
cd dinov2
|
161 |
```
|
162 |
|
163 |
```python
|
164 |
>>> import torch
|
165 |
-
>>>
|
|
|
166 |
>>> backbone_state_dict = safetensors_to_state_dict("backbone_compatible.safetensors")
|
167 |
>>> rad_dino_gh.load_state_dict(backbone_state_dict, strict=True)
|
168 |
<All keys matched successfully>
|
@@ -272,20 +202,20 @@ We used [SimpleITK](https://simpleitk.org/) and [Pydicom](https://pydicom.github
|
|
272 |
|
273 |
```bibtex
|
274 |
@article{perez-garcia_exploring_2025,
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
}
|
284 |
```
|
285 |
|
286 |
**APA:**
|
287 |
|
288 |
-
> Pérez-García, F., Sharma, H., Bond-Taylor, S., Bouzid, K., Salvatelli, V., Ilse, M., Bannur, S., Castro, D. C., Schwaighofer, A., Lungren, M. P., Wetscherek, M. T., Codella, N., Hyland, S. L., Alvarez-Valle, J., & Oktay, O. (2025). *Exploring scalable medical image encoders beyond text supervision*. In Nature Machine Intelligence. Springer Science and Business Media LLC. https://doi.org/10.1038/s42256-024-00965-w
|
289 |
|
290 |
## Model card contact
|
291 |
|
|
|
56 |
RAD-DINO was trained with data from three countries, therefore it might be biased towards population in the training data.
|
57 |
Underlying biases of the training datasets may not be well characterized.
|
58 |
|
59 |
+
## Installation
|
60 |
|
61 |
+
```shell
|
62 |
+
pip install rad-dino
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
```
|
64 |
|
65 |
+
## Usage
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
### Encode an image
|
68 |
|
69 |
```python
|
70 |
+
>>> from rad_dino import RadDino
|
71 |
+
>>> from rad_dino.utils import download_sample_image
|
72 |
+
>>> encoder = RadDino()
|
73 |
>>> image = download_sample_image()
|
74 |
+
>>> image
|
75 |
+
<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=2765x2505 at 0x7CCD5C014050>
|
76 |
+
>>> cls_token, patch_tokens = encoder.extract_features(image)
|
77 |
+
>>> cls_embeddings.shape, patch_embeddings.shape
|
78 |
+
(torch.Size([1, 768]), torch.Size([1, 768, 37, 37]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
```
|
80 |
|
81 |
### Weights for fine-tuning
|
82 |
|
83 |
+
We have released a checkpoint compatible with [the original DINOv2 code](https://github.com/facebookresearch/dinov2) to help researchers fine-tune our model.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
+
We can use the hub model and load the RAD-DINO weights.
|
86 |
Let's clone the DINOv2 repository so we can import the code for the head.
|
87 |
|
88 |
```shell
|
89 |
git clone https://github.com/facebookresearch/dinov2.git
|
|
|
90 |
```
|
91 |
|
92 |
```python
|
93 |
>>> import torch
|
94 |
+
>>> from rad_dino.utils import safetensors_to_state_dict
|
95 |
+
>>> rad_dino_gh = torch.hub.load("./dinov2", "dinov2_vitb14")
|
96 |
>>> backbone_state_dict = safetensors_to_state_dict("backbone_compatible.safetensors")
|
97 |
>>> rad_dino_gh.load_state_dict(backbone_state_dict, strict=True)
|
98 |
<All keys matched successfully>
|
|
|
202 |
|
203 |
```bibtex
|
204 |
@article{perez-garcia_exploring_2025,
|
205 |
+
title = {Exploring scalable medical image encoders beyond text supervision},
|
206 |
+
issn = {2522-5839},
|
207 |
+
url = {https://doi.org/10.1038/s42256-024-00965-w},
|
208 |
+
doi = {10.1038/s42256-024-00965-w},
|
209 |
+
journal = {Nature Machine Intelligence},
|
210 |
+
author = {P{\'e}rez-Garc{\'i}a, Fernando and Sharma, Harshita and Bond-Taylor, Sam and Bouzid, Kenza and Salvatelli, Valentina and Ilse, Maximilian and Bannur, Shruthi and Castro, Daniel C. and Schwaighofer, Anton and Lungren, Matthew P. and Wetscherek, Maria Teodora and Codella, Noel and Hyland, Stephanie L. and Alvarez-Valle, Javier and Oktay, Ozan},
|
211 |
+
month = jan,
|
212 |
+
year = {2025},
|
213 |
}
|
214 |
```
|
215 |
|
216 |
**APA:**
|
217 |
|
218 |
+
> Pérez-García, F., Sharma, H., Bond-Taylor, S., Bouzid, K., Salvatelli, V., Ilse, M., Bannur, S., Castro, D. C., Schwaighofer, A., Lungren, M. P., Wetscherek, M. T., Codella, N., Hyland, S. L., Alvarez-Valle, J., & Oktay, O. (2025). *Exploring scalable medical image encoders beyond text supervision*. In Nature Machine Intelligence. Springer Science and Business Media LLC. <https://doi.org/10.1038/s42256-024-00965-w>
|
219 |
|
220 |
## Model card contact
|
221 |
|
pyproject.toml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[project]
|
2 |
+
name = "rad-dino"
|
3 |
+
version = "0.1.0"
|
4 |
+
description = "Vision encoder for chest X-rays."
|
5 |
+
readme = "README.md"
|
6 |
+
authors = [{ name = "Microsoft Health Futures" }]
|
7 |
+
requires-python = ">=3.10"
|
8 |
+
dependencies = [
|
9 |
+
"einops",
|
10 |
+
"jaxtyping",
|
11 |
+
"pillow",
|
12 |
+
"requests",
|
13 |
+
"safetensors",
|
14 |
+
"transformers[torch]",
|
15 |
+
"typer>=0.16.0",
|
16 |
+
]
|
17 |
+
|
18 |
+
[project.scripts]
|
19 |
+
rad-dino = "rad_dino.__main__:main"
|
20 |
+
|
21 |
+
[build-system]
|
22 |
+
requires = ["uv_build>=0.8.3,<0.9.0"]
|
23 |
+
build-backend = "uv_build"
|
24 |
+
|
25 |
+
[dependency-groups]
|
26 |
+
dev = ["ipykernel", "ipywidgets"]
|
src/rad_dino/__init__.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from einops import rearrange
|
3 |
+
from jaxtyping import Float
|
4 |
+
from PIL import Image
|
5 |
+
from torch import Tensor
|
6 |
+
from torch import nn
|
7 |
+
from transformers import AutoImageProcessor
|
8 |
+
from transformers import AutoModel
|
9 |
+
from transformers.image_processing_base import BatchFeature
|
10 |
+
|
11 |
+
|
12 |
+
__version__ = "0.1.0"
|
13 |
+
|
14 |
+
TypeClsToken = Float[Tensor, "batch_size embed_dim"]
|
15 |
+
TypePatchTokensFlat = Float[Tensor, "batch_size (height width) embed_dim"]
|
16 |
+
TypePatchTokens = Float[Tensor, "batch_size embed_dim height width"]
|
17 |
+
TypeInputImages = Image.Image | list[Image.Image]
|
18 |
+
|
19 |
+
|
20 |
+
class RadDino(nn.Module):
|
21 |
+
_REPO = "microsoft/rad-dino"
|
22 |
+
|
23 |
+
def __init__(self):
|
24 |
+
super().__init__()
|
25 |
+
self.model = AutoModel.from_pretrained(self._REPO).eval()
|
26 |
+
self.processor = AutoImageProcessor.from_pretrained(self._REPO, use_fast=False)
|
27 |
+
|
28 |
+
def preprocess(self, image_or_images: TypeInputImages) -> BatchFeature:
|
29 |
+
return self.processor(image_or_images, return_tensors="pt")
|
30 |
+
|
31 |
+
def encode(self, inputs: BatchFeature) -> tuple[TypeClsToken, TypePatchTokensFlat]:
|
32 |
+
outputs = self.model(**inputs)
|
33 |
+
cls_token = outputs.last_hidden_state[:, 0]
|
34 |
+
patch_tokens = outputs.last_hidden_state[:, 1:]
|
35 |
+
return cls_token, patch_tokens
|
36 |
+
|
37 |
+
def reshape_patch_tokens(
|
38 |
+
self,
|
39 |
+
patch_tokens_flat: TypePatchTokensFlat,
|
40 |
+
) -> TypePatchTokens:
|
41 |
+
input_size = self.processor.crop_size["height"]
|
42 |
+
patch_size = self.model.config.patch_size
|
43 |
+
embeddings_size = input_size // patch_size
|
44 |
+
patches_grid = rearrange(
|
45 |
+
patch_tokens_flat,
|
46 |
+
"batch (height width) embed_dim -> batch embed_dim height width",
|
47 |
+
height=embeddings_size,
|
48 |
+
)
|
49 |
+
return patches_grid
|
50 |
+
|
51 |
+
@torch.inference_mode()
|
52 |
+
def extract_features(
|
53 |
+
self,
|
54 |
+
image_or_images: TypeInputImages,
|
55 |
+
) -> tuple[TypeClsToken, TypePatchTokens]:
|
56 |
+
inputs = self.preprocess(image_or_images)
|
57 |
+
cls_token, patch_tokens_flat = self.encode(inputs)
|
58 |
+
patch_tokens = self.reshape_patch_tokens(patch_tokens_flat)
|
59 |
+
return cls_token, patch_tokens
|
60 |
+
|
61 |
+
def extract_cls_token(self, image_or_images: TypeInputImages) -> TypeClsToken:
|
62 |
+
cls_token, _ = self.extract_features(image_or_images)
|
63 |
+
return cls_token
|
64 |
+
|
65 |
+
def extract_patch_tokens(self, image_or_images: TypeInputImages) -> TypePatchTokens:
|
66 |
+
_, patch_tokens = self.extract_features(image_or_images)
|
67 |
+
return patch_tokens
|
68 |
+
|
69 |
+
def forward(self, *args) -> tuple[TypeClsToken, TypePatchTokens]:
|
70 |
+
return self.extract_features(*args)
|
src/rad_dino/__main__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def main():
|
2 |
+
print("Hello from rad-dino!")
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
main()
|
src/rad_dino/utils.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
from safetensors import safe_open
|
5 |
+
|
6 |
+
|
7 |
+
def download_sample_image() -> Image.Image:
|
8 |
+
"""Download chest X-ray with CC license."""
|
9 |
+
base_url = "https://upload.wikimedia.org/wikipedia/commons"
|
10 |
+
path = "2/20/Chest_X-ray_in_influenza_and_Haemophilus_influenzae.jpg"
|
11 |
+
image_url = f"{base_url}/{path}"
|
12 |
+
headers = {"User-Agent": "RAD-DINO"}
|
13 |
+
response = requests.get(image_url, headers=headers, stream=True)
|
14 |
+
return Image.open(response.raw)
|
15 |
+
|
16 |
+
|
17 |
+
def safetensors_to_state_dict(checkpoint_path: str) -> dict[str, torch.Tensor]:
|
18 |
+
state_dict = {}
|
19 |
+
with safe_open(checkpoint_path, framework="pt") as ckpt_file:
|
20 |
+
for key in ckpt_file.keys():
|
21 |
+
state_dict[key] = ckpt_file.get_tensor(key)
|
22 |
+
return state_dict
|