File size: 5,548 Bytes
975f394
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from logging import PlaceHolder
from re import sub
import streamlit as st
import imp, time, random
import base64
import io
import nbformat
from PIL import Image
from datasets import load_from_disk, load_dataset
import os
from transformers import pipeline


st.set_page_config(layout="wide")

def set_submitted_true():
    st.session_state.submitted = True

st.markdown("""
<style>
input, .rtl {
  unicode-bidi:bidi-override;
  direction: RTL;
}
textarea, .rtl {
  unicode-bidi:bidi-override;
  direction: RTL;
}
h2, .rtl {
  unicode-bidi:bidi-override;
  direction: RTL;
}
div[role=tablist], .rtl {
  unicode-bidi:bidi-override;
  direction: RTL;
}
div[role=alert], .rtl {
  unicode-bidi:bidi-override;
  direction: RTL;
}
</style>
    """, unsafe_allow_html=True)

latest_iteration = st.empty()
bar = st.progress(0)


st.markdown("## سیستم پرسش و پاسخ فارسی")
st.markdown("")

tab1, tab2 = st.tabs(["دمو", "مستندات"])


datasets_names_addresses = {"small-persian-QA": "Hamid-reza/small-persian-QA",
                            "addsent-small-persian-QA": "Hamid-reza/Adv-small-persian-QA",
                            "addany-small-persian-QA": "mohammadhossein/addany-dataset",
                            "back-translation-small-persian-QA": "jalalnb/back_translation_hy_on_small_persian_QA",
                            "invisible-char-small-persian-QA": "jalalnb/invisible_char_on_small_persian_QA"}

@st.cache(allow_output_mutation=True)
def load_datasets(datasets_names_addresses):
    return {dataset_name: load_dataset(dataset_address)["validation"]
            for dataset_name, dataset_address in datasets_names_addresses.items()}

datasets_names_content = load_datasets(datasets_names_addresses)

selected_dataset_name = st.sidebar.radio(
    ':دیتاست مورد نظر خود را انتخاب نمایید',
     list(datasets_names_addresses.keys()))
selected_dataset = datasets_names_content[selected_dataset_name]


models_names_addresses = {"mbert": ("arashmarioriyad/mbert_v3", "arashmarioriyad/mbert_tokenizer_v3"),
                          "parsbert": ("arashmarioriyad/parsbert_v1", "arashmarioriyad/parsbert_tokenizer_v1"),
                          "addsent-mbert": ("arashmarioriyad/addsent_mbert_v1", "arashmarioriyad/addsent_mbert_tokenizer_v1"),
                          "addsent-parsbert": ("arashmarioriyad/addsent_parsbert_v1", "arashmarioriyad/addsent_parsbert_tokenizer_v1"),
                          "addany-mbert": ("arashmarioriyad/addany_mbert_v1", "arashmarioriyad/addany_mbert_tokenizer_v1"),
                          "addany-parsbert": ("arashmarioriyad/addany_parsbert_v1", "arashmarioriyad/addany_parsbert_tokenizer_v1"),
                          "back-translation-mbert": ("arashmarioriyad/bt_hy_mbert_v1", "arashmarioriyad/bt_hy_mbert_tokenizer_v1"),
                          "back-translation-parsbert": ("arashmarioriyad/bt_hy_parsbert_v1", "arashmarioriyad/bt_hy_parsbert_tokenizer_v1"),
                          "invisible-char-mbert": ("arashmarioriyad/ic_mbert_v1", "arashmarioriyad/ic_mbert_tokenizer_v1"),
                          "invisible-char-parsbert": ("arashmarioriyad/ic_parsbert_v1", "arashmarioriyad/ic_parsbert_tokenizer_v1")}

@st.cache(allow_output_mutation=True)
def load_models(models_names_addresses):
    return {model_name: pipeline("question-answering",
                                 model=models_names_addresses[model_name][0],
                                 tokenizer=models_names_addresses[model_name][1])
            for model_name, model_address in models_names_addresses.items()}

models_names_contents = load_models(models_names_addresses)

selected_model_name = st.sidebar.radio(
    ':مدل مورد نظر خود را انتخاب نمایید',
     list(models_names_addresses.keys()))
selected_model = models_names_contents[selected_model_name]


st.sidebar.info("تمامی دادگان، کد ها و نتایج ارزیابی مدل ها در [صفحه گیت هاب پروژه](https://github.com/NLP-Final-Projects/Adversarial-QA/) قابل دسترسی است", icon="ℹ️")



with tab1.form("my_form", clear_on_submit=False):

    col1, col2, col3 = st.columns(3)
    with col1:
        generate_random_data = st.form_submit_button("تولید داده‌ی تصادفی")
        if generate_random_data:
            sample_idx = random.randrange(len(selected_dataset))
            st.session_state.context = selected_dataset[sample_idx]["context"]
            st.session_state.question = selected_dataset[sample_idx]["question"]

    if 'context' in st.session_state and st.session_state.context is not None:
        context = st.text_area(label="Context", key="context", height=300, value=st.session_state.context)
        question = st.text_input(label="Question", key="question", value=st.session_state.question)
    else:
        context = st.text_area(label="Context", height=300, placeholder="متن مورد نظر را اینجا وارد کنید ...")
        question = st.text_input(label="Question", placeholder="سوال خود از متن را اینجا بپرسید  ...")

    submitted = st.form_submit_button("Get Answer")
    if submitted or ('submitted' in st.session_state and st.session_state.submitted):
        st.session_state.submitted = False
        selected_prediction = selected_model(question=question, context=context)["answer"]
        st.text_area(label=f"Answer ({selected_model_name}):", value=selected_prediction if selected_prediction!="" else "بدون پاسخ")