Craig Pretzinger commited on
Commit
b1f5115
·
2 Parent(s): 4d2f914 3d9cfb9

Forgot to commit

Browse files
Files changed (1) hide show
  1. app.py +81 -38
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
- from transformers import BertTokenizer, BertForSequenceClassification, TrainingArguments, Trainer
4
  import openai
5
  import os
6
  import faiss
@@ -8,13 +8,11 @@ import numpy as np
8
  import requests
9
  from datasets import load_dataset
10
 
11
- ds = load_dataset("epfl-llm/guidelines")
 
 
12
 
13
- # Load OpenAI and Serper API keys from Hugging Face secrets
14
- openai.api_key = os.getenv("OPENAI_API_KEY") # Ensure the OpenAI API key is pulled correctly
15
- serper_api_key = os.getenv("SERPER_API_KEY") # Ensure the Serper API key is pulled correctly
16
-
17
- # Load PubMedBERT tokenizer and model for FDA-related processing
18
  tokenizer = BertTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
19
  model = BertForSequenceClassification.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract", num_labels=2)
20
 
@@ -22,24 +20,30 @@ model = BertForSequenceClassification.from_pretrained("microsoft/BiomedNLP-PubMe
22
  dimension = 768 # PubMedBERT embedding size
23
  index = faiss.IndexFlatL2(dimension)
24
 
 
25
  def embed_text(text):
26
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
27
- outputs = model(**inputs, output_hidden_states=True) # Ensure hidden states are returned
28
- hidden_state = outputs.hidden_states[-1] # Get the last hidden state
29
- return hidden_state.mean(dim=1).detach().numpy() # Take the mean across the sequence
30
 
31
- # Example: Embed past conversation and store in FAISS
32
  past_conversation = "FDA approval for companion diagnostics requires careful documentation."
33
  past_embedding = embed_text(past_conversation)
 
 
 
 
 
34
  index.add(past_embedding)
35
 
36
- # Embed the incoming query and search for related memory
37
  def search_memory(query):
38
  query_embedding = embed_text(query)
39
- D, I = index.search(query_embedding, k=1) # Retrieve most similar past conversation
40
  return I
41
 
42
- # Function to handle FDA-specific queries with PubMedBERT
43
  def handle_fda_query(query):
44
  inputs = tokenizer(query, return_tensors="pt", padding="max_length", truncation=True)
45
  outputs = model(**inputs)
@@ -47,20 +51,23 @@ def handle_fda_query(query):
47
  response = "Processed FDA-related query via PubMedBERT"
48
  return response
49
 
50
- # Function to handle general queries using GPT-4o
51
  def handle_openai_query(prompt):
52
- response = openai.Completion.create(
53
- engine="gpt-4o", # Using GPT-4o as per instruction
54
- prompt=prompt,
 
 
 
55
  max_tokens=100
56
  )
57
- return response.choices[0].text.strip()
58
 
59
  # Web search with Serper API
60
  def web_search(query):
61
  url = f"https://google.serper.dev/search"
62
  headers = {
63
- "X-API-KEY": serper_api_key
64
  }
65
  params = {
66
  "q": query
@@ -68,7 +75,45 @@ def web_search(query):
68
  response = requests.get(url, headers=headers, params=params)
69
  return response.json()
70
 
71
- # Main assistant function that delegates to either OpenAI, PubMedBERT, or Serper (web search)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  def respond(
73
  message,
74
  history: list[tuple[str, str]],
@@ -77,7 +122,7 @@ def respond(
77
  temperature,
78
  top_p,
79
  ):
80
- # Prepare the context for OpenAI and PubMedBERT
81
  messages = [{"role": "system", "content": system_message}]
82
 
83
  for val in history:
@@ -88,35 +133,32 @@ def respond(
88
 
89
  messages.append({"role": "user", "content": message})
90
 
91
- # Check if the query is related to FDA
92
  openai_response = handle_openai_query(f"Is this query FDA-related: {message}")
93
 
94
  if "FDA" in openai_response or "regulatory" in openai_response:
95
  # Search past conversations/memory using FAISS
96
  memory_index = search_memory(message)
97
  if memory_index:
98
- return f"Found relevant past memory: {past_conversation}" # Return past context from memory
99
 
100
  # If no memory match, proceed with PubMedBERT
101
  return handle_fda_query(message)
102
 
103
- # If query asks for a web search, perform web search
104
  if "search the web" in message.lower():
105
  return web_search(message)
106
 
107
- # General conversational handling with GPT-4o
108
- response = ""
109
- for message in client.chat_completion(
110
- messages,
111
- max_tokens=max_tokens,
112
- stream=True,
113
- temperature=temperature,
114
- top_p=top_p,
115
- ):
116
- token = message.choices[0].delta.content
117
 
118
- response += token
119
- yield response
 
120
 
121
 
122
  # Create Gradio ChatInterface for interaction
@@ -130,5 +172,6 @@ demo = gr.ChatInterface(
130
  ],
131
  )
132
 
 
133
  if __name__ == "__main__":
134
- demo.launch(share=True)
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
+ from transformers import BertTokenizer, BertForSequenceClassification
4
  import openai
5
  import os
6
  import faiss
 
8
  import requests
9
  from datasets import load_dataset
10
 
11
+ # Load OpenAI API key and organization ID from environment variables
12
+ openai.api_key = os.getenv("OPENAI_API_KEY")
13
+ openai.Organization = os.getenv("OPENAI_ORG_ID")
14
 
15
+ # Load PubMedBERT tokenizer and model
 
 
 
 
16
  tokenizer = BertTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
17
  model = BertForSequenceClassification.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract", num_labels=2)
18
 
 
20
  dimension = 768 # PubMedBERT embedding size
21
  index = faiss.IndexFlatL2(dimension)
22
 
23
+ # Embed text using PubMedBERT
24
  def embed_text(text):
25
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
26
+ outputs = model(**inputs, output_hidden_states=True)
27
+ hidden_state = outputs.hidden_states[-1]
28
+ return hidden_state.mean(dim=1).detach().numpy()
29
 
30
+ # Add past conversation embedding to FAISS index
31
  past_conversation = "FDA approval for companion diagnostics requires careful documentation."
32
  past_embedding = embed_text(past_conversation)
33
+ past_embedding = np.array(past_embedding) # Convert to numpy array
34
+
35
+ # Reshape if necessary (e.g., (1, 768) for PubMedBERT)
36
+ past_embedding = past_embedding.reshape(1, -1)
37
+
38
  index.add(past_embedding)
39
 
40
+ # Search past conversations/memory using FAISS
41
  def search_memory(query):
42
  query_embedding = embed_text(query)
43
+ D, I = index.search(query_embedding, k=1)
44
  return I
45
 
46
+ # Handle FDA-specific queries with PubMedBERT
47
  def handle_fda_query(query):
48
  inputs = tokenizer(query, return_tensors="pt", padding="max_length", truncation=True)
49
  outputs = model(**inputs)
 
51
  response = "Processed FDA-related query via PubMedBERT"
52
  return response
53
 
54
+ # Handle general queries using GPT-4O
55
  def handle_openai_query(prompt):
56
+ response = openai.Chat.create(
57
+ model="gpt-4-0314-16k-512",
58
+ messages=[
59
+ {"role": "user", "content": prompt}
60
+ ],
61
+ temperature=0.7,
62
  max_tokens=100
63
  )
64
+ return response.choices[0].message.content
65
 
66
  # Web search with Serper API
67
  def web_search(query):
68
  url = f"https://google.serper.dev/search"
69
  headers = {
70
+ "X-API-KEY": os.getenv("SERPER_API_KEY")
71
  }
72
  params = {
73
  "q": query
 
75
  response = requests.get(url, headers=headers, params=params)
76
  return response.json()
77
 
78
+ # Contextual Short-Term Memory (CSTM)
79
+ cstm = []
80
+
81
+ # Long-Term Memory (LTM)
82
+ ltm = [] # Load knowledge base articles or FAQs
83
+
84
+ # Semantic search function
85
+ def semantic_search(query, cstm, ltm):
86
+ # Generate embeddings for query and CSTM/LTM
87
+ query_embedding = embed_text(query)
88
+ cstm_embeddings = [embed_text(text) for text in cstm]
89
+ ltm_embeddings = [embed_text(text) for text in ltm]
90
+
91
+ # Calculate similarity scores
92
+ cstm_scores = calculate_similarity(query_embedding, cstm_embeddings)
93
+ ltm_scores = calculate_similarity(query_embedding, ltm_embeddings)
94
+
95
+ # Retrieve top relevant results from CSTM and LTM
96
+ top_cstm = np.argmax(cstm_scores)
97
+ top_ltm = np.argmax(ltm_scores)
98
+
99
+ return top_cstm, top_ltm
100
+
101
+ # Calculate similarity between embeddings
102
+ def calculate_similarity(query_embedding, embeddings):
103
+ similarity_scores = []
104
+ for embedding in embeddings:
105
+ score = cosine_similarity(query_embedding, embedding)
106
+ similarity_scores.append(score)
107
+ return similarity_scores
108
+
109
+ # Cosine similarity function
110
+ def cosine_similarity(a, b):
111
+ dot_product = np.dot(a, b)
112
+ magnitude_a = np.linalg.norm(a)
113
+ magnitude_b = np.linalg.norm(b)
114
+ return dot_product / (magnitude_a * magnitude_b)
115
+
116
+ # Main assistant function
117
  def respond(
118
  message,
119
  history: list[tuple[str, str]],
 
122
  temperature,
123
  top_p,
124
  ):
125
+ # Prepare context for OpenAI and PubMedBERT
126
  messages = [{"role": "system", "content": system_message}]
127
 
128
  for val in history:
 
133
 
134
  messages.append({"role": "user", "content": message})
135
 
136
+ # Check if query is FDA-related
137
  openai_response = handle_openai_query(f"Is this query FDA-related: {message}")
138
 
139
  if "FDA" in openai_response or "regulatory" in openai_response:
140
  # Search past conversations/memory using FAISS
141
  memory_index = search_memory(message)
142
  if memory_index:
143
+ return f"Found relevant past memory: {past_conversation}"
144
 
145
  # If no memory match, proceed with PubMedBERT
146
  return handle_fda_query(message)
147
 
148
+ # If query asks for web search, perform web search
149
  if "search the web" in message.lower():
150
  return web_search(message)
151
 
152
+ # Perform semantic search on CSTM and LTM
153
+ top_cstm, top_ltm = semantic_search(message, cstm, ltm)
154
+ if top_cstm:
155
+ return f"Found relevant context: {cstm[top_cstm]}"
156
+ elif top_ltm:
157
+ return f"Found relevant knowledge: {ltm[top_ltm]}"
 
 
 
 
158
 
159
+ # General conversational handling with GPT-4O
160
+ response = handle_openai_query(message)
161
+ return response
162
 
163
 
164
  # Create Gradio ChatInterface for interaction
 
172
  ],
173
  )
174
 
175
+
176
  if __name__ == "__main__":
177
+ demo.launch()