File size: 8,078 Bytes
76546bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
from torch import nn
from transformers.utils import ModelOutput
from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast

# MAT_LECT => Matres Lectionis, known in Hebrew as Em Kriaa.
MAT_LECT_TOKEN = '<MAT_LECT>'
NIKUD_CLASSES = ['', MAT_LECT_TOKEN, '\u05BC', '\u05B0', '\u05B1', '\u05B2', '\u05B3', '\u05B4', '\u05B5', '\u05B6', '\u05B7', '\u05B8', '\u05B9', '\u05BA', '\u05BB', '\u05BC\u05B0', '\u05BC\u05B1', '\u05BC\u05B2', '\u05BC\u05B3', '\u05BC\u05B4', '\u05BC\u05B5', '\u05BC\u05B6', '\u05BC\u05B7', '\u05BC\u05B8', '\u05BC\u05B9', '\u05BC\u05BA', '\u05BC\u05BB', '\u05C7', '\u05BC\u05C7'] 
SHIN_CLASSES = ['\u05C1', '\u05C2'] # shin, sin
        
@dataclass
class MenakedLogitsOutput(ModelOutput):
    nikud_logits: torch.FloatTensor = None
    shin_logits: torch.FloatTensor = None
    
    def detach(self):
        return MenakedLogitsOutput(self.nikud_logits.detach(), self.shin_logits.detach())

@dataclass
class MenakedOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: Optional[MenakedLogitsOutput] = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None

@dataclass
class MenakedLabels(ModelOutput):
    nikud_labels: Optional[torch.FloatTensor] = None
    shin_labels: Optional[torch.FloatTensor] = None
    
    def detach(self):
        return MenakedLabels(self.nikud_labels.detach(), self.shin_labels.detach())
    
    def to(self, device):
        return MenakedLabels(self.nikud_labels.to(device), self.shin_labels.to(device))

class BertMenakedHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        if not hasattr(config, 'nikud_classes'):
            config.nikud_classes = NIKUD_CLASSES
            config.shin_classes = SHIN_CLASSES
            config.mat_lect_token = MAT_LECT_TOKEN
        
        self.num_nikud_classes = len(config.nikud_classes)
        self.num_shin_classes = len(config.shin_classes)
        
        # create our classifiers
        self.nikud_cls = nn.Linear(config.hidden_size, self.num_nikud_classes)
        self.shin_cls = nn.Linear(config.hidden_size, self.num_shin_classes)
        
    def forward(

            self, 

            hidden_states: torch.Tensor,

            labels: Optional[MenakedLabels] = None):
        
        # run each of the classifiers on the transformed output
        nikud_logits = self.nikud_cls(hidden_states)
        shin_logits = self.shin_cls(hidden_states)
        
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(nikud_logits.view(-1, self.num_nikud_classes), labels.nikud_labels.view(-1))
            loss += loss_fct(shin_logits.view(-1, self.num_shin_classes), labels.shin_labels.view(-1))
        
        return loss, MenakedLogitsOutput(nikud_logits, shin_logits)
        
class BertForDiacritization(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.bert = BertModel(config, add_pooling_layer=False)

        classifier_dropout = config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        self.dropout = nn.Dropout(classifier_dropout)

        self.menaked = BertMenakedHead(config)
        
        # Initialize weights and apply final processing
        self.post_init()

    def forward(

        self,

        input_ids: Optional[torch.Tensor] = None,

        attention_mask: Optional[torch.Tensor] = None,

        token_type_ids: Optional[torch.Tensor] = None,

        position_ids: Optional[torch.Tensor] = None,

        head_mask: Optional[torch.Tensor] = None,

        inputs_embeds: Optional[torch.Tensor] = None,

        labels: Optional[torch.Tensor] = None,

        output_attentions: Optional[bool] = None,

        output_hidden_states: Optional[bool] = None,

        return_dict: Optional[bool] = None,

    ) -> Union[Tuple[torch.Tensor], MenakedOutput]:
        
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        bert_outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = bert_outputs[0]
        hidden_states = self.dropout(hidden_states)

        loss, logits = self.menaked(hidden_states, labels)
        
        if not return_dict:
            return (loss,logits) + bert_outputs[2:]
        
        return MenakedOutput(
            loss=loss,
            logits=logits,
            hidden_states=bert_outputs.hidden_states,
            attentions=bert_outputs.attentions,
        )
    
    def predict(self, sentences: List[str], tokenizer: BertTokenizerFast, mark_matres_lectionis: str = None, padding='longest'):
        sentences = [remove_nikkud(sentence) for sentence in sentences]
        # assert the lengths aren't out of range
        assert all(len(sentence) + 2 <= tokenizer.model_max_length for sentence in sentences), f'All sentences must be <= {tokenizer.model_max_length}, please segment and try again'

        # tokenize the inputs and convert them to relevant device
        inputs = tokenizer(sentences, padding=padding, truncation=True, return_tensors='pt', return_offsets_mapping=True)
        offset_mapping = inputs.pop('offset_mapping')
        inputs = {k:v.to(self.device) for k,v in inputs.items()}
        
        # calculate the predictions
        logits = self.forward(**inputs, return_dict=True).logits
        nikud_predictions = logits.nikud_logits.argmax(dim=-1).tolist()
        shin_predictions = logits.shin_logits.argmax(dim=-1).tolist()

        ret = []
        for sent_idx,(sentence,sent_offsets) in enumerate(zip(sentences, offset_mapping)):
            # assign the nikud to each letter!
            output = []
            prev_index = 0
            for idx,offsets in enumerate(sent_offsets):
                # add in anything we missed
                if offsets[0] > prev_index:
                    output.append(sentence[prev_index:offsets[0]])
                if offsets[1] - offsets[0] != 1: continue
                
                # get our next char
                char = sentence[offsets[0]:offsets[1]]
                prev_index = offsets[1]
                if not is_hebrew_letter(char):
                    output.append(char)
                    continue
                
                nikud = self.config.nikud_classes[nikud_predictions[sent_idx][idx]]
                shin = '' if char != 'ש' else self.config.shin_classes[shin_predictions[sent_idx][idx]] 

                # check for matres lectionis
                if nikud == self.config.mat_lect_token:
                    if not is_matres_letter(char): nikud = '' # don't allow matres on irrelevant letters
                    elif mark_matres_lectionis is not None: nikud = mark_matres_lectionis
                    else: continue

                output.append(char + shin + nikud)
            output.append(sentence[prev_index:])
            ret.append(''.join(output))
        
        return ret

ALEF_ORD = ord('א')
TAF_ORD = ord('ת')
def is_hebrew_letter(char):
   return ALEF_ORD <= ord(char) <= TAF_ORD

MATRES_LETTERS = list('אוי')
def is_matres_letter(char):
    return char in MATRES_LETTERS

import re
nikud_pattern = re.compile(r'[\u05B0-\u05BD\u05C1\u05C2\u05C7]')
def remove_nikkud(text):
    return nikud_pattern.sub('', text)