Spaces:
Sleeping
Sleeping
import os | |
from langchain_openai import OpenAIEmbeddings | |
from langchain_mongodb.vectorstores import MongoDBAtlasVectorSearch | |
from langchain_core.prompts import PromptTemplate | |
from langchain.chains import RetrievalQA | |
from langchain_openai import ChatOpenAI | |
import logging | |
from dotenv import load_dotenv | |
load_dotenv() | |
INDEX_NAME = "vector_index" | |
DATABASE_NAME = "scraped_data_db" | |
def mongo_rag_tool(query: str) -> str: | |
""" | |
This function is used to retrieve documents from a MongoDB database and then use the RAG model to answer the query. | |
The documents that are most semantically close to the query are returned. | |
args: | |
query: str: The query that you want to use to retrieve documents | |
collection_name: str: The name of the collection in the MongoDB database | |
output_filename: str: The name of the output file where the results will be saved | |
returns: | |
str: The answer to the query | |
""" | |
try: | |
collection_name = os.getenv("MONGODB_COLLECTION_NAME") | |
# Connect to the MongoDB database | |
openai_api_key = os.getenv("OPENAI_API_KEY") | |
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key, disallowed_special=(), model="text-embedding-3-small") | |
uri = os.getenv("MONGO_CONNECTION_STRING") | |
logging.info("Creating the mongo vector search object") | |
vector_search = MongoDBAtlasVectorSearch.from_connection_string( | |
uri, | |
DATABASE_NAME + "." + collection_name, | |
embeddings, | |
index_name=INDEX_NAME, | |
) | |
logging.info("Retrieving the documents and answering the query") | |
# Retrieve the documents that are most semantically close to the query, exclude ones that are less similar than the threshold | |
post_filter = [{"$project": {"_id": 0,"text": 1,"source": 1,"score":1,"embedding":1}}] | |
qa_retriever = vector_search.as_retriever( | |
search_type="mmr", | |
search_kwargs={"k": 10, 'fetch_k':100, "post_filter_pipeline": post_filter}, | |
) | |
prompt_template = """Use the following pieces of context to answer the question at the end. | |
If you don't know the answer, just say that you don't know, don't try to make up an answer. | |
If you know the answer give a comprehensive, detailed and insightful answer. | |
{context} | |
Question: {question} | |
""" | |
PROMPT = PromptTemplate( | |
template=prompt_template, input_variables=["context", "question"] | |
) | |
qa = RetrievalQA.from_chain_type( | |
llm=ChatOpenAI(api_key=openai_api_key, model="gpt-4o", temperature=0.2), | |
chain_type="stuff", | |
retriever=qa_retriever, | |
return_source_documents=True, | |
chain_type_kwargs={"prompt": PROMPT}, | |
) | |
docs = qa.invoke({"query": query}) | |
if docs: | |
logging.info("Saving the retrieved documents") | |
sources = docs["source_documents"] | |
source_list = [{"content":result.page_content, "source":result.metadata["source"]} for result in sources] | |
formatted_sources = "\n".join([f"Content: {source['content']}\nSource: {source['source']}\n" for source in source_list]) | |
return docs["result"], formatted_sources | |
except Exception as e: | |
logging.error(f"An error occurred: {str(e)}") | |
return f"An error occurred: {str(e)}", "An error occurred: {str(e)}" | |
#mongo_rag_tool("What do people think about caterpillar vision link fleet management app") |