|
import os
|
|
import yaml
|
|
from dotenv import load_dotenv
|
|
from pinecone import Pinecone
|
|
from llama_index.vector_stores.pinecone import PineconeVectorStore
|
|
from llama_index.core import VectorStoreIndex
|
|
from llama_index.core.response.pprint_utils import pprint_source_node
|
|
from llama_index.core import Settings
|
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
|
from llama_index.llms.groq import Groq
|
|
from llama_index.core.tools import QueryEngineTool
|
|
from llama_index.core.query_engine import RouterQueryEngine
|
|
from llama_index.core.selectors import LLMSingleSelector, LLMMultiSelector
|
|
from llama_index.core.selectors import (
|
|
PydanticMultiSelector,
|
|
PydanticSingleSelector,
|
|
)
|
|
from llama_index.core import PromptTemplate
|
|
from llama_index.core.response_synthesizers import TreeSummarize
|
|
import nest_asyncio
|
|
import asyncio
|
|
nest_asyncio.apply()
|
|
|
|
load_dotenv()
|
|
|
|
|
|
def load_config(config_path):
|
|
with open(config_path, 'r') as file:
|
|
config = yaml.safe_load(file)
|
|
return config
|
|
def load_prompt_template(prompt_template_path):
|
|
with open(prompt_template_path, 'r') as file:
|
|
prompt_template = yaml.safe_load(file)
|
|
return prompt_template
|
|
|
|
def index_connection(config_path):
|
|
"""
|
|
Initializes the Pinecone client and retrieves the index using the provided YAML configuration.
|
|
|
|
Args:
|
|
config_path (str): Path to the YAML configuration file.
|
|
|
|
Returns:
|
|
index: The initialized Pinecone index.
|
|
"""
|
|
|
|
config = load_config(config_path)
|
|
embed_model_name = config['embeddings']['model_name']
|
|
embed_model = HuggingFaceEmbedding(model_name=embed_model_name)
|
|
model_name = config['model']['model_name']
|
|
Settings.llm = Groq(model=model_name, api_key=os.getenv('GROQ_API_KEY'))
|
|
Settings.embed_model = embed_model
|
|
|
|
pc = Pinecone(
|
|
api_key=os.getenv('PINECONE_API_KEY')
|
|
)
|
|
index_name = config['pinecone']['index_name']
|
|
summary_index_name = config['pinecone']['summary_index_name']
|
|
index = pc.Index(index_name)
|
|
summary_index = pc.Index(summary_index_name)
|
|
return index,summary_index
|
|
|
|
|
|
def initialize_retriever(pinecone_index,summary_index):
|
|
"""
|
|
Initializes the Pinecone vector store and sets up the retriever.
|
|
|
|
Args:
|
|
pinecone_index: The Pinecone index object.
|
|
|
|
Returns:
|
|
retriever: The initialized retriever for querying the vector store.
|
|
"""
|
|
|
|
|
|
vector_store = PineconeVectorStore(pinecone_index=pinecone_index, text_key="_node_content")
|
|
summary_vector_store = PineconeVectorStore(pinecone_index=summary_index, text_key="_node_content")
|
|
|
|
index = VectorStoreIndex.from_vector_store(vector_store=vector_store)
|
|
summary_index = VectorStoreIndex.from_vector_store(vector_store=summary_vector_store)
|
|
return index,summary_index
|
|
|
|
|
|
def index_retrieval(index, summary_index, query_text):
|
|
"""
|
|
Queries the Pinecone index using the provided retriever and query text.
|
|
|
|
Args:
|
|
retriever: The initialized retriever.
|
|
query_text (str): The text query to search for.
|
|
|
|
Returns:
|
|
str: Query result from the Pinecone index.
|
|
"""
|
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
base_dir = os.path.dirname(script_dir)
|
|
prompt_template_path = os.path.join(base_dir, 'model', 'prompt_template.yaml')
|
|
prompt_template = load_prompt_template(prompt_template_path)
|
|
QA_PROMPT = PromptTemplate(prompt_template['QA_PROMPT_TMPL'])
|
|
|
|
vector_query_engine = index.as_query_engine(text_qa_template=QA_PROMPT)
|
|
summary_query_engine = summary_index.as_query_engine(text_qa_template=QA_PROMPT)
|
|
|
|
vector_tool = QueryEngineTool.from_defaults(
|
|
query_engine=vector_query_engine,
|
|
description="Useful for answering questions about this context",
|
|
)
|
|
|
|
summary_tool = QueryEngineTool.from_defaults(
|
|
query_engine=summary_query_engine,
|
|
description="Useful for answering questions about this context",
|
|
)
|
|
|
|
tree_summarize = TreeSummarize(
|
|
summary_template=PromptTemplate(prompt_template['TREE_SUMMARIZE_PROMPT_TMPL'])
|
|
)
|
|
query_engine = RouterQueryEngine(
|
|
selector=LLMMultiSelector.from_defaults(),
|
|
query_engine_tools=[
|
|
vector_tool,
|
|
summary_tool,
|
|
],
|
|
summarizer=tree_summarize,)
|
|
response = query_engine.query(query_text)
|
|
return response
|
|
|
|
if __name__ == "__main__":
|
|
|
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
base_dir = os.path.dirname(script_dir)
|
|
config_path = os.path.join(base_dir, 'configs', 'config.yaml')
|
|
|
|
|
|
pinecone_index,summary_index = index_connection(config_path=config_path)
|
|
|
|
|
|
retriever,summary_retriever = initialize_retriever(pinecone_index,summary_index)
|
|
|
|
|
|
query_text = """How much can the Minister of Health pay out of the Consolidated Revenue Fund in relation to coronavirus disease 2019 (COVID-19) tests"""
|
|
response = index_retrieval(retriever, summary_retriever, query_text)
|
|
print(response)
|
|
|
|
|