add hybrid models
Browse files- app.py +22 -3
- bert_gru_classifier/__init__.py +2 -0
- bert_gru_classifier/config.py +32 -0
- bert_gru_classifier/model.py +175 -0
app.py
CHANGED
@@ -1,7 +1,26 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
|
|
5 |
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from bert_gru_classifier import BERTBiGRUClassifier
|
4 |
+
from transformers import AutoTokenizer
|
5 |
|
6 |
+
model = BERTBiGRUClassifier.from_pretrained("Amal17/NusaBERT-BiGRU-NusaX-ace")
|
7 |
+
tokenizer = AutoTokenizer.from_pretrained("LazarusNLP/NusaBERT-large")
|
8 |
+
model.eval()
|
9 |
|
10 |
+
def run(input):
|
11 |
+
text_tokenized = tokenizer(
|
12 |
+
input,
|
13 |
+
padding="max_length",
|
14 |
+
truncation=True,
|
15 |
+
max_length=128
|
16 |
+
)
|
17 |
+
|
18 |
+
outputs = model(**text_tokenized)
|
19 |
+
logits = outputs['logits']
|
20 |
+
probs = torch.nn.functional.softmax(logits, dim=1)
|
21 |
+
preds = torch.argmax(probs, dim=1).tolist()
|
22 |
+
|
23 |
+
return "Hello " + preds + "!!"
|
24 |
+
|
25 |
+
demo = gr.Interface(fn=run, inputs="text", outputs="text")
|
26 |
demo.launch()
|
bert_gru_classifier/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .config import ConfigHybridBiGRUModel
|
2 |
+
from .model import BERTBiGRUClassifier
|
bert_gru_classifier/config.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
|
3 |
+
class ConfigHybridBiGRUModel(PretrainedConfig):
|
4 |
+
model_type = "bert-bigru"
|
5 |
+
|
6 |
+
def __init__(self,
|
7 |
+
bert_model_name="bert-base-uncased",
|
8 |
+
tokenizer_name="bert-base-uncased",
|
9 |
+
hidden_dim=128,
|
10 |
+
num_classes=2,
|
11 |
+
gru_layers=1,
|
12 |
+
bidirectional=True,
|
13 |
+
dropout=0.3,
|
14 |
+
concat_layers=0,
|
15 |
+
pooling="last",
|
16 |
+
freeze_bert=False,
|
17 |
+
freeze_n_layers=0, # jumlah layer yg akan di-freeze
|
18 |
+
freeze_from_start=False, # freeze dari awal atau akhir
|
19 |
+
**kwargs):
|
20 |
+
super().__init__(**kwargs)
|
21 |
+
self.bert_model_name = bert_model_name
|
22 |
+
self.tokenizer_name = tokenizer_name
|
23 |
+
self.hidden_dim = hidden_dim
|
24 |
+
self.num_classes = num_classes
|
25 |
+
self.gru_layers = gru_layers
|
26 |
+
self.bidirectional = bidirectional
|
27 |
+
self.dropout = dropout
|
28 |
+
self.concat_layers = concat_layers
|
29 |
+
self.pooling = pooling
|
30 |
+
self.freeze_bert = freeze_bert
|
31 |
+
self.freeze_n_layers = freeze_n_layers
|
32 |
+
self.freeze_from_start = freeze_from_start
|
bert_gru_classifier/model.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers import BertModel, PreTrainedModel
|
4 |
+
from transformers.modeling_outputs import SequenceClassifierOutput
|
5 |
+
from .config import ConfigHybridBiGRUModel
|
6 |
+
|
7 |
+
import logging
|
8 |
+
|
9 |
+
class BERTBiGRUClassifier(PreTrainedModel):
|
10 |
+
"""
|
11 |
+
Bert + BiGRU + Classifier head for sequence classification tasks.
|
12 |
+
"""
|
13 |
+
config_class = ConfigHybridBiGRUModel
|
14 |
+
|
15 |
+
def __init__(self, config):
|
16 |
+
super().__init__(config)
|
17 |
+
self.config = config
|
18 |
+
# Setup logging configuration
|
19 |
+
logging.basicConfig(level=logging.INFO, # INFO level untuk production
|
20 |
+
format='%(asctime)s - %(levelname)s - %(message)s',
|
21 |
+
handlers=[logging.StreamHandler()])
|
22 |
+
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
self.logger = logger
|
25 |
+
|
26 |
+
# ===== Load BERT backbone =====
|
27 |
+
self.bert = BertModel.from_pretrained(config.bert_model_name, output_hidden_states=True)
|
28 |
+
logger.info("Model initialized with BERT model: %s", config.bert_model_name)
|
29 |
+
# ===== Freeze BERT parameters if needed =====
|
30 |
+
if config.freeze_bert:
|
31 |
+
assert hasattr(self.bert, "encoder"), "BERT model must have encoder layers"
|
32 |
+
total_layers = len(self.bert.encoder.layer)
|
33 |
+
# Validate freeze_n_layers
|
34 |
+
if config.freeze_n_layers > total_layers or config.freeze_n_layers < 0:
|
35 |
+
raise ValueError(f"freeze_n_layers ({config.freeze_n_layers}) is out of valid range (0-{total_layers})")
|
36 |
+
# Select which layers to freeze
|
37 |
+
if config.freeze_from_start:
|
38 |
+
layers_to_freeze = list(range(config.freeze_n_layers)) # freeze dari awal
|
39 |
+
logger.info(f"Freezing the top {config.freeze_n_layers} layers of BERT.")
|
40 |
+
else:
|
41 |
+
layers_to_freeze = list(range(total_layers - config.freeze_n_layers, total_layers)) # freeze dari akhir
|
42 |
+
logger.info(f"Freezing the end {config.freeze_n_layers} layers of BERT.")
|
43 |
+
# Apply freezing
|
44 |
+
for idx, layer in enumerate(self.bert.encoder.layer):
|
45 |
+
if idx in layers_to_freeze:
|
46 |
+
for param in layer.parameters():
|
47 |
+
param.requires_grad = False
|
48 |
+
else:
|
49 |
+
for param in layer.parameters():
|
50 |
+
param.requires_grad = True
|
51 |
+
# ===== Define BiGRU layer =====
|
52 |
+
# Update input_size to account for concatenation
|
53 |
+
# 768 for bert, then multiply with concat layer
|
54 |
+
input_size = self.bert.config.hidden_size * config.concat_layers if config.concat_layers > 0 else self.bert.config.hidden_size
|
55 |
+
self.gru = nn.GRU(
|
56 |
+
input_size=input_size,
|
57 |
+
hidden_size=config.hidden_dim,
|
58 |
+
num_layers=config.gru_layers,
|
59 |
+
bidirectional=config.bidirectional,
|
60 |
+
batch_first=True,
|
61 |
+
)
|
62 |
+
# ===== Define dropout layer =====
|
63 |
+
self.dropout = nn.Dropout(config.dropout)
|
64 |
+
# ===== Define final classification head =====
|
65 |
+
self.classifier = nn.Linear(
|
66 |
+
config.hidden_dim * 2 if config.bidirectional else config.hidden_dim,
|
67 |
+
config.num_classes
|
68 |
+
)
|
69 |
+
# ===== Define loss function =====
|
70 |
+
self.loss_fn = nn.CrossEntropyLoss()
|
71 |
+
# ===== Print model summary =====
|
72 |
+
self._print_trainable_parameters()
|
73 |
+
self.post_init()
|
74 |
+
|
75 |
+
def _print_trainable_parameters(self):
|
76 |
+
"""
|
77 |
+
Utility function to print the number of total and trainable parameters.
|
78 |
+
"""
|
79 |
+
total_params = sum(p.numel() for p in self.parameters())
|
80 |
+
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
81 |
+
|
82 |
+
print("="*50)
|
83 |
+
print(f"Total Parameters: {total_params:,}")
|
84 |
+
print(f"Trainable Parameters: {trainable_params:,}")
|
85 |
+
print(f"Frozen Parameters: {total_params - trainable_params:,}")
|
86 |
+
print("="*50)
|
87 |
+
|
88 |
+
def _print_named_parameters(self):
|
89 |
+
"""
|
90 |
+
Utility function to print each parameter name and whether it's trainable.
|
91 |
+
"""
|
92 |
+
print("="*70)
|
93 |
+
print(f"{'Parameter Name':45} | {'Trainable'}")
|
94 |
+
print("-"*70)
|
95 |
+
for name, param in self.named_parameters():
|
96 |
+
print(f"{name:45} | {'Yes' if param.requires_grad else 'No'}")
|
97 |
+
print("="*70)
|
98 |
+
|
99 |
+
def _get_freeze_summary(self):
|
100 |
+
"""
|
101 |
+
Returns a summary of frozen and trainable layers in the BERT model.
|
102 |
+
"""
|
103 |
+
summary = []
|
104 |
+
for idx, layer in enumerate(self.bert.encoder.layer):
|
105 |
+
layer_info = {
|
106 |
+
"layer": f"bert.encoder.layer.{idx}",
|
107 |
+
"trainable": False if not any(param.requires_grad for param in layer.parameters()) else True
|
108 |
+
}
|
109 |
+
summary.append(layer_info)
|
110 |
+
|
111 |
+
return summary
|
112 |
+
|
113 |
+
|
114 |
+
def forward(self, input_ids, attention_mask=None, token_type_ids=None, labels=None):
|
115 |
+
"""
|
116 |
+
Forward pass through BERT -> BiGRU -> Pooling -> Classifier.
|
117 |
+
"""
|
118 |
+
# ===== BERT forward pass =====
|
119 |
+
bert_output = self.bert(
|
120 |
+
input_ids=input_ids,
|
121 |
+
attention_mask=attention_mask,
|
122 |
+
token_type_ids=token_type_ids
|
123 |
+
)
|
124 |
+
# ===== Handle concat of last hidden states if configured =====
|
125 |
+
if self.config.concat_layers > 0:
|
126 |
+
hidden_states = bert_output.hidden_states
|
127 |
+
concat_layers = min(self.config.concat_layers, len(hidden_states))
|
128 |
+
selected_layers = hidden_states[-concat_layers:]
|
129 |
+
sequence_output = torch.cat(selected_layers, dim=-1)
|
130 |
+
else:
|
131 |
+
sequence_output = bert_output.last_hidden_state
|
132 |
+
# ===== Pass through BiGRU =====
|
133 |
+
gru_output, _ = self.gru(sequence_output)
|
134 |
+
# ===== Pooling layer (CLS / Last / Mean / Max) =====
|
135 |
+
if self.config.pooling == "cls":
|
136 |
+
pooled_output = gru_output[:, 0, :]
|
137 |
+
|
138 |
+
elif self.config.pooling == "last":
|
139 |
+
pooled_output = gru_output[:, -1, :]
|
140 |
+
|
141 |
+
elif self.config.pooling == "mean":
|
142 |
+
if attention_mask is not None:
|
143 |
+
mask = attention_mask.unsqueeze(-1).expand(gru_output.size())
|
144 |
+
masked_output = gru_output * mask
|
145 |
+
sum_masked_output = masked_output.sum(dim=1)
|
146 |
+
lengths = mask.sum(dim=1).clamp(min=1e-9)
|
147 |
+
pooled_output = sum_masked_output / lengths
|
148 |
+
else:
|
149 |
+
pooled_output = gru_output.mean(dim=1)
|
150 |
+
|
151 |
+
elif self.config.pooling == "max":
|
152 |
+
if attention_mask is not None:
|
153 |
+
mask = attention_mask.unsqueeze(-1).expand(gru_output.size())
|
154 |
+
masked_output = gru_output.masked_fill(mask == 0, -1e9)
|
155 |
+
pooled_output, _ = masked_output.max(dim=1)
|
156 |
+
else:
|
157 |
+
pooled_output, _ = gru_output.max(dim=1)
|
158 |
+
|
159 |
+
else:
|
160 |
+
raise ValueError(f"Unknown pooling type: {self.config.pooling}")
|
161 |
+
|
162 |
+
# ===== Dropout + Classification Head =====
|
163 |
+
pooled_output = self.dropout(pooled_output)
|
164 |
+
logits = self.classifier(pooled_output)
|
165 |
+
|
166 |
+
# ===== Compute loss if labels provided =====
|
167 |
+
loss = None
|
168 |
+
if labels is not None:
|
169 |
+
loss = self.loss_fn(logits, labels)
|
170 |
+
return SequenceClassifierOutput(
|
171 |
+
loss=loss,
|
172 |
+
logits=logits
|
173 |
+
)
|
174 |
+
else:
|
175 |
+
return {"logits": logits}
|