from typing import Optional, Tuple, Union import torch from torch import nn from torch.nn import CrossEntropyLoss from transformers import ModernBertModel, ModernBertPreTrainedModel, ModernBertConfig from transformers.modeling_outputs import QuestionAnsweringModelOutput from transformers.models.modernbert.modeling_modernbert import _pad_modernbert_output, _unpad_modernbert_input, \ ModernBertPredictionHead class ModernBertForQuestionAnswering(ModernBertPreTrainedModel): def __init__(self, config: ModernBertConfig): super().__init__(config) self.num_labels = config.num_labels self.config = config self.model = ModernBertModel(config) self.head = ModernBertPredictionHead(config) self.qa_outputs = nn.Linear(config.hidden_size, 2) # 2 for start/end position logits self.qa_outputs.weight.data.normal_(mean=0.0, std=0.02) self.qa_outputs.bias.data.zero_() self.drop = torch.nn.Dropout(config.classifier_dropout) # Initialize weights and apply final processing self.post_init() @torch.compile(dynamic=True) def compiled_head(self, output: torch.Tensor) -> torch.Tensor: return self.head(output) def forward( self, input_ids: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor] = None, sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, start_positions: Optional[torch.Tensor] = None, end_positions: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, batch_size: Optional[int] = None, seq_len: Optional[int] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]: r""" start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for position (index) of the start of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for position (index) of the end of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict self._maybe_set_compile() # Get sequence length and batch size if not provided # if batch_size is None or seq_len is None: # batch_size, seq_len = input_ids.shape[:2] # # Handle Flash Attention 2 unpadding # if self.config._attn_implementation == "flash_attention_2": # if indices is None and cu_seqlens is None and max_seqlen is None: # if attention_mask is None: # attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool) # with torch.no_grad(): # input_ids, indices, cu_seqlens, max_seqlen, position_ids, _ = _unpad_modernbert_input( # inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids # ) outputs = self.model( input_ids, attention_mask=attention_mask, sliding_window_mask=sliding_window_mask, position_ids=position_ids, indices=indices, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, batch_size=batch_size, seq_len=seq_len, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, ) sequence_output = outputs[0] sequence_output = ( self.drop(self.compiled_head(sequence_output)) if self.config.reference_compile else self.drop(self.head(sequence_output)) ) # sequence_output = self.drop(self.head(sequence_output)) logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1) end_logits = end_logits.squeeze(-1) # # Handle Flash Attention 2 padding # if self.config._attn_implementation == "flash_attention_2": # start_logits = _pad_modernbert_output(inputs=start_logits, indices=indices, batch=batch_size, # seqlen=seq_len) # end_logits = _pad_modernbert_output(inputs=end_logits, indices=indices, batch=batch_size, # seqlen=seq_len) total_loss = None if start_positions is not None and end_positions is not None: # If we are on multi-GPU, split add a dimension if len(start_positions.size()) > 1: start_positions = start_positions.squeeze(-1) if len(end_positions.size()) > 1: end_positions = end_positions.squeeze(-1) # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = start_logits.size(1) start_positions = start_positions.clamp(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index) start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 if not return_dict: output = (start_logits, end_logits) + outputs[2:] return ((total_loss,) + output) if total_loss is not None else output return QuestionAnsweringModelOutput( loss=total_loss, start_logits=start_logits, end_logits=end_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )