Spaces:
Build error
Build error
import streamlit as st | |
from langchain.chat_models import ChatOpenAI | |
from langchain.document_loaders import PyPDFLoader | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
from langchain.embeddings.cohere import CohereEmbeddings | |
from langchain.text_splitter import CharacterTextSplitter | |
from langchain.vectorstores.elastic_vector_search import ElasticVectorSearch | |
from langchain.vectorstores import Chroma | |
from langchain.chat_models import ChatOpenAI | |
from langchain.chains import RetrievalQAWithSourcesChain | |
from dotenv import load_dotenv | |
from streamlit_chat import message | |
import os | |
import time | |
import textwrap | |
import threading | |
load_dotenv() | |
os.environ["OPENAI_API_KEY"] = os.environ['my_secret'] | |
if 'chat_history' not in st.session_state: | |
st.session_state['chat_history'] = [] | |
if 'chain' not in st.session_state: | |
llm = ChatOpenAI(model_name="gpt-3.5-turbo-0613", temperature=0) | |
loader = PyPDFLoader("/home/user/app/docs.pdf") | |
documents = loader.load() | |
text_splitter = CharacterTextSplitter(chunk_size=800, chunk_overlap=0) | |
texts = text_splitter.split_documents(documents) | |
embeddings = OpenAIEmbeddings() | |
vector_store = Chroma.from_documents(texts, embeddings) | |
retriever = vector_store.as_retriever(search_kwargs={"k": 3}) | |
from langchain.prompts.chat import ( | |
ChatPromptTemplate, | |
SystemMessagePromptTemplate, | |
HumanMessagePromptTemplate, | |
) | |
system_template="""You act like a successful pharmacist. Talk to students about the career path of a pharmacist. | |
Never reveal that it is an AI. Use the following pieces of context to answer the users question shortly. | |
Given the following summaries of a long document and a question, create a final answer with references. | |
If you don't know the answer, just say that "I don't know", don't try to make up an answer. | |
---------------- | |
{summaries} | |
You MUST answer in Korean and in Markdown format""" | |
messages = [ | |
SystemMessagePromptTemplate.from_template(system_template), | |
HumanMessagePromptTemplate.from_template("{question}") | |
] | |
prompt = ChatPromptTemplate.from_messages(messages) | |
chain_type_kwargs = {"prompt": prompt} | |
st.session_state['chain'] = RetrievalQAWithSourcesChain.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=retriever, | |
return_source_documents=True, | |
chain_type_kwargs=chain_type_kwargs, | |
reduce_k_below_max_tokens=True, | |
verbose=True, | |
) | |
def generate_response(user_input): | |
result = st.session_state['chain'](user_input) | |
bot_message = result['answer'] | |
for i, doc in enumerate(result['source_documents']): | |
bot_message += '[' + str(i+1) + '] ' + doc.metadata['source'] + '(' + str(doc.metadata['page']) + ') ' | |
return bot_message | |
def wrap_text(text, max_length=40): | |
return '\n'.join(textwrap.wrap(text, max_length)) | |
# st.header("[μμμ§λ‘ν΅] μ½μ¬μ κΈΈ \n μ€μ μ½μ¬μ μΈν°λ·° λ΄μ©μ κΈ°λ°μΌλ‘ μ§λ‘ μλ΄μ ν΄λ³΄μΈμ") | |
# with st.form('form', clear_on_submit=True): | |
# user_input = st.text_input('You: ', '', key='input') | |
# submitted = st.form_submit_button('Send') | |
# if submitted and user_input: | |
# with st.spinner('μλ΅μ μμ±μ€μ λλ€...'): | |
# output = generate_response(user_input) | |
# st.session_state.chat_history.append({"User": user_input, "Bot": output}) | |
# for idx, chat in enumerate(st.session_state['chat_history'][:-1]): | |
# message(chat['User'], is_user=True, key=str(idx) + '_user') | |
# message(wrap_text("μ½μ¬: " + chat['Bot']), key=str(idx)) | |
# if st.session_state['chat_history']: | |
# last_chat = st.session_state['chat_history'][-1] | |
# message(last_chat['User'], is_user=True, key=str(len(st.session_state['chat_history'])-1) + '_user') | |
# new_placeholder = st.empty() | |
# sender_name = "μ½μ¬: " | |
# for j in range(len(last_chat['Bot'])): | |
# new_placeholder.text(wrap_text(sender_name + last_chat['Bot'][:j+1])) | |
# time.sleep(0.05) | |
st.header("[μμμ§λ‘ν΅] μ½μ¬μ κΈΈ \n μ€μ μ½μ¬μ μΈν°λ·° λ΄μ©μ κΈ°λ°μΌλ‘ μ§λ‘ μλ΄μ ν΄λ³΄μΈμ") | |
if st.session_state['chat_history']: | |
for idx, chat in enumerate(st.session_state['chat_history'][:-1]): | |
message(chat['User'], is_user=True, key=str(idx) + '_user') | |
message(wrap_text("μ½μ¬: " + chat['Bot']), key=str(idx)) | |
if st.session_state['chat_history']: | |
last_chat = st.session_state['chat_history'][-1] | |
message(last_chat['User'], is_user=True, key=str(len(st.session_state['chat_history'])-1) + '_user') | |
new_placeholder = st.empty() | |
sender_name = "μ½μ¬: " | |
for j in range(len(last_chat['Bot'])): | |
new_placeholder.text(wrap_text(sender_name + last_chat['Bot'][:j+1])) | |
time.sleep(0.05) | |
with st.form('form', clear_on_submit=True): | |
user_input = st.text_input('You: ', '', key='input') | |
submitted = st.form_submit_button('Send') | |
if submitted and user_input: | |
with st.spinner('μλ΅μ μμ±μ€μ λλ€...'): | |
output = generate_response(user_input) | |
st.session_state.chat_history.append({"User": user_input, "Bot": output}) |