File size: 585 Bytes
93e6aa5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from dataclasses import dataclass
from transformers import PretrainedConfig

@dataclass
class GPTConfig(PretrainedConfig):
    """
    Configuration class for custom GPT model.
    """
    model_type = "custom_gpt"
    block_size: int = 768
    vocab_size: int = 50257
    n_layer: int = 8
    n_head: int = 8
    n_embd: int = 768
    dropout: float = 0.1

    @classmethod
    def from_pretrained(cls, *args, **kwargs):
        """
        Override the from_pretrained method to handle custom configuration loading.
        """
        return super().from_pretrained(*args, **kwargs)