samim2024 commited on
Commit
b4deae8
·
verified ·
1 Parent(s): 3c59ee7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -60
app.py CHANGED
@@ -1,67 +1,128 @@
1
- import os
2
- import tempfile
 
 
 
 
 
 
 
 
3
  import streamlit as st
4
- from streamlit_chat import message
5
-
6
- from rag import ChatPDF
7
-
8
- st.set_page_config(page_title="ChatPDF")
9
-
10
-
11
- def display_messages():
12
- st.subheader("Chat")
13
- for i, (msg, is_user) in enumerate(st.session_state["messages"]):
14
- message(msg, is_user=is_user, key=str(i))
15
- st.session_state["thinking_spinner"] = st.empty()
16
-
17
-
18
- def process_input():
19
- if st.session_state["user_input"] and len(st.session_state["user_input"].strip()) > 0:
20
- user_text = st.session_state["user_input"].strip()
21
- with st.session_state["thinking_spinner"], st.spinner(f"Thinking"):
22
- agent_text = st.session_state["assistant"].ask(user_text)
23
-
24
- st.session_state["messages"].append((user_text, True))
25
- st.session_state["messages"].append((agent_text, False))
26
-
27
-
28
- def read_and_save_file():
29
- st.session_state["assistant"].clear()
30
- st.session_state["messages"] = []
31
- st.session_state["user_input"] = ""
32
-
33
- for file in st.session_state["file_uploader"]:
34
- with tempfile.NamedTemporaryFile(delete=False) as tf:
35
- tf.write(file.getbuffer())
36
- file_path = tf.name
37
 
38
- with st.session_state["ingestion_spinner"], st.spinner(f"Ingesting {file.name}"):
39
- st.session_state["assistant"].ingest(file_path)
40
- os.remove(file_path)
41
 
 
 
42
 
43
- def page():
44
- if len(st.session_state) == 0:
45
- st.session_state["messages"] = []
46
- st.session_state["assistant"] = ChatPDF()
47
 
48
- st.header("ChatPDF")
 
49
 
50
- st.subheader("Upload a document")
51
- st.file_uploader(
52
- "Upload document",
53
- type=["pdf"],
54
- key="file_uploader",
55
- on_change=read_and_save_file,
56
- label_visibility="collapsed",
57
- accept_multiple_files=True,
58
  )
59
-
60
- st.session_state["ingestion_spinner"] = st.empty()
61
-
62
- display_messages()
63
- st.text_input("Message", key="user_input", on_change=process_input)
64
-
65
-
66
- if __name__ == "__main__":
67
- page()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.chains import RetrievalQA
2
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
3
+ from langchain.callbacks.manager import CallbackManager
4
+ from langchain_community.llms import Ollama
5
+ from langchain_community.embeddings.ollama import OllamaEmbeddings
6
+ from langchain_community.vectorstores import Chroma
7
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ from langchain_community.document_loaders import PyPDFLoader
9
+ from langchain.prompts import PromptTemplate
10
+ from langchain.memory import ConversationBufferMemory
11
  import streamlit as st
12
+ import os
13
+ import time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ if not os.path.exists('files'):
16
+ os.mkdir('files')
 
17
 
18
+ if not os.path.exists('jj'):
19
+ os.mkdir('jj')
20
 
21
+ if 'template' not in st.session_state:
22
+ st.session_state.template = """You are a knowledgeable chatbot, here to help with questions of the user. Your tone should be professional and informative.
 
 
23
 
24
+ Context: {context}
25
+ History: {history}
26
 
27
+ User: {question}
28
+ Chatbot:"""
29
+ if 'prompt' not in st.session_state:
30
+ st.session_state.prompt = PromptTemplate(
31
+ input_variables=["history", "context", "question"],
32
+ template=st.session_state.template,
 
 
33
  )
34
+ if 'memory' not in st.session_state:
35
+ st.session_state.memory = ConversationBufferMemory(
36
+ memory_key="history",
37
+ return_messages=True,
38
+ input_key="question")
39
+ if 'vectorstore' not in st.session_state:
40
+ st.session_state.vectorstore = Chroma(persist_directory='jj',
41
+ embedding_function=OllamaEmbeddings(base_url='http://localhost:11434',
42
+ model="mistral")
43
+ )
44
+ if 'llm' not in st.session_state:
45
+ st.session_state.llm = Ollama(base_url="http://localhost:11434",
46
+ model="mistral",
47
+ verbose=True,
48
+ callback_manager=CallbackManager(
49
+ [StreamingStdOutCallbackHandler()]),
50
+ )
51
+
52
+ # Initialize session state
53
+ if 'chat_history' not in st.session_state:
54
+ st.session_state.chat_history = []
55
+
56
+ st.title("PDF Chatbot")
57
+
58
+ # Upload a PDF file
59
+ uploaded_file = st.file_uploader("Upload your PDF", type='pdf')
60
+
61
+ for message in st.session_state.chat_history:
62
+ with st.chat_message(message["role"]):
63
+ st.markdown(message["message"])
64
+
65
+ if uploaded_file is not None:
66
+ if not os.path.isfile("files/"+uploaded_file.name+".pdf"):
67
+ with st.status("Analyzing your document..."):
68
+ bytes_data = uploaded_file.read()
69
+ f = open("files/"+uploaded_file.name+".pdf", "wb")
70
+ f.write(bytes_data)
71
+ f.close()
72
+ loader = PyPDFLoader("files/"+uploaded_file.name+".pdf")
73
+ data = loader.load()
74
+
75
+ # Initialize text splitter
76
+ text_splitter = RecursiveCharacterTextSplitter(
77
+ chunk_size=1500,
78
+ chunk_overlap=200,
79
+ length_function=len
80
+ )
81
+ all_splits = text_splitter.split_documents(data)
82
+
83
+ # Create and persist the vector store
84
+ st.session_state.vectorstore = Chroma.from_documents(
85
+ documents=all_splits,
86
+ embedding=OllamaEmbeddings(model="mistral")
87
+ )
88
+ st.session_state.vectorstore.persist()
89
+
90
+ st.session_state.retriever = st.session_state.vectorstore.as_retriever()
91
+ # Initialize the QA chain
92
+ if 'qa_chain' not in st.session_state:
93
+ st.session_state.qa_chain = RetrievalQA.from_chain_type(
94
+ llm=st.session_state.llm,
95
+ chain_type='stuff',
96
+ retriever=st.session_state.retriever,
97
+ verbose=True,
98
+ chain_type_kwargs={
99
+ "verbose": True,
100
+ "prompt": st.session_state.prompt,
101
+ "memory": st.session_state.memory,
102
+ }
103
+ )
104
+
105
+ # Chat input
106
+ if user_input := st.chat_input("You:", key="user_input"):
107
+ user_message = {"role": "user", "message": user_input}
108
+ st.session_state.chat_history.append(user_message)
109
+ with st.chat_message("user"):
110
+ st.markdown(user_input)
111
+ with st.chat_message("assistant"):
112
+ with st.spinner("Assistant is typing..."):
113
+ response = st.session_state.qa_chain(user_input)
114
+ message_placeholder = st.empty()
115
+ full_response = ""
116
+ for chunk in response['result'].split():
117
+ full_response += chunk + " "
118
+ time.sleep(0.05)
119
+ # Add a blinking cursor to simulate typing
120
+ message_placeholder.markdown(full_response + "▌")
121
+ message_placeholder.markdown(full_response)
122
+
123
+ chatbot_message = {"role": "assistant", "message": response['result']}
124
+ st.session_state.chat_history.append(chatbot_message)
125
+
126
+
127
+ else:
128
+ st.write("Please upload a PDF file.")