datawithsuman commited on
Commit
ff0a602
·
verified ·
1 Parent(s): 77320f3

Modularized appl.py

Browse files
Files changed (1) hide show
  1. app.py +155 -265
app.py CHANGED
@@ -1,291 +1,181 @@
 
1
  import os
2
- import streamlit as st
3
- import streamlit.components.v1 as components
4
- import openai
5
- from llama_index.llms.openai import OpenAI
6
-
7
- import os
8
- from llama_index.core import SimpleDirectoryReader, VectorStoreIndex, StorageContext, PropertyGraphIndex
9
- from llama_index.core.indices.property_graph import (
10
- ImplicitPathExtractor,
11
- SimpleLLMPathExtractor,
12
- )
13
- from llama_index.retrievers.bm25 import BM25Retriever
14
- from llama_index.core.retrievers import BaseRetriever
15
- from llama_index.core.node_parser import SentenceSplitter
16
- from llama_index.embeddings.openai import OpenAIEmbedding
17
- # from llama_index.llms.mistralai import MistralAI
18
- from llmlingua import PromptCompressor
19
- from rouge_score import rouge_scorer
20
- from semantic_text_similarity.models import WebBertSimilarity
21
- import nest_asyncio
22
-
23
- # Apply nest_asyncio
24
- nest_asyncio.apply()
25
-
26
- # OpenAI credentials
27
- key = os.getenv('OPENAI_API_KEY')
28
- openai.api_key = key
29
- os.environ["OPENAI_API_KEY"] = key
30
-
31
- # key = os.getenv('MISTRAL_API_KEY')
32
- # os.environ["MISTRAL_API_KEY"] = key
33
 
34
- # Anthropic credentials
35
- # key = os.getenv('CLAUDE_API_KEY')
36
- # os.environ["ANTHROPIC_API_KEY"] = key
 
 
37
 
38
- # Streamlit UI
39
- st.title("Prompt Optimization for a Policy Bot")
40
-
41
- uploaded_files = st.file_uploader("Upload a Policy document in pdf format", type="pdf", accept_multiple_files=True)
42
 
43
- if uploaded_files:
44
- for uploaded_file in uploaded_files:
45
- with open(f"./data/{uploaded_file.name}", 'wb') as f:
 
 
 
 
46
  f.write(uploaded_file.getbuffer())
47
- reader = SimpleDirectoryReader(input_files=[f"./data/{uploaded_file.name}"])
 
48
  documents = reader.load_data()
49
- st.success("File uploaded...")
50
-
51
- # # Indexing
52
- # index = PropertyGraphIndex.from_documents(
53
- # documents,
54
- # embed_model=OpenAIEmbedding(model_name="text-embedding-3-small"),
55
- # kg_extractors=[
56
- # ImplicitPathExtractor(),
57
- # SimpleLLMPathExtractor(
58
- # llm=OpenAI(model="gpt-3.5-turbo", temperature=0.3),
59
- # num_workers=4,
60
- # max_paths_per_chunk=10,
61
- # ),
62
- # ],
63
- # show_progress=True,
64
- # )
65
-
66
- # # Save Knowlege Graph
67
- # index.property_graph_store.save_networkx_graph(name="./data/kg.html")
68
-
69
- # # Display the graph in Streamlit
70
- # st.success("File Processed...")
71
- # st.success("Creating Knowledge Graph...")
72
- # HtmlFile = open("./data/kg.html", 'r', encoding='utf-8')
73
- # source_code = HtmlFile.read()
74
- # components.html(source_code, height= 500, width=700)
75
-
76
- # # Retrieval
77
- # kg_retriever = index.as_retriever(
78
- # include_text=True, # include source text, default True
79
- # )
80
 
81
-
82
- # Indexing
83
- splitter = SentenceSplitter(chunk_size=256)
84
- nodes = splitter.get_nodes_from_documents(documents)
85
  storage_context = StorageContext.from_defaults()
86
  storage_context.docstore.add_documents(nodes)
87
- index = VectorStoreIndex(nodes=nodes, storage_context=storage_context)
88
-
89
- # Retrieval
90
- bm25_retriever = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=1)
91
- vector_retriever = index.as_retriever(similarity_top_k=1)
92
-
93
- # Hybrid Retriever class
94
- class HybridRetriever(BaseRetriever):
95
- def __init__(self, vector_retriever, bm25_retriever):
96
- self.vector_retriever = vector_retriever
97
- self.bm25_retriever = bm25_retriever
98
- super().__init__()
99
-
100
- # def _retrieve(self, query, **kwargs):
101
- # bm25_nodes = self.bm25_retriever.retrieve(query, **kwargs)
102
- # vector_nodes = self.vector_retriever.retrieve(query, **kwargs)
103
- # all_nodes = []
104
- # node_ids = set()
105
- # for n in bm25_nodes + vector_nodes:
106
- # if n.node.node_id not in node_ids:
107
- # all_nodes.append(n)
108
- # node_ids.add(n.node.node_id)
109
- # return all_nodes
110
-
111
- def _retrieve(self, query, **kwargs):
112
- # bm25_nodes = self.bm25_retriever.retrieve(query, **kwargs)
113
- vector_nodes = self.vector_retriever.retrieve(query, **kwargs)
114
- all_nodes = []
115
- node_ids = set()
116
- for n in vector_nodes:
117
- if n.node.node_id not in node_ids:
118
- all_nodes.append(n)
119
- node_ids.add(n.node.node_id)
120
- return all_nodes
121
 
122
- hybrid_retriever = HybridRetriever(vector_retriever, bm25_retriever)
123
-
124
- # Generation
125
- model = "gpt-3.5-turbo"
126
- # model = "claude-3-opus-20240229"
127
-
128
- # def get_context(query):
129
- # contexts = kg_retriever.retrieve(query)
130
- # context_list = [n.text for n in contexts]
131
- # return context_list
132
 
133
- def get_context(query):
134
- contexts = hybrid_retriever.retrieve(query)
135
- context_list = [n.get_content() for n in contexts]
136
- return context_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
-
140
- def res(prompt):
141
-
142
- response = openai.chat.completions.create(
143
- model=model,
144
- messages=[
145
- {"role":"system",
146
- "content":"You are a helpful assistant who answers from the following context. If the answer can't be found in context, politely refuse"
147
- },
148
- {"role": "user",
149
- "content": prompt,
150
- }
151
- ]
152
- )
153
-
154
- return [response.usage.prompt_tokens, response.usage.completion_tokens, response.usage.total_tokens, response.choices[0].message.content]
155
-
156
-
157
- # Summary
158
- def summary(prompt, temp):
159
-
160
- response = openai.chat.completions.create(
161
- model=model,
162
- temperature=temp,
163
- messages=[
164
- {"role":"system",
165
- "content":"Summarize the following context:"
166
- },
167
- {"role": "user",
168
- "content": prompt,
169
- }
170
- ]
171
- )
172
- return response.choices[0].message.content
173
-
174
-
175
- full_prompt = documents[0].text
176
- st.success("Input text")
177
- st.markdown(full_prompt)
178
-
179
- st.success("Reference summary")
180
- gen_summ = summary(full_prompt, temp = 0.6)
181
- st.markdown(gen_summ)
182
- st.success("Generated summary")
183
- ref_summ = summary(full_prompt, temp = 0.8)
184
- st.markdown(ref_summ)
185
-
186
-
187
- # Initialize session state for token summary, evaluation details, and chat messages
188
  if "token_summary" not in st.session_state:
189
  st.session_state.token_summary = []
190
  if "messages" not in st.session_state:
191
  st.session_state.messages = []
192
-
193
- # Display chat messages from history on app rerun
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  for message in st.session_state.messages:
195
  with st.chat_message(message["role"]):
196
  st.markdown(message["content"])
197
-
198
-
199
- # Accept user input
200
  if prompt := st.chat_input("Enter your query:"):
201
- st.success("Fetching info...")
202
- # Add user message to chat history
203
  st.session_state.messages.append({"role": "user", "content": prompt})
204
  with st.chat_message("user"):
205
  st.markdown(prompt)
206
-
207
- # Generate response
208
- # st.success("Fetching info...")
209
- context_list = get_context(prompt)
210
- context = " ".join(context_list)
211
- st.success("Getting context")
212
- st.markdown(context)
213
-
214
- # # Summarize
215
- # full_prompt = "\n\n".join([context + prompt])
216
- # orig_res = res(full_prompt)
217
 
 
218
 
219
-
220
- # Original prompt response
221
- full_prompt = "\n\n".join([context + prompt])
222
- orig_res = res(full_prompt)
223
- st.session_state.messages.append({"role": "assistant", "content": "Generating Original prompt response..."})
224
- st.session_state.messages.append({"role": "assistant", "content": orig_res[3]})
225
- st.success("Generating Original prompt response...")
226
- with st.chat_message("assistant"):
227
- st.markdown(orig_res[3])
228
-
229
- # # Compressed Response
230
- # st.session_state.messages.append({"role": "assistant", "content": "Generating Optimized prompt response..."})
231
- # st.success("Generating Optimized prompt response...")
232
-
233
- # llm_lingua = PromptCompressor(
234
- # model_name="microsoft/llmlingua-2-xlm-roberta-large-meetingbank",
235
- # use_llmlingua2=True, device_map="cpu"
236
- # )
237
-
238
- # def prompt_compression(context, rate=0.5):
239
- # compressed_context = llm_lingua.compress_prompt(
240
- # context,
241
- # rate=rate,
242
- # force_tokens=["!", ".", "?", "\n"],
243
- # drop_consecutive=True,
244
- # )
245
- # return compressed_context
246
- # compressed_context = prompt_compression(context)
247
- # full_opt_prompt = "\n\n".join([compressed_context['compressed_prompt'] + prompt])
248
- # compressed_res = res(full_opt_prompt)
249
- # st.session_state.messages.append({"role": "assistant", "content": compressed_res[3]})
250
- # with st.chat_message("assistant"):
251
- # st.markdown(compressed_res[3])
252
-
253
- # # Save token summary and evaluation details to session state
254
- # scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
255
- # scores = scorer.score(compressed_res[3],orig_res[3])
256
- # webert_model = WebBertSimilarity(device='cpu')
257
- # similarity_score = webert_model.predict([(compressed_res[3], orig_res[3])])[0] / 5 * 100
258
-
259
-
260
- # # Display token summary
261
- # st.session_state.messages.append({"role": "assistant", "content": "Token Length Summary..."})
262
- # st.success('Token Length Summary...')
263
- # st.session_state.messages.append({"role": "assistant", "content": f"Original Prompt has {orig_res[0]} tokens"})
264
- # st.write(f"Original Prompt has {orig_res[0]} tokens")
265
- # st.session_state.messages.append({"role": "assistant", "content": f"Optimized Prompt has {compressed_res[0]} tokens"})
266
- # st.write(f"Optimized Prompt has {compressed_res[0]} tokens")
267
-
268
- # st.session_state.messages.append({"role": "assistant", "content": "Comparing Original and Optimized Prompt Response..."})
269
- # st.success("Comparing Original and Optimized Prompt Response...")
270
- # st.session_state.messages.append({"role": "assistant", "content": f"Rouge Score : {scores['rougeL'].fmeasure * 100}"})
271
- # st.write(f"Rouge Score : {scores['rougeL'].fmeasure * 100}")
272
- # st.session_state.messages.append({"role": "assistant", "content": f"Semantic Text Similarity Score : {similarity_score}"})
273
- # st.write(f"Semantic Text Similarity Score : {similarity_score}")
274
-
275
- # st.write(" ")
276
- # # origin_tokens = compressed_context['origin_tokens']
277
- # # compressed_tokens = compressed_context['compressed_tokens']
278
- # origin_tokens = orig_res[0]
279
- # compressed_tokens = compressed_res[0]
280
- # gpt_saving = (origin_tokens - compressed_tokens) * 0.06 / 1000
281
- # claude_saving = (origin_tokens - compressed_tokens) * 0.015 / 1000
282
- # mistral_saving = (origin_tokens - compressed_tokens) * 0.004 / 1000
283
- # # st.session_state.messages.append({"role": "assistant", "content": f"""The optimized prompt has saved ${gpt_saving:.4f} in GPT4, ${mistral_saving:.4f} in Mistral"""})
284
- # # st.success(f"""The optimized prompt has saved ${gpt_saving:.4f} in GPT4, ${mistral_saving:.4f} in Mistral""")
285
- # st.session_state.messages.append({"role": "assistant", "content": f"The optimized prompt has ${gpt_saving:.4f} saved in GPT-4."})
286
- # st.success(f"The optimized prompt has ${gpt_saving:.4f} saved in GPT-4.")
287
-
288
- # st.success("Downloading Optimized Prompt...")
289
- # st.download_button(label = "Download Optimized Prompt",
290
- # data = full_opt_prompt, file_name='./data/optimized_prompt.txt')
291
-
 
1
+ # config.py
2
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
+ class Config:
5
+ OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
6
+ MODEL_NAME = "gpt-3.5-turbo"
7
+ EMBEDDING_MODEL = "text-embedding-3-small"
8
+ CHUNK_SIZE = 256
9
 
10
+ # document_processor.py
11
+ from llama_index.core import SimpleDirectoryReader, VectorStoreIndex, StorageContext
12
+ from llama_index.core.node_parser import SentenceSplitter
13
+ import streamlit as st
14
 
15
+ class DocumentProcessor:
16
+ def __init__(self):
17
+ self.splitter = SentenceSplitter(chunk_size=Config.CHUNK_SIZE)
18
+
19
+ def process_uploaded_file(self, uploaded_file):
20
+ file_path = f"./data/{uploaded_file.name}"
21
+ with open(file_path, 'wb') as f:
22
  f.write(uploaded_file.getbuffer())
23
+
24
+ reader = SimpleDirectoryReader(input_files=[file_path])
25
  documents = reader.load_data()
26
+ return documents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ def create_index(self, documents):
29
+ nodes = self.splitter.get_nodes_from_documents(documents)
 
 
30
  storage_context = StorageContext.from_defaults()
31
  storage_context.docstore.add_documents(nodes)
32
+ return VectorStoreIndex(nodes=nodes, storage_context=storage_context), nodes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ # retriever.py
35
+ from llama_index.retrievers.bm25 import BM25Retriever
36
+ from llama_index.core.retrievers import BaseRetriever
 
 
 
 
 
 
 
37
 
38
+ class HybridRetriever(BaseRetriever):
39
+ def __init__(self, vector_retriever, bm25_retriever):
40
+ self.vector_retriever = vector_retriever
41
+ self.bm25_retriever = bm25_retriever
42
+ super().__init__()
43
+
44
+ def _retrieve(self, query, **kwargs):
45
+ vector_nodes = self.vector_retriever.retrieve(query, **kwargs)
46
+ all_nodes = []
47
+ node_ids = set()
48
+ for n in vector_nodes:
49
+ if n.node.node_id not in node_ids:
50
+ all_nodes.append(n)
51
+ node_ids.add(n.node.node_id)
52
+ return all_nodes
53
+
54
+ # llm_service.py
55
+ import openai
56
 
57
+ class LLMService:
58
+ def __init__(self, model_name):
59
+ self.model_name = model_name
60
+ openai.api_key = Config.OPENAI_API_KEY
61
+
62
+ def generate_response(self, prompt, system_message="You are a helpful assistant who answers from the following context. If the answer can't be found in context, politely refuse"):
63
+ response = openai.chat.completions.create(
64
+ model=self.model_name,
65
+ messages=[
66
+ {"role": "system", "content": system_message},
67
+ {"role": "user", "content": prompt}
68
+ ]
69
+ )
70
+ return {
71
+ 'content': response.choices[0].message.content,
72
+ 'usage': {
73
+ 'prompt_tokens': response.usage.prompt_tokens,
74
+ 'completion_tokens': response.usage.completion_tokens,
75
+ 'total_tokens': response.usage.total_tokens
76
+ }
77
+ }
78
+
79
+ def generate_summary(self, text, temperature=0.6):
80
+ response = openai.chat.completions.create(
81
+ model=self.model_name,
82
+ temperature=temperature,
83
+ messages=[
84
+ {"role": "system", "content": "Summarize the following context:"},
85
+ {"role": "user", "content": text}
86
+ ]
87
+ )
88
+ return response.choices[0].message.content
89
+
90
+ # app.py
91
+ import streamlit as st
92
+ from config import Config
93
+ from document_processor import DocumentProcessor
94
+ from retriever import HybridRetriever
95
+ from llm_service import LLMService
96
+
97
+ class PromptOptimizationApp:
98
+ def __init__(self):
99
+ self.doc_processor = DocumentProcessor()
100
+ self.llm_service = LLMService(Config.MODEL_NAME)
101
+ self.initialize_session_state()
102
 
103
+ def initialize_session_state(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  if "token_summary" not in st.session_state:
105
  st.session_state.token_summary = []
106
  if "messages" not in st.session_state:
107
  st.session_state.messages = []
108
+
109
+ def process_documents(self, uploaded_files):
110
+ for uploaded_file in uploaded_files:
111
+ documents = self.doc_processor.process_uploaded_file(uploaded_file)
112
+ index, nodes = self.doc_processor.create_index(documents)
113
+
114
+ bm25_retriever = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=1)
115
+ vector_retriever = index.as_retriever(similarity_top_k=1)
116
+ hybrid_retriever = HybridRetriever(vector_retriever, bm25_retriever)
117
+
118
+ return documents, hybrid_retriever
119
+
120
+ def display_summaries(self, text):
121
+ st.success("Reference summary")
122
+ ref_summary = self.llm_service.generate_summary(text, temperature=0.6)
123
+ st.markdown(ref_summary)
124
+
125
+ st.success("Generated summary")
126
+ gen_summary = self.llm_service.generate_summary(text, temperature=0.8)
127
+ st.markdown(gen_summary)
128
+
129
+ def handle_chat(self, prompt, hybrid_retriever):
130
+ st.success("Fetching info...")
131
+ context_list = [n.get_content() for n in hybrid_retriever.retrieve(prompt)]
132
+ context = " ".join(context_list)
133
+
134
+ st.success("Getting context")
135
+ st.markdown(context)
136
+
137
+ full_prompt = "\n\n".join([context + prompt])
138
+ response = self.llm_service.generate_response(full_prompt)
139
+
140
+ st.session_state.messages.append({"role": "assistant", "content": response['content']})
141
+ with st.chat_message("assistant"):
142
+ st.markdown(response['content'])
143
+
144
+ return response
145
+
146
+ def main():
147
+ st.title("Prompt Optimization for a Policy Bot")
148
+
149
+ app = PromptOptimizationApp()
150
+
151
+ uploaded_files = st.file_uploader(
152
+ "Upload a Policy document in pdf format",
153
+ type="pdf",
154
+ accept_multiple_files=True
155
+ )
156
+
157
+ if uploaded_files:
158
+ documents, hybrid_retriever = app.process_documents(uploaded_files)
159
+ st.success("File uploaded...")
160
+
161
+ full_text = documents[0].text
162
+ st.success("Input text")
163
+ st.markdown(full_text)
164
+
165
+ app.display_summaries(full_text)
166
+
167
+ # Display chat history
168
  for message in st.session_state.messages:
169
  with st.chat_message(message["role"]):
170
  st.markdown(message["content"])
171
+
172
+ # Handle new chat input
 
173
  if prompt := st.chat_input("Enter your query:"):
 
 
174
  st.session_state.messages.append({"role": "user", "content": prompt})
175
  with st.chat_message("user"):
176
  st.markdown(prompt)
 
 
 
 
 
 
 
 
 
 
 
177
 
178
+ app.handle_chat(prompt, hybrid_retriever)
179
 
180
+ if __name__ == "__main__":
181
+ main()