| | """
|
| | Module định nghĩa các mô hình cho spam review detection
|
| | """
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| | from transformers import AutoModel, AutoConfig, AutoModelForSequenceClassification
|
| | from .custom_models import TextCNN, BiLSTM, RoBERTaGRU, SPhoBERT
|
| |
|
| | class TransformerForSpamDetection(nn.Module):
|
| | """
|
| | Base transformer model cho spam review detection
|
| | """
|
| | def __init__(self, model_name: str, num_labels: int):
|
| | super().__init__()
|
| | config = AutoConfig.from_pretrained(model_name, num_labels=num_labels)
|
| | self.encoder = AutoModel.from_pretrained(model_name, config=config)
|
| | self.classifier = nn.Linear(config.hidden_size, num_labels)
|
| | self.dropout = nn.Dropout(0.1)
|
| |
|
| | def forward(self, input_ids, attention_mask, labels=None, **kwargs):
|
| |
|
| | filtered_kwargs = {k: v for k, v in kwargs.items()
|
| | if k not in ['num_items_in_batch', 'position_ids']}
|
| |
|
| |
|
| | out = self.encoder(input_ids=input_ids, attention_mask=attention_mask, **filtered_kwargs)
|
| | pooled = out.last_hidden_state[:, 0]
|
| | pooled = self.dropout(pooled)
|
| | logits = self.classifier(pooled)
|
| | loss = None
|
| | if labels is not None:
|
| | loss_fn = nn.CrossEntropyLoss()
|
| | loss = loss_fn(logits, labels)
|
| | return {"loss": loss, "logits": logits}
|
| |
|
| | class ViT5ForSpamDetection(nn.Module):
|
| | """
|
| | ViT5 model cho spam review detection - sử dụng encoder-only approach
|
| | """
|
| | def __init__(self, model_name: str, num_labels: int):
|
| | super().__init__()
|
| | from transformers import T5EncoderModel, T5Config
|
| |
|
| |
|
| | config = T5Config.from_pretrained(model_name)
|
| | self.t5_encoder = T5EncoderModel.from_pretrained(model_name, config=config)
|
| |
|
| |
|
| | self.classifier = nn.Linear(config.d_model, num_labels)
|
| | self.dropout = nn.Dropout(0.1)
|
| |
|
| | def forward(self, input_ids, attention_mask, labels=None, **kwargs):
|
| |
|
| | filtered_kwargs = {k: v for k, v in kwargs.items()
|
| | if k not in ['num_items_in_batch', 'position_ids']}
|
| |
|
| |
|
| | encoder_outputs = self.t5_encoder(input_ids=input_ids, attention_mask=attention_mask, **filtered_kwargs)
|
| |
|
| |
|
| | pooled = encoder_outputs.last_hidden_state[:, 0]
|
| | pooled = self.dropout(pooled)
|
| | logits = self.classifier(pooled)
|
| |
|
| | loss = None
|
| | if labels is not None:
|
| | loss_fn = nn.CrossEntropyLoss()
|
| | loss = loss_fn(logits, labels)
|
| |
|
| | return {"loss": loss, "logits": logits}
|
| |
|
| | def get_model(model_name: str, num_labels: int, vocab_size: int = None):
|
| | """
|
| | Factory function để tạo model dựa trên tên model
|
| |
|
| | Args:
|
| | model_name: Tên model (phobert-v2, textcnn, bilstm, etc.)
|
| | num_labels: Số lượng classes
|
| | vocab_size: Kích thước vocabulary (chỉ cần cho BiLSTM-CRF)
|
| |
|
| | Returns:
|
| | Model instance
|
| | """
|
| |
|
| | model_mapping = {
|
| | "phobert-v1": "vinai/phobert-base",
|
| | "phobert-v2": "vinai/phobert-base-v2",
|
| | "bartpho": "vinai/bartpho-syllable",
|
| | "visobert": "uitnlp/visobert",
|
| | "xlm-r": "xlm-roberta-large",
|
| | "mbert": "bert-base-multilingual-cased",
|
| | "vit5": "VietAI/vit5-base"
|
| | }
|
| |
|
| | if model_name == "vit5":
|
| |
|
| | base_model_name = model_mapping[model_name]
|
| | return ViT5ForSpamDetection(base_model_name, num_labels)
|
| | elif model_name in model_mapping:
|
| |
|
| | base_model_name = model_mapping[model_name]
|
| | return TransformerForSpamDetection(base_model_name, num_labels)
|
| |
|
| | elif model_name == "textcnn":
|
| |
|
| | base_model_name = "vinai/phobert-base-v2"
|
| | return TextCNN(base_model_name, num_labels)
|
| |
|
| | elif model_name == "bilstm":
|
| |
|
| | base_model_name = "vinai/phobert-base-v2"
|
| | return BiLSTM(base_model_name, num_labels)
|
| |
|
| | elif model_name == "roberta-gru":
|
| |
|
| | base_model_name = "vinai/phobert-base-v2"
|
| | return RoBERTaGRU(base_model_name, num_labels)
|
| |
|
| | elif model_name == "sphobert":
|
| |
|
| | base_model_name = "vinai/phobert-base-v2"
|
| | return SPhoBERT(base_model_name, num_labels)
|
| |
|
| | elif model_name == "bilstm-crf":
|
| |
|
| |
|
| | base_model_name = "vinai/phobert-base-v2"
|
| | return BiLSTM(base_model_name, num_labels)
|
| |
|
| | else:
|
| | raise ValueError(f"Unknown model name: {model_name}. Available models: {list(model_mapping.keys()) + ['textcnn', 'bilstm', 'roberta-gru', 'sphobert', 'bilstm-crf']}")
|
| |
|
| | def get_model_config(model_name: str):
|
| | """
|
| | Lấy cấu hình cho model
|
| |
|
| | Args:
|
| | model_name: Tên model
|
| |
|
| | Returns:
|
| | Dict chứa cấu hình model
|
| | """
|
| | configs = {
|
| | "phobert-v1": {
|
| | "model_name": "vinai/phobert-base",
|
| | "description": "PhoBERT v1 - Pre-trained BERT for Vietnamese",
|
| | "max_length": 256,
|
| | "learning_rate": 5e-5
|
| | },
|
| | "phobert-v2": {
|
| | "model_name": "vinai/phobert-base-v2",
|
| | "description": "PhoBERT v2 - Improved PhoBERT for Vietnamese",
|
| | "max_length": 256,
|
| | "learning_rate": 5e-5
|
| | },
|
| | "bartpho": {
|
| | "model_name": "vinai/bartpho-syllable",
|
| | "description": "BART Pho - Vietnamese BART model",
|
| | "max_length": 256,
|
| | "learning_rate": 5e-5
|
| | },
|
| | "visobert": {
|
| | "model_name": "uitnlp/visobert",
|
| | "description": "ViSoBERT - Vietnamese Social BERT",
|
| | "max_length": 256,
|
| | "learning_rate": 5e-5
|
| | },
|
| | "xlm-r": {
|
| | "model_name": "xlm-roberta-large",
|
| | "description": "XLM-RoBERTa Large - Multilingual model",
|
| | "max_length": 256,
|
| | "learning_rate": 3e-5
|
| | },
|
| | "mbert": {
|
| | "model_name": "bert-base-multilingual-cased",
|
| | "description": "mBERT - Multilingual BERT model",
|
| | "max_length": 256,
|
| | "learning_rate": 5e-5
|
| | },
|
| | "vit5": {
|
| | "model_name": "VietAI/vit5-base",
|
| | "description": "ViT5 - Vietnamese T5",
|
| | "max_length": 256,
|
| | "learning_rate": 5e-5
|
| | },
|
| | "textcnn": {
|
| | "model_name": "vinai/phobert-base-v2",
|
| | "description": "TextCNN - Convolutional Neural Network for text",
|
| | "max_length": 256,
|
| | "learning_rate": 1e-3,
|
| | "custom_model": True
|
| | },
|
| | "bilstm": {
|
| | "model_name": "vinai/phobert-base-v2",
|
| | "description": "BiLSTM - Bidirectional LSTM for text classification",
|
| | "max_length": 256,
|
| | "learning_rate": 1e-3,
|
| | "custom_model": True
|
| | },
|
| | "roberta-gru": {
|
| | "model_name": "vinai/phobert-base-v2",
|
| | "description": "RoBERTa-GRU - Hybrid RoBERTa + GRU model",
|
| | "max_length": 256,
|
| | "learning_rate": 5e-5,
|
| | "custom_model": True
|
| | },
|
| | "sphobert": {
|
| | "model_name": "vinai/phobert-base-v2",
|
| | "description": "SPhoBERT - PhoBERT + SentenceBERT embedding fusion",
|
| | "max_length": 256,
|
| | "learning_rate": 5e-5,
|
| | "custom_model": True
|
| | },
|
| | "bilstm-crf": {
|
| | "model_name": "vinai/phobert-base-v2",
|
| | "description": "BiLSTM-CRF - Bidirectional LSTM with CRF",
|
| | "max_length": 256,
|
| | "learning_rate": 1e-3,
|
| | "custom_model": True
|
| | }
|
| | }
|
| |
|
| | if model_name not in configs:
|
| | raise ValueError(f"Model {model_name} not found. Available models: {list(configs.keys())}")
|
| |
|
| | return configs[model_name] |