|
|
|
|
|
from transformers import PreTrainedModel, AutoModelForCausalLM, AutoTokenizer, AutoModel, AutoConfig |
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
from torch.nn.functional import log_softmax |
|
from torch.nn.modules.container import ModuleList |
|
from .configuration_custom4 import CustomConfig4 |
|
|
|
class CustomModel4(PreTrainedModel): |
|
config_class = CustomConfig4 |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
def forward(self, *args, labels=None, **kwargs): |
|
loss = None |
|
logits = None |
|
for model, coeff in zip(self.models, self.coeffs): |
|
logp = log_softmax(model.forward(*args, **kwargs).logits, dim=-1) |
|
logits = coeff * logp if logits is None else logits + coeff * logp |
|
|
|
if labels is not None: |
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
loss_fct = CrossEntropyLoss() |
|
shift_logits = shift_logits.view(-1, self.config.vocab_size) |
|
shift_labels = shift_labels.view(-1) |
|
|
|
shift_labels = shift_labels.to(shift_logits.device) |
|
loss = loss_fct(shift_logits, shift_labels) |
|
|
|
return CausalLMOutputWithPast(loss=loss, logits=logits) |
|
|
|
|
|
@classmethod |
|
def combine_models(cls, *args, coeffs = [], **kwargs): |
|
models = [] |
|
for model in args: |
|
models.append(AutoModelForCausalLM.from_pretrained(model, **kwargs).eval()) |
|
if coeffs == []: |
|
coeffs = [1/len(args)] * len(args) |
|
m = cls(models[0].config) |
|
m.models = ModuleList(models) |
|
m.coeffs = coeffs |
|
return m |
|
|
|
|
|
|
|
CustomConfig4.register_for_auto_class() |
|
CustomModel4.register_for_auto_class('AutoModelForCausalLM') |
|
CustomModel4.register_for_auto_class('AutoModel') |
|
AutoConfig.register("custom4", CustomConfig4) |
|
AutoModel.register(CustomConfig4, CustomModel4) |
|
AutoModelForCausalLM.register(CustomConfig4, CustomModel4) |
|
|