File size: 4,954 Bytes
202eff6
 
 
 
 
 
 
 
 
 
 
 
e8983fc
6ba63c9
e8983fc
 
 
 
6ba63c9
6bd0d8c
287d863
 
 
 
 
f50a656
287d863
f50a656
287d863
6ba63c9
7320170
 
 
 
 
 
 
 
4e4c0a1
 
7320170
 
4e4c0a1
 
 
 
 
 
 
 
 
 
 
7320170
 
 
fbf538f
e8983fc
 
 
 
 
 
 
 
 
 
fbf538f
 
e8983fc
 
 
 
 
 
 
 
fbf538f
 
 
 
e8983fc
 
 
 
 
 
 
 
 
 
 
 
 
fbf538f
f50a656
 
fbf538f
 
 
 
 
 
 
 
e8983fc
 
 
 
 
 
 
 
fbf538f
 
 
699e2ed
 
 
fbf538f
e8983fc
 
 
 
 
 
 
 
f50a656
e8983fc
699e2ed
fbf538f
 
 
 
e8983fc
f50a656
 
fbf538f
 
 
 
 
 
e8983fc
 
 
 
 
 
 
 
fa26a4b
 
6ba63c9
287d863
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
from pathlib import Path
from typing import Dict, List

from inference_utils.target_dist import modality_targets_from_target_dist

# If True, then mock init_model() and predict() functions will be used.
DEV_MODE = True if os.getenv("DEV_MODE") else False

import gradio as gr

if DEV_MODE:
    from inference_utils.model_mock import Model
else:
    from inference_utils.model import Model


gr.set_static_paths(["assets"])


description = """Upload a biomedical image and enter prompts (separated by commas) to detect specific features.

The model understands these prompts:
![gpt4_ontology_hierarchy.png](file/assets/gpt4_ontology_hierarchy.png)

Above figure is from the [BiomedParse paper](https://arxiv.org/abs/2405.12971).

The model understands these types of biomedical images:

- [Computed Tomography (CT)](https://en.wikipedia.org/wiki/Computed_tomography)
- [Magnetic Resonance Imaging (MRI)](https://en.wikipedia.org/wiki/Magnetic_resonance_imaging)
- [X-ray](https://en.wikipedia.org/wiki/X-ray)
- [Medical Ultrasound](https://en.wikipedia.org/wiki/Medical_ultrasound)
- [Pathology](https://en.wikipedia.org/wiki/Pathology)
- [Fundus (eye)](https://en.wikipedia.org/wiki/Fundus_(eye))
- [Dermoscopy](https://en.wikipedia.org/wiki/Dermoscopy)
- [Endoscopy](https://en.wikipedia.org/wiki/Endoscopy)
- [Optical Coherence Tomography (OCT)](https://en.wikipedia.org/wiki/Optical_coherence_tomography)

This Space is based on the [BiomedParse model](https://microsoft.github.io/BiomedParse/).
"""


examples = [
    ["examples/144DME_as_F.jpeg", "OCT", []],
    ["examples/C3_EndoCV2021_00462.jpg", "Endoscopy", []],
    ["examples/CT-abdomen.png", "CT-Abdomen", []],
    ["examples/covid_1585.png", "X-Ray-Chest", []],
    ["examples/ISIC_0015551.jpg", "Dermoscopy", []],
    [
        "examples/LIDC-IDRI-0140_143_280_CT_lung.png",
        "CT-Chest",
        [],
    ],
    [
        "examples/Part_1_516_pathology_breast.png",
        "Pathology",
        [],
    ],
    ["examples/T0011.jpg", "Fundus", []],
    [
        "examples/TCGA_HT_7856_19950831_8_MRI-FLAIR_brain.png",
        "MRI-FLAIR-Brain",
        [],
    ],
]


def load_modality_targets() -> Dict[str, List[str]]:
    target_dist_json_path = Path("inference_utils/target_dist.json")
    with open(target_dist_json_path, "r") as f:
        target_dist = json.load(f)

    modality_targets = modality_targets_from_target_dist(target_dist)
    return modality_targets


MODALITY_TARGETS = load_modality_targets()
DEFAULT_MODALITY = "CT-Abdomen"


def run():
    model = Model()
    model.init()

    with gr.Blocks() as demo:
        gr.Markdown("# BiomedParse Demo")
        gr.Markdown(description)

        with gr.Row():
            with gr.Column():
                input_image = gr.Image(type="pil", label="Input Image")
                input_modality_type = gr.Dropdown(
                    choices=list(MODALITY_TARGETS.keys()),
                    label="Modality Type",
                    value=DEFAULT_MODALITY,
                )
                input_targets = gr.CheckboxGroup(
                    choices=MODALITY_TARGETS[DEFAULT_MODALITY],
                    label="Targets",
                )
            with gr.Column():
                output_image = gr.Image(type="pil", label="Prediction")
                output_targets_not_found = gr.Textbox(
                    label="Targets Not Found", lines=4, max_lines=10
                )

        input_modality_type.change(
            fn=update_input_targets,
            inputs=input_modality_type,
            outputs=input_targets,
        )

        submit_btn = gr.Button("Submit")
        submit_btn.click(
            fn=model.predict,
            inputs=[input_image, input_modality_type, input_targets],
            outputs=[output_image, output_targets_not_found],
        )

        gr.Examples(
            examples=examples,
            inputs=[input_image, input_modality_type, input_targets],
            outputs=[output_image, output_targets_not_found],
            fn=model.predict,
            cache_examples=False,
        )

    return demo


def update_input_targets(input_modality_type):
    return gr.CheckboxGroup(
        choices=MODALITY_TARGETS[input_modality_type],
        value=[],
        label="Targets",
    )


demo = run()

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)