Spaces:
Runtime error
Runtime error
File size: 4,970 Bytes
455ab40 272f043 db551a2 73f115b 0c3ea8f 77f8a57 0c3ea8f 77f8a57 ab73934 8ad5809 0c3ea8f 77f8a57 0c3ea8f 77f8a57 ab73934 77f8a57 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from transformers import AutoTokenizer
from transformers import TFAutoModelForQuestionAnswering
from datasets import Dataset
import streamlit as st
# loading saved roberta-base tokenizer to tokenize the text into input IDs that model can make sense of.
model_checkpoint = "Modfiededition/roberta-fine-tuned-tweet-sentiment-extractor"
@st.cache(allow_output_mutation=True, suppress_st_warning=True)
def load_tokenizer():
return AutoTokenizer.from_pretrained(model_checkpoint )
tokenizer = load_tokenizer()
@st.cache(allow_output_mutation=True, suppress_st_warning=True)
def load_model():
return TFAutoModelForQuestionAnswering.from_pretrained(model_checkpoint)
model = load_model()
#prompts
st.title("Tweet Sentiment Extractor...")
# take text/tweet input
textbox = st.text_area('Write your text in this box:', '',height=100, max_chars=500 )
option = st.selectbox(
'How would you like to be contacted?',
('positive', 'negative', 'neutral'))
st.write(option)
python_dict = {"text":[textbox], "sentiment":[option]}
dataset = Dataset.from_dict(python_dict)
MAX_LENGTH = 105
button = st.button('Extract text of the given sentiment..')
if button:
with st.spinner('In progress.......'):
def process_data(examples):
questions = examples["sentiment"]
context = examples["text"]
inputs = tokenizer(
questions,
context,
max_length = MAX_LENGTH,
padding="max_length",
return_offsets_mapping = True,
)
# Assigning None values to all offset mapping of tokens which are not the context tokens.
for i in range(len(inputs["input_ids"])):
offset = inputs["offset_mapping"][i]
sequence_ids = inputs.sequence_ids(i)
inputs["offset_mapping"][i] = [
o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
]
return inputs
processed_raw_data = dataset.map(
process_data,
batched = True
)
tf_raw_dataset = processed_raw_data.to_tf_dataset(
columns=["input_ids", "attention_mask"],
shuffle=False,
batch_size=1,
)
# final predictions.
outputs = model.predict(tf_raw_dataset)
start_logits = outputs.start_logits
end_logits = outputs.end_logits
# Post Processing.
# Using start_logits and end_logits to generate the final answer from the given context.
n_best = 20
def predict_answers(inputs):
predicted_answer = []
for i in range(len(inputs["offset_mapping"])):
start_logit = inputs["start_logits"][i]
end_logit = inputs["end_logits"][i]
context = inputs["text"][i]
offset = inputs["offset_mapping"][i]
start_indexes = np.argsort(start_logit)[-1: -n_best - 1:-1].tolist()
end_indexes = np.argsort(end_logit)[-1: -n_best - 1: -1].tolist()
flag = False
for start_index in start_indexes:
for end_index in end_indexes:
# skip answer that are not in the context.
if offset[start_index] is None or offset[end_index] is None:
continue
# skip answer with length that is either < 0
if end_index < start_index:
continue
flag = True
answer = context[offset[start_index][0]: offset[end_index][1]]
predicted_answer.append(answer)
break
if flag:
break
if not flag:
predicted_answer.append(answer)
return {"predicted_answer":predicted_answer}
processed_raw_data.set_format("pandas")
processed_raw_df = processed_raw_data[:]
processed_raw_df["start_logits"] = start_logits.tolist()
processed_raw_df["end_logits"] = end_logits.tolist()
processed_raw_df["text"] = python_dict["text"]
final_data = Dataset.from_pandas(processed_raw_df)
final_data = final_data.map(predict_answers,batched=True)
st.markdown("## " +final_data["predicted_answer"] ) |