PsTuts-RAG / app.py
mbudisic's picture
Async implementation of vector store population
8bc5e76
raw
history blame
2.37 kB
import asyncio
from typing import List
import chainlit as cl
import json
import os
from dotenv import load_dotenv
from langchain_experimental.text_splitter import SemanticChunker
from langchain_openai import ChatOpenAI
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain_core.documents import Document
from langchain_qdrant import QdrantVectorStore
from pstuts_rag.rag import RAGChainFactory, RetrieverFactory
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
from dataclasses import dataclass
import pstuts_rag.rag
@dataclass
class ApplicationParameters:
filename = "data/test.json"
embedding_model = "text-embedding-3-small"
n_context_docs = 2
llm_model = "gpt-4.1-mini"
def set_api_key_if_not_present(key_name, prompt_message=""):
if len(prompt_message) == 0:
prompt_message = key_name
if key_name not in os.environ or not os.environ[key_name]:
os.environ[key_name] = getpass.getpass(prompt_message)
class ApplicationState:
embeddings: OpenAIEmbeddings = None
docs: List[Document] = []
qdrantclient: QdrantClient = None
vectorstore: QdrantVectorStore = None
retriever_factory: pstuts_rag.rag.RetrieverFactory
rag_factory: pstuts_rag.rag.RAGChainFactory
def __init__(self) -> None:
load_dotenv()
set_api_key_if_not_present("OPENAI_API_KEY")
state = ApplicationState()
params = ApplicationParameters()
@cl.on_chat_start
async def on_chat_start():
state.client = QdrantClient(":memory:")
state.retriever_factory = pstuts_rag.rag.RetrieverFactory(
qdrant_client=state.client, name="local_test"
)
if state.retriever_factory.count_docs() == 0:
data: List[Dict[str, Any]] = json.load(open(params.filename, "rb"))
asyncio.run(main=state.retriever_factory.aadd_docs(raw_docs=data))
state.rag_factory = pstuts_rag.rag.RAGChainFactory(
retriever=state.retriever_factory.get_retriever()
)
state.llm = ChatOpenAI(model=params.llm_model, temperature=0)
state.rag_chain = state.rag_factory.get_rag_chain(state.llm)
@cl.on_message
async def main(message: cl.Message):
# Send a response back to the user
v = await state.rag_chain.ainvoke(message.content)
await cl.Message(content=v.content).send()
if __name__ == "__main__":
main()