Spaces:
Running
Running
File size: 8,491 Bytes
6dc96e0 e795ee7 ebeeb35 6dc96e0 e795ee7 6dc96e0 e795ee7 6dc96e0 e795ee7 6dc96e0 e795ee7 6dc96e0 e795ee7 6dc96e0 ebeeb35 6dc96e0 ebeeb35 6dc96e0 e795ee7 6dc96e0 e795ee7 6dc96e0 e795ee7 6dc96e0 e795ee7 6dc96e0 e795ee7 6dc96e0 e795ee7 6dc96e0 e795ee7 6dc96e0 e795ee7 6dc96e0 e795ee7 6dc96e0 e795ee7 6dc96e0 e795ee7 6dc96e0 e795ee7 6dc96e0 e795ee7 6dc96e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
# from langgraph.graph import Graph
# from langchain_groq import ChatGroq
# llm = langchain_groq(model="llama3-70b-8192")
# llm.invoke("hi how are you")
import streamlit as st
import os
import base64
from dotenv import load_dotenv
from langchain_groq import ChatGroq
from langchain.chains import LLMMathChain, LLMChain
from langchain.prompts import PromptTemplate
from langchain_community.utilities import WikipediaAPIWrapper
from langchain.agents.agent_types import AgentType
from langchain.agents import Tool, initialize_agent
from langchain_community.callbacks.streamlit import StreamlitCallbackHandler
from groq import Groq
import open_clip
from open_clip import create_model_from_pretrained, get_tokenizer # works on open-clip-torch>=2.23.0, timm>=0.9.8
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
tokenizer = open_clip.get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
load_dotenv()
groq_api_key = os.getenv("GROQ_API_KEY")
if not groq_api_key:
st.error("Groq API Key not found in .env file")
st.stop()
st.set_page_config(page_title="Medical Bot", page_icon="π¨βπ¬")
st.title("Medical Bot")
llm_text = ChatGroq(model="gemma2-9b-it", groq_api_key=groq_api_key)
llm_image = ChatGroq(model="llama-3.2-90b-vision-preview", groq_api_key=groq_api_key)
wikipedia_wrapper = WikipediaAPIWrapper()
wikipedia_tool = Tool(
name="Wikipedia",
func=wikipedia_wrapper.run,
description="A tool for searching the Internet to find various information on the topics mentioned."
)
math_chain = LLMMathChain.from_llm(llm=llm_text)
calculator = Tool(
name="Calculator",
func=math_chain.run,
description="A tool for solving mathematical problems. Provide only the mathematical expressions."
)
prompt = """
You are a mathematical problem-solving assistant tasked with helping users solve their questions. Arrive at the solution logically, providing a clear and step-by-step explanation. Present your response in a structured point-wise format for better understanding.
Question: {question}
Answer:
"""
prompt_template = PromptTemplate(
input_variables=["question"],
template=prompt
)
# Combine all the tools into a chain for text questions
chain = LLMChain(llm=llm_text, prompt=prompt_template)
reasoning_tool = Tool(
name="Reasoning Tool",
func=chain.run,
description="A tool for answering logic-based and reasoning questions."
)
def classify_image(image_path: str) -> str:
"""Classifies a medical image using BiomedCLIP."""
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device).eval()
# Open and preprocess image
image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
labels = ["MRI scan", "X-ray", "histopathology", "CT scan", "ultrasound", "medical chart"]
texts = tokenizer([f"this is a photo of {l}" for l in labels], context_length=256).to(device)
with torch.no_grad():
image_features, text_features, logit_scale = model(image, texts)
logits = (logit_scale * image_features @ text_features.t()).detach().softmax(dim=-1)
sorted_indices = torch.argsort(logits, dim=-1, descending=True)
top_class = labels[sorted_indices[0][0].item()]
return f"The image is classified as {top_class}."
# Wrap BiomedCLIP as a LangChain tool
biomed_clip_tool = Tool(
name="BiomedCLIP Image Classifier",
func=classify_image,
description="Classifies medical images into categories like MRI, X-ray, histopathology, etc."
)
# Initialize the agents for text questions
assistant_agent_text = initialize_agent(
tools=[wikipedia_tool, calculator, reasoning_tool, biomed_clip_tool],
llm=llm_text,
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=False,
handle_parsing_errors=True
)
if "messages" not in st.session_state:
st.session_state["messages"] = [
{"role": "assistant", "content": "Welcome! I am your Assistant. How can I help you today?"}
]
for msg in st.session_state.messages:
if msg["role"] == "user" and "image" in msg:
st.chat_message(msg["role"]).write(msg['content'])
st.image(msg["image"], caption='Uploaded Image', use_column_width=True)
else:
st.chat_message(msg["role"]).write(msg['content'])
st.sidebar.header("Navigation")
if st.sidebar.button("Text Question"):
st.session_state["section"] = "text"
if st.sidebar.button("Image Question"):
st.session_state["section"] = "image"
if "section" not in st.session_state:
st.session_state["section"] = "text"
def clean_response(response):
if "```" in response:
response = response.split("```")[1].strip()
return response
if st.session_state["section"] == "text":
st.header("Text Question")
st.write("Please enter your question below, and I will provide a detailed description of the problem and suggest a solution for it.")
question = st.text_area("Your Question:")
if st.button("Get Answer"):
if question:
with st.spinner("Generating response..."):
st.session_state.messages.append({"role": "user", "content": question})
st.chat_message("user").write(question)
st_cb = StreamlitCallbackHandler(st.container(), expand_new_thoughts=False)
try:
response = assistant_agent_text.run(st.session_state.messages, callbacks=[st_cb])
cleaned_response = clean_response(response)
st.session_state.messages.append({'role': 'assistant', "content": cleaned_response})
st.write('### Response:')
st.success(cleaned_response)
except ValueError as e:
st.error(f"An error occurred: {e}")
else:
st.warning("Please enter a question to get an answer.")
elif st.session_state["section"] == "image":
st.header("Image Question")
st.write("Please enter your question below and upload the medical image. I will provide a detailed description of the problem and suggest a solution for it.")
question = st.text_area("Your Question:", "Example: What is the patient suffering from?")
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
if st.button("Get Answer"):
if question and uploaded_file is not None:
with st.spinner("Generating response..."):
image_data = uploaded_file.read()
image_data_url = f"data:image/jpeg;base64,{base64.b64encode(image_data).decode()}"
st.session_state.messages.append({"role": "user", "content": question, "image": image_data})
st.chat_message("user").write(question)
st.image(image_data, caption='Uploaded Image', use_column_width=True)
client = Groq()
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": question
},
{
"type": "image_url",
"image_url": {
"url": image_data_url
}
}
]
}
]
try:
completion = client.chat.completions.create(
model="llama-3.2-90b-vision-preview",
messages=messages,
temperature=1,
max_tokens=1024,
top_p=1,
stream=False,
stop=None,
)
response = completion.choices[0].message.content
cleaned_response = clean_response(response)
st.session_state.messages.append({'role': 'assistant', "content": cleaned_response})
st.write('### Response:')
st.success(cleaned_response)
except ValueError as e:
st.error(f"An error occurred: {e}")
else:
st.warning("Please enter a question and upload an image to get an answer.")
|