ajyl's picture
Upload CubeLM
0853b2a verified
raw
history blame
3.78 kB
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
from .CubeConfig import CubeConfig
from transformers import (
GPT2Model,
GPT2LMHeadModel,
GenerationMixin,
GPT2PreTrainedModel,
PreTrainedModel,
)
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from transformers.utils import ModelOutput
IGNORE_INDEX = -100
@dataclass
class CubeLMOutput(CausalLMOutputWithCrossAttentions):
# loss: Optional[torch.FloatTensor] = None
lm_loss: Optional[torch.FloatTensor] = None
cube_loss: Optional[torch.FloatTensor] = None
# logits: Optional[torch.FloatTensor] = None
cube_logits: Optional[torch.FloatTensor] = None
class CubeLM(GPT2LMHeadModel):
def __init__(self, config, task="sft", num_heads=24, num_classes=6):
super().__init__(config)
assert task in ["sft", "pretrain", "joint"]
self.task = task
self.alpha = None
if hasattr(config, "alpha"):
self.alpha = config.alpha
self.vocab_size = config.vocab_size
self.cube_heads = None
if task in ["pretrain", "joint"]:
self.cube_heads = nn.Linear(
config.n_embd, num_heads * num_classes, bias=False
)
self.num_heads = num_heads
self.num_classes = num_classes
self.config = config
def forward(
self, input_ids, attention_mask=None, labels=None, cube_states=None, **kwargs
):
outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
)
# [batch, seq, d_model]
hidden_states = outputs.last_hidden_state
lm_logits = None
lm_loss = None
# [batch, seq, n_heads * n_classes]
if self.task in ["sft", "joint"]:
lm_logits = self.lm_head(hidden_states)
if labels is not None:
shift_logits = lm_logits[:, :-1, :].contiguous()
shift_labels = input_ids[:, 1:].contiguous()
loss_fn = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX)
lm_loss = loss_fn(
shift_logits.view(-1, self.vocab_size),
shift_labels.view(-1),
)
cube_logits = None
cube_loss = None
if self.cube_heads:
cube_logits = self.cube_heads(hidden_states)
if cube_states is not None:
cube_logits = cube_logits.view(
hidden_states.size(0),
hidden_states.size(1),
self.num_heads,
self.num_classes,
)
# [batch * seq, n_heads, n_classes]
_logits = cube_logits.view(-1, self.num_heads, self.num_classes)
# [batch * seq, n_heads]
_labels = cube_states.view(-1, self.num_heads)
loss_fn = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX)
losses = []
for head_idx in range(self.num_heads):
losses.append(
loss_fn(_logits[:, head_idx, :], _labels[:, head_idx])
)
cube_loss = sum(losses) / self.num_heads
total_loss = None
if lm_loss is not None and cube_loss is not None:
assert self.alpha is not None
total_loss = lm_loss + self.alpha * cube_loss
elif lm_loss is not None:
total_loss = lm_loss
elif cube_loss is not None:
total_loss = cube_loss
return CubeLMOutput(
loss=total_loss,
lm_loss=lm_loss,
cube_loss=cube_loss,
logits=lm_logits,
cube_logits=cube_logits,
)