kernel-luso-comfort commited on
Commit
99b73a0
·
1 Parent(s): 699e2ed

Refactor model initialization and prediction logic; enhance mock prediction to handle modality and targets

Browse files
inference_utils/init_predict.py CHANGED
@@ -10,15 +10,13 @@
10
  # See the License for the specific language governing permissions and
11
  # limitations under the License.
12
 
 
13
  from PIL import Image
14
  from huggingface_hub import hf_hub_download
15
  import matplotlib.pyplot as plt
16
  import numpy as np
 
17
  from inference_utils.inference import interactive_infer_image
18
- from main import model
19
-
20
-
21
- import gradio as gr
22
 
23
  from modeling import build_model
24
  from modeling.BaseModel import BaseModel
@@ -27,29 +25,40 @@ from utilities.constants import BIOMED_CLASSES
27
  from utilities.distributed import init_distributed
28
 
29
 
30
- def generate_colors(n):
31
- cmap = plt.get_cmap("tab10")
32
- colors = [tuple(int(255 * val) for val in cmap(i)[:3]) for i in range(n)]
33
- return colors
34
 
 
 
35
 
36
- def overlay_masks(image, masks, colors):
37
- overlay = image.copy()
38
- overlay = np.array(overlay, dtype=np.uint8)
39
- for mask, color in zip(masks, colors):
40
- overlay[mask > 0] = (overlay[mask > 0] * 0.4 + np.array(color) * 0.6).astype(
41
- np.uint8
 
 
 
 
 
 
 
 
 
 
 
 
42
  )
43
- return Image.fromarray(overlay)
 
44
 
45
 
46
- def predict(image: gr.Image, prompts: str):
47
  if not prompts:
48
  return None
49
 
50
- # Convert string input to list
51
- prompts = [p.strip() for p in prompts.split(",")]
52
-
53
  # Convert to RGB if needed
54
  if image.mode != "RGB":
55
  image = image.convert("RGB")
@@ -66,23 +75,17 @@ def predict(image: gr.Image, prompts: str):
66
  return pred_overlay
67
 
68
 
69
- def init_model():
70
- # Download model
71
- model_file = hf_hub_download(
72
- repo_id="microsoft/BiomedParse",
73
- filename="biomedparse_v1.pt",
74
- token=os.getenv("HF_TOKEN"),
75
- )
76
 
77
- # Initialize model
78
- conf_files = "configs/biomedparse_inference.yaml"
79
- opt = load_opt_from_config_files([conf_files])
80
- opt = init_distributed(opt)
81
 
82
- model = BaseModel(opt, build_model(opt)).from_pretrained(model_file).eval().cuda()
83
- with torch.no_grad():
84
- model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(
85
- BIOMED_CLASSES + ["background"], is_eval=True
 
 
86
  )
87
-
88
- return model
 
10
  # See the License for the specific language governing permissions and
11
  # limitations under the License.
12
 
13
+ import os
14
  from PIL import Image
15
  from huggingface_hub import hf_hub_download
16
  import matplotlib.pyplot as plt
17
  import numpy as np
18
+ import torch
19
  from inference_utils.inference import interactive_infer_image
 
 
 
 
20
 
21
  from modeling import build_model
22
  from modeling.BaseModel import BaseModel
 
25
  from utilities.distributed import init_distributed
26
 
27
 
28
+ class Model:
29
+ def init(self):
30
+ self._model = init_model()
 
31
 
32
+ def predict(self, image: Image, modality_type: str, targets: list[str]) -> Image:
33
+ return predict(self._model, image, targets)
34
 
35
+
36
+ def init_model():
37
+ # Download model
38
+ model_file = hf_hub_download(
39
+ repo_id="microsoft/BiomedParse",
40
+ filename="biomedparse_v1.pt",
41
+ token=os.getenv("HF_TOKEN"),
42
+ )
43
+
44
+ # Initialize model
45
+ conf_files = "configs/biomedparse_inference.yaml"
46
+ opt = load_opt_from_config_files([conf_files])
47
+ opt = init_distributed(opt)
48
+
49
+ model = BaseModel(opt, build_model(opt)).from_pretrained(model_file).eval().cuda()
50
+ with torch.no_grad():
51
+ model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(
52
+ BIOMED_CLASSES + ["background"], is_eval=True
53
  )
54
+
55
+ return model
56
 
57
 
58
+ def predict(model, image: Image, prompts: list[str]):
59
  if not prompts:
60
  return None
61
 
 
 
 
62
  # Convert to RGB if needed
63
  if image.mode != "RGB":
64
  image = image.convert("RGB")
 
75
  return pred_overlay
76
 
77
 
78
+ def generate_colors(n):
79
+ cmap = plt.get_cmap("tab10")
80
+ colors = [tuple(int(255 * val) for val in cmap(i)[:3]) for i in range(n)]
81
+ return colors
 
 
 
82
 
 
 
 
 
83
 
84
+ def overlay_masks(image, masks, colors):
85
+ overlay = image.copy()
86
+ overlay = np.array(overlay, dtype=np.uint8)
87
+ for mask, color in zip(masks, colors):
88
+ overlay[mask > 0] = (overlay[mask > 0] * 0.4 + np.array(color) * 0.6).astype(
89
+ np.uint8
90
  )
91
+ return Image.fromarray(overlay)
 
inference_utils/init_predict_mock.py CHANGED
@@ -12,48 +12,51 @@
12
 
13
 
14
  from typing import Tuple
15
- from PIL import Image, ImageDraw, ImageFont
 
16
  import gradio as gr
17
  import random
18
 
19
 
20
- def init_model():
21
- return None
22
-
23
-
24
- def predict(
25
- image: Image, modality_type: str, targets: list[str]
26
- ) -> Tuple[gr.Image, str]:
27
- # Randomly split targets into found and not found
28
- targets_found = random.sample(targets, k=len(targets) // 2)
29
- targets_not_found = [t for t in targets if t not in targets_found]
30
-
31
- # Create a copy of the image to draw on
32
- image_with_text = image.copy()
33
- draw = ImageDraw.Draw(image_with_text)
34
-
35
- # Draw found targets on the image with larger font
36
- font_size = 36
37
- try:
38
- font = ImageFont.truetype("DejaVuSans.ttf", font_size)
39
- except OSError:
40
- # Fallback to default font if DejaVuSans is not available
41
- font = ImageFont.load_default()
42
-
43
- # Calculate starting position from bottom
44
- line_height = 50
45
- total_height = len(targets_found) * line_height
46
- padding = 20
47
-
48
- # Start from bottom and work upwards
49
- y_position = image_with_text.height - total_height - padding
50
- for target in targets_found:
51
- draw.text((20, y_position), target, fill="red", font=font)
52
- y_position += line_height
53
-
54
- # Format targets_not_found as a string with one target per line
55
- targets_not_found_str = (
56
- "\n".join(targets_not_found) if targets_not_found else "All targets were found!"
57
- )
58
-
59
- return image_with_text, targets_not_found_str
 
 
 
12
 
13
 
14
  from typing import Tuple
15
+ from PIL import ImageDraw, ImageFont
16
+ from PIL.Image import Image
17
  import gradio as gr
18
  import random
19
 
20
 
21
+ class Model:
22
+ def init(self):
23
+ pass
24
+
25
+ def predict(
26
+ image: Image, modality_type: str, targets: list[str]
27
+ ) -> Tuple[Image, str]:
28
+ # Randomly split targets into found and not found
29
+ targets_found = random.sample(targets, k=len(targets) // 2)
30
+ targets_not_found = [t for t in targets if t not in targets_found]
31
+
32
+ # Create a copy of the image to draw on
33
+ image_with_text = image.copy()
34
+ draw = ImageDraw.Draw(image_with_text)
35
+
36
+ # Draw found targets on the image with larger font
37
+ font_size = 36
38
+ try:
39
+ font = ImageFont.truetype("DejaVuSans.ttf", font_size)
40
+ except OSError:
41
+ # Fallback to default font if DejaVuSans is not available
42
+ font = ImageFont.load_default()
43
+
44
+ # Calculate starting position from bottom
45
+ line_height = 50
46
+ total_height = len(targets_found) * line_height
47
+ padding = 20
48
+
49
+ # Start from bottom and work upwards
50
+ y_position = image_with_text.height - total_height - padding
51
+ for target in targets_found:
52
+ draw.text((20, y_position), target, fill="red", font=font)
53
+ y_position += line_height
54
+
55
+ # Format targets_not_found as a string with one target per line
56
+ targets_not_found_str = (
57
+ "\n".join(targets_not_found)
58
+ if targets_not_found
59
+ else "All targets were found!"
60
+ )
61
+
62
+ return image_with_text, targets_not_found_str