from haystack.document_stores import FAISSDocumentStore | |
from haystack.utils import convert_files_to_docs, fetch_archive_from_http, clean_wiki_text | |
from haystack.nodes import DensePassageRetriever | |
from haystack.utils import print_documents, print_answers | |
from haystack.pipelines import DocumentSearchPipeline | |
from haystack.nodes import Seq2SeqGenerator | |
from haystack.pipelines import GenerativeQAPipeline | |
# %% Save/Load FAISS and embeddings | |
# Try out this script. Make sure you have deleted any old saves of the document store, including the file called faiss_document_store.db that is saved and loaded by default. | |
# # Convert files to dicts | |
# dicts = convert_files_to_dicts(dir_path=doc_dir, clean_func=clean_wiki_text, split_paragraphs=True)[:10] | |
# document_store = FAISSDocumentStore(faiss_index_factory_str="Flat", vector_dim=128) | |
# # document_store = FAISSDocumentStore(sql_url= "sqlite:///faiss_document_store.db") | |
# retriever = EmbeddingRetriever(document_store=document_store, | |
# embedding_model="yjernite/retribert-base-uncased", | |
# model_format="retribert", | |
# use_gpu=False) | |
# # Now, let's write the dicts containing documents to our DB. | |
# document_store.write_documents(dicts) | |
# document_store.update_embeddings(retriever) | |
# document_store.save("my_faiss_index.faiss") | |
# new_document_store= FAISSDocumentStore.load("my_faiss_index.faiss") | |
# # new_document_store = FAISSDocumentStore.load(faiss_file_path="testfile_path", sql_url= "sqlite:///faiss_document_store.db") | |
# %% ------------------------------------------------------------------------------------------------------------ | |
def prepare(): | |
# %% Document Store | |
document_store= FAISSDocumentStore.load("faiss_index.faiss") | |
# %% Initialize Retriever and Reader/Generator | |
# Retriever (DPR) | |
retriever = DensePassageRetriever( | |
document_store=document_store, | |
query_embedding_model="vblagoje/dpr-question_encoder-single-lfqa-wiki", | |
passage_embedding_model="vblagoje/dpr-ctx_encoder-single-lfqa-wiki", | |
use_gpu=False | |
) | |
# # Test DPR | |
# p_retrieval = DocumentSearchPipeline(retriever) | |
# res = p_retrieval.run(query="Tell me something about Arya Stark?", params={"Retriever": {"top_k": 5}}) | |
# print_documents(res, max_text_len=512) | |
# Reader/Generator | |
# Here we use a Seq2SeqGenerator with the vblagoje/bart_lfqa model (https://huggingface.co/vblagoje/bart_lfqa) | |
generator = Seq2SeqGenerator(model_name_or_path="vblagoje/bart_lfqa", | |
use_gpu=False) | |
# %% Pipeline | |
pipe = GenerativeQAPipeline(generator, retriever) | |
return pipe | |
def answer(pipe, question, k_retriever=3): | |
res = pipe.run(question, params={"Retriever": {"top_k": k_retriever}}) | |
# # Question | |
# pipe.run( | |
# query="How did Arya Stark's character get portrayed in a television adaptation?", params={"Retriever": {"top_k": 3}} | |
# ) | |
# # Answer | |
# res = pipe.run(query="Why is Arya Stark an unusual character?", params={"Retriever": {"top_k": 3}}) | |
return res | |
if __name__ == '__main__': | |
question = 'Tell me something about Arya Stark?' | |
pipe = prepare() | |
res = answer(pipe, question) | |
print_answers(res, details="minimum") | |