File size: 8,491 Bytes
6dc96e0
 
 
 
e795ee7
 
 
 
 
 
 
 
 
 
 
 
ebeeb35
6dc96e0
e795ee7
6dc96e0
 
e795ee7
 
6dc96e0
e795ee7
 
 
 
 
 
6dc96e0
e795ee7
 
6dc96e0
 
 
 
 
 
 
 
 
 
 
 
e795ee7
6dc96e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebeeb35
 
 
 
 
6dc96e0
 
ebeeb35
 
 
 
 
 
 
 
 
 
 
6dc96e0
 
 
 
 
 
 
e795ee7
6dc96e0
e795ee7
 
 
 
 
 
 
6dc96e0
 
 
e795ee7
 
 
 
 
 
 
 
 
 
 
 
 
6dc96e0
e795ee7
 
 
 
6dc96e0
 
 
e795ee7
 
 
6dc96e0
e795ee7
 
 
 
 
 
6dc96e0
 
 
 
 
 
 
 
 
 
e795ee7
6dc96e0
e795ee7
 
 
6dc96e0
 
e795ee7
6dc96e0
e795ee7
6dc96e0
e795ee7
6dc96e0
 
 
e795ee7
6dc96e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e795ee7
6dc96e0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
# from langgraph.graph import Graph
# from langchain_groq import ChatGroq
# llm = langchain_groq(model="llama3-70b-8192")
# llm.invoke("hi how are you")
import streamlit as st
import os
import base64
from dotenv import load_dotenv
from langchain_groq import ChatGroq
from langchain.chains import LLMMathChain, LLMChain
from langchain.prompts import PromptTemplate
from langchain_community.utilities import WikipediaAPIWrapper
from langchain.agents.agent_types import AgentType
from langchain.agents import Tool, initialize_agent
from langchain_community.callbacks.streamlit import StreamlitCallbackHandler
from groq import Groq
import open_clip
from open_clip import create_model_from_pretrained, get_tokenizer # works on open-clip-torch>=2.23.0, timm>=0.9.8

model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
tokenizer = open_clip.get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
load_dotenv()
groq_api_key = os.getenv("GROQ_API_KEY")

if not groq_api_key:
    st.error("Groq API Key not found in .env file")
    st.stop()

st.set_page_config(page_title="Medical Bot", page_icon="πŸ‘¨β€πŸ”¬")
st.title("Medical Bot")
llm_text = ChatGroq(model="gemma2-9b-it", groq_api_key=groq_api_key)
llm_image = ChatGroq(model="llama-3.2-90b-vision-preview", groq_api_key=groq_api_key)

wikipedia_wrapper = WikipediaAPIWrapper()
wikipedia_tool = Tool(
    name="Wikipedia",
    func=wikipedia_wrapper.run,
    description="A tool for searching the Internet to find various information on the topics mentioned."
)
math_chain = LLMMathChain.from_llm(llm=llm_text)
calculator = Tool(
    name="Calculator",
    func=math_chain.run,
    description="A tool for solving mathematical problems. Provide only the mathematical expressions."
)

prompt = """
You are a mathematical problem-solving assistant tasked with helping users solve their questions. Arrive at the solution logically, providing a clear and step-by-step explanation. Present your response in a structured point-wise format for better understanding.
Question: {question}
Answer:
"""

prompt_template = PromptTemplate(
    input_variables=["question"],
    template=prompt
)
# Combine all the tools into a chain for text questions
chain = LLMChain(llm=llm_text, prompt=prompt_template)

reasoning_tool = Tool(
    name="Reasoning Tool",
    func=chain.run,
    description="A tool for answering logic-based and reasoning questions."
)
def classify_image(image_path: str) -> str:
    """Classifies a medical image using BiomedCLIP."""
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    model.to(device).eval()
    
    # Open and preprocess image
    image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
    labels = ["MRI scan", "X-ray", "histopathology", "CT scan", "ultrasound", "medical chart"]
    texts = tokenizer([f"this is a photo of {l}" for l in labels], context_length=256).to(device)
    
    with torch.no_grad():
        image_features, text_features, logit_scale = model(image, texts)
        logits = (logit_scale * image_features @ text_features.t()).detach().softmax(dim=-1)
        sorted_indices = torch.argsort(logits, dim=-1, descending=True)
    
    top_class = labels[sorted_indices[0][0].item()]
    return f"The image is classified as {top_class}."

# Wrap BiomedCLIP as a LangChain tool
biomed_clip_tool = Tool(
    name="BiomedCLIP Image Classifier",
    func=classify_image,
    description="Classifies medical images into categories like MRI, X-ray, histopathology, etc."
)
# Initialize the agents for text questions
assistant_agent_text = initialize_agent(
    tools=[wikipedia_tool, calculator, reasoning_tool, biomed_clip_tool],
    llm=llm_text,
    agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
    verbose=False,
    handle_parsing_errors=True
)

if "messages" not in st.session_state:
    st.session_state["messages"] = [
        {"role": "assistant", "content": "Welcome! I am your Assistant. How can I help you today?"}
    ]

for msg in st.session_state.messages:
    if msg["role"] == "user" and "image" in msg:
        st.chat_message(msg["role"]).write(msg['content'])
        st.image(msg["image"], caption='Uploaded Image', use_column_width=True)
    else:
        st.chat_message(msg["role"]).write(msg['content'])

st.sidebar.header("Navigation")
if st.sidebar.button("Text Question"):
    st.session_state["section"] = "text"
if st.sidebar.button("Image Question"):
    st.session_state["section"] = "image"

if "section" not in st.session_state:
    st.session_state["section"] = "text"

def clean_response(response):
    if "```" in response:
        response = response.split("```")[1].strip()
    return response

if st.session_state["section"] == "text":
    st.header("Text Question")
    st.write("Please enter your question below, and I will provide a detailed description of the problem and suggest a solution for it.")
    question = st.text_area("Your Question:")
    if st.button("Get Answer"):
        if question:
            with st.spinner("Generating response..."):
                st.session_state.messages.append({"role": "user", "content": question})
                st.chat_message("user").write(question)

                st_cb = StreamlitCallbackHandler(st.container(), expand_new_thoughts=False)
                try:
                    response = assistant_agent_text.run(st.session_state.messages, callbacks=[st_cb])
                    cleaned_response = clean_response(response)
                    st.session_state.messages.append({'role': 'assistant', "content": cleaned_response})
                    st.write('### Response:')
                    st.success(cleaned_response)
                except ValueError as e:
                    st.error(f"An error occurred: {e}")
        else:
            st.warning("Please enter a question to get an answer.")

elif st.session_state["section"] == "image":
    st.header("Image Question")
    st.write("Please enter your question below and upload the medical image. I will provide a detailed description of the problem and suggest a solution for it.")
    question = st.text_area("Your Question:", "Example: What is the patient suffering from?")
    uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])

    if st.button("Get Answer"):
        if question and uploaded_file is not None:
            with st.spinner("Generating response..."):
                image_data = uploaded_file.read()
                image_data_url = f"data:image/jpeg;base64,{base64.b64encode(image_data).decode()}"
                st.session_state.messages.append({"role": "user", "content": question, "image": image_data})
                st.chat_message("user").write(question)
                st.image(image_data, caption='Uploaded Image', use_column_width=True)

                client = Groq()

                messages = [
                    {
                        "role": "user",
                        "content": [
                            {
                                "type": "text",
                                "text": question
                            },
                            {
                                "type": "image_url",
                                "image_url": {
                                    "url": image_data_url
                                }
                            }
                        ]
                    }
                ]
                try:
                    completion = client.chat.completions.create(
                        model="llama-3.2-90b-vision-preview",
                        messages=messages,
                        temperature=1,
                        max_tokens=1024,
                        top_p=1,
                        stream=False,
                        stop=None,
                    )
                    response = completion.choices[0].message.content
                    cleaned_response = clean_response(response)
                    st.session_state.messages.append({'role': 'assistant', "content": cleaned_response})
                    st.write('### Response:')
                    st.success(cleaned_response)
                except ValueError as e:
                    st.error(f"An error occurred: {e}")
        else:
            st.warning("Please enter a question and upload an image to get an answer.")