Commit 
							
							·
						
						8b5c39e
	
1
								Parent(s):
							
							0fa2b48
								
Upload 11 files
Browse files- added_tokens.json +40 -0
- config.json +26 -0
- configuration_mixformer_sequential.py +53 -0
- generation_config.json +4 -0
- merges.txt +0 -0
- modeling_mixformer_sequential.py +771 -0
- pytorch_model.bin +3 -0
- special_tokens_map.json +5 -0
- tokenizer.json +0 -0
- tokenizer_config.json +323 -0
- vocab.json +0 -0
    	
        added_tokens.json
    ADDED
    
    | @@ -0,0 +1,40 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "\t\t": 50294,
         | 
| 3 | 
            +
              "\t\t\t": 50293,
         | 
| 4 | 
            +
              "\t\t\t\t": 50292,
         | 
| 5 | 
            +
              "\t\t\t\t\t": 50291,
         | 
| 6 | 
            +
              "\t\t\t\t\t\t": 50290,
         | 
| 7 | 
            +
              "\t\t\t\t\t\t\t": 50289,
         | 
| 8 | 
            +
              "\t\t\t\t\t\t\t\t": 50288,
         | 
| 9 | 
            +
              "\t\t\t\t\t\t\t\t\t": 50287,
         | 
| 10 | 
            +
              "  ": 50286,
         | 
| 11 | 
            +
              "   ": 50285,
         | 
| 12 | 
            +
              "    ": 50284,
         | 
| 13 | 
            +
              "     ": 50283,
         | 
| 14 | 
            +
              "      ": 50282,
         | 
| 15 | 
            +
              "       ": 50281,
         | 
| 16 | 
            +
              "        ": 50280,
         | 
| 17 | 
            +
              "         ": 50279,
         | 
| 18 | 
            +
              "          ": 50278,
         | 
| 19 | 
            +
              "           ": 50277,
         | 
| 20 | 
            +
              "            ": 50276,
         | 
| 21 | 
            +
              "             ": 50275,
         | 
| 22 | 
            +
              "              ": 50274,
         | 
| 23 | 
            +
              "               ": 50273,
         | 
| 24 | 
            +
              "                ": 50272,
         | 
| 25 | 
            +
              "                 ": 50271,
         | 
| 26 | 
            +
              "                  ": 50270,
         | 
| 27 | 
            +
              "                   ": 50269,
         | 
| 28 | 
            +
              "                    ": 50268,
         | 
| 29 | 
            +
              "                     ": 50267,
         | 
| 30 | 
            +
              "                      ": 50266,
         | 
| 31 | 
            +
              "                       ": 50265,
         | 
| 32 | 
            +
              "                        ": 50264,
         | 
| 33 | 
            +
              "                         ": 50263,
         | 
| 34 | 
            +
              "                          ": 50262,
         | 
| 35 | 
            +
              "                           ": 50261,
         | 
| 36 | 
            +
              "                            ": 50260,
         | 
| 37 | 
            +
              "                             ": 50259,
         | 
| 38 | 
            +
              "                              ": 50258,
         | 
| 39 | 
            +
              "                               ": 50257
         | 
| 40 | 
            +
            }
         | 
    	
        config.json
    ADDED
    
    | @@ -0,0 +1,26 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "_name_or_path": "/Users/danielbarcenas/Downloads/Modelo/phi-1_5",
         | 
| 3 | 
            +
              "activation_function": "gelu_new",
         | 
| 4 | 
            +
              "architectures": [
         | 
| 5 | 
            +
                "MixFormerSequentialForCausalLM"
         | 
| 6 | 
            +
              ],
         | 
| 7 | 
            +
              "auto_map": {
         | 
| 8 | 
            +
                "AutoConfig": "configuration_mixformer_sequential.MixFormerSequentialConfig",
         | 
| 9 | 
            +
                "AutoModelForCausalLM": "modeling_mixformer_sequential.MixFormerSequentialForCausalLM"
         | 
| 10 | 
            +
              },
         | 
| 11 | 
            +
              "embd_pdrop": 0.0,
         | 
| 12 | 
            +
              "initializer_range": 0.02,
         | 
| 13 | 
            +
              "layer_norm_epsilon": 1e-05,
         | 
| 14 | 
            +
              "model_type": "mixformer-sequential",
         | 
| 15 | 
            +
              "n_embd": 2048,
         | 
| 16 | 
            +
              "n_head": 32,
         | 
| 17 | 
            +
              "n_inner": null,
         | 
| 18 | 
            +
              "n_layer": 24,
         | 
| 19 | 
            +
              "n_positions": 2048,
         | 
| 20 | 
            +
              "resid_pdrop": 0.0,
         | 
| 21 | 
            +
              "rotary_dim": 32,
         | 
| 22 | 
            +
              "tie_word_embeddings": false,
         | 
| 23 | 
            +
              "torch_dtype": "float16",
         | 
| 24 | 
            +
              "transformers_version": "4.35.0.dev0",
         | 
| 25 | 
            +
              "vocab_size": 51200
         | 
| 26 | 
            +
            }
         | 
    	
        configuration_mixformer_sequential.py
    ADDED
    
    | @@ -0,0 +1,53 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Microsoft Corporation.
         | 
| 2 | 
            +
            # Licensed under the MIT license.
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import math
         | 
| 5 | 
            +
            from typing import Any, Dict, List, Optional, Union
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from transformers import PretrainedConfig
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            class MixFormerSequentialConfig(PretrainedConfig):
         | 
| 11 | 
            +
                """MixFormer (sequential for DeepSpeed) configuration."""
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                model_type = "mixformer-sequential"
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                attribute_map = {
         | 
| 16 | 
            +
                    "max_position_embeddings": "n_positions",
         | 
| 17 | 
            +
                    "hidden_size": "n_embd",
         | 
| 18 | 
            +
                    "num_attention_heads": "n_head",
         | 
| 19 | 
            +
                    "num_hidden_layers": "n_layer",
         | 
| 20 | 
            +
                }
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                def __init__(
         | 
| 23 | 
            +
                    self,
         | 
| 24 | 
            +
                    vocab_size: Optional[int] = 50304,
         | 
| 25 | 
            +
                    n_positions: Optional[int] = 2048,
         | 
| 26 | 
            +
                    n_embd: Optional[int] = 1024,
         | 
| 27 | 
            +
                    n_layer: Optional[int] = 20,
         | 
| 28 | 
            +
                    n_inner: Optional[int] = None,
         | 
| 29 | 
            +
                    n_head: Optional[int] = 16,
         | 
| 30 | 
            +
                    rotary_dim: Optional[int] = 32,
         | 
| 31 | 
            +
                    activation_function: Optional[str] = "gelu_new",
         | 
| 32 | 
            +
                    embd_pdrop: Optional[float] = 0.0,
         | 
| 33 | 
            +
                    resid_pdrop: Optional[float] = 0.0,
         | 
| 34 | 
            +
                    layer_norm_epsilon: Optional[float] = 1e-5,
         | 
| 35 | 
            +
                    initializer_range: Optional[float] = 0.02,
         | 
| 36 | 
            +
                    tie_word_embeddings: Optional[bool] = False,
         | 
| 37 | 
            +
                    pad_vocab_size_multiple: Optional[int] = 64,
         | 
| 38 | 
            +
                    **kwargs
         | 
| 39 | 
            +
                ) -> None:
         | 
| 40 | 
            +
                    self.vocab_size = int(math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
         | 
| 41 | 
            +
                    self.n_positions = n_positions
         | 
| 42 | 
            +
                    self.n_embd = n_embd
         | 
| 43 | 
            +
                    self.n_layer = n_layer
         | 
| 44 | 
            +
                    self.n_inner = n_inner
         | 
| 45 | 
            +
                    self.n_head = n_head
         | 
| 46 | 
            +
                    self.rotary_dim = min(rotary_dim, n_embd // n_head)
         | 
| 47 | 
            +
                    self.activation_function = activation_function
         | 
| 48 | 
            +
                    self.embd_pdrop = embd_pdrop
         | 
| 49 | 
            +
                    self.resid_pdrop = resid_pdrop
         | 
| 50 | 
            +
                    self.layer_norm_epsilon = layer_norm_epsilon
         | 
| 51 | 
            +
                    self.initializer_range = initializer_range
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                    super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
         | 
    	
        generation_config.json
    ADDED
    
    | @@ -0,0 +1,4 @@ | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "_from_model_config": true,
         | 
| 3 | 
            +
              "transformers_version": "4.35.0.dev0"
         | 
| 4 | 
            +
            }
         | 
    	
        merges.txt
    ADDED
    
    | The diff for this file is too large to render. 
		See raw diff | 
|  | 
    	
        modeling_mixformer_sequential.py
    ADDED
    
    | @@ -0,0 +1,771 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Microsoft Corporation.
         | 
| 2 | 
            +
            # Licensed under the MIT license.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # BSD 3-Clause License
         | 
| 5 | 
            +
            # 
         | 
| 6 | 
            +
            # Copyright (c) 2022, Tri Dao, [email protected].
         | 
| 7 | 
            +
            # All rights reserved.
         | 
| 8 | 
            +
            # 
         | 
| 9 | 
            +
            # Redistribution and use in source and binary forms, with or without
         | 
| 10 | 
            +
            # modification, are permitted provided that the following conditions are met:
         | 
| 11 | 
            +
            # 
         | 
| 12 | 
            +
            # * Redistributions of source code must retain the above copyright notice, this
         | 
| 13 | 
            +
            #   list of conditions and the following disclaimer.
         | 
| 14 | 
            +
            # 
         | 
| 15 | 
            +
            # * Redistributions in binary form must reproduce the above copyright notice,
         | 
| 16 | 
            +
            #   this list of conditions and the following disclaimer in the documentation
         | 
| 17 | 
            +
            #   and/or other materials provided with the distribution.
         | 
| 18 | 
            +
            # 
         | 
| 19 | 
            +
            # * Neither the name of the copyright holder nor the names of its
         | 
| 20 | 
            +
            #   contributors may be used to endorse or promote products derived from
         | 
| 21 | 
            +
            #   this software without specific prior written permission.
         | 
| 22 | 
            +
            # 
         | 
| 23 | 
            +
            # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
         | 
| 24 | 
            +
            # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
         | 
| 25 | 
            +
            # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
         | 
| 26 | 
            +
            # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
         | 
| 27 | 
            +
            # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
         | 
| 28 | 
            +
            # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
         | 
| 29 | 
            +
            # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
         | 
| 30 | 
            +
            # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
         | 
| 31 | 
            +
            # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
         | 
| 32 | 
            +
            # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            from __future__ import annotations
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            import math
         | 
| 37 | 
            +
            import copy
         | 
| 38 | 
            +
            from typing import Any, Dict, Optional, Tuple, Union
         | 
| 39 | 
            +
            from dataclasses import dataclass, field
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            import torch
         | 
| 42 | 
            +
            import torch.nn as nn
         | 
| 43 | 
            +
             | 
| 44 | 
            +
            from einops import rearrange
         | 
| 45 | 
            +
            from transformers.activations import ACT2FN
         | 
| 46 | 
            +
            from transformers import PretrainedConfig, PreTrainedModel
         | 
| 47 | 
            +
            from transformers.modeling_outputs import CausalLMOutputWithPast
         | 
| 48 | 
            +
             | 
| 49 | 
            +
            from .configuration_mixformer_sequential import MixFormerSequentialConfig
         | 
| 50 | 
            +
             | 
| 51 | 
            +
            @dataclass
         | 
| 52 | 
            +
            class InferenceParams:
         | 
| 53 | 
            +
                """Inference parameters passed to model to efficiently calculate
         | 
| 54 | 
            +
                and store context during inference.
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                Reference:
         | 
| 57 | 
            +
                    https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py.
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                Args:
         | 
| 60 | 
            +
                    max_sequence_len: Maximum sequence length.
         | 
| 61 | 
            +
                    max_batch_size: Maximum batch size.
         | 
| 62 | 
            +
                    sequence_len_offset: Sequence length offset.
         | 
| 63 | 
            +
                    batch_size_offset: Batch size offset.
         | 
| 64 | 
            +
                    key_value_memory_dict: Key value memory dictionary.
         | 
| 65 | 
            +
                    fused_ft_kernel: Whether to use fused kernel for fast inference.
         | 
| 66 | 
            +
                    lengths_per_sample: Lengths per sample.
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                """
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                max_sequence_len: int = field(metadata={"help": "Maximum sequence length."})
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                max_batch_size: int = field(metadata={"help": "Maximum batch size."})
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                sequence_len_offset: int = field(default=0, metadata={"help": "Sequence length offset."})
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                batch_size_offset: int = field(default=0, metadata={"help": "Batch size offset."})
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                key_value_memory_dict: Dict[str, Any] = field(
         | 
| 79 | 
            +
                    default_factory=dict, metadata={"help": "Key value memory dictionary."}
         | 
| 80 | 
            +
                )
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                fused_ft_kernel: bool = field(default=False, metadata={"help": "Whether to use fused kernel for fast inference."})
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                lengths_per_sample: torch.Tensor = field(default=None, metadata={"help": "Lengths per sample."})
         | 
| 85 | 
            +
             | 
| 86 | 
            +
             | 
| 87 | 
            +
            class Embedding(nn.Module):
         | 
| 88 | 
            +
                """Token embedding with dropout."""
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                def __init__(self, config: PretrainedConfig) -> None:
         | 
| 91 | 
            +
                    super().__init__()
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                    self.wte = nn.Embedding(config.vocab_size, config.n_embd)
         | 
| 94 | 
            +
                    self.drop = nn.Dropout(config.embd_pdrop)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
         | 
| 97 | 
            +
                    input_shape = input_ids.size()
         | 
| 98 | 
            +
                    input_ids = input_ids.view(-1, input_shape[-1])
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                    hidden_states = self.wte(input_ids)
         | 
| 101 | 
            +
                    hidden_states = self.drop(hidden_states)
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                    return hidden_states
         | 
| 104 | 
            +
             | 
| 105 | 
            +
             | 
| 106 | 
            +
            class RotaryEmbedding(nn.Module):
         | 
| 107 | 
            +
                """Rotary embeddings.
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                Reference:
         | 
| 110 | 
            +
                    https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py.
         | 
| 111 | 
            +
                
         | 
| 112 | 
            +
                """
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                def __init__(
         | 
| 115 | 
            +
                    self,
         | 
| 116 | 
            +
                    dim: int,
         | 
| 117 | 
            +
                    base: int = 10000,
         | 
| 118 | 
            +
                    scale_base: Optional[float] = None,
         | 
| 119 | 
            +
                    device: Optional[str] = None,
         | 
| 120 | 
            +
                    **kwargs,
         | 
| 121 | 
            +
                ) -> None:
         | 
| 122 | 
            +
                    super().__init__()
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                    if scale_base is not None:
         | 
| 125 | 
            +
                        raise NotImplementedError
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                    # Generate and save the inverse frequency buffer (non-trainable)
         | 
| 128 | 
            +
                    self.dim = dim
         | 
| 129 | 
            +
                    self.base = base
         | 
| 130 | 
            +
                    self.scale_base = scale_base
         | 
| 131 | 
            +
                    self.device = device
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
         | 
| 134 | 
            +
                    self.register_buffer("inv_freq", inv_freq)
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                    scale = (
         | 
| 137 | 
            +
                        (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
         | 
| 138 | 
            +
                        if scale_base is not None
         | 
| 139 | 
            +
                        else None
         | 
| 140 | 
            +
                    )
         | 
| 141 | 
            +
                    self.register_buffer("scale", scale)
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                    self._seq_len_cached = 0
         | 
| 144 | 
            +
                    self._cos_cached = None
         | 
| 145 | 
            +
                    self._sin_cached = None
         | 
| 146 | 
            +
                    self._cos_k_cached = None
         | 
| 147 | 
            +
                    self._sin_k_cached = None
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                def _update_cos_sin_cache(self, x: torch.FloatTensor, seqlen_offset: int = 0) -> None:
         | 
| 150 | 
            +
                    # Reset the tables if the sequence length has changed,
         | 
| 151 | 
            +
                    # or if we're on a new device (possibly due to tracing for instance)
         | 
| 152 | 
            +
                    seqlen = x.shape[1] + seqlen_offset
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                    # Re-generate the inverse frequency buffer if it's not fp32
         | 
| 155 | 
            +
                    # (for instance if model.half() was called)
         | 
| 156 | 
            +
                    if self.inv_freq.dtype != "torch.float32":
         | 
| 157 | 
            +
                        self.inv_freq = 1.0 / (
         | 
| 158 | 
            +
                            self.base ** (torch.arange(0, self.dim, 2, device=self.device, dtype=torch.float32) / self.dim)
         | 
| 159 | 
            +
                        )
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                    if seqlen > self._seq_len_cached or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype:
         | 
| 162 | 
            +
                        self._seq_len_cached = seqlen
         | 
| 163 | 
            +
                        t = torch.arange(seqlen, device=x.device, dtype=torch.float32)
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                        # Don't do einsum, it converts fp32 to fp16
         | 
| 166 | 
            +
                        # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
         | 
| 167 | 
            +
                        freqs = torch.outer(t, self.inv_freq.to(device=t.device, dtype=torch.float32))
         | 
| 168 | 
            +
                        if self.scale is None:
         | 
| 169 | 
            +
                            self._cos_cached = torch.cos(freqs).to(x.dtype)
         | 
| 170 | 
            +
                            self._sin_cached = torch.sin(freqs).to(x.dtype)
         | 
| 171 | 
            +
                        else:
         | 
| 172 | 
            +
                            power = (
         | 
| 173 | 
            +
                                torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
         | 
| 174 | 
            +
                            ) / self.scale_base
         | 
| 175 | 
            +
                            scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                            # We want the multiplication by scale to happen in fp32
         | 
| 178 | 
            +
                            self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype)
         | 
| 179 | 
            +
                            self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype)
         | 
| 180 | 
            +
                            self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
         | 
| 181 | 
            +
                            self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                def _apply_rotary_emb_qkv(
         | 
| 184 | 
            +
                    self,
         | 
| 185 | 
            +
                    qkv: torch.FloatTensor,
         | 
| 186 | 
            +
                    sin: torch.FloatTensor,
         | 
| 187 | 
            +
                    cos: torch.FloatTensor,
         | 
| 188 | 
            +
                    sin_k: Optional[torch.FloatTensor] = None,
         | 
| 189 | 
            +
                    cos_k: Optional[torch.FloatTensor] = None,
         | 
| 190 | 
            +
                ) -> torch.FloatTensor:
         | 
| 191 | 
            +
                    _, seqlen, three, _, headdim = qkv.shape
         | 
| 192 | 
            +
                    assert three == 3
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                    rotary_seqlen, rotary_dim = cos.shape
         | 
| 195 | 
            +
                    rotary_dim *= 2
         | 
| 196 | 
            +
                    assert rotary_dim <= headdim
         | 
| 197 | 
            +
                    assert seqlen <= rotary_seqlen
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    cos_k = cos if cos_k is None else cos_k
         | 
| 200 | 
            +
                    sin_k = sin if sin_k is None else sin_k
         | 
| 201 | 
            +
                    assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                    q_rot = qkv[:, :, 0, :, :rotary_dim]
         | 
| 204 | 
            +
                    q_pass = qkv[:, :, 0, :, rotary_dim:]
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                    k_rot = qkv[:, :, 1, :, :rotary_dim]
         | 
| 207 | 
            +
                    k_pass = qkv[:, :, 1, :, rotary_dim:]
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                    # Splits the queries and keys in half
         | 
| 210 | 
            +
                    q1, q2 = q_rot.chunk(2, dim=-1)
         | 
| 211 | 
            +
                    k1, k2 = k_rot.chunk(2, dim=-1)
         | 
| 212 | 
            +
                    c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                    # Casts to fp32 are necessary to prevent fp16 overflow issues
         | 
| 215 | 
            +
                    q1, q2, k1, k2, c, s = [t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]]
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                    # Computes the new keys and queries, recasting to original dtype
         | 
| 218 | 
            +
                    q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
         | 
| 219 | 
            +
                    k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                    return torch.cat(
         | 
| 222 | 
            +
                        [
         | 
| 223 | 
            +
                            torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2),
         | 
| 224 | 
            +
                            torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
         | 
| 225 | 
            +
                            qkv[:, :, 2:3, :, :],
         | 
| 226 | 
            +
                        ],
         | 
| 227 | 
            +
                        axis=2,
         | 
| 228 | 
            +
                    )
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                def forward(self, qkv: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 231 | 
            +
                    # `qkv` is of shape (batch, seqlen, 3, nheads, headdim)
         | 
| 232 | 
            +
                    self._update_cos_sin_cache(qkv, seqlen_offset)
         | 
| 233 | 
            +
                    return self._apply_rotary_emb_qkv(qkv, self._sin_cached[seqlen_offset:], self._cos_cached[seqlen_offset:])
         | 
| 234 | 
            +
             | 
| 235 | 
            +
             | 
| 236 | 
            +
            class MLP(nn.Module):
         | 
| 237 | 
            +
                """Multi-Layer Perceptron.
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                Reference:
         | 
| 240 | 
            +
                    Attention Is All You Need.
         | 
| 241 | 
            +
                    https://arxiv.org/pdf/1706.03762.pdf.
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                """
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                def __init__(self, config: PretrainedConfig, n_inner: Optional[int] = None, act_fn: Optional[str] = None) -> None:
         | 
| 246 | 
            +
                    super().__init__()
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                    act_fn = config.activation_function if act_fn is None else act_fn
         | 
| 249 | 
            +
                    assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}."
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                    n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
         | 
| 252 | 
            +
                    n_inner = n_inner if n_inner is not None else 4 * config.n_embd
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                    self.fc1 = nn.Linear(config.n_embd, n_inner)
         | 
| 255 | 
            +
                    self.fc2 = nn.Linear(n_inner, config.n_embd)
         | 
| 256 | 
            +
                    self.act = ACT2FN[act_fn]
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
         | 
| 259 | 
            +
                    hidden_states = self.fc1(hidden_states)
         | 
| 260 | 
            +
                    hidden_states = self.act(hidden_states)
         | 
| 261 | 
            +
                    hidden_states = self.fc2(hidden_states)
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                    return hidden_states
         | 
| 264 | 
            +
             | 
| 265 | 
            +
             | 
| 266 | 
            +
            class SelfAttention(nn.Module):
         | 
| 267 | 
            +
                """Self-attention layer (compatible with PyTorch).
         | 
| 268 | 
            +
                
         | 
| 269 | 
            +
                Reference:
         | 
| 270 | 
            +
                    https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                """
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                def __init__(
         | 
| 275 | 
            +
                    self,
         | 
| 276 | 
            +
                    causal: bool = True,
         | 
| 277 | 
            +
                    softmax_scale: Optional[float] = None,
         | 
| 278 | 
            +
                    attention_dropout: float = 0.0,
         | 
| 279 | 
            +
                ) -> None:
         | 
| 280 | 
            +
                    super().__init__()
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                    self.causal = causal
         | 
| 283 | 
            +
                    self.softmax_scale = softmax_scale
         | 
| 284 | 
            +
                    self.drop = nn.Dropout(attention_dropout)
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                def forward(
         | 
| 287 | 
            +
                    self,
         | 
| 288 | 
            +
                    qkv: torch.FloatTensor,
         | 
| 289 | 
            +
                    causal: bool = None,
         | 
| 290 | 
            +
                    attention_mask: Optional[torch.BoolTensor] = None,
         | 
| 291 | 
            +
                    **kwargs,
         | 
| 292 | 
            +
                ) -> torch.FloatTensor:
         | 
| 293 | 
            +
                    causal = self.causal if causal is None else causal
         | 
| 294 | 
            +
                    batch_size, seq_len = qkv.shape[0], qkv.shape[1]
         | 
| 295 | 
            +
                    q, k, v = qkv.unbind(dim=2)
         | 
| 296 | 
            +
             | 
| 297 | 
            +
                    softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
         | 
| 298 | 
            +
                    scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                    if attention_mask is not None:
         | 
| 301 | 
            +
                        padding_mask = torch.full((batch_size, seq_len), -10000.0, dtype=scores.dtype, device=scores.device)
         | 
| 302 | 
            +
                        padding_mask.masked_fill_(attention_mask, 0.0)
         | 
| 303 | 
            +
             | 
| 304 | 
            +
                        scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                    if causal:
         | 
| 307 | 
            +
                        causal_mask = torch.triu(torch.full((seq_len, seq_len), -10000.0, device=scores.device), 1)
         | 
| 308 | 
            +
                        scores = scores + causal_mask.to(dtype=scores.dtype)
         | 
| 309 | 
            +
             | 
| 310 | 
            +
                    attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
         | 
| 311 | 
            +
                    attention = self.drop(attention)
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                    output = torch.einsum("bhts,bshd->bthd", attention, v)
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                    return output
         | 
| 316 | 
            +
             | 
| 317 | 
            +
             | 
| 318 | 
            +
            class CrossAttention(nn.Module):
         | 
| 319 | 
            +
                """Cross-attention layer (compatible with PyTorch).
         | 
| 320 | 
            +
                
         | 
| 321 | 
            +
                Reference:
         | 
| 322 | 
            +
                    https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
         | 
| 323 | 
            +
                
         | 
| 324 | 
            +
                """
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                def __init__(
         | 
| 327 | 
            +
                    self,
         | 
| 328 | 
            +
                    causal: bool = True,
         | 
| 329 | 
            +
                    softmax_scale: Optional[float] = None,
         | 
| 330 | 
            +
                    attention_dropout: float = 0.0,
         | 
| 331 | 
            +
                ) -> None:
         | 
| 332 | 
            +
                    super().__init__()
         | 
| 333 | 
            +
             | 
| 334 | 
            +
                    self.causal = causal
         | 
| 335 | 
            +
                    self.softmax_scale = softmax_scale
         | 
| 336 | 
            +
                    self.drop = nn.Dropout(attention_dropout)
         | 
| 337 | 
            +
             | 
| 338 | 
            +
                def forward(
         | 
| 339 | 
            +
                    self,
         | 
| 340 | 
            +
                    q: torch.FloatTensor,
         | 
| 341 | 
            +
                    kv: torch.FloatTensor,
         | 
| 342 | 
            +
                    causal: bool = None,
         | 
| 343 | 
            +
                    attention_mask: Optional[torch.BoolTensor] = None,
         | 
| 344 | 
            +
                    **kwargs,
         | 
| 345 | 
            +
                ) -> torch.FloatTensor:
         | 
| 346 | 
            +
                    causal = self.causal if causal is None else causal
         | 
| 347 | 
            +
                    batch_size, seq_len_q = q.shape[0], q.shape[1]
         | 
| 348 | 
            +
                    assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3]
         | 
| 349 | 
            +
             | 
| 350 | 
            +
                    seq_len_k = kv.shape[1]
         | 
| 351 | 
            +
                    k, v = kv.unbind(dim=2)
         | 
| 352 | 
            +
             | 
| 353 | 
            +
                    softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
         | 
| 354 | 
            +
                    scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
         | 
| 355 | 
            +
             | 
| 356 | 
            +
                    if attention_mask is not None:
         | 
| 357 | 
            +
                        padding_mask = torch.full((batch_size, seq_len_k), -10000.0, dtype=scores.dtype, device=scores.device)
         | 
| 358 | 
            +
                        padding_mask.masked_fill_(attention_mask, 0.0)
         | 
| 359 | 
            +
             | 
| 360 | 
            +
                        scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
         | 
| 361 | 
            +
             | 
| 362 | 
            +
                    if causal:
         | 
| 363 | 
            +
                        causal_mask = torch.triu(torch.full((seq_len_q, seq_len_k), -10000.0, device=scores.device), 1)
         | 
| 364 | 
            +
                        scores = scores + causal_mask.to(dtype=scores.dtype)
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                    attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
         | 
| 367 | 
            +
                    attention = self.drop(attention)
         | 
| 368 | 
            +
             | 
| 369 | 
            +
                    output = torch.einsum("bhts,bshd->bthd", attention, v)
         | 
| 370 | 
            +
             | 
| 371 | 
            +
                    return output
         | 
| 372 | 
            +
             | 
| 373 | 
            +
             | 
| 374 | 
            +
            def find_mha_dims(
         | 
| 375 | 
            +
                config: PretrainedConfig, n_head: Optional[int] = None, head_dim: Optional[int] = None
         | 
| 376 | 
            +
            ) -> Tuple[int, int]:
         | 
| 377 | 
            +
                """Validate and return the number of heads and head dimension for multi-head attention.
         | 
| 378 | 
            +
             | 
| 379 | 
            +
                Args:
         | 
| 380 | 
            +
                    config: Model configuration.
         | 
| 381 | 
            +
                    n_head: Number of heads.
         | 
| 382 | 
            +
                    head_dim: Head dimension.
         | 
| 383 | 
            +
             | 
| 384 | 
            +
                Returns:
         | 
| 385 | 
            +
                    Number of heads and head dimension.
         | 
| 386 | 
            +
             | 
| 387 | 
            +
                """
         | 
| 388 | 
            +
             | 
| 389 | 
            +
                assert all(
         | 
| 390 | 
            +
                    hasattr(config, attr) for attr in ["n_embd", "n_head"]
         | 
| 391 | 
            +
                ), "`config` must have `n_embd` and `n_head` attributes."
         | 
| 392 | 
            +
             | 
| 393 | 
            +
                if head_dim is None:
         | 
| 394 | 
            +
                    assert (
         | 
| 395 | 
            +
                        config.n_embd % config.n_head == 0
         | 
| 396 | 
            +
                    ), f"Hidden size ({config.n_embd}) must be divisible by the number of heads ({config.n_head})."
         | 
| 397 | 
            +
             | 
| 398 | 
            +
                if n_head is None and head_dim is None:
         | 
| 399 | 
            +
                    head_dim = config.n_embd // config.n_head
         | 
| 400 | 
            +
                    n_head = config.n_head
         | 
| 401 | 
            +
                elif n_head is None or head_dim is None:
         | 
| 402 | 
            +
                    raise ValueError("`n_head` and `head_dim` must be both specified or `None`.")
         | 
| 403 | 
            +
             | 
| 404 | 
            +
                return n_head, head_dim
         | 
| 405 | 
            +
             | 
| 406 | 
            +
             | 
| 407 | 
            +
            def update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, layer_idx: int) -> torch.FloatTensor:
         | 
| 408 | 
            +
                """Update the key-value cache for inference.
         | 
| 409 | 
            +
             | 
| 410 | 
            +
                Reference:
         | 
| 411 | 
            +
                    https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
         | 
| 412 | 
            +
             | 
| 413 | 
            +
                Args:
         | 
| 414 | 
            +
                    kv: Key-value tensor.
         | 
| 415 | 
            +
                    inference_params: Inference parameters.
         | 
| 416 | 
            +
                    layer_idx: Layer index.
         | 
| 417 | 
            +
             | 
| 418 | 
            +
                Returns:
         | 
| 419 | 
            +
                    Updated key-value tensor.
         | 
| 420 | 
            +
             | 
| 421 | 
            +
                """
         | 
| 422 | 
            +
             | 
| 423 | 
            +
                num_heads, head_dim = kv.shape[-2:]
         | 
| 424 | 
            +
             | 
| 425 | 
            +
                if layer_idx not in inference_params.key_value_memory_dict:
         | 
| 426 | 
            +
                    kv_cache = torch.empty(
         | 
| 427 | 
            +
                        inference_params.max_batch_size,
         | 
| 428 | 
            +
                        inference_params.max_sequence_len,
         | 
| 429 | 
            +
                        2,
         | 
| 430 | 
            +
                        num_heads,
         | 
| 431 | 
            +
                        head_dim,
         | 
| 432 | 
            +
                        dtype=kv.dtype,
         | 
| 433 | 
            +
                        device=kv.device,
         | 
| 434 | 
            +
                    )
         | 
| 435 | 
            +
                    inference_params.key_value_memory_dict[layer_idx] = kv_cache
         | 
| 436 | 
            +
                else:
         | 
| 437 | 
            +
                    if not inference_params.fused_ft_kernel:
         | 
| 438 | 
            +
                        kv_cache = inference_params.key_value_memory_dict[layer_idx]
         | 
| 439 | 
            +
                    else:
         | 
| 440 | 
            +
                        k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx]
         | 
| 441 | 
            +
                        kv_cache = None
         | 
| 442 | 
            +
             | 
| 443 | 
            +
                batch_start = inference_params.batch_size_offset
         | 
| 444 | 
            +
                batch_end = batch_start + kv.shape[0]
         | 
| 445 | 
            +
                assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0])
         | 
| 446 | 
            +
             | 
| 447 | 
            +
                sequence_start = inference_params.sequence_len_offset
         | 
| 448 | 
            +
                sequence_end = sequence_start + kv.shape[1]
         | 
| 449 | 
            +
                assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2])
         | 
| 450 | 
            +
             | 
| 451 | 
            +
                if not inference_params.fused_ft_kernel:
         | 
| 452 | 
            +
                    assert kv_cache is not None
         | 
| 453 | 
            +
             | 
| 454 | 
            +
                    kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
         | 
| 455 | 
            +
                    kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
         | 
| 456 | 
            +
             | 
| 457 | 
            +
                    return kv
         | 
| 458 | 
            +
             | 
| 459 | 
            +
                assert inference_params.sequence_len_offset == 0
         | 
| 460 | 
            +
                assert kv.dtype in [torch.float16, torch.bfloat16, torch.float32]
         | 
| 461 | 
            +
             | 
| 462 | 
            +
                packsize = 4 if kv.dtype == torch.float32 else 8
         | 
| 463 | 
            +
             | 
| 464 | 
            +
                if kv_cache is not None:
         | 
| 465 | 
            +
                    kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
         | 
| 466 | 
            +
                    k_cache = rearrange(kv_cache[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize).contiguous()
         | 
| 467 | 
            +
                    v_cache = rearrange(kv_cache[:, :, 1], "b s h d -> b h s d").contiguous()
         | 
| 468 | 
            +
                    inference_params.key_value_memory_dict[layer_idx] = (k_cache, v_cache)
         | 
| 469 | 
            +
                else:
         | 
| 470 | 
            +
                    k_cache[batch_start:batch_end, :, :, :sequence_end, :] = rearrange(
         | 
| 471 | 
            +
                        kv[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize
         | 
| 472 | 
            +
                    )
         | 
| 473 | 
            +
                    v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange(kv[:, :, 1], "b s h d -> b h s d")
         | 
| 474 | 
            +
             | 
| 475 | 
            +
                return kv
         | 
| 476 | 
            +
             | 
| 477 | 
            +
             | 
| 478 | 
            +
            class MHA(nn.Module):
         | 
| 479 | 
            +
                """Multi-head attention layer."""
         | 
| 480 | 
            +
             | 
| 481 | 
            +
                def __init__(
         | 
| 482 | 
            +
                    self,
         | 
| 483 | 
            +
                    config: PretrainedConfig,
         | 
| 484 | 
            +
                    dtype: Optional[torch.dtype] = None,
         | 
| 485 | 
            +
                    device: Optional[str] = None,
         | 
| 486 | 
            +
                    rotary_dim: Optional[int] = None,
         | 
| 487 | 
            +
                    rotary_emb_scale_base: Optional[float] = None,
         | 
| 488 | 
            +
                    n_head: Optional[int] = None,
         | 
| 489 | 
            +
                    head_dim: Optional[int] = None,
         | 
| 490 | 
            +
                    bias: bool = True,
         | 
| 491 | 
            +
                    causal: bool = True,
         | 
| 492 | 
            +
                    softmax_scale: Optional[float] = None,
         | 
| 493 | 
            +
                    dropout: float = 0.0,
         | 
| 494 | 
            +
                    layer_idx: Optional[int] = None,
         | 
| 495 | 
            +
                    return_residual: bool = False,
         | 
| 496 | 
            +
                    checkpointing: bool = False,
         | 
| 497 | 
            +
                ) -> None:
         | 
| 498 | 
            +
                    super().__init__()
         | 
| 499 | 
            +
             | 
| 500 | 
            +
                    # Rotary embedding
         | 
| 501 | 
            +
                    self.rotary_emb_dim = rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0)
         | 
| 502 | 
            +
                    if self.rotary_emb_dim > 0:
         | 
| 503 | 
            +
                        rotary_kwargs = {"device": device}
         | 
| 504 | 
            +
                        if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0:
         | 
| 505 | 
            +
                            rotary_kwargs["scale_base"] = rotary_emb_scale_base
         | 
| 506 | 
            +
                        self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, **rotary_kwargs)
         | 
| 507 | 
            +
                    
         | 
| 508 | 
            +
                    # MLP
         | 
| 509 | 
            +
                    self.n_head, self.head_dim = find_mha_dims(config, n_head, head_dim)
         | 
| 510 | 
            +
                    op_size = self.n_head * self.head_dim
         | 
| 511 | 
            +
                    hidden_size = config.n_embd
         | 
| 512 | 
            +
             | 
| 513 | 
            +
                    self.Wqkv = nn.Linear(hidden_size, 3 * op_size, bias=bias, device=device, dtype=dtype)
         | 
| 514 | 
            +
                    self.out_proj = nn.Linear(op_size, hidden_size, bias=bias, device=device, dtype=dtype)
         | 
| 515 | 
            +
             | 
| 516 | 
            +
                    # Attention
         | 
| 517 | 
            +
                    self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
         | 
| 518 | 
            +
                    self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
         | 
| 519 | 
            +
             | 
| 520 | 
            +
                    self.layer_idx = layer_idx
         | 
| 521 | 
            +
                    self.return_residual = return_residual
         | 
| 522 | 
            +
                    self.checkpointing = checkpointing
         | 
| 523 | 
            +
             | 
| 524 | 
            +
                def forward(
         | 
| 525 | 
            +
                    self,
         | 
| 526 | 
            +
                    x: torch.FloatTensor,
         | 
| 527 | 
            +
                    past_key_values: Optional[InferenceParams] = None,
         | 
| 528 | 
            +
                    attention_mask: Optional[torch.BoolTensor] = None,
         | 
| 529 | 
            +
                    cu_seqlens: Optional[torch.LongTensor] = None,
         | 
| 530 | 
            +
                    max_seqlen: Optional[int] = None,
         | 
| 531 | 
            +
                    **kwargs,
         | 
| 532 | 
            +
                ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
         | 
| 533 | 
            +
                    qkv = self.Wqkv(x)
         | 
| 534 | 
            +
                    qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
         | 
| 535 | 
            +
             | 
| 536 | 
            +
                    seqlen_offset = past_key_values.sequence_len_offset if past_key_values is not None else 0
         | 
| 537 | 
            +
                    if self.rotary_emb_dim > 0:
         | 
| 538 | 
            +
                        qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset)
         | 
| 539 | 
            +
             | 
| 540 | 
            +
                    if past_key_values is not None:
         | 
| 541 | 
            +
                        kv = update_kv_cache(qkv[:, :, 1:], past_key_values, self.layer_idx)
         | 
| 542 | 
            +
             | 
| 543 | 
            +
                    if attention_mask is not None:
         | 
| 544 | 
            +
                        attention_mask = attention_mask[0] if isinstance(attention_mask, tuple) else attention_mask
         | 
| 545 | 
            +
                        attention_mask = attention_mask.bool().to(qkv.device)
         | 
| 546 | 
            +
             | 
| 547 | 
            +
                    attention_kwargs = {"attention_mask": attention_mask}
         | 
| 548 | 
            +
             | 
| 549 | 
            +
                    if past_key_values is None or seqlen_offset == 0:
         | 
| 550 | 
            +
                        if self.checkpointing:
         | 
| 551 | 
            +
                            attn_output = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **attention_kwargs)
         | 
| 552 | 
            +
                        else:
         | 
| 553 | 
            +
                            attn_output = self.inner_attn(qkv, **attention_kwargs)
         | 
| 554 | 
            +
                    else:
         | 
| 555 | 
            +
                        q = qkv[:, :, 0]
         | 
| 556 | 
            +
                        causal = None if past_key_values.sequence_len_offset == 0 else False
         | 
| 557 | 
            +
                        attn_output = self.inner_cross_attn(q, kv, causal=causal, **attention_kwargs)
         | 
| 558 | 
            +
             | 
| 559 | 
            +
                    output = rearrange(attn_output, "... h d -> ... (h d)")
         | 
| 560 | 
            +
                    output = self.out_proj(output)
         | 
| 561 | 
            +
             | 
| 562 | 
            +
                    return output if not self.return_residual else (output, x)
         | 
| 563 | 
            +
             | 
| 564 | 
            +
             | 
| 565 | 
            +
            class ParallelBlock(nn.Module):
         | 
| 566 | 
            +
                """Parallel block.
         | 
| 567 | 
            +
             | 
| 568 | 
            +
                This block applies parallel mixer and MLP layers to the input (used in GPT-J and CodeGen).
         | 
| 569 | 
            +
             | 
| 570 | 
            +
                """
         | 
| 571 | 
            +
             | 
| 572 | 
            +
                def __init__(
         | 
| 573 | 
            +
                    self,
         | 
| 574 | 
            +
                    config: PretrainedConfig,
         | 
| 575 | 
            +
                    block_idx: Optional[int] = None,
         | 
| 576 | 
            +
                ) -> None:
         | 
| 577 | 
            +
                    super().__init__()
         | 
| 578 | 
            +
             | 
| 579 | 
            +
                    self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
         | 
| 580 | 
            +
                    self.resid_dropout = nn.Dropout(config.resid_pdrop)
         | 
| 581 | 
            +
                    self.block_idx = block_idx
         | 
| 582 | 
            +
             | 
| 583 | 
            +
                    self.mixer = MHA(config, layer_idx=block_idx)
         | 
| 584 | 
            +
                    self.mlp = MLP(config)
         | 
| 585 | 
            +
             | 
| 586 | 
            +
                def forward(
         | 
| 587 | 
            +
                    self,
         | 
| 588 | 
            +
                    hidden_states: torch.FloatTensor,
         | 
| 589 | 
            +
                    past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
         | 
| 590 | 
            +
                    attention_mask: Optional[torch.BoolTensor] = None,
         | 
| 591 | 
            +
                    **kwargs,
         | 
| 592 | 
            +
                ) -> torch.FloatTensor:
         | 
| 593 | 
            +
                    residual = hidden_states
         | 
| 594 | 
            +
                    hidden_states = self.ln(hidden_states)
         | 
| 595 | 
            +
             | 
| 596 | 
            +
                    attn_outputs = self.mixer(hidden_states, past_key_values=past_key_values, attention_mask=attention_mask)
         | 
| 597 | 
            +
                    if isinstance(attn_outputs, tuple):
         | 
| 598 | 
            +
                        attn_outputs = attn_outputs[0]
         | 
| 599 | 
            +
             | 
| 600 | 
            +
                    attn_outputs = self.resid_dropout(attn_outputs)
         | 
| 601 | 
            +
                    feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
         | 
| 602 | 
            +
             | 
| 603 | 
            +
                    hidden_states = attn_outputs + feed_forward_hidden_states + residual
         | 
| 604 | 
            +
             | 
| 605 | 
            +
                    return hidden_states
         | 
| 606 | 
            +
             | 
| 607 | 
            +
             | 
| 608 | 
            +
            class CausalLMHead(nn.Module):
         | 
| 609 | 
            +
                """Causal Language Modeling head.
         | 
| 610 | 
            +
             | 
| 611 | 
            +
                Reference:
         | 
| 612 | 
            +
                    Improving Language Understanding by Generative Pre-Training.
         | 
| 613 | 
            +
                    https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
         | 
| 614 | 
            +
             | 
| 615 | 
            +
                """
         | 
| 616 | 
            +
             | 
| 617 | 
            +
                def __init__(self, config: PretrainedConfig) -> None:
         | 
| 618 | 
            +
                    super().__init__()
         | 
| 619 | 
            +
             | 
| 620 | 
            +
                    self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
         | 
| 621 | 
            +
                    self.linear = nn.Linear(config.n_embd, config.vocab_size)
         | 
| 622 | 
            +
             | 
| 623 | 
            +
                def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
         | 
| 624 | 
            +
                    hidden_states = self.ln(hidden_states)
         | 
| 625 | 
            +
                    logits = self.linear(hidden_states).to(torch.float32)
         | 
| 626 | 
            +
             | 
| 627 | 
            +
                    return logits
         | 
| 628 | 
            +
             | 
| 629 | 
            +
             | 
| 630 | 
            +
            class CausalLMLoss(nn.Module):
         | 
| 631 | 
            +
                """Causal Language Modeling loss.
         | 
| 632 | 
            +
             | 
| 633 | 
            +
                Reference:
         | 
| 634 | 
            +
                    Improving Language Understanding by Generative Pre-Training.
         | 
| 635 | 
            +
                    https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
         | 
| 636 | 
            +
             | 
| 637 | 
            +
                """
         | 
| 638 | 
            +
             | 
| 639 | 
            +
                def __init__(self, shift_labels: bool = True) -> None:
         | 
| 640 | 
            +
                    super().__init__()
         | 
| 641 | 
            +
             | 
| 642 | 
            +
                    self.shift_labels = shift_labels
         | 
| 643 | 
            +
                    self.loss_fct = nn.CrossEntropyLoss()
         | 
| 644 | 
            +
             | 
| 645 | 
            +
                def forward(self, logits: torch.FloatTensor, labels: torch.LongTensor) -> torch.FloatTensor:
         | 
| 646 | 
            +
                    if self.shift_labels:
         | 
| 647 | 
            +
                        logits = logits[..., :-1, :].contiguous()
         | 
| 648 | 
            +
                        labels = labels[..., 1:].contiguous()
         | 
| 649 | 
            +
             | 
| 650 | 
            +
                    loss = self.loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
         | 
| 651 | 
            +
             | 
| 652 | 
            +
                    return loss
         | 
| 653 | 
            +
             | 
| 654 | 
            +
             | 
| 655 | 
            +
            class MixFormerSequentialPreTrainedModel(PreTrainedModel):
         | 
| 656 | 
            +
                """MixFormer (sequential for DeepSpeed) pre-trained model."""
         | 
| 657 | 
            +
             | 
| 658 | 
            +
                config_class = MixFormerSequentialConfig
         | 
| 659 | 
            +
                base_model_prefix = "transformer"
         | 
| 660 | 
            +
                supports_gradient_checkpointing = True
         | 
| 661 | 
            +
             | 
| 662 | 
            +
                def __init__(self, *inputs, **kwargs) -> None:
         | 
| 663 | 
            +
                    super().__init__(*inputs, **kwargs)
         | 
| 664 | 
            +
             | 
| 665 | 
            +
                def _init_weights(self, module: nn.Module) -> None:
         | 
| 666 | 
            +
                    if isinstance(module, (nn.Linear,)):
         | 
| 667 | 
            +
                        module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
         | 
| 668 | 
            +
                        if module.bias is not None:
         | 
| 669 | 
            +
                            module.bias.data.zero_()
         | 
| 670 | 
            +
                    elif isinstance(module, nn.Embedding):
         | 
| 671 | 
            +
                        module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
         | 
| 672 | 
            +
                        if module.padding_idx is not None:
         | 
| 673 | 
            +
                            module.weight.data[module.padding_idx].zero_()
         | 
| 674 | 
            +
                    elif isinstance(module, nn.LayerNorm):
         | 
| 675 | 
            +
                        module.bias.data.zero_()
         | 
| 676 | 
            +
                        module.weight.data.fill_(1.0)
         | 
| 677 | 
            +
             | 
| 678 | 
            +
                def prepare_inputs_for_generation(
         | 
| 679 | 
            +
                    self,
         | 
| 680 | 
            +
                    input_ids: torch.LongTensor,
         | 
| 681 | 
            +
                    past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
         | 
| 682 | 
            +
                    attention_mask: Optional[torch.BoolTensor] = None,
         | 
| 683 | 
            +
                    **kwargs,
         | 
| 684 | 
            +
                ) -> Dict[str, Any]:
         | 
| 685 | 
            +
                    if attention_mask is not None and torch.any(~attention_mask.bool()):
         | 
| 686 | 
            +
                        total_seq_len = torch.sum(attention_mask, dim=1)
         | 
| 687 | 
            +
                        max_seq_len = torch.max(total_seq_len)
         | 
| 688 | 
            +
             | 
| 689 | 
            +
                        total_seq_len = torch.cat((torch.tensor([0], device=attention_mask.device), total_seq_len)).unsqueeze(1)
         | 
| 690 | 
            +
                        cumulative_seq_len = torch.cumsum(total_seq_len, dim=0).squeeze(1).to(torch.int32)
         | 
| 691 | 
            +
                        attention_mask = (attention_mask.bool(), cumulative_seq_len, max_seq_len.item())
         | 
| 692 | 
            +
                    else:
         | 
| 693 | 
            +
                        attention_mask = None
         | 
| 694 | 
            +
             | 
| 695 | 
            +
                    if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
         | 
| 696 | 
            +
                        past_key_values = InferenceParams(
         | 
| 697 | 
            +
                            max_batch_size=input_ids.shape[0],
         | 
| 698 | 
            +
                            max_sequence_len=self.config.n_positions,
         | 
| 699 | 
            +
                            sequence_len_offset=0,
         | 
| 700 | 
            +
                            batch_size_offset=0,
         | 
| 701 | 
            +
                            fused_ft_kernel=False,
         | 
| 702 | 
            +
                            key_value_memory_dict={},
         | 
| 703 | 
            +
                        )
         | 
| 704 | 
            +
                    else:
         | 
| 705 | 
            +
                        # Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
         | 
| 706 | 
            +
                        past_key_values.sequence_len_offset = len(input_ids[0]) - 1
         | 
| 707 | 
            +
                        input_ids = input_ids[:, -1].unsqueeze(-1)
         | 
| 708 | 
            +
             | 
| 709 | 
            +
                    return {
         | 
| 710 | 
            +
                        "input_ids": input_ids,
         | 
| 711 | 
            +
                        "past_key_values": past_key_values,
         | 
| 712 | 
            +
                        "attention_mask": attention_mask,
         | 
| 713 | 
            +
                    }
         | 
| 714 | 
            +
                
         | 
| 715 | 
            +
                def _set_gradient_checkpointing(self, module, value=False):
         | 
| 716 | 
            +
                        if isinstance(module, MixFormerSequentialPreTrainedModel):
         | 
| 717 | 
            +
                            module.gradient_checkpointing = value
         | 
| 718 | 
            +
             | 
| 719 | 
            +
             | 
| 720 | 
            +
            class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
         | 
| 721 | 
            +
                """MixFormer (sequential for DeepSpeed) for Causal Language Modeling."""
         | 
| 722 | 
            +
             | 
| 723 | 
            +
                _keys_to_ignore_on_load_missing = [""]
         | 
| 724 | 
            +
                _keys_to_ignore_on_load_unexpected = [r"layers\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"]
         | 
| 725 | 
            +
                _no_split_modules = ["ParallelBlock"]
         | 
| 726 | 
            +
             | 
| 727 | 
            +
                def __init__(self, config: MixFormerSequentialConfig) -> None:
         | 
| 728 | 
            +
                    super().__init__(config)
         | 
| 729 | 
            +
             | 
| 730 | 
            +
                    modules = [Embedding(config)]
         | 
| 731 | 
            +
                    modules += [ParallelBlock(config, block_idx=i) for i in range(config.n_layer)]
         | 
| 732 | 
            +
                    modules.append(CausalLMHead(config))
         | 
| 733 | 
            +
             | 
| 734 | 
            +
                    self.layers = nn.Sequential(*modules)
         | 
| 735 | 
            +
                    self.loss = CausalLMLoss()
         | 
| 736 | 
            +
             | 
| 737 | 
            +
                    self.post_init()
         | 
| 738 | 
            +
             | 
| 739 | 
            +
                def get_input_embeddings(self) -> nn.Embedding:
         | 
| 740 | 
            +
                    return self.layers[0].wte
         | 
| 741 | 
            +
             | 
| 742 | 
            +
                def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
         | 
| 743 | 
            +
                    self.layers[0].wte = new_embeddings
         | 
| 744 | 
            +
             | 
| 745 | 
            +
                def get_output_embeddings(self) -> nn.Linear:
         | 
| 746 | 
            +
                    return self.layers[-1].linear
         | 
| 747 | 
            +
             | 
| 748 | 
            +
                def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
         | 
| 749 | 
            +
                    self.layers[-1].linear = new_embeddings
         | 
| 750 | 
            +
             | 
| 751 | 
            +
                def forward(
         | 
| 752 | 
            +
                    self,
         | 
| 753 | 
            +
                    input_ids: torch.LongTensor,
         | 
| 754 | 
            +
                    past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
         | 
| 755 | 
            +
                    attention_mask: Optional[torch.BoolTensor] = None,
         | 
| 756 | 
            +
                    labels: Optional[torch.LongTensor] = None,
         | 
| 757 | 
            +
                    **kwargs,
         | 
| 758 | 
            +
                ) -> CausalLMOutputWithPast:
         | 
| 759 | 
            +
                    if past_key_values is None and attention_mask is None:
         | 
| 760 | 
            +
                        lm_logits = self.layers(input_ids)
         | 
| 761 | 
            +
                    else:
         | 
| 762 | 
            +
                        hidden_layer = self.layers[0](input_ids)
         | 
| 763 | 
            +
                        for module in self.layers[1:-1]:
         | 
| 764 | 
            +
                            hidden_layer = module(hidden_layer, past_key_values=past_key_values, attention_mask=attention_mask)
         | 
| 765 | 
            +
                        lm_logits = self.layers[-1](hidden_layer)
         | 
| 766 | 
            +
             | 
| 767 | 
            +
                    loss = None
         | 
| 768 | 
            +
                    if labels is not None:
         | 
| 769 | 
            +
                        loss = self.loss(lm_logits, labels)
         | 
| 770 | 
            +
             | 
| 771 | 
            +
                    return CausalLMOutputWithPast(loss=loss, logits=lm_logits, past_key_values=past_key_values)
         | 
    	
        pytorch_model.bin
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:ada7d8edf7851d2e3d7fa8c886888f0173d9a88cdbf5262210fefe5ee9adaac6
         | 
| 3 | 
            +
            size 2836625217
         | 
    	
        special_tokens_map.json
    ADDED
    
    | @@ -0,0 +1,5 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "bos_token": "<|endoftext|>",
         | 
| 3 | 
            +
              "eos_token": "<|endoftext|>",
         | 
| 4 | 
            +
              "unk_token": "<|endoftext|>"
         | 
| 5 | 
            +
            }
         | 
    	
        tokenizer.json
    ADDED
    
    | The diff for this file is too large to render. 
		See raw diff | 
|  | 
    	
        tokenizer_config.json
    ADDED
    
    | @@ -0,0 +1,323 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "add_prefix_space": false,
         | 
| 3 | 
            +
              "added_tokens_decoder": {
         | 
| 4 | 
            +
                "50256": {
         | 
| 5 | 
            +
                  "content": "<|endoftext|>",
         | 
| 6 | 
            +
                  "lstrip": false,
         | 
| 7 | 
            +
                  "normalized": false,
         | 
| 8 | 
            +
                  "rstrip": false,
         | 
| 9 | 
            +
                  "single_word": false,
         | 
| 10 | 
            +
                  "special": true
         | 
| 11 | 
            +
                },
         | 
| 12 | 
            +
                "50257": {
         | 
| 13 | 
            +
                  "content": "                               ",
         | 
| 14 | 
            +
                  "lstrip": false,
         | 
| 15 | 
            +
                  "normalized": true,
         | 
| 16 | 
            +
                  "rstrip": false,
         | 
| 17 | 
            +
                  "single_word": false,
         | 
| 18 | 
            +
                  "special": false
         | 
| 19 | 
            +
                },
         | 
| 20 | 
            +
                "50258": {
         | 
| 21 | 
            +
                  "content": "                              ",
         | 
| 22 | 
            +
                  "lstrip": false,
         | 
| 23 | 
            +
                  "normalized": true,
         | 
| 24 | 
            +
                  "rstrip": false,
         | 
| 25 | 
            +
                  "single_word": false,
         | 
| 26 | 
            +
                  "special": false
         | 
| 27 | 
            +
                },
         | 
| 28 | 
            +
                "50259": {
         | 
| 29 | 
            +
                  "content": "                             ",
         | 
| 30 | 
            +
                  "lstrip": false,
         | 
| 31 | 
            +
                  "normalized": true,
         | 
| 32 | 
            +
                  "rstrip": false,
         | 
| 33 | 
            +
                  "single_word": false,
         | 
| 34 | 
            +
                  "special": false
         | 
| 35 | 
            +
                },
         | 
| 36 | 
            +
                "50260": {
         | 
| 37 | 
            +
                  "content": "                            ",
         | 
| 38 | 
            +
                  "lstrip": false,
         | 
| 39 | 
            +
                  "normalized": true,
         | 
| 40 | 
            +
                  "rstrip": false,
         | 
| 41 | 
            +
                  "single_word": false,
         | 
| 42 | 
            +
                  "special": false
         | 
| 43 | 
            +
                },
         | 
| 44 | 
            +
                "50261": {
         | 
| 45 | 
            +
                  "content": "                           ",
         | 
| 46 | 
            +
                  "lstrip": false,
         | 
| 47 | 
            +
                  "normalized": true,
         | 
| 48 | 
            +
                  "rstrip": false,
         | 
| 49 | 
            +
                  "single_word": false,
         | 
| 50 | 
            +
                  "special": false
         | 
| 51 | 
            +
                },
         | 
| 52 | 
            +
                "50262": {
         | 
| 53 | 
            +
                  "content": "                          ",
         | 
| 54 | 
            +
                  "lstrip": false,
         | 
| 55 | 
            +
                  "normalized": true,
         | 
| 56 | 
            +
                  "rstrip": false,
         | 
| 57 | 
            +
                  "single_word": false,
         | 
| 58 | 
            +
                  "special": false
         | 
| 59 | 
            +
                },
         | 
| 60 | 
            +
                "50263": {
         | 
| 61 | 
            +
                  "content": "                         ",
         | 
| 62 | 
            +
                  "lstrip": false,
         | 
| 63 | 
            +
                  "normalized": true,
         | 
| 64 | 
            +
                  "rstrip": false,
         | 
| 65 | 
            +
                  "single_word": false,
         | 
| 66 | 
            +
                  "special": false
         | 
| 67 | 
            +
                },
         | 
| 68 | 
            +
                "50264": {
         | 
| 69 | 
            +
                  "content": "                        ",
         | 
| 70 | 
            +
                  "lstrip": false,
         | 
| 71 | 
            +
                  "normalized": true,
         | 
| 72 | 
            +
                  "rstrip": false,
         | 
| 73 | 
            +
                  "single_word": false,
         | 
| 74 | 
            +
                  "special": false
         | 
| 75 | 
            +
                },
         | 
| 76 | 
            +
                "50265": {
         | 
| 77 | 
            +
                  "content": "                       ",
         | 
| 78 | 
            +
                  "lstrip": false,
         | 
| 79 | 
            +
                  "normalized": true,
         | 
| 80 | 
            +
                  "rstrip": false,
         | 
| 81 | 
            +
                  "single_word": false,
         | 
| 82 | 
            +
                  "special": false
         | 
| 83 | 
            +
                },
         | 
| 84 | 
            +
                "50266": {
         | 
| 85 | 
            +
                  "content": "                      ",
         | 
| 86 | 
            +
                  "lstrip": false,
         | 
| 87 | 
            +
                  "normalized": true,
         | 
| 88 | 
            +
                  "rstrip": false,
         | 
| 89 | 
            +
                  "single_word": false,
         | 
| 90 | 
            +
                  "special": false
         | 
| 91 | 
            +
                },
         | 
| 92 | 
            +
                "50267": {
         | 
| 93 | 
            +
                  "content": "                     ",
         | 
| 94 | 
            +
                  "lstrip": false,
         | 
| 95 | 
            +
                  "normalized": true,
         | 
| 96 | 
            +
                  "rstrip": false,
         | 
| 97 | 
            +
                  "single_word": false,
         | 
| 98 | 
            +
                  "special": false
         | 
| 99 | 
            +
                },
         | 
| 100 | 
            +
                "50268": {
         | 
| 101 | 
            +
                  "content": "                    ",
         | 
| 102 | 
            +
                  "lstrip": false,
         | 
| 103 | 
            +
                  "normalized": true,
         | 
| 104 | 
            +
                  "rstrip": false,
         | 
| 105 | 
            +
                  "single_word": false,
         | 
| 106 | 
            +
                  "special": false
         | 
| 107 | 
            +
                },
         | 
| 108 | 
            +
                "50269": {
         | 
| 109 | 
            +
                  "content": "                   ",
         | 
| 110 | 
            +
                  "lstrip": false,
         | 
| 111 | 
            +
                  "normalized": true,
         | 
| 112 | 
            +
                  "rstrip": false,
         | 
| 113 | 
            +
                  "single_word": false,
         | 
| 114 | 
            +
                  "special": false
         | 
| 115 | 
            +
                },
         | 
| 116 | 
            +
                "50270": {
         | 
| 117 | 
            +
                  "content": "                  ",
         | 
| 118 | 
            +
                  "lstrip": false,
         | 
| 119 | 
            +
                  "normalized": true,
         | 
| 120 | 
            +
                  "rstrip": false,
         | 
| 121 | 
            +
                  "single_word": false,
         | 
| 122 | 
            +
                  "special": false
         | 
| 123 | 
            +
                },
         | 
| 124 | 
            +
                "50271": {
         | 
| 125 | 
            +
                  "content": "                 ",
         | 
| 126 | 
            +
                  "lstrip": false,
         | 
| 127 | 
            +
                  "normalized": true,
         | 
| 128 | 
            +
                  "rstrip": false,
         | 
| 129 | 
            +
                  "single_word": false,
         | 
| 130 | 
            +
                  "special": false
         | 
| 131 | 
            +
                },
         | 
| 132 | 
            +
                "50272": {
         | 
| 133 | 
            +
                  "content": "                ",
         | 
| 134 | 
            +
                  "lstrip": false,
         | 
| 135 | 
            +
                  "normalized": true,
         | 
| 136 | 
            +
                  "rstrip": false,
         | 
| 137 | 
            +
                  "single_word": false,
         | 
| 138 | 
            +
                  "special": false
         | 
| 139 | 
            +
                },
         | 
| 140 | 
            +
                "50273": {
         | 
| 141 | 
            +
                  "content": "               ",
         | 
| 142 | 
            +
                  "lstrip": false,
         | 
| 143 | 
            +
                  "normalized": true,
         | 
| 144 | 
            +
                  "rstrip": false,
         | 
| 145 | 
            +
                  "single_word": false,
         | 
| 146 | 
            +
                  "special": false
         | 
| 147 | 
            +
                },
         | 
| 148 | 
            +
                "50274": {
         | 
| 149 | 
            +
                  "content": "              ",
         | 
| 150 | 
            +
                  "lstrip": false,
         | 
| 151 | 
            +
                  "normalized": true,
         | 
| 152 | 
            +
                  "rstrip": false,
         | 
| 153 | 
            +
                  "single_word": false,
         | 
| 154 | 
            +
                  "special": false
         | 
| 155 | 
            +
                },
         | 
| 156 | 
            +
                "50275": {
         | 
| 157 | 
            +
                  "content": "             ",
         | 
| 158 | 
            +
                  "lstrip": false,
         | 
| 159 | 
            +
                  "normalized": true,
         | 
| 160 | 
            +
                  "rstrip": false,
         | 
| 161 | 
            +
                  "single_word": false,
         | 
| 162 | 
            +
                  "special": false
         | 
| 163 | 
            +
                },
         | 
| 164 | 
            +
                "50276": {
         | 
| 165 | 
            +
                  "content": "            ",
         | 
| 166 | 
            +
                  "lstrip": false,
         | 
| 167 | 
            +
                  "normalized": true,
         | 
| 168 | 
            +
                  "rstrip": false,
         | 
| 169 | 
            +
                  "single_word": false,
         | 
| 170 | 
            +
                  "special": false
         | 
| 171 | 
            +
                },
         | 
| 172 | 
            +
                "50277": {
         | 
| 173 | 
            +
                  "content": "           ",
         | 
| 174 | 
            +
                  "lstrip": false,
         | 
| 175 | 
            +
                  "normalized": true,
         | 
| 176 | 
            +
                  "rstrip": false,
         | 
| 177 | 
            +
                  "single_word": false,
         | 
| 178 | 
            +
                  "special": false
         | 
| 179 | 
            +
                },
         | 
| 180 | 
            +
                "50278": {
         | 
| 181 | 
            +
                  "content": "          ",
         | 
| 182 | 
            +
                  "lstrip": false,
         | 
| 183 | 
            +
                  "normalized": true,
         | 
| 184 | 
            +
                  "rstrip": false,
         | 
| 185 | 
            +
                  "single_word": false,
         | 
| 186 | 
            +
                  "special": false
         | 
| 187 | 
            +
                },
         | 
| 188 | 
            +
                "50279": {
         | 
| 189 | 
            +
                  "content": "         ",
         | 
| 190 | 
            +
                  "lstrip": false,
         | 
| 191 | 
            +
                  "normalized": true,
         | 
| 192 | 
            +
                  "rstrip": false,
         | 
| 193 | 
            +
                  "single_word": false,
         | 
| 194 | 
            +
                  "special": false
         | 
| 195 | 
            +
                },
         | 
| 196 | 
            +
                "50280": {
         | 
| 197 | 
            +
                  "content": "        ",
         | 
| 198 | 
            +
                  "lstrip": false,
         | 
| 199 | 
            +
                  "normalized": true,
         | 
| 200 | 
            +
                  "rstrip": false,
         | 
| 201 | 
            +
                  "single_word": false,
         | 
| 202 | 
            +
                  "special": false
         | 
| 203 | 
            +
                },
         | 
| 204 | 
            +
                "50281": {
         | 
| 205 | 
            +
                  "content": "       ",
         | 
| 206 | 
            +
                  "lstrip": false,
         | 
| 207 | 
            +
                  "normalized": true,
         | 
| 208 | 
            +
                  "rstrip": false,
         | 
| 209 | 
            +
                  "single_word": false,
         | 
| 210 | 
            +
                  "special": false
         | 
| 211 | 
            +
                },
         | 
| 212 | 
            +
                "50282": {
         | 
| 213 | 
            +
                  "content": "      ",
         | 
| 214 | 
            +
                  "lstrip": false,
         | 
| 215 | 
            +
                  "normalized": true,
         | 
| 216 | 
            +
                  "rstrip": false,
         | 
| 217 | 
            +
                  "single_word": false,
         | 
| 218 | 
            +
                  "special": false
         | 
| 219 | 
            +
                },
         | 
| 220 | 
            +
                "50283": {
         | 
| 221 | 
            +
                  "content": "     ",
         | 
| 222 | 
            +
                  "lstrip": false,
         | 
| 223 | 
            +
                  "normalized": true,
         | 
| 224 | 
            +
                  "rstrip": false,
         | 
| 225 | 
            +
                  "single_word": false,
         | 
| 226 | 
            +
                  "special": false
         | 
| 227 | 
            +
                },
         | 
| 228 | 
            +
                "50284": {
         | 
| 229 | 
            +
                  "content": "    ",
         | 
| 230 | 
            +
                  "lstrip": false,
         | 
| 231 | 
            +
                  "normalized": true,
         | 
| 232 | 
            +
                  "rstrip": false,
         | 
| 233 | 
            +
                  "single_word": false,
         | 
| 234 | 
            +
                  "special": false
         | 
| 235 | 
            +
                },
         | 
| 236 | 
            +
                "50285": {
         | 
| 237 | 
            +
                  "content": "   ",
         | 
| 238 | 
            +
                  "lstrip": false,
         | 
| 239 | 
            +
                  "normalized": true,
         | 
| 240 | 
            +
                  "rstrip": false,
         | 
| 241 | 
            +
                  "single_word": false,
         | 
| 242 | 
            +
                  "special": false
         | 
| 243 | 
            +
                },
         | 
| 244 | 
            +
                "50286": {
         | 
| 245 | 
            +
                  "content": "  ",
         | 
| 246 | 
            +
                  "lstrip": false,
         | 
| 247 | 
            +
                  "normalized": true,
         | 
| 248 | 
            +
                  "rstrip": false,
         | 
| 249 | 
            +
                  "single_word": false,
         | 
| 250 | 
            +
                  "special": false
         | 
| 251 | 
            +
                },
         | 
| 252 | 
            +
                "50287": {
         | 
| 253 | 
            +
                  "content": "\t\t\t\t\t\t\t\t\t",
         | 
| 254 | 
            +
                  "lstrip": false,
         | 
| 255 | 
            +
                  "normalized": true,
         | 
| 256 | 
            +
                  "rstrip": false,
         | 
| 257 | 
            +
                  "single_word": false,
         | 
| 258 | 
            +
                  "special": false
         | 
| 259 | 
            +
                },
         | 
| 260 | 
            +
                "50288": {
         | 
| 261 | 
            +
                  "content": "\t\t\t\t\t\t\t\t",
         | 
| 262 | 
            +
                  "lstrip": false,
         | 
| 263 | 
            +
                  "normalized": true,
         | 
| 264 | 
            +
                  "rstrip": false,
         | 
| 265 | 
            +
                  "single_word": false,
         | 
| 266 | 
            +
                  "special": false
         | 
| 267 | 
            +
                },
         | 
| 268 | 
            +
                "50289": {
         | 
| 269 | 
            +
                  "content": "\t\t\t\t\t\t\t",
         | 
| 270 | 
            +
                  "lstrip": false,
         | 
| 271 | 
            +
                  "normalized": true,
         | 
| 272 | 
            +
                  "rstrip": false,
         | 
| 273 | 
            +
                  "single_word": false,
         | 
| 274 | 
            +
                  "special": false
         | 
| 275 | 
            +
                },
         | 
| 276 | 
            +
                "50290": {
         | 
| 277 | 
            +
                  "content": "\t\t\t\t\t\t",
         | 
| 278 | 
            +
                  "lstrip": false,
         | 
| 279 | 
            +
                  "normalized": true,
         | 
| 280 | 
            +
                  "rstrip": false,
         | 
| 281 | 
            +
                  "single_word": false,
         | 
| 282 | 
            +
                  "special": false
         | 
| 283 | 
            +
                },
         | 
| 284 | 
            +
                "50291": {
         | 
| 285 | 
            +
                  "content": "\t\t\t\t\t",
         | 
| 286 | 
            +
                  "lstrip": false,
         | 
| 287 | 
            +
                  "normalized": true,
         | 
| 288 | 
            +
                  "rstrip": false,
         | 
| 289 | 
            +
                  "single_word": false,
         | 
| 290 | 
            +
                  "special": false
         | 
| 291 | 
            +
                },
         | 
| 292 | 
            +
                "50292": {
         | 
| 293 | 
            +
                  "content": "\t\t\t\t",
         | 
| 294 | 
            +
                  "lstrip": false,
         | 
| 295 | 
            +
                  "normalized": true,
         | 
| 296 | 
            +
                  "rstrip": false,
         | 
| 297 | 
            +
                  "single_word": false,
         | 
| 298 | 
            +
                  "special": false
         | 
| 299 | 
            +
                },
         | 
| 300 | 
            +
                "50293": {
         | 
| 301 | 
            +
                  "content": "\t\t\t",
         | 
| 302 | 
            +
                  "lstrip": false,
         | 
| 303 | 
            +
                  "normalized": true,
         | 
| 304 | 
            +
                  "rstrip": false,
         | 
| 305 | 
            +
                  "single_word": false,
         | 
| 306 | 
            +
                  "special": false
         | 
| 307 | 
            +
                },
         | 
| 308 | 
            +
                "50294": {
         | 
| 309 | 
            +
                  "content": "\t\t",
         | 
| 310 | 
            +
                  "lstrip": false,
         | 
| 311 | 
            +
                  "normalized": true,
         | 
| 312 | 
            +
                  "rstrip": false,
         | 
| 313 | 
            +
                  "single_word": false,
         | 
| 314 | 
            +
                  "special": false
         | 
| 315 | 
            +
                }
         | 
| 316 | 
            +
              },
         | 
| 317 | 
            +
              "bos_token": "<|endoftext|>",
         | 
| 318 | 
            +
              "clean_up_tokenization_spaces": true,
         | 
| 319 | 
            +
              "eos_token": "<|endoftext|>",
         | 
| 320 | 
            +
              "model_max_length": 2048,
         | 
| 321 | 
            +
              "tokenizer_class": "CodeGenTokenizer",
         | 
| 322 | 
            +
              "unk_token": "<|endoftext|>"
         | 
| 323 | 
            +
            }
         | 
    	
        vocab.json
    ADDED
    
    | The diff for this file is too large to render. 
		See raw diff | 
|  | 
