import torch from transformers import BertLMHeadModel from .configuration import NewModelConfig class NewModelForCausalLM(BertLMHeadModel): config_class = NewModelConfig def __init__(self, config): super().__init__(config) self.last_layer = torch.nn.Linear(config.hidden_size, config.new_hidden_size)