pharmacist1 / app.py
JUNGU's picture
Update app.py
6b774ce
import streamlit as st
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import PyPDFLoader
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings.cohere import CohereEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores.elastic_vector_search import ElasticVectorSearch
from langchain.vectorstores import Chroma
from langchain.chat_models import ChatOpenAI
from langchain.chains import RetrievalQAWithSourcesChain
from dotenv import load_dotenv
from streamlit_chat import message
import os
import time
import textwrap
import threading
load_dotenv()
os.environ["OPENAI_API_KEY"] = os.environ['my_secret']
if 'chat_history' not in st.session_state:
st.session_state['chat_history'] = []
if 'chain' not in st.session_state:
llm = ChatOpenAI(model_name="gpt-3.5-turbo-0613", temperature=0)
loader = PyPDFLoader("/home/user/app/docs.pdf")
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=800, chunk_overlap=0)
texts = text_splitter.split_documents(documents)
embeddings = OpenAIEmbeddings()
vector_store = Chroma.from_documents(texts, embeddings)
retriever = vector_store.as_retriever(search_kwargs={"k": 3})
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
system_template="""You act like a successful pharmacist. Talk to students about the career path of a pharmacist.
Never reveal that it is an AI. Use the following pieces of context to answer the users question shortly.
Given the following summaries of a long document and a question, create a final answer with references.
If you don't know the answer, just say that "I don't know", don't try to make up an answer.
----------------
{summaries}
You MUST answer in Korean and in Markdown format"""
messages = [
SystemMessagePromptTemplate.from_template(system_template),
HumanMessagePromptTemplate.from_template("{question}")
]
prompt = ChatPromptTemplate.from_messages(messages)
chain_type_kwargs = {"prompt": prompt}
st.session_state['chain'] = RetrievalQAWithSourcesChain.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
return_source_documents=True,
chain_type_kwargs=chain_type_kwargs,
reduce_k_below_max_tokens=True,
verbose=True,
)
def generate_response(user_input):
result = st.session_state['chain'](user_input)
bot_message = result['answer']
for i, doc in enumerate(result['source_documents']):
bot_message += '[' + str(i+1) + '] ' + doc.metadata['source'] + '(' + str(doc.metadata['page']) + ') '
return bot_message
def wrap_text(text, max_length=40):
return '\n'.join(textwrap.wrap(text, max_length))
# st.header("[μƒμƒμ§„λ‘œν†΅] μ•½μ‚¬μ˜ κΈΈ \n μ‹€μ œ μ•½μ‚¬μ˜ 인터뷰 λ‚΄μš©μ„ 기반으둜 μ§„λ‘œ 상담을 ν•΄λ³΄μ„Έμš”")
# with st.form('form', clear_on_submit=True):
# user_input = st.text_input('You: ', '', key='input')
# submitted = st.form_submit_button('Send')
# if submitted and user_input:
# with st.spinner('응닡을 μƒμ„±μ€‘μž…λ‹ˆλ‹€...'):
# output = generate_response(user_input)
# st.session_state.chat_history.append({"User": user_input, "Bot": output})
# for idx, chat in enumerate(st.session_state['chat_history'][:-1]):
# message(chat['User'], is_user=True, key=str(idx) + '_user')
# message(wrap_text("약사: " + chat['Bot']), key=str(idx))
# if st.session_state['chat_history']:
# last_chat = st.session_state['chat_history'][-1]
# message(last_chat['User'], is_user=True, key=str(len(st.session_state['chat_history'])-1) + '_user')
# new_placeholder = st.empty()
# sender_name = "약사: "
# for j in range(len(last_chat['Bot'])):
# new_placeholder.text(wrap_text(sender_name + last_chat['Bot'][:j+1]))
# time.sleep(0.05)
st.header("[μƒμƒμ§„λ‘œν†΅] μ•½μ‚¬μ˜ κΈΈ \n μ‹€μ œ μ•½μ‚¬μ˜ 인터뷰 λ‚΄μš©μ„ 기반으둜 μ§„λ‘œ 상담을 ν•΄λ³΄μ„Έμš”")
if st.session_state['chat_history']:
for idx, chat in enumerate(st.session_state['chat_history'][:-1]):
message(chat['User'], is_user=True, key=str(idx) + '_user')
message(wrap_text("약사: " + chat['Bot']), key=str(idx))
if st.session_state['chat_history']:
last_chat = st.session_state['chat_history'][-1]
message(last_chat['User'], is_user=True, key=str(len(st.session_state['chat_history'])-1) + '_user')
new_placeholder = st.empty()
sender_name = "약사: "
for j in range(len(last_chat['Bot'])):
new_placeholder.text(wrap_text(sender_name + last_chat['Bot'][:j+1]))
time.sleep(0.05)
with st.form('form', clear_on_submit=True):
user_input = st.text_input('You: ', '', key='input')
submitted = st.form_submit_button('Send')
if submitted and user_input:
with st.spinner('응닡을 μƒμ„±μ€‘μž…λ‹ˆλ‹€...'):
output = generate_response(user_input)
st.session_state.chat_history.append({"User": user_input, "Bot": output})