JUNGU commited on
Commit
b5a32d7
·
1 Parent(s): 6c831d2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +165 -0
app.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Literal
3
+ import streamlit as st
4
+ from langchain import OpenAI
5
+ from langchain.callbacks import get_openai_callback
6
+ from langchain.chains import ConversationChain
7
+ from langchain.chains.conversation.memory import ConversationSummaryMemory
8
+ import streamlit.components.v1 as components
9
+
10
+ @dataclass
11
+ class Message:
12
+ """Class for keeping track of a chat message."""
13
+ origin: Literal["human", "ai"]
14
+ message: str
15
+
16
+ def load_css():
17
+ with open("static/styles.css", "r") as f:
18
+ css = f"<style>{f.read()}</style>"
19
+ st.markdown(css, unsafe_allow_html=True)
20
+
21
+ def initialize_session_state():
22
+ if "history" not in st.session_state:
23
+ st.session_state.history = []
24
+ if "token_count" not in st.session_state:
25
+ st.session_state.token_count = 0
26
+ if 'chain' not in st.session_state:
27
+ llm = ChatOpenAI(model_name="gpt-3.5-turbo-0613", temperature=0)
28
+ loader = PyPDFLoader("/home/user/app/docs.pdf")
29
+ documents = loader.load()
30
+
31
+ text_splitter = CharacterTextSplitter(chunk_size=800, chunk_overlap=0)
32
+ texts = text_splitter.split_documents(documents)
33
+
34
+ embeddings = OpenAIEmbeddings()
35
+ vector_store = Chroma.from_documents(texts, embeddings)
36
+ retriever = vector_store.as_retriever(search_kwargs={"k": 3})
37
+
38
+ from langchain.prompts.chat import (
39
+ ChatPromptTemplate,
40
+ SystemMessagePromptTemplate,
41
+ HumanMessagePromptTemplate,
42
+ )
43
+
44
+ system_template="""You act like a successful pharmacist. Talk to students about the career path of a pharmacist.
45
+ Never reveal that it is an AI. Use the following pieces of context to answer the users question shortly.
46
+ Given the following summaries of a long document and a question, create a final answer with references.
47
+ If you don't know the answer, just say that "I don't know", don't try to make up an answer.
48
+ ----------------
49
+ {summaries}
50
+ You MUST answer in Korean and in Markdown format"""
51
+
52
+ messages = [
53
+ SystemMessagePromptTemplate.from_template(system_template),
54
+ HumanMessagePromptTemplate.from_template("{question}")
55
+ ]
56
+
57
+ prompt = ChatPromptTemplate.from_messages(messages)
58
+
59
+ chain_type_kwargs = {"prompt": prompt}
60
+
61
+ st.session_state['chain'] = RetrievalQAWithSourcesChain.from_chain_type(
62
+ llm=llm,
63
+ chain_type="stuff",
64
+ retriever=retriever,
65
+ return_source_documents=True,
66
+ chain_type_kwargs=chain_type_kwargs,
67
+ reduce_k_below_max_tokens=True,
68
+ verbose=True,
69
+ )
70
+
71
+ def generate_response(user_input):
72
+ result = st.session_state['chain'](user_input)
73
+
74
+ bot_message = result['answer']
75
+
76
+ for i, doc in enumerate(result['source_documents']):
77
+ bot_message += '[' + str(i+1) + '] ' + doc.metadata['source'] + '(' + str(doc.metadata['page']) + ') '
78
+
79
+ return bot_message
80
+
81
+ def on_click_callback():
82
+ with get_openai_callback() as cb:
83
+ human_prompt = st.session_state.human_prompt
84
+ llm_response = generate_response(human_prompt)
85
+ st.session_state.history.append(
86
+ Message("human", human_prompt)
87
+ )
88
+ st.session_state.history.append(
89
+ Message("ai", llm_response)
90
+ )
91
+ st.session_state.token_count += cb.total_tokens
92
+
93
+ load_css()
94
+ initialize_session_state()
95
+
96
+ st.title("Hello Custom CSS Chatbot 🤖")
97
+
98
+ chat_placeholder = st.container()
99
+ prompt_placeholder = st.form("chat-form")
100
+ credit_card_placeholder = st.empty()
101
+
102
+ with chat_placeholder:
103
+ for chat in st.session_state.history:
104
+ div = f"""
105
+ <div class="chat-row
106
+ {'' if chat.origin == 'ai' else 'row-reverse'}">
107
+ <img class="chat-icon" src="app/static/{
108
+ 'ai_icon.png' if chat.origin == 'ai'
109
+ else 'user_icon.png'}"
110
+ width=32 height=32>
111
+ <div class="chat-bubble
112
+ {'ai-bubble' if chat.origin == 'ai' else 'human-bubble'}">
113
+ &#8203;{chat.message}
114
+ </div>
115
+ </div>
116
+ """
117
+ st.markdown(div, unsafe_allow_html=True)
118
+
119
+ for _ in range(3):
120
+ st.markdown("")
121
+
122
+ with prompt_placeholder:
123
+ st.markdown("**Chat**")
124
+ cols = st.columns((6, 1))
125
+ cols[0].text_input(
126
+ "Chat",
127
+ value="Hello bot",
128
+ label_visibility="collapsed",
129
+ key="human_prompt",
130
+ )
131
+ cols[1].form_submit_button(
132
+ "Submit",
133
+ type="primary",
134
+ on_click=on_click_callback,
135
+ )
136
+
137
+ credit_card_placeholder.caption(f"""
138
+ Used {st.session_state.token_count} tokens \n
139
+ Debug Langchain conversation:
140
+ {st.session_state.conversation.memory.buffer}
141
+ """)
142
+
143
+ components.html("""
144
+ <script>
145
+ const streamlitDoc = window.parent.document;
146
+
147
+ const buttons = Array.from(
148
+ streamlitDoc.querySelectorAll('.stButton > button')
149
+ );
150
+ const submitButton = buttons.find(
151
+ el => el.innerText === 'Submit'
152
+ );
153
+
154
+ streamlitDoc.addEventListener('keydown', function(e) {
155
+ switch (e.key) {
156
+ case 'Enter':
157
+ submitButton.click();
158
+ break;
159
+ }
160
+ });
161
+ </script>
162
+ """,
163
+ height=0,
164
+ width=0,
165
+ )