mMonika commited on
Commit
6dc96e0
Β·
verified Β·
1 Parent(s): 7f1ecff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -62
app.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  import streamlit as st
2
  import os
3
  import base64
@@ -11,34 +15,60 @@ from langchain.agents import Tool, initialize_agent
11
  from langchain_community.callbacks.streamlit import StreamlitCallbackHandler
12
  from groq import Groq
13
  import open_clip
14
- import torch
15
- from PIL import Image
16
 
17
- # Load environment variables
 
18
  load_dotenv()
19
  groq_api_key = os.getenv("GROQ_API_KEY")
 
20
  if not groq_api_key:
21
  st.error("Groq API Key not found in .env file")
22
  st.stop()
23
 
24
- # Configure Streamlit
25
  st.set_page_config(page_title="Medical Bot", page_icon="πŸ‘¨β€πŸ”¬")
26
  st.title("Medical Bot")
27
-
28
- # Initialize LLM models
29
- llm_text = ChatGroq(model="llama-3.3-70b-versatile", groq_api_key=groq_api_key)
30
  llm_image = ChatGroq(model="llama-3.2-90b-vision-preview", groq_api_key=groq_api_key)
31
 
32
- # Load BiomedCLIP model for image classification
33
- model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
34
- tokenizer = open_clip.get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
 
 
 
 
 
 
 
 
 
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def classify_image(image_path: str) -> str:
37
  """Classifies a medical image using BiomedCLIP."""
38
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
39
  model.to(device).eval()
40
 
41
- image = preprocess_val(Image.open(image_path)).unsqueeze(0).to(device)
 
42
  labels = ["MRI scan", "X-ray", "histopathology", "CT scan", "ultrasound", "medical chart"]
43
  texts = tokenizer([f"this is a photo of {l}" for l in labels], context_length=256).to(device)
44
 
@@ -50,43 +80,26 @@ def classify_image(image_path: str) -> str:
50
  top_class = labels[sorted_indices[0][0].item()]
51
  return f"The image is classified as {top_class}."
52
 
53
- # Define tools
54
- wikipedia_tool = Tool(name="Wikipedia", func=WikipediaAPIWrapper().run, description="A tool for searching information.")
55
- math_chain = LLMMathChain.from_llm(llm=llm_text)
56
- calculator = Tool(name="Calculator", func=math_chain.run, description="Solves mathematical problems.")
57
-
58
- prompt_template = PromptTemplate(input_variables=["question"], template="""
59
- You are a mathematical problem-solving assistant. Solve the question step by step.
60
- Question: {question}
61
- Answer:
62
- """)
63
- chain = LLMChain(llm=llm_text, prompt=prompt_template)
64
- reasoning_tool = Tool(name="Reasoning Tool", func=chain.run, description="Answers logic-based questions.")
65
-
66
- biomed_clip_tool = Tool(name="BiomedCLIP Image Classifier", func=classify_image, description="Classifies medical images.")
67
-
68
- # Initialize agents
69
  assistant_agent_text = initialize_agent(
70
- tools=[wikipedia_tool, calculator, reasoning_tool],
71
  llm=llm_text,
72
  agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
73
  verbose=False,
74
  handle_parsing_errors=True
75
  )
76
 
77
- assistant_agent_image = initialize_agent(
78
- tools=[biomed_clip_tool],
79
- llm=llm_image,
80
- agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
81
- verbose=False,
82
- handle_parsing_errors=True
83
- )
84
-
85
- # Streamlit session state for chat messages
86
  if "messages" not in st.session_state:
87
- st.session_state["messages"] = [{"role": "assistant", "content": "Welcome! How can I help you today?"}]
 
 
88
 
89
- # Chat Interface
90
  for msg in st.session_state.messages:
91
  if msg["role"] == "user" and "image" in msg:
92
  st.chat_message(msg["role"]).write(msg['content'])
@@ -99,48 +112,87 @@ if st.sidebar.button("Text Question"):
99
  st.session_state["section"] = "text"
100
  if st.sidebar.button("Image Question"):
101
  st.session_state["section"] = "image"
 
102
  if "section" not in st.session_state:
103
  st.session_state["section"] = "text"
104
 
105
  def clean_response(response):
106
- return response.split("```")[-1].strip() if "```" in response else response
 
 
107
 
108
  if st.session_state["section"] == "text":
109
  st.header("Text Question")
 
110
  question = st.text_area("Your Question:")
111
  if st.button("Get Answer"):
112
  if question:
113
  with st.spinner("Generating response..."):
114
  st.session_state.messages.append({"role": "user", "content": question})
115
  st.chat_message("user").write(question)
116
-
117
- response = assistant_agent_text.run(question)
118
- cleaned_response = clean_response(response)
119
- st.session_state.messages.append({'role': 'assistant', "content": cleaned_response})
120
- st.write('### Response:')
121
- st.success(cleaned_response)
 
 
 
 
122
  else:
123
- st.warning("Please enter a question.")
124
 
125
  elif st.session_state["section"] == "image":
126
  st.header("Image Question")
127
- question = st.text_area("Your Question:")
 
128
  uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
 
129
  if st.button("Get Answer"):
130
- if question and uploaded_file:
131
  with st.spinner("Generating response..."):
132
- image_path = f"temp_{uploaded_file.name}"
133
- with open(image_path, "wb") as f:
134
- f.write(uploaded_file.read())
135
-
136
- st.session_state.messages.append({"role": "user", "content": question, "image": image_path})
137
  st.chat_message("user").write(question)
138
- st.image(image_path, caption='Uploaded Image', use_column_width=True)
139
-
140
- response = assistant_agent_image.run(image_path)
141
- cleaned_response = clean_response(response)
142
- st.session_state.messages.append({'role': 'assistant', "content": cleaned_response})
143
- st.write('### Response:')
144
- st.success(cleaned_response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  else:
146
- st.warning("Please enter a question and upload an image.")
 
1
+ # from langgraph.graph import Graph
2
+ # from langchain_groq import ChatGroq
3
+ # llm = langchain_groq(model="llama3-70b-8192")
4
+ # llm.invoke("hi how are you")
5
  import streamlit as st
6
  import os
7
  import base64
 
15
  from langchain_community.callbacks.streamlit import StreamlitCallbackHandler
16
  from groq import Groq
17
  import open_clip
18
+ from open_clip import create_model_from_pretrained, get_tokenizer # works on open-clip-torch>=2.23.0, timm>=0.9.8
 
19
 
20
+ model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
21
+ tokenizer = open_clip.get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
22
  load_dotenv()
23
  groq_api_key = os.getenv("GROQ_API_KEY")
24
+
25
  if not groq_api_key:
26
  st.error("Groq API Key not found in .env file")
27
  st.stop()
28
 
 
29
  st.set_page_config(page_title="Medical Bot", page_icon="πŸ‘¨β€πŸ”¬")
30
  st.title("Medical Bot")
31
+ llm_text = ChatGroq(model="gemma2-9b-it", groq_api_key=groq_api_key)
 
 
32
  llm_image = ChatGroq(model="llama-3.2-90b-vision-preview", groq_api_key=groq_api_key)
33
 
34
+ wikipedia_wrapper = WikipediaAPIWrapper()
35
+ wikipedia_tool = Tool(
36
+ name="Wikipedia",
37
+ func=wikipedia_wrapper.run,
38
+ description="A tool for searching the Internet to find various information on the topics mentioned."
39
+ )
40
+ math_chain = LLMMathChain.from_llm(llm=llm_text)
41
+ calculator = Tool(
42
+ name="Calculator",
43
+ func=math_chain.run,
44
+ description="A tool for solving mathematical problems. Provide only the mathematical expressions."
45
+ )
46
 
47
+ prompt = """
48
+ You are a mathematical problem-solving assistant tasked with helping users solve their questions. Arrive at the solution logically, providing a clear and step-by-step explanation. Present your response in a structured point-wise format for better understanding.
49
+ Question: {question}
50
+ Answer:
51
+ """
52
+
53
+ prompt_template = PromptTemplate(
54
+ input_variables=["question"],
55
+ template=prompt
56
+ )
57
+ # Combine all the tools into a chain for text questions
58
+ chain = LLMChain(llm=llm_text, prompt=prompt_template)
59
+
60
+ reasoning_tool = Tool(
61
+ name="Reasoning Tool",
62
+ func=chain.run,
63
+ description="A tool for answering logic-based and reasoning questions."
64
+ )
65
  def classify_image(image_path: str) -> str:
66
  """Classifies a medical image using BiomedCLIP."""
67
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
68
  model.to(device).eval()
69
 
70
+ # Open and preprocess image
71
+ image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
72
  labels = ["MRI scan", "X-ray", "histopathology", "CT scan", "ultrasound", "medical chart"]
73
  texts = tokenizer([f"this is a photo of {l}" for l in labels], context_length=256).to(device)
74
 
 
80
  top_class = labels[sorted_indices[0][0].item()]
81
  return f"The image is classified as {top_class}."
82
 
83
+ # Wrap BiomedCLIP as a LangChain tool
84
+ biomed_clip_tool = Tool(
85
+ name="BiomedCLIP Image Classifier",
86
+ func=classify_image,
87
+ description="Classifies medical images into categories like MRI, X-ray, histopathology, etc."
88
+ )
89
+ # Initialize the agents for text questions
 
 
 
 
 
 
 
 
 
90
  assistant_agent_text = initialize_agent(
91
+ tools=[wikipedia_tool, calculator, reasoning_tool, biomed_clip_tool],
92
  llm=llm_text,
93
  agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
94
  verbose=False,
95
  handle_parsing_errors=True
96
  )
97
 
 
 
 
 
 
 
 
 
 
98
  if "messages" not in st.session_state:
99
+ st.session_state["messages"] = [
100
+ {"role": "assistant", "content": "Welcome! I am your Assistant. How can I help you today?"}
101
+ ]
102
 
 
103
  for msg in st.session_state.messages:
104
  if msg["role"] == "user" and "image" in msg:
105
  st.chat_message(msg["role"]).write(msg['content'])
 
112
  st.session_state["section"] = "text"
113
  if st.sidebar.button("Image Question"):
114
  st.session_state["section"] = "image"
115
+
116
  if "section" not in st.session_state:
117
  st.session_state["section"] = "text"
118
 
119
  def clean_response(response):
120
+ if "```" in response:
121
+ response = response.split("```")[1].strip()
122
+ return response
123
 
124
  if st.session_state["section"] == "text":
125
  st.header("Text Question")
126
+ st.write("Please enter your question below, and I will provide a detailed description of the problem and suggest a solution for it.")
127
  question = st.text_area("Your Question:")
128
  if st.button("Get Answer"):
129
  if question:
130
  with st.spinner("Generating response..."):
131
  st.session_state.messages.append({"role": "user", "content": question})
132
  st.chat_message("user").write(question)
133
+
134
+ st_cb = StreamlitCallbackHandler(st.container(), expand_new_thoughts=False)
135
+ try:
136
+ response = assistant_agent_text.run(st.session_state.messages, callbacks=[st_cb])
137
+ cleaned_response = clean_response(response)
138
+ st.session_state.messages.append({'role': 'assistant', "content": cleaned_response})
139
+ st.write('### Response:')
140
+ st.success(cleaned_response)
141
+ except ValueError as e:
142
+ st.error(f"An error occurred: {e}")
143
  else:
144
+ st.warning("Please enter a question to get an answer.")
145
 
146
  elif st.session_state["section"] == "image":
147
  st.header("Image Question")
148
+ st.write("Please enter your question below and upload the medical image. I will provide a detailed description of the problem and suggest a solution for it.")
149
+ question = st.text_area("Your Question:", "Example: What is the patient suffering from?")
150
  uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
151
+
152
  if st.button("Get Answer"):
153
+ if question and uploaded_file is not None:
154
  with st.spinner("Generating response..."):
155
+ image_data = uploaded_file.read()
156
+ image_data_url = f"data:image/jpeg;base64,{base64.b64encode(image_data).decode()}"
157
+ st.session_state.messages.append({"role": "user", "content": question, "image": image_data})
 
 
158
  st.chat_message("user").write(question)
159
+ st.image(image_data, caption='Uploaded Image', use_column_width=True)
160
+
161
+ client = Groq()
162
+
163
+ messages = [
164
+ {
165
+ "role": "user",
166
+ "content": [
167
+ {
168
+ "type": "text",
169
+ "text": question
170
+ },
171
+ {
172
+ "type": "image_url",
173
+ "image_url": {
174
+ "url": image_data_url
175
+ }
176
+ }
177
+ ]
178
+ }
179
+ ]
180
+ try:
181
+ completion = client.chat.completions.create(
182
+ model="llama-3.2-90b-vision-preview",
183
+ messages=messages,
184
+ temperature=1,
185
+ max_tokens=1024,
186
+ top_p=1,
187
+ stream=False,
188
+ stop=None,
189
+ )
190
+ response = completion.choices[0].message.content
191
+ cleaned_response = clean_response(response)
192
+ st.session_state.messages.append({'role': 'assistant', "content": cleaned_response})
193
+ st.write('### Response:')
194
+ st.success(cleaned_response)
195
+ except ValueError as e:
196
+ st.error(f"An error occurred: {e}")
197
  else:
198
+ st.warning("Please enter a question and upload an image to get an answer.")