SimaFarazi commited on
Commit
5e0c914
·
1 Parent(s): 45a81b0

enable filtering based on file name

Browse files
app_stream_rag/app/chains.py CHANGED
@@ -54,7 +54,12 @@ input_to_rag_chain = input_1 | input_2
54
  # HistoryInput and the LLM to build the rag_chain.
55
  rag_chain = (input_to_rag_chain | rag_prompt_formatted | llm).with_types(input_type=schemas.HistoryInput)
56
 
57
- # TODO: Implement the filtered_rag_chain. It should be the
58
  # same as the rag_chain but with hybrid_search = True.
59
- filtered_rag_chain = None
 
 
 
 
 
60
 
 
54
  # HistoryInput and the LLM to build the rag_chain.
55
  rag_chain = (input_to_rag_chain | rag_prompt_formatted | llm).with_types(input_type=schemas.HistoryInput)
56
 
57
+ # Implement the filtered_rag_chain. It should be the
58
  # same as the rag_chain but with hybrid_search = True.
59
+ input_2_filtered = {
60
+ 'context': lambda x: format_context(data_indexer.search(x['new_question'], hybrid_search=True)),
61
+ 'standalone_question': lambda x: x['new_question']
62
+ }
63
+ input_to_filtered_rag_chain = input_1 | input_2_filtered
64
+ filtered_rag_chain = (input_to_filtered_rag_chain | rag_prompt_formatted | llm).with_types(input_type=schemas.HistoryInput)
65
 
app_stream_rag/app/main.py CHANGED
@@ -14,7 +14,8 @@ from chains import (
14
  simple_chain,
15
  formatted_chain,
16
  history_chain,
17
- rag_chain
 
18
  )
19
 
20
  import models
@@ -148,6 +149,38 @@ async def rag_stream(request: Request, db: Session = Depends(get_db)):
148
  callbacks=[LogResponseCallback(user_request=user_request, db=db)]
149
  ))
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  if __name__ == "__main__":
152
  import uvicorn
153
  uvicorn.run("main:app", host="localhost", reload=True, port=8000)
 
14
  simple_chain,
15
  formatted_chain,
16
  history_chain,
17
+ rag_chain,
18
+ filtered_rag_chain
19
  )
20
 
21
  import models
 
149
  callbacks=[LogResponseCallback(user_request=user_request, db=db)]
150
  ))
151
 
152
+ @app.post("/filtered_rag/stream")
153
+ async def rag_stream(request: Request, db: Session = Depends(get_db)):
154
+ # Receive request that had hit the endpoint
155
+ data = await request.json()
156
+ # Parse request into a user request
157
+ user_request = schemas.UserRequest(**data['input'])
158
+ username = user_request.username
159
+ question = user_request.question
160
+
161
+ # Pull the chat history of the user based on the user request
162
+ user_messages = crud.get_user_chat_history(db, username)
163
+
164
+ # Use add_message & add the current question as part of the user history
165
+ message = schemas.MessageBase(
166
+ message=question,
167
+ type= "user",
168
+ timestamp=datetime.now()
169
+ )
170
+
171
+ crud.add_message(db, message, username)
172
+
173
+ # create an instance of HistoryInput by using format_chat_history
174
+ user_chat_history = prompts.format_chat_history(user_messages)
175
+ history_input = schemas.HistoryInput(question=question, chat_history=user_chat_history)
176
+
177
+ # Use the history input within the rag chain
178
+ return EventSourceResponse(generate_stream(
179
+ history_input,
180
+ filtered_rag_chain,
181
+ callbacks=[LogResponseCallback(user_request=user_request, db=db)]
182
+ ))
183
+
184
  if __name__ == "__main__":
185
  import uvicorn
186
  uvicorn.run("main:app", host="localhost", reload=True, port=8000)
app_stream_rag/test.py CHANGED
@@ -3,7 +3,7 @@ from langserve import RemoteRunnable
3
  # If we put /simple/stream, it complains; because chain.stream will hit /simple/stream endpoint
4
  url = "https://simafarazi-backend-c.hf.space/rag"
5
  chain = RemoteRunnable(url) #Client for iteracting with LangChain runnables that are hosted as LangServe endpoints
6
- stream = chain.stream(input={"question":"Where do you recommend me for a 7 days hike in europe during April?",
7
  "username": "Sima"}) # .stream() and .invoke() are standard methods to interact with hosted runnables
8
 
9
 
 
3
  # If we put /simple/stream, it complains; because chain.stream will hit /simple/stream endpoint
4
  url = "https://simafarazi-backend-c.hf.space/rag"
5
  chain = RemoteRunnable(url) #Client for iteracting with LangChain runnables that are hosted as LangServe endpoints
6
+ stream = chain.stream(input={"question":"What does the following code do? input_1 = RunnablePassthrough.assign(new_question=standalone_chain)",
7
  "username": "Sima"}) # .stream() and .invoke() are standard methods to interact with hosted runnables
8
 
9