AleksanderObuchowski
commited on
Commit
•
1640642
1
Parent(s):
5f10544
Add files using upload-large-folder tool
Browse files- .gitignore +10 -0
- .python-version +1 -0
- README.md +0 -3
- example.py +44 -0
- flask_app.py +57 -0
- medimageinsightmodel.py +239 -0
- pyproject.toml +29 -0
- requirements.txt +18 -0
- uv.lock +0 -0
.gitignore
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Python-generated files
|
2 |
+
__pycache__/
|
3 |
+
*.py[oc]
|
4 |
+
build/
|
5 |
+
dist/
|
6 |
+
wheels/
|
7 |
+
*.egg-info
|
8 |
+
|
9 |
+
# Virtual environments
|
10 |
+
.venv
|
.python-version
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
3.8.19
|
README.md
CHANGED
@@ -1,3 +0,0 @@
|
|
1 |
-
---
|
2 |
-
license: mit
|
3 |
-
---
|
|
|
|
|
|
|
|
example.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Initialize classifier
|
2 |
+
from medimageinsightmodel import MedImageInsight
|
3 |
+
import base64
|
4 |
+
|
5 |
+
|
6 |
+
classifier = MedImageInsight(
|
7 |
+
model_dir="2024.09.27",
|
8 |
+
vision_model_name="medimageinsigt-v1.0.0.pt",
|
9 |
+
language_model_name="language_model.pth"
|
10 |
+
)
|
11 |
+
|
12 |
+
def read_image(image_path):
|
13 |
+
with open(image_path, "rb") as f:
|
14 |
+
return f.read()
|
15 |
+
|
16 |
+
# Load model
|
17 |
+
classifier.load_model()
|
18 |
+
|
19 |
+
import urllib.request
|
20 |
+
|
21 |
+
image_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-1001.png"
|
22 |
+
image_path = "CXR145_IM-0290-1001.png"
|
23 |
+
|
24 |
+
urllib.request.urlretrieve(image_url, image_path)
|
25 |
+
print(f"Image downloaded to {image_path}")
|
26 |
+
|
27 |
+
|
28 |
+
image = base64.encodebytes(read_image(image_path)).decode("utf-8")
|
29 |
+
|
30 |
+
# Example inference
|
31 |
+
images = [image]
|
32 |
+
labels = ["normal", "Pneumonia", "unclear"]
|
33 |
+
|
34 |
+
#Zero-shot classification
|
35 |
+
results = classifier.predict(images, labels)
|
36 |
+
print(results)
|
37 |
+
|
38 |
+
#Image embeddings
|
39 |
+
results = classifier.encode(images = images)
|
40 |
+
print(results)
|
41 |
+
|
42 |
+
#Text embeddings
|
43 |
+
results = classifier.encode(texts = labels)
|
44 |
+
print(results)
|
flask_app.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException
|
2 |
+
from pydantic import BaseModel
|
3 |
+
from typing import List
|
4 |
+
import uvicorn
|
5 |
+
from medimageinsightmodel import MedImageInsight
|
6 |
+
import base64
|
7 |
+
|
8 |
+
# Initialize FastAPI app
|
9 |
+
app = FastAPI(title="Medical Image Analysis API")
|
10 |
+
|
11 |
+
# Initialize model
|
12 |
+
classifier = MedImageInsight(
|
13 |
+
model_dir="2024.09.27",
|
14 |
+
vision_model_name="medimageinsigt-v1.0.0.pt",
|
15 |
+
language_model_name="language_model.pth"
|
16 |
+
)
|
17 |
+
classifier.load_model()
|
18 |
+
|
19 |
+
|
20 |
+
class ClassificationRequest(BaseModel):
|
21 |
+
images: List[str] # Base64 encoded images
|
22 |
+
labels: List[str]
|
23 |
+
multilabel : bool = False
|
24 |
+
|
25 |
+
class EmbeddingRequest(BaseModel):
|
26 |
+
images: List[str] = None # Base64 encoded images
|
27 |
+
texts: List[str] = None
|
28 |
+
|
29 |
+
@app.post("/predict")
|
30 |
+
async def predict(request: ClassificationRequest):
|
31 |
+
try:
|
32 |
+
results = classifier.predict(
|
33 |
+
images=request.images,
|
34 |
+
labels=request.labels,
|
35 |
+
multilabel = request.multilabel
|
36 |
+
)
|
37 |
+
return {"predictions": results}
|
38 |
+
except Exception as e:
|
39 |
+
raise HTTPException(status_code=500, detail=str(e))
|
40 |
+
|
41 |
+
@app.post("/encode")
|
42 |
+
async def encode(request: EmbeddingRequest):
|
43 |
+
try:
|
44 |
+
results = classifier.encode(images=request.images, texts= request.texts)
|
45 |
+
results["image_embeddings"] = results["image_embeddings"].tolist() if results["image_embeddings"] is not None else None
|
46 |
+
results["text_embeddings"] = results["text_embeddings"].tolist() if results["text_embeddings"] is not None else None
|
47 |
+
|
48 |
+
return results
|
49 |
+
except Exception as e:
|
50 |
+
raise HTTPException(status_code=500, detail=str(e))
|
51 |
+
|
52 |
+
@app.get("/health")
|
53 |
+
async def health():
|
54 |
+
return {"status": "healthy"}
|
55 |
+
|
56 |
+
if __name__ == "__main__":
|
57 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
medimageinsightmodel.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Medical Image Classification model wrapper class that loads the model, preprocesses inputs and performs inference."""
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
import pandas as pd
|
6 |
+
from typing import List, Tuple
|
7 |
+
import os
|
8 |
+
import tempfile
|
9 |
+
import base64
|
10 |
+
import io
|
11 |
+
|
12 |
+
from MedImageInsight.UniCLModel import build_unicl_model
|
13 |
+
from MedImageInsight.Utils.Arguments import load_opt_from_config_files
|
14 |
+
from MedImageInsight.ImageDataLoader import build_transforms
|
15 |
+
from MedImageInsight.LangEncoder import build_tokenizer
|
16 |
+
|
17 |
+
|
18 |
+
class MedImageInsight:
|
19 |
+
"""Wrapper class for medical image classification model."""
|
20 |
+
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
model_dir: str,
|
24 |
+
vision_model_name: str,
|
25 |
+
language_model_name: str
|
26 |
+
) -> None:
|
27 |
+
"""Initialize the medical image classifier.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
model_dir: Directory containing model files and config
|
31 |
+
vision_model_name: Name of the vision model
|
32 |
+
language_model_name: Name of the language model
|
33 |
+
"""
|
34 |
+
self.model_dir = model_dir
|
35 |
+
self.vision_model_name = vision_model_name
|
36 |
+
self.language_model_name = language_model_name
|
37 |
+
self.model = None
|
38 |
+
self.device = None
|
39 |
+
self.tokenize = None
|
40 |
+
self.preprocess = None
|
41 |
+
self.opt = None
|
42 |
+
|
43 |
+
def load_model(self) -> None:
|
44 |
+
"""Load the model and necessary components."""
|
45 |
+
try:
|
46 |
+
# Load configuration
|
47 |
+
config_path = os.path.join(self.model_dir, 'config.yaml')
|
48 |
+
self.opt = load_opt_from_config_files([config_path])
|
49 |
+
|
50 |
+
# Set paths
|
51 |
+
self.opt['LANG_ENCODER']['PRETRAINED_TOKENIZER'] = os.path.join(
|
52 |
+
self.model_dir,
|
53 |
+
'language_model',
|
54 |
+
'clip_tokenizer_4.16.2'
|
55 |
+
)
|
56 |
+
self.opt['UNICL_MODEL']['PRETRAINED'] = os.path.join(
|
57 |
+
self.model_dir,
|
58 |
+
'vision_model',
|
59 |
+
self.vision_model_name
|
60 |
+
)
|
61 |
+
|
62 |
+
# Initialize components
|
63 |
+
self.preprocess = build_transforms(self.opt, False)
|
64 |
+
self.model = build_unicl_model(self.opt)
|
65 |
+
|
66 |
+
# Set device
|
67 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
68 |
+
self.model.to(self.device)
|
69 |
+
|
70 |
+
# Load tokenizer
|
71 |
+
self.tokenize = build_tokenizer(self.opt['LANG_ENCODER'])
|
72 |
+
self.max_length = self.opt['LANG_ENCODER']['CONTEXT_LENGTH']
|
73 |
+
|
74 |
+
print(f"Model loaded successfully on device: {self.device}")
|
75 |
+
|
76 |
+
except Exception as e:
|
77 |
+
print("Failed to load the model:")
|
78 |
+
raise e
|
79 |
+
|
80 |
+
@staticmethod
|
81 |
+
def decode_base64_image(base64_str: str) -> Image.Image:
|
82 |
+
"""Decode base64 string to PIL Image and ensure RGB format.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
base64_str: Base64 encoded image string
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
PIL Image object in RGB format
|
89 |
+
"""
|
90 |
+
try:
|
91 |
+
# Remove header if present
|
92 |
+
if ',' in base64_str:
|
93 |
+
base64_str = base64_str.split(',')[1]
|
94 |
+
|
95 |
+
image_bytes = base64.b64decode(base64_str)
|
96 |
+
image = Image.open(io.BytesIO(image_bytes))
|
97 |
+
|
98 |
+
# Convert grayscale (L) or grayscale with alpha (LA) to RGB
|
99 |
+
if image.mode in ('L', 'LA'):
|
100 |
+
image = image.convert('RGB')
|
101 |
+
|
102 |
+
return image
|
103 |
+
except Exception as e:
|
104 |
+
raise ValueError(f"Failed to decode base64 image: {str(e)}")
|
105 |
+
|
106 |
+
def predict(self, images: List[str], labels: List[str], multilabel: bool = False) -> List[dict]:
|
107 |
+
"""Perform zero shot classification on the input images.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
images: List of base64 encoded image strings
|
111 |
+
labels: List of candidate labels for classification
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
DataFrame with columns ["probabilities", "labels"]
|
115 |
+
"""
|
116 |
+
if not self.model:
|
117 |
+
raise RuntimeError("Model not loaded. Call load_model() first.")
|
118 |
+
|
119 |
+
if not labels:
|
120 |
+
raise ValueError("No labels provided")
|
121 |
+
|
122 |
+
# Create temporary directory for processing
|
123 |
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
124 |
+
# Process images
|
125 |
+
image_list = []
|
126 |
+
for img_base64 in images:
|
127 |
+
try:
|
128 |
+
img = self.decode_base64_image(img_base64)
|
129 |
+
image_list.append(img)
|
130 |
+
except Exception as e:
|
131 |
+
raise ValueError(f"Failed to process image: {str(e)}")
|
132 |
+
|
133 |
+
# Run inference
|
134 |
+
probs = self.run_inference_batch(image_list, labels, multilabel)
|
135 |
+
probs_np = probs.cpu().numpy()
|
136 |
+
results = []
|
137 |
+
for prob_row in probs_np:
|
138 |
+
# Create label-prob pairs and sort by probability
|
139 |
+
label_probs = [(label, float(prob)) for label, prob in zip(labels, prob_row)]
|
140 |
+
label_probs.sort(key=lambda x: x[1], reverse=True)
|
141 |
+
|
142 |
+
# Create ordered dictionary from sorted pairs
|
143 |
+
results.append({
|
144 |
+
label: prob
|
145 |
+
for label, prob in label_probs
|
146 |
+
})
|
147 |
+
|
148 |
+
return results
|
149 |
+
|
150 |
+
def encode(self, images: List[str] = None, texts: List[str] = None):
|
151 |
+
|
152 |
+
output = {
|
153 |
+
"image_embeddings" : None,
|
154 |
+
"text_embeddings" : None,
|
155 |
+
}
|
156 |
+
|
157 |
+
if not self.model:
|
158 |
+
raise RuntimeError("Model not loaded. Call load_model() first.")
|
159 |
+
|
160 |
+
if not images and not texts:
|
161 |
+
raise ValueError("You must provide either images or texts")
|
162 |
+
|
163 |
+
if images is not None:
|
164 |
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
165 |
+
# Process images
|
166 |
+
image_list = []
|
167 |
+
for img_base64 in images:
|
168 |
+
try:
|
169 |
+
img = self.decode_base64_image(img_base64)
|
170 |
+
image_list.append(img)
|
171 |
+
except Exception as e:
|
172 |
+
raise ValueError(f"Failed to process image: {str(e)}")
|
173 |
+
images = torch.stack([self.preprocess(img) for img in image_list]).to(self.device)
|
174 |
+
with torch.no_grad():
|
175 |
+
output["image_embeddings"] = self.model.encode_image(images).cpu().numpy()
|
176 |
+
|
177 |
+
if texts is not None:
|
178 |
+
text_tokens = self.tokenize(
|
179 |
+
texts,
|
180 |
+
padding='max_length',
|
181 |
+
max_length=self.max_length,
|
182 |
+
truncation=True,
|
183 |
+
return_tensors='pt'
|
184 |
+
)
|
185 |
+
|
186 |
+
# Move text tensors to the correct device
|
187 |
+
text_tokens = {k: v.to(self.device) for k, v in text_tokens.items()}
|
188 |
+
output["text_embeddings"] = self.model.encode_text(text_tokens).cpu().numpy()
|
189 |
+
|
190 |
+
|
191 |
+
return output
|
192 |
+
|
193 |
+
def run_inference_batch(
|
194 |
+
self,
|
195 |
+
images: List[Image.Image],
|
196 |
+
texts: List[str],
|
197 |
+
multilabel: bool = False
|
198 |
+
) -> torch.Tensor:
|
199 |
+
"""Perform inference on batch of input images.
|
200 |
+
|
201 |
+
Args:
|
202 |
+
images: List of PIL Image objects
|
203 |
+
texts: List of text labels
|
204 |
+
multilabel: If True, use sigmoid for multilabel classification.
|
205 |
+
If False, use softmax for single-label classification.
|
206 |
+
|
207 |
+
Returns:
|
208 |
+
Tensor of prediction probabilities
|
209 |
+
"""
|
210 |
+
# Prepare inputs
|
211 |
+
images = torch.stack([self.preprocess(img) for img in images]).to(self.device)
|
212 |
+
|
213 |
+
# Process text
|
214 |
+
text_tokens = self.tokenize(
|
215 |
+
texts,
|
216 |
+
padding='max_length',
|
217 |
+
max_length=self.max_length,
|
218 |
+
truncation=True,
|
219 |
+
return_tensors='pt'
|
220 |
+
)
|
221 |
+
|
222 |
+
# Move text tensors to the correct device
|
223 |
+
text_tokens = {k: v.to(self.device) for k, v in text_tokens.items()}
|
224 |
+
|
225 |
+
# Run inference
|
226 |
+
with torch.no_grad():
|
227 |
+
outputs = self.model(image=images, text=text_tokens)
|
228 |
+
logits_per_image = outputs[0] @ outputs[1].t() * outputs[2]
|
229 |
+
|
230 |
+
if multilabel:
|
231 |
+
# Use sigmoid for independent probabilities per label
|
232 |
+
probs = torch.sigmoid(logits_per_image)
|
233 |
+
else:
|
234 |
+
# Use softmax for single-label classification
|
235 |
+
probs = logits_per_image.softmax(dim=1)
|
236 |
+
|
237 |
+
return probs
|
238 |
+
|
239 |
+
|
pyproject.toml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[project]
|
2 |
+
name = "MedImageInsights"
|
3 |
+
version = "0.1.0"
|
4 |
+
description = "Add your description here"
|
5 |
+
readme = "README.md"
|
6 |
+
requires-python = "==3.8.19"
|
7 |
+
dependencies = [
|
8 |
+
"mlflow==2.14.3",
|
9 |
+
"cffi==1.17.1",
|
10 |
+
"cloudpickle==3.0.0",
|
11 |
+
"colorama==0.4.6",
|
12 |
+
"einops==0.8.0",
|
13 |
+
"ftfy==6.2.3",
|
14 |
+
"fvcore==0.1.5.post20221221",
|
15 |
+
"mup==1.0.0",
|
16 |
+
"numpy==1.24.4",
|
17 |
+
"packaging==24.1",
|
18 |
+
"pandas==2.0.3",
|
19 |
+
"pyyaml==6.0.2",
|
20 |
+
"requests==2.32.3",
|
21 |
+
"sentencepiece==0.2.0",
|
22 |
+
"tenacity==9.0.0",
|
23 |
+
"timm==1.0.9",
|
24 |
+
"tornado==6.4.1",
|
25 |
+
"transformers==4.46.0",
|
26 |
+
# "huggingface-hub==0.26.1",
|
27 |
+
"fastapi[standard]>=0.115.3",
|
28 |
+
# "opencv-python>=4.10.0.84",
|
29 |
+
]
|
requirements.txt
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
mlflow==2.14.3
|
2 |
+
cffi==1.17.1
|
3 |
+
cloudpickle==3.0.0
|
4 |
+
colorama==0.4.6
|
5 |
+
einops==0.8.0
|
6 |
+
ftfy==6.2.3
|
7 |
+
fvcore==0.1.5.post20221221
|
8 |
+
mup==1.0.0
|
9 |
+
numpy==1.24.4
|
10 |
+
packaging==24.1
|
11 |
+
pandas==2.0.3
|
12 |
+
pyyaml==6.0.2
|
13 |
+
requests==2.32.3
|
14 |
+
sentencepiece==0.2.0
|
15 |
+
tenacity==9.0.0
|
16 |
+
timm==1.0.9
|
17 |
+
tornado==6.4.1
|
18 |
+
transformers==4.16.2
|
uv.lock
ADDED
The diff for this file is too large to render.
See raw diff
|
|