Spaces:
Sleeping
Sleeping
File size: 3,681 Bytes
8bc5e76 9063e00 8bc5e76 d89987a 9063e00 d89987a 8bc5e76 9063e00 cf7b605 d89987a 9063e00 d89987a 317ff97 9063e00 c419378 9063e00 70159ab 8bc5e76 9063e00 cf7b605 8bc5e76 cf7b605 8bc5e76 c419378 8bc5e76 9063e00 8bc5e76 9063e00 cf7b605 c419378 3b978ee c419378 cf7b605 8bc5e76 cf7b605 9063e00 8bc5e76 cf7b605 70159ab b44bcb9 c419378 cf7b605 70159ab c419378 b44bcb9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import asyncio
import json
import os
from dataclasses import dataclass
from typing import Any, Dict, List
import chainlit as cl
from dotenv import load_dotenv
from langchain_core.documents import Document
from langchain_core.language_models import BaseChatModel
from langchain_core.runnables import Runnable
from langchain_openai import ChatOpenAI
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain_qdrant import QdrantVectorStore
from qdrant_client import QdrantClient
import pstuts_rag.datastore
import pstuts_rag.rag
from pstuts_rag.datastore import load_json_files
@dataclass
class ApplicationParameters:
filename = [f"data/{f}.json" for f in ["dev"]]
embedding_model = "text-embedding-3-small"
n_context_docs = 2
llm_model = "gpt-4.1-mini"
def set_api_key_if_not_present(key_name, prompt_message=""):
if len(prompt_message) == 0:
prompt_message = key_name
if key_name not in os.environ or not os.environ[key_name]:
os.environ[key_name] = getpass.getpass(prompt_message)
class ApplicationState:
embeddings: OpenAIEmbeddings = None
docs: List[Document] = []
qdrant_client: QdrantClient = None
vector_store: QdrantVectorStore = None
datastore_manager: pstuts_rag.datastore.DatastoreManager
rag_factory: pstuts_rag.rag.RAGChainFactory
llm: BaseChatModel
rag_chain: Runnable
hasLoaded: asyncio.Event = asyncio.Event()
pointsLoaded: int = 0
def __init__(self) -> None:
load_dotenv()
set_api_key_if_not_present("OPENAI_API_KEY")
state = ApplicationState()
params = ApplicationParameters()
async def fill_the_db():
if state.datastore_manager.count_docs() == 0:
data: List[Dict[str, Any]] = await load_json_files(params.filename)
state.pointsLoaded = await state.datastore_manager.embed_chunks(
raw_docs=data
)
await cl.Message(
content=f"✅ The database has been loaded with {state.pointsLoaded} elements!"
).send()
async def build_the_chain():
state.rag_factory = pstuts_rag.rag.RAGChainFactory(
retriever=state.datastore_manager.get_retriever()
)
state.llm = ChatOpenAI(model=params.llm_model, temperature=0)
state.rag_chain = state.rag_factory.get_rag_chain(state.llm)
pass
@cl.on_chat_start
async def on_chat_start():
state.qdrant_client = QdrantClient(":memory:")
state.datastore_manager = pstuts_rag.datastore.DatastoreManager(
qdrant_client=state.qdrant_client, name="local_test"
)
asyncio.run(main=fill_the_db())
asyncio.run(main=build_the_chain())
@cl.on_message
async def main(message: cl.Message):
# Send a response back to the user
msg = cl.Message(content="")
response = await state.rag_chain.ainvoke({"question": message.content})
text, references = pstuts_rag.rag.RAGChainFactory.unpack_references(
response.content
)
if isinstance(text, str):
for token in [char for char in text]:
await msg.stream_token(token)
await msg.send()
references = json.loads(references)
print(references)
msg_references = [
(
f"Watch {ref["title"]} from timestamp "
f"{round(ref["start"] // 60)}m:{round(ref["start"] % 60)}s",
cl.Video(
name=ref["title"],
url=f"{ref["source"]}#t={ref["start"]}",
display="side",
),
)
for ref in references
]
await cl.Message(content="Related videos").send()
for e in msg_references:
await cl.Message(content=e[0], elements=[e[1]]).send()
if __name__ == "__main__":
main()
|