Shaltiel commited on
Commit
76546bb
verified
1 Parent(s): e17a555

Upload BertForDiacritization.py

Browse files
Files changed (1) hide show
  1. BertForDiacritization.py +190 -0
BertForDiacritization.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional, Tuple, Union
3
+ import torch
4
+ from torch import nn
5
+ from transformers.utils import ModelOutput
6
+ from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast
7
+
8
+ # MAT_LECT => Matres Lectionis, known in Hebrew as Em Kriaa.
9
+ MAT_LECT_TOKEN = '<MAT_LECT>'
10
+ NIKUD_CLASSES = ['', MAT_LECT_TOKEN, '\u05BC', '\u05B0', '\u05B1', '\u05B2', '\u05B3', '\u05B4', '\u05B5', '\u05B6', '\u05B7', '\u05B8', '\u05B9', '\u05BA', '\u05BB', '\u05BC\u05B0', '\u05BC\u05B1', '\u05BC\u05B2', '\u05BC\u05B3', '\u05BC\u05B4', '\u05BC\u05B5', '\u05BC\u05B6', '\u05BC\u05B7', '\u05BC\u05B8', '\u05BC\u05B9', '\u05BC\u05BA', '\u05BC\u05BB', '\u05C7', '\u05BC\u05C7']
11
+ SHIN_CLASSES = ['\u05C1', '\u05C2'] # shin, sin
12
+
13
+ @dataclass
14
+ class MenakedLogitsOutput(ModelOutput):
15
+ nikud_logits: torch.FloatTensor = None
16
+ shin_logits: torch.FloatTensor = None
17
+
18
+ def detach(self):
19
+ return MenakedLogitsOutput(self.nikud_logits.detach(), self.shin_logits.detach())
20
+
21
+ @dataclass
22
+ class MenakedOutput(ModelOutput):
23
+ loss: Optional[torch.FloatTensor] = None
24
+ logits: Optional[MenakedLogitsOutput] = None
25
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
26
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
27
+
28
+ @dataclass
29
+ class MenakedLabels(ModelOutput):
30
+ nikud_labels: Optional[torch.FloatTensor] = None
31
+ shin_labels: Optional[torch.FloatTensor] = None
32
+
33
+ def detach(self):
34
+ return MenakedLabels(self.nikud_labels.detach(), self.shin_labels.detach())
35
+
36
+ def to(self, device):
37
+ return MenakedLabels(self.nikud_labels.to(device), self.shin_labels.to(device))
38
+
39
+ class BertMenakedHead(nn.Module):
40
+ def __init__(self, config):
41
+ super().__init__()
42
+ self.config = config
43
+
44
+ if not hasattr(config, 'nikud_classes'):
45
+ config.nikud_classes = NIKUD_CLASSES
46
+ config.shin_classes = SHIN_CLASSES
47
+ config.mat_lect_token = MAT_LECT_TOKEN
48
+
49
+ self.num_nikud_classes = len(config.nikud_classes)
50
+ self.num_shin_classes = len(config.shin_classes)
51
+
52
+ # create our classifiers
53
+ self.nikud_cls = nn.Linear(config.hidden_size, self.num_nikud_classes)
54
+ self.shin_cls = nn.Linear(config.hidden_size, self.num_shin_classes)
55
+
56
+ def forward(
57
+ self,
58
+ hidden_states: torch.Tensor,
59
+ labels: Optional[MenakedLabels] = None):
60
+
61
+ # run each of the classifiers on the transformed output
62
+ nikud_logits = self.nikud_cls(hidden_states)
63
+ shin_logits = self.shin_cls(hidden_states)
64
+
65
+ loss = None
66
+ if labels is not None:
67
+ loss_fct = nn.CrossEntropyLoss()
68
+ loss = loss_fct(nikud_logits.view(-1, self.num_nikud_classes), labels.nikud_labels.view(-1))
69
+ loss += loss_fct(shin_logits.view(-1, self.num_shin_classes), labels.shin_labels.view(-1))
70
+
71
+ return loss, MenakedLogitsOutput(nikud_logits, shin_logits)
72
+
73
+ class BertForDiacritization(BertPreTrainedModel):
74
+ def __init__(self, config):
75
+ super().__init__(config)
76
+ self.config = config
77
+ self.bert = BertModel(config, add_pooling_layer=False)
78
+
79
+ classifier_dropout = config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
80
+ self.dropout = nn.Dropout(classifier_dropout)
81
+
82
+ self.menaked = BertMenakedHead(config)
83
+
84
+ # Initialize weights and apply final processing
85
+ self.post_init()
86
+
87
+ def forward(
88
+ self,
89
+ input_ids: Optional[torch.Tensor] = None,
90
+ attention_mask: Optional[torch.Tensor] = None,
91
+ token_type_ids: Optional[torch.Tensor] = None,
92
+ position_ids: Optional[torch.Tensor] = None,
93
+ head_mask: Optional[torch.Tensor] = None,
94
+ inputs_embeds: Optional[torch.Tensor] = None,
95
+ labels: Optional[torch.Tensor] = None,
96
+ output_attentions: Optional[bool] = None,
97
+ output_hidden_states: Optional[bool] = None,
98
+ return_dict: Optional[bool] = None,
99
+ ) -> Union[Tuple[torch.Tensor], MenakedOutput]:
100
+
101
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
102
+
103
+ bert_outputs = self.bert(
104
+ input_ids,
105
+ attention_mask=attention_mask,
106
+ token_type_ids=token_type_ids,
107
+ position_ids=position_ids,
108
+ head_mask=head_mask,
109
+ inputs_embeds=inputs_embeds,
110
+ output_attentions=output_attentions,
111
+ output_hidden_states=output_hidden_states,
112
+ return_dict=return_dict,
113
+ )
114
+
115
+ hidden_states = bert_outputs[0]
116
+ hidden_states = self.dropout(hidden_states)
117
+
118
+ loss, logits = self.menaked(hidden_states, labels)
119
+
120
+ if not return_dict:
121
+ return (loss,logits) + bert_outputs[2:]
122
+
123
+ return MenakedOutput(
124
+ loss=loss,
125
+ logits=logits,
126
+ hidden_states=bert_outputs.hidden_states,
127
+ attentions=bert_outputs.attentions,
128
+ )
129
+
130
+ def predict(self, sentences: List[str], tokenizer: BertTokenizerFast, mark_matres_lectionis: str = None, padding='longest'):
131
+ sentences = [remove_nikkud(sentence) for sentence in sentences]
132
+ # assert the lengths aren't out of range
133
+ assert all(len(sentence) + 2 <= tokenizer.model_max_length for sentence in sentences), f'All sentences must be <= {tokenizer.model_max_length}, please segment and try again'
134
+
135
+ # tokenize the inputs and convert them to relevant device
136
+ inputs = tokenizer(sentences, padding=padding, truncation=True, return_tensors='pt', return_offsets_mapping=True)
137
+ offset_mapping = inputs.pop('offset_mapping')
138
+ inputs = {k:v.to(self.device) for k,v in inputs.items()}
139
+
140
+ # calculate the predictions
141
+ logits = self.forward(**inputs, return_dict=True).logits
142
+ nikud_predictions = logits.nikud_logits.argmax(dim=-1).tolist()
143
+ shin_predictions = logits.shin_logits.argmax(dim=-1).tolist()
144
+
145
+ ret = []
146
+ for sent_idx,(sentence,sent_offsets) in enumerate(zip(sentences, offset_mapping)):
147
+ # assign the nikud to each letter!
148
+ output = []
149
+ prev_index = 0
150
+ for idx,offsets in enumerate(sent_offsets):
151
+ # add in anything we missed
152
+ if offsets[0] > prev_index:
153
+ output.append(sentence[prev_index:offsets[0]])
154
+ if offsets[1] - offsets[0] != 1: continue
155
+
156
+ # get our next char
157
+ char = sentence[offsets[0]:offsets[1]]
158
+ prev_index = offsets[1]
159
+ if not is_hebrew_letter(char):
160
+ output.append(char)
161
+ continue
162
+
163
+ nikud = self.config.nikud_classes[nikud_predictions[sent_idx][idx]]
164
+ shin = '' if char != '砖' else self.config.shin_classes[shin_predictions[sent_idx][idx]]
165
+
166
+ # check for matres lectionis
167
+ if nikud == self.config.mat_lect_token:
168
+ if not is_matres_letter(char): nikud = '' # don't allow matres on irrelevant letters
169
+ elif mark_matres_lectionis is not None: nikud = mark_matres_lectionis
170
+ else: continue
171
+
172
+ output.append(char + shin + nikud)
173
+ output.append(sentence[prev_index:])
174
+ ret.append(''.join(output))
175
+
176
+ return ret
177
+
178
+ ALEF_ORD = ord('讗')
179
+ TAF_ORD = ord('转')
180
+ def is_hebrew_letter(char):
181
+ return ALEF_ORD <= ord(char) <= TAF_ORD
182
+
183
+ MATRES_LETTERS = list('讗讜讬')
184
+ def is_matres_letter(char):
185
+ return char in MATRES_LETTERS
186
+
187
+ import re
188
+ nikud_pattern = re.compile(r'[\u05B0-\u05BD\u05C1\u05C2\u05C7]')
189
+ def remove_nikkud(text):
190
+ return nikud_pattern.sub('', text)