mMonika commited on
Commit
7f1ecff
·
verified ·
1 Parent(s): a242e19

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -113
app.py CHANGED
@@ -1,7 +1,3 @@
1
- # from langgraph.graph import Graph
2
- # from langchain_groq import ChatGroq
3
- # llm = langchain_groq(model="llama3-70b-8192")
4
- # llm.invoke("hi how are you")
5
  import streamlit as st
6
  import os
7
  import base64
@@ -15,60 +11,34 @@ from langchain.agents import Tool, initialize_agent
15
  from langchain_community.callbacks.streamlit import StreamlitCallbackHandler
16
  from groq import Groq
17
  import open_clip
18
- from open_clip import create_model_from_pretrained, get_tokenizer # works on open-clip-torch>=2.23.0, timm>=0.9.8
 
19
 
20
- model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
21
- tokenizer = open_clip.get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
22
  load_dotenv()
23
  groq_api_key = os.getenv("GROQ_API_KEY")
24
-
25
  if not groq_api_key:
26
  st.error("Groq API Key not found in .env file")
27
  st.stop()
28
 
 
29
  st.set_page_config(page_title="Medical Bot", page_icon="👨‍🔬")
30
  st.title("Medical Bot")
 
 
31
  llm_text = ChatGroq(model="llama-3.3-70b-versatile", groq_api_key=groq_api_key)
32
  llm_image = ChatGroq(model="llama-3.2-90b-vision-preview", groq_api_key=groq_api_key)
33
 
34
- wikipedia_wrapper = WikipediaAPIWrapper()
35
- wikipedia_tool = Tool(
36
- name="Wikipedia",
37
- func=wikipedia_wrapper.run,
38
- description="A tool for searching the Internet to find various information on the topics mentioned."
39
- )
40
- math_chain = LLMMathChain.from_llm(llm=llm_text)
41
- calculator = Tool(
42
- name="Calculator",
43
- func=math_chain.run,
44
- description="A tool for solving mathematical problems. Provide only the mathematical expressions."
45
- )
46
-
47
- prompt = """
48
- You are a mathematical problem-solving assistant tasked with helping users solve their questions. Arrive at the solution logically, providing a clear and step-by-step explanation. Present your response in a structured point-wise format for better understanding.
49
- Question: {question}
50
- Answer:
51
- """
52
-
53
- prompt_template = PromptTemplate(
54
- input_variables=["question"],
55
- template=prompt
56
- )
57
- # Combine all the tools into a chain for text questions
58
- chain = LLMChain(llm=llm_text, prompt=prompt_template)
59
 
60
- reasoning_tool = Tool(
61
- name="Reasoning Tool",
62
- func=chain.run,
63
- description="A tool for answering logic-based and reasoning questions."
64
- )
65
  def classify_image(image_path: str) -> str:
66
  """Classifies a medical image using BiomedCLIP."""
67
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
68
  model.to(device).eval()
69
 
70
- # Open and preprocess image
71
- image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
72
  labels = ["MRI scan", "X-ray", "histopathology", "CT scan", "ultrasound", "medical chart"]
73
  texts = tokenizer([f"this is a photo of {l}" for l in labels], context_length=256).to(device)
74
 
@@ -80,26 +50,43 @@ def classify_image(image_path: str) -> str:
80
  top_class = labels[sorted_indices[0][0].item()]
81
  return f"The image is classified as {top_class}."
82
 
83
- # Wrap BiomedCLIP as a LangChain tool
84
- biomed_clip_tool = Tool(
85
- name="BiomedCLIP Image Classifier",
86
- func=classify_image,
87
- description="Classifies medical images into categories like MRI, X-ray, histopathology, etc."
88
- )
89
- # Initialize the agents for text questions
 
 
 
 
 
 
 
 
 
90
  assistant_agent_text = initialize_agent(
91
- tools=[wikipedia_tool, calculator, reasoning_tool, biomed_clip_tool],
92
  llm=llm_text,
93
  agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
94
  verbose=False,
95
  handle_parsing_errors=True
96
  )
97
 
 
 
 
 
 
 
 
 
 
98
  if "messages" not in st.session_state:
99
- st.session_state["messages"] = [
100
- {"role": "assistant", "content": "Welcome! I am your Assistant. How can I help you today?"}
101
- ]
102
 
 
103
  for msg in st.session_state.messages:
104
  if msg["role"] == "user" and "image" in msg:
105
  st.chat_message(msg["role"]).write(msg['content'])
@@ -112,87 +99,48 @@ if st.sidebar.button("Text Question"):
112
  st.session_state["section"] = "text"
113
  if st.sidebar.button("Image Question"):
114
  st.session_state["section"] = "image"
115
-
116
  if "section" not in st.session_state:
117
  st.session_state["section"] = "text"
118
 
119
  def clean_response(response):
120
- if "```" in response:
121
- response = response.split("```")[1].strip()
122
- return response
123
 
124
  if st.session_state["section"] == "text":
125
  st.header("Text Question")
126
- st.write("Please enter your question below, and I will provide a detailed description of the problem and suggest a solution for it.")
127
  question = st.text_area("Your Question:")
128
  if st.button("Get Answer"):
129
  if question:
130
  with st.spinner("Generating response..."):
131
  st.session_state.messages.append({"role": "user", "content": question})
132
  st.chat_message("user").write(question)
133
-
134
- st_cb = StreamlitCallbackHandler(st.container(), expand_new_thoughts=False)
135
- try:
136
- response = assistant_agent_text.run(st.session_state.messages, callbacks=[st_cb])
137
- cleaned_response = clean_response(response)
138
- st.session_state.messages.append({'role': 'assistant', "content": cleaned_response})
139
- st.write('### Response:')
140
- st.success(cleaned_response)
141
- except ValueError as e:
142
- st.error(f"An error occurred: {e}")
143
  else:
144
- st.warning("Please enter a question to get an answer.")
145
 
146
  elif st.session_state["section"] == "image":
147
  st.header("Image Question")
148
- st.write("Please enter your question below and upload the medical image. I will provide a detailed description of the problem and suggest a solution for it.")
149
- question = st.text_area("Your Question:", "Example: What is the patient suffering from?")
150
  uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
151
-
152
  if st.button("Get Answer"):
153
- if question and uploaded_file is not None:
154
  with st.spinner("Generating response..."):
155
- image_data = uploaded_file.read()
156
- image_data_url = f"data:image/jpeg;base64,{base64.b64encode(image_data).decode()}"
157
- st.session_state.messages.append({"role": "user", "content": question, "image": image_data})
 
 
158
  st.chat_message("user").write(question)
159
- st.image(image_data, caption='Uploaded Image', use_column_width=True)
160
-
161
- client = Groq()
162
-
163
- messages = [
164
- {
165
- "role": "user",
166
- "content": [
167
- {
168
- "type": "text",
169
- "text": question
170
- },
171
- {
172
- "type": "image_url",
173
- "image_url": {
174
- "url": image_data_url
175
- }
176
- }
177
- ]
178
- }
179
- ]
180
- try:
181
- completion = client.chat.completions.create(
182
- model="llama-3.2-90b-vision-preview",
183
- messages=messages,
184
- temperature=1,
185
- max_tokens=1024,
186
- top_p=1,
187
- stream=False,
188
- stop=None,
189
- )
190
- response = completion.choices[0].message.content
191
- cleaned_response = clean_response(response)
192
- st.session_state.messages.append({'role': 'assistant', "content": cleaned_response})
193
- st.write('### Response:')
194
- st.success(cleaned_response)
195
- except ValueError as e:
196
- st.error(f"An error occurred: {e}")
197
  else:
198
- st.warning("Please enter a question and upload an image to get an answer.")
 
 
 
 
 
1
  import streamlit as st
2
  import os
3
  import base64
 
11
  from langchain_community.callbacks.streamlit import StreamlitCallbackHandler
12
  from groq import Groq
13
  import open_clip
14
+ import torch
15
+ from PIL import Image
16
 
17
+ # Load environment variables
 
18
  load_dotenv()
19
  groq_api_key = os.getenv("GROQ_API_KEY")
 
20
  if not groq_api_key:
21
  st.error("Groq API Key not found in .env file")
22
  st.stop()
23
 
24
+ # Configure Streamlit
25
  st.set_page_config(page_title="Medical Bot", page_icon="👨‍🔬")
26
  st.title("Medical Bot")
27
+
28
+ # Initialize LLM models
29
  llm_text = ChatGroq(model="llama-3.3-70b-versatile", groq_api_key=groq_api_key)
30
  llm_image = ChatGroq(model="llama-3.2-90b-vision-preview", groq_api_key=groq_api_key)
31
 
32
+ # Load BiomedCLIP model for image classification
33
+ model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
34
+ tokenizer = open_clip.get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
 
 
 
 
 
36
  def classify_image(image_path: str) -> str:
37
  """Classifies a medical image using BiomedCLIP."""
38
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
39
  model.to(device).eval()
40
 
41
+ image = preprocess_val(Image.open(image_path)).unsqueeze(0).to(device)
 
42
  labels = ["MRI scan", "X-ray", "histopathology", "CT scan", "ultrasound", "medical chart"]
43
  texts = tokenizer([f"this is a photo of {l}" for l in labels], context_length=256).to(device)
44
 
 
50
  top_class = labels[sorted_indices[0][0].item()]
51
  return f"The image is classified as {top_class}."
52
 
53
+ # Define tools
54
+ wikipedia_tool = Tool(name="Wikipedia", func=WikipediaAPIWrapper().run, description="A tool for searching information.")
55
+ math_chain = LLMMathChain.from_llm(llm=llm_text)
56
+ calculator = Tool(name="Calculator", func=math_chain.run, description="Solves mathematical problems.")
57
+
58
+ prompt_template = PromptTemplate(input_variables=["question"], template="""
59
+ You are a mathematical problem-solving assistant. Solve the question step by step.
60
+ Question: {question}
61
+ Answer:
62
+ """)
63
+ chain = LLMChain(llm=llm_text, prompt=prompt_template)
64
+ reasoning_tool = Tool(name="Reasoning Tool", func=chain.run, description="Answers logic-based questions.")
65
+
66
+ biomed_clip_tool = Tool(name="BiomedCLIP Image Classifier", func=classify_image, description="Classifies medical images.")
67
+
68
+ # Initialize agents
69
  assistant_agent_text = initialize_agent(
70
+ tools=[wikipedia_tool, calculator, reasoning_tool],
71
  llm=llm_text,
72
  agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
73
  verbose=False,
74
  handle_parsing_errors=True
75
  )
76
 
77
+ assistant_agent_image = initialize_agent(
78
+ tools=[biomed_clip_tool],
79
+ llm=llm_image,
80
+ agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
81
+ verbose=False,
82
+ handle_parsing_errors=True
83
+ )
84
+
85
+ # Streamlit session state for chat messages
86
  if "messages" not in st.session_state:
87
+ st.session_state["messages"] = [{"role": "assistant", "content": "Welcome! How can I help you today?"}]
 
 
88
 
89
+ # Chat Interface
90
  for msg in st.session_state.messages:
91
  if msg["role"] == "user" and "image" in msg:
92
  st.chat_message(msg["role"]).write(msg['content'])
 
99
  st.session_state["section"] = "text"
100
  if st.sidebar.button("Image Question"):
101
  st.session_state["section"] = "image"
 
102
  if "section" not in st.session_state:
103
  st.session_state["section"] = "text"
104
 
105
  def clean_response(response):
106
+ return response.split("```")[-1].strip() if "```" in response else response
 
 
107
 
108
  if st.session_state["section"] == "text":
109
  st.header("Text Question")
 
110
  question = st.text_area("Your Question:")
111
  if st.button("Get Answer"):
112
  if question:
113
  with st.spinner("Generating response..."):
114
  st.session_state.messages.append({"role": "user", "content": question})
115
  st.chat_message("user").write(question)
116
+
117
+ response = assistant_agent_text.run(question)
118
+ cleaned_response = clean_response(response)
119
+ st.session_state.messages.append({'role': 'assistant', "content": cleaned_response})
120
+ st.write('### Response:')
121
+ st.success(cleaned_response)
 
 
 
 
122
  else:
123
+ st.warning("Please enter a question.")
124
 
125
  elif st.session_state["section"] == "image":
126
  st.header("Image Question")
127
+ question = st.text_area("Your Question:")
 
128
  uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
 
129
  if st.button("Get Answer"):
130
+ if question and uploaded_file:
131
  with st.spinner("Generating response..."):
132
+ image_path = f"temp_{uploaded_file.name}"
133
+ with open(image_path, "wb") as f:
134
+ f.write(uploaded_file.read())
135
+
136
+ st.session_state.messages.append({"role": "user", "content": question, "image": image_path})
137
  st.chat_message("user").write(question)
138
+ st.image(image_path, caption='Uploaded Image', use_column_width=True)
139
+
140
+ response = assistant_agent_image.run(image_path)
141
+ cleaned_response = clean_response(response)
142
+ st.session_state.messages.append({'role': 'assistant', "content": cleaned_response})
143
+ st.write('### Response:')
144
+ st.success(cleaned_response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  else:
146
+ st.warning("Please enter a question and upload an image.")