HelloWorldRAG / app1.py
ashutoshzade's picture
Rename app.py to app1.py
3153da2 verified
from langchain_community.llms import HuggingFacePipeline
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_community.document_loaders import TextLoader
from langchain.text_splitter import CharacterTextSplitter
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
# Load Gemma model and tokenizer
#model_name = "google/gemma-2-2b"
#model_name = "google/gemma-1.1-2b-it"
model_name = "HuggingFaceH4/zephyr-7b-beta"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# Create a text generation pipeline
text_generation_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512,
temperature=0.7
)
# Create a LangChain LLM from the pipeline
llm = HuggingFacePipeline(pipeline=text_generation_pipeline)
# Load and process documents
#loader = TextLoader("https://en.wikipedia.org/wiki/Cheetah")
loader = TextLoader("https://en.wikipedia.org/wiki/Artificial_neuron")
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
texts = text_splitter.split_documents(documents)
# Create embeddings and vector store
embeddings = HuggingFaceEmbeddings()
db = Chroma.from_documents(texts, embeddings)
# Create a retriever
retriever = db.as_retriever()
# Create a prompt template
template = """Use the following pieces of context to answer the question at the end.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
{context}
Question: {question}
Answer:"""
prompt = PromptTemplate(template=template, input_variables=["context", "question"])
# Create the RetrievalQA chain
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
return_source_documents=True,
chain_type_kwargs={"prompt": prompt}
)
# Example query
#query = "How fast cheetah can run?"
query = "What is an artifical neuron?"
result = qa_chain({"query": query})
print("Question:", query)
print("Answer:", result["result"])