Jason Caro commited on
Commit
4263541
·
1 Parent(s): 0cf2500

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -35
app.py CHANGED
@@ -12,13 +12,12 @@ from langchain.prompts.chat import (
12
  HumanMessagePromptTemplate,
13
  )
14
 
 
15
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
16
-
17
  system_template = """
18
  Use the following pieces of context to answer the user's question. If the question cannot be answered with the supplied context, simply answer "I cannot determine this based on the provided context."
19
  ----------------
20
  {context}"""
21
-
22
  messages = [
23
  SystemMessagePromptTemplate.from_template(system_template),
24
  HumanMessagePromptTemplate.from_template("{question}"),
@@ -28,31 +27,25 @@ chain_type_kwargs = {"prompt": prompt}
28
 
29
  @cl.author_rename
30
  def rename(orig_author: str):
31
- rename_dict = {"RetrievalQA": "Consulting PageTurn"}
32
  return rename_dict.get(orig_author, orig_author)
33
 
 
34
  @cl.on_chat_start
35
  async def init():
36
  msg = cl.Message(content=f"Building Index...")
37
  await msg.send()
38
 
39
- # Read text from a .txt file
40
  with open('./data/aerodynamic_drag.txt', 'r', encoding='Windows-1252') as f:
41
  aerodynamic_drag_data = f.read()
42
 
43
- # Split the text into smaller chunks
44
  documents = text_splitter.create_documents([aerodynamic_drag_data])
45
-
46
- # Create a local file store for caching
47
  store = LocalFileStore("./cache/")
48
  core_embeddings_model = OpenAIEmbeddings()
49
  embedder = CacheBackedEmbeddings.from_bytes_store(
50
  core_embeddings_model, store, namespace=core_embeddings_model.model
51
  )
52
-
53
- # Make async docsearch
54
  docsearch = await cl.make_async(FAISS.from_documents)(documents, embedder)
55
-
56
  chain = RetrievalQA.from_chain_type(
57
  ChatOpenAI(model="gpt-4", temperature=0, streaming=True),
58
  chain_type="stuff",
@@ -63,9 +56,9 @@ async def init():
63
 
64
  msg.content = f"Index built!"
65
  await msg.send()
66
-
67
  cl.user_session.set("chain", chain)
68
 
 
69
  @cl.on_message
70
  async def main(message):
71
  chain = cl.user_session.get("chain")
@@ -74,28 +67,5 @@ async def main(message):
74
  )
75
  cb.answer_reached = True
76
  res = await chain.acall(message, callbacks=[cb], )
77
-
78
  answer = res["result"]
79
- source_elements = []
80
- visited_sources = set()
81
-
82
- # Get the documents from the user session
83
- docs = res["source_documents"]
84
- metadatas = [doc.metadata for doc in docs]
85
- all_sources = [m["source"] for m in metadatas]
86
-
87
- for source in all_sources:
88
- if source in visited_sources:
89
- continue
90
- visited_sources.add(source)
91
- # Create the text element referenced in the message
92
- source_elements.append(
93
- cl.Text(content="https://www.imdb.com" + source, name="Review URL")
94
- )
95
-
96
- if source_elements:
97
- answer += f"\nSources: {', '.join([e.content.decode('utf-8') for e in source_elements])}"
98
- else:
99
- answer += "\nNo sources found"
100
-
101
- await cl.Message(content=answer, elements=source_elements).send()
 
12
  HumanMessagePromptTemplate,
13
  )
14
 
15
+ # Initialize text splitter and other settings
16
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
 
17
  system_template = """
18
  Use the following pieces of context to answer the user's question. If the question cannot be answered with the supplied context, simply answer "I cannot determine this based on the provided context."
19
  ----------------
20
  {context}"""
 
21
  messages = [
22
  SystemMessagePromptTemplate.from_template(system_template),
23
  HumanMessagePromptTemplate.from_template("{question}"),
 
27
 
28
  @cl.author_rename
29
  def rename(orig_author: str):
30
+ rename_dict = {"RetrievalQA": "PageTurn"}
31
  return rename_dict.get(orig_author, orig_author)
32
 
33
+ # Initialize the index and other setup
34
  @cl.on_chat_start
35
  async def init():
36
  msg = cl.Message(content=f"Building Index...")
37
  await msg.send()
38
 
 
39
  with open('./data/aerodynamic_drag.txt', 'r', encoding='Windows-1252') as f:
40
  aerodynamic_drag_data = f.read()
41
 
 
42
  documents = text_splitter.create_documents([aerodynamic_drag_data])
 
 
43
  store = LocalFileStore("./cache/")
44
  core_embeddings_model = OpenAIEmbeddings()
45
  embedder = CacheBackedEmbeddings.from_bytes_store(
46
  core_embeddings_model, store, namespace=core_embeddings_model.model
47
  )
 
 
48
  docsearch = await cl.make_async(FAISS.from_documents)(documents, embedder)
 
49
  chain = RetrievalQA.from_chain_type(
50
  ChatOpenAI(model="gpt-4", temperature=0, streaming=True),
51
  chain_type="stuff",
 
56
 
57
  msg.content = f"Index built!"
58
  await msg.send()
 
59
  cl.user_session.set("chain", chain)
60
 
61
+ # Main function to handle incoming queries
62
  @cl.on_message
63
  async def main(message):
64
  chain = cl.user_session.get("chain")
 
67
  )
68
  cb.answer_reached = True
69
  res = await chain.acall(message, callbacks=[cb], )
 
70
  answer = res["result"]
71
+ await cl.Message(content=answer).send()