Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
import torch.nn.functional as F | |
import pandas as pd | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
from datasets import load_dataset | |
device = 'cpu' | |
def get_model_and_tokenizer(): | |
model_name = "FacebookAI/roberta-base" | |
num_labels = 157 | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels) | |
chkp = torch.load("arxiv_roberta_final.pt", map_location=device) | |
model.load_state_dict(chkp['model']) | |
return model, tokenizer | |
def get_categories(): | |
categories = load_dataset("TimSchopf/arxiv_categories", "arxiv_category_descriptions") | |
cat2id = dict((cat, id) for id, cat in enumerate(categories['arxiv_category_descriptions']['tag'])) | |
id2cat = categories['arxiv_category_descriptions']['tag'] | |
names = categories['arxiv_category_descriptions']['name'] | |
return cat2id, id2cat, names | |
model, tokenizer = get_model_and_tokenizer() | |
cat2id, id2cat, cat_names = get_categories() | |
def predict_and_decode(model, title='', abstract=''): | |
model.eval() | |
inputs = tokenizer(title, abstract, return_tensors='pt', truncation=True, max_length=512).to(device) | |
logits = model(**inputs)['logits'][0].cpu() | |
df = pd.DataFrame([ | |
(id2cat[cat_id], cat_names[cat_id], prob.item()) | |
for cat_id, prob in enumerate(F.sigmoid(logits)) | |
], columns=("tag", "name", "probability")) | |
df.sort_values(by="probability", ascending=False, inplace=True) | |
return df.reset_index(drop=True) | |
st.header("Paper Category Classifier") | |
st.text("Input a title and/or an abstract of a scientific paper, and get classification according to arxiv.org categories") | |
input_container = st.container(border=True) | |
with input_container: | |
title_default = "Attention Is All You Need" | |
abstract_default = ( | |
"The dominant sequence transduction models are based on complex recurrent or convolutional neural networks " | |
"in an encoder-decoder configuration. The best performing models also connect the encoder and decoder through " | |
"an attention mechanism. We propose a new simple network architecture, the Transformer..." | |
) | |
line_height = 34 | |
n_lines = 10 | |
title = st.text_input("Paper title", value=title_default, help="Type in paper's title") | |
abstract = st.text_area("Paper abstract", value=abstract_default, height=line_height*n_lines, help="Type in paper's abstract") | |
if title or abstract: | |
result = predict_and_decode(model, title=title, abstract=abstract) | |
main_cnt = st.container(border=True) | |
with main_cnt: | |
st.markdown("#### Top category") | |
st.markdown(f"**{result.tag[0]}** -- {result.name[0]}") | |
st.markdown(f"Probability: {result.probability[0]*100:.2f}%") | |
rest_cnt = st.container(border=True) | |
with rest_cnt: | |
threshold = 0.55 | |
st.text("Other top categories:") | |
max_len = min(max(1, sum(result.iloc[1:].probability > threshold)), 5) | |
def format_p(example): | |
example.probability = f"{example.probability * 100 :.2f}%" | |
return example | |
st.table(result.iloc[1:1 + max_len].apply(format_p, axis=1)) | |
else: | |
st.warning("Type a title and/or an abstract to get started!") |