Bram Vanroy
commited on
Commit
·
d3a07ee
1
Parent(s):
79a800a
add app
Browse files
app.py
CHANGED
@@ -1,3 +1,7 @@
|
|
|
|
|
|
|
|
|
|
1 |
from utils import get_resources, simplify
|
2 |
|
3 |
import streamlit as st
|
@@ -7,25 +11,83 @@ st.set_page_config(
|
|
7 |
page_icon="🏃"
|
8 |
)
|
9 |
|
|
|
|
|
|
|
|
|
|
|
10 |
st.title("🏃 Text Simplification in Dutch")
|
11 |
|
12 |
-
|
13 |
-
text = st.text_area(label="Input text", value="Met het naderen van de zonovergoten middaghemel op deze betoverende dag, waarbij de atmosferische omstandigheden een onbelemmerde convergentie van cumulusbewolking en uitgestrekte stratosferische azuurblauwe wijdheid faciliteren, lijken de geaggregeerde weersverschijnselen van vandaag, die variëren van sporadische plensbuien tot kalme zuchtjes wind en zeldzame opvlammingen van bliksem, de delicate balans tussen meteorologische complexiteit en eenvoud te weerspiegelen, waardoor de gepassioneerde observator met een gevoel van ontzag en verwondering wordt vervuld.")
|
14 |
-
submitted = st.form_submit_button("Submit")
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
else:
|
22 |
-
|
23 |
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
25 |
error_ct.empty()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
|
31 |
########################
|
|
|
1 |
+
import base64
|
2 |
+
from io import StringIO
|
3 |
+
from math import ceil
|
4 |
+
|
5 |
from utils import get_resources, simplify
|
6 |
|
7 |
import streamlit as st
|
|
|
11 |
page_icon="🏃"
|
12 |
)
|
13 |
|
14 |
+
BATCH_SIZE = 8
|
15 |
+
|
16 |
+
if "text_to_simplify" not in st.session_state:
|
17 |
+
st.session_state["text_to_simplify"] = None
|
18 |
+
|
19 |
st.title("🏃 Text Simplification in Dutch")
|
20 |
|
21 |
+
fupload_check = st.checkbox("File upload?")
|
|
|
|
|
22 |
|
23 |
+
st.markdown(
|
24 |
+
"Make sure that the file or text in the text box contains **one sentence per line**. Empty lines will"
|
25 |
+
" be removed."
|
26 |
+
)
|
27 |
+
if fupload_check:
|
28 |
+
uploaded_file = st.file_uploader("Text file", label_visibility="collapsed")
|
29 |
+
if uploaded_file is not None:
|
30 |
+
stringio = StringIO(uploaded_file.getvalue().decode("utf-8"))
|
31 |
+
st.session_state["text_to_simplify"] = stringio.read().strip()
|
32 |
+
else:
|
33 |
+
st.session_state["text_to_simplify"] = None
|
34 |
+
else:
|
35 |
+
st.session_state["text_to_simplify"] = st.text_area(
|
36 |
+
label="Sentences to translate", label_visibility="collapsed", height=200,
|
37 |
+
value="Met het naderen van de zonovergoten middaghemel op deze betoverende dag, waarbij de atmosferische omstandigheden een onbelemmerde convergentie van cumulusbewolking en uitgestrekte stratosferische azuurblauwe wijdheid faciliteren, lijken de geaggregeerde weersverschijnselen van vandaag, die variëren van sporadische plensbuien tot kalme zuchtjes wind en zeldzame opvlammingen van bliksem, de delicate balans tussen meteorologische complexiteit en eenvoud te weerspiegelen, waardoor de gepassioneerde observator met een gevoel van ontzag en verwondering wordt vervuld."
|
38 |
+
).strip()
|
39 |
+
|
40 |
+
|
41 |
+
def _get_increment_size(num_sents) -> int:
|
42 |
+
if BATCH_SIZE >= num_sents:
|
43 |
+
return 100
|
44 |
else:
|
45 |
+
return ceil(100 / (num_sents / BATCH_SIZE))
|
46 |
|
47 |
+
btn_col, results_col = st.columns(2)
|
48 |
+
btn_ct = btn_col.empty()
|
49 |
+
error_ct = st.empty()
|
50 |
+
simpl_ct = st.container()
|
51 |
+
if st.session_state["text_to_simplify"]:
|
52 |
+
if btn_ct.button("Simplify text"):
|
53 |
error_ct.empty()
|
54 |
+
lines = [strip_line for line in st.session_state["text_to_simplify"].splitlines() if (strip_line := line.strip())]
|
55 |
+
num_sentences = len(lines)
|
56 |
+
|
57 |
+
pbar = st.progress(0, text=f"Simplifying sentences in batches of {BATCH_SIZE}...")
|
58 |
+
increment = _get_increment_size(num_sentences)
|
59 |
+
percent_done = 0
|
60 |
+
|
61 |
+
model, tokenizer = get_resources()
|
62 |
+
|
63 |
+
simpl_ct.caption("Simplified text")
|
64 |
+
output_ct = simpl_ct.empty()
|
65 |
+
all_simplifications = []
|
66 |
+
html = "<ol>"
|
67 |
+
for input_batch, simplifications in simplify(lines, model, tokenizer):
|
68 |
+
for input_text, simplification in zip(input_batch, simplifications):
|
69 |
+
output_ct.empty()
|
70 |
+
html += f"""<li>
|
71 |
+
<ul>
|
72 |
+
<li><strong>Input text:</strong> {input_text}</li>
|
73 |
+
<li><strong>Simplification:</strong> {simplification}</li>
|
74 |
+
</ul>
|
75 |
+
</li>"""
|
76 |
+
output_ct.markdown(html+"</ol>", unsafe_allow_html=True)
|
77 |
+
|
78 |
+
all_simplifications.extend(simplifications)
|
79 |
+
|
80 |
+
percent_done += increment
|
81 |
+
pbar.progress(min(percent_done, 100))
|
82 |
+
pbar.empty()
|
83 |
|
84 |
+
all_simplifications = "\n".join(all_simplifications) + "\n"
|
85 |
+
b64 = base64.b64encode(all_simplifications.encode("utf-8")).decode("utf-8")
|
86 |
+
results_col.markdown(f'<a download="dutch-simplifications.txt" href="data:file/txt;base64,{b64}" title="Download">Download simplifications</a>', unsafe_allow_html=True)
|
87 |
+
else:
|
88 |
+
btn_ct.empty()
|
89 |
+
error_ct.error("Text cannot be empty!", icon="⚠️")
|
90 |
+
simpl_ct.container()
|
91 |
|
92 |
|
93 |
########################
|
utils.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from threading import Thread
|
2 |
-
from typing import Tuple, Generator
|
3 |
|
4 |
from optimum.bettertransformer import BetterTransformer
|
5 |
import streamlit as st
|
@@ -10,7 +10,7 @@ from transformers import T5ForConditionalGeneration, T5Tokenizer, TextStreamer,
|
|
10 |
|
11 |
|
12 |
@st.cache_resource(show_spinner=False)
|
13 |
-
def get_resources(quantize: bool = True, no_cuda: bool = False) -> Tuple[T5ForConditionalGeneration, T5Tokenizer
|
14 |
"""
|
15 |
"""
|
16 |
tokenizer = T5Tokenizer.from_pretrained("BramVanroy/ul2-base-dutch-simplification-mai-2023", use_fast=False)
|
@@ -25,34 +25,36 @@ def get_resources(quantize: bool = True, no_cuda: bool = False) -> Tuple[T5ForCo
|
|
25 |
model = quantize_dynamic(model, {nn.Linear, nn.Dropout, nn.LayerNorm}, dtype=qint8)
|
26 |
|
27 |
model.eval()
|
28 |
-
streamer = TextIteratorStreamer(tokenizer, decode_kwargs={"skip_special_tokens": True, "clean_up_tokenization_spaces": True})
|
29 |
|
30 |
-
return model, tokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
|
33 |
def simplify(
|
34 |
-
|
35 |
model: T5ForConditionalGeneration,
|
36 |
tokenizer: T5Tokenizer,
|
37 |
-
|
38 |
-
) ->
|
39 |
"""
|
40 |
"""
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
for new_text in streamer:
|
57 |
-
generated_text += new_text
|
58 |
-
yield generated_text
|
|
|
1 |
from threading import Thread
|
2 |
+
from typing import Tuple, Generator, List
|
3 |
|
4 |
from optimum.bettertransformer import BetterTransformer
|
5 |
import streamlit as st
|
|
|
10 |
|
11 |
|
12 |
@st.cache_resource(show_spinner=False)
|
13 |
+
def get_resources(quantize: bool = True, no_cuda: bool = False) -> Tuple[T5ForConditionalGeneration, T5Tokenizer]:
|
14 |
"""
|
15 |
"""
|
16 |
tokenizer = T5Tokenizer.from_pretrained("BramVanroy/ul2-base-dutch-simplification-mai-2023", use_fast=False)
|
|
|
25 |
model = quantize_dynamic(model, {nn.Linear, nn.Dropout, nn.LayerNorm}, dtype=qint8)
|
26 |
|
27 |
model.eval()
|
|
|
28 |
|
29 |
+
return model, tokenizer
|
30 |
+
|
31 |
+
|
32 |
+
def batchify(iterable, batch_size=16):
|
33 |
+
num_items = len(iterable)
|
34 |
+
for idx in range(0, num_items, batch_size):
|
35 |
+
yield iterable[idx:min(idx + batch_size, num_items)]
|
36 |
|
37 |
|
38 |
def simplify(
|
39 |
+
texts: List[str],
|
40 |
model: T5ForConditionalGeneration,
|
41 |
tokenizer: T5Tokenizer,
|
42 |
+
batch_size: int = 16
|
43 |
+
) -> List[str]:
|
44 |
"""
|
45 |
"""
|
46 |
+
|
47 |
+
for batch_texts in batchify(texts, batch_size=batch_size):
|
48 |
+
nlg_batch_texts = ["[NLG] " + text for text in batch_texts]
|
49 |
+
encoded = tokenizer(nlg_batch_texts, return_tensors="pt", padding=True, truncation=True)
|
50 |
+
encoded = {k: v.to(model.device) for k, v in encoded.items()}
|
51 |
+
gen_kwargs = {
|
52 |
+
"max_new_tokens": 128,
|
53 |
+
"num_beams": 3,
|
54 |
+
}
|
55 |
+
|
56 |
+
with torch.no_grad():
|
57 |
+
encoded = {k: v.to(model.device) for k, v in encoded.items()}
|
58 |
+
generated = model.generate(**encoded, **gen_kwargs).cpu()
|
59 |
+
|
60 |
+
yield batch_texts, tokenizer.batch_decode(generated, skip_special_tokens=True)
|
|
|
|
|
|