dictabert-large-char-menaked / BertForDiacritization.py
Shaltiel's picture
Upload BertForDiacritization.py
76546bb verified
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
from torch import nn
from transformers.utils import ModelOutput
from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast
# MAT_LECT => Matres Lectionis, known in Hebrew as Em Kriaa.
MAT_LECT_TOKEN = '<MAT_LECT>'
NIKUD_CLASSES = ['', MAT_LECT_TOKEN, '\u05BC', '\u05B0', '\u05B1', '\u05B2', '\u05B3', '\u05B4', '\u05B5', '\u05B6', '\u05B7', '\u05B8', '\u05B9', '\u05BA', '\u05BB', '\u05BC\u05B0', '\u05BC\u05B1', '\u05BC\u05B2', '\u05BC\u05B3', '\u05BC\u05B4', '\u05BC\u05B5', '\u05BC\u05B6', '\u05BC\u05B7', '\u05BC\u05B8', '\u05BC\u05B9', '\u05BC\u05BA', '\u05BC\u05BB', '\u05C7', '\u05BC\u05C7']
SHIN_CLASSES = ['\u05C1', '\u05C2'] # shin, sin
@dataclass
class MenakedLogitsOutput(ModelOutput):
nikud_logits: torch.FloatTensor = None
shin_logits: torch.FloatTensor = None
def detach(self):
return MenakedLogitsOutput(self.nikud_logits.detach(), self.shin_logits.detach())
@dataclass
class MenakedOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: Optional[MenakedLogitsOutput] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class MenakedLabels(ModelOutput):
nikud_labels: Optional[torch.FloatTensor] = None
shin_labels: Optional[torch.FloatTensor] = None
def detach(self):
return MenakedLabels(self.nikud_labels.detach(), self.shin_labels.detach())
def to(self, device):
return MenakedLabels(self.nikud_labels.to(device), self.shin_labels.to(device))
class BertMenakedHead(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
if not hasattr(config, 'nikud_classes'):
config.nikud_classes = NIKUD_CLASSES
config.shin_classes = SHIN_CLASSES
config.mat_lect_token = MAT_LECT_TOKEN
self.num_nikud_classes = len(config.nikud_classes)
self.num_shin_classes = len(config.shin_classes)
# create our classifiers
self.nikud_cls = nn.Linear(config.hidden_size, self.num_nikud_classes)
self.shin_cls = nn.Linear(config.hidden_size, self.num_shin_classes)
def forward(
self,
hidden_states: torch.Tensor,
labels: Optional[MenakedLabels] = None):
# run each of the classifiers on the transformed output
nikud_logits = self.nikud_cls(hidden_states)
shin_logits = self.shin_cls(hidden_states)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(nikud_logits.view(-1, self.num_nikud_classes), labels.nikud_labels.view(-1))
loss += loss_fct(shin_logits.view(-1, self.num_shin_classes), labels.shin_labels.view(-1))
return loss, MenakedLogitsOutput(nikud_logits, shin_logits)
class BertForDiacritization(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.bert = BertModel(config, add_pooling_layer=False)
classifier_dropout = config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
self.dropout = nn.Dropout(classifier_dropout)
self.menaked = BertMenakedHead(config)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], MenakedOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
bert_outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = bert_outputs[0]
hidden_states = self.dropout(hidden_states)
loss, logits = self.menaked(hidden_states, labels)
if not return_dict:
return (loss,logits) + bert_outputs[2:]
return MenakedOutput(
loss=loss,
logits=logits,
hidden_states=bert_outputs.hidden_states,
attentions=bert_outputs.attentions,
)
def predict(self, sentences: List[str], tokenizer: BertTokenizerFast, mark_matres_lectionis: str = None, padding='longest'):
sentences = [remove_nikkud(sentence) for sentence in sentences]
# assert the lengths aren't out of range
assert all(len(sentence) + 2 <= tokenizer.model_max_length for sentence in sentences), f'All sentences must be <= {tokenizer.model_max_length}, please segment and try again'
# tokenize the inputs and convert them to relevant device
inputs = tokenizer(sentences, padding=padding, truncation=True, return_tensors='pt', return_offsets_mapping=True)
offset_mapping = inputs.pop('offset_mapping')
inputs = {k:v.to(self.device) for k,v in inputs.items()}
# calculate the predictions
logits = self.forward(**inputs, return_dict=True).logits
nikud_predictions = logits.nikud_logits.argmax(dim=-1).tolist()
shin_predictions = logits.shin_logits.argmax(dim=-1).tolist()
ret = []
for sent_idx,(sentence,sent_offsets) in enumerate(zip(sentences, offset_mapping)):
# assign the nikud to each letter!
output = []
prev_index = 0
for idx,offsets in enumerate(sent_offsets):
# add in anything we missed
if offsets[0] > prev_index:
output.append(sentence[prev_index:offsets[0]])
if offsets[1] - offsets[0] != 1: continue
# get our next char
char = sentence[offsets[0]:offsets[1]]
prev_index = offsets[1]
if not is_hebrew_letter(char):
output.append(char)
continue
nikud = self.config.nikud_classes[nikud_predictions[sent_idx][idx]]
shin = '' if char != 'ש' else self.config.shin_classes[shin_predictions[sent_idx][idx]]
# check for matres lectionis
if nikud == self.config.mat_lect_token:
if not is_matres_letter(char): nikud = '' # don't allow matres on irrelevant letters
elif mark_matres_lectionis is not None: nikud = mark_matres_lectionis
else: continue
output.append(char + shin + nikud)
output.append(sentence[prev_index:])
ret.append(''.join(output))
return ret
ALEF_ORD = ord('א')
TAF_ORD = ord('ת')
def is_hebrew_letter(char):
return ALEF_ORD <= ord(char) <= TAF_ORD
MATRES_LETTERS = list('אוי')
def is_matres_letter(char):
return char in MATRES_LETTERS
import re
nikud_pattern = re.compile(r'[\u05B0-\u05BD\u05C1\u05C2\u05C7]')
def remove_nikkud(text):
return nikud_pattern.sub('', text)