DanilO0o commited on
Commit
edcd390
·
1 Parent(s): e1607f1

added new model

Browse files
rugpt.py → app.py RENAMED
File without changes
models/Sasha_best_lstm_model3.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ebd4ef0cd62eb779c9f9b0dcf90bc23d63ec46884c035884a44258a5763ec1c6
3
+ size 46108066
models/Sasha_best_model_bert.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:39d23a0fa07df356ffb6207961cab88d4efaab0b9d2a3fad4ad620f3f89a73bd
3
+ size 117128865
models/Sasha_logistic_model2.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2467bdf90ee233cd38e6b3336fc5746dc6fd848d0ca1a53b002da85485366ac1
3
+ size 401015
models/bert_classifier.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers import AutoModel
4
+
5
+ class MyTinyBERT(nn.Module):
6
+ def __init__(self):
7
+ super().__init__()
8
+ self.bert = AutoModel.from_pretrained("cointegrated/rubert-tiny2")
9
+ for param in self.bert.parameters():
10
+ param.requires_grad = True
11
+ # Разморозка последних слоёв
12
+ for name, param in self.bert.named_parameters():
13
+ if any(layer in name for layer in ['layer.7', 'layer.8', 'layer.9', 'layer.10', 'layer.11']):
14
+ param.requires_grad = True
15
+
16
+ self.linear = nn.Sequential(
17
+ nn.Linear(312, 256),
18
+ nn.ReLU(),
19
+ nn.Dropout(0.3),
20
+ nn.Linear(256, 10)) # Для 10 классов
21
+
22
+ def forward(self, input_dict):
23
+ # Ожидается словарь с ключами "input_ids" и "attention_mask"
24
+ bert_out = self.bert(**input_dict)
25
+ # Используем скрытое состояние для [CLS] токена
26
+ normed_bert_out = nn.functional.normalize(bert_out.last_hidden_state[:, 0, :])
27
+ return self.linear(normed_bert_out)
models/lstm_attention.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ class LSTMAttention(nn.Module):
5
+ def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim):
6
+ super(LSTMAttention, self).__init__()
7
+ self.embedding = nn.Embedding(vocab_size, embedding_dim)
8
+ self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, bidirectional=True)
9
+ self.attention = nn.Linear(hidden_dim * 2, 1)
10
+ self.fc = nn.Linear(hidden_dim * 2, output_dim)
11
+ self.dropout = nn.Dropout(0.5)
12
+
13
+ def forward(self, input_ids):
14
+ # Embedding слой
15
+ embedded = self.embedding(input_ids) # (batch_size, seq_len, embedding_dim)
16
+
17
+ # LSTM слой
18
+ lstm_out, _ = self.lstm(embedded) # (batch_size, seq_len, hidden_dim*2)
19
+
20
+ # Механизм внимания
21
+ attn_weights = torch.softmax(self.attention(lstm_out), dim=1) # (batch_size, seq_len, 1)
22
+
23
+ # Вектор контекста
24
+ context_vector = torch.sum(attn_weights * lstm_out, dim=1) # (batch_size, hidden_dim*2)
25
+
26
+ # Классификатор
27
+ output = self.fc(self.dropout(context_vector)) # (batch_size, output_dim)
28
+
29
+ return output, attn_weights.squeeze(-1)
models/text_preprocessor.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #text_preprocessor.py
2
+
3
+ import pandas as pd
4
+ import re
5
+ import string
6
+ import pymorphy3
7
+ from sklearn.base import BaseEstimator, TransformerMixin
8
+
9
+ class MyCustomTextPreprocessor(BaseEstimator, TransformerMixin):
10
+ def __init__(self):
11
+ self.stop_words = self.get_stopwords_list()
12
+ self.morph = pymorphy3.MorphAnalyzer()
13
+
14
+ def fit(self, X, y=None):
15
+ return self
16
+
17
+ def transform(self, texts, y=None, lemmatize=True):
18
+ return [self.preprocess(text, lemmatize=lemmatize) for text in texts]
19
+
20
+ def get_stopwords_list(self):
21
+ url = "https://raw.githubusercontent.com/stopwords-iso/stopwords-ru/master/stopwords-ru.txt"
22
+ stopwords_cust = set(pd.read_csv(url, header=None, names=["stopwords"], encoding="utf-8")['stopwords'])
23
+ return stopwords_cust
24
+
25
+ def clean(self, text):
26
+ text = text.lower()
27
+ text = re.sub(r'http\S+', " ", text)
28
+ text = re.sub(r'@\w+', ' ', text)
29
+ text = re.sub(r'#\w+', ' ', text)
30
+ text = re.sub(r'\d+', ' ', text)
31
+ text = re.sub(r'[^\w\s,]', '', text)
32
+ text = text.translate(str.maketrans('', '', string.punctuation))
33
+ text = re.sub(r'<.*?>', ' ', text)
34
+ text = re.sub(r'[\u00A0\u2000-\u206F]', ' ', text)
35
+ text = re.sub(r'[a-zA-Z]', '', text)
36
+ text = re.sub(r'\s+', ' ', text).strip()
37
+ return text
38
+
39
+ def remove_stopwords(self, text):
40
+ return ' '.join([word for word in text.split() if word not in self.stop_words])
41
+
42
+ def lemmatize(self, text):
43
+ morph = self.morph
44
+ lemmatized_text = ''
45
+ for word in text.split():
46
+ lemmatized_text += morph.parse(word)[0].normal_form + " "
47
+ return lemmatized_text
48
+
49
+ def preprocess(self, text, lemmatize=True):
50
+ """Общая функция обработки текста с возможностью отключить лемматизацию"""
51
+ text = self.clean(text)
52
+ text = self.remove_stopwords(text)
53
+ if lemmatize:
54
+ text = self.lemmatize(text)
55
+ return text
models/vectorizer.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ad82a7fee20e9a6f702eb77245b35260e61b5cf4931616a3872da2ce3cb352ff
3
+ size 238499
pages/sasha_main_page_final.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.bert_classifier import MyTinyBERT
2
+ from models.lstm_attention import LSTMAttention
3
+ from models.text_preprocessor import MyCustomTextPreprocessor
4
+ import streamlit as st
5
+ from sklearn.utils.class_weight import compute_class_weight
6
+ import torch.nn.functional as F
7
+ import torch.optim as optim
8
+ import joblib
9
+ from torch import nn
10
+ from sklearn.base import BaseEstimator, TransformerMixin
11
+ from transformers import AutoTokenizer, AutoModel
12
+ from sklearn.metrics import confusion_matrix, f1_score, accuracy_score
13
+ from sklearn.linear_model import LogisticRegression
14
+ from sklearn.model_selection import train_test_split
15
+ from torch.utils.data import DataLoader, TensorDataset
16
+ from time import time
17
+ from sklearn.feature_extraction.text import TfidfVectorizer
18
+ import pymorphy3
19
+ import string
20
+ import re
21
+ import pandas as pd
22
+ import numpy as np
23
+ import torch
24
+ import sklearn
25
+ import matplotlib.pyplot as plt
26
+ import warnings
27
+ warnings.simplefilter("ignore")
28
+ # Metrics
29
+ # custom
30
+
31
+
32
+ # ======= Глобальная инициализация токенизатора =======
33
+ tokenizer = AutoTokenizer.from_pretrained(
34
+ "cointegrated/rubert-tiny2") # Для LSTM и BERT
35
+
36
+ # ======= Инициализация обработчика текста =======
37
+ preprocessor = MyCustomTextPreprocessor()
38
+
39
+ # ======= Загрузка моделей и векторизатора =======
40
+ # @st.cache_resource
41
+
42
+
43
+ def load_resources():
44
+ # Загрузка TF-IDF векторизатора
45
+ vectorizer = joblib.load('models/vectorizer.pkl') # TF-IDF
46
+
47
+ # Загрузка модели логистической регрессии
48
+ # Логистическая регрессия
49
+ model1 = joblib.load('models/Sasha_logistic_model2.pkl')
50
+
51
+ # Настройка модели LSTM
52
+ # Используем уже загруженный токенизатор
53
+ VOCAB_SIZE = len(tokenizer.get_vocab())
54
+ EMBEDDING_DIM = 128
55
+ HIDDEN_DIM = 256
56
+ OUTPUT_DIM = 10
57
+ model2 = LSTMAttention(VOCAB_SIZE, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM)
58
+ model2.load_state_dict(torch.load(
59
+ 'models/Sasha_best_lstm_model3.pth', map_location=torch.device('cpu')))
60
+ model2.eval()
61
+
62
+ # Настройка модели BERT
63
+ model3 = MyTinyBERT()
64
+ model3.load_state_dict(torch.load(
65
+ 'models/Sasha_best_model_bert.pth', map_location=torch.device('cpu')))
66
+ model3.eval()
67
+
68
+ return model1, model2, model3, vectorizer
69
+
70
+
71
+ # Загружаем ресурсы
72
+ model1, model2, model3, vectorizer = load_resources()
73
+
74
+ # ======= Предобработка текста =======
75
+
76
+
77
+ def preprocess_for_model1(text):
78
+ """TF-IDF векторизация для логистической регрессии"""
79
+ processed_text = preprocessor.preprocess(
80
+ text, lemmatize=True) # Лемматизация включена
81
+ return vectorizer.transform([processed_text])
82
+
83
+
84
+ def preprocess_for_model2_and_model3(text):
85
+ """Общая обработка для LSTM и BERT моделей (без лемматизации)"""
86
+ processed_text = preprocessor.preprocess(
87
+ text, lemmatize=False) # Лемматизация выключена
88
+ return processed_text
89
+
90
+
91
+ def preprocess_for_model2(text, tokenizer):
92
+ """Токенизация для LSTM модели"""
93
+ processed_text = preprocess_for_model2_and_model3(text)
94
+ tokenized_data = tokenizer(
95
+ [processed_text],
96
+ padding=True,
97
+ truncation=True,
98
+ return_tensors="pt",
99
+ max_length=256
100
+ )
101
+ return tokenized_data["input_ids"], tokenized_data["attention_mask"]
102
+
103
+
104
+ def preprocess_for_model3(text, tokenizer):
105
+ """Токенизация для BERT модели"""
106
+ processed_text = preprocess_for_model2_and_model3(text)
107
+ tokenized_data = tokenizer(
108
+ [processed_text],
109
+ padding=True,
110
+ truncation=True,
111
+ return_tensors="pt",
112
+ max_length=256
113
+ )
114
+ return tokenized_data
115
+
116
+
117
+ # ======= Прогноз и визуализация =======
118
+ def predict_and_visualize(text):
119
+ # ======= Модель 1 (Logistic Regression) =======
120
+ start_time = time() # Начало времени предсказания
121
+ vectorized_text = preprocess_for_model1(text)
122
+ probs1 = model1.predict_proba(vectorized_text)[0]
123
+ model1_time = time() - start_time # Рассчитываем время предсказания для модели 1
124
+
125
+ # ======= Модель 2 (LSTM & Attention) =======
126
+ start_time = time() # Начало времени предсказания
127
+ input_ids, _ = preprocess_for_model2(
128
+ text, tokenizer) # Получаем только input_ids
129
+ with torch.no_grad():
130
+ logits2, attn_weights = model2(input_ids) # Передаём только input_ids
131
+ probs2 = torch.softmax(logits2, dim=1).numpy()[0]
132
+ attention_vector = attn_weights.cpu().numpy()[0]
133
+ model2_time = time() - start_time # Рассчитываем время предсказ��ния для модели 2
134
+
135
+ # ======= Модель 3 (BERT) =======
136
+ start_time = time() # Начало времени предсказания
137
+ tokenized_text = preprocess_for_model3(text, tokenizer)
138
+ with torch.no_grad():
139
+ logits3 = model3(tokenized_text)
140
+ probs3 = torch.softmax(logits3, dim=1).numpy()[0]
141
+ model3_time = time() - start_time # Рассчитываем время предсказания для модели 3
142
+
143
+ # ======= Финальное предсказание =======
144
+ final_probs = (probs1 + probs2 + probs3) / 3
145
+ final_class = np.argmax(final_probs)
146
+
147
+ # ======= Визуализация =======
148
+ st.subheader("Распределение вероятностей")
149
+ for probs, model_name in zip([probs1, probs2, probs3], ['Model 1 (Logistic Regression)', 'Model 2 (LSTM)', 'Model 3 (BERT)']):
150
+ fig, ax = plt.subplots()
151
+ ax.bar(range(1, len(probs) + 1), probs) # Сдвиг индекса на +1
152
+ ax.set_title(f'{model_name} Probabilities')
153
+ ax.set_xlabel('Class (1-10)')
154
+ ax.set_ylabel('Probability')
155
+ st.pyplot(fig)
156
+
157
+ # ======= Визуализация внимания (LSTM) =======
158
+ st.subheader("Веса внимания (LSTM)")
159
+
160
+ # Проверяем наличие attention weights
161
+ tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
162
+ tokens = tokens[:len(attention_vector)]
163
+ attention_vector = attention_vector[:len(tokens)]
164
+
165
+ fig, ax = plt.subplots(figsize=(12, 6))
166
+ ax.bar(range(len(tokens)), attention_vector, align="center")
167
+ ax.set_xticks(range(len(tokens)))
168
+ ax.set_xticklabels(tokens, rotation=45, ha="right")
169
+ ax.set_title("Attention Weights (LSTM)")
170
+ ax.set_xlabel("Токены")
171
+ ax.set_ylabel("Вес внимания")
172
+ st.pyplot(fig)
173
+
174
+ # Итоговое предсказание
175
+ st.subheader("Итоговое предсказание")
176
+ # Смещение на +1
177
+ st.write(f"Наиболее вероятный класс: **{final_class + 1}**")
178
+
179
+ # Вывод времени выполнения
180
+ st.subheader("Время выполнения моделей")
181
+ st.write(f"Модель 1 (Logistic Regression): {model1_time:.4f} секунд")
182
+ st.write(f"Модель 2 (LSTM): {model2_time:.4f} секунд")
183
+ st.write(f"Модель 3 (BERT): {model3_time:.4f} секунд")
184
+
185
+ return final_class
186
+
187
+
188
+ # ======= Streamlit UI =======
189
+ st.title("Классификация текстов с 3 моделями")
190
+ st.write("Введите текст отзыва, чтобы получить результаты классификации от трёх моделей.")
191
+
192
+ # Ввод текста пользователем
193
+ user_input = st.text_area("Введите текст отзыва:", "")
194
+
195
+ if st.button("Классифицировать"):
196
+ if user_input.strip():
197
+ # Прогноз и визуализация
198
+ predict_and_visualize(user_input)
199
+ else:
200
+ st.warning("Введите текст для анализа.")
201
+
202
+ st.subheader("F1 macro, валидационная выборка")
203
+ st.write(f'f1 macro valid logreg=0.2516')
204
+ st.write(f'f1 macro valid lstm=0.2515')
205
+ st.write(f'f1 macro valid bert=0.2709')
requirements.txt CHANGED
@@ -1,3 +1,69 @@
1
  streamlit
2
  torch
3
  transformers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  streamlit
2
  torch
3
  transformers
4
+ altair==5.5.0
5
+ attrs==24.2.0
6
+ blinker==1.9.0
7
+ cachetools==5.5.0
8
+ certifi==2024.8.30
9
+ charset-normalizer==3.4.0
10
+ click==8.1.7
11
+ contourpy==1.3.1
12
+ cycler==0.12.1
13
+ DAWG-Python==0.7.2
14
+ filelock
15
+ fonttools==4.55.0
16
+ fsspec
17
+ gitdb==4.0.11
18
+ GitPython==3.1.43
19
+ gmpy2
20
+ huggingface-hub==0.26.3
21
+ idna==3.10
22
+ Jinja2
23
+ joblib==1.4.2
24
+ jsonschema==4.23.0
25
+ jsonschema-specifications==2024.10.1
26
+ kiwisolver==1.4.7
27
+ markdown-it-py==3.0.0
28
+ MarkupSafe
29
+ matplotlib==3.9.2
30
+ mdurl==0.1.2
31
+ mpmath
32
+ narwhals==1.14.3
33
+ networkx
34
+ numpy
35
+ packaging==24.2
36
+ pandas==2.2.3
37
+ pillow==11.0.0
38
+ protobuf==5.29.0
39
+ pyarrow==18.1.0
40
+ pydeck==0.9.1
41
+ Pygments==2.18.0
42
+ pymorphy3==2.0.2
43
+ pymorphy3-dicts-ru==2.4.417150.4580142
44
+ pyparsing==3.2.0
45
+ python-dateutil==2.9.0.post0
46
+ pytz==2024.2
47
+ PyYAML==6.0.2
48
+ referencing==0.35.1
49
+ regex==2024.11.6
50
+ requests==2.32.3
51
+ rich==13.9.4
52
+ rpds-py==0.21.0
53
+ safetensors==0.4.5
54
+ scikit-learn==1.5.2
55
+ scipy==1.14.1
56
+ sentencepiece==0.2.0
57
+ six==1.16.0
58
+ smmap==5.0.1
59
+ sympy
60
+ tenacity==9.0.0
61
+ threadpoolctl==3.5.0
62
+ tokenizers==0.20.3
63
+ toml==0.10.2
64
+ tornado==6.4.2
65
+ tqdm==4.67.1
66
+ typing_extensions
67
+ tzdata==2024.2
68
+ urllib3==2.2.3
69
+ watchdog==6.0.0