import os from langchain_openai import OpenAIEmbeddings from langchain_mongodb.vectorstores import MongoDBAtlasVectorSearch from langchain_core.prompts import PromptTemplate from langchain.chains import RetrievalQA from langchain_openai import ChatOpenAI import logging from dotenv import load_dotenv load_dotenv() INDEX_NAME = "vector_index" DATABASE_NAME = "scraped_data_db" def mongo_rag_tool(query: str, collection_name: str) -> str: """ This function is used to retrieve documents from a MongoDB database and then use the RAG model to answer the query. The documents that are most semantically close to the query are returned. args: query: str: The query that you want to use to retrieve documents collection_name: str: The name of the collection in the MongoDB database returns: str: The answer to the query """ try: #collection_name = os.getenv("MONGODB_COLLECTION_NAME") # Connect to the MongoDB database openai_api_key = os.getenv("OPENAI_API_KEY") embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key, disallowed_special=(), model="text-embedding-3-small") uri = os.getenv("MONGO_CONNECTION_STRING") logging.info("Creating the mongo vector search object") vector_search = MongoDBAtlasVectorSearch.from_connection_string( uri, DATABASE_NAME + "." + collection_name, embeddings, index_name=INDEX_NAME, ) logging.info("Retrieving the documents and answering the query") # Retrieve the documents that are most semantically close to the query, exclude ones that are less similar than the threshold post_filter = [{"$project": {"_id": 0,"text": 1,"source": 1,"score":1,"embedding":1}}] qa_retriever = vector_search.as_retriever( search_type="mmr", search_kwargs={"k": 10, 'fetch_k':100, "post_filter_pipeline": post_filter}, ) prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. If you know the answer give a comprehensive, detailed and insightful answer. {context} Question: {question} """ PROMPT = PromptTemplate( template=prompt_template, input_variables=["context", "question"] ) qa = RetrievalQA.from_chain_type( llm=ChatOpenAI(api_key=openai_api_key, model="gpt-4o", temperature=0.2), chain_type="stuff", retriever=qa_retriever, return_source_documents=True, chain_type_kwargs={"prompt": PROMPT}, ) docs = qa.invoke({"query": query}) if docs: logging.info("Saving the retrieved documents") sources = docs["source_documents"] source_list = [{"content":result.page_content, "source":result.metadata.get("source", '')} for result in sources] formatted_sources = "\n".join([f"Content: {source['content']}\nSource: {source['source']}\n" for source in source_list]) return docs["result"], formatted_sources except Exception as e: logging.error(f"An error occurred: {str(e)}") return f"An error occurred: {str(e)}", "An error occurred: {str(e)}" #mongo_rag_tool("What do people think about caterpillar vision link fleet management app")