agshiv92 commited on
Commit
610ad42
·
verified ·
1 Parent(s): 1c9980f

Delete query_engine.py

Browse files
Files changed (1) hide show
  1. query_engine.py +0 -143
query_engine.py DELETED
@@ -1,143 +0,0 @@
1
- import os
2
- import yaml
3
- from dotenv import load_dotenv
4
- from pinecone import Pinecone
5
- from llama_index.vector_stores.pinecone import PineconeVectorStore
6
- from llama_index.core import VectorStoreIndex
7
- from llama_index.core.response.pprint_utils import pprint_source_node
8
- from llama_index.core import Settings
9
- from llama_index.embeddings.huggingface import HuggingFaceEmbedding
10
- from llama_index.llms.groq import Groq
11
- from llama_index.core.tools import QueryEngineTool
12
- from llama_index.core.query_engine import RouterQueryEngine
13
- from llama_index.core.selectors import LLMSingleSelector, LLMMultiSelector
14
- from llama_index.core.selectors import (
15
- PydanticMultiSelector,
16
- PydanticSingleSelector,
17
- )
18
- from llama_index.core import PromptTemplate
19
- from llama_index.core.response_synthesizers import TreeSummarize
20
- import nest_asyncio
21
- import asyncio
22
- nest_asyncio.apply()
23
- # Load environment variables from the .env file
24
- load_dotenv()
25
-
26
- # Function to load YAML configuration
27
- def load_config(config_path):
28
- with open(config_path, 'r') as file:
29
- config = yaml.safe_load(file)
30
- return config
31
- def load_prompt_template(prompt_template_path):
32
- with open(prompt_template_path, 'r') as file:
33
- prompt_template = yaml.safe_load(file)
34
- return prompt_template
35
- # Pinecone Index Connection
36
- def index_connection(config_path):
37
- """
38
- Initializes the Pinecone client and retrieves the index using the provided YAML configuration.
39
-
40
- Args:
41
- config_path (str): Path to the YAML configuration file.
42
-
43
- Returns:
44
- index: The initialized Pinecone index.
45
- """
46
- # Load the configuration from a YAML file
47
- config = load_config(config_path)
48
- embed_model_name = config['embeddings']['model_name']
49
- embed_model = HuggingFaceEmbedding(model_name=embed_model_name)
50
- model_name = config['model']['model_name']
51
- Settings.llm = Groq(model=model_name, api_key=os.getenv('GROQ_API_KEY'))
52
- Settings.embed_model = embed_model
53
- # Initialize the Pinecone client
54
- pc = Pinecone(
55
- api_key=os.getenv('PINECONE_API_KEY') # Get the Pinecone API key from the environment
56
- )
57
- index_name = config['pinecone']['index_name']
58
- summary_index_name = config['pinecone']['summary_index_name']
59
- index = pc.Index(index_name)
60
- summary_index = pc.Index(summary_index_name) # Get the Pinecone index using the index name from the config
61
- return index,summary_index
62
-
63
- # Initialize Pinecone Vector Store and Retriever
64
- def initialize_retriever(pinecone_index,summary_index):
65
- """
66
- Initializes the Pinecone vector store and sets up the retriever.
67
-
68
- Args:
69
- pinecone_index: The Pinecone index object.
70
-
71
- Returns:
72
- retriever: The initialized retriever for querying the vector store.
73
- """
74
-
75
- # Initialize Pinecone Vector Store
76
- vector_store = PineconeVectorStore(pinecone_index=pinecone_index, text_key="_node_content")
77
- summary_vector_store = PineconeVectorStore(pinecone_index=summary_index, text_key="_node_content")
78
- # Create the retriever using the VectorStoreIndex and configure similarity_top_k
79
- index = VectorStoreIndex.from_vector_store(vector_store=vector_store)
80
- summary_index = VectorStoreIndex.from_vector_store(vector_store=summary_vector_store)
81
- return index,summary_index
82
-
83
- # Query the Pinecone Index
84
- def index_retrieval(index, summary_index, query_text):
85
- """
86
- Queries the Pinecone index using the provided retriever and query text.
87
-
88
- Args:
89
- retriever: The initialized retriever.
90
- query_text (str): The text query to search for.
91
-
92
- Returns:
93
- str: Query result from the Pinecone index.
94
- """
95
- script_dir = os.path.dirname(os.path.abspath(__file__)) # Get the current script directory
96
- base_dir = os.path.dirname(script_dir)
97
- prompt_template_path = os.path.join(base_dir, 'model', 'prompt_template.yaml')
98
- prompt_template = load_prompt_template(prompt_template_path)
99
- QA_PROMPT = PromptTemplate(prompt_template['QA_PROMPT_TMPL'])
100
- # Execute the query using the retriever
101
- vector_query_engine = index.as_query_engine(text_qa_template=QA_PROMPT)
102
- summary_query_engine = summary_index.as_query_engine(text_qa_template=QA_PROMPT)
103
-
104
- vector_tool = QueryEngineTool.from_defaults(
105
- query_engine=vector_query_engine,
106
- description="Useful for answering questions about this context",
107
- )
108
-
109
- summary_tool = QueryEngineTool.from_defaults(
110
- query_engine=summary_query_engine,
111
- description="Useful for answering questions about this context",
112
- )
113
-
114
- tree_summarize = TreeSummarize(
115
- summary_template=PromptTemplate(prompt_template['TREE_SUMMARIZE_PROMPT_TMPL'])
116
- )
117
- query_engine = RouterQueryEngine(
118
- selector=LLMMultiSelector.from_defaults(),
119
- query_engine_tools=[
120
- vector_tool,
121
- summary_tool,
122
- ],
123
- summarizer=tree_summarize,)
124
- response = query_engine.query(query_text)
125
- return response
126
- # Example usage
127
- if __name__ == "__main__":
128
- # Dynamically determine the path to the config file
129
- script_dir = os.path.dirname(os.path.abspath(__file__)) # Get the current script directory
130
- base_dir = os.path.dirname(script_dir) # Go one level up
131
- config_path = os.path.join(base_dir, 'configs', 'config.yaml') # Path to 'config.yaml' in the 'configs' directory
132
-
133
- # Step 1: Initialize Pinecone Connection
134
- pinecone_index,summary_index = index_connection(config_path=config_path)
135
-
136
- # Step 2: Initialize the Retriever
137
- retriever,summary_retriever = initialize_retriever(pinecone_index,summary_index)
138
-
139
- # Step 3: Query the Pinecone index
140
- query_text = """How much can the Minister of Health pay out of the Consolidated Revenue Fund in relation to coronavirus disease 2019 (COVID-19) tests"""
141
- response = index_retrieval(retriever, summary_retriever, query_text)
142
- print(response)
143
- # Print the result (already printed by pprint_source_node)