File size: 3,784 Bytes
a4db683
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# retrieval_qa_pipeline.py

from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA
from datasets import load_dataset

def load_model_and_tokenizer(model_name: str):
    """
    Load the pre-trained model and tokenizer from the Hugging Face Hub.
    
    Args:
        model_name (str): The Hugging Face repository name of the model.
        
    Returns:
        model: The loaded model.
        tokenizer: The loaded tokenizer.
    """
    print(f"Loading model and tokenizer from {model_name}...")
    model = AutoModelForCausalLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return model, tokenizer

def load_dataset_from_hf(dataset_name: str):
    """
    Load the dataset from the Hugging Face Hub.
    
    Args:
        dataset_name (str): The Hugging Face repository name of the dataset.
        
    Returns:
        texts (list): The text descriptions from the dataset.
        metadatas (list): Metadata for each text (e.g., upf_code).
    """
    print(f"Loading dataset from {dataset_name}...")
    dataset = load_dataset(dataset_name)
    texts = dataset["train"]["power_intent_description"]
    metadatas = [{"upf_code": code} for code in dataset["train"]["upf_code"]]
    return texts, metadatas

def load_faiss_index(faiss_index_path: str):
    """
    Load the FAISS index and associated embeddings.
    
    Args:
        faiss_index_path (str): Path to the saved FAISS index.
        
    Returns:
        vectorstore (FAISS): The FAISS vector store.
    """
    print(f"Loading FAISS index from {faiss_index_path}...")
    embeddings = HuggingFaceEmbeddings()  # Default embeddings
    vectorstore = FAISS.load_local(faiss_index_path, embeddings)
    return vectorstore

def build_retrieval_qa_pipeline(model, tokenizer, vectorstore):
    """
    Build the retrieval-based QA pipeline.
    
    Args:
        model: The pre-trained model.
        tokenizer: The tokenizer associated with the model.
        vectorstore (FAISS): The FAISS vector store for retrieval.
        
    Returns:
        qa_chain (RetrievalQA): The retrieval-based QA pipeline.
    """
    print("Building the retrieval-based QA pipeline...")
    hf_pipeline = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        max_length=2048,
        temperature=0.7,
        top_p=0.95,
        repetition_penalty=1.15
    )
    
    llm = HuggingFacePipeline(pipeline=hf_pipeline)
    retriever = vectorstore.as_retriever()
    qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever)
    
    return qa_chain

def main():
    # Replace these names with your model and dataset repo names
    model_name = "username/my_fine_tuned_model"
    dataset_name = "PranavKeshav/upf_code"
    faiss_index_path = "faiss_index"
    
    print("Starting pipeline setup...")
    
    # Load model and tokenizer
    model, tokenizer = load_model_and_tokenizer(model_name)
    
    # Load dataset
    texts, metadatas = load_dataset_from_hf(dataset_name)
    
    # Load FAISS index
    vectorstore = load_faiss_index(faiss_index_path)
    
    # Build QA pipeline
    qa_chain = build_retrieval_qa_pipeline(model, tokenizer, vectorstore)
    
    # Test the pipeline
    print("Pipeline is ready! You can now ask questions.")
    while True:
        query = input("Enter your query (or type 'exit' to quit): ")
        if query.lower() == "exit":
            print("Exiting...")
            break
        response = qa_chain.run(query)
        print(f"Response: {response}")

if __name__ == "__main__":
    main()