|
from zipfile import ZipFile |
|
|
|
import os |
|
|
|
import streamlit as st |
|
|
|
from lfqa import prepare, answer |
|
|
|
|
|
with ZipFile("doc_store.zip","r") as zip_ref: |
|
zip_ref.extractall('.') |
|
|
|
|
|
DEFAULT_DOCS_FROM_RETRIEVER = int(os.getenv("DEFAULT_DOCS_FROM_RETRIEVER", "3")) |
|
|
|
DEFAULT_QUESTION_AT_STARTUP = os.getenv("DEFAULT_QUESTION_AT_STARTUP", "Tell me something about Arya Stark?") |
|
|
|
|
|
def set_state_if_absent(key, value): |
|
if key not in st.session_state: |
|
st.session_state[key] = value |
|
|
|
def reset_results(*args): |
|
st.session_state.answer = None |
|
st.session_state.results = None |
|
|
|
def main(pipe): |
|
st.set_page_config(page_title="Haystack Demo", page_icon="https://haystack.deepset.ai/img/HaystackIcon.png") |
|
|
|
|
|
set_state_if_absent("question", DEFAULT_QUESTION_AT_STARTUP) |
|
set_state_if_absent("results", None) |
|
|
|
st.write("# Long-Form Question Answering") |
|
st.markdown(""" |
|
This demo takes its data from a selection of Wikipedia pages on the topic of the **Game of Thrones** TV series |
|
""") |
|
|
|
|
|
st.sidebar.header("Options") |
|
top_k_retriever = st.sidebar.slider( |
|
"Max. number of documents from retriever", |
|
min_value=1, |
|
max_value=10, |
|
value=DEFAULT_DOCS_FROM_RETRIEVER, |
|
step=1, |
|
on_change=reset_results, |
|
) |
|
|
|
|
|
|
|
|
|
st.sidebar.markdown( |
|
""" |
|
<style> |
|
a {{ |
|
text-decoration: none; |
|
}} |
|
.haystack-footer {{ |
|
text-align: center; |
|
}} |
|
.haystack-footer h4 {{ |
|
margin: 0.1rem; |
|
padding:0; |
|
}} |
|
footer {{ |
|
opacity: 0; |
|
}} |
|
</style> |
|
<div class="haystack-footer"> |
|
<hr /> |
|
<h4>Built with <a href="https://www.deepset.ai/haystack">Haystack</a></h4> |
|
<p>Get it on <a href="https://github.com/deepset-ai/haystack/">GitHub</a> - Read the <a href="https://haystack.deepset.ai/overview/intro">Docs</a></p> |
|
<small>Dataset link: <a href="https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-qa/datasets/documents/wiki_gameofthrones_txt12.zip"">Game of Thrones Wiki</a> <br />See the <a href="https://creativecommons.org/licenses/by-sa/3.0/">License</a> (CC BY-SA 3.0).</small> |
|
</div> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
|
|
question = st.text_input( |
|
value=st.session_state.question, |
|
max_chars=100, |
|
on_change=reset_results, |
|
label="question", |
|
label_visibility="hidden", |
|
) |
|
col1, col2 = st.columns(2) |
|
col1.markdown("<style>.stButton button {width:100%;}</style>", unsafe_allow_html=True) |
|
|
|
|
|
run_pressed = col1.button("Run") |
|
|
|
run_query = run_pressed or (question != st.session_state.question) |
|
|
|
if run_query and question: |
|
reset_results() |
|
st.session_state.question = question |
|
|
|
with st.spinner( |
|
"π§ Performing neural search on documents... \n "): |
|
try: |
|
st.session_state.results = answer(pipe, question, top_k_retriever) |
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
if "The server is busy processing requests" in str(e) or "503" in str(e): |
|
st.error("π§βπΎ All our workers are busy! Try again later.") |
|
else: |
|
st.error("π An error occurred during the request.") |
|
return |
|
|
|
if st.session_state.results: |
|
st.session_state.answer = st.session_state.results['answers'][0].answer |
|
st.write(st.session_state.answer) |
|
st.write('Doc IDs:') |
|
st.write(st.session_state.results['answers'][0].meta['doc_ids']) |
|
st.write('Doc Scores:') |
|
st.write(st.session_state.results['answers'][0].meta['doc_scores']) |
|
for i in range(top_k_retriever): |
|
st.write(st.session_state.results['answers'][0].meta['content'][i]) |
|
st.markdown('---\n') |
|
|
|
pipe = prepare() |
|
main(pipe) |
|
|