Spaces:
Sleeping
Sleeping
Commit
·
99b73a0
1
Parent(s):
699e2ed
Refactor model initialization and prediction logic; enhance mock prediction to handle modality and targets
Browse files- inference_utils/init_predict.py +39 -36
- inference_utils/init_predict_mock.py +44 -41
inference_utils/init_predict.py
CHANGED
@@ -10,15 +10,13 @@
|
|
10 |
# See the License for the specific language governing permissions and
|
11 |
# limitations under the License.
|
12 |
|
|
|
13 |
from PIL import Image
|
14 |
from huggingface_hub import hf_hub_download
|
15 |
import matplotlib.pyplot as plt
|
16 |
import numpy as np
|
|
|
17 |
from inference_utils.inference import interactive_infer_image
|
18 |
-
from main import model
|
19 |
-
|
20 |
-
|
21 |
-
import gradio as gr
|
22 |
|
23 |
from modeling import build_model
|
24 |
from modeling.BaseModel import BaseModel
|
@@ -27,29 +25,40 @@ from utilities.constants import BIOMED_CLASSES
|
|
27 |
from utilities.distributed import init_distributed
|
28 |
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
return colors
|
34 |
|
|
|
|
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
)
|
43 |
-
|
|
|
44 |
|
45 |
|
46 |
-
def predict(image:
|
47 |
if not prompts:
|
48 |
return None
|
49 |
|
50 |
-
# Convert string input to list
|
51 |
-
prompts = [p.strip() for p in prompts.split(",")]
|
52 |
-
|
53 |
# Convert to RGB if needed
|
54 |
if image.mode != "RGB":
|
55 |
image = image.convert("RGB")
|
@@ -66,23 +75,17 @@ def predict(image: gr.Image, prompts: str):
|
|
66 |
return pred_overlay
|
67 |
|
68 |
|
69 |
-
def
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
filename="biomedparse_v1.pt",
|
74 |
-
token=os.getenv("HF_TOKEN"),
|
75 |
-
)
|
76 |
|
77 |
-
# Initialize model
|
78 |
-
conf_files = "configs/biomedparse_inference.yaml"
|
79 |
-
opt = load_opt_from_config_files([conf_files])
|
80 |
-
opt = init_distributed(opt)
|
81 |
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
|
|
|
|
86 |
)
|
87 |
-
|
88 |
-
return model
|
|
|
10 |
# See the License for the specific language governing permissions and
|
11 |
# limitations under the License.
|
12 |
|
13 |
+
import os
|
14 |
from PIL import Image
|
15 |
from huggingface_hub import hf_hub_download
|
16 |
import matplotlib.pyplot as plt
|
17 |
import numpy as np
|
18 |
+
import torch
|
19 |
from inference_utils.inference import interactive_infer_image
|
|
|
|
|
|
|
|
|
20 |
|
21 |
from modeling import build_model
|
22 |
from modeling.BaseModel import BaseModel
|
|
|
25 |
from utilities.distributed import init_distributed
|
26 |
|
27 |
|
28 |
+
class Model:
|
29 |
+
def init(self):
|
30 |
+
self._model = init_model()
|
|
|
31 |
|
32 |
+
def predict(self, image: Image, modality_type: str, targets: list[str]) -> Image:
|
33 |
+
return predict(self._model, image, targets)
|
34 |
|
35 |
+
|
36 |
+
def init_model():
|
37 |
+
# Download model
|
38 |
+
model_file = hf_hub_download(
|
39 |
+
repo_id="microsoft/BiomedParse",
|
40 |
+
filename="biomedparse_v1.pt",
|
41 |
+
token=os.getenv("HF_TOKEN"),
|
42 |
+
)
|
43 |
+
|
44 |
+
# Initialize model
|
45 |
+
conf_files = "configs/biomedparse_inference.yaml"
|
46 |
+
opt = load_opt_from_config_files([conf_files])
|
47 |
+
opt = init_distributed(opt)
|
48 |
+
|
49 |
+
model = BaseModel(opt, build_model(opt)).from_pretrained(model_file).eval().cuda()
|
50 |
+
with torch.no_grad():
|
51 |
+
model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(
|
52 |
+
BIOMED_CLASSES + ["background"], is_eval=True
|
53 |
)
|
54 |
+
|
55 |
+
return model
|
56 |
|
57 |
|
58 |
+
def predict(model, image: Image, prompts: list[str]):
|
59 |
if not prompts:
|
60 |
return None
|
61 |
|
|
|
|
|
|
|
62 |
# Convert to RGB if needed
|
63 |
if image.mode != "RGB":
|
64 |
image = image.convert("RGB")
|
|
|
75 |
return pred_overlay
|
76 |
|
77 |
|
78 |
+
def generate_colors(n):
|
79 |
+
cmap = plt.get_cmap("tab10")
|
80 |
+
colors = [tuple(int(255 * val) for val in cmap(i)[:3]) for i in range(n)]
|
81 |
+
return colors
|
|
|
|
|
|
|
82 |
|
|
|
|
|
|
|
|
|
83 |
|
84 |
+
def overlay_masks(image, masks, colors):
|
85 |
+
overlay = image.copy()
|
86 |
+
overlay = np.array(overlay, dtype=np.uint8)
|
87 |
+
for mask, color in zip(masks, colors):
|
88 |
+
overlay[mask > 0] = (overlay[mask > 0] * 0.4 + np.array(color) * 0.6).astype(
|
89 |
+
np.uint8
|
90 |
)
|
91 |
+
return Image.fromarray(overlay)
|
|
inference_utils/init_predict_mock.py
CHANGED
@@ -12,48 +12,51 @@
|
|
12 |
|
13 |
|
14 |
from typing import Tuple
|
15 |
-
from PIL import
|
|
|
16 |
import gradio as gr
|
17 |
import random
|
18 |
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
def predict(
|
25 |
-
|
26 |
-
) -> Tuple[
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
|
|
|
|
|
12 |
|
13 |
|
14 |
from typing import Tuple
|
15 |
+
from PIL import ImageDraw, ImageFont
|
16 |
+
from PIL.Image import Image
|
17 |
import gradio as gr
|
18 |
import random
|
19 |
|
20 |
|
21 |
+
class Model:
|
22 |
+
def init(self):
|
23 |
+
pass
|
24 |
+
|
25 |
+
def predict(
|
26 |
+
image: Image, modality_type: str, targets: list[str]
|
27 |
+
) -> Tuple[Image, str]:
|
28 |
+
# Randomly split targets into found and not found
|
29 |
+
targets_found = random.sample(targets, k=len(targets) // 2)
|
30 |
+
targets_not_found = [t for t in targets if t not in targets_found]
|
31 |
+
|
32 |
+
# Create a copy of the image to draw on
|
33 |
+
image_with_text = image.copy()
|
34 |
+
draw = ImageDraw.Draw(image_with_text)
|
35 |
+
|
36 |
+
# Draw found targets on the image with larger font
|
37 |
+
font_size = 36
|
38 |
+
try:
|
39 |
+
font = ImageFont.truetype("DejaVuSans.ttf", font_size)
|
40 |
+
except OSError:
|
41 |
+
# Fallback to default font if DejaVuSans is not available
|
42 |
+
font = ImageFont.load_default()
|
43 |
+
|
44 |
+
# Calculate starting position from bottom
|
45 |
+
line_height = 50
|
46 |
+
total_height = len(targets_found) * line_height
|
47 |
+
padding = 20
|
48 |
+
|
49 |
+
# Start from bottom and work upwards
|
50 |
+
y_position = image_with_text.height - total_height - padding
|
51 |
+
for target in targets_found:
|
52 |
+
draw.text((20, y_position), target, fill="red", font=font)
|
53 |
+
y_position += line_height
|
54 |
+
|
55 |
+
# Format targets_not_found as a string with one target per line
|
56 |
+
targets_not_found_str = (
|
57 |
+
"\n".join(targets_not_found)
|
58 |
+
if targets_not_found
|
59 |
+
else "All targets were found!"
|
60 |
+
)
|
61 |
+
|
62 |
+
return image_with_text, targets_not_found_str
|