# https://huggingface.co/docs/transformers/custom_models from transformers import PreTrainedModel, AutoModelForCausalLM, AutoTokenizer, AutoModel, AutoConfig from transformers.modeling_outputs import CausalLMOutputWithPast from torch.nn import CrossEntropyLoss from torch.nn.functional import log_softmax from torch.nn.modules.container import ModuleList from .configuration_custom5 import CustomConfig5 class CustomModel5(PreTrainedModel): config_class = CustomConfig5 def __init__(self, config): super().__init__(config) self.model = ModuleList([AutoModelForCausalLM.from_pretrained(m) for m in config.models]) def forward(self, *args, labels=None, **kwargs): loss = None logits = None for model, coeff in zip(self.model, self.config.coeffs): logp = log_softmax(model.forward(*args, **kwargs).logits, dim=-1) logits = coeff * logp if logits is None else logits + coeff * logp # The rest copied from modeling_llama.py: if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) return CausalLMOutputWithPast(loss=loss, logits=logits) ## Which one do we use? ## You have to tell the library you want to copy the code files of those objects when using the save_pretrained method and properly register them with a given Auto class (especially for models), just run: # CustomConfig5.register_for_auto_class() # CustomModel5.register_for_auto_class('AutoModelForCausalLM') # CustomModel5.register_for_auto_class('AutoModel') ## If you are writing a library that extends 🤗 Transformers, you may want to extend the auto classes to include your own model. This is different from pushing the code to the Hub in the sense that users will need to import your library to get the custom models (contrarily to automatically downloading the model code from the Hub). ## As long as your config has a model_type attribute that is different from existing model types, and that your model classes have the right config_class attributes, you can just add them to the auto classes like this: # AutoConfig.register("custom5", CustomConfig5) # AutoModel.register(CustomConfig5, CustomModel5) # AutoModelForCausalLM.register(CustomConfig5, CustomModel5)