Amal17 commited on
Commit
ec1ed73
·
1 Parent(s): 66db869

add bert-lstm

Browse files
app.py CHANGED
@@ -1,28 +1,59 @@
1
  import gradio as gr
2
  import torch
3
  from bert_gru_classifier import BERTBiGRUClassifier
 
4
  from transformers import AutoTokenizer
5
 
6
  CLASS_MAP = {0: "Negative", 1: "Neutral", 2: "Positive" }
7
- model = BERTBiGRUClassifier.from_pretrained("Amal17/NusaBERT-BiGRU-NusaX-ace")
 
8
  tokenizer = AutoTokenizer.from_pretrained("LazarusNLP/NusaBERT-large")
9
- model.eval()
10
 
11
- def run(input):
12
- text_tokenized = tokenizer(
13
- input,
 
 
 
 
 
 
 
 
14
  padding="max_length",
15
  truncation=True,
16
  max_length=128,
17
  return_tensors="pt"
18
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- outputs = model(**text_tokenized)
21
- logits = outputs['logits']
22
- probs = torch.nn.functional.softmax(logits, dim=1)
23
- print(probs)
24
- preds = torch.argmax(probs, dim=1).tolist()
25
- return f"Prediction: Class {preds[0]} ({CLASS_MAP[preds[0]]}) With Probs : {probs[0][preds[0]]}"
 
 
 
 
 
 
26
 
27
- demo = gr.Interface(fn=run, inputs="text", outputs="text")
28
- demo.launch()
 
1
  import gradio as gr
2
  import torch
3
  from bert_gru_classifier import BERTBiGRUClassifier
4
+ from bert_lstm_classifier import BERTBiLSTMClassifier
5
  from transformers import AutoTokenizer
6
 
7
  CLASS_MAP = {0: "Negative", 1: "Neutral", 2: "Positive" }
8
+
9
+ # Load tokenizer (pakai tokenizer yang sama untuk semua model)
10
  tokenizer = AutoTokenizer.from_pretrained("LazarusNLP/NusaBERT-large")
 
11
 
12
+ # Load models
13
+ bigru_model = BERTBiGRUClassifier.from_pretrained("Amal17/NusaBERT-concate-BiGRU-NusaX-ace")
14
+ bigru_model.eval()
15
+
16
+ bilstm_model = BERTBiLSTMClassifier.from_pretrained("Amal17/NusaBERT-concate-BiLSTM-NusaX-ace")
17
+ bilstm_model.eval()
18
+
19
+ # Inference helper
20
+ def predict_with_model(model, text):
21
+ inputs = tokenizer(
22
+ text,
23
  padding="max_length",
24
  truncation=True,
25
  max_length=128,
26
  return_tensors="pt"
27
  )
28
+ with torch.no_grad():
29
+ outputs = model(**inputs)
30
+ logits = outputs["logits"]
31
+ probs = torch.softmax(logits, dim=1)
32
+ pred = torch.argmax(probs, dim=1).item()
33
+ confidence = probs[0][pred].item()
34
+ return pred, confidence
35
+
36
+ # Gradio interface function
37
+ def compare_models(text):
38
+ pred_a, conf_a = predict_with_model(bigru_model, text)
39
+ pred_b, conf_b = predict_with_model(bilstm_model, text)
40
+
41
+ return (
42
+ f"BiGRU → Class: {pred_a}", f"Confidence: {conf_a:.4f}",
43
+ f"BiLSTM → Class: {pred_b}", f"Confidence: {conf_b:.4f}"
44
+ )
45
 
46
+ # Build Gradio UI
47
+ interface = gr.Interface(
48
+ fn=compare_models,
49
+ inputs=gr.Textbox(label="Input Text"),
50
+ outputs=[
51
+ gr.Textbox(label="BiGRU Prediction"),
52
+ gr.Textbox(label="BiGRU Confidence"),
53
+ gr.Textbox(label="BiLSTM Prediction"),
54
+ gr.Textbox(label="BiLSTM Confidence"),
55
+ ],
56
+ title="Model Comparison: BiGRU vs BiLSTM"
57
+ )
58
 
59
+ interface.launch()
 
bert_lstm_classifier/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .config import ConfigHybridBiLSTMModel
2
+ from .model import BERTBiLSTMClassifier
bert_lstm_classifier/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (287 Bytes). View file
 
bert_lstm_classifier/__pycache__/config.cpython-310.pyc ADDED
Binary file (1.12 kB). View file
 
bert_lstm_classifier/__pycache__/model.cpython-310.pyc ADDED
Binary file (5.59 kB). View file
 
bert_lstm_classifier/config.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class ConfigHybridBiLSTMModel(PretrainedConfig):
4
+ model_type = "bert-bilstm"
5
+
6
+ def __init__(self,
7
+ bert_model_name="bert-base-uncased",
8
+ tokenizer_name="bert-base-uncased",
9
+ hidden_dim=128,
10
+ num_classes=2,
11
+ lstm_layers=1,
12
+ bidirectional=True,
13
+ dropout=0.3,
14
+ concat_layers=0,
15
+ pooling="last",
16
+ freeze_bert=False,
17
+ freeze_n_layers=0, # jumlah layer yg akan di-freeze
18
+ freeze_from_start=False, # freeze dari awal atau akhir
19
+ **kwargs):
20
+ super().__init__(**kwargs)
21
+ self.bert_model_name = bert_model_name
22
+ self.tokenizer_name = tokenizer_name
23
+ self.hidden_dim = hidden_dim
24
+ self.num_classes = num_classes
25
+ self.lstm_layers = lstm_layers
26
+ self.bidirectional = bidirectional
27
+ self.dropout = dropout
28
+ self.concat_layers = concat_layers
29
+ self.pooling = pooling
30
+ self.freeze_bert = freeze_bert
31
+ self.freeze_n_layers = freeze_n_layers
32
+ self.freeze_from_start = freeze_from_start
bert_lstm_classifier/model.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import BertModel, PreTrainedModel
4
+ from transformers.modeling_outputs import SequenceClassifierOutput
5
+ from .config import ConfigHybridBiLSTMModel
6
+
7
+ import logging
8
+
9
+ class BERTBiLSTMClassifier(PreTrainedModel):
10
+ """
11
+ Bert + BiLSTM + Classifier head for sequence classification tasks.
12
+ """
13
+ config_class = ConfigHybridBiLSTMModel
14
+
15
+ def __init__(self, config):
16
+ super().__init__(config)
17
+ self.config = config
18
+ # Setup logging configuration
19
+ logging.basicConfig(level=logging.INFO, # INFO level untuk production
20
+ format='%(asctime)s - %(levelname)s - %(message)s',
21
+ handlers=[logging.StreamHandler()])
22
+
23
+ logger = logging.getLogger(__name__)
24
+ self.logger = logger
25
+
26
+ # ===== Load BERT backbone =====
27
+ self.bert = BertModel.from_pretrained(config.bert_model_name, output_hidden_states=True)
28
+ logger.info("Model initialized with BERT model: %s", config.bert_model_name)
29
+ # ===== Freeze BERT parameters if needed =====
30
+ if config.freeze_bert:
31
+ assert hasattr(self.bert, "encoder"), "BERT model must have encoder layers"
32
+ total_layers = len(self.bert.encoder.layer)
33
+ # Validate freeze_n_layers
34
+ if config.freeze_n_layers > total_layers or config.freeze_n_layers < 0:
35
+ raise ValueError(f"freeze_n_layers ({config.freeze_n_layers}) is out of valid range (0-{total_layers})")
36
+ # Select which layers to freeze
37
+ if config.freeze_from_start:
38
+ layers_to_freeze = list(range(config.freeze_n_layers)) # freeze dari awal
39
+ logger.info(f"Freezing the top {config.freeze_n_layers} layers of BERT.")
40
+ else:
41
+ layers_to_freeze = list(range(total_layers - config.freeze_n_layers, total_layers)) # freeze dari akhir
42
+ logger.info(f"Freezing the end {config.freeze_n_layers} layers of BERT.")
43
+ # Apply freezing
44
+ for idx, layer in enumerate(self.bert.encoder.layer):
45
+ if idx in layers_to_freeze:
46
+ for param in layer.parameters():
47
+ param.requires_grad = False
48
+ else:
49
+ for param in layer.parameters():
50
+ param.requires_grad = True
51
+ # ===== Define BiLSTM layer =====
52
+ # Update input_size to account for concatenation
53
+ # 768 for bert, then multiply with concat layer
54
+ input_size = self.bert.config.hidden_size * config.concat_layers if config.concat_layers > 0 else self.bert.config.hidden_size
55
+ self.lstm = nn.LSTM(
56
+ input_size=input_size,
57
+ hidden_size=config.hidden_dim,
58
+ num_layers=config.lstm_layers,
59
+ bidirectional=config.bidirectional,
60
+ batch_first=True
61
+ )
62
+ # ===== Define dropout layer =====
63
+ self.dropout = nn.Dropout(config.dropout)
64
+ # ===== Define final classification head =====
65
+ self.classifier = nn.Linear(
66
+ config.hidden_dim * 2 if config.bidirectional else config.hidden_dim,
67
+ config.num_classes
68
+ )
69
+ # ===== Define loss function =====
70
+ self.loss_fn = nn.CrossEntropyLoss()
71
+ # ===== Print model summary =====
72
+ self._print_trainable_parameters()
73
+ self.post_init()
74
+
75
+ def _print_trainable_parameters(self):
76
+ """
77
+ Utility function to print the number of total and trainable parameters.
78
+ """
79
+ total_params = sum(p.numel() for p in self.parameters())
80
+ trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
81
+
82
+ print("="*50)
83
+ print(f"Total Parameters: {total_params:,}")
84
+ print(f"Trainable Parameters: {trainable_params:,}")
85
+ print(f"Frozen Parameters: {total_params - trainable_params:,}")
86
+ print("="*50)
87
+
88
+ def _print_named_parameters(self):
89
+ """
90
+ Utility function to print each parameter name and whether it's trainable.
91
+ """
92
+ print("="*70)
93
+ print(f"{'Parameter Name':45} | {'Trainable'}")
94
+ print("-"*70)
95
+ for name, param in self.named_parameters():
96
+ print(f"{name:45} | {'Yes' if param.requires_grad else 'No'}")
97
+ print("="*70)
98
+
99
+ def _get_freeze_summary(self):
100
+ """
101
+ Returns a summary of frozen and trainable layers in the BERT model.
102
+ """
103
+ summary = []
104
+ for idx, layer in enumerate(self.bert.encoder.layer):
105
+ layer_info = {
106
+ "layer": f"bert.encoder.layer.{idx}",
107
+ "trainable": False if not any(param.requires_grad for param in layer.parameters()) else True
108
+ }
109
+ summary.append(layer_info)
110
+
111
+ return summary
112
+
113
+
114
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None, labels=None):
115
+ """
116
+ Forward pass through BERT -> BiLSTM -> Pooling -> Classifier.
117
+ """
118
+ # ===== BERT forward pass =====
119
+ bert_output = self.bert(
120
+ input_ids=input_ids,
121
+ attention_mask=attention_mask,
122
+ token_type_ids=token_type_ids
123
+ )
124
+ # ===== Handle concat of last hidden states if configured =====
125
+ if self.config.concat_layers > 0:
126
+ hidden_states = bert_output.hidden_states
127
+ concat_layers = min(self.config.concat_layers, len(hidden_states))
128
+ selected_layers = hidden_states[-concat_layers:]
129
+ sequence_output = torch.cat(selected_layers, dim=-1)
130
+ else:
131
+ sequence_output = bert_output.last_hidden_state
132
+ # ===== Pass through BiLSTM =====
133
+ lstm_output, _ = self.lstm(sequence_output)
134
+ # ===== Pooling layer (CLS / Last / Mean / Max) =====
135
+ if self.config.pooling == "cls":
136
+ pooled_output = lstm_output[:, 0, :]
137
+
138
+ elif self.config.pooling == "last":
139
+ pooled_output = lstm_output[:, -1, :]
140
+
141
+ elif self.config.pooling == "mean":
142
+ if attention_mask is not None:
143
+ mask = attention_mask.unsqueeze(-1).expand(lstm_output.size())
144
+ masked_output = lstm_output * mask
145
+ sum_masked_output = masked_output.sum(dim=1)
146
+ lengths = mask.sum(dim=1).clamp(min=1e-9)
147
+ pooled_output = sum_masked_output / lengths
148
+ else:
149
+ pooled_output = lstm_output.mean(dim=1)
150
+
151
+ elif self.config.pooling == "max":
152
+ if attention_mask is not None:
153
+ mask = attention_mask.unsqueeze(-1).expand(lstm_output.size())
154
+ masked_output = lstm_output.masked_fill(mask == 0, -1e9)
155
+ pooled_output, _ = masked_output.max(dim=1)
156
+ else:
157
+ pooled_output, _ = lstm_output.max(dim=1)
158
+
159
+ else:
160
+ raise ValueError(f"Unknown pooling type: {self.config.pooling}")
161
+
162
+ # ===== Dropout + Classification Head =====
163
+ pooled_output = self.dropout(pooled_output)
164
+ logits = self.classifier(pooled_output)
165
+
166
+ # ===== Compute loss if labels provided =====
167
+ loss = None
168
+ if labels is not None:
169
+ loss = self.loss_fn(logits, labels)
170
+ return SequenceClassifierOutput(
171
+ loss=loss,
172
+ logits=logits
173
+ )
174
+ else:
175
+ return {"logits": logits}