File size: 6,029 Bytes
10e8311 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import os
import yaml
from dotenv import load_dotenv
from pinecone import Pinecone
from llama_index.vector_stores.pinecone import PineconeVectorStore
from llama_index.core import VectorStoreIndex
from llama_index.core.response.pprint_utils import pprint_source_node
from llama_index.core import Settings
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.llms.groq import Groq
from llama_index.core.tools import QueryEngineTool
from llama_index.core.query_engine import RouterQueryEngine
from llama_index.core.selectors import LLMSingleSelector, LLMMultiSelector
from llama_index.core.selectors import (
PydanticMultiSelector,
PydanticSingleSelector,
)
from llama_index.core import PromptTemplate
from llama_index.core.response_synthesizers import TreeSummarize
import nest_asyncio
import asyncio
nest_asyncio.apply()
# Load environment variables from the .env file
load_dotenv()
# Function to load YAML configuration
def load_config(config_path):
with open(config_path, 'r') as file:
config = yaml.safe_load(file)
return config
def load_prompt_template(prompt_template_path):
with open(prompt_template_path, 'r') as file:
prompt_template = yaml.safe_load(file)
return prompt_template
# Pinecone Index Connection
def index_connection(config_path):
"""
Initializes the Pinecone client and retrieves the index using the provided YAML configuration.
Args:
config_path (str): Path to the YAML configuration file.
Returns:
index: The initialized Pinecone index.
"""
# Load the configuration from a YAML file
config = load_config(config_path)
embed_model_name = config['embeddings']['model_name']
embed_model = HuggingFaceEmbedding(model_name=embed_model_name)
model_name = config['model']['model_name']
Settings.llm = Groq(model=model_name, api_key=os.getenv('GROQ_API_KEY'))
Settings.embed_model = embed_model
# Initialize the Pinecone client
pc = Pinecone(
api_key=os.getenv('PINECONE_API_KEY') # Get the Pinecone API key from the environment
)
index_name = config['pinecone']['index_name']
summary_index_name = config['pinecone']['summary_index_name']
index = pc.Index(index_name)
summary_index = pc.Index(summary_index_name) # Get the Pinecone index using the index name from the config
return index,summary_index
# Initialize Pinecone Vector Store and Retriever
def initialize_retriever(pinecone_index,summary_index):
"""
Initializes the Pinecone vector store and sets up the retriever.
Args:
pinecone_index: The Pinecone index object.
Returns:
retriever: The initialized retriever for querying the vector store.
"""
# Initialize Pinecone Vector Store
vector_store = PineconeVectorStore(pinecone_index=pinecone_index, text_key="_node_content")
summary_vector_store = PineconeVectorStore(pinecone_index=summary_index, text_key="_node_content")
# Create the retriever using the VectorStoreIndex and configure similarity_top_k
index = VectorStoreIndex.from_vector_store(vector_store=vector_store)
summary_index = VectorStoreIndex.from_vector_store(vector_store=summary_vector_store)
return index,summary_index
# Query the Pinecone Index
def index_retrieval(index, summary_index, query_text):
"""
Queries the Pinecone index using the provided retriever and query text.
Args:
retriever: The initialized retriever.
query_text (str): The text query to search for.
Returns:
str: Query result from the Pinecone index.
"""
script_dir = os.path.dirname(os.path.abspath(__file__)) # Get the current script directory
base_dir = os.path.dirname(script_dir)
prompt_template_path = os.path.join(base_dir, 'model', 'prompt_template.yaml')
prompt_template = load_prompt_template(prompt_template_path)
QA_PROMPT = PromptTemplate(prompt_template['QA_PROMPT_TMPL'])
# Execute the query using the retriever
vector_query_engine = index.as_query_engine(text_qa_template=QA_PROMPT)
summary_query_engine = summary_index.as_query_engine(text_qa_template=QA_PROMPT)
vector_tool = QueryEngineTool.from_defaults(
query_engine=vector_query_engine,
description="Useful for answering questions about this context",
)
summary_tool = QueryEngineTool.from_defaults(
query_engine=summary_query_engine,
description="Useful for answering questions about this context",
)
tree_summarize = TreeSummarize(
summary_template=PromptTemplate(prompt_template['TREE_SUMMARIZE_PROMPT_TMPL'])
)
query_engine = RouterQueryEngine(
selector=LLMMultiSelector.from_defaults(),
query_engine_tools=[
vector_tool,
summary_tool,
],
summarizer=tree_summarize,)
response = query_engine.query(query_text)
return response
# Example usage
if __name__ == "__main__":
# Dynamically determine the path to the config file
script_dir = os.path.dirname(os.path.abspath(__file__)) # Get the current script directory
base_dir = os.path.dirname(script_dir) # Go one level up
config_path = os.path.join(base_dir, 'configs', 'config.yaml') # Path to 'config.yaml' in the 'configs' directory
# Step 1: Initialize Pinecone Connection
pinecone_index,summary_index = index_connection(config_path=config_path)
# Step 2: Initialize the Retriever
retriever,summary_retriever = initialize_retriever(pinecone_index,summary_index)
# Step 3: Query the Pinecone index
query_text = """How much can the Minister of Health pay out of the Consolidated Revenue Fund in relation to coronavirus disease 2019 (COVID-19) tests"""
response = index_retrieval(retriever, summary_retriever, query_text)
print(response)
# Print the result (already printed by pprint_source_node)
|