kernel-luso-comfort commited on
Commit
e8983fc
·
1 Parent(s): fbf538f

Refactor predict function to accept modality type and targets; update main.py to integrate modality selection and target loading

Browse files
Files changed (2) hide show
  1. inference_utils/init_predict_mock.py +1 -1
  2. main.py +62 -22
inference_utils/init_predict_mock.py CHANGED
@@ -15,5 +15,5 @@ def init_model():
15
  return None
16
 
17
 
18
- def predict(image, prompts):
19
  return image
 
15
  return None
16
 
17
 
18
+ def predict(image, modality_type: str, targets: list[str]):
19
  return image
main.py CHANGED
@@ -10,7 +10,12 @@
10
  # See the License for the specific language governing permissions and
11
  # limitations under the License.
12
 
 
13
  import os
 
 
 
 
14
 
15
  # If True, then mock init_model() and predict() functions will be used.
16
  DEV_MODE = True if os.getenv("DEV_MODE") else False
@@ -50,27 +55,43 @@ This Space is based on the [BiomedParse model](https://microsoft.github.io/Biome
50
 
51
 
52
  examples = [
53
- ["examples/144DME_as_F.jpeg", "edema"],
54
- ["examples/C3_EndoCV2021_00462.jpg", "polyp"],
55
- ["examples/CT-abdomen.png", "liver, pancreas, spleen"],
56
- ["examples/covid_1585.png", "left lung"],
57
- ["examples/covid_1585.png", "right lung"],
58
- ["examples/covid_1585.png", "COVID-19 infection"],
59
- ["examples/ISIC_0015551.jpg", "lesion"],
60
- ["examples/LIDC-IDRI-0140_143_280_CT_lung.png", "lung nodule"],
61
- ["examples/LIDC-IDRI-0140_143_280_CT_lung.png", "COVID-19 infection"],
62
- ["examples/Part_1_516_pathology_breast.png", "connective tissue cells"],
63
- ["examples/Part_1_516_pathology_breast.png", "neoplastic cells"],
64
  [
65
  "examples/Part_1_516_pathology_breast.png",
66
- "neoplastic cells, inflammatory cells",
 
 
 
 
 
 
 
67
  ],
68
- ["examples/T0011.jpg", "optic disc"],
69
- ["examples/T0011.jpg", "optic cup"],
70
- ["examples/TCGA_HT_7856_19950831_8_MRI-FLAIR_brain.png", "glioma"],
71
  ]
72
 
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  def run():
75
  global model
76
  model = init_model()
@@ -82,23 +103,34 @@ def run():
82
  with gr.Row():
83
  with gr.Column():
84
  input_image = gr.Image(type="pil", label="Input Image")
85
- input_text = gr.Textbox(
86
- label="Prompts",
87
- placeholder="Enter prompts separated by commas (e.g., neoplastic cells, inflammatory cells)",
 
 
 
 
 
88
  )
89
  with gr.Column():
90
  output_image = gr.Image(type="pil", label="Prediction")
91
 
92
- predict_btn = gr.Button("Submit")
93
- predict_btn.click(
 
 
 
 
 
 
94
  fn=predict,
95
- inputs=[input_image, input_text],
96
  outputs=output_image,
97
  )
98
 
99
  gr.Examples(
100
  examples=examples,
101
- inputs=[input_image, input_text],
102
  outputs=output_image,
103
  fn=predict,
104
  cache_examples=False,
@@ -107,6 +139,14 @@ def run():
107
  return demo
108
 
109
 
 
 
 
 
 
 
 
 
110
  demo = run()
111
 
112
  if __name__ == "__main__":
 
10
  # See the License for the specific language governing permissions and
11
  # limitations under the License.
12
 
13
+ import json
14
  import os
15
+ from pathlib import Path
16
+ from typing import Dict, List
17
+
18
+ from inference_utils.target_dist import modality_targets_from_target_dist
19
 
20
  # If True, then mock init_model() and predict() functions will be used.
21
  DEV_MODE = True if os.getenv("DEV_MODE") else False
 
55
 
56
 
57
  examples = [
58
+ ["examples/144DME_as_F.jpeg", "OCT", []],
59
+ ["examples/C3_EndoCV2021_00462.jpg", "Endoscopy", []],
60
+ ["examples/CT-abdomen.png", "CT-Abdomen", []],
61
+ ["examples/covid_1585.png", "X-Ray-Chest", []],
62
+ ["examples/ISIC_0015551.jpg", "Dermoscopy", []],
63
+ [
64
+ "examples/LIDC-IDRI-0140_143_280_CT_lung.png",
65
+ "CT-Chest",
66
+ [],
67
+ ],
 
68
  [
69
  "examples/Part_1_516_pathology_breast.png",
70
+ "Pathology",
71
+ [],
72
+ ],
73
+ ["examples/T0011.jpg", "Fundus", []],
74
+ [
75
+ "examples/TCGA_HT_7856_19950831_8_MRI-FLAIR_brain.png",
76
+ "MRI-FLAIR-Brain",
77
+ [],
78
  ],
 
 
 
79
  ]
80
 
81
 
82
+ def load_modality_targets() -> Dict[str, List[str]]:
83
+ target_dist_json_path = Path("inference_utils/target_dist.json")
84
+ with open(target_dist_json_path, "r") as f:
85
+ target_dist = json.load(f)
86
+
87
+ modality_targets = modality_targets_from_target_dist(target_dist)
88
+ return modality_targets
89
+
90
+
91
+ MODALITY_TARGETS = load_modality_targets()
92
+ DEFAULT_MODALITY = "CT-Abdomen"
93
+
94
+
95
  def run():
96
  global model
97
  model = init_model()
 
103
  with gr.Row():
104
  with gr.Column():
105
  input_image = gr.Image(type="pil", label="Input Image")
106
+ input_modality_type = gr.Dropdown(
107
+ choices=list(MODALITY_TARGETS.keys()),
108
+ label="Modality Type",
109
+ value=DEFAULT_MODALITY,
110
+ )
111
+ input_targets = gr.CheckboxGroup(
112
+ choices=MODALITY_TARGETS[DEFAULT_MODALITY],
113
+ label="Targets",
114
  )
115
  with gr.Column():
116
  output_image = gr.Image(type="pil", label="Prediction")
117
 
118
+ input_modality_type.change(
119
+ fn=update_input_targets,
120
+ inputs=input_modality_type,
121
+ outputs=input_targets,
122
+ )
123
+
124
+ submit_btn = gr.Button("Submit")
125
+ submit_btn.click(
126
  fn=predict,
127
+ inputs=[input_image, input_modality_type, input_targets],
128
  outputs=output_image,
129
  )
130
 
131
  gr.Examples(
132
  examples=examples,
133
+ inputs=[input_image, input_modality_type, input_targets],
134
  outputs=output_image,
135
  fn=predict,
136
  cache_examples=False,
 
139
  return demo
140
 
141
 
142
+ def update_input_targets(input_modality_type):
143
+ return gr.CheckboxGroup(
144
+ choices=MODALITY_TARGETS[input_modality_type],
145
+ value=[],
146
+ label="Targets",
147
+ )
148
+
149
+
150
  demo = run()
151
 
152
  if __name__ == "__main__":