|
|
from dataclasses import dataclass |
|
|
from typing import Optional |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from .CubeConfig import CubeConfig |
|
|
from transformers import ( |
|
|
GPT2Model, |
|
|
GPT2LMHeadModel, |
|
|
GenerationMixin, |
|
|
GPT2PreTrainedModel, |
|
|
PreTrainedModel, |
|
|
) |
|
|
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions |
|
|
from transformers.utils import ModelOutput |
|
|
|
|
|
IGNORE_INDEX = -100 |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class CubeLMOutput(CausalLMOutputWithCrossAttentions): |
|
|
|
|
|
lm_loss: Optional[torch.FloatTensor] = None |
|
|
cube_loss: Optional[torch.FloatTensor] = None |
|
|
|
|
|
cube_logits: Optional[torch.FloatTensor] = None |
|
|
|
|
|
|
|
|
class CubeLM(GPT2LMHeadModel): |
|
|
|
|
|
def __init__(self, config, task="sft", num_heads=24, num_classes=6): |
|
|
super().__init__(config) |
|
|
assert task in ["sft", "pretrain", "joint"] |
|
|
|
|
|
self.task = task |
|
|
self.alpha = None |
|
|
if hasattr(config, "alpha"): |
|
|
self.alpha = config.alpha |
|
|
self.vocab_size = config.vocab_size |
|
|
self.cube_heads = None |
|
|
if task in ["pretrain", "joint"]: |
|
|
self.cube_heads = nn.Linear( |
|
|
config.n_embd, num_heads * num_classes, bias=False |
|
|
) |
|
|
self.num_heads = num_heads |
|
|
self.num_classes = num_classes |
|
|
self.config = config |
|
|
|
|
|
def forward( |
|
|
self, input_ids, attention_mask=None, labels=None, cube_states=None, **kwargs |
|
|
): |
|
|
outputs = self.transformer( |
|
|
input_ids, |
|
|
attention_mask=attention_mask, |
|
|
) |
|
|
|
|
|
hidden_states = outputs.last_hidden_state |
|
|
|
|
|
lm_logits = None |
|
|
lm_loss = None |
|
|
|
|
|
if self.task in ["sft", "joint"]: |
|
|
lm_logits = self.lm_head(hidden_states) |
|
|
|
|
|
if labels is not None: |
|
|
shift_logits = lm_logits[:, :-1, :].contiguous() |
|
|
shift_labels = input_ids[:, 1:].contiguous() |
|
|
loss_fn = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX) |
|
|
lm_loss = loss_fn( |
|
|
shift_logits.view(-1, self.vocab_size), |
|
|
shift_labels.view(-1), |
|
|
) |
|
|
|
|
|
cube_logits = None |
|
|
cube_loss = None |
|
|
if self.cube_heads: |
|
|
cube_logits = self.cube_heads(hidden_states) |
|
|
|
|
|
if cube_states is not None: |
|
|
cube_logits = cube_logits.view( |
|
|
hidden_states.size(0), |
|
|
hidden_states.size(1), |
|
|
self.num_heads, |
|
|
self.num_classes, |
|
|
) |
|
|
|
|
|
_logits = cube_logits.view(-1, self.num_heads, self.num_classes) |
|
|
|
|
|
_labels = cube_states.view(-1, self.num_heads) |
|
|
loss_fn = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX) |
|
|
losses = [] |
|
|
for head_idx in range(self.num_heads): |
|
|
losses.append( |
|
|
loss_fn(_logits[:, head_idx, :], _labels[:, head_idx]) |
|
|
) |
|
|
cube_loss = sum(losses) / self.num_heads |
|
|
|
|
|
total_loss = None |
|
|
if lm_loss is not None and cube_loss is not None: |
|
|
assert self.alpha is not None |
|
|
total_loss = lm_loss + self.alpha * cube_loss |
|
|
elif lm_loss is not None: |
|
|
total_loss = lm_loss |
|
|
elif cube_loss is not None: |
|
|
total_loss = cube_loss |
|
|
|
|
|
return CubeLMOutput( |
|
|
loss=total_loss, |
|
|
lm_loss=lm_loss, |
|
|
cube_loss=cube_loss, |
|
|
logits=lm_logits, |
|
|
cube_logits=cube_logits, |
|
|
) |
|
|
|