Spaces:
Sleeping
Sleeping
Jason Caro
commited on
Commit
·
4263541
1
Parent(s):
0cf2500
Update app.py
Browse files
app.py
CHANGED
@@ -12,13 +12,12 @@ from langchain.prompts.chat import (
|
|
12 |
HumanMessagePromptTemplate,
|
13 |
)
|
14 |
|
|
|
15 |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
|
16 |
-
|
17 |
system_template = """
|
18 |
Use the following pieces of context to answer the user's question. If the question cannot be answered with the supplied context, simply answer "I cannot determine this based on the provided context."
|
19 |
----------------
|
20 |
{context}"""
|
21 |
-
|
22 |
messages = [
|
23 |
SystemMessagePromptTemplate.from_template(system_template),
|
24 |
HumanMessagePromptTemplate.from_template("{question}"),
|
@@ -28,31 +27,25 @@ chain_type_kwargs = {"prompt": prompt}
|
|
28 |
|
29 |
@cl.author_rename
|
30 |
def rename(orig_author: str):
|
31 |
-
rename_dict = {"RetrievalQA": "
|
32 |
return rename_dict.get(orig_author, orig_author)
|
33 |
|
|
|
34 |
@cl.on_chat_start
|
35 |
async def init():
|
36 |
msg = cl.Message(content=f"Building Index...")
|
37 |
await msg.send()
|
38 |
|
39 |
-
# Read text from a .txt file
|
40 |
with open('./data/aerodynamic_drag.txt', 'r', encoding='Windows-1252') as f:
|
41 |
aerodynamic_drag_data = f.read()
|
42 |
|
43 |
-
# Split the text into smaller chunks
|
44 |
documents = text_splitter.create_documents([aerodynamic_drag_data])
|
45 |
-
|
46 |
-
# Create a local file store for caching
|
47 |
store = LocalFileStore("./cache/")
|
48 |
core_embeddings_model = OpenAIEmbeddings()
|
49 |
embedder = CacheBackedEmbeddings.from_bytes_store(
|
50 |
core_embeddings_model, store, namespace=core_embeddings_model.model
|
51 |
)
|
52 |
-
|
53 |
-
# Make async docsearch
|
54 |
docsearch = await cl.make_async(FAISS.from_documents)(documents, embedder)
|
55 |
-
|
56 |
chain = RetrievalQA.from_chain_type(
|
57 |
ChatOpenAI(model="gpt-4", temperature=0, streaming=True),
|
58 |
chain_type="stuff",
|
@@ -63,9 +56,9 @@ async def init():
|
|
63 |
|
64 |
msg.content = f"Index built!"
|
65 |
await msg.send()
|
66 |
-
|
67 |
cl.user_session.set("chain", chain)
|
68 |
|
|
|
69 |
@cl.on_message
|
70 |
async def main(message):
|
71 |
chain = cl.user_session.get("chain")
|
@@ -74,28 +67,5 @@ async def main(message):
|
|
74 |
)
|
75 |
cb.answer_reached = True
|
76 |
res = await chain.acall(message, callbacks=[cb], )
|
77 |
-
|
78 |
answer = res["result"]
|
79 |
-
|
80 |
-
visited_sources = set()
|
81 |
-
|
82 |
-
# Get the documents from the user session
|
83 |
-
docs = res["source_documents"]
|
84 |
-
metadatas = [doc.metadata for doc in docs]
|
85 |
-
all_sources = [m["source"] for m in metadatas]
|
86 |
-
|
87 |
-
for source in all_sources:
|
88 |
-
if source in visited_sources:
|
89 |
-
continue
|
90 |
-
visited_sources.add(source)
|
91 |
-
# Create the text element referenced in the message
|
92 |
-
source_elements.append(
|
93 |
-
cl.Text(content="https://www.imdb.com" + source, name="Review URL")
|
94 |
-
)
|
95 |
-
|
96 |
-
if source_elements:
|
97 |
-
answer += f"\nSources: {', '.join([e.content.decode('utf-8') for e in source_elements])}"
|
98 |
-
else:
|
99 |
-
answer += "\nNo sources found"
|
100 |
-
|
101 |
-
await cl.Message(content=answer, elements=source_elements).send()
|
|
|
12 |
HumanMessagePromptTemplate,
|
13 |
)
|
14 |
|
15 |
+
# Initialize text splitter and other settings
|
16 |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
|
|
|
17 |
system_template = """
|
18 |
Use the following pieces of context to answer the user's question. If the question cannot be answered with the supplied context, simply answer "I cannot determine this based on the provided context."
|
19 |
----------------
|
20 |
{context}"""
|
|
|
21 |
messages = [
|
22 |
SystemMessagePromptTemplate.from_template(system_template),
|
23 |
HumanMessagePromptTemplate.from_template("{question}"),
|
|
|
27 |
|
28 |
@cl.author_rename
|
29 |
def rename(orig_author: str):
|
30 |
+
rename_dict = {"RetrievalQA": "PageTurn"}
|
31 |
return rename_dict.get(orig_author, orig_author)
|
32 |
|
33 |
+
# Initialize the index and other setup
|
34 |
@cl.on_chat_start
|
35 |
async def init():
|
36 |
msg = cl.Message(content=f"Building Index...")
|
37 |
await msg.send()
|
38 |
|
|
|
39 |
with open('./data/aerodynamic_drag.txt', 'r', encoding='Windows-1252') as f:
|
40 |
aerodynamic_drag_data = f.read()
|
41 |
|
|
|
42 |
documents = text_splitter.create_documents([aerodynamic_drag_data])
|
|
|
|
|
43 |
store = LocalFileStore("./cache/")
|
44 |
core_embeddings_model = OpenAIEmbeddings()
|
45 |
embedder = CacheBackedEmbeddings.from_bytes_store(
|
46 |
core_embeddings_model, store, namespace=core_embeddings_model.model
|
47 |
)
|
|
|
|
|
48 |
docsearch = await cl.make_async(FAISS.from_documents)(documents, embedder)
|
|
|
49 |
chain = RetrievalQA.from_chain_type(
|
50 |
ChatOpenAI(model="gpt-4", temperature=0, streaming=True),
|
51 |
chain_type="stuff",
|
|
|
56 |
|
57 |
msg.content = f"Index built!"
|
58 |
await msg.send()
|
|
|
59 |
cl.user_session.set("chain", chain)
|
60 |
|
61 |
+
# Main function to handle incoming queries
|
62 |
@cl.on_message
|
63 |
async def main(message):
|
64 |
chain = cl.user_session.get("chain")
|
|
|
67 |
)
|
68 |
cb.answer_reached = True
|
69 |
res = await chain.acall(message, callbacks=[cb], )
|
|
|
70 |
answer = res["result"]
|
71 |
+
await cl.Message(content=answer).send()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|