Cricles commited on
Commit
80214a5
·
verified ·
1 Parent(s): f6488eb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -36
app.py CHANGED
@@ -19,10 +19,9 @@ def get_image_base64(path):
19
 
20
  @st.cache_resource # Кэширование модели для ускорения работы
21
  def load_model():
22
- model = fasttext.load_model('./model_fasttext.bin')
23
- return model
24
 
25
- @st.cache_resource
26
  class ModelWrapper(object):
27
  MODELS_DIR: str = "./new_models/"
28
  MODEL_NAME: str = "model"
@@ -54,10 +53,9 @@ class ModelWrapper(object):
54
  ) # output is logits for huggingfcae transformers
55
  predicted = torch.nn.functional.softmax(outputs.logits, dim=1)
56
  predicted_id = torch.argmax(predicted, dim=1).numpy()[0]
57
- return self.id2label[predicted_id]
58
 
59
- model_fasttext = load_model()
60
- model_BERT=ModelWrapper()
61
 
62
  bin_str = get_image_base64("./билли.png")
63
  page_bg_img = '''
@@ -112,41 +110,22 @@ model_type = st.radio(
112
  )
113
 
114
  def highlight_obscene_words(text, model_type):
115
- if model_type=="fasttext":
116
- label,_=model_fasttext.predict(text.lower())
117
- if label[0]=='__label__positive':
118
- st.markdown(
119
- "<span style='background:#47916B;'>{}|приемлемо</span>".format(text),
120
- unsafe_allow_html=True
121
- )
122
- else:
123
- st.markdown(
124
- "<span style='background:#ffcccc;'>{}|токсично</span>".format(text),
125
- unsafe_allow_html=True
126
- )
127
- elif model_type=="ru-BERT":
128
- label=model_BERT(text.lower())
129
- if label=='__label__positive':
130
- st.markdown(
131
- "<span style='background:#47916B;'>{}|приемлемо</span>".format(text),
132
- unsafe_allow_html=True
133
- )
134
- else:
135
- st.markdown(
136
- "<span style='background:#ffcccc;'>{}|токсично</span>".format(text),
137
- unsafe_allow_html=True
138
- )
139
  else:
140
- #ЗАГЛУШКА
141
  st.markdown(
142
- "<span style='background:#47916B;'>{}|приемлемо</span>".format(text),
143
- unsafe_allow_html=True
144
- )
145
-
146
  if st.button("Проверить текст"):
147
  if user_input.strip():
148
  st.subheader("Результат:")
149
- result = re.split(r'[.\n]+', user_input)
150
  result = [part for part in result if part.strip() != ""]
151
  if result!=[]:
152
  for text in result:
 
19
 
20
  @st.cache_resource # Кэширование модели для ускорения работы
21
  def load_model():
22
+ return ModelWrapper()
 
23
 
24
+ """@st.cache_resource
25
  class ModelWrapper(object):
26
  MODELS_DIR: str = "./new_models/"
27
  MODEL_NAME: str = "model"
 
53
  ) # output is logits for huggingfcae transformers
54
  predicted = torch.nn.functional.softmax(outputs.logits, dim=1)
55
  predicted_id = torch.argmax(predicted, dim=1).numpy()[0]
56
+ return self.id2label[predicted_id]"""
57
 
58
+ model_wrapper= load_model()
 
59
 
60
  bin_str = get_image_base64("./билли.png")
61
  page_bg_img = '''
 
110
  )
111
 
112
  def highlight_obscene_words(text, model_type):
113
+ label=model_wrapper(text.lower(),model_type)
114
+ if label=='__label__positive':
115
+ st.markdown(
116
+ "<span style='background:#47916B;'>{}|приемлемо</span>".format(text),
117
+ unsafe_allow_html=True
118
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  else:
 
120
  st.markdown(
121
+ "<span style='background:#ffcccc;'>{}|токсично</span>".format(text),
122
+ unsafe_allow_html=True
123
+ )
124
+
125
  if st.button("Проверить текст"):
126
  if user_input.strip():
127
  st.subheader("Результат:")
128
+ result = re.split(r'[.\n!?]+', user_input)
129
  result = [part for part in result if part.strip() != ""]
130
  if result!=[]:
131
  for text in result: