from langchain_core.retrievers import BaseRetriever | |
from langchain_core.documents.base import Document | |
from langchain_core.vectorstores import VectorStore | |
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun | |
from typing import List | |
class ClimateQARetriever(BaseRetriever): | |
vectorstore: VectorStore | |
sources: list = [] | |
reports:list = [] | |
threshold: float = 0.01 | |
k_summary: int = 3 | |
k_total: int = 7 | |
min_size: int = 200 | |
filter: dict = None | |
def _get_relevant_documents( | |
self, query: str, *, run_manager: CallbackManagerForRetrieverRun | |
) -> List[Document]: | |
# Check if all elements in the list are either IPCC or IPBES | |
assert isinstance(self.sources,list) | |
# assert self.k_total > self.k_summary, "k_total should be greater than k_summary" | |
# Prepare base search kwargs | |
filters = {} | |
filters["source"] = { "$in":self.sources} | |
docs = self.vectorstore.similarity_search_with_score(query=query,k=self.k_total, filter=self.filter) | |
# Add score to metadata | |
results = [] | |
for i, (doc, score) in enumerate(docs): | |
# filtre les sources sous le seuil | |
if score < self.threshold: | |
continue | |
doc.metadata["similarity_score"] = score | |
doc.metadata["content"] = doc.page_content | |
doc.metadata["chunk_type"] = "text" | |
doc.metadata["page_number"] = 1 | |
results.append(doc) | |
return results | |