Cricles commited on
Commit
52b49c5
·
1 Parent(s): 296c318

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -0
app.py CHANGED
@@ -2,6 +2,9 @@ import streamlit as st
2
  import base64
3
  import fasttext
4
  import re
 
 
 
5
 
6
  st.set_page_config(
7
  page_title="detoxi.ai",
@@ -54,6 +57,40 @@ st.markdown(
54
  st.write("""<p style='text-align: center; font-size: 24px;'>Это приложение сделает твою речь менее токсичной.
55
  И даже не придётся платить 300 bucks.</p>""", unsafe_allow_html=True)
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def highlight_obscene_words(text):
58
  label,_=model.predict(text.lower())
59
  if label[0]=='__label__positive':
 
2
  import base64
3
  import fasttext
4
  import re
5
+ import torch
6
+ from transformers import AutoModelForSequenceClassification
7
+ from transformers import BertTokenizerFast
8
 
9
  st.set_page_config(
10
  page_title="detoxi.ai",
 
57
  st.write("""<p style='text-align: center; font-size: 24px;'>Это приложение сделает твою речь менее токсичной.
58
  И даже не придётся платить 300 bucks.</p>""", unsafe_allow_html=True)
59
 
60
+ class ModelWrapper(object):
61
+ MODELS_DIR: str = "./new_models/"
62
+ MODEL_NAME: str = "model"
63
+ TOKENIZER: str = "tokenizer"
64
+
65
+ def __init__(self):
66
+ self.model = AutoModelForSequenceClassification.from_pretrained(
67
+ ModelWrapper.MODELS_DIR + ModelWrapper.MODEL_NAME, torchscript=True
68
+ )
69
+ self.tokenizer = BertTokenizerFast.from_pretrained(
70
+ ModelWrapper.MODELS_DIR + ModelWrapper.TOKENIZER
71
+ )
72
+ self.id2label: dict[int, str] = {0: "__label__positive", 1: "__label__negative"}
73
+
74
+ @torch.no_grad()
75
+ def __call__(self, text: str) -> str:
76
+ max_input_length = (
77
+ self.model.config.max_position_embeddings
78
+ ) # 512 for this model
79
+ inputs = self.tokenizer(
80
+ text,
81
+ max_length=max_input_length,
82
+ padding=True,
83
+ truncation=True,
84
+ return_tensors="pt",
85
+ )
86
+ outputs = self.model(
87
+ **inputs, return_dict=True
88
+ ) # output is logits for huggingfcae transformers
89
+ predicted = torch.nn.functional.softmax(outputs.logits, dim=1)
90
+ predicted_id = torch.argmax(predicted, dim=1).numpy()[0]
91
+ return self.id2label[predicted_id]
92
+
93
+
94
  def highlight_obscene_words(text):
95
  label,_=model.predict(text.lower())
96
  if label[0]=='__label__positive':