File size: 1,178 Bytes
260d670
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# https://huggingface.co/docs/transformers/custom_models

from transformers import PreTrainedModel, GPTNeoXForCausalLM, AutoModelForCausalLM, AutoTokenizer, GPTNeoXConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from torch.nn.functional import log_softmax
from torch.nn.modules.container import ModuleList

# In the example there is also config class but we'll just use the one from GPTNeoX
# The norm is to import from PreTrainedModel but we'll take a shortcut
class CustomModel3(GPTNeoXForCausalLM):
    config_class = GPTNeoXConfig

    def __init__(self, config):
        super().__init__(config)

    def forward(self, *args, **kwargs):
        # See https://huggingface.co/docs/transformers/main_classes/output
        out = super().forward(*args, **kwargs)
        out.logits = log_softmax(out.logits, dim=-1)
        return out

    @classmethod
    def copy_from_neox(cls, *args, **kwargs):
        m0 = GPTNeoXForCausalLM.from_pretrained(*args, **kwargs)
        m1 = cls(m0.config).to(dtype=m0.dtype, device=m0.device)
        m1.load_state_dict(m0.state_dict())
        return m1

CustomModel3.register_for_auto_class('AutoModelForCausalLM')