Goodnight7 commited on
Commit
bc723cd
·
verified ·
1 Parent(s): d87ca70

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +143 -0
utils.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils
2
+
3
+ from langchain_chroma import Chroma
4
+ from langchain_nomic.embeddings import NomicEmbeddings
5
+ from langchain_core.documents import Document
6
+ from langchain.retrievers.document_compressors import CohereRerank
7
+ #from langchain_core import CohereRerank
8
+ #from langchain_cohere import CohereRerank
9
+
10
+ from langchain.retrievers import ContextualCompressionRetriever
11
+ from langchain.retrievers import EnsembleRetriever
12
+ from langchain.retrievers import BM25Retriever
13
+ from langchain_groq import ChatGroq
14
+
15
+ from dotenv import load_dotenv
16
+ from langchain_core.prompts import ChatPromptTemplate
17
+ from langchain_core.runnables import Runnable, RunnableMap
18
+ from langchain.schema import BaseRetriever
19
+ from qdrant_client import models
20
+
21
+
22
+ from langchain_huggingface.embeddings import HuggingFaceEmbeddings
23
+
24
+ load_dotenv()
25
+
26
+ import os
27
+ LANGCHAIN_API_KEY = os.getenv('LANGCHAIN_API_KEY')
28
+
29
+ #Retriever
30
+
31
+ def get_retriever(n_docs=5): # renamed function
32
+ vector_database_path = "db"
33
+
34
+ embedding_model = NomicEmbeddings(model="nomic-embed-text-v1.5", inference_mode="local")
35
+
36
+ vectorstore = Chroma(collection_name="chromadb3",
37
+ persist_directory=vector_database_path,
38
+ embedding_function=embedding_model)
39
+
40
+ vs_retriever = vectorstore.as_retriever(k=n_docs)
41
+
42
+ # Get documents from vector store
43
+ try:
44
+ store_data = vectorstore.get()
45
+ texts = store_data['documents']
46
+ metadatas = store_data['metadatas']
47
+
48
+ if not texts: # If no documents found
49
+ print("Warning: No documents found in vector store. Using vector retriever only.")
50
+ return vs_retriever
51
+
52
+ # Create documents with explicit IDs
53
+ documents = []
54
+ for i, (text, metadata) in enumerate(zip(texts, metadatas)):
55
+ doc = Document(
56
+ page_content=text,
57
+ metadata=metadata if metadata else {},
58
+ id_=str(i) # Add explicit ID
59
+ )
60
+ documents.append(doc)
61
+
62
+ # Create BM25 retriever with explicit document handling
63
+ keyword_retriever = BM25Retriever.from_texts(
64
+ texts=[doc.page_content for doc in documents],
65
+ metadatas=[doc.metadata for doc in documents],
66
+ ids=[doc.id_ for doc in documents]
67
+ )
68
+ keyword_retriever.k = n_docs
69
+
70
+ ensemble_retriever = EnsembleRetriever(
71
+ retrievers=[vs_retriever, keyword_retriever],
72
+ weights=[0.5, 0.5]
73
+ )
74
+
75
+ compressor = CohereRerank(model="rerank-english-v3.0")
76
+ compression_retriever = ContextualCompressionRetriever(
77
+ base_compressor=compressor,
78
+ base_retriever=ensemble_retriever
79
+ )
80
+
81
+ return compression_retriever
82
+
83
+ except Exception as e:
84
+ print(f"Warning: Error creating combined retriever ({str(e)}). Using vector retriever only.")
85
+ return vs_retriever
86
+
87
+ #Retriever prompt
88
+ rag_prompt = """You are a medical chatbot designed to answer health-related questions.
89
+ The questions you will receive will primarily focus on medical topics and patient care.
90
+ Here is the context to use to answer the question:
91
+ {context}
92
+ Think carefully about the above context.
93
+ Now, review the user question:
94
+ {input}
95
+ Provide an answer to this question using only the above context.
96
+ Answer:"""
97
+
98
+ # Post-processing
99
+ def format_docs(docs):
100
+ return "\n\n".join(doc.page_content for doc in docs)
101
+
102
+ #RAG chain
103
+ def get_expression_chain(retriever: BaseRetriever, model_name="llama-3.1-70b-versatile", temp=0 ) -> Runnable:
104
+ """Return a chain defined primarily in LangChain Expression Language"""
105
+ def retrieve_context(input_text):
106
+ # Use the retriever to fetch relevant documents
107
+ docs = retriever.get_relevant_documents(input_text)
108
+ return format_docs(docs)
109
+
110
+ ingress = RunnableMap(
111
+ {
112
+ "input": lambda x: x["input"],
113
+ "context": lambda x: retrieve_context(x["input"]),
114
+ }
115
+ )
116
+ prompt = ChatPromptTemplate.from_messages(
117
+ [
118
+ (
119
+ "system",
120
+ rag_prompt
121
+ )
122
+ ]
123
+ )
124
+ llm = ChatGroq(model=model_name,api_key="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", temperature=temp)
125
+
126
+ chain = ingress | prompt | llm
127
+ return chain
128
+
129
+ embedding_model = NomicEmbeddings(model="nomic-embed-text-v1.5", inference_mode="local")
130
+ #embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
131
+
132
+ #Generate embeddings for a given text
133
+ def get_embeddings(text):
134
+ return embedding_model.embed([text], task_type='search_document')[0]
135
+
136
+
137
+ # Create or connect to a Qdrant collection
138
+ def create_qdrant_collection(client, collection_name):
139
+ if collection_name not in client.get_collections().collections:
140
+ client.create_collection(
141
+ collection_name=collection_name,
142
+ vectors_config=models.VectorParams(size=768, distance=models.Distance.COSINE)
143
+ )