Runtime error
Runtime error
import os | |
import json | |
import numpy as np | |
import pandas as pd | |
import random | |
import streamlit as st | |
import torch | |
import torch.nn.functional as F | |
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification | |
def init_model(): | |
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased') | |
model = DistilBertForSequenceClassification.from_pretrained('khizon/distilbert-unreliable-news-eng-4L', num_labels = 2) | |
return tokenizer, model | |
def download_dataset(): | |
url = '' | |
data = 'data/nela_gt_2018_site_split' | |
os.system(f'gdown --folder {url} -O {data}') | |
def jsonl_to_df(file_path): | |
with open(file_path) as f: | |
lines = | |
df_inter = pd.DataFrame(lines) | |
df_inter.columns = ['json_element'] | |
df_inter['json_element'].apply(json.loads) | |
return pd.json_normalize(df_inter['json_element'].apply(json.loads)) | |
def load_test_df(): | |
file_path = os.path.join('data', 'nela_gt_2018_site_split', 'test.jsonl') | |
test_df = jsonl_to_df(file_path) | |
test_df = pd.get_dummies(test_df, columns = ['label']) | |
return test_df | |
def predict(model, tokenizer, data): | |
labels = data[['label_0', 'label_1']] | |
labels = torch.tensor(labels, dtype=torch.float32) | |
encoding = tokenizer.encode_plus( | |
data['title'], | |
' [SEP] ' + data['content'], | |
add_special_tokens=True, | |
max_length = 512, | |
return_token_type_ids = False, | |
padding = 'max_length', | |
truncation = 'only_second', | |
return_attention_mask = True, | |
return_tensors = 'pt' | |
) | |
output = model(**encoding) | |
return correct_preds(output['logits'], labels) | |
def predict_new(model, tokenizer, title, content): | |
encoding = tokenizer.encode_plus( | |
title, | |
' [SEP] ' + content, | |
add_special_tokens=True, | |
max_length = 512, | |
return_token_type_ids = False, | |
padding = 'max_length', | |
truncation = 'only_second', | |
return_attention_mask = True, | |
return_tensors = 'pt' | |
) | |
output = model(**encoding) | |
preds = F.softmax(output['logits'], dim = 1) | |
p_idx = torch.argmax(preds, dim = 1) | |
return 'reliable' if p_idx > 0 else 'unreliable' | |
def correct_preds(preds, labels): | |
preds = torch.nn.functional.softmax(preds, dim = 1) | |
p_idx = torch.argmax(preds, dim=1) | |
l_idx = torch.argmax(labels, dim=0) | |
pred_label = 'reliable' if p_idx > 0 else 'unreliable' | |
correct = True if (p_idx == l_idx).sum().item() > 0 else False | |
return pred_label, correct | |
if __name__ == '__main__': | |
if not os.path.exists('data/nela_gt_2018_site_split/test.jsonl'): | |
download_dataset() | |
df = load_test_df() | |
tokenizer, model = init_model() | |
st.title("Unreliable News classifier") | |
mode = | |
'', ('Test article', 'Input own article') | |
) | |
if mode == 'Test article': | |
if st.button('Get random article'): | |
idx = np.random.randint(0, len(df)) | |
sample = df.iloc[idx] | |
prediction, correct = predict(model, tokenizer, sample) | |
label = 'reliable' if sample['label_1'] > sample['label_0'] else 'unreliable' | |
st.header(sample['title']) | |
if correct: | |
st.success(f'Prediction: {prediction}') | |
else: | |
st.error(f'Prediction: {prediction}') | |
st.caption(f'Source: {sample["source"]} ({label})') | |
# if len(sample['content']) > 300: | |
# sample['content'] = sample['content'][:300] | |
temp = [] | |
for idx, word in enumerate(sample['content'].split()): | |
if (random.randint(0, 99)> 45) and idx > 0: | |
word = '▒'*len(word) | |
temp.append(word) | |
sample['content'] = ' '.join(temp) | |
st.markdown(sample['content']) | |
else: | |
title = st.text_input('Article title', 'Test title') | |
content = st.text_area('Article content', 'Lorem ipsum') | |
if st.button('Submit'): | |
pred = predict_new(model, tokenizer, title, content) | |
st.markdown(f'Prediction: {pred}') | |
# st.success('success') |