Spaces:
Sleeping
Sleeping
Create pages/RAG with xModels.py
Browse files- pages/RAG with xModels.py +122 -0
pages/RAG with xModels.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import streamlit as st
|
3 |
+
from langchain.chains import RetrievalQA
|
4 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
5 |
+
from langchain.vectorstores import Chroma
|
6 |
+
from langchain.llms import HuggingFaceHub
|
7 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
8 |
+
from dotenv import load_dotenv
|
9 |
+
from PyPDF2 import PdfReader
|
10 |
+
|
11 |
+
# Load environment variables
|
12 |
+
load_dotenv()
|
13 |
+
|
14 |
+
# Define your supported models
|
15 |
+
model_links = {
|
16 |
+
"Meta-Llama-3-8B": "meta-llama/Meta-Llama-3-8B-Instruct",
|
17 |
+
"Mistral-7B": "mistralai/Mistral-7B-Instruct-v0.2",
|
18 |
+
"Gemma-7B": "google/gemma-1.1-7b-it",
|
19 |
+
"Gemma-2B": "google/gemma-1.1-2b-it",
|
20 |
+
"Zephyr-7B-β": "HuggingFaceH4/zephyr-7b-beta",
|
21 |
+
}
|
22 |
+
|
23 |
+
# Function to read PDF files and extract text along with their names
|
24 |
+
@st.cache_data
|
25 |
+
def read_pdf_files(directory):
|
26 |
+
documents = []
|
27 |
+
for filename in os.listdir(directory):
|
28 |
+
if filename.endswith(".pdf"):
|
29 |
+
with open(os.path.join(directory, filename), "rb") as file:
|
30 |
+
reader = PdfReader(file)
|
31 |
+
text = ""
|
32 |
+
for page in reader.pages:
|
33 |
+
text += page.extract_text()
|
34 |
+
documents.append((filename, text))
|
35 |
+
return documents
|
36 |
+
|
37 |
+
# Initialize ChromaDB with PDF data
|
38 |
+
@st.cache_resource
|
39 |
+
def initialize_chromadb_from_pdfs(directory):
|
40 |
+
documents = read_pdf_files(directory)
|
41 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
42 |
+
split_docs = []
|
43 |
+
for doc_name, doc_text in documents:
|
44 |
+
chunks = text_splitter.split_text(doc_text)
|
45 |
+
split_docs.extend([(doc_name, chunk) for chunk in chunks])
|
46 |
+
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
47 |
+
chromadb = Chroma.from_texts([chunk for _, chunk in split_docs], embeddings, metadatas=[{"source": name} for name, _ in split_docs])
|
48 |
+
return chromadb, split_docs
|
49 |
+
|
50 |
+
# Initialize the ChromaDB retriever
|
51 |
+
chromadb, split_docs = initialize_chromadb_from_pdfs("docs")
|
52 |
+
retriever = chromadb.as_retriever(search_type="similarity", search_kwargs={"k": 5})
|
53 |
+
|
54 |
+
# Create the sidebar with the dropdown for model selection
|
55 |
+
selected_model = st.sidebar.selectbox("Select Model", model_links.keys())
|
56 |
+
|
57 |
+
# Create temperature slider
|
58 |
+
temp_values = st.sidebar.slider('Select a temperature value', 0.0, 1.0, 0.5)
|
59 |
+
|
60 |
+
# Add reset button to clear conversation
|
61 |
+
st.sidebar.button('Reset Chat', on_click=lambda: st.session_state.clear())
|
62 |
+
|
63 |
+
# Pull in the selected model
|
64 |
+
repo_id = model_links[selected_model]
|
65 |
+
|
66 |
+
# Initialize chat history
|
67 |
+
if "messages" not in st.session_state:
|
68 |
+
st.session_state.messages = []
|
69 |
+
|
70 |
+
# Display chat messages from history on app rerun
|
71 |
+
for message in st.session_state.messages:
|
72 |
+
with st.chat_message(message["role"]):
|
73 |
+
st.markdown(message["content"])
|
74 |
+
|
75 |
+
# Accept user input
|
76 |
+
if prompt := st.chat_input(f"Hi I'm {selected_model}, ask me a question"):
|
77 |
+
# Display user message in chat message container
|
78 |
+
with st.chat_message("user"):
|
79 |
+
st.markdown(prompt)
|
80 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
81 |
+
|
82 |
+
# Perform RAG
|
83 |
+
with st.spinner('Processing query with RAG...'):
|
84 |
+
llm = HuggingFaceHub(repo_id=repo_id)
|
85 |
+
qa_chain = RetrievalQA.from_chain_type(
|
86 |
+
llm=llm,
|
87 |
+
chain_type="stuff",
|
88 |
+
retriever=retriever,
|
89 |
+
return_source_documents=True
|
90 |
+
)
|
91 |
+
response = qa_chain({"query": prompt})
|
92 |
+
|
93 |
+
helpful_answer = response['result']
|
94 |
+
source_documents = response['source_documents']
|
95 |
+
|
96 |
+
# Ensure the answer is complete by checking for truncation
|
97 |
+
if helpful_answer.endswith("..."):
|
98 |
+
# If truncated, try to get more context from the source documents
|
99 |
+
for doc in source_documents:
|
100 |
+
if doc is not None:
|
101 |
+
doc_name = doc.metadata["source"]
|
102 |
+
doc_text = next((text for name, text in read_pdf_files("docs") if name == doc_name), "")
|
103 |
+
# Extract relevant context
|
104 |
+
start_idx = doc_text.find(helpful_answer)
|
105 |
+
if start_idx != -1:
|
106 |
+
end_idx = start_idx + len(helpful_answer) + 100 # Add some extra context
|
107 |
+
helpful_answer += "\n\n" + doc_text[start_idx:end_idx]
|
108 |
+
|
109 |
+
# Display assistant response
|
110 |
+
with st.chat_message("assistant"):
|
111 |
+
st.markdown(helpful_answer)
|
112 |
+
|
113 |
+
# Display references in an expander
|
114 |
+
if source_documents:
|
115 |
+
with st.expander("References", expanded=False):
|
116 |
+
for doc in source_documents:
|
117 |
+
doc_name = doc.metadata["source"]
|
118 |
+
st.markdown(f"- **{doc_name}**")
|
119 |
+
|
120 |
+
# Only add the helpful answer and references to the session state
|
121 |
+
references = "\n".join([f"- **{doc.metadata['source']}**" for doc in source_documents if doc])
|
122 |
+
st.session_state.messages.append({"role": "assistant", "content": f"{helpful_answer}\n\n**References:**\n{references}"})
|