|
import logging |
|
|
|
import torch |
|
from torch import nn |
|
from torch.nn import CrossEntropyLoss |
|
from transformers import BertConfig, BertModel, BertPreTrainedModel, RobertaConfig |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
LAYOUTLMV1_PRETRAINED_MODEL_ARCHIVE_MAP = {} |
|
|
|
LAYOUTLMV1_PRETRAINED_CONFIG_ARCHIVE_MAP = {} |
|
|
|
|
|
class Layoutlmv1Config(RobertaConfig): |
|
pretrained_config_archive_map = LAYOUTLMV1_PRETRAINED_CONFIG_ARCHIVE_MAP |
|
model_type = "bert" |
|
|
|
def __init__(self, max_2d_position_embeddings=1024, add_linear=False, **kwargs): |
|
super().__init__(**kwargs) |
|
pass |
|
|
|
|
|
class Layoutlmv1Embeddings(nn.Module): |
|
def __init__(self, config): |
|
super(Layoutlmv1Embeddings, self).__init__() |
|
|
|
self.config = config |
|
|
|
self.word_embeddings = nn.Embedding( |
|
config.vocab_size, config.hidden_size, padding_idx=0 |
|
) |
|
self.position_embeddings = nn.Embedding( |
|
config.max_position_embeddings, config.hidden_size |
|
) |
|
config.max_2d_position_embeddings = 1024 |
|
self.x_position_embeddings = nn.Embedding( |
|
config.max_2d_position_embeddings, config.hidden_size |
|
) |
|
self.y_position_embeddings = nn.Embedding( |
|
config.max_2d_position_embeddings, config.hidden_size |
|
) |
|
self.h_position_embeddings = nn.Embedding( |
|
config.max_2d_position_embeddings, config.hidden_size |
|
) |
|
self.w_position_embeddings = nn.Embedding( |
|
config.max_2d_position_embeddings, config.hidden_size |
|
) |
|
self.token_type_embeddings = nn.Embedding( |
|
config.type_vocab_size, config.hidden_size |
|
) |
|
|
|
self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
|
self.doc_linear1 = nn.Linear(config.hidden_size, config.hidden_size) |
|
self.doc_linear2 = nn.Linear(config.hidden_size, config.hidden_size) |
|
self.doc_linear3 = nn.Linear(config.hidden_size, config.hidden_size) |
|
self.doc_linear4 = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
|
self.relu = nn.ReLU() |
|
|
|
def forward( |
|
self, |
|
input_ids, |
|
bbox, |
|
token_type_ids=None, |
|
position_ids=None, |
|
inputs_embeds=None, |
|
): |
|
seq_length = input_ids.size(1) |
|
if position_ids is None: |
|
position_ids = torch.arange( |
|
seq_length, dtype=torch.long, device=input_ids.device |
|
) |
|
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) |
|
if token_type_ids is None: |
|
token_type_ids = torch.zeros_like(input_ids) |
|
|
|
words_embeddings = self.word_embeddings(input_ids) |
|
position_embeddings = self.position_embeddings(position_ids) |
|
token_type_embeddings = self.token_type_embeddings(token_type_ids) |
|
|
|
left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0]) |
|
upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1]) |
|
right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2]) |
|
lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3]) |
|
h_position_embeddings = self.h_position_embeddings( |
|
bbox[:, :, 3] - bbox[:, :, 1] |
|
) |
|
w_position_embeddings = self.w_position_embeddings( |
|
bbox[:, :, 2] - bbox[:, :, 0] |
|
) |
|
|
|
|
|
temp_embeddings = self.doc_linear2(self.relu(self.doc_linear1( |
|
left_position_embeddings |
|
+ upper_position_embeddings |
|
+ right_position_embeddings |
|
+ lower_position_embeddings |
|
+ h_position_embeddings |
|
+ w_position_embeddings |
|
))) |
|
|
|
embeddings = ( |
|
words_embeddings |
|
+ position_embeddings |
|
+ temp_embeddings |
|
+ token_type_embeddings |
|
) |
|
|
|
embeddings = self.LayerNorm(embeddings) |
|
embeddings = self.dropout(embeddings) |
|
return embeddings |
|
|
|
|
|
class Layoutlmv1Model(BertModel): |
|
|
|
config_class = Layoutlmv1Config |
|
pretrained_model_archive_map = LAYOUTLMV1_PRETRAINED_MODEL_ARCHIVE_MAP |
|
base_model_prefix = "bert" |
|
|
|
def __init__(self, config): |
|
super(Layoutlmv1Model, self).__init__(config) |
|
self.embeddings = Layoutlmv1Embeddings(config) |
|
self.init_weights() |
|
|
|
def forward( |
|
self, |
|
input_ids, |
|
bbox, |
|
attention_mask=None, |
|
token_type_ids=None, |
|
position_ids=None, |
|
head_mask=None, |
|
inputs_embeds=None, |
|
encoder_hidden_states=None, |
|
encoder_attention_mask=None, |
|
): |
|
if attention_mask is None: |
|
attention_mask = torch.ones_like(input_ids) |
|
if token_type_ids is None: |
|
token_type_ids = torch.zeros_like(input_ids) |
|
|
|
|
|
|
|
|
|
|
|
|
|
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
extended_attention_mask = extended_attention_mask.to( |
|
dtype=torch.float32 |
|
|
|
) |
|
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
if head_mask is not None: |
|
if head_mask.dim() == 1: |
|
head_mask = ( |
|
head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) |
|
) |
|
head_mask = head_mask.expand( |
|
self.config.num_hidden_layers, -1, -1, -1, -1 |
|
) |
|
elif head_mask.dim() == 2: |
|
head_mask = ( |
|
head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) |
|
) |
|
head_mask = head_mask.to( |
|
dtype=next(self.parameters()).dtype |
|
) |
|
else: |
|
head_mask = [None] * self.config.num_hidden_layers |
|
|
|
embedding_output = self.embeddings( |
|
input_ids, bbox, position_ids=position_ids, token_type_ids=token_type_ids |
|
) |
|
encoder_outputs = self.encoder( |
|
embedding_output, extended_attention_mask, head_mask=head_mask |
|
) |
|
sequence_output = encoder_outputs[0] |
|
pooled_output = self.pooler(sequence_output) |
|
|
|
outputs = (sequence_output, pooled_output) + encoder_outputs[ |
|
1: |
|
] |
|
return outputs |
|
|
|
|
|
class Layoutlmv1ForTokenClassification(BertPreTrainedModel): |
|
config_class = Layoutlmv1Config |
|
pretrained_model_archive_map = LAYOUTLMV1_PRETRAINED_MODEL_ARCHIVE_MAP |
|
base_model_prefix = "bert" |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
self.roberta = Layoutlmv1Model(config) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
self.init_weights() |
|
|
|
def forward( |
|
self, |
|
input_ids, |
|
bbox, |
|
attention_mask=None, |
|
token_type_ids=None, |
|
position_ids=None, |
|
head_mask=None, |
|
inputs_embeds=None, |
|
labels=None, |
|
): |
|
|
|
outputs = self.roberta( |
|
input_ids=input_ids, |
|
bbox=bbox, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
) |
|
|
|
sequence_output = outputs[0] |
|
|
|
sequence_output = self.dropout(sequence_output) |
|
logits = self.classifier(sequence_output) |
|
|
|
outputs = (logits,) + outputs[ |
|
2: |
|
] |
|
if labels is not None: |
|
loss_fct = CrossEntropyLoss() |
|
|
|
if attention_mask is not None: |
|
active_loss = attention_mask.view(-1) == 1 |
|
active_logits = logits.view(-1, self.num_labels)[active_loss] |
|
active_labels = labels.view(-1)[active_loss] |
|
loss = loss_fct(active_logits, active_labels) |
|
else: |
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
outputs = (loss,) + outputs |
|
|
|
return outputs |
|
|
|
|
|
class Layoutlmv1ForMaskedLM(BertPreTrainedModel): |
|
config_class = Layoutlmv1Config |
|
pretrained_model_archive_map = LAYOUTLMV1_PRETRAINED_MODEL_ARCHIVE_MAP |
|
base_model_prefix = "bert" |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.bert = Layoutlmv1Model(config) |
|
self.cls = BertOnlyMLMHead(config) |
|
|
|
self.init_weights() |
|
|
|
def get_input_embeddings(self): |
|
return self.bert.embeddings.word_embeddings |
|
|
|
def get_output_embeddings(self): |
|
return self.cls.predictions.decoder |
|
|
|
def forward( |
|
self, |
|
input_ids, |
|
bbox, |
|
attention_mask=None, |
|
token_type_ids=None, |
|
position_ids=None, |
|
head_mask=None, |
|
inputs_embeds=None, |
|
masked_lm_labels=None, |
|
encoder_hidden_states=None, |
|
encoder_attention_mask=None, |
|
lm_labels=None, |
|
): |
|
|
|
outputs = self.layoutlm( |
|
input_ids, |
|
bbox, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
) |
|
|
|
sequence_output = outputs[0] |
|
prediction_scores = self.cls(sequence_output) |
|
|
|
outputs = (prediction_scores,) + outputs[ |
|
2: |
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if masked_lm_labels is not None: |
|
loss_fct = CrossEntropyLoss() |
|
masked_lm_loss = loss_fct( |
|
prediction_scores.view(-1, self.config.vocab_size), |
|
masked_lm_labels.view(-1), |
|
) |
|
outputs = (masked_lm_loss,) + outputs |
|
return ( |
|
outputs |
|
) |
|
|
|
|
|
|
|
class Layoutlmv1ForQuestionAnswering(BertPreTrainedModel): |
|
config_class = Layoutlmv1Config |
|
pretrained_model_archive_map = LAYOUTLMV1_PRETRAINED_MODEL_ARCHIVE_MAP |
|
base_model_prefix = "bert" |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
|
|
self.bert = Layoutlmv1Model(config) |
|
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
self.init_weights() |
|
|
|
def forward( |
|
self, |
|
input_ids, |
|
bbox, |
|
attention_mask=None, |
|
token_type_ids=None, |
|
position_ids=None, |
|
head_mask=None, |
|
|
|
start_positions=None, |
|
end_positions=None, |
|
|
|
|
|
|
|
): |
|
|
|
|
|
|
|
|
|
r""" |
|
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): |
|
Labels for position (index) of the start of the labelled span for computing the token classification loss. |
|
Positions are clamped to the length of the sequence (`sequence_length`). |
|
Position outside of the sequence are not taken into account for computing the loss. |
|
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): |
|
Labels for position (index) of the end of the labelled span for computing the token classification loss. |
|
Positions are clamped to the length of the sequence (`sequence_length`). |
|
Position outside of the sequence are not taken into account for computing the loss. |
|
""" |
|
|
|
|
|
outputs = self.bert( |
|
input_ids=input_ids, |
|
bbox=bbox, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
) |
|
|
|
sequence_output = outputs[0] |
|
|
|
logits = self.qa_outputs(sequence_output) |
|
start_logits, end_logits = logits.split(1, dim=-1) |
|
start_logits = start_logits.squeeze(-1) |
|
end_logits = end_logits.squeeze(-1) |
|
|
|
total_loss = None |
|
if start_positions is not None and end_positions is not None: |
|
|
|
if len(start_positions.size()) > 1: |
|
start_positions = start_positions.squeeze(-1) |
|
if len(end_positions.size()) > 1: |
|
end_positions = end_positions.squeeze(-1) |
|
|
|
ignored_index = start_logits.size(1) |
|
start_positions.clamp_(0, ignored_index) |
|
end_positions.clamp_(0, ignored_index) |
|
|
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) |
|
start_loss = loss_fct(start_logits, start_positions) |
|
end_loss = loss_fct(end_logits, end_positions) |
|
total_loss = (start_loss + end_loss) / 2 |
|
|
|
|
|
output = (start_logits, end_logits) + outputs[2:] |
|
|
|
|
|
return ((total_loss,) + output) if total_loss is not None else output |
|
|