Spaces:
Runtime error
Runtime error
| import torch | |
| from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast | |
| import streamlit as st | |
| def get_text(title: str, abstract: str): | |
| if abstract and title: | |
| text = abstract + ' ' + title | |
| elif title: | |
| text = title | |
| elif abstract: | |
| text = abstract | |
| else: | |
| text = None | |
| return text | |
| def get_labels(text, model, tokenizer, count_labels=8): | |
| tokens = tokenizer(text, return_tensors='pt') | |
| outputs = model(**tokens) | |
| probs = torch.nn.Softmax()(outputs.logits) | |
| labels = ['Computer_science', 'Economics', | |
| 'Electrical_Engineering_and_Systems_Science', 'Mathematics', | |
| 'Physics', 'Quantitative_Biology', 'Quantitative_Finance', | |
| 'Statistics'] | |
| sort_lst = sorted([(prob, label) for prob, label in zip(probs.detach().numpy()[0], labels)], key=lambda x: -x[0]) | |
| cumsum = 0 | |
| result_labels = [] | |
| for pair in sort_lst: | |
| cumsum += pair[0] | |
| if cumsum > 0.95: | |
| result_labels.append(pair[1]) | |
| return result_labels | |
| result_labels.append(pair[1]) | |
| def load_model(): | |
| tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-cased") | |
| model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-cased", num_labels=8) | |
| model.load_state_dict(torch.load('weight_model')) | |
| return model, tokenizer | |