rsxdalv bigmoyan commited on
Commit
7dc384d
·
verified ·
0 Parent(s):

Duplicate from moonshotai/Kimi-Audio-7B-Instruct

Browse files

Co-authored-by: moyanwang <[email protected]>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +35 -0
  2. .gitignore +0 -0
  3. README.md +161 -0
  4. audio_detokenizer/config.yaml +123 -0
  5. audio_detokenizer/model.pt +3 -0
  6. config.json +44 -0
  7. configuration_moonshot_kimia.py +66 -0
  8. generation_config.json +3 -0
  9. model-1-of-35.safetensors +3 -0
  10. model-10-of-35.safetensors +3 -0
  11. model-11-of-35.safetensors +3 -0
  12. model-12-of-35.safetensors +3 -0
  13. model-13-of-35.safetensors +3 -0
  14. model-14-of-35.safetensors +3 -0
  15. model-15-of-35.safetensors +3 -0
  16. model-16-of-35.safetensors +3 -0
  17. model-17-of-35.safetensors +3 -0
  18. model-18-of-35.safetensors +3 -0
  19. model-19-of-35.safetensors +3 -0
  20. model-2-of-35.safetensors +3 -0
  21. model-20-of-35.safetensors +3 -0
  22. model-21-of-35.safetensors +3 -0
  23. model-22-of-35.safetensors +3 -0
  24. model-23-of-35.safetensors +3 -0
  25. model-24-of-35.safetensors +3 -0
  26. model-25-of-35.safetensors +3 -0
  27. model-26-of-35.safetensors +3 -0
  28. model-27-of-35.safetensors +3 -0
  29. model-28-of-35.safetensors +3 -0
  30. model-29-of-35.safetensors +3 -0
  31. model-3-of-35.safetensors +3 -0
  32. model-30-of-35.safetensors +3 -0
  33. model-31-of-35.safetensors +3 -0
  34. model-32-of-35.safetensors +3 -0
  35. model-33-of-35.safetensors +3 -0
  36. model-34-of-35.safetensors +3 -0
  37. model-35-of-35.safetensors +3 -0
  38. model-36-of-36.safetensors +3 -0
  39. model-4-of-35.safetensors +3 -0
  40. model-5-of-35.safetensors +3 -0
  41. model-6-of-35.safetensors +3 -0
  42. model-7-of-35.safetensors +3 -0
  43. model-8-of-35.safetensors +3 -0
  44. model-9-of-35.safetensors +3 -0
  45. model.safetensors.index.json +460 -0
  46. modeling_moonshot_kimia.py +917 -0
  47. special_tokens_map.json +425 -0
  48. tiktoken.model +3 -0
  49. tokenization_kimia.py +335 -0
  50. tokenizer_config.json +0 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
File without changes
README.md ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ - zh
6
+ tags:
7
+ - audio
8
+ - audio-language-model
9
+ - speech-recognition
10
+ - audio-understanding
11
+ - text-to-speech
12
+ - audio-generation
13
+ - chat
14
+ - kimi-audio
15
+ ---
16
+
17
+ # Kimi-Audio
18
+
19
+ <p align="center">
20
+ <img src="https://raw.githubusercontent.com/MoonshotAI/Kimi-Audio/master/assets/kimia_logo.png" width="400"/> <!-- TODO: Replace with actual raw image URL from your repo -->
21
+ <p>
22
+
23
+ <p align="center">
24
+ <a href="https://huggingface.co/moonshotai/Kimi-Audio-7B">🤗 Kimi-Audio-7B</a>&nbsp; | <a href="https://huggingface.co/moonshotai/Kimi-Audio-7B-Instruct">🤗 Kimi-Audio-7B-Instruct </a>&nbsp; | <a href="https://raw.githubusercontent.com/MoonshotAI/Kimi-Audio/master/assets/kimia_report.pdf">📑 Paper</a>
25
+ </p>
26
+
27
+ ## Introduction
28
+
29
+ We present Kimi-Audio, an open-source audio foundation model excelling in **audio understanding, generation, and conversation**. This repository hosts the model checkpoints for Kimi-Audio-7B-Instruct.
30
+
31
+ Kimi-Audio is designed as a universal audio foundation model capable of handling a wide variety of audio processing tasks within a single unified framework. Key features include:
32
+
33
+ * **Universal Capabilities:** Handles diverse tasks like speech recognition (ASR), audio question answering (AQA), audio captioning (AAC), speech emotion recognition (SER), sound event/scene classification (SEC/ASC) and end-to-end speech conversation.
34
+ * **State-of-the-Art Performance:** Achieves SOTA results on numerous audio benchmarks (see our [Technical Report](https://raw.githubusercontent.com/MoonshotAI/Kimi-Audio/master/assets/kimia_report.pdf)).
35
+ * **Large-Scale Pre-training:** Pre-trained on over 13 million hours of diverse audio data (speech, music, sounds) and text data.
36
+ * **Novel Architecture:** Employs a hybrid audio input (continuous acoustic + discrete semantic tokens) and an LLM core with parallel heads for text and audio token generation.
37
+ * **Efficient Inference:** Features a chunk-wise streaming detokenizer based on flow matching for low-latency audio generation.
38
+
39
+ For more details, please refer to our [GitHub Repository](https://github.com/MoonshotAI/Kimi-Audio) and [Technical Report](https://raw.githubusercontent.com/MoonshotAI/Kimi-Audio/master/assets/kimia_report.pdf).
40
+
41
+ ## Requirements
42
+
43
+ We recommend that you build a Docker image to run the inference. After cloning the inference code, you can construct the image using the `docker build` command.
44
+ ```bash
45
+ git clone https://github.com/MoonshotAI/Kimi-Audio
46
+ git submodule update --init
47
+ cd Kimi-Audio
48
+ docker build -t kimi-audio:v0.1 .
49
+ ```
50
+ Alternatively, You can also use our pre-built image:
51
+ ```bash
52
+ docker pull moonshotai/kimi-audio:v0.1
53
+ ```
54
+
55
+ Or, you can install requirments by:
56
+ ```bash
57
+ pip install -r requirements.txt
58
+ ```
59
+
60
+ You may refer to the Dockerfile in case of any environment issues.
61
+
62
+ ## Quickstart
63
+
64
+ This example demonstrates basic usage for generating text from audio (ASR) and generating both text and speech in a conversational turn using the `Kimi-Audio-7B-Instruct` model.
65
+
66
+ ```python
67
+ import soundfile as sf
68
+ # Assuming the KimiAudio class is available after installation
69
+ from kimia_infer.api.kimia import KimiAudio
70
+ import torch # Ensure torch is imported if needed for device placement
71
+
72
+ # --- 1. Load Model ---
73
+ # Load the model from Hugging Face Hub
74
+ # Make sure you are logged in (`huggingface-cli login`) if the repo is private.
75
+ model_id = "moonshotai/Kimi-Audio-7B-Instruct" # Or "Kimi/Kimi-Audio-7B"
76
+ device = "cuda" if torch.cuda.is_available() else "cpu" # Example device placement
77
+ # Note: The KimiAudio class might handle model loading differently.
78
+ # You might need to pass the model_id directly or download checkpoints manually
79
+ # and provide the local path as shown in the original readme_kimia.md.
80
+ # Please refer to the main Kimi-Audio repository for precise loading instructions.
81
+ # Example assuming KimiAudio takes the HF ID or a local path:
82
+ try:
83
+ model = KimiAudio(model_path=model_id, load_detokenizer=True) # May need device argument
84
+ model.to(device) # Example device placement
85
+ except Exception as e:
86
+ print(f"Automatic loading from HF Hub might require specific setup.")
87
+ print(f"Refer to Kimi-Audio docs. Trying local path example (update path!). Error: {e}")
88
+ # Fallback example:
89
+ # model_path = "/path/to/your/downloaded/kimia-hf-ckpt" # IMPORTANT: Update this path if loading locally
90
+ # model = KimiAudio(model_path=model_path, load_detokenizer=True)
91
+ # model.to(device) # Example device placement
92
+
93
+ # --- 2. Define Sampling Parameters ---
94
+ sampling_params = {
95
+ "audio_temperature": 0.8,
96
+ "audio_top_k": 10,
97
+ "text_temperature": 0.0,
98
+ "text_top_k": 5,
99
+ "audio_repetition_penalty": 1.0,
100
+ "audio_repetition_window_size": 64,
101
+ "text_repetition_penalty": 1.0,
102
+ "text_repetition_window_size": 16,
103
+ }
104
+
105
+ # --- 3. Example 1: Audio-to-Text (ASR) ---
106
+ # TODO: Provide actual example audio files or URLs accessible to users
107
+ # E.g., download sample files first or use URLs
108
+ # wget https://path/to/your/asr_example.wav -O asr_example.wav
109
+ # wget https://path/to/your/qa_example.wav -O qa_example.wav
110
+ asr_audio_path = "asr_example.wav" # IMPORTANT: Make sure this file exists
111
+ qa_audio_path = "qa_example.wav" # IMPORTANT: Make sure this file exists
112
+
113
+ messages_asr = [
114
+ {"role": "user", "message_type": "text", "content": "Please transcribe the following audio:"},
115
+ {"role": "user", "message_type": "audio", "content": asr_audio_path}
116
+ ]
117
+
118
+ # Generate only text output
119
+ # Note: Ensure the model object and generate method accept device placement if needed
120
+ _, text_output = model.generate(messages_asr, **sampling_params, output_type="text")
121
+ print(">>> ASR Output Text: ", text_output)
122
+ # Expected output: "这并不是告别,这是一个篇章的结束,也是新篇章的开始。" (Example)
123
+
124
+ # --- 4. Example 2: Audio-to-Audio/Text Conversation ---
125
+ messages_conversation = [
126
+ {"role": "user", "message_type": "audio", "content": qa_audio_path}
127
+ ]
128
+
129
+ # Generate both audio and text output
130
+ wav_output, text_output = model.generate(messages_conversation, **sampling_params, output_type="both")
131
+
132
+ # Save the generated audio
133
+ output_audio_path = "output_audio.wav"
134
+ # Ensure wav_output is on CPU and flattened before saving
135
+ sf.write(output_audio_path, wav_output.detach().cpu().view(-1).numpy(), 24000) # Assuming 24kHz output
136
+ print(f">>> Conversational Output Audio saved to: {output_audio_path}")
137
+ print(">>> Conversational Output Text: ", text_output)
138
+ # Expected output: "A." (Example)
139
+
140
+ print("Kimi-Audio inference examples complete.")
141
+
142
+ ```
143
+
144
+ ## Citation
145
+
146
+ If you find Kimi-Audio useful in your research or applications, please cite our technical report:
147
+
148
+ ```bibtex
149
+ @misc{kimi_audio_2024,
150
+ title={Kimi-Audio Technical Report},
151
+ author={Kimi Team},
152
+ year={2024},
153
+ eprint={arXiv:placeholder},
154
+ archivePrefix={arXiv},
155
+ primaryClass={cs.CL}
156
+ }
157
+ ```
158
+
159
+ ## License
160
+
161
+ The model is based and modified from [Qwen 2.5-7B](https://github.com/QwenLM/Qwen2.5). Code derived from Qwen2.5-7B is licensed under the [Apache 2.0 License](https://www.apache.org/licenses/LICENSE-2.0). Other parts of the code are licensed under the [MIT License](https://opensource.org/licenses/MIT).
audio_detokenizer/config.yaml ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accumulate_grad_batches: 1
2
+ base_config: config/config_base.yaml
3
+ batch_max_tokens: 12000
4
+ batch_size: 2
5
+ cfg_init: 1.0
6
+ cfg_scale: 4.0
7
+ cfg_schedule: linear
8
+ check_val_every_n_epoch: 10
9
+ clip_grad_norm: 0
10
+ data_dir: ''
11
+ debug: false
12
+ deep_speed_strategy_stage: 2
13
+ drop_last: true
14
+ dynamic_cfg: false
15
+ endless_ds: false
16
+ filter_args:
17
+ lang:
18
+ - zh
19
+ - en
20
+ max_spk_num: 6
21
+ speech_ratio: 0.6
22
+ gradient_clip_val: 1.0
23
+ indexed_ds: true
24
+ infer: false
25
+ infer_exp_name: ''
26
+ infer_json_path: ''
27
+ inference_ckpt: ''
28
+ inference_mode: nonstreaming
29
+ learning_rate: 1e-4
30
+ limit_val_batches: 100
31
+ load_opt: false
32
+ log_interval: 10
33
+ logger_type: tensorboard
34
+ loss:
35
+ lambda_fm: 1.0
36
+ lambda_phone: 0.0
37
+ mel_loss: l1
38
+ max_epochs: 1000
39
+ max_eval_sentences: -1
40
+ max_eval_tokens: -1
41
+ max_prompt_ratio: 0.5
42
+ max_segment_cnt: 20000
43
+ max_sentences: -1
44
+ max_speech_duration: 20
45
+ max_tokens: 31250
46
+ max_training_steps: 100000
47
+ max_updates: 160000
48
+ mel_mean: -4.479605
49
+ mel_std: 3.4584913
50
+ meta_dir: null
51
+ min_prompt_duration: 0.5
52
+ min_speech_duration: -1
53
+ model:
54
+ condition_prenet_depth: 6
55
+ dit:
56
+ chunk_params:
57
+ hz: 50
58
+ max_chunk: 3.0
59
+ max_chunk_history: 50000000
60
+ min_chunk: 0.5
61
+ need_block_shift: false
62
+ condition_input_dim: 1280
63
+ condition_type: discrete_codes
64
+ depth: 16
65
+ ffn_act_layer: gleu_tanh
66
+ ffn_conv_kernel_size: 5
67
+ ffn_gated_glu: false
68
+ ffn_type: vanilla_mlp
69
+ hidden_size: 2304
70
+ input_size: 80
71
+ max_seq_len: 4096
72
+ mlp_ratio: 4.0
73
+ num_heads: 18
74
+ position_embedding_type: skip
75
+ prompt_cfg_dropout: 0.2
76
+ rope_params:
77
+ max_position_embeddings: 4096
78
+ rope_base: 10000.0
79
+ rope_interpolation_factor: 1.0
80
+ semantic_cfg_dropout: 0.2
81
+ semantic_vocab_size: 16384
82
+ use_chunk_setting: true
83
+ use_rope: true
84
+ phone_predictor:
85
+ blank_id: 4
86
+ phone_vocab_size: 5000
87
+ position_id_start_from: 0
88
+ random_position_start: true
89
+ restart_position_ids: false
90
+ use_condition_prenet: false
91
+ need_merge_same_speaker: true
92
+ need_precise_phones: false
93
+ no_verlap: true
94
+ normalize_mel: true
95
+ num_nodes: 1
96
+ num_sanity_val_steps: 0
97
+ num_workers: 1
98
+ ode_steps: 150
99
+ optimizer_adam_beta1: 0.9
100
+ optimizer_adam_beta2: 0.98
101
+ optimizer_class: adamw
102
+ pin_memory: true
103
+ precision: bf16-mixed
104
+ save_interval: 2000
105
+ save_topk: 10
106
+ seed: 1234
107
+ shuffle: true
108
+ sort_by_len: true
109
+ src_sample_rate: 16000
110
+ strategy: ddp
111
+ tensorboard_dir: tb_logs
112
+ test_num: 100
113
+ tgt_sample_rate: 24000
114
+ timescale: 80000
115
+ use_cfg: false
116
+ use_cfg_rescale: false
117
+ use_distributed_sampler: false
118
+ use_uncondition: false
119
+ val_check_interval: 2000000
120
+ vocoder_ckpt: ''
121
+ wandb_name: glm4_semantic_cfm_v2_debug
122
+ warmup_updates: 100
123
+ weight_decay: 0.0001
audio_detokenizer/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cdeeec41e629565439cd8ef807c8a014ad6ce052cce0c259c7bfe3fe6ada3f51
3
+ size 19008505142
config.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MoonshotKimiaForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_moonshot_kimia.KimiAudioConfig",
7
+ "AutoModel": "modeling_moonshot_kimia.MoonshotKimiaModel",
8
+ "AutoModelForCausalLM": "modeling_moonshot_kimia.MoonshotKimiaForCausalLM"
9
+ },
10
+ "bos_token_id": 151643,
11
+ "eos_token_ids": [
12
+ 151644,
13
+ 151645
14
+ ],
15
+ "hidden_act": "silu",
16
+ "hidden_size": 3584,
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 18944,
19
+ "kimia_adaptor_input_dim": 5120,
20
+ "kimia_audio_output_vocab": 16896,
21
+ "kimia_media_begin": 151661,
22
+ "kimia_media_end": 151663,
23
+ "kimia_mimo_audiodelaytokens": 5,
24
+ "kimia_mimo_layers": 6,
25
+ "kimia_mimo_transformer_from_layer_index": 21,
26
+ "kimia_text_output_vocab": 152064,
27
+ "kimia_token_offset": 152064,
28
+ "num_attention_heads": 28,
29
+ "num_audio_special_tokens": 512,
30
+ "num_base_tokens": 151643,
31
+ "num_hidden_layers": 28,
32
+ "num_key_value_heads": 4,
33
+ "pad_token_id": 152063,
34
+ "max_position_embeddings": 8192,
35
+ "rms_norm_eps": 1e-06,
36
+ "rope_scaling": null,
37
+ "rope_theta": 1000000.0,
38
+ "tie_word_embeddings": false,
39
+ "torch_dtype": "bfloat16",
40
+ "transformers_version": "4.44.1",
41
+ "use_cache": true,
42
+ "use_whisper_feature": true,
43
+ "vocab_size": 168448
44
+ }
configuration_moonshot_kimia.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
2
+
3
+
4
+ class KimiAudioConfig(Qwen2Config):
5
+ def __init__(
6
+ self,
7
+ vocab_size=163840,
8
+ hidden_size=4096,
9
+ intermediate_size=11008,
10
+ num_hidden_layers=32,
11
+ num_attention_heads=32,
12
+ num_key_value_heads=None,
13
+ hidden_act="silu",
14
+ initializer_range=0.02,
15
+ rms_norm_eps=1e-6,
16
+ use_cache=True,
17
+ rope_theta=10000.0,
18
+ rope_scaling=None,
19
+ tie_word_embeddings=False,
20
+ kimia_mimo_layers: int = 6,
21
+ kimia_mimo_audiodelaytokens: int = 5,
22
+ kimia_mimo_transformer_from_layer_index: int = 21,
23
+ kimia_audio_output_vocab: int = 16896,
24
+ kimia_text_output_vocab: int = 152064,
25
+ num_audio_special_tokens: int = 512,
26
+ num_base_tokens: int = 151643,
27
+ kimia_token_offset: int = 152064,
28
+ use_whisper_feature: bool = True,
29
+ kimia_adaptor_input_dim: int = 5120,
30
+ kimia_media_begin: int = 151661,
31
+ kimia_media_end: int = 151663,
32
+ **kwargs,
33
+ ):
34
+ super().__init__(
35
+ vocab_size=vocab_size,
36
+ hidden_size=hidden_size,
37
+ intermediate_size=intermediate_size,
38
+ num_hidden_layers=num_hidden_layers,
39
+ num_attention_heads=num_attention_heads,
40
+ num_key_value_heads=num_key_value_heads,
41
+ hidden_act=hidden_act,
42
+ initializer_range=initializer_range,
43
+ rms_norm_eps=rms_norm_eps,
44
+ use_cache=use_cache,
45
+ tie_word_embeddings=tie_word_embeddings,
46
+ rope_theta=rope_theta,
47
+ rope_scaling=rope_scaling,
48
+ **kwargs,
49
+ )
50
+
51
+ self.kimia_mimo_layers = kimia_mimo_layers
52
+ self.kimia_mimo_audiodelaytokens = kimia_mimo_audiodelaytokens
53
+ # vocab
54
+ self.kimia_mimo_transformer_from_layer_index = (
55
+ kimia_mimo_transformer_from_layer_index
56
+ )
57
+ self.kimia_audio_output_vocab = kimia_audio_output_vocab
58
+ self.kimia_text_output_vocab = kimia_text_output_vocab
59
+ self.num_audio_special_tokens = num_audio_special_tokens
60
+ self.num_base_tokens = num_base_tokens
61
+ self.kimia_token_offset = kimia_token_offset
62
+ self.use_whisper_feature = use_whisper_feature
63
+ self.kimia_adaptor_input_dim = kimia_adaptor_input_dim
64
+ # special tokens
65
+ self.kimia_media_begin = kimia_media_begin
66
+ self.kimia_media_end = kimia_media_end
generation_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "max_length": 8192
3
+ }
model-1-of-35.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:462878caab2cc405a9665569e9c4a72191b070f73df5530f6e9c419108a24fe2
3
+ size 466117192
model-10-of-35.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:835a9fe443d1118f5ff46d1d35814e6080ff2a9ed578654887638ee853f47ae2
3
+ size 466117192
model-11-of-35.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6c9de9a823c1b63a4813b4e1d0f6d9c522382981f1889f9423aa077cc91a8f0f
3
+ size 466117208
model-12-of-35.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4e38dd5adcf335dfe2da28613ab0098570ae9f5fb0f0d137720b5b2458d68d45
3
+ size 466117208
model-13-of-35.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:895e02180f74fd811874d200229a549344e6a84a12943aa78d7bf03c3ffb6140
3
+ size 466117208
model-14-of-35.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e6ebf2a6d19063146565f0e6f7de3bc810a4565a8a24693c85119969b99e2542
3
+ size 466117208
model-15-of-35.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:76c3fb75b201ccbbfc1f9b909b75924445efe3e09199ac6d13db0857356975fa
3
+ size 466117208
model-16-of-35.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9af4933a1ef591f45d2e61fd7cc5b4e00969e59339a74a10a625de85ccb26362
3
+ size 466117208
model-17-of-35.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:39545d953a1db3ffb63f981e57c2f7d4fd178131c515c58955ece020be96241d
3
+ size 466117208
model-18-of-35.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a782c5c67a08571730bbebdb8928dfad48cf8342c38bbe98c6bfa7329edcf7bd
3
+ size 466117208
model-19-of-35.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:51e850d87b4ece3566404d66c1e27d474549a160b6bc216607bd11fc336e25ee
3
+ size 466117208
model-2-of-35.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c20fafb1f647f47084f81095ffcf2c85a55af0665c5b1caa1fb5c0504ee92433
3
+ size 466117192
model-20-of-35.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:434af95fd5e5a9e38dfe37c2123de7e55e8676c45ecc3f11ab94c0984588ab40
3
+ size 466117208
model-21-of-35.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f1dd7e0d633053d61bdd14f1ff96f3dc27e7e4ee66c707ce4ff4f4bb1490040
3
+ size 466117208
model-22-of-35.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e8e22a53c90fd6b68df56de41519c6bce4e0c3df0d75262121c0f0cdfabedf8
3
+ size 466117208
model-23-of-35.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f52cf116ea9183f84678aae29bef762e945f0a11b043385e34d150cb4e0930ee
3
+ size 466117208
model-24-of-35.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:17c2348936c30f4715c6ac406d88d3fcd2ede9b860051b0e43357398b6b2c392
3
+ size 466117208
model-25-of-35.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7f0f0d3852637374b66bd250abab868a48ca4120d01e70acff617ce751e2f9ce
3
+ size 466117208
model-26-of-35.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec675c4b2590c0f3b6bdfcd3afa5292f08c3723f84d816eb30ad85143b47f704
3
+ size 466117208
model-27-of-35.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:95c4714d295631d3562b25c5168a6551a92450a4c63bb9592eadde18347e2a3c
3
+ size 466117208
model-28-of-35.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d7c53162402f1a258fc1863a5857bc006b84f934719e6350cab328d5190d1800
3
+ size 466117208
model-29-of-35.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:02c5b5edc547ab8ff6ac83fec6f10a018ce111cc1ce6f6e3ced4db432b60903f
3
+ size 466117264
model-3-of-35.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d11bd527d217b86951248ba3506205b47fdd05ef8afedbd088733b2cff236553
3
+ size 466117192
model-30-of-35.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f7132a11e4e890532c9e4c30cfaf55904f2f6fcf79b74552e3903ec8f4fe38c1
3
+ size 466117264
model-31-of-35.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a836b2ad72a98873872bc7cbfda0ac0811d565dc4303212c0684abfaf83fb50f
3
+ size 466117264
model-32-of-35.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cd088c37a69da02d73a779caa447f82fb8dc3ad04a1d1590704395f7043737ee
3
+ size 466117264
model-33-of-35.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e7ab0dbed449d1ba7ed1f2582e279af6585c505158b1af66a7a77521677b70a
3
+ size 466117264
model-34-of-35.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:46b117a2fe7dd2c2eb3d2c6fb367bc3813a2d3b2a01bd330d983c52112a32ffd
3
+ size 466117264
model-35-of-35.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7dbbe9894d818f2751c32a36fe6db1f7e5d0f7d94e9752ee2ba81535ec777e33
3
+ size 62419592
model-36-of-36.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a25c2286a3373471ab4687f3908327ea15e29a909f9cc69eb642b1b47643a2df
3
+ size 3622320648
model-4-of-35.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9d5fe7491bac40d3b5c5eec23ba449f15e3248d12bdf582881fc65dafebe5a7a
3
+ size 466117192
model-5-of-35.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:193972da11808be8be48b74f439c5cd0567d825c84505565d30e208bf3dff0ed
3
+ size 466117192
model-6-of-35.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6aabe4e1e987bfc7ecc46933ae4a5f37c6e262411c1d86c663709abe633f3756
3
+ size 466117192
model-7-of-35.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a0fd3c96b26299dd056c849d969f54f1f89fca3ab480f7afdb57edf437497f48
3
+ size 466117192
model-8-of-35.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54f14bd698a405430c529895e167d39827e349a48181544e7e12ff41637f28e8
3
+ size 466117192
model-9-of-35.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fb0dc503bf96d67cf5ebb87bd9d899874e2cda9f0c828786e84f60e8c22cbbfc
3
+ size 466117192
model.safetensors.index.json ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 19532673280
4
+ },
5
+ "weight_map": {
6
+ "model.layers.0.self_attn.q_proj.weight": "model-1-of-35.safetensors",
7
+ "model.layers.0.self_attn.k_proj.weight": "model-1-of-35.safetensors",
8
+ "model.layers.0.self_attn.v_proj.weight": "model-1-of-35.safetensors",
9
+ "model.layers.0.self_attn.o_proj.weight": "model-1-of-35.safetensors",
10
+ "model.layers.0.self_attn.q_proj.bias": "model-1-of-35.safetensors",
11
+ "model.layers.0.self_attn.k_proj.bias": "model-1-of-35.safetensors",
12
+ "model.layers.0.self_attn.v_proj.bias": "model-1-of-35.safetensors",
13
+ "model.layers.0.input_layernorm.weight": "model-1-of-35.safetensors",
14
+ "model.layers.0.post_attention_layernorm.weight": "model-1-of-35.safetensors",
15
+ "model.layers.0.mlp.gate_proj.weight": "model-1-of-35.safetensors",
16
+ "model.layers.0.mlp.down_proj.weight": "model-1-of-35.safetensors",
17
+ "model.layers.0.mlp.up_proj.weight": "model-1-of-35.safetensors",
18
+ "model.layers.0.self_attn.rotary_emb.inv_freq": "model-1-of-35.safetensors",
19
+ "model.layers.1.self_attn.q_proj.weight": "model-2-of-35.safetensors",
20
+ "model.layers.1.self_attn.k_proj.weight": "model-2-of-35.safetensors",
21
+ "model.layers.1.self_attn.v_proj.weight": "model-2-of-35.safetensors",
22
+ "model.layers.1.self_attn.o_proj.weight": "model-2-of-35.safetensors",
23
+ "model.layers.1.self_attn.q_proj.bias": "model-2-of-35.safetensors",
24
+ "model.layers.1.self_attn.k_proj.bias": "model-2-of-35.safetensors",
25
+ "model.layers.1.self_attn.v_proj.bias": "model-2-of-35.safetensors",
26
+ "model.layers.1.input_layernorm.weight": "model-2-of-35.safetensors",
27
+ "model.layers.1.post_attention_layernorm.weight": "model-2-of-35.safetensors",
28
+ "model.layers.1.mlp.gate_proj.weight": "model-2-of-35.safetensors",
29
+ "model.layers.1.mlp.down_proj.weight": "model-2-of-35.safetensors",
30
+ "model.layers.1.mlp.up_proj.weight": "model-2-of-35.safetensors",
31
+ "model.layers.1.self_attn.rotary_emb.inv_freq": "model-2-of-35.safetensors",
32
+ "model.layers.2.self_attn.q_proj.weight": "model-3-of-35.safetensors",
33
+ "model.layers.2.self_attn.k_proj.weight": "model-3-of-35.safetensors",
34
+ "model.layers.2.self_attn.v_proj.weight": "model-3-of-35.safetensors",
35
+ "model.layers.2.self_attn.o_proj.weight": "model-3-of-35.safetensors",
36
+ "model.layers.2.self_attn.q_proj.bias": "model-3-of-35.safetensors",
37
+ "model.layers.2.self_attn.k_proj.bias": "model-3-of-35.safetensors",
38
+ "model.layers.2.self_attn.v_proj.bias": "model-3-of-35.safetensors",
39
+ "model.layers.2.input_layernorm.weight": "model-3-of-35.safetensors",
40
+ "model.layers.2.post_attention_layernorm.weight": "model-3-of-35.safetensors",
41
+ "model.layers.2.mlp.gate_proj.weight": "model-3-of-35.safetensors",
42
+ "model.layers.2.mlp.down_proj.weight": "model-3-of-35.safetensors",
43
+ "model.layers.2.mlp.up_proj.weight": "model-3-of-35.safetensors",
44
+ "model.layers.2.self_attn.rotary_emb.inv_freq": "model-3-of-35.safetensors",
45
+ "model.layers.3.self_attn.q_proj.weight": "model-4-of-35.safetensors",
46
+ "model.layers.3.self_attn.k_proj.weight": "model-4-of-35.safetensors",
47
+ "model.layers.3.self_attn.v_proj.weight": "model-4-of-35.safetensors",
48
+ "model.layers.3.self_attn.o_proj.weight": "model-4-of-35.safetensors",
49
+ "model.layers.3.self_attn.q_proj.bias": "model-4-of-35.safetensors",
50
+ "model.layers.3.self_attn.k_proj.bias": "model-4-of-35.safetensors",
51
+ "model.layers.3.self_attn.v_proj.bias": "model-4-of-35.safetensors",
52
+ "model.layers.3.input_layernorm.weight": "model-4-of-35.safetensors",
53
+ "model.layers.3.post_attention_layernorm.weight": "model-4-of-35.safetensors",
54
+ "model.layers.3.mlp.gate_proj.weight": "model-4-of-35.safetensors",
55
+ "model.layers.3.mlp.down_proj.weight": "model-4-of-35.safetensors",
56
+ "model.layers.3.mlp.up_proj.weight": "model-4-of-35.safetensors",
57
+ "model.layers.3.self_attn.rotary_emb.inv_freq": "model-4-of-35.safetensors",
58
+ "model.layers.4.self_attn.q_proj.weight": "model-5-of-35.safetensors",
59
+ "model.layers.4.self_attn.k_proj.weight": "model-5-of-35.safetensors",
60
+ "model.layers.4.self_attn.v_proj.weight": "model-5-of-35.safetensors",
61
+ "model.layers.4.self_attn.o_proj.weight": "model-5-of-35.safetensors",
62
+ "model.layers.4.self_attn.q_proj.bias": "model-5-of-35.safetensors",
63
+ "model.layers.4.self_attn.k_proj.bias": "model-5-of-35.safetensors",
64
+ "model.layers.4.self_attn.v_proj.bias": "model-5-of-35.safetensors",
65
+ "model.layers.4.input_layernorm.weight": "model-5-of-35.safetensors",
66
+ "model.layers.4.post_attention_layernorm.weight": "model-5-of-35.safetensors",
67
+ "model.layers.4.mlp.gate_proj.weight": "model-5-of-35.safetensors",
68
+ "model.layers.4.mlp.down_proj.weight": "model-5-of-35.safetensors",
69
+ "model.layers.4.mlp.up_proj.weight": "model-5-of-35.safetensors",
70
+ "model.layers.4.self_attn.rotary_emb.inv_freq": "model-5-of-35.safetensors",
71
+ "model.layers.5.self_attn.q_proj.weight": "model-6-of-35.safetensors",
72
+ "model.layers.5.self_attn.k_proj.weight": "model-6-of-35.safetensors",
73
+ "model.layers.5.self_attn.v_proj.weight": "model-6-of-35.safetensors",
74
+ "model.layers.5.self_attn.o_proj.weight": "model-6-of-35.safetensors",
75
+ "model.layers.5.self_attn.q_proj.bias": "model-6-of-35.safetensors",
76
+ "model.layers.5.self_attn.k_proj.bias": "model-6-of-35.safetensors",
77
+ "model.layers.5.self_attn.v_proj.bias": "model-6-of-35.safetensors",
78
+ "model.layers.5.input_layernorm.weight": "model-6-of-35.safetensors",
79
+ "model.layers.5.post_attention_layernorm.weight": "model-6-of-35.safetensors",
80
+ "model.layers.5.mlp.gate_proj.weight": "model-6-of-35.safetensors",
81
+ "model.layers.5.mlp.down_proj.weight": "model-6-of-35.safetensors",
82
+ "model.layers.5.mlp.up_proj.weight": "model-6-of-35.safetensors",
83
+ "model.layers.5.self_attn.rotary_emb.inv_freq": "model-6-of-35.safetensors",
84
+ "model.layers.6.self_attn.q_proj.weight": "model-7-of-35.safetensors",
85
+ "model.layers.6.self_attn.k_proj.weight": "model-7-of-35.safetensors",
86
+ "model.layers.6.self_attn.v_proj.weight": "model-7-of-35.safetensors",
87
+ "model.layers.6.self_attn.o_proj.weight": "model-7-of-35.safetensors",
88
+ "model.layers.6.self_attn.q_proj.bias": "model-7-of-35.safetensors",
89
+ "model.layers.6.self_attn.k_proj.bias": "model-7-of-35.safetensors",
90
+ "model.layers.6.self_attn.v_proj.bias": "model-7-of-35.safetensors",
91
+ "model.layers.6.input_layernorm.weight": "model-7-of-35.safetensors",
92
+ "model.layers.6.post_attention_layernorm.weight": "model-7-of-35.safetensors",
93
+ "model.layers.6.mlp.gate_proj.weight": "model-7-of-35.safetensors",
94
+ "model.layers.6.mlp.down_proj.weight": "model-7-of-35.safetensors",
95
+ "model.layers.6.mlp.up_proj.weight": "model-7-of-35.safetensors",
96
+ "model.layers.6.self_attn.rotary_emb.inv_freq": "model-7-of-35.safetensors",
97
+ "model.layers.7.self_attn.q_proj.weight": "model-8-of-35.safetensors",
98
+ "model.layers.7.self_attn.k_proj.weight": "model-8-of-35.safetensors",
99
+ "model.layers.7.self_attn.v_proj.weight": "model-8-of-35.safetensors",
100
+ "model.layers.7.self_attn.o_proj.weight": "model-8-of-35.safetensors",
101
+ "model.layers.7.self_attn.q_proj.bias": "model-8-of-35.safetensors",
102
+ "model.layers.7.self_attn.k_proj.bias": "model-8-of-35.safetensors",
103
+ "model.layers.7.self_attn.v_proj.bias": "model-8-of-35.safetensors",
104
+ "model.layers.7.input_layernorm.weight": "model-8-of-35.safetensors",
105
+ "model.layers.7.post_attention_layernorm.weight": "model-8-of-35.safetensors",
106
+ "model.layers.7.mlp.gate_proj.weight": "model-8-of-35.safetensors",
107
+ "model.layers.7.mlp.down_proj.weight": "model-8-of-35.safetensors",
108
+ "model.layers.7.mlp.up_proj.weight": "model-8-of-35.safetensors",
109
+ "model.layers.7.self_attn.rotary_emb.inv_freq": "model-8-of-35.safetensors",
110
+ "model.layers.8.self_attn.q_proj.weight": "model-9-of-35.safetensors",
111
+ "model.layers.8.self_attn.k_proj.weight": "model-9-of-35.safetensors",
112
+ "model.layers.8.self_attn.v_proj.weight": "model-9-of-35.safetensors",
113
+ "model.layers.8.self_attn.o_proj.weight": "model-9-of-35.safetensors",
114
+ "model.layers.8.self_attn.q_proj.bias": "model-9-of-35.safetensors",
115
+ "model.layers.8.self_attn.k_proj.bias": "model-9-of-35.safetensors",
116
+ "model.layers.8.self_attn.v_proj.bias": "model-9-of-35.safetensors",
117
+ "model.layers.8.input_layernorm.weight": "model-9-of-35.safetensors",
118
+ "model.layers.8.post_attention_layernorm.weight": "model-9-of-35.safetensors",
119
+ "model.layers.8.mlp.gate_proj.weight": "model-9-of-35.safetensors",
120
+ "model.layers.8.mlp.down_proj.weight": "model-9-of-35.safetensors",
121
+ "model.layers.8.mlp.up_proj.weight": "model-9-of-35.safetensors",
122
+ "model.layers.8.self_attn.rotary_emb.inv_freq": "model-9-of-35.safetensors",
123
+ "model.layers.9.self_attn.q_proj.weight": "model-10-of-35.safetensors",
124
+ "model.layers.9.self_attn.k_proj.weight": "model-10-of-35.safetensors",
125
+ "model.layers.9.self_attn.v_proj.weight": "model-10-of-35.safetensors",
126
+ "model.layers.9.self_attn.o_proj.weight": "model-10-of-35.safetensors",
127
+ "model.layers.9.self_attn.q_proj.bias": "model-10-of-35.safetensors",
128
+ "model.layers.9.self_attn.k_proj.bias": "model-10-of-35.safetensors",
129
+ "model.layers.9.self_attn.v_proj.bias": "model-10-of-35.safetensors",
130
+ "model.layers.9.input_layernorm.weight": "model-10-of-35.safetensors",
131
+ "model.layers.9.post_attention_layernorm.weight": "model-10-of-35.safetensors",
132
+ "model.layers.9.mlp.gate_proj.weight": "model-10-of-35.safetensors",
133
+ "model.layers.9.mlp.down_proj.weight": "model-10-of-35.safetensors",
134
+ "model.layers.9.mlp.up_proj.weight": "model-10-of-35.safetensors",
135
+ "model.layers.9.self_attn.rotary_emb.inv_freq": "model-10-of-35.safetensors",
136
+ "model.layers.10.self_attn.q_proj.weight": "model-11-of-35.safetensors",
137
+ "model.layers.10.self_attn.k_proj.weight": "model-11-of-35.safetensors",
138
+ "model.layers.10.self_attn.v_proj.weight": "model-11-of-35.safetensors",
139
+ "model.layers.10.self_attn.o_proj.weight": "model-11-of-35.safetensors",
140
+ "model.layers.10.self_attn.q_proj.bias": "model-11-of-35.safetensors",
141
+ "model.layers.10.self_attn.k_proj.bias": "model-11-of-35.safetensors",
142
+ "model.layers.10.self_attn.v_proj.bias": "model-11-of-35.safetensors",
143
+ "model.layers.10.input_layernorm.weight": "model-11-of-35.safetensors",
144
+ "model.layers.10.post_attention_layernorm.weight": "model-11-of-35.safetensors",
145
+ "model.layers.10.mlp.gate_proj.weight": "model-11-of-35.safetensors",
146
+ "model.layers.10.mlp.down_proj.weight": "model-11-of-35.safetensors",
147
+ "model.layers.10.mlp.up_proj.weight": "model-11-of-35.safetensors",
148
+ "model.layers.10.self_attn.rotary_emb.inv_freq": "model-11-of-35.safetensors",
149
+ "model.layers.11.self_attn.q_proj.weight": "model-12-of-35.safetensors",
150
+ "model.layers.11.self_attn.k_proj.weight": "model-12-of-35.safetensors",
151
+ "model.layers.11.self_attn.v_proj.weight": "model-12-of-35.safetensors",
152
+ "model.layers.11.self_attn.o_proj.weight": "model-12-of-35.safetensors",
153
+ "model.layers.11.self_attn.q_proj.bias": "model-12-of-35.safetensors",
154
+ "model.layers.11.self_attn.k_proj.bias": "model-12-of-35.safetensors",
155
+ "model.layers.11.self_attn.v_proj.bias": "model-12-of-35.safetensors",
156
+ "model.layers.11.input_layernorm.weight": "model-12-of-35.safetensors",
157
+ "model.layers.11.post_attention_layernorm.weight": "model-12-of-35.safetensors",
158
+ "model.layers.11.mlp.gate_proj.weight": "model-12-of-35.safetensors",
159
+ "model.layers.11.mlp.down_proj.weight": "model-12-of-35.safetensors",
160
+ "model.layers.11.mlp.up_proj.weight": "model-12-of-35.safetensors",
161
+ "model.layers.11.self_attn.rotary_emb.inv_freq": "model-12-of-35.safetensors",
162
+ "model.layers.12.self_attn.q_proj.weight": "model-13-of-35.safetensors",
163
+ "model.layers.12.self_attn.k_proj.weight": "model-13-of-35.safetensors",
164
+ "model.layers.12.self_attn.v_proj.weight": "model-13-of-35.safetensors",
165
+ "model.layers.12.self_attn.o_proj.weight": "model-13-of-35.safetensors",
166
+ "model.layers.12.self_attn.q_proj.bias": "model-13-of-35.safetensors",
167
+ "model.layers.12.self_attn.k_proj.bias": "model-13-of-35.safetensors",
168
+ "model.layers.12.self_attn.v_proj.bias": "model-13-of-35.safetensors",
169
+ "model.layers.12.input_layernorm.weight": "model-13-of-35.safetensors",
170
+ "model.layers.12.post_attention_layernorm.weight": "model-13-of-35.safetensors",
171
+ "model.layers.12.mlp.gate_proj.weight": "model-13-of-35.safetensors",
172
+ "model.layers.12.mlp.down_proj.weight": "model-13-of-35.safetensors",
173
+ "model.layers.12.mlp.up_proj.weight": "model-13-of-35.safetensors",
174
+ "model.layers.12.self_attn.rotary_emb.inv_freq": "model-13-of-35.safetensors",
175
+ "model.layers.13.self_attn.q_proj.weight": "model-14-of-35.safetensors",
176
+ "model.layers.13.self_attn.k_proj.weight": "model-14-of-35.safetensors",
177
+ "model.layers.13.self_attn.v_proj.weight": "model-14-of-35.safetensors",
178
+ "model.layers.13.self_attn.o_proj.weight": "model-14-of-35.safetensors",
179
+ "model.layers.13.self_attn.q_proj.bias": "model-14-of-35.safetensors",
180
+ "model.layers.13.self_attn.k_proj.bias": "model-14-of-35.safetensors",
181
+ "model.layers.13.self_attn.v_proj.bias": "model-14-of-35.safetensors",
182
+ "model.layers.13.input_layernorm.weight": "model-14-of-35.safetensors",
183
+ "model.layers.13.post_attention_layernorm.weight": "model-14-of-35.safetensors",
184
+ "model.layers.13.mlp.gate_proj.weight": "model-14-of-35.safetensors",
185
+ "model.layers.13.mlp.down_proj.weight": "model-14-of-35.safetensors",
186
+ "model.layers.13.mlp.up_proj.weight": "model-14-of-35.safetensors",
187
+ "model.layers.13.self_attn.rotary_emb.inv_freq": "model-14-of-35.safetensors",
188
+ "model.layers.14.self_attn.q_proj.weight": "model-15-of-35.safetensors",
189
+ "model.layers.14.self_attn.k_proj.weight": "model-15-of-35.safetensors",
190
+ "model.layers.14.self_attn.v_proj.weight": "model-15-of-35.safetensors",
191
+ "model.layers.14.self_attn.o_proj.weight": "model-15-of-35.safetensors",
192
+ "model.layers.14.self_attn.q_proj.bias": "model-15-of-35.safetensors",
193
+ "model.layers.14.self_attn.k_proj.bias": "model-15-of-35.safetensors",
194
+ "model.layers.14.self_attn.v_proj.bias": "model-15-of-35.safetensors",
195
+ "model.layers.14.input_layernorm.weight": "model-15-of-35.safetensors",
196
+ "model.layers.14.post_attention_layernorm.weight": "model-15-of-35.safetensors",
197
+ "model.layers.14.mlp.gate_proj.weight": "model-15-of-35.safetensors",
198
+ "model.layers.14.mlp.down_proj.weight": "model-15-of-35.safetensors",
199
+ "model.layers.14.mlp.up_proj.weight": "model-15-of-35.safetensors",
200
+ "model.layers.14.self_attn.rotary_emb.inv_freq": "model-15-of-35.safetensors",
201
+ "model.layers.15.self_attn.q_proj.weight": "model-16-of-35.safetensors",
202
+ "model.layers.15.self_attn.k_proj.weight": "model-16-of-35.safetensors",
203
+ "model.layers.15.self_attn.v_proj.weight": "model-16-of-35.safetensors",
204
+ "model.layers.15.self_attn.o_proj.weight": "model-16-of-35.safetensors",
205
+ "model.layers.15.self_attn.q_proj.bias": "model-16-of-35.safetensors",
206
+ "model.layers.15.self_attn.k_proj.bias": "model-16-of-35.safetensors",
207
+ "model.layers.15.self_attn.v_proj.bias": "model-16-of-35.safetensors",
208
+ "model.layers.15.input_layernorm.weight": "model-16-of-35.safetensors",
209
+ "model.layers.15.post_attention_layernorm.weight": "model-16-of-35.safetensors",
210
+ "model.layers.15.mlp.gate_proj.weight": "model-16-of-35.safetensors",
211
+ "model.layers.15.mlp.down_proj.weight": "model-16-of-35.safetensors",
212
+ "model.layers.15.mlp.up_proj.weight": "model-16-of-35.safetensors",
213
+ "model.layers.15.self_attn.rotary_emb.inv_freq": "model-16-of-35.safetensors",
214
+ "model.layers.16.self_attn.q_proj.weight": "model-17-of-35.safetensors",
215
+ "model.layers.16.self_attn.k_proj.weight": "model-17-of-35.safetensors",
216
+ "model.layers.16.self_attn.v_proj.weight": "model-17-of-35.safetensors",
217
+ "model.layers.16.self_attn.o_proj.weight": "model-17-of-35.safetensors",
218
+ "model.layers.16.self_attn.q_proj.bias": "model-17-of-35.safetensors",
219
+ "model.layers.16.self_attn.k_proj.bias": "model-17-of-35.safetensors",
220
+ "model.layers.16.self_attn.v_proj.bias": "model-17-of-35.safetensors",
221
+ "model.layers.16.input_layernorm.weight": "model-17-of-35.safetensors",
222
+ "model.layers.16.post_attention_layernorm.weight": "model-17-of-35.safetensors",
223
+ "model.layers.16.mlp.gate_proj.weight": "model-17-of-35.safetensors",
224
+ "model.layers.16.mlp.down_proj.weight": "model-17-of-35.safetensors",
225
+ "model.layers.16.mlp.up_proj.weight": "model-17-of-35.safetensors",
226
+ "model.layers.16.self_attn.rotary_emb.inv_freq": "model-17-of-35.safetensors",
227
+ "model.layers.17.self_attn.q_proj.weight": "model-18-of-35.safetensors",
228
+ "model.layers.17.self_attn.k_proj.weight": "model-18-of-35.safetensors",
229
+ "model.layers.17.self_attn.v_proj.weight": "model-18-of-35.safetensors",
230
+ "model.layers.17.self_attn.o_proj.weight": "model-18-of-35.safetensors",
231
+ "model.layers.17.self_attn.q_proj.bias": "model-18-of-35.safetensors",
232
+ "model.layers.17.self_attn.k_proj.bias": "model-18-of-35.safetensors",
233
+ "model.layers.17.self_attn.v_proj.bias": "model-18-of-35.safetensors",
234
+ "model.layers.17.input_layernorm.weight": "model-18-of-35.safetensors",
235
+ "model.layers.17.post_attention_layernorm.weight": "model-18-of-35.safetensors",
236
+ "model.layers.17.mlp.gate_proj.weight": "model-18-of-35.safetensors",
237
+ "model.layers.17.mlp.down_proj.weight": "model-18-of-35.safetensors",
238
+ "model.layers.17.mlp.up_proj.weight": "model-18-of-35.safetensors",
239
+ "model.layers.17.self_attn.rotary_emb.inv_freq": "model-18-of-35.safetensors",
240
+ "model.layers.18.self_attn.q_proj.weight": "model-19-of-35.safetensors",
241
+ "model.layers.18.self_attn.k_proj.weight": "model-19-of-35.safetensors",
242
+ "model.layers.18.self_attn.v_proj.weight": "model-19-of-35.safetensors",
243
+ "model.layers.18.self_attn.o_proj.weight": "model-19-of-35.safetensors",
244
+ "model.layers.18.self_attn.q_proj.bias": "model-19-of-35.safetensors",
245
+ "model.layers.18.self_attn.k_proj.bias": "model-19-of-35.safetensors",
246
+ "model.layers.18.self_attn.v_proj.bias": "model-19-of-35.safetensors",
247
+ "model.layers.18.input_layernorm.weight": "model-19-of-35.safetensors",
248
+ "model.layers.18.post_attention_layernorm.weight": "model-19-of-35.safetensors",
249
+ "model.layers.18.mlp.gate_proj.weight": "model-19-of-35.safetensors",
250
+ "model.layers.18.mlp.down_proj.weight": "model-19-of-35.safetensors",
251
+ "model.layers.18.mlp.up_proj.weight": "model-19-of-35.safetensors",
252
+ "model.layers.18.self_attn.rotary_emb.inv_freq": "model-19-of-35.safetensors",
253
+ "model.layers.19.self_attn.q_proj.weight": "model-20-of-35.safetensors",
254
+ "model.layers.19.self_attn.k_proj.weight": "model-20-of-35.safetensors",
255
+ "model.layers.19.self_attn.v_proj.weight": "model-20-of-35.safetensors",
256
+ "model.layers.19.self_attn.o_proj.weight": "model-20-of-35.safetensors",
257
+ "model.layers.19.self_attn.q_proj.bias": "model-20-of-35.safetensors",
258
+ "model.layers.19.self_attn.k_proj.bias": "model-20-of-35.safetensors",
259
+ "model.layers.19.self_attn.v_proj.bias": "model-20-of-35.safetensors",
260
+ "model.layers.19.input_layernorm.weight": "model-20-of-35.safetensors",
261
+ "model.layers.19.post_attention_layernorm.weight": "model-20-of-35.safetensors",
262
+ "model.layers.19.mlp.gate_proj.weight": "model-20-of-35.safetensors",
263
+ "model.layers.19.mlp.down_proj.weight": "model-20-of-35.safetensors",
264
+ "model.layers.19.mlp.up_proj.weight": "model-20-of-35.safetensors",
265
+ "model.layers.19.self_attn.rotary_emb.inv_freq": "model-20-of-35.safetensors",
266
+ "model.layers.20.self_attn.q_proj.weight": "model-21-of-35.safetensors",
267
+ "model.layers.20.self_attn.k_proj.weight": "model-21-of-35.safetensors",
268
+ "model.layers.20.self_attn.v_proj.weight": "model-21-of-35.safetensors",
269
+ "model.layers.20.self_attn.o_proj.weight": "model-21-of-35.safetensors",
270
+ "model.layers.20.self_attn.q_proj.bias": "model-21-of-35.safetensors",
271
+ "model.layers.20.self_attn.k_proj.bias": "model-21-of-35.safetensors",
272
+ "model.layers.20.self_attn.v_proj.bias": "model-21-of-35.safetensors",
273
+ "model.layers.20.input_layernorm.weight": "model-21-of-35.safetensors",
274
+ "model.layers.20.post_attention_layernorm.weight": "model-21-of-35.safetensors",
275
+ "model.layers.20.mlp.gate_proj.weight": "model-21-of-35.safetensors",
276
+ "model.layers.20.mlp.down_proj.weight": "model-21-of-35.safetensors",
277
+ "model.layers.20.mlp.up_proj.weight": "model-21-of-35.safetensors",
278
+ "model.layers.20.self_attn.rotary_emb.inv_freq": "model-21-of-35.safetensors",
279
+ "model.layers.21.self_attn.q_proj.weight": "model-22-of-35.safetensors",
280
+ "model.layers.21.self_attn.k_proj.weight": "model-22-of-35.safetensors",
281
+ "model.layers.21.self_attn.v_proj.weight": "model-22-of-35.safetensors",
282
+ "model.layers.21.self_attn.o_proj.weight": "model-22-of-35.safetensors",
283
+ "model.layers.21.self_attn.q_proj.bias": "model-22-of-35.safetensors",
284
+ "model.layers.21.self_attn.k_proj.bias": "model-22-of-35.safetensors",
285
+ "model.layers.21.self_attn.v_proj.bias": "model-22-of-35.safetensors",
286
+ "model.layers.21.input_layernorm.weight": "model-22-of-35.safetensors",
287
+ "model.layers.21.post_attention_layernorm.weight": "model-22-of-35.safetensors",
288
+ "model.layers.21.mlp.gate_proj.weight": "model-22-of-35.safetensors",
289
+ "model.layers.21.mlp.down_proj.weight": "model-22-of-35.safetensors",
290
+ "model.layers.21.mlp.up_proj.weight": "model-22-of-35.safetensors",
291
+ "model.layers.21.self_attn.rotary_emb.inv_freq": "model-22-of-35.safetensors",
292
+ "model.layers.22.self_attn.q_proj.weight": "model-23-of-35.safetensors",
293
+ "model.layers.22.self_attn.k_proj.weight": "model-23-of-35.safetensors",
294
+ "model.layers.22.self_attn.v_proj.weight": "model-23-of-35.safetensors",
295
+ "model.layers.22.self_attn.o_proj.weight": "model-23-of-35.safetensors",
296
+ "model.layers.22.self_attn.q_proj.bias": "model-23-of-35.safetensors",
297
+ "model.layers.22.self_attn.k_proj.bias": "model-23-of-35.safetensors",
298
+ "model.layers.22.self_attn.v_proj.bias": "model-23-of-35.safetensors",
299
+ "model.layers.22.input_layernorm.weight": "model-23-of-35.safetensors",
300
+ "model.layers.22.post_attention_layernorm.weight": "model-23-of-35.safetensors",
301
+ "model.layers.22.mlp.gate_proj.weight": "model-23-of-35.safetensors",
302
+ "model.layers.22.mlp.down_proj.weight": "model-23-of-35.safetensors",
303
+ "model.layers.22.mlp.up_proj.weight": "model-23-of-35.safetensors",
304
+ "model.layers.22.self_attn.rotary_emb.inv_freq": "model-23-of-35.safetensors",
305
+ "model.layers.23.self_attn.q_proj.weight": "model-24-of-35.safetensors",
306
+ "model.layers.23.self_attn.k_proj.weight": "model-24-of-35.safetensors",
307
+ "model.layers.23.self_attn.v_proj.weight": "model-24-of-35.safetensors",
308
+ "model.layers.23.self_attn.o_proj.weight": "model-24-of-35.safetensors",
309
+ "model.layers.23.self_attn.q_proj.bias": "model-24-of-35.safetensors",
310
+ "model.layers.23.self_attn.k_proj.bias": "model-24-of-35.safetensors",
311
+ "model.layers.23.self_attn.v_proj.bias": "model-24-of-35.safetensors",
312
+ "model.layers.23.input_layernorm.weight": "model-24-of-35.safetensors",
313
+ "model.layers.23.post_attention_layernorm.weight": "model-24-of-35.safetensors",
314
+ "model.layers.23.mlp.gate_proj.weight": "model-24-of-35.safetensors",
315
+ "model.layers.23.mlp.down_proj.weight": "model-24-of-35.safetensors",
316
+ "model.layers.23.mlp.up_proj.weight": "model-24-of-35.safetensors",
317
+ "model.layers.23.self_attn.rotary_emb.inv_freq": "model-24-of-35.safetensors",
318
+ "model.layers.24.self_attn.q_proj.weight": "model-25-of-35.safetensors",
319
+ "model.layers.24.self_attn.k_proj.weight": "model-25-of-35.safetensors",
320
+ "model.layers.24.self_attn.v_proj.weight": "model-25-of-35.safetensors",
321
+ "model.layers.24.self_attn.o_proj.weight": "model-25-of-35.safetensors",
322
+ "model.layers.24.self_attn.q_proj.bias": "model-25-of-35.safetensors",
323
+ "model.layers.24.self_attn.k_proj.bias": "model-25-of-35.safetensors",
324
+ "model.layers.24.self_attn.v_proj.bias": "model-25-of-35.safetensors",
325
+ "model.layers.24.input_layernorm.weight": "model-25-of-35.safetensors",
326
+ "model.layers.24.post_attention_layernorm.weight": "model-25-of-35.safetensors",
327
+ "model.layers.24.mlp.gate_proj.weight": "model-25-of-35.safetensors",
328
+ "model.layers.24.mlp.down_proj.weight": "model-25-of-35.safetensors",
329
+ "model.layers.24.mlp.up_proj.weight": "model-25-of-35.safetensors",
330
+ "model.layers.24.self_attn.rotary_emb.inv_freq": "model-25-of-35.safetensors",
331
+ "model.layers.25.self_attn.q_proj.weight": "model-26-of-35.safetensors",
332
+ "model.layers.25.self_attn.k_proj.weight": "model-26-of-35.safetensors",
333
+ "model.layers.25.self_attn.v_proj.weight": "model-26-of-35.safetensors",
334
+ "model.layers.25.self_attn.o_proj.weight": "model-26-of-35.safetensors",
335
+ "model.layers.25.self_attn.q_proj.bias": "model-26-of-35.safetensors",
336
+ "model.layers.25.self_attn.k_proj.bias": "model-26-of-35.safetensors",
337
+ "model.layers.25.self_attn.v_proj.bias": "model-26-of-35.safetensors",
338
+ "model.layers.25.input_layernorm.weight": "model-26-of-35.safetensors",
339
+ "model.layers.25.post_attention_layernorm.weight": "model-26-of-35.safetensors",
340
+ "model.layers.25.mlp.gate_proj.weight": "model-26-of-35.safetensors",
341
+ "model.layers.25.mlp.down_proj.weight": "model-26-of-35.safetensors",
342
+ "model.layers.25.mlp.up_proj.weight": "model-26-of-35.safetensors",
343
+ "model.layers.25.self_attn.rotary_emb.inv_freq": "model-26-of-35.safetensors",
344
+ "model.layers.26.self_attn.q_proj.weight": "model-27-of-35.safetensors",
345
+ "model.layers.26.self_attn.k_proj.weight": "model-27-of-35.safetensors",
346
+ "model.layers.26.self_attn.v_proj.weight": "model-27-of-35.safetensors",
347
+ "model.layers.26.self_attn.o_proj.weight": "model-27-of-35.safetensors",
348
+ "model.layers.26.self_attn.q_proj.bias": "model-27-of-35.safetensors",
349
+ "model.layers.26.self_attn.k_proj.bias": "model-27-of-35.safetensors",
350
+ "model.layers.26.self_attn.v_proj.bias": "model-27-of-35.safetensors",
351
+ "model.layers.26.input_layernorm.weight": "model-27-of-35.safetensors",
352
+ "model.layers.26.post_attention_layernorm.weight": "model-27-of-35.safetensors",
353
+ "model.layers.26.mlp.gate_proj.weight": "model-27-of-35.safetensors",
354
+ "model.layers.26.mlp.down_proj.weight": "model-27-of-35.safetensors",
355
+ "model.layers.26.mlp.up_proj.weight": "model-27-of-35.safetensors",
356
+ "model.layers.26.self_attn.rotary_emb.inv_freq": "model-27-of-35.safetensors",
357
+ "model.layers.27.self_attn.q_proj.weight": "model-28-of-35.safetensors",
358
+ "model.layers.27.self_attn.k_proj.weight": "model-28-of-35.safetensors",
359
+ "model.layers.27.self_attn.v_proj.weight": "model-28-of-35.safetensors",
360
+ "model.layers.27.self_attn.o_proj.weight": "model-28-of-35.safetensors",
361
+ "model.layers.27.self_attn.q_proj.bias": "model-28-of-35.safetensors",
362
+ "model.layers.27.self_attn.k_proj.bias": "model-28-of-35.safetensors",
363
+ "model.layers.27.self_attn.v_proj.bias": "model-28-of-35.safetensors",
364
+ "model.layers.27.input_layernorm.weight": "model-28-of-35.safetensors",
365
+ "model.layers.27.post_attention_layernorm.weight": "model-28-of-35.safetensors",
366
+ "model.layers.27.mlp.gate_proj.weight": "model-28-of-35.safetensors",
367
+ "model.layers.27.mlp.down_proj.weight": "model-28-of-35.safetensors",
368
+ "model.layers.27.mlp.up_proj.weight": "model-28-of-35.safetensors",
369
+ "model.layers.27.self_attn.rotary_emb.inv_freq": "model-28-of-35.safetensors",
370
+ "model.mimo_layers.0.self_attn.q_proj.weight": "model-29-of-35.safetensors",
371
+ "model.mimo_layers.0.self_attn.k_proj.weight": "model-29-of-35.safetensors",
372
+ "model.mimo_layers.0.self_attn.v_proj.weight": "model-29-of-35.safetensors",
373
+ "model.mimo_layers.0.self_attn.o_proj.weight": "model-29-of-35.safetensors",
374
+ "model.mimo_layers.0.input_layernorm.weight": "model-29-of-35.safetensors",
375
+ "model.mimo_layers.0.post_attention_layernorm.weight": "model-29-of-35.safetensors",
376
+ "model.mimo_layers.0.mlp.gate_proj.weight": "model-29-of-35.safetensors",
377
+ "model.mimo_layers.0.mlp.down_proj.weight": "model-29-of-35.safetensors",
378
+ "model.mimo_layers.0.mlp.up_proj.weight": "model-29-of-35.safetensors",
379
+ "model.mimo_layers.0.self_attn.q_proj.bias": "model-29-of-35.safetensors",
380
+ "model.mimo_layers.0.self_attn.k_proj.bias": "model-29-of-35.safetensors",
381
+ "model.mimo_layers.0.self_attn.v_proj.bias": "model-29-of-35.safetensors",
382
+ "model.mimo_layers.0.self_attn.rotary_emb.inv_freq": "model-29-of-35.safetensors",
383
+ "model.mimo_layers.1.self_attn.q_proj.weight": "model-30-of-35.safetensors",
384
+ "model.mimo_layers.1.self_attn.k_proj.weight": "model-30-of-35.safetensors",
385
+ "model.mimo_layers.1.self_attn.v_proj.weight": "model-30-of-35.safetensors",
386
+ "model.mimo_layers.1.self_attn.o_proj.weight": "model-30-of-35.safetensors",
387
+ "model.mimo_layers.1.input_layernorm.weight": "model-30-of-35.safetensors",
388
+ "model.mimo_layers.1.post_attention_layernorm.weight": "model-30-of-35.safetensors",
389
+ "model.mimo_layers.1.mlp.gate_proj.weight": "model-30-of-35.safetensors",
390
+ "model.mimo_layers.1.mlp.down_proj.weight": "model-30-of-35.safetensors",
391
+ "model.mimo_layers.1.mlp.up_proj.weight": "model-30-of-35.safetensors",
392
+ "model.mimo_layers.1.self_attn.q_proj.bias": "model-30-of-35.safetensors",
393
+ "model.mimo_layers.1.self_attn.k_proj.bias": "model-30-of-35.safetensors",
394
+ "model.mimo_layers.1.self_attn.v_proj.bias": "model-30-of-35.safetensors",
395
+ "model.mimo_layers.1.self_attn.rotary_emb.inv_freq": "model-30-of-35.safetensors",
396
+ "model.mimo_layers.2.self_attn.q_proj.weight": "model-31-of-35.safetensors",
397
+ "model.mimo_layers.2.self_attn.k_proj.weight": "model-31-of-35.safetensors",
398
+ "model.mimo_layers.2.self_attn.v_proj.weight": "model-31-of-35.safetensors",
399
+ "model.mimo_layers.2.self_attn.o_proj.weight": "model-31-of-35.safetensors",
400
+ "model.mimo_layers.2.input_layernorm.weight": "model-31-of-35.safetensors",
401
+ "model.mimo_layers.2.post_attention_layernorm.weight": "model-31-of-35.safetensors",
402
+ "model.mimo_layers.2.mlp.gate_proj.weight": "model-31-of-35.safetensors",
403
+ "model.mimo_layers.2.mlp.down_proj.weight": "model-31-of-35.safetensors",
404
+ "model.mimo_layers.2.mlp.up_proj.weight": "model-31-of-35.safetensors",
405
+ "model.mimo_layers.2.self_attn.q_proj.bias": "model-31-of-35.safetensors",
406
+ "model.mimo_layers.2.self_attn.k_proj.bias": "model-31-of-35.safetensors",
407
+ "model.mimo_layers.2.self_attn.v_proj.bias": "model-31-of-35.safetensors",
408
+ "model.mimo_layers.2.self_attn.rotary_emb.inv_freq": "model-31-of-35.safetensors",
409
+ "model.mimo_layers.3.self_attn.q_proj.weight": "model-32-of-35.safetensors",
410
+ "model.mimo_layers.3.self_attn.k_proj.weight": "model-32-of-35.safetensors",
411
+ "model.mimo_layers.3.self_attn.v_proj.weight": "model-32-of-35.safetensors",
412
+ "model.mimo_layers.3.self_attn.o_proj.weight": "model-32-of-35.safetensors",
413
+ "model.mimo_layers.3.input_layernorm.weight": "model-32-of-35.safetensors",
414
+ "model.mimo_layers.3.post_attention_layernorm.weight": "model-32-of-35.safetensors",
415
+ "model.mimo_layers.3.mlp.gate_proj.weight": "model-32-of-35.safetensors",
416
+ "model.mimo_layers.3.mlp.down_proj.weight": "model-32-of-35.safetensors",
417
+ "model.mimo_layers.3.mlp.up_proj.weight": "model-32-of-35.safetensors",
418
+ "model.mimo_layers.3.self_attn.q_proj.bias": "model-32-of-35.safetensors",
419
+ "model.mimo_layers.3.self_attn.k_proj.bias": "model-32-of-35.safetensors",
420
+ "model.mimo_layers.3.self_attn.v_proj.bias": "model-32-of-35.safetensors",
421
+ "model.mimo_layers.3.self_attn.rotary_emb.inv_freq": "model-32-of-35.safetensors",
422
+ "model.mimo_layers.4.self_attn.q_proj.weight": "model-33-of-35.safetensors",
423
+ "model.mimo_layers.4.self_attn.k_proj.weight": "model-33-of-35.safetensors",
424
+ "model.mimo_layers.4.self_attn.v_proj.weight": "model-33-of-35.safetensors",
425
+ "model.mimo_layers.4.self_attn.o_proj.weight": "model-33-of-35.safetensors",
426
+ "model.mimo_layers.4.input_layernorm.weight": "model-33-of-35.safetensors",
427
+ "model.mimo_layers.4.post_attention_layernorm.weight": "model-33-of-35.safetensors",
428
+ "model.mimo_layers.4.mlp.gate_proj.weight": "model-33-of-35.safetensors",
429
+ "model.mimo_layers.4.mlp.down_proj.weight": "model-33-of-35.safetensors",
430
+ "model.mimo_layers.4.mlp.up_proj.weight": "model-33-of-35.safetensors",
431
+ "model.mimo_layers.4.self_attn.q_proj.bias": "model-33-of-35.safetensors",
432
+ "model.mimo_layers.4.self_attn.k_proj.bias": "model-33-of-35.safetensors",
433
+ "model.mimo_layers.4.self_attn.v_proj.bias": "model-33-of-35.safetensors",
434
+ "model.mimo_layers.4.self_attn.rotary_emb.inv_freq": "model-33-of-35.safetensors",
435
+ "model.mimo_layers.5.self_attn.q_proj.weight": "model-34-of-35.safetensors",
436
+ "model.mimo_layers.5.self_attn.k_proj.weight": "model-34-of-35.safetensors",
437
+ "model.mimo_layers.5.self_attn.v_proj.weight": "model-34-of-35.safetensors",
438
+ "model.mimo_layers.5.self_attn.o_proj.weight": "model-34-of-35.safetensors",
439
+ "model.mimo_layers.5.input_layernorm.weight": "model-34-of-35.safetensors",
440
+ "model.mimo_layers.5.post_attention_layernorm.weight": "model-34-of-35.safetensors",
441
+ "model.mimo_layers.5.mlp.gate_proj.weight": "model-34-of-35.safetensors",
442
+ "model.mimo_layers.5.mlp.down_proj.weight": "model-34-of-35.safetensors",
443
+ "model.mimo_layers.5.mlp.up_proj.weight": "model-34-of-35.safetensors",
444
+ "model.mimo_layers.5.self_attn.q_proj.bias": "model-34-of-35.safetensors",
445
+ "model.mimo_layers.5.self_attn.k_proj.bias": "model-34-of-35.safetensors",
446
+ "model.mimo_layers.5.self_attn.v_proj.bias": "model-34-of-35.safetensors",
447
+ "model.mimo_layers.5.self_attn.rotary_emb.inv_freq": "model-34-of-35.safetensors",
448
+ "model.vq_adaptor.layers.0.weight": "model-35-of-35.safetensors",
449
+ "model.vq_adaptor.layers.0.bias": "model-35-of-35.safetensors",
450
+ "model.vq_adaptor.layers.3.weight": "model-35-of-35.safetensors",
451
+ "model.vq_adaptor.layers.3.bias": "model-35-of-35.safetensors",
452
+ "model.vq_adaptor.layers.4.weight": "model-35-of-35.safetensors",
453
+ "model.vq_adaptor.layers.4.bias": "model-35-of-35.safetensors",
454
+ "model.embed_tokens.weight": "model-36-of-36.safetensors",
455
+ "model.norm.weight": "model-36-of-36.safetensors",
456
+ "lm_head.weight": "model-36-of-36.safetensors",
457
+ "mimo_output.weight": "model-36-of-36.safetensors",
458
+ "model.mimo_norm.weight": "model-36-of-36.safetensors"
459
+ }
460
+ }
modeling_moonshot_kimia.py ADDED
@@ -0,0 +1,917 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The Moonshot AI Team, Qwen Team, and HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # The code is based on Qwen2.5-7B, but modified for KimiAudio.
5
+ #
6
+ # Licensing Information:
7
+ # - Code derived from Qwen2.5-7B is licensed under the Apache License, Version 2.0.
8
+ # - Other parts of the code are licensed under the MIT License.
9
+ #
10
+ # Apache License, Version 2.0:
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+ #
23
+ # MIT License:
24
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
25
+ # of this software and associated documentation files (the "Software"), to deal
26
+ # in the Software without restriction, including without limitation the rights
27
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
28
+ # copies of the Software, and to permit persons to whom the Software is
29
+ # furnished to do so, subject to the following conditions:
30
+ #
31
+ # The above copyright notice and this permission notice shall be included in all
32
+ # copies or substantial portions of the Software.
33
+ #
34
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
35
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
36
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
37
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
38
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
39
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
40
+ # SOFTWARE.
41
+ """PyTorch KimiAudio model."""
42
+
43
+ from typing import List, Optional, Tuple, Union
44
+ import torch
45
+ import torch.utils.checkpoint
46
+ from torch import nn
47
+
48
+ import transformers
49
+ from packaging import version
50
+
51
+ assert version.parse(transformers.__version__) >= version.parse("4.34.1")
52
+
53
+ from transformers.modeling_outputs import (
54
+ BaseModelOutputWithPast,
55
+ CausalLMOutputWithPast,
56
+ )
57
+ from transformers.utils import (
58
+ logging,
59
+ )
60
+ from .configuration_moonshot_kimia import KimiAudioConfig
61
+ import torch.nn.functional as F
62
+ from transformers.models.qwen2.modeling_qwen2 import (
63
+ Qwen2RMSNorm,
64
+ Qwen2MLP,
65
+ Qwen2PreTrainedModel,
66
+ )
67
+ from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb
68
+
69
+ if version.parse(transformers.__version__) >= version.parse("4.35.0"):
70
+ from transformers.utils import is_flash_attn_2_available as is_flash_attn_available
71
+ else:
72
+ from transformers.utils import is_flash_attn_available
73
+
74
+ if is_flash_attn_available():
75
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
76
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
77
+ else:
78
+ raise RuntimeError("flash attention must be installed")
79
+
80
+
81
+ logger = logging.get_logger(__name__)
82
+
83
+
84
+ def _get_unpad_data(padding_mask):
85
+ seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
86
+ indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
87
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
88
+ cu_seqlens = F.pad(
89
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
90
+ )
91
+ return (
92
+ indices,
93
+ cu_seqlens,
94
+ max_seqlen_in_batch,
95
+ )
96
+
97
+
98
+ def _upad_input(query_layer, key_layer, value_layer, padding_mask, query_length):
99
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask)
100
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
101
+ num_heads = query_layer.shape[2]
102
+
103
+ key_layer = index_first_axis(
104
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
105
+ indices_k,
106
+ )
107
+ value_layer = index_first_axis(
108
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
109
+ indices_k,
110
+ )
111
+ if query_length == kv_seq_len:
112
+ query_layer = index_first_axis(
113
+ query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
114
+ )
115
+ cu_seqlens_q = cu_seqlens_k
116
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
117
+ indices_q = indices_k
118
+ elif query_length == 1:
119
+ max_seqlen_in_batch_q = 1
120
+ cu_seqlens_q = torch.arange(
121
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
122
+ ) # There is a memcpy here, that is very bad.
123
+ indices_q = cu_seqlens_q[:-1]
124
+ query_layer = query_layer.squeeze(1)
125
+ else:
126
+ # The -q_len: slice assumes left padding.
127
+ padding_mask = padding_mask[:, -query_length:]
128
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
129
+ query_layer, padding_mask
130
+ )
131
+
132
+ return (
133
+ query_layer,
134
+ key_layer,
135
+ value_layer,
136
+ indices_q,
137
+ (cu_seqlens_q, cu_seqlens_k),
138
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
139
+ )
140
+
141
+
142
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
143
+ def _make_causal_mask(
144
+ input_ids_shape: torch.Size,
145
+ dtype: torch.dtype,
146
+ device: torch.device,
147
+ past_key_values_length: int = 0,
148
+ ):
149
+ """
150
+ Make causal mask used for bi-directional self-attention.
151
+ """
152
+ bsz, tgt_len = input_ids_shape
153
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
154
+ mask_cond = torch.arange(mask.size(-1), device=device)
155
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
156
+ mask = mask.to(dtype)
157
+
158
+ if past_key_values_length > 0:
159
+ mask = torch.cat(
160
+ [
161
+ torch.zeros(
162
+ tgt_len, past_key_values_length, dtype=dtype, device=device
163
+ ),
164
+ mask,
165
+ ],
166
+ dim=-1,
167
+ )
168
+ return mask[None, None, :, :].expand(
169
+ bsz, 1, tgt_len, tgt_len + past_key_values_length
170
+ )
171
+
172
+
173
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
174
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
175
+ """
176
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
177
+ """
178
+ bsz, src_len = mask.size()
179
+ tgt_len = tgt_len if tgt_len is not None else src_len
180
+
181
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
182
+
183
+ inverted_mask = 1.0 - expanded_mask
184
+
185
+ return inverted_mask.masked_fill(
186
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
187
+ )
188
+
189
+
190
+ class RotaryEmbedding(nn.Module):
191
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
192
+ super().__init__()
193
+
194
+ self.dim = dim
195
+ self.max_position_embeddings = max_position_embeddings
196
+ self.base = base
197
+ inv_freq = 1.0 / (
198
+ self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
199
+ )
200
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
201
+
202
+ # Build here to make `torch.jit.trace` work.
203
+ self._set_cos_sin_cache(
204
+ seq_len=max_position_embeddings,
205
+ device=self.inv_freq.device,
206
+ dtype=torch.get_default_dtype(),
207
+ )
208
+
209
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
210
+ self.max_seq_len_cached = seq_len
211
+ t = torch.arange(
212
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
213
+ )
214
+
215
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
216
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
217
+ emb = torch.cat((freqs, freqs), dim=-1)
218
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
219
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
220
+
221
+ def forward(self, x, seq_len=None):
222
+ # x: [bs, num_attention_heads, seq_len, head_size]
223
+ if seq_len > self.max_seq_len_cached:
224
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
225
+
226
+ return (
227
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
228
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
229
+ )
230
+
231
+
232
+ class MoonshotAttention(nn.Module):
233
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
234
+
235
+ def __init__(self, config: KimiAudioConfig):
236
+ super().__init__()
237
+ self.config = config
238
+ self.hidden_size = config.hidden_size
239
+ self.num_heads = config.num_attention_heads
240
+ self.head_dim = self.hidden_size // self.num_heads
241
+ self.num_key_value_heads = config.num_key_value_heads
242
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
243
+ self.max_position_embeddings = config.max_position_embeddings
244
+ self.rope_theta = config.rope_theta
245
+ if (self.head_dim * self.num_heads) != self.hidden_size:
246
+ raise ValueError(
247
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
248
+ f" and `num_heads`: {self.num_heads})."
249
+ )
250
+ self.q_proj = nn.Linear(
251
+ self.hidden_size, self.num_heads * self.head_dim, bias=True
252
+ )
253
+ self.k_proj = nn.Linear(
254
+ self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
255
+ )
256
+ self.v_proj = nn.Linear(
257
+ self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
258
+ )
259
+ self.o_proj = nn.Linear(
260
+ self.num_heads * self.head_dim, self.hidden_size, bias=False
261
+ )
262
+
263
+ self._init_rope()
264
+
265
+ def _init_rope(self):
266
+
267
+ self.rotary_emb = RotaryEmbedding(
268
+ self.head_dim,
269
+ max_position_embeddings=self.max_position_embeddings,
270
+ base=self.rope_theta,
271
+ )
272
+
273
+ def forward(
274
+ self,
275
+ hidden_states: torch.Tensor,
276
+ attention_mask: Optional[torch.Tensor] = None,
277
+ position_ids: Optional[torch.LongTensor] = None,
278
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
279
+ output_attentions: bool = False,
280
+ use_cache: bool = False,
281
+ padding_mask: Optional[torch.LongTensor] = None,
282
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
283
+ # LlamaFlashAttention2 attention does not support output_attentions
284
+
285
+ output_attentions = False
286
+
287
+ bsz, q_len, _ = hidden_states.size()
288
+
289
+ query_states = self.q_proj(hidden_states)
290
+ key_states = self.k_proj(hidden_states)
291
+ value_states = self.v_proj(hidden_states)
292
+
293
+ # Flash attention requires the input to have the shape
294
+ # batch_size x seq_length x head_dime x hidden_dim
295
+ # therefore we just need to keep the original shape
296
+ query_states = query_states.view(
297
+ bsz, q_len, self.num_heads, self.head_dim
298
+ ).transpose(1, 2)
299
+ key_states = key_states.view(
300
+ bsz, q_len, self.num_key_value_heads, self.head_dim
301
+ ).transpose(1, 2)
302
+ value_states = value_states.view(
303
+ bsz, q_len, self.num_key_value_heads, self.head_dim
304
+ ).transpose(1, 2)
305
+
306
+ kv_seq_len = key_states.shape[-2]
307
+ if past_key_value is not None:
308
+ kv_seq_len += past_key_value[0].shape[-2]
309
+
310
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
311
+ cos = cos[position_ids]
312
+ sin = sin[position_ids]
313
+ query_states, key_states = apply_rotary_pos_emb(
314
+ query_states, key_states, cos, sin, position_ids
315
+ )
316
+
317
+ if past_key_value is not None:
318
+ # reuse k, v, self_attention
319
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
320
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
321
+
322
+ past_key_value = (key_states, value_states) if use_cache else None
323
+
324
+ query_states = query_states.transpose(1, 2)
325
+ key_states = key_states.transpose(1, 2)
326
+ value_states = value_states.transpose(1, 2)
327
+
328
+ # TODO: llama does not have dropout in the config??
329
+ # It is recommended to use dropout with FA according to the docs
330
+ # when training.
331
+ dropout_rate = 0.0 # if not self.training else self.attn_dropout
332
+
333
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
334
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
335
+ # cast them back in float16 just to be sure everything works as expected.
336
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
337
+ # in fp32. (LlamaRMSNorm handles it correctly)
338
+ input_dtype = query_states.dtype
339
+ if input_dtype == torch.float32:
340
+ logger.warning_once(
341
+ "The input hidden states seems to be silently casted in float32, this might be related to"
342
+ " the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
343
+ " float16."
344
+ )
345
+
346
+ query_states = query_states.to(torch.float16)
347
+ key_states = key_states.to(torch.float16)
348
+ value_states = value_states.to(torch.float16)
349
+
350
+ attn_output = self._flash_attention_forward(
351
+ query_states,
352
+ key_states,
353
+ value_states,
354
+ padding_mask,
355
+ q_len,
356
+ dropout=dropout_rate,
357
+ )
358
+
359
+ if input_dtype == torch.float32:
360
+ attn_output = attn_output.to(torch.float32)
361
+
362
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
363
+ attn_output = self.o_proj(attn_output)
364
+
365
+ if not output_attentions:
366
+ attn_weights = None
367
+
368
+ return attn_output, attn_weights, past_key_value
369
+
370
+ def _flash_attention_forward(
371
+ self,
372
+ query_states,
373
+ key_states,
374
+ value_states,
375
+ padding_mask,
376
+ query_length,
377
+ dropout=0.0,
378
+ softmax_scale=None,
379
+ ):
380
+ """
381
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
382
+ first unpad the input, then computes the attention scores and pad the final attention scores.
383
+
384
+ Args:
385
+ query_states (`torch.Tensor`):
386
+ Input query states to be passed to Flash Attention API
387
+ key_states (`torch.Tensor`):
388
+ Input key states to be passed to Flash Attention API
389
+ value_states (`torch.Tensor`):
390
+ Input value states to be passed to Flash Attention API
391
+ padding_mask (`torch.Tensor`):
392
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
393
+ position of padding tokens and 1 for the position of non-padding tokens.
394
+ dropout (`int`, *optional*):
395
+ Attention dropout
396
+ softmax_scale (`float`, *optional*):
397
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
398
+ """
399
+ # Contains at least one padding token in the sequence
400
+ if padding_mask is not None:
401
+ batch_size = query_states.shape[0]
402
+ (
403
+ query_states,
404
+ key_states,
405
+ value_states,
406
+ indices_q,
407
+ cu_seq_lens,
408
+ max_seq_lens,
409
+ ) = _upad_input(
410
+ query_states, key_states, value_states, padding_mask, query_length
411
+ )
412
+
413
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
414
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
415
+
416
+ attn_output_unpad = flash_attn_varlen_func(
417
+ query_states,
418
+ key_states,
419
+ value_states,
420
+ cu_seqlens_q=cu_seqlens_q,
421
+ cu_seqlens_k=cu_seqlens_k,
422
+ max_seqlen_q=max_seqlen_in_batch_q,
423
+ max_seqlen_k=max_seqlen_in_batch_k,
424
+ dropout_p=dropout,
425
+ softmax_scale=softmax_scale,
426
+ causal=True,
427
+ )
428
+
429
+ attn_output = pad_input(
430
+ attn_output_unpad, indices_q, batch_size, query_length
431
+ )
432
+ else:
433
+ attn_output = flash_attn_func(
434
+ query_states,
435
+ key_states,
436
+ value_states,
437
+ dropout,
438
+ softmax_scale=softmax_scale,
439
+ causal=True,
440
+ )
441
+
442
+ return attn_output
443
+
444
+
445
+ class MoonshotDecoderLayer(nn.Module):
446
+ def __init__(self, config: KimiAudioConfig):
447
+ super().__init__()
448
+ self.hidden_size = config.hidden_size
449
+ self.config = config
450
+
451
+ logger.warning_once("using normal flash attention")
452
+ self.self_attn = MoonshotAttention(config=config)
453
+
454
+ self.mlp = Qwen2MLP(config)
455
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
456
+ self.post_attention_layernorm = Qwen2RMSNorm(
457
+ config.hidden_size, eps=config.rms_norm_eps
458
+ )
459
+
460
+ def forward(
461
+ self,
462
+ hidden_states: torch.Tensor,
463
+ attention_mask: Optional[torch.Tensor] = None,
464
+ position_ids: Optional[torch.LongTensor] = None,
465
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
466
+ output_attentions: Optional[bool] = False,
467
+ use_cache: Optional[bool] = False,
468
+ padding_mask: Optional[torch.LongTensor] = None,
469
+ ) -> Tuple[
470
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
471
+ ]:
472
+ """
473
+ Args:
474
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
475
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
476
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
477
+ output_attentions (`bool`, *optional*):
478
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
479
+ returned tensors for more detail.
480
+ use_cache (`bool`, *optional*):
481
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
482
+ (see `past_key_values`).
483
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
484
+ """
485
+
486
+ residual = hidden_states
487
+
488
+ hidden_states = self.input_layernorm(hidden_states)
489
+
490
+ # Self Attention
491
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
492
+ hidden_states=hidden_states,
493
+ attention_mask=attention_mask,
494
+ position_ids=position_ids,
495
+ past_key_value=past_key_value,
496
+ output_attentions=output_attentions,
497
+ use_cache=use_cache,
498
+ padding_mask=padding_mask,
499
+ )
500
+ hidden_states = residual + hidden_states
501
+
502
+ # Fully Connected
503
+ residual = hidden_states
504
+ hidden_states = self.post_attention_layernorm(hidden_states)
505
+ hidden_states = self.mlp(hidden_states)
506
+ hidden_states = residual + hidden_states
507
+
508
+ outputs = (hidden_states,)
509
+
510
+ if output_attentions:
511
+ outputs += (self_attn_weights,)
512
+
513
+ if use_cache:
514
+ outputs += (present_key_value,)
515
+
516
+ return outputs
517
+
518
+
519
+ class VQAdaptor(nn.Module):
520
+ def __init__(self, config):
521
+ super().__init__()
522
+ self.layers = nn.Sequential(
523
+ nn.Linear(config.kimia_adaptor_input_dim, config.hidden_size, bias=True),
524
+ nn.SiLU(),
525
+ nn.Dropout(0.0),
526
+ nn.Linear(config.hidden_size, config.hidden_size, bias=True),
527
+ nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps, bias=True),
528
+ )
529
+
530
+ def forward(self, x):
531
+ return self.layers(x)
532
+
533
+
534
+ class MoonshotKimiaModel(Qwen2PreTrainedModel):
535
+ """
536
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`QwenDecoderLayer`]
537
+
538
+ Args:
539
+ config: KimiAudioConfig
540
+ """
541
+
542
+ config_class = KimiAudioConfig
543
+
544
+ def __init__(self, config: KimiAudioConfig):
545
+ super().__init__(config)
546
+ self.padding_idx = config.pad_token_id
547
+ self.vocab_size = config.vocab_size
548
+ self.kimia_mimo_transformer_from_layer_index = (
549
+ config.kimia_mimo_transformer_from_layer_index
550
+ )
551
+
552
+ self.embed_tokens = nn.Embedding(
553
+ config.vocab_size, config.hidden_size, self.padding_idx
554
+ )
555
+ self.layers = nn.ModuleList(
556
+ [MoonshotDecoderLayer(config) for _ in range(config.num_hidden_layers)]
557
+ )
558
+ self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
559
+
560
+ # extra 1B audio transformers
561
+ self.mimo_layers = nn.ModuleList(
562
+ [MoonshotDecoderLayer(config) for _ in range(config.kimia_mimo_layers)]
563
+ )
564
+ self.mimo_norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
565
+ self.use_whisper_feature = config.use_whisper_feature
566
+ if self.use_whisper_feature:
567
+ self.vq_adaptor = VQAdaptor(config)
568
+ self.kimia_media_begin = config.kimia_media_begin
569
+ self.kimia_media_end = config.kimia_media_end
570
+
571
+ self.gradient_checkpointing = False
572
+ # Initialize weights and apply final processing
573
+ self.post_init()
574
+
575
+ def get_input_embeddings(self):
576
+ return self.embed_tokens
577
+
578
+ def set_input_embeddings(self, value):
579
+ self.embed_tokens = value
580
+
581
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
582
+ def _prepare_decoder_attention_mask(
583
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
584
+ ):
585
+ # create causal mask
586
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
587
+ combined_attention_mask = None
588
+ if input_shape[-1] > 1:
589
+ combined_attention_mask = _make_causal_mask(
590
+ input_shape,
591
+ inputs_embeds.dtype,
592
+ device=inputs_embeds.device,
593
+ past_key_values_length=past_key_values_length,
594
+ )
595
+
596
+ if attention_mask is not None:
597
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
598
+ expanded_attn_mask = _expand_mask(
599
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
600
+ ).to(inputs_embeds.device)
601
+ combined_attention_mask = (
602
+ expanded_attn_mask
603
+ if combined_attention_mask is None
604
+ else expanded_attn_mask + combined_attention_mask
605
+ )
606
+
607
+ return combined_attention_mask
608
+
609
+ def forward(
610
+ self,
611
+ input_ids: torch.LongTensor = None,
612
+ text_input_ids: torch.LongTensor = None,
613
+ whisper_input_feature: Optional[torch.FloatTensor] = None,
614
+ is_continuous_mask: Optional[torch.Tensor] = None,
615
+ attention_mask: Optional[torch.Tensor] = None,
616
+ position_ids: Optional[torch.LongTensor] = None,
617
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
618
+ inputs_embeds: Optional[torch.FloatTensor] = None,
619
+ use_cache: Optional[bool] = None,
620
+ output_attentions: Optional[bool] = None,
621
+ output_hidden_states: Optional[bool] = None,
622
+ return_dict: Optional[bool] = None,
623
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
624
+ output_attentions = (
625
+ output_attentions
626
+ if output_attentions is not None
627
+ else self.config.output_attentions
628
+ )
629
+ output_hidden_states = (
630
+ output_hidden_states
631
+ if output_hidden_states is not None
632
+ else self.config.output_hidden_states
633
+ )
634
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
635
+
636
+ return_dict = (
637
+ return_dict if return_dict is not None else self.config.use_return_dict
638
+ )
639
+
640
+ # retrieve input_ids and inputs_embeds
641
+ if input_ids is not None and inputs_embeds is not None:
642
+ raise ValueError(
643
+ "You cannot specify both input_ids and inputs_embeds at the same time"
644
+ )
645
+ elif input_ids is not None:
646
+ batch_size, seq_length = input_ids.shape
647
+ elif inputs_embeds is not None:
648
+ batch_size, seq_length, _ = inputs_embeds.shape
649
+ else:
650
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
651
+
652
+ seq_length_with_past = seq_length
653
+ past_key_values_length = 0
654
+
655
+ if past_key_values is not None:
656
+ past_key_values_length = past_key_values[0][0].shape[2]
657
+ seq_length_with_past = seq_length_with_past + past_key_values_length
658
+ if position_ids is None:
659
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
660
+ position_ids = torch.arange(
661
+ past_key_values_length,
662
+ seq_length + past_key_values_length,
663
+ dtype=torch.long,
664
+ device=device,
665
+ )
666
+ position_ids = position_ids.unsqueeze(0)
667
+
668
+ if inputs_embeds is None:
669
+ # shape: batch, seq_len, hidden_size
670
+ input_ids = input_ids.to(torch.cuda.current_device())
671
+ text_input_ids = text_input_ids.to(torch.cuda.current_device())
672
+ audio_emb = self.embed_tokens(input_ids)
673
+ if self.use_whisper_feature and whisper_input_feature is not None:
674
+ if not isinstance(whisper_input_feature, list):
675
+ whisper_input_feature = whisper_input_feature.squeeze(0)
676
+ whisper_input_feature = [whisper_input_feature]
677
+
678
+ media_start_idx = (input_ids == self.kimia_media_begin).nonzero()
679
+ media_end_idx = (input_ids == self.kimia_media_end).nonzero()
680
+ # shape: batch, seq_len, hidden_size
681
+ whisper_input_dim = whisper_input_feature[0].shape[-1]
682
+ whisper_dtype = whisper_input_feature[0].dtype
683
+ expanded_whisper = (
684
+ torch.zeros(audio_emb.shape[1], whisper_input_dim)
685
+ .to(torch.cuda.current_device())
686
+ .to(whisper_dtype)
687
+ )
688
+ for (seg_idx, start_idx), (_, end_idx) in zip(
689
+ media_start_idx, media_end_idx
690
+ ):
691
+ # assert whisper_emb.shape[1] == end_idx - (start_idx + 1)
692
+
693
+ feat_len = end_idx - (start_idx + 1)
694
+ whisper_input_feature_i = whisper_input_feature[seg_idx].squeeze(0)
695
+ assert feat_len == is_continuous_mask[seg_idx].sum()
696
+ expanded_whisper[start_idx + 1 : end_idx, :] = (
697
+ whisper_input_feature_i[:feat_len, :]
698
+ )
699
+
700
+ expanded_whisper = expanded_whisper.unsqueeze(0)
701
+ whisper_emb = self.vq_adaptor(
702
+ expanded_whisper.transpose(0, 1)
703
+ ).transpose(0, 1)
704
+ is_continuous_mask = is_continuous_mask.to(torch.cuda.current_device())
705
+ whisper_emb = whisper_emb.to(torch.cuda.current_device())
706
+ whisper_emb = whisper_emb * is_continuous_mask[:, :, None]
707
+
708
+ encoder_input_addwith_discrete_token = (
709
+ audio_emb + whisper_emb
710
+ ) * torch.sqrt(
711
+ torch.tensor(
712
+ 2.0, dtype=whisper_emb.dtype, device=torch.cuda.current_device()
713
+ )
714
+ )
715
+ audio_emb = (
716
+ audio_emb * (~is_continuous_mask[:, :, None])
717
+ + encoder_input_addwith_discrete_token
718
+ * is_continuous_mask[:, :, None]
719
+ )
720
+ if text_input_ids is not None and text_input_ids.sum() != 0:
721
+ inputs_embeds = audio_emb + self.embed_tokens(text_input_ids)
722
+ else:
723
+ inputs_embeds = audio_emb
724
+ # embed positions
725
+ # TODO kill attention_mask for prefill
726
+ padding_mask = attention_mask
727
+
728
+ hidden_states = inputs_embeds
729
+
730
+ # decoder layers
731
+ all_hidden_states = () if output_hidden_states else None
732
+ all_self_attns = () if output_attentions else None
733
+ next_decoder_cache = () if use_cache else None
734
+ for idx, decoder_layer in enumerate(self.layers):
735
+ if output_hidden_states:
736
+ all_hidden_states += (hidden_states,)
737
+
738
+ past_key_value = (
739
+ past_key_values[idx] if past_key_values is not None else None
740
+ )
741
+ layer_outputs = decoder_layer(
742
+ hidden_states,
743
+ attention_mask=attention_mask,
744
+ position_ids=position_ids,
745
+ past_key_value=past_key_value,
746
+ output_attentions=output_attentions,
747
+ use_cache=use_cache,
748
+ padding_mask=padding_mask,
749
+ )
750
+
751
+ hidden_states = layer_outputs[0]
752
+ if idx == self.kimia_mimo_transformer_from_layer_index:
753
+ mimo_hidden_states = hidden_states.clone()
754
+
755
+ if use_cache:
756
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
757
+
758
+ if output_attentions:
759
+ all_self_attns += (layer_outputs[1],)
760
+
761
+ hidden_states = self.norm(hidden_states)
762
+ if output_hidden_states:
763
+ all_hidden_states += (hidden_states,)
764
+
765
+ # apply audio transformer layers
766
+ for idx, decoder_layer in enumerate(self.mimo_layers):
767
+ if output_hidden_states:
768
+ all_hidden_states += (mimo_hidden_states,)
769
+
770
+ past_key_value = (
771
+ past_key_values[idx + len(self.layers)]
772
+ if past_key_values is not None
773
+ else None
774
+ )
775
+ layer_outputs = decoder_layer(
776
+ mimo_hidden_states,
777
+ attention_mask=attention_mask,
778
+ position_ids=position_ids,
779
+ past_key_value=past_key_value,
780
+ output_attentions=output_attentions,
781
+ use_cache=use_cache,
782
+ padding_mask=padding_mask,
783
+ )
784
+
785
+ mimo_hidden_states = layer_outputs[0]
786
+
787
+ if use_cache:
788
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
789
+
790
+ mimo_hidden_states = self.mimo_norm(mimo_hidden_states)
791
+
792
+ # add hidden states from the last decoder layer
793
+ if output_hidden_states:
794
+ all_hidden_states += (mimo_hidden_states,)
795
+
796
+ next_cache = next_decoder_cache if use_cache else None
797
+ if not return_dict:
798
+ return tuple(
799
+ v
800
+ for v in [
801
+ hidden_states,
802
+ mimo_hidden_states,
803
+ next_cache,
804
+ all_hidden_states,
805
+ all_hidden_states,
806
+ all_self_attns,
807
+ ]
808
+ if v is not None
809
+ )
810
+ return BaseModelOutputWithPast(
811
+ last_hidden_state=(hidden_states, mimo_hidden_states),
812
+ past_key_values=next_cache,
813
+ hidden_states=all_hidden_states,
814
+ attentions=all_self_attns,
815
+ )
816
+
817
+
818
+ class MoonshotKimiaForCausalLM(Qwen2PreTrainedModel):
819
+ _tied_weights_keys = ["lm_head.weight", "mimo_output.weight"]
820
+ config_class = KimiAudioConfig
821
+
822
+ def __init__(self, config):
823
+ super().__init__(config)
824
+ self.model = MoonshotKimiaModel(config)
825
+ self.vocab_size = config.vocab_size
826
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
827
+ self.mimo_output = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
828
+
829
+ # Initialize weights and apply final processing
830
+ self.post_init()
831
+
832
+ def get_input_embeddings(self):
833
+ return self.model.embed_tokens
834
+
835
+ def set_input_embeddings(self, value):
836
+ self.model.embed_tokens = value
837
+
838
+ def get_output_embeddings(self):
839
+ return self.lm_head
840
+
841
+ def set_output_embeddings(self, new_embeddings):
842
+ self.lm_head = new_embeddings
843
+
844
+ def set_decoder(self, decoder):
845
+ self.model = decoder
846
+
847
+ def get_decoder(self):
848
+ return self.model
849
+
850
+ def forward(
851
+ self,
852
+ input_ids: torch.LongTensor = None,
853
+ text_input_ids: torch.LongTensor = None,
854
+ whisper_input_feature: Optional[torch.FloatTensor] = None,
855
+ is_continuous_mask: Optional[torch.Tensor] = None,
856
+ attention_mask: Optional[torch.Tensor] = None,
857
+ position_ids: Optional[torch.LongTensor] = None,
858
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
859
+ inputs_embeds: Optional[torch.FloatTensor] = None,
860
+ labels: Optional[torch.LongTensor] = None,
861
+ use_cache: Optional[bool] = None,
862
+ output_attentions: Optional[bool] = None,
863
+ output_hidden_states: Optional[bool] = None,
864
+ generation_mode: Optional[bool] = None,
865
+ return_dict: Optional[bool] = None,
866
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
867
+
868
+ output_attentions = (
869
+ output_attentions
870
+ if output_attentions is not None
871
+ else self.config.output_attentions
872
+ )
873
+ output_hidden_states = (
874
+ output_hidden_states
875
+ if output_hidden_states is not None
876
+ else self.config.output_hidden_states
877
+ )
878
+ return_dict = (
879
+ return_dict if return_dict is not None else self.config.use_return_dict
880
+ )
881
+
882
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
883
+ outputs = self.model(
884
+ input_ids=input_ids,
885
+ text_input_ids=text_input_ids,
886
+ whisper_input_feature=whisper_input_feature,
887
+ is_continuous_mask=is_continuous_mask,
888
+ attention_mask=attention_mask,
889
+ position_ids=position_ids,
890
+ past_key_values=past_key_values,
891
+ inputs_embeds=inputs_embeds,
892
+ use_cache=use_cache,
893
+ output_attentions=output_attentions,
894
+ output_hidden_states=output_hidden_states,
895
+ return_dict=return_dict,
896
+ )
897
+ if return_dict:
898
+ hidden_states, mimo_hidden_states = (
899
+ outputs.last_hidden_state[0],
900
+ outputs.last_hidden_state[1],
901
+ )
902
+ else:
903
+ hidden_states, mimo_hidden_states = outputs[0], outputs[1]
904
+
905
+ audio_logits = self.lm_head(hidden_states)
906
+ text_logits = self.mimo_output(mimo_hidden_states)
907
+
908
+ if not return_dict:
909
+ output = (text_logits, audio_logits) + outputs[2:]
910
+ return output
911
+ return CausalLMOutputWithPast(
912
+ loss=None,
913
+ logits=(text_logits, audio_logits),
914
+ past_key_values=outputs.past_key_values,
915
+ hidden_states=outputs.hidden_states,
916
+ attentions=outputs.attentions,
917
+ )
special_tokens_map.json ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_msg_end|>",
4
+ "<|im_user_msg_start|>",
5
+ "<|im_assistant_msg_start|>",
6
+ "<|reserved_token_0|>",
7
+ "<|reserved_token_1|>",
8
+ "<|reserved_token_2|>",
9
+ "<|reserved_token_3|>",
10
+ "[EOT]",
11
+ "<|reserved_token_4|>",
12
+ "<|reserved_token_5|>",
13
+ "<|reserved_token_6|>",
14
+ "<|reserved_token_7|>",
15
+ "<|reserved_token_8|>",
16
+ "<|reserved_token_9|>",
17
+ "<|reserved_token_10|>",
18
+ "<|reserved_token_11|>",
19
+ "<|im_media_begin|>",
20
+ "<|reserved_token_12|>",
21
+ "<|im_media_end|>",
22
+ "<|reserved_token_13|>",
23
+ "<|reserved_token_14|>",
24
+ "<|im_kimia_text_blank|>",
25
+ "<|im_kimia_text_eos|>",
26
+ "<|reserved_token_15|>",
27
+ "<|reserved_token_16|>",
28
+ "<|im_kimia_user_msg_start|>",
29
+ "<|im_kimia_assistant_msg_start|>",
30
+ "<|reserved_token_17|>",
31
+ "<|reserved_token_18|>",
32
+ "<|reserved_token_19|>",
33
+ "<|im_kimia_speech_ct_id|>",
34
+ "<|im_kimia_speech_ctd_id|>",
35
+ "<|reserved_token_20|>",
36
+ "<|reserved_token_21|>",
37
+ "<|reserved_token_22|>",
38
+ "<|reserved_token_23|>",
39
+ "<|reserved_token_24|>",
40
+ "<|reserved_token_25|>",
41
+ "<|reserved_token_26|>",
42
+ "<|reserved_token_27|>",
43
+ "<|reserved_token_28|>",
44
+ "<|reserved_token_29|>",
45
+ "<|reserved_token_30|>",
46
+ "<|reserved_token_31|>",
47
+ "<|reserved_token_32|>",
48
+ "<|reserved_token_33|>",
49
+ "<|reserved_token_34|>",
50
+ "<|reserved_token_35|>",
51
+ "<|reserved_token_36|>",
52
+ "<|reserved_token_37|>",
53
+ "<|reserved_token_38|>",
54
+ "<|reserved_token_39|>",
55
+ "<|reserved_token_40|>",
56
+ "<|reserved_token_41|>",
57
+ "<|reserved_token_42|>",
58
+ "<|reserved_token_43|>",
59
+ "<|reserved_token_44|>",
60
+ "<|reserved_token_45|>",
61
+ "<|reserved_token_46|>",
62
+ "<|reserved_token_47|>",
63
+ "<|reserved_token_48|>",
64
+ "<|reserved_token_49|>",
65
+ "<|reserved_token_50|>",
66
+ "<|reserved_token_51|>",
67
+ "<|reserved_token_52|>",
68
+ "<|reserved_token_53|>",
69
+ "<|reserved_token_54|>",
70
+ "<|reserved_token_55|>",
71
+ "<|reserved_token_56|>",
72
+ "<|reserved_token_57|>",
73
+ "<|reserved_token_58|>",
74
+ "<|reserved_token_59|>",
75
+ "<|reserved_token_60|>",
76
+ "<|reserved_token_61|>",
77
+ "<|reserved_token_62|>",
78
+ "<|reserved_token_63|>",
79
+ "<|reserved_token_64|>",
80
+ "<|reserved_token_65|>",
81
+ "<|reserved_token_66|>",
82
+ "<|reserved_token_67|>",
83
+ "<|reserved_token_68|>",
84
+ "<|reserved_token_69|>",
85
+ "<|reserved_token_70|>",
86
+ "<|reserved_token_71|>",
87
+ "<|reserved_token_72|>",
88
+ "<|reserved_token_73|>",
89
+ "<|reserved_token_74|>",
90
+ "<|reserved_token_75|>",
91
+ "<|reserved_token_76|>",
92
+ "<|reserved_token_77|>",
93
+ "<|reserved_token_78|>",
94
+ "<|reserved_token_79|>",
95
+ "<|reserved_token_80|>",
96
+ "<|reserved_token_81|>",
97
+ "<|reserved_token_82|>",
98
+ "<|reserved_token_83|>",
99
+ "<|reserved_token_84|>",
100
+ "<|reserved_token_85|>",
101
+ "<|reserved_token_86|>",
102
+ "<|reserved_token_87|>",
103
+ "<|reserved_token_88|>",
104
+ "<|reserved_token_89|>",
105
+ "<|reserved_token_90|>",
106
+ "<|reserved_token_91|>",
107
+ "<|reserved_token_92|>",
108
+ "<|reserved_token_93|>",
109
+ "<|reserved_token_94|>",
110
+ "<|reserved_token_95|>",
111
+ "<|reserved_token_96|>",
112
+ "<|reserved_token_97|>",
113
+ "<|reserved_token_98|>",
114
+ "<|reserved_token_99|>",
115
+ "<|reserved_token_100|>",
116
+ "<|reserved_token_101|>",
117
+ "<|reserved_token_102|>",
118
+ "<|reserved_token_103|>",
119
+ "<|reserved_token_104|>",
120
+ "<|reserved_token_105|>",
121
+ "<|reserved_token_106|>",
122
+ "<|reserved_token_107|>",
123
+ "<|reserved_token_108|>",
124
+ "<|reserved_token_109|>",
125
+ "<|reserved_token_110|>",
126
+ "<|reserved_token_111|>",
127
+ "<|reserved_token_112|>",
128
+ "<|reserved_token_113|>",
129
+ "<|reserved_token_114|>",
130
+ "<|reserved_token_115|>",
131
+ "<|reserved_token_116|>",
132
+ "<|reserved_token_117|>",
133
+ "<|reserved_token_118|>",
134
+ "<|reserved_token_119|>",
135
+ "<|reserved_token_120|>",
136
+ "<|reserved_token_121|>",
137
+ "<|reserved_token_122|>",
138
+ "<|reserved_token_123|>",
139
+ "<|reserved_token_124|>",
140
+ "<|reserved_token_125|>",
141
+ "<|reserved_token_126|>",
142
+ "<|reserved_token_127|>",
143
+ "<|reserved_token_128|>",
144
+ "<|reserved_token_129|>",
145
+ "<|reserved_token_130|>",
146
+ "<|reserved_token_131|>",
147
+ "<|reserved_token_132|>",
148
+ "<|reserved_token_133|>",
149
+ "<|reserved_token_134|>",
150
+ "<|reserved_token_135|>",
151
+ "<|reserved_token_136|>",
152
+ "<|reserved_token_137|>",
153
+ "<|reserved_token_138|>",
154
+ "<|reserved_token_139|>",
155
+ "<|reserved_token_140|>",
156
+ "<|reserved_token_141|>",
157
+ "<|reserved_token_142|>",
158
+ "<|reserved_token_143|>",
159
+ "<|reserved_token_144|>",
160
+ "<|reserved_token_145|>",
161
+ "<|reserved_token_146|>",
162
+ "<|reserved_token_147|>",
163
+ "<|reserved_token_148|>",
164
+ "<|reserved_token_149|>",
165
+ "<|reserved_token_150|>",
166
+ "<|reserved_token_151|>",
167
+ "<|reserved_token_152|>",
168
+ "<|reserved_token_153|>",
169
+ "<|reserved_token_154|>",
170
+ "<|reserved_token_155|>",
171
+ "<|reserved_token_156|>",
172
+ "<|reserved_token_157|>",
173
+ "<|reserved_token_158|>",
174
+ "<|reserved_token_159|>",
175
+ "<|reserved_token_160|>",
176
+ "<|reserved_token_161|>",
177
+ "<|reserved_token_162|>",
178
+ "<|reserved_token_163|>",
179
+ "<|reserved_token_164|>",
180
+ "<|reserved_token_165|>",
181
+ "<|reserved_token_166|>",
182
+ "<|reserved_token_167|>",
183
+ "<|reserved_token_168|>",
184
+ "<|reserved_token_169|>",
185
+ "<|reserved_token_170|>",
186
+ "<|reserved_token_171|>",
187
+ "<|reserved_token_172|>",
188
+ "<|reserved_token_173|>",
189
+ "<|reserved_token_174|>",
190
+ "<|reserved_token_175|>",
191
+ "<|reserved_token_176|>",
192
+ "<|reserved_token_177|>",
193
+ "<|reserved_token_178|>",
194
+ "<|reserved_token_179|>",
195
+ "<|reserved_token_180|>",
196
+ "<|reserved_token_181|>",
197
+ "<|reserved_token_182|>",
198
+ "<|reserved_token_183|>",
199
+ "<|reserved_token_184|>",
200
+ "<|reserved_token_185|>",
201
+ "<|reserved_token_186|>",
202
+ "<|reserved_token_187|>",
203
+ "<|reserved_token_188|>",
204
+ "<|reserved_token_189|>",
205
+ "<|reserved_token_190|>",
206
+ "<|reserved_token_191|>",
207
+ "<|reserved_token_192|>",
208
+ "<|reserved_token_193|>",
209
+ "<|reserved_token_194|>",
210
+ "<|reserved_token_195|>",
211
+ "<|reserved_token_196|>",
212
+ "<|reserved_token_197|>",
213
+ "<|reserved_token_198|>",
214
+ "<|reserved_token_199|>",
215
+ "<|reserved_token_200|>",
216
+ "<|reserved_token_201|>",
217
+ "<|reserved_token_202|>",
218
+ "<|reserved_token_203|>",
219
+ "<|reserved_token_204|>",
220
+ "<|reserved_token_205|>",
221
+ "<|reserved_token_206|>",
222
+ "<|reserved_token_207|>",
223
+ "<|reserved_token_208|>",
224
+ "<|reserved_token_209|>",
225
+ "<|reserved_token_210|>",
226
+ "<|reserved_token_211|>",
227
+ "<|reserved_token_212|>",
228
+ "<|reserved_token_213|>",
229
+ "<|reserved_token_214|>",
230
+ "<|reserved_token_215|>",
231
+ "<|reserved_token_216|>",
232
+ "<|reserved_token_217|>",
233
+ "<|reserved_token_218|>",
234
+ "<|reserved_token_219|>",
235
+ "<|reserved_token_220|>",
236
+ "<|reserved_token_221|>",
237
+ "<|reserved_token_222|>",
238
+ "<|reserved_token_223|>",
239
+ "<|reserved_token_224|>",
240
+ "<|reserved_token_225|>",
241
+ "<|reserved_token_226|>",
242
+ "<|reserved_token_227|>",
243
+ "<|reserved_token_228|>",
244
+ "<|reserved_token_229|>",
245
+ "<|reserved_token_230|>",
246
+ "<|reserved_token_231|>",
247
+ "<|reserved_token_232|>",
248
+ "<|reserved_token_233|>",
249
+ "<|reserved_token_234|>",
250
+ "<|reserved_token_235|>",
251
+ "<|reserved_token_236|>",
252
+ "<|reserved_token_237|>",
253
+ "<|reserved_token_238|>",
254
+ "<|reserved_token_239|>",
255
+ "<|reserved_token_240|>",
256
+ "<|reserved_token_241|>",
257
+ "<|reserved_token_242|>",
258
+ "<|reserved_token_243|>",
259
+ "<|reserved_token_244|>",
260
+ "<|reserved_token_245|>",
261
+ "<|reserved_token_246|>",
262
+ "<|reserved_token_247|>",
263
+ "<|reserved_token_248|>",
264
+ "<|reserved_token_249|>",
265
+ "<|reserved_token_250|>",
266
+ "<|reserved_token_251|>",
267
+ "<|reserved_token_252|>",
268
+ "<|reserved_token_253|>",
269
+ "<|reserved_token_254|>",
270
+ "<|reserved_token_255|>",
271
+ "<|reserved_token_256|>",
272
+ "<|reserved_token_257|>",
273
+ "<|reserved_token_258|>",
274
+ "<|reserved_token_259|>",
275
+ "<|reserved_token_260|>",
276
+ "<|reserved_token_261|>",
277
+ "<|reserved_token_262|>",
278
+ "<|reserved_token_263|>",
279
+ "<|reserved_token_264|>",
280
+ "<|reserved_token_265|>",
281
+ "<|reserved_token_266|>",
282
+ "<|reserved_token_267|>",
283
+ "<|reserved_token_268|>",
284
+ "<|reserved_token_269|>",
285
+ "<|reserved_token_270|>",
286
+ "<|reserved_token_271|>",
287
+ "<|reserved_token_272|>",
288
+ "<|reserved_token_273|>",
289
+ "<|reserved_token_274|>",
290
+ "<|reserved_token_275|>",
291
+ "<|reserved_token_276|>",
292
+ "<|reserved_token_277|>",
293
+ "<|reserved_token_278|>",
294
+ "<|reserved_token_279|>",
295
+ "<|reserved_token_280|>",
296
+ "<|reserved_token_281|>",
297
+ "<|reserved_token_282|>",
298
+ "<|reserved_token_283|>",
299
+ "<|reserved_token_284|>",
300
+ "<|reserved_token_285|>",
301
+ "<|reserved_token_286|>",
302
+ "<|reserved_token_287|>",
303
+ "<|reserved_token_288|>",
304
+ "<|reserved_token_289|>",
305
+ "<|reserved_token_290|>",
306
+ "<|reserved_token_291|>",
307
+ "<|reserved_token_292|>",
308
+ "<|reserved_token_293|>",
309
+ "<|reserved_token_294|>",
310
+ "<|reserved_token_295|>",
311
+ "<|reserved_token_296|>",
312
+ "<|reserved_token_297|>",
313
+ "<|reserved_token_298|>",
314
+ "<|reserved_token_299|>",
315
+ "<|reserved_token_300|>",
316
+ "<|reserved_token_301|>",
317
+ "<|reserved_token_302|>",
318
+ "<|reserved_token_303|>",
319
+ "<|reserved_token_304|>",
320
+ "<|reserved_token_305|>",
321
+ "<|reserved_token_306|>",
322
+ "<|reserved_token_307|>",
323
+ "<|reserved_token_308|>",
324
+ "<|reserved_token_309|>",
325
+ "<|reserved_token_310|>",
326
+ "<|reserved_token_311|>",
327
+ "<|reserved_token_312|>",
328
+ "<|reserved_token_313|>",
329
+ "<|reserved_token_314|>",
330
+ "<|reserved_token_315|>",
331
+ "<|reserved_token_316|>",
332
+ "<|reserved_token_317|>",
333
+ "<|reserved_token_318|>",
334
+ "<|reserved_token_319|>",
335
+ "<|reserved_token_320|>",
336
+ "<|reserved_token_321|>",
337
+ "<|reserved_token_322|>",
338
+ "<|reserved_token_323|>",
339
+ "<|reserved_token_324|>",
340
+ "<|reserved_token_325|>",
341
+ "<|reserved_token_326|>",
342
+ "<|reserved_token_327|>",
343
+ "<|reserved_token_328|>",
344
+ "<|reserved_token_329|>",
345
+ "<|reserved_token_330|>",
346
+ "<|reserved_token_331|>",
347
+ "<|reserved_token_332|>",
348
+ "<|reserved_token_333|>",
349
+ "<|reserved_token_334|>",
350
+ "<|reserved_token_335|>",
351
+ "<|reserved_token_336|>",
352
+ "<|reserved_token_337|>",
353
+ "<|reserved_token_338|>",
354
+ "<|reserved_token_339|>",
355
+ "<|reserved_token_340|>",
356
+ "<|reserved_token_341|>",
357
+ "<|reserved_token_342|>",
358
+ "<|reserved_token_343|>",
359
+ "<|reserved_token_344|>",
360
+ "<|reserved_token_345|>",
361
+ "<|reserved_token_346|>",
362
+ "<|reserved_token_347|>",
363
+ "<|reserved_token_348|>",
364
+ "<|reserved_token_349|>",
365
+ "<|reserved_token_350|>",
366
+ "<|reserved_token_351|>",
367
+ "<|reserved_token_352|>",
368
+ "<|reserved_token_353|>",
369
+ "<|reserved_token_354|>",
370
+ "<|reserved_token_355|>",
371
+ "<|reserved_token_356|>",
372
+ "<|reserved_token_357|>",
373
+ "<|reserved_token_358|>",
374
+ "<|reserved_token_359|>",
375
+ "<|reserved_token_360|>",
376
+ "<|reserved_token_361|>",
377
+ "<|reserved_token_362|>",
378
+ "<|reserved_token_363|>",
379
+ "<|reserved_token_364|>",
380
+ "<|reserved_token_365|>",
381
+ "<|reserved_token_366|>",
382
+ "<|reserved_token_367|>",
383
+ "<|reserved_token_368|>",
384
+ "<|reserved_token_369|>",
385
+ "<|reserved_token_370|>",
386
+ "<|reserved_token_371|>",
387
+ "<|reserved_token_372|>",
388
+ "<|reserved_token_373|>",
389
+ "<|reserved_token_374|>",
390
+ "<|reserved_token_375|>",
391
+ "<|reserved_token_376|>",
392
+ "<|reserved_token_377|>",
393
+ "<|reserved_token_378|>",
394
+ "<|reserved_token_379|>",
395
+ "<|reserved_token_380|>",
396
+ "<|reserved_token_381|>",
397
+ "<|reserved_token_382|>",
398
+ "<|reserved_token_383|>",
399
+ "<|reserved_token_384|>",
400
+ "<|reserved_token_385|>",
401
+ "<|reserved_token_386|>",
402
+ "<|reserved_token_387|>",
403
+ "<|reserved_token_388|>",
404
+ "<|reserved_token_389|>",
405
+ "<|reserved_token_390|>",
406
+ "<|reserved_token_391|>",
407
+ "<|reserved_token_392|>",
408
+ "<|reserved_token_393|>",
409
+ "<|reserved_token_394|>",
410
+ "<|reserved_token_395|>",
411
+ "<|reserved_token_396|>",
412
+ "<|reserved_token_397|>",
413
+ "<|reserved_token_398|>",
414
+ "<|reserved_token_399|>",
415
+ "<|reserved_token_400|>",
416
+ "<|reserved_token_401|>",
417
+ "<|reserved_token_402|>",
418
+ "<|reserved_token_403|>",
419
+ "<|reserved_token_404|>"
420
+ ],
421
+ "bos_token": "[BOS]",
422
+ "eos_token": "[EOS]",
423
+ "pad_token": "<|reserved_token_406|>",
424
+ "unk_token": "<|reserved_token_405|>"
425
+ }
tiktoken.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2b1b8dfb5cc5f024bafc373121c6aba3f66f9a5a0269e243470a1de16a33186
3
+ size 2561218
tokenization_kimia.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
+
3
+ """Megatron tokenizers."""
4
+ from transformers.tokenization_utils import PreTrainedTokenizer
5
+ from typing import Union
6
+ from typing import (
7
+ AbstractSet,
8
+ cast,
9
+ Collection,
10
+ Dict,
11
+ Iterator,
12
+ List,
13
+ Literal,
14
+ Sequence,
15
+ Union,
16
+ Optional,
17
+ )
18
+ from tiktoken.load import load_tiktoken_bpe
19
+ import tiktoken
20
+ from pathlib import Path
21
+ import os
22
+ import logging
23
+ from tokenizers import AddedToken
24
+
25
+ logger = logging.getLogger(__name__)
26
+ VOCAB_FILES_NAMES = {"vocab_file": "tiktoken.model"}
27
+
28
+
29
+ class TikTokenTokenizer(PreTrainedTokenizer):
30
+ """
31
+ Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
32
+ """
33
+
34
+ special_tokens: Dict[str, int]
35
+
36
+ num_reserved_special_tokens = 293 + 128
37
+
38
+ pat_str = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
39
+
40
+ vocab_files_names = VOCAB_FILES_NAMES
41
+
42
+ def __init__(
43
+ self,
44
+ vocab_file,
45
+ bos_token: Union[str, AddedToken] = "[BOS]",
46
+ eos_token: Union[str, AddedToken] = "[EOS]",
47
+ unk_token: Union[str, AddedToken] = "[UNK]",
48
+ pad_token: Union[str, AddedToken] = "[PAD]",
49
+ additional_special_tokens: Optional[List[str]] = None,
50
+ added_tokens_decoder: Optional[dict] = None,
51
+ **kwargs,
52
+ ):
53
+ """
54
+ Initializes the Tokenizer with a Tiktoken model.
55
+
56
+ Args:
57
+ model_path (str): The path to the Tiktoken model file.
58
+ """
59
+ assert os.path.isfile(vocab_file), vocab_file
60
+
61
+ mergeable_ranks = load_tiktoken_bpe(vocab_file)
62
+ num_base_tokens = len(mergeable_ranks)
63
+
64
+ used_special_tokens = [
65
+ "[BOS]",
66
+ "[EOS]",
67
+ "<|im_msg_end|>", # 0
68
+ "<|im_user_msg_start|>", # 1
69
+ "<|im_assistant_msg_start|>", # 2
70
+ "<|reserved_token_0|>", # 3
71
+ "<|reserved_token_1|>",
72
+ "<|reserved_token_2|>",
73
+ "<|reserved_token_3|>", # 4
74
+ "[EOT]",
75
+ "<|reserved_token_4|>", # 5
76
+ "<|reserved_token_5|>", # 6
77
+ "<|reserved_token_6|>", # 7
78
+ "<|reserved_token_7|>", # 8
79
+ "<|reserved_token_8|>", # 9
80
+ "<|reserved_token_9|>", # 10
81
+ "<|reserved_token_10|>", # 11
82
+ "<|reserved_token_11|>", # 12
83
+ "<|im_media_begin|>", # 13
84
+ "<|reserved_token_12|>", # 14
85
+ "<|im_media_end|>", # 15
86
+ "<|reserved_token_13|>", # 16
87
+ "<|reserved_token_14|>", # 17
88
+ "<|im_kimia_text_blank|>", # 18
89
+ "<|im_kimia_text_eos|>", # 19
90
+ "<|reserved_token_15|>", # 20
91
+ "<|reserved_token_16|>", # 21
92
+ "<|im_kimia_user_msg_start|>", # 22
93
+ "<|im_kimia_assistant_msg_start|>", # 23
94
+ "<|reserved_token_17|>", # 24
95
+ "<|reserved_token_18|>", # 25
96
+ "<|reserved_token_19|>", # 26
97
+ "<|im_kimia_speech_ct_id|>", # 27
98
+ "<|im_kimia_speech_ctd_id|>", # 28
99
+ ]
100
+ autoset_special_tokens = [
101
+ f"<|reserved_token_{i}|>"
102
+ for i in range(
103
+ 20, self.num_reserved_special_tokens - len(used_special_tokens) + 20
104
+ )
105
+ ]
106
+ special_tokens = used_special_tokens + autoset_special_tokens
107
+ self.special_tokens = {
108
+ token: num_base_tokens + i for i, token in enumerate(special_tokens)
109
+ }
110
+ self.model = tiktoken.Encoding(
111
+ name=Path(vocab_file).name,
112
+ pat_str=self.pat_str,
113
+ mergeable_ranks=mergeable_ranks,
114
+ special_tokens=self.special_tokens,
115
+ )
116
+ logger.info(f"Reloaded tiktoken model from {vocab_file}")
117
+
118
+ self.n_words: int = self.model.n_vocab
119
+ # BOS / EOS token IDs
120
+ self.bos_token = "[BOS]"
121
+ self.bos_id: int = self.special_tokens["[BOS]"]
122
+ self.eos_token = "[EOS]"
123
+ self.eos_id: int = self.special_tokens["[EOS]"]
124
+
125
+ # use last speical token as pad token, the last - 1 is unk_token
126
+ self.pad_token: str = special_tokens[-1]
127
+ self.pad_id: int = self.special_tokens[self.pad_token]
128
+
129
+ self.unk_token: str = special_tokens[-2]
130
+ self.unk_id: int = self.special_tokens[self.pad_token]
131
+
132
+ self.stop_tokens = {
133
+ self.special_tokens["[EOS]"],
134
+ self.special_tokens["[EOT]"],
135
+ }
136
+
137
+ logger.info(
138
+ f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
139
+ )
140
+
141
+ def encode(
142
+ self,
143
+ s: str,
144
+ *,
145
+ bos: bool,
146
+ eos: bool,
147
+ allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
148
+ disallowed_special: Union[Literal["all"], Collection[str]] = (),
149
+ ) -> List[int]:
150
+ """
151
+ Encodes a string into a list of token IDs.
152
+
153
+ Args:
154
+ s (str): The input string to be encoded.
155
+ bos (bool): Whether to prepend the beginning-of-sequence token.
156
+ eos (bool): Whether to append the end-of-sequence token.
157
+ allowed_tokens ("all"|set[str]): allowed special tokens in string
158
+ disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string
159
+
160
+ Returns:
161
+ list[int]: A list of token IDs.
162
+
163
+ By default, setting disallowed_special=() encodes a string by ignoring
164
+ special tokens. Specifically:
165
+ - Setting `disallowed_special` to () will cause all text corresponding
166
+ to special tokens to be encoded as natural text (insteading of raising
167
+ an error).
168
+ - Setting `allowed_special` to "all" will treat all text corresponding
169
+ to special tokens to be encoded as special tokens.
170
+ """
171
+ assert type(s) is str
172
+
173
+ # The tiktoken tokenizer can handle <=400k chars without
174
+ # pyo3_runtime.PanicException.
175
+ TIKTOKEN_MAX_ENCODE_CHARS = 400_000
176
+
177
+ # https://github.com/openai/tiktoken/issues/195
178
+ # Here we iterate over subsequences and split if we exceed the limit
179
+ # of max consecutive non-whitespace or whitespace characters.
180
+ MAX_NO_WHITESPACES_CHARS = 25_000
181
+
182
+ substrs = (
183
+ substr
184
+ for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
185
+ for substr in self._split_whitespaces_or_nonwhitespaces(
186
+ s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
187
+ )
188
+ )
189
+ t: List[int] = []
190
+ for substr in substrs:
191
+ t.extend(
192
+ self.model.encode(
193
+ substr,
194
+ allowed_special=allowed_special,
195
+ disallowed_special=disallowed_special,
196
+ )
197
+ )
198
+ if bos:
199
+ t.insert(0, self.bos_id)
200
+ if eos:
201
+ t.append(self.eos_id)
202
+ return t
203
+
204
+ def decode(self, t: Sequence[int]) -> str:
205
+ """
206
+ Decodes a list of token IDs into a string.
207
+
208
+ Args:
209
+ t (List[int]): The list of token IDs to be decoded.
210
+
211
+ Returns:
212
+ str: The decoded string.
213
+ """
214
+ # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
215
+ return self.model.decode(cast(List[int], t))
216
+
217
+ @staticmethod
218
+ def _split_whitespaces_or_nonwhitespaces(
219
+ s: str, max_consecutive_slice_len: int
220
+ ) -> Iterator[str]:
221
+ """
222
+ Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
223
+ consecutive whitespaces or consecutive non-whitespaces.
224
+ """
225
+ current_slice_len = 0
226
+ current_slice_is_space = s[0].isspace() if len(s) > 0 else False
227
+ slice_start = 0
228
+
229
+ for i in range(len(s)):
230
+ is_now_space = s[i].isspace()
231
+
232
+ if current_slice_is_space ^ is_now_space:
233
+ current_slice_len = 1
234
+ current_slice_is_space = is_now_space
235
+ else:
236
+ current_slice_len += 1
237
+ if current_slice_len > max_consecutive_slice_len:
238
+ yield s[slice_start:i]
239
+ slice_start = i
240
+ current_slice_len = 1
241
+ yield s[slice_start:]
242
+
243
+ """ ----- Below are the abstract methods required by megatron ----- """
244
+
245
+ @property
246
+ def vocab_size(self):
247
+ return self.n_words
248
+
249
+ @property
250
+ def vocab(self):
251
+ if hasattr(self, "str_vocab"):
252
+ return self.str_vocab
253
+ self.str_vocab = {}
254
+
255
+ # convert mergeable_ranks from bytes to string
256
+ utf8_num, unicode_num = 0, 0
257
+ for byte_key, index in self.model._mergeable_ranks.items():
258
+ try:
259
+ str_key = byte_key.decode("utf-8")
260
+ utf8_num += 1
261
+ except UnicodeDecodeError:
262
+ # use backslashreplace so we can get num vocab different tokens
263
+ # see: https://docs.python.org/3/howto/unicode.html
264
+ # this vocab is only used for offline processing, so this is fine
265
+ str_key = byte_key.decode("utf-8", "backslashreplace") + "_unicode_"
266
+ unicode_num += 1
267
+
268
+ self.str_vocab[str_key] = index
269
+ logger.info(f"num utf8: {utf8_num}, num unicode: {unicode_num}")
270
+
271
+ # add all special tokens to the dictionary
272
+ self.str_vocab.update(self.model._special_tokens)
273
+
274
+ assert len(self.str_vocab) == self.vocab_size
275
+ return self.str_vocab
276
+
277
+ @property
278
+ def inv_vocab(self):
279
+ return {v: k for k, v in self.vocab.items()}
280
+
281
+ def tokenize(self, text, eos=True):
282
+ # BOS: always add bos token
283
+ # EOS:
284
+ # Most cases should be true when we are tokenizing a full sequence
285
+ # Only setting to false when we are running a inference
286
+ return self.encode(text, bos=True, eos=eos)
287
+
288
+ def detokenize(self, tokens):
289
+ # convert tensor to list if needed...
290
+ if not isinstance(tokens, list):
291
+ tokens = tokens.tolist()
292
+ return self.decode(tokens)
293
+
294
+ @property
295
+ def eod(self):
296
+ return self.eos_id
297
+
298
+ def bod(self):
299
+ return self.bos_id
300
+
301
+ @property
302
+ def msk_start_id(self):
303
+ return self.msk_start
304
+
305
+ @property
306
+ def msk_end_id(self):
307
+ return self.msk_end
308
+
309
+ def _get_index_2_bytes(self):
310
+ if hasattr(self, "index_2_bytes"):
311
+ return self.index_2_bytes
312
+
313
+ # use array rather than dict for faster access
314
+ self.index_2_bytes = [0] * self.model.n_vocab
315
+ for byte_key, index in self.model._mergeable_ranks.items():
316
+ self.index_2_bytes[index] = len(byte_key)
317
+
318
+ for _, index in self.model._special_tokens.items():
319
+ # in total we have 256 special tokens, 2^8 = 256
320
+ # so the num of bytes of each token is only 1
321
+ self.index_2_bytes[index] = 1
322
+
323
+ return self.index_2_bytes
324
+
325
+ def get_array_bytes(self, array):
326
+ index_2_bytes = self._get_index_2_bytes()
327
+ return sum(index_2_bytes[i] for i in array)
328
+
329
+ @property
330
+ def eos_token_id(self):
331
+ return self.eos_id
332
+
333
+ @property
334
+ def pad_token_id(self):
335
+ return self.pad_id
tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff