File size: 6,029 Bytes
10e8311
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import os
import yaml
from dotenv import load_dotenv
from pinecone import Pinecone
from llama_index.vector_stores.pinecone import PineconeVectorStore
from llama_index.core import VectorStoreIndex
from llama_index.core.response.pprint_utils import pprint_source_node
from llama_index.core import Settings
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.llms.groq import Groq
from llama_index.core.tools import QueryEngineTool
from llama_index.core.query_engine import RouterQueryEngine
from llama_index.core.selectors import LLMSingleSelector, LLMMultiSelector
from llama_index.core.selectors import (
    PydanticMultiSelector,
    PydanticSingleSelector,
)
from llama_index.core import PromptTemplate
from llama_index.core.response_synthesizers import TreeSummarize
import nest_asyncio
import asyncio
nest_asyncio.apply()
# Load environment variables from the .env file
load_dotenv()

# Function to load YAML configuration
def load_config(config_path):
    with open(config_path, 'r') as file:
        config = yaml.safe_load(file)
    return config
def load_prompt_template(prompt_template_path):
    with open(prompt_template_path, 'r') as file:
        prompt_template = yaml.safe_load(file)
    return prompt_template
# Pinecone Index Connection
def index_connection(config_path):
    """

    Initializes the Pinecone client and retrieves the index using the provided YAML configuration.

    

    Args:

    config_path (str): Path to the YAML configuration file.

    

    Returns:

    index: The initialized Pinecone index.

    """
    # Load the configuration from a YAML file
    config = load_config(config_path)
    embed_model_name = config['embeddings']['model_name']
    embed_model = HuggingFaceEmbedding(model_name=embed_model_name)
    model_name = config['model']['model_name']
    Settings.llm = Groq(model=model_name, api_key=os.getenv('GROQ_API_KEY'))
    Settings.embed_model = embed_model   
    # Initialize the Pinecone client
    pc = Pinecone(
        api_key=os.getenv('PINECONE_API_KEY')  # Get the Pinecone API key from the environment
    )
    index_name = config['pinecone']['index_name']
    summary_index_name = config['pinecone']['summary_index_name']
    index = pc.Index(index_name) 
    summary_index = pc.Index(summary_index_name)  # Get the Pinecone index using the index name from the config
    return index,summary_index

# Initialize Pinecone Vector Store and Retriever
def initialize_retriever(pinecone_index,summary_index):
    """

    Initializes the Pinecone vector store and sets up the retriever.

    

    Args:

    pinecone_index: The Pinecone index object.

    

    Returns:

    retriever: The initialized retriever for querying the vector store.

    """

    # Initialize Pinecone Vector Store
    vector_store = PineconeVectorStore(pinecone_index=pinecone_index, text_key="_node_content")
    summary_vector_store = PineconeVectorStore(pinecone_index=summary_index, text_key="_node_content")
    # Create the retriever using the VectorStoreIndex and configure similarity_top_k
    index = VectorStoreIndex.from_vector_store(vector_store=vector_store)
    summary_index = VectorStoreIndex.from_vector_store(vector_store=summary_vector_store)   
    return index,summary_index

# Query the Pinecone Index
def index_retrieval(index, summary_index, query_text):
    """

    Queries the Pinecone index using the provided retriever and query text.

    

    Args:

    retriever: The initialized retriever.

    query_text (str): The text query to search for.

    

    Returns:

    str: Query result from the Pinecone index.

    """
    script_dir = os.path.dirname(os.path.abspath(__file__))  # Get the current script directory
    base_dir = os.path.dirname(script_dir) 
    prompt_template_path = os.path.join(base_dir, 'model', 'prompt_template.yaml')
    prompt_template = load_prompt_template(prompt_template_path)
    QA_PROMPT = PromptTemplate(prompt_template['QA_PROMPT_TMPL'])
    # Execute the query using the retriever
    vector_query_engine = index.as_query_engine(text_qa_template=QA_PROMPT)
    summary_query_engine = summary_index.as_query_engine(text_qa_template=QA_PROMPT)

    vector_tool = QueryEngineTool.from_defaults(
    query_engine=vector_query_engine,
    description="Useful for answering questions about this context",
    )

    summary_tool = QueryEngineTool.from_defaults(
        query_engine=summary_query_engine,
        description="Useful for answering questions about this context",
    )

    tree_summarize = TreeSummarize(
        summary_template=PromptTemplate(prompt_template['TREE_SUMMARIZE_PROMPT_TMPL'])
    )
    query_engine = RouterQueryEngine(
    selector=LLMMultiSelector.from_defaults(),
    query_engine_tools=[
        vector_tool,
        summary_tool,
    ],
    summarizer=tree_summarize,)
    response = query_engine.query(query_text)
    return response
# Example usage
if __name__ == "__main__":
    # Dynamically determine the path to the config file
    script_dir = os.path.dirname(os.path.abspath(__file__))  # Get the current script directory
    base_dir = os.path.dirname(script_dir)  # Go one level up
    config_path = os.path.join(base_dir, 'configs', 'config.yaml')  # Path to 'config.yaml' in the 'configs' directory
    
    # Step 1: Initialize Pinecone Connection
    pinecone_index,summary_index = index_connection(config_path=config_path)
    
    # Step 2: Initialize the Retriever
    retriever,summary_retriever = initialize_retriever(pinecone_index,summary_index)
    
    # Step 3: Query the Pinecone index
    query_text = """How much can the Minister of Health pay out of the Consolidated Revenue Fund in relation to coronavirus disease 2019 (COVID-19) tests"""
    response = index_retrieval(retriever, summary_retriever, query_text)
    print(response)
    # Print the result (already printed by pprint_source_node)