import streamlit as st import torch import torch.nn.functional as F import pandas as pd from transformers import AutoModelForSequenceClassification, AutoTokenizer from datasets import load_dataset device = 'cpu' @st.cache_resource def get_model_and_tokenizer(): model_name = "FacebookAI/roberta-base" num_labels = 157 tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels) chkp = torch.load("arxiv_roberta_final.pt", map_location=device) model.load_state_dict(chkp['model']) return model, tokenizer @st.cache_data def get_categories(): categories = load_dataset("TimSchopf/arxiv_categories", "arxiv_category_descriptions") cat2id = dict((cat, id) for id, cat in enumerate(categories['arxiv_category_descriptions']['tag'])) id2cat = categories['arxiv_category_descriptions']['tag'] names = categories['arxiv_category_descriptions']['name'] return cat2id, id2cat, names model, tokenizer = get_model_and_tokenizer() cat2id, id2cat, cat_names = get_categories() @torch.no_grad def predict_and_decode(model, title='', abstract=''): model.eval() inputs = tokenizer(title, abstract, return_tensors='pt', truncation=True, max_length=512).to(device) logits = model(**inputs)['logits'][0].cpu() df = pd.DataFrame([ (id2cat[cat_id], cat_names[cat_id], prob.item()) for cat_id, prob in enumerate(F.sigmoid(logits)) ], columns=("tag", "name", "probability")) df.sort_values(by="probability", ascending=False, inplace=True) return df.reset_index(drop=True) st.header("Paper Category Classifier") st.text("Input a title and/or an abstract of a scientific paper, and get classification according to arxiv.org categories") input_container = st.container(border=True) with input_container: title_default = "Attention Is All You Need" abstract_default = ( "The dominant sequence transduction models are based on complex recurrent or convolutional neural networks " "in an encoder-decoder configuration. The best performing models also connect the encoder and decoder through " "an attention mechanism. We propose a new simple network architecture, the Transformer..." ) line_height = 34 n_lines = 10 title = st.text_input("Paper title", value=title_default, help="Type in paper's title") abstract = st.text_area("Paper abstract", value=abstract_default, height=line_height*n_lines, help="Type in paper's abstract") if title or abstract: result = predict_and_decode(model, title=title, abstract=abstract) main_cnt = st.container(border=True) with main_cnt: st.markdown("#### Top category") st.markdown(f"**{result.tag[0]}** -- {result.name[0]}") st.markdown(f"Probability: {result.probability[0]*100:.2f}%") rest_cnt = st.container(border=True) with rest_cnt: threshold = 0.55 st.text("Other top categories:") max_len = min(max(1, sum(result.iloc[1:].probability > threshold)), 5) def format_p(example): example.probability = f"{example.probability * 100 :.2f}%" return example st.table(result.iloc[1:1 + max_len].apply(format_p, axis=1)) else: st.warning("Type a title and/or an abstract to get started!")