mMonika commited on
Commit
ebeeb35
·
verified ·
1 Parent(s): e52d88e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -3
app.py CHANGED
@@ -14,7 +14,11 @@ from langchain.agents.agent_types import AgentType
14
  from langchain.agents import Tool, initialize_agent
15
  from langchain_community.callbacks.streamlit import StreamlitCallbackHandler
16
  from groq import Groq
 
 
17
 
 
 
18
  load_dotenv()
19
  groq_api_key = os.getenv("GROQ_API_KEY")
20
 
@@ -58,10 +62,33 @@ reasoning_tool = Tool(
58
  func=chain.run,
59
  description="A tool for answering logic-based and reasoning questions."
60
  )
61
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  # Initialize the agents for text questions
63
  assistant_agent_text = initialize_agent(
64
- tools=[wikipedia_tool, calculator, reasoning_tool],
65
  llm=llm_text,
66
  agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
67
  verbose=False,
@@ -168,4 +195,4 @@ elif st.session_state["section"] == "image":
168
  except ValueError as e:
169
  st.error(f"An error occurred: {e}")
170
  else:
171
- st.warning("Please enter a question and upload an image to get an answer.")
 
14
  from langchain.agents import Tool, initialize_agent
15
  from langchain_community.callbacks.streamlit import StreamlitCallbackHandler
16
  from groq import Groq
17
+ import open_clip
18
+ from open_clip import create_model_from_pretrained, get_tokenizer # works on open-clip-torch>=2.23.0, timm>=0.9.8
19
 
20
+ model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
21
+ tokenizer = open_clip.get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
22
  load_dotenv()
23
  groq_api_key = os.getenv("GROQ_API_KEY")
24
 
 
62
  func=chain.run,
63
  description="A tool for answering logic-based and reasoning questions."
64
  )
65
+ def classify_image(image_path: str) -> str:
66
+ """Classifies a medical image using BiomedCLIP."""
67
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
68
+ model.to(device).eval()
69
+
70
+ # Open and preprocess image
71
+ image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
72
+ labels = ["MRI scan", "X-ray", "histopathology", "CT scan", "ultrasound", "medical chart"]
73
+ texts = tokenizer([f"this is a photo of {l}" for l in labels], context_length=256).to(device)
74
+
75
+ with torch.no_grad():
76
+ image_features, text_features, logit_scale = model(image, texts)
77
+ logits = (logit_scale * image_features @ text_features.t()).detach().softmax(dim=-1)
78
+ sorted_indices = torch.argsort(logits, dim=-1, descending=True)
79
+
80
+ top_class = labels[sorted_indices[0][0].item()]
81
+ return f"The image is classified as {top_class}."
82
+
83
+ # Wrap BiomedCLIP as a LangChain tool
84
+ biomed_clip_tool = Tool(
85
+ name="BiomedCLIP Image Classifier",
86
+ func=classify_image,
87
+ description="Classifies medical images into categories like MRI, X-ray, histopathology, etc."
88
+ )
89
  # Initialize the agents for text questions
90
  assistant_agent_text = initialize_agent(
91
+ tools=[wikipedia_tool, calculator, reasoning_tool, biomed_clip_tool],
92
  llm=llm_text,
93
  agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
94
  verbose=False,
 
195
  except ValueError as e:
196
  st.error(f"An error occurred: {e}")
197
  else:
198
+ st.warning("Please enter a question and upload an image to get an answer.")