Duplicate from moonshotai/Kimi-Audio-7B-Instruct
Browse filesCo-authored-by: moyanwang <[email protected]>
This view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +35 -0
- .gitignore +0 -0
- README.md +161 -0
- audio_detokenizer/config.yaml +123 -0
- audio_detokenizer/model.pt +3 -0
- config.json +44 -0
- configuration_moonshot_kimia.py +66 -0
- generation_config.json +3 -0
- model-1-of-35.safetensors +3 -0
- model-10-of-35.safetensors +3 -0
- model-11-of-35.safetensors +3 -0
- model-12-of-35.safetensors +3 -0
- model-13-of-35.safetensors +3 -0
- model-14-of-35.safetensors +3 -0
- model-15-of-35.safetensors +3 -0
- model-16-of-35.safetensors +3 -0
- model-17-of-35.safetensors +3 -0
- model-18-of-35.safetensors +3 -0
- model-19-of-35.safetensors +3 -0
- model-2-of-35.safetensors +3 -0
- model-20-of-35.safetensors +3 -0
- model-21-of-35.safetensors +3 -0
- model-22-of-35.safetensors +3 -0
- model-23-of-35.safetensors +3 -0
- model-24-of-35.safetensors +3 -0
- model-25-of-35.safetensors +3 -0
- model-26-of-35.safetensors +3 -0
- model-27-of-35.safetensors +3 -0
- model-28-of-35.safetensors +3 -0
- model-29-of-35.safetensors +3 -0
- model-3-of-35.safetensors +3 -0
- model-30-of-35.safetensors +3 -0
- model-31-of-35.safetensors +3 -0
- model-32-of-35.safetensors +3 -0
- model-33-of-35.safetensors +3 -0
- model-34-of-35.safetensors +3 -0
- model-35-of-35.safetensors +3 -0
- model-36-of-36.safetensors +3 -0
- model-4-of-35.safetensors +3 -0
- model-5-of-35.safetensors +3 -0
- model-6-of-35.safetensors +3 -0
- model-7-of-35.safetensors +3 -0
- model-8-of-35.safetensors +3 -0
- model-9-of-35.safetensors +3 -0
- model.safetensors.index.json +460 -0
- modeling_moonshot_kimia.py +917 -0
- special_tokens_map.json +425 -0
- tiktoken.model +3 -0
- tokenization_kimia.py +335 -0
- tokenizer_config.json +0 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
File without changes
|
README.md
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: mit
|
3 |
+
language:
|
4 |
+
- en
|
5 |
+
- zh
|
6 |
+
tags:
|
7 |
+
- audio
|
8 |
+
- audio-language-model
|
9 |
+
- speech-recognition
|
10 |
+
- audio-understanding
|
11 |
+
- text-to-speech
|
12 |
+
- audio-generation
|
13 |
+
- chat
|
14 |
+
- kimi-audio
|
15 |
+
---
|
16 |
+
|
17 |
+
# Kimi-Audio
|
18 |
+
|
19 |
+
<p align="center">
|
20 |
+
<img src="https://raw.githubusercontent.com/MoonshotAI/Kimi-Audio/master/assets/kimia_logo.png" width="400"/> <!-- TODO: Replace with actual raw image URL from your repo -->
|
21 |
+
<p>
|
22 |
+
|
23 |
+
<p align="center">
|
24 |
+
<a href="https://huggingface.co/moonshotai/Kimi-Audio-7B">🤗 Kimi-Audio-7B</a> | <a href="https://huggingface.co/moonshotai/Kimi-Audio-7B-Instruct">🤗 Kimi-Audio-7B-Instruct </a> | <a href="https://raw.githubusercontent.com/MoonshotAI/Kimi-Audio/master/assets/kimia_report.pdf">📑 Paper</a>
|
25 |
+
</p>
|
26 |
+
|
27 |
+
## Introduction
|
28 |
+
|
29 |
+
We present Kimi-Audio, an open-source audio foundation model excelling in **audio understanding, generation, and conversation**. This repository hosts the model checkpoints for Kimi-Audio-7B-Instruct.
|
30 |
+
|
31 |
+
Kimi-Audio is designed as a universal audio foundation model capable of handling a wide variety of audio processing tasks within a single unified framework. Key features include:
|
32 |
+
|
33 |
+
* **Universal Capabilities:** Handles diverse tasks like speech recognition (ASR), audio question answering (AQA), audio captioning (AAC), speech emotion recognition (SER), sound event/scene classification (SEC/ASC) and end-to-end speech conversation.
|
34 |
+
* **State-of-the-Art Performance:** Achieves SOTA results on numerous audio benchmarks (see our [Technical Report](https://raw.githubusercontent.com/MoonshotAI/Kimi-Audio/master/assets/kimia_report.pdf)).
|
35 |
+
* **Large-Scale Pre-training:** Pre-trained on over 13 million hours of diverse audio data (speech, music, sounds) and text data.
|
36 |
+
* **Novel Architecture:** Employs a hybrid audio input (continuous acoustic + discrete semantic tokens) and an LLM core with parallel heads for text and audio token generation.
|
37 |
+
* **Efficient Inference:** Features a chunk-wise streaming detokenizer based on flow matching for low-latency audio generation.
|
38 |
+
|
39 |
+
For more details, please refer to our [GitHub Repository](https://github.com/MoonshotAI/Kimi-Audio) and [Technical Report](https://raw.githubusercontent.com/MoonshotAI/Kimi-Audio/master/assets/kimia_report.pdf).
|
40 |
+
|
41 |
+
## Requirements
|
42 |
+
|
43 |
+
We recommend that you build a Docker image to run the inference. After cloning the inference code, you can construct the image using the `docker build` command.
|
44 |
+
```bash
|
45 |
+
git clone https://github.com/MoonshotAI/Kimi-Audio
|
46 |
+
git submodule update --init
|
47 |
+
cd Kimi-Audio
|
48 |
+
docker build -t kimi-audio:v0.1 .
|
49 |
+
```
|
50 |
+
Alternatively, You can also use our pre-built image:
|
51 |
+
```bash
|
52 |
+
docker pull moonshotai/kimi-audio:v0.1
|
53 |
+
```
|
54 |
+
|
55 |
+
Or, you can install requirments by:
|
56 |
+
```bash
|
57 |
+
pip install -r requirements.txt
|
58 |
+
```
|
59 |
+
|
60 |
+
You may refer to the Dockerfile in case of any environment issues.
|
61 |
+
|
62 |
+
## Quickstart
|
63 |
+
|
64 |
+
This example demonstrates basic usage for generating text from audio (ASR) and generating both text and speech in a conversational turn using the `Kimi-Audio-7B-Instruct` model.
|
65 |
+
|
66 |
+
```python
|
67 |
+
import soundfile as sf
|
68 |
+
# Assuming the KimiAudio class is available after installation
|
69 |
+
from kimia_infer.api.kimia import KimiAudio
|
70 |
+
import torch # Ensure torch is imported if needed for device placement
|
71 |
+
|
72 |
+
# --- 1. Load Model ---
|
73 |
+
# Load the model from Hugging Face Hub
|
74 |
+
# Make sure you are logged in (`huggingface-cli login`) if the repo is private.
|
75 |
+
model_id = "moonshotai/Kimi-Audio-7B-Instruct" # Or "Kimi/Kimi-Audio-7B"
|
76 |
+
device = "cuda" if torch.cuda.is_available() else "cpu" # Example device placement
|
77 |
+
# Note: The KimiAudio class might handle model loading differently.
|
78 |
+
# You might need to pass the model_id directly or download checkpoints manually
|
79 |
+
# and provide the local path as shown in the original readme_kimia.md.
|
80 |
+
# Please refer to the main Kimi-Audio repository for precise loading instructions.
|
81 |
+
# Example assuming KimiAudio takes the HF ID or a local path:
|
82 |
+
try:
|
83 |
+
model = KimiAudio(model_path=model_id, load_detokenizer=True) # May need device argument
|
84 |
+
model.to(device) # Example device placement
|
85 |
+
except Exception as e:
|
86 |
+
print(f"Automatic loading from HF Hub might require specific setup.")
|
87 |
+
print(f"Refer to Kimi-Audio docs. Trying local path example (update path!). Error: {e}")
|
88 |
+
# Fallback example:
|
89 |
+
# model_path = "/path/to/your/downloaded/kimia-hf-ckpt" # IMPORTANT: Update this path if loading locally
|
90 |
+
# model = KimiAudio(model_path=model_path, load_detokenizer=True)
|
91 |
+
# model.to(device) # Example device placement
|
92 |
+
|
93 |
+
# --- 2. Define Sampling Parameters ---
|
94 |
+
sampling_params = {
|
95 |
+
"audio_temperature": 0.8,
|
96 |
+
"audio_top_k": 10,
|
97 |
+
"text_temperature": 0.0,
|
98 |
+
"text_top_k": 5,
|
99 |
+
"audio_repetition_penalty": 1.0,
|
100 |
+
"audio_repetition_window_size": 64,
|
101 |
+
"text_repetition_penalty": 1.0,
|
102 |
+
"text_repetition_window_size": 16,
|
103 |
+
}
|
104 |
+
|
105 |
+
# --- 3. Example 1: Audio-to-Text (ASR) ---
|
106 |
+
# TODO: Provide actual example audio files or URLs accessible to users
|
107 |
+
# E.g., download sample files first or use URLs
|
108 |
+
# wget https://path/to/your/asr_example.wav -O asr_example.wav
|
109 |
+
# wget https://path/to/your/qa_example.wav -O qa_example.wav
|
110 |
+
asr_audio_path = "asr_example.wav" # IMPORTANT: Make sure this file exists
|
111 |
+
qa_audio_path = "qa_example.wav" # IMPORTANT: Make sure this file exists
|
112 |
+
|
113 |
+
messages_asr = [
|
114 |
+
{"role": "user", "message_type": "text", "content": "Please transcribe the following audio:"},
|
115 |
+
{"role": "user", "message_type": "audio", "content": asr_audio_path}
|
116 |
+
]
|
117 |
+
|
118 |
+
# Generate only text output
|
119 |
+
# Note: Ensure the model object and generate method accept device placement if needed
|
120 |
+
_, text_output = model.generate(messages_asr, **sampling_params, output_type="text")
|
121 |
+
print(">>> ASR Output Text: ", text_output)
|
122 |
+
# Expected output: "这并不是告别,这是一个篇章的结束,也是新篇章的开始。" (Example)
|
123 |
+
|
124 |
+
# --- 4. Example 2: Audio-to-Audio/Text Conversation ---
|
125 |
+
messages_conversation = [
|
126 |
+
{"role": "user", "message_type": "audio", "content": qa_audio_path}
|
127 |
+
]
|
128 |
+
|
129 |
+
# Generate both audio and text output
|
130 |
+
wav_output, text_output = model.generate(messages_conversation, **sampling_params, output_type="both")
|
131 |
+
|
132 |
+
# Save the generated audio
|
133 |
+
output_audio_path = "output_audio.wav"
|
134 |
+
# Ensure wav_output is on CPU and flattened before saving
|
135 |
+
sf.write(output_audio_path, wav_output.detach().cpu().view(-1).numpy(), 24000) # Assuming 24kHz output
|
136 |
+
print(f">>> Conversational Output Audio saved to: {output_audio_path}")
|
137 |
+
print(">>> Conversational Output Text: ", text_output)
|
138 |
+
# Expected output: "A." (Example)
|
139 |
+
|
140 |
+
print("Kimi-Audio inference examples complete.")
|
141 |
+
|
142 |
+
```
|
143 |
+
|
144 |
+
## Citation
|
145 |
+
|
146 |
+
If you find Kimi-Audio useful in your research or applications, please cite our technical report:
|
147 |
+
|
148 |
+
```bibtex
|
149 |
+
@misc{kimi_audio_2024,
|
150 |
+
title={Kimi-Audio Technical Report},
|
151 |
+
author={Kimi Team},
|
152 |
+
year={2024},
|
153 |
+
eprint={arXiv:placeholder},
|
154 |
+
archivePrefix={arXiv},
|
155 |
+
primaryClass={cs.CL}
|
156 |
+
}
|
157 |
+
```
|
158 |
+
|
159 |
+
## License
|
160 |
+
|
161 |
+
The model is based and modified from [Qwen 2.5-7B](https://github.com/QwenLM/Qwen2.5). Code derived from Qwen2.5-7B is licensed under the [Apache 2.0 License](https://www.apache.org/licenses/LICENSE-2.0). Other parts of the code are licensed under the [MIT License](https://opensource.org/licenses/MIT).
|
audio_detokenizer/config.yaml
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accumulate_grad_batches: 1
|
2 |
+
base_config: config/config_base.yaml
|
3 |
+
batch_max_tokens: 12000
|
4 |
+
batch_size: 2
|
5 |
+
cfg_init: 1.0
|
6 |
+
cfg_scale: 4.0
|
7 |
+
cfg_schedule: linear
|
8 |
+
check_val_every_n_epoch: 10
|
9 |
+
clip_grad_norm: 0
|
10 |
+
data_dir: ''
|
11 |
+
debug: false
|
12 |
+
deep_speed_strategy_stage: 2
|
13 |
+
drop_last: true
|
14 |
+
dynamic_cfg: false
|
15 |
+
endless_ds: false
|
16 |
+
filter_args:
|
17 |
+
lang:
|
18 |
+
- zh
|
19 |
+
- en
|
20 |
+
max_spk_num: 6
|
21 |
+
speech_ratio: 0.6
|
22 |
+
gradient_clip_val: 1.0
|
23 |
+
indexed_ds: true
|
24 |
+
infer: false
|
25 |
+
infer_exp_name: ''
|
26 |
+
infer_json_path: ''
|
27 |
+
inference_ckpt: ''
|
28 |
+
inference_mode: nonstreaming
|
29 |
+
learning_rate: 1e-4
|
30 |
+
limit_val_batches: 100
|
31 |
+
load_opt: false
|
32 |
+
log_interval: 10
|
33 |
+
logger_type: tensorboard
|
34 |
+
loss:
|
35 |
+
lambda_fm: 1.0
|
36 |
+
lambda_phone: 0.0
|
37 |
+
mel_loss: l1
|
38 |
+
max_epochs: 1000
|
39 |
+
max_eval_sentences: -1
|
40 |
+
max_eval_tokens: -1
|
41 |
+
max_prompt_ratio: 0.5
|
42 |
+
max_segment_cnt: 20000
|
43 |
+
max_sentences: -1
|
44 |
+
max_speech_duration: 20
|
45 |
+
max_tokens: 31250
|
46 |
+
max_training_steps: 100000
|
47 |
+
max_updates: 160000
|
48 |
+
mel_mean: -4.479605
|
49 |
+
mel_std: 3.4584913
|
50 |
+
meta_dir: null
|
51 |
+
min_prompt_duration: 0.5
|
52 |
+
min_speech_duration: -1
|
53 |
+
model:
|
54 |
+
condition_prenet_depth: 6
|
55 |
+
dit:
|
56 |
+
chunk_params:
|
57 |
+
hz: 50
|
58 |
+
max_chunk: 3.0
|
59 |
+
max_chunk_history: 50000000
|
60 |
+
min_chunk: 0.5
|
61 |
+
need_block_shift: false
|
62 |
+
condition_input_dim: 1280
|
63 |
+
condition_type: discrete_codes
|
64 |
+
depth: 16
|
65 |
+
ffn_act_layer: gleu_tanh
|
66 |
+
ffn_conv_kernel_size: 5
|
67 |
+
ffn_gated_glu: false
|
68 |
+
ffn_type: vanilla_mlp
|
69 |
+
hidden_size: 2304
|
70 |
+
input_size: 80
|
71 |
+
max_seq_len: 4096
|
72 |
+
mlp_ratio: 4.0
|
73 |
+
num_heads: 18
|
74 |
+
position_embedding_type: skip
|
75 |
+
prompt_cfg_dropout: 0.2
|
76 |
+
rope_params:
|
77 |
+
max_position_embeddings: 4096
|
78 |
+
rope_base: 10000.0
|
79 |
+
rope_interpolation_factor: 1.0
|
80 |
+
semantic_cfg_dropout: 0.2
|
81 |
+
semantic_vocab_size: 16384
|
82 |
+
use_chunk_setting: true
|
83 |
+
use_rope: true
|
84 |
+
phone_predictor:
|
85 |
+
blank_id: 4
|
86 |
+
phone_vocab_size: 5000
|
87 |
+
position_id_start_from: 0
|
88 |
+
random_position_start: true
|
89 |
+
restart_position_ids: false
|
90 |
+
use_condition_prenet: false
|
91 |
+
need_merge_same_speaker: true
|
92 |
+
need_precise_phones: false
|
93 |
+
no_verlap: true
|
94 |
+
normalize_mel: true
|
95 |
+
num_nodes: 1
|
96 |
+
num_sanity_val_steps: 0
|
97 |
+
num_workers: 1
|
98 |
+
ode_steps: 150
|
99 |
+
optimizer_adam_beta1: 0.9
|
100 |
+
optimizer_adam_beta2: 0.98
|
101 |
+
optimizer_class: adamw
|
102 |
+
pin_memory: true
|
103 |
+
precision: bf16-mixed
|
104 |
+
save_interval: 2000
|
105 |
+
save_topk: 10
|
106 |
+
seed: 1234
|
107 |
+
shuffle: true
|
108 |
+
sort_by_len: true
|
109 |
+
src_sample_rate: 16000
|
110 |
+
strategy: ddp
|
111 |
+
tensorboard_dir: tb_logs
|
112 |
+
test_num: 100
|
113 |
+
tgt_sample_rate: 24000
|
114 |
+
timescale: 80000
|
115 |
+
use_cfg: false
|
116 |
+
use_cfg_rescale: false
|
117 |
+
use_distributed_sampler: false
|
118 |
+
use_uncondition: false
|
119 |
+
val_check_interval: 2000000
|
120 |
+
vocoder_ckpt: ''
|
121 |
+
wandb_name: glm4_semantic_cfm_v2_debug
|
122 |
+
warmup_updates: 100
|
123 |
+
weight_decay: 0.0001
|
audio_detokenizer/model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cdeeec41e629565439cd8ef807c8a014ad6ce052cce0c259c7bfe3fe6ada3f51
|
3 |
+
size 19008505142
|
config.json
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"MoonshotKimiaForCausalLM"
|
4 |
+
],
|
5 |
+
"auto_map": {
|
6 |
+
"AutoConfig": "configuration_moonshot_kimia.KimiAudioConfig",
|
7 |
+
"AutoModel": "modeling_moonshot_kimia.MoonshotKimiaModel",
|
8 |
+
"AutoModelForCausalLM": "modeling_moonshot_kimia.MoonshotKimiaForCausalLM"
|
9 |
+
},
|
10 |
+
"bos_token_id": 151643,
|
11 |
+
"eos_token_ids": [
|
12 |
+
151644,
|
13 |
+
151645
|
14 |
+
],
|
15 |
+
"hidden_act": "silu",
|
16 |
+
"hidden_size": 3584,
|
17 |
+
"initializer_range": 0.02,
|
18 |
+
"intermediate_size": 18944,
|
19 |
+
"kimia_adaptor_input_dim": 5120,
|
20 |
+
"kimia_audio_output_vocab": 16896,
|
21 |
+
"kimia_media_begin": 151661,
|
22 |
+
"kimia_media_end": 151663,
|
23 |
+
"kimia_mimo_audiodelaytokens": 5,
|
24 |
+
"kimia_mimo_layers": 6,
|
25 |
+
"kimia_mimo_transformer_from_layer_index": 21,
|
26 |
+
"kimia_text_output_vocab": 152064,
|
27 |
+
"kimia_token_offset": 152064,
|
28 |
+
"num_attention_heads": 28,
|
29 |
+
"num_audio_special_tokens": 512,
|
30 |
+
"num_base_tokens": 151643,
|
31 |
+
"num_hidden_layers": 28,
|
32 |
+
"num_key_value_heads": 4,
|
33 |
+
"pad_token_id": 152063,
|
34 |
+
"max_position_embeddings": 8192,
|
35 |
+
"rms_norm_eps": 1e-06,
|
36 |
+
"rope_scaling": null,
|
37 |
+
"rope_theta": 1000000.0,
|
38 |
+
"tie_word_embeddings": false,
|
39 |
+
"torch_dtype": "bfloat16",
|
40 |
+
"transformers_version": "4.44.1",
|
41 |
+
"use_cache": true,
|
42 |
+
"use_whisper_feature": true,
|
43 |
+
"vocab_size": 168448
|
44 |
+
}
|
configuration_moonshot_kimia.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
|
2 |
+
|
3 |
+
|
4 |
+
class KimiAudioConfig(Qwen2Config):
|
5 |
+
def __init__(
|
6 |
+
self,
|
7 |
+
vocab_size=163840,
|
8 |
+
hidden_size=4096,
|
9 |
+
intermediate_size=11008,
|
10 |
+
num_hidden_layers=32,
|
11 |
+
num_attention_heads=32,
|
12 |
+
num_key_value_heads=None,
|
13 |
+
hidden_act="silu",
|
14 |
+
initializer_range=0.02,
|
15 |
+
rms_norm_eps=1e-6,
|
16 |
+
use_cache=True,
|
17 |
+
rope_theta=10000.0,
|
18 |
+
rope_scaling=None,
|
19 |
+
tie_word_embeddings=False,
|
20 |
+
kimia_mimo_layers: int = 6,
|
21 |
+
kimia_mimo_audiodelaytokens: int = 5,
|
22 |
+
kimia_mimo_transformer_from_layer_index: int = 21,
|
23 |
+
kimia_audio_output_vocab: int = 16896,
|
24 |
+
kimia_text_output_vocab: int = 152064,
|
25 |
+
num_audio_special_tokens: int = 512,
|
26 |
+
num_base_tokens: int = 151643,
|
27 |
+
kimia_token_offset: int = 152064,
|
28 |
+
use_whisper_feature: bool = True,
|
29 |
+
kimia_adaptor_input_dim: int = 5120,
|
30 |
+
kimia_media_begin: int = 151661,
|
31 |
+
kimia_media_end: int = 151663,
|
32 |
+
**kwargs,
|
33 |
+
):
|
34 |
+
super().__init__(
|
35 |
+
vocab_size=vocab_size,
|
36 |
+
hidden_size=hidden_size,
|
37 |
+
intermediate_size=intermediate_size,
|
38 |
+
num_hidden_layers=num_hidden_layers,
|
39 |
+
num_attention_heads=num_attention_heads,
|
40 |
+
num_key_value_heads=num_key_value_heads,
|
41 |
+
hidden_act=hidden_act,
|
42 |
+
initializer_range=initializer_range,
|
43 |
+
rms_norm_eps=rms_norm_eps,
|
44 |
+
use_cache=use_cache,
|
45 |
+
tie_word_embeddings=tie_word_embeddings,
|
46 |
+
rope_theta=rope_theta,
|
47 |
+
rope_scaling=rope_scaling,
|
48 |
+
**kwargs,
|
49 |
+
)
|
50 |
+
|
51 |
+
self.kimia_mimo_layers = kimia_mimo_layers
|
52 |
+
self.kimia_mimo_audiodelaytokens = kimia_mimo_audiodelaytokens
|
53 |
+
# vocab
|
54 |
+
self.kimia_mimo_transformer_from_layer_index = (
|
55 |
+
kimia_mimo_transformer_from_layer_index
|
56 |
+
)
|
57 |
+
self.kimia_audio_output_vocab = kimia_audio_output_vocab
|
58 |
+
self.kimia_text_output_vocab = kimia_text_output_vocab
|
59 |
+
self.num_audio_special_tokens = num_audio_special_tokens
|
60 |
+
self.num_base_tokens = num_base_tokens
|
61 |
+
self.kimia_token_offset = kimia_token_offset
|
62 |
+
self.use_whisper_feature = use_whisper_feature
|
63 |
+
self.kimia_adaptor_input_dim = kimia_adaptor_input_dim
|
64 |
+
# special tokens
|
65 |
+
self.kimia_media_begin = kimia_media_begin
|
66 |
+
self.kimia_media_end = kimia_media_end
|
generation_config.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"max_length": 8192
|
3 |
+
}
|
model-1-of-35.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:462878caab2cc405a9665569e9c4a72191b070f73df5530f6e9c419108a24fe2
|
3 |
+
size 466117192
|
model-10-of-35.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:835a9fe443d1118f5ff46d1d35814e6080ff2a9ed578654887638ee853f47ae2
|
3 |
+
size 466117192
|
model-11-of-35.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6c9de9a823c1b63a4813b4e1d0f6d9c522382981f1889f9423aa077cc91a8f0f
|
3 |
+
size 466117208
|
model-12-of-35.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4e38dd5adcf335dfe2da28613ab0098570ae9f5fb0f0d137720b5b2458d68d45
|
3 |
+
size 466117208
|
model-13-of-35.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:895e02180f74fd811874d200229a549344e6a84a12943aa78d7bf03c3ffb6140
|
3 |
+
size 466117208
|
model-14-of-35.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e6ebf2a6d19063146565f0e6f7de3bc810a4565a8a24693c85119969b99e2542
|
3 |
+
size 466117208
|
model-15-of-35.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:76c3fb75b201ccbbfc1f9b909b75924445efe3e09199ac6d13db0857356975fa
|
3 |
+
size 466117208
|
model-16-of-35.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9af4933a1ef591f45d2e61fd7cc5b4e00969e59339a74a10a625de85ccb26362
|
3 |
+
size 466117208
|
model-17-of-35.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:39545d953a1db3ffb63f981e57c2f7d4fd178131c515c58955ece020be96241d
|
3 |
+
size 466117208
|
model-18-of-35.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a782c5c67a08571730bbebdb8928dfad48cf8342c38bbe98c6bfa7329edcf7bd
|
3 |
+
size 466117208
|
model-19-of-35.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:51e850d87b4ece3566404d66c1e27d474549a160b6bc216607bd11fc336e25ee
|
3 |
+
size 466117208
|
model-2-of-35.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c20fafb1f647f47084f81095ffcf2c85a55af0665c5b1caa1fb5c0504ee92433
|
3 |
+
size 466117192
|
model-20-of-35.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:434af95fd5e5a9e38dfe37c2123de7e55e8676c45ecc3f11ab94c0984588ab40
|
3 |
+
size 466117208
|
model-21-of-35.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4f1dd7e0d633053d61bdd14f1ff96f3dc27e7e4ee66c707ce4ff4f4bb1490040
|
3 |
+
size 466117208
|
model-22-of-35.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1e8e22a53c90fd6b68df56de41519c6bce4e0c3df0d75262121c0f0cdfabedf8
|
3 |
+
size 466117208
|
model-23-of-35.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f52cf116ea9183f84678aae29bef762e945f0a11b043385e34d150cb4e0930ee
|
3 |
+
size 466117208
|
model-24-of-35.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:17c2348936c30f4715c6ac406d88d3fcd2ede9b860051b0e43357398b6b2c392
|
3 |
+
size 466117208
|
model-25-of-35.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7f0f0d3852637374b66bd250abab868a48ca4120d01e70acff617ce751e2f9ce
|
3 |
+
size 466117208
|
model-26-of-35.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ec675c4b2590c0f3b6bdfcd3afa5292f08c3723f84d816eb30ad85143b47f704
|
3 |
+
size 466117208
|
model-27-of-35.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:95c4714d295631d3562b25c5168a6551a92450a4c63bb9592eadde18347e2a3c
|
3 |
+
size 466117208
|
model-28-of-35.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d7c53162402f1a258fc1863a5857bc006b84f934719e6350cab328d5190d1800
|
3 |
+
size 466117208
|
model-29-of-35.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:02c5b5edc547ab8ff6ac83fec6f10a018ce111cc1ce6f6e3ced4db432b60903f
|
3 |
+
size 466117264
|
model-3-of-35.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d11bd527d217b86951248ba3506205b47fdd05ef8afedbd088733b2cff236553
|
3 |
+
size 466117192
|
model-30-of-35.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f7132a11e4e890532c9e4c30cfaf55904f2f6fcf79b74552e3903ec8f4fe38c1
|
3 |
+
size 466117264
|
model-31-of-35.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a836b2ad72a98873872bc7cbfda0ac0811d565dc4303212c0684abfaf83fb50f
|
3 |
+
size 466117264
|
model-32-of-35.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cd088c37a69da02d73a779caa447f82fb8dc3ad04a1d1590704395f7043737ee
|
3 |
+
size 466117264
|
model-33-of-35.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9e7ab0dbed449d1ba7ed1f2582e279af6585c505158b1af66a7a77521677b70a
|
3 |
+
size 466117264
|
model-34-of-35.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:46b117a2fe7dd2c2eb3d2c6fb367bc3813a2d3b2a01bd330d983c52112a32ffd
|
3 |
+
size 466117264
|
model-35-of-35.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7dbbe9894d818f2751c32a36fe6db1f7e5d0f7d94e9752ee2ba81535ec777e33
|
3 |
+
size 62419592
|
model-36-of-36.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a25c2286a3373471ab4687f3908327ea15e29a909f9cc69eb642b1b47643a2df
|
3 |
+
size 3622320648
|
model-4-of-35.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9d5fe7491bac40d3b5c5eec23ba449f15e3248d12bdf582881fc65dafebe5a7a
|
3 |
+
size 466117192
|
model-5-of-35.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:193972da11808be8be48b74f439c5cd0567d825c84505565d30e208bf3dff0ed
|
3 |
+
size 466117192
|
model-6-of-35.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6aabe4e1e987bfc7ecc46933ae4a5f37c6e262411c1d86c663709abe633f3756
|
3 |
+
size 466117192
|
model-7-of-35.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a0fd3c96b26299dd056c849d969f54f1f89fca3ab480f7afdb57edf437497f48
|
3 |
+
size 466117192
|
model-8-of-35.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:54f14bd698a405430c529895e167d39827e349a48181544e7e12ff41637f28e8
|
3 |
+
size 466117192
|
model-9-of-35.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fb0dc503bf96d67cf5ebb87bd9d899874e2cda9f0c828786e84f60e8c22cbbfc
|
3 |
+
size 466117192
|
model.safetensors.index.json
ADDED
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"metadata": {
|
3 |
+
"total_size": 19532673280
|
4 |
+
},
|
5 |
+
"weight_map": {
|
6 |
+
"model.layers.0.self_attn.q_proj.weight": "model-1-of-35.safetensors",
|
7 |
+
"model.layers.0.self_attn.k_proj.weight": "model-1-of-35.safetensors",
|
8 |
+
"model.layers.0.self_attn.v_proj.weight": "model-1-of-35.safetensors",
|
9 |
+
"model.layers.0.self_attn.o_proj.weight": "model-1-of-35.safetensors",
|
10 |
+
"model.layers.0.self_attn.q_proj.bias": "model-1-of-35.safetensors",
|
11 |
+
"model.layers.0.self_attn.k_proj.bias": "model-1-of-35.safetensors",
|
12 |
+
"model.layers.0.self_attn.v_proj.bias": "model-1-of-35.safetensors",
|
13 |
+
"model.layers.0.input_layernorm.weight": "model-1-of-35.safetensors",
|
14 |
+
"model.layers.0.post_attention_layernorm.weight": "model-1-of-35.safetensors",
|
15 |
+
"model.layers.0.mlp.gate_proj.weight": "model-1-of-35.safetensors",
|
16 |
+
"model.layers.0.mlp.down_proj.weight": "model-1-of-35.safetensors",
|
17 |
+
"model.layers.0.mlp.up_proj.weight": "model-1-of-35.safetensors",
|
18 |
+
"model.layers.0.self_attn.rotary_emb.inv_freq": "model-1-of-35.safetensors",
|
19 |
+
"model.layers.1.self_attn.q_proj.weight": "model-2-of-35.safetensors",
|
20 |
+
"model.layers.1.self_attn.k_proj.weight": "model-2-of-35.safetensors",
|
21 |
+
"model.layers.1.self_attn.v_proj.weight": "model-2-of-35.safetensors",
|
22 |
+
"model.layers.1.self_attn.o_proj.weight": "model-2-of-35.safetensors",
|
23 |
+
"model.layers.1.self_attn.q_proj.bias": "model-2-of-35.safetensors",
|
24 |
+
"model.layers.1.self_attn.k_proj.bias": "model-2-of-35.safetensors",
|
25 |
+
"model.layers.1.self_attn.v_proj.bias": "model-2-of-35.safetensors",
|
26 |
+
"model.layers.1.input_layernorm.weight": "model-2-of-35.safetensors",
|
27 |
+
"model.layers.1.post_attention_layernorm.weight": "model-2-of-35.safetensors",
|
28 |
+
"model.layers.1.mlp.gate_proj.weight": "model-2-of-35.safetensors",
|
29 |
+
"model.layers.1.mlp.down_proj.weight": "model-2-of-35.safetensors",
|
30 |
+
"model.layers.1.mlp.up_proj.weight": "model-2-of-35.safetensors",
|
31 |
+
"model.layers.1.self_attn.rotary_emb.inv_freq": "model-2-of-35.safetensors",
|
32 |
+
"model.layers.2.self_attn.q_proj.weight": "model-3-of-35.safetensors",
|
33 |
+
"model.layers.2.self_attn.k_proj.weight": "model-3-of-35.safetensors",
|
34 |
+
"model.layers.2.self_attn.v_proj.weight": "model-3-of-35.safetensors",
|
35 |
+
"model.layers.2.self_attn.o_proj.weight": "model-3-of-35.safetensors",
|
36 |
+
"model.layers.2.self_attn.q_proj.bias": "model-3-of-35.safetensors",
|
37 |
+
"model.layers.2.self_attn.k_proj.bias": "model-3-of-35.safetensors",
|
38 |
+
"model.layers.2.self_attn.v_proj.bias": "model-3-of-35.safetensors",
|
39 |
+
"model.layers.2.input_layernorm.weight": "model-3-of-35.safetensors",
|
40 |
+
"model.layers.2.post_attention_layernorm.weight": "model-3-of-35.safetensors",
|
41 |
+
"model.layers.2.mlp.gate_proj.weight": "model-3-of-35.safetensors",
|
42 |
+
"model.layers.2.mlp.down_proj.weight": "model-3-of-35.safetensors",
|
43 |
+
"model.layers.2.mlp.up_proj.weight": "model-3-of-35.safetensors",
|
44 |
+
"model.layers.2.self_attn.rotary_emb.inv_freq": "model-3-of-35.safetensors",
|
45 |
+
"model.layers.3.self_attn.q_proj.weight": "model-4-of-35.safetensors",
|
46 |
+
"model.layers.3.self_attn.k_proj.weight": "model-4-of-35.safetensors",
|
47 |
+
"model.layers.3.self_attn.v_proj.weight": "model-4-of-35.safetensors",
|
48 |
+
"model.layers.3.self_attn.o_proj.weight": "model-4-of-35.safetensors",
|
49 |
+
"model.layers.3.self_attn.q_proj.bias": "model-4-of-35.safetensors",
|
50 |
+
"model.layers.3.self_attn.k_proj.bias": "model-4-of-35.safetensors",
|
51 |
+
"model.layers.3.self_attn.v_proj.bias": "model-4-of-35.safetensors",
|
52 |
+
"model.layers.3.input_layernorm.weight": "model-4-of-35.safetensors",
|
53 |
+
"model.layers.3.post_attention_layernorm.weight": "model-4-of-35.safetensors",
|
54 |
+
"model.layers.3.mlp.gate_proj.weight": "model-4-of-35.safetensors",
|
55 |
+
"model.layers.3.mlp.down_proj.weight": "model-4-of-35.safetensors",
|
56 |
+
"model.layers.3.mlp.up_proj.weight": "model-4-of-35.safetensors",
|
57 |
+
"model.layers.3.self_attn.rotary_emb.inv_freq": "model-4-of-35.safetensors",
|
58 |
+
"model.layers.4.self_attn.q_proj.weight": "model-5-of-35.safetensors",
|
59 |
+
"model.layers.4.self_attn.k_proj.weight": "model-5-of-35.safetensors",
|
60 |
+
"model.layers.4.self_attn.v_proj.weight": "model-5-of-35.safetensors",
|
61 |
+
"model.layers.4.self_attn.o_proj.weight": "model-5-of-35.safetensors",
|
62 |
+
"model.layers.4.self_attn.q_proj.bias": "model-5-of-35.safetensors",
|
63 |
+
"model.layers.4.self_attn.k_proj.bias": "model-5-of-35.safetensors",
|
64 |
+
"model.layers.4.self_attn.v_proj.bias": "model-5-of-35.safetensors",
|
65 |
+
"model.layers.4.input_layernorm.weight": "model-5-of-35.safetensors",
|
66 |
+
"model.layers.4.post_attention_layernorm.weight": "model-5-of-35.safetensors",
|
67 |
+
"model.layers.4.mlp.gate_proj.weight": "model-5-of-35.safetensors",
|
68 |
+
"model.layers.4.mlp.down_proj.weight": "model-5-of-35.safetensors",
|
69 |
+
"model.layers.4.mlp.up_proj.weight": "model-5-of-35.safetensors",
|
70 |
+
"model.layers.4.self_attn.rotary_emb.inv_freq": "model-5-of-35.safetensors",
|
71 |
+
"model.layers.5.self_attn.q_proj.weight": "model-6-of-35.safetensors",
|
72 |
+
"model.layers.5.self_attn.k_proj.weight": "model-6-of-35.safetensors",
|
73 |
+
"model.layers.5.self_attn.v_proj.weight": "model-6-of-35.safetensors",
|
74 |
+
"model.layers.5.self_attn.o_proj.weight": "model-6-of-35.safetensors",
|
75 |
+
"model.layers.5.self_attn.q_proj.bias": "model-6-of-35.safetensors",
|
76 |
+
"model.layers.5.self_attn.k_proj.bias": "model-6-of-35.safetensors",
|
77 |
+
"model.layers.5.self_attn.v_proj.bias": "model-6-of-35.safetensors",
|
78 |
+
"model.layers.5.input_layernorm.weight": "model-6-of-35.safetensors",
|
79 |
+
"model.layers.5.post_attention_layernorm.weight": "model-6-of-35.safetensors",
|
80 |
+
"model.layers.5.mlp.gate_proj.weight": "model-6-of-35.safetensors",
|
81 |
+
"model.layers.5.mlp.down_proj.weight": "model-6-of-35.safetensors",
|
82 |
+
"model.layers.5.mlp.up_proj.weight": "model-6-of-35.safetensors",
|
83 |
+
"model.layers.5.self_attn.rotary_emb.inv_freq": "model-6-of-35.safetensors",
|
84 |
+
"model.layers.6.self_attn.q_proj.weight": "model-7-of-35.safetensors",
|
85 |
+
"model.layers.6.self_attn.k_proj.weight": "model-7-of-35.safetensors",
|
86 |
+
"model.layers.6.self_attn.v_proj.weight": "model-7-of-35.safetensors",
|
87 |
+
"model.layers.6.self_attn.o_proj.weight": "model-7-of-35.safetensors",
|
88 |
+
"model.layers.6.self_attn.q_proj.bias": "model-7-of-35.safetensors",
|
89 |
+
"model.layers.6.self_attn.k_proj.bias": "model-7-of-35.safetensors",
|
90 |
+
"model.layers.6.self_attn.v_proj.bias": "model-7-of-35.safetensors",
|
91 |
+
"model.layers.6.input_layernorm.weight": "model-7-of-35.safetensors",
|
92 |
+
"model.layers.6.post_attention_layernorm.weight": "model-7-of-35.safetensors",
|
93 |
+
"model.layers.6.mlp.gate_proj.weight": "model-7-of-35.safetensors",
|
94 |
+
"model.layers.6.mlp.down_proj.weight": "model-7-of-35.safetensors",
|
95 |
+
"model.layers.6.mlp.up_proj.weight": "model-7-of-35.safetensors",
|
96 |
+
"model.layers.6.self_attn.rotary_emb.inv_freq": "model-7-of-35.safetensors",
|
97 |
+
"model.layers.7.self_attn.q_proj.weight": "model-8-of-35.safetensors",
|
98 |
+
"model.layers.7.self_attn.k_proj.weight": "model-8-of-35.safetensors",
|
99 |
+
"model.layers.7.self_attn.v_proj.weight": "model-8-of-35.safetensors",
|
100 |
+
"model.layers.7.self_attn.o_proj.weight": "model-8-of-35.safetensors",
|
101 |
+
"model.layers.7.self_attn.q_proj.bias": "model-8-of-35.safetensors",
|
102 |
+
"model.layers.7.self_attn.k_proj.bias": "model-8-of-35.safetensors",
|
103 |
+
"model.layers.7.self_attn.v_proj.bias": "model-8-of-35.safetensors",
|
104 |
+
"model.layers.7.input_layernorm.weight": "model-8-of-35.safetensors",
|
105 |
+
"model.layers.7.post_attention_layernorm.weight": "model-8-of-35.safetensors",
|
106 |
+
"model.layers.7.mlp.gate_proj.weight": "model-8-of-35.safetensors",
|
107 |
+
"model.layers.7.mlp.down_proj.weight": "model-8-of-35.safetensors",
|
108 |
+
"model.layers.7.mlp.up_proj.weight": "model-8-of-35.safetensors",
|
109 |
+
"model.layers.7.self_attn.rotary_emb.inv_freq": "model-8-of-35.safetensors",
|
110 |
+
"model.layers.8.self_attn.q_proj.weight": "model-9-of-35.safetensors",
|
111 |
+
"model.layers.8.self_attn.k_proj.weight": "model-9-of-35.safetensors",
|
112 |
+
"model.layers.8.self_attn.v_proj.weight": "model-9-of-35.safetensors",
|
113 |
+
"model.layers.8.self_attn.o_proj.weight": "model-9-of-35.safetensors",
|
114 |
+
"model.layers.8.self_attn.q_proj.bias": "model-9-of-35.safetensors",
|
115 |
+
"model.layers.8.self_attn.k_proj.bias": "model-9-of-35.safetensors",
|
116 |
+
"model.layers.8.self_attn.v_proj.bias": "model-9-of-35.safetensors",
|
117 |
+
"model.layers.8.input_layernorm.weight": "model-9-of-35.safetensors",
|
118 |
+
"model.layers.8.post_attention_layernorm.weight": "model-9-of-35.safetensors",
|
119 |
+
"model.layers.8.mlp.gate_proj.weight": "model-9-of-35.safetensors",
|
120 |
+
"model.layers.8.mlp.down_proj.weight": "model-9-of-35.safetensors",
|
121 |
+
"model.layers.8.mlp.up_proj.weight": "model-9-of-35.safetensors",
|
122 |
+
"model.layers.8.self_attn.rotary_emb.inv_freq": "model-9-of-35.safetensors",
|
123 |
+
"model.layers.9.self_attn.q_proj.weight": "model-10-of-35.safetensors",
|
124 |
+
"model.layers.9.self_attn.k_proj.weight": "model-10-of-35.safetensors",
|
125 |
+
"model.layers.9.self_attn.v_proj.weight": "model-10-of-35.safetensors",
|
126 |
+
"model.layers.9.self_attn.o_proj.weight": "model-10-of-35.safetensors",
|
127 |
+
"model.layers.9.self_attn.q_proj.bias": "model-10-of-35.safetensors",
|
128 |
+
"model.layers.9.self_attn.k_proj.bias": "model-10-of-35.safetensors",
|
129 |
+
"model.layers.9.self_attn.v_proj.bias": "model-10-of-35.safetensors",
|
130 |
+
"model.layers.9.input_layernorm.weight": "model-10-of-35.safetensors",
|
131 |
+
"model.layers.9.post_attention_layernorm.weight": "model-10-of-35.safetensors",
|
132 |
+
"model.layers.9.mlp.gate_proj.weight": "model-10-of-35.safetensors",
|
133 |
+
"model.layers.9.mlp.down_proj.weight": "model-10-of-35.safetensors",
|
134 |
+
"model.layers.9.mlp.up_proj.weight": "model-10-of-35.safetensors",
|
135 |
+
"model.layers.9.self_attn.rotary_emb.inv_freq": "model-10-of-35.safetensors",
|
136 |
+
"model.layers.10.self_attn.q_proj.weight": "model-11-of-35.safetensors",
|
137 |
+
"model.layers.10.self_attn.k_proj.weight": "model-11-of-35.safetensors",
|
138 |
+
"model.layers.10.self_attn.v_proj.weight": "model-11-of-35.safetensors",
|
139 |
+
"model.layers.10.self_attn.o_proj.weight": "model-11-of-35.safetensors",
|
140 |
+
"model.layers.10.self_attn.q_proj.bias": "model-11-of-35.safetensors",
|
141 |
+
"model.layers.10.self_attn.k_proj.bias": "model-11-of-35.safetensors",
|
142 |
+
"model.layers.10.self_attn.v_proj.bias": "model-11-of-35.safetensors",
|
143 |
+
"model.layers.10.input_layernorm.weight": "model-11-of-35.safetensors",
|
144 |
+
"model.layers.10.post_attention_layernorm.weight": "model-11-of-35.safetensors",
|
145 |
+
"model.layers.10.mlp.gate_proj.weight": "model-11-of-35.safetensors",
|
146 |
+
"model.layers.10.mlp.down_proj.weight": "model-11-of-35.safetensors",
|
147 |
+
"model.layers.10.mlp.up_proj.weight": "model-11-of-35.safetensors",
|
148 |
+
"model.layers.10.self_attn.rotary_emb.inv_freq": "model-11-of-35.safetensors",
|
149 |
+
"model.layers.11.self_attn.q_proj.weight": "model-12-of-35.safetensors",
|
150 |
+
"model.layers.11.self_attn.k_proj.weight": "model-12-of-35.safetensors",
|
151 |
+
"model.layers.11.self_attn.v_proj.weight": "model-12-of-35.safetensors",
|
152 |
+
"model.layers.11.self_attn.o_proj.weight": "model-12-of-35.safetensors",
|
153 |
+
"model.layers.11.self_attn.q_proj.bias": "model-12-of-35.safetensors",
|
154 |
+
"model.layers.11.self_attn.k_proj.bias": "model-12-of-35.safetensors",
|
155 |
+
"model.layers.11.self_attn.v_proj.bias": "model-12-of-35.safetensors",
|
156 |
+
"model.layers.11.input_layernorm.weight": "model-12-of-35.safetensors",
|
157 |
+
"model.layers.11.post_attention_layernorm.weight": "model-12-of-35.safetensors",
|
158 |
+
"model.layers.11.mlp.gate_proj.weight": "model-12-of-35.safetensors",
|
159 |
+
"model.layers.11.mlp.down_proj.weight": "model-12-of-35.safetensors",
|
160 |
+
"model.layers.11.mlp.up_proj.weight": "model-12-of-35.safetensors",
|
161 |
+
"model.layers.11.self_attn.rotary_emb.inv_freq": "model-12-of-35.safetensors",
|
162 |
+
"model.layers.12.self_attn.q_proj.weight": "model-13-of-35.safetensors",
|
163 |
+
"model.layers.12.self_attn.k_proj.weight": "model-13-of-35.safetensors",
|
164 |
+
"model.layers.12.self_attn.v_proj.weight": "model-13-of-35.safetensors",
|
165 |
+
"model.layers.12.self_attn.o_proj.weight": "model-13-of-35.safetensors",
|
166 |
+
"model.layers.12.self_attn.q_proj.bias": "model-13-of-35.safetensors",
|
167 |
+
"model.layers.12.self_attn.k_proj.bias": "model-13-of-35.safetensors",
|
168 |
+
"model.layers.12.self_attn.v_proj.bias": "model-13-of-35.safetensors",
|
169 |
+
"model.layers.12.input_layernorm.weight": "model-13-of-35.safetensors",
|
170 |
+
"model.layers.12.post_attention_layernorm.weight": "model-13-of-35.safetensors",
|
171 |
+
"model.layers.12.mlp.gate_proj.weight": "model-13-of-35.safetensors",
|
172 |
+
"model.layers.12.mlp.down_proj.weight": "model-13-of-35.safetensors",
|
173 |
+
"model.layers.12.mlp.up_proj.weight": "model-13-of-35.safetensors",
|
174 |
+
"model.layers.12.self_attn.rotary_emb.inv_freq": "model-13-of-35.safetensors",
|
175 |
+
"model.layers.13.self_attn.q_proj.weight": "model-14-of-35.safetensors",
|
176 |
+
"model.layers.13.self_attn.k_proj.weight": "model-14-of-35.safetensors",
|
177 |
+
"model.layers.13.self_attn.v_proj.weight": "model-14-of-35.safetensors",
|
178 |
+
"model.layers.13.self_attn.o_proj.weight": "model-14-of-35.safetensors",
|
179 |
+
"model.layers.13.self_attn.q_proj.bias": "model-14-of-35.safetensors",
|
180 |
+
"model.layers.13.self_attn.k_proj.bias": "model-14-of-35.safetensors",
|
181 |
+
"model.layers.13.self_attn.v_proj.bias": "model-14-of-35.safetensors",
|
182 |
+
"model.layers.13.input_layernorm.weight": "model-14-of-35.safetensors",
|
183 |
+
"model.layers.13.post_attention_layernorm.weight": "model-14-of-35.safetensors",
|
184 |
+
"model.layers.13.mlp.gate_proj.weight": "model-14-of-35.safetensors",
|
185 |
+
"model.layers.13.mlp.down_proj.weight": "model-14-of-35.safetensors",
|
186 |
+
"model.layers.13.mlp.up_proj.weight": "model-14-of-35.safetensors",
|
187 |
+
"model.layers.13.self_attn.rotary_emb.inv_freq": "model-14-of-35.safetensors",
|
188 |
+
"model.layers.14.self_attn.q_proj.weight": "model-15-of-35.safetensors",
|
189 |
+
"model.layers.14.self_attn.k_proj.weight": "model-15-of-35.safetensors",
|
190 |
+
"model.layers.14.self_attn.v_proj.weight": "model-15-of-35.safetensors",
|
191 |
+
"model.layers.14.self_attn.o_proj.weight": "model-15-of-35.safetensors",
|
192 |
+
"model.layers.14.self_attn.q_proj.bias": "model-15-of-35.safetensors",
|
193 |
+
"model.layers.14.self_attn.k_proj.bias": "model-15-of-35.safetensors",
|
194 |
+
"model.layers.14.self_attn.v_proj.bias": "model-15-of-35.safetensors",
|
195 |
+
"model.layers.14.input_layernorm.weight": "model-15-of-35.safetensors",
|
196 |
+
"model.layers.14.post_attention_layernorm.weight": "model-15-of-35.safetensors",
|
197 |
+
"model.layers.14.mlp.gate_proj.weight": "model-15-of-35.safetensors",
|
198 |
+
"model.layers.14.mlp.down_proj.weight": "model-15-of-35.safetensors",
|
199 |
+
"model.layers.14.mlp.up_proj.weight": "model-15-of-35.safetensors",
|
200 |
+
"model.layers.14.self_attn.rotary_emb.inv_freq": "model-15-of-35.safetensors",
|
201 |
+
"model.layers.15.self_attn.q_proj.weight": "model-16-of-35.safetensors",
|
202 |
+
"model.layers.15.self_attn.k_proj.weight": "model-16-of-35.safetensors",
|
203 |
+
"model.layers.15.self_attn.v_proj.weight": "model-16-of-35.safetensors",
|
204 |
+
"model.layers.15.self_attn.o_proj.weight": "model-16-of-35.safetensors",
|
205 |
+
"model.layers.15.self_attn.q_proj.bias": "model-16-of-35.safetensors",
|
206 |
+
"model.layers.15.self_attn.k_proj.bias": "model-16-of-35.safetensors",
|
207 |
+
"model.layers.15.self_attn.v_proj.bias": "model-16-of-35.safetensors",
|
208 |
+
"model.layers.15.input_layernorm.weight": "model-16-of-35.safetensors",
|
209 |
+
"model.layers.15.post_attention_layernorm.weight": "model-16-of-35.safetensors",
|
210 |
+
"model.layers.15.mlp.gate_proj.weight": "model-16-of-35.safetensors",
|
211 |
+
"model.layers.15.mlp.down_proj.weight": "model-16-of-35.safetensors",
|
212 |
+
"model.layers.15.mlp.up_proj.weight": "model-16-of-35.safetensors",
|
213 |
+
"model.layers.15.self_attn.rotary_emb.inv_freq": "model-16-of-35.safetensors",
|
214 |
+
"model.layers.16.self_attn.q_proj.weight": "model-17-of-35.safetensors",
|
215 |
+
"model.layers.16.self_attn.k_proj.weight": "model-17-of-35.safetensors",
|
216 |
+
"model.layers.16.self_attn.v_proj.weight": "model-17-of-35.safetensors",
|
217 |
+
"model.layers.16.self_attn.o_proj.weight": "model-17-of-35.safetensors",
|
218 |
+
"model.layers.16.self_attn.q_proj.bias": "model-17-of-35.safetensors",
|
219 |
+
"model.layers.16.self_attn.k_proj.bias": "model-17-of-35.safetensors",
|
220 |
+
"model.layers.16.self_attn.v_proj.bias": "model-17-of-35.safetensors",
|
221 |
+
"model.layers.16.input_layernorm.weight": "model-17-of-35.safetensors",
|
222 |
+
"model.layers.16.post_attention_layernorm.weight": "model-17-of-35.safetensors",
|
223 |
+
"model.layers.16.mlp.gate_proj.weight": "model-17-of-35.safetensors",
|
224 |
+
"model.layers.16.mlp.down_proj.weight": "model-17-of-35.safetensors",
|
225 |
+
"model.layers.16.mlp.up_proj.weight": "model-17-of-35.safetensors",
|
226 |
+
"model.layers.16.self_attn.rotary_emb.inv_freq": "model-17-of-35.safetensors",
|
227 |
+
"model.layers.17.self_attn.q_proj.weight": "model-18-of-35.safetensors",
|
228 |
+
"model.layers.17.self_attn.k_proj.weight": "model-18-of-35.safetensors",
|
229 |
+
"model.layers.17.self_attn.v_proj.weight": "model-18-of-35.safetensors",
|
230 |
+
"model.layers.17.self_attn.o_proj.weight": "model-18-of-35.safetensors",
|
231 |
+
"model.layers.17.self_attn.q_proj.bias": "model-18-of-35.safetensors",
|
232 |
+
"model.layers.17.self_attn.k_proj.bias": "model-18-of-35.safetensors",
|
233 |
+
"model.layers.17.self_attn.v_proj.bias": "model-18-of-35.safetensors",
|
234 |
+
"model.layers.17.input_layernorm.weight": "model-18-of-35.safetensors",
|
235 |
+
"model.layers.17.post_attention_layernorm.weight": "model-18-of-35.safetensors",
|
236 |
+
"model.layers.17.mlp.gate_proj.weight": "model-18-of-35.safetensors",
|
237 |
+
"model.layers.17.mlp.down_proj.weight": "model-18-of-35.safetensors",
|
238 |
+
"model.layers.17.mlp.up_proj.weight": "model-18-of-35.safetensors",
|
239 |
+
"model.layers.17.self_attn.rotary_emb.inv_freq": "model-18-of-35.safetensors",
|
240 |
+
"model.layers.18.self_attn.q_proj.weight": "model-19-of-35.safetensors",
|
241 |
+
"model.layers.18.self_attn.k_proj.weight": "model-19-of-35.safetensors",
|
242 |
+
"model.layers.18.self_attn.v_proj.weight": "model-19-of-35.safetensors",
|
243 |
+
"model.layers.18.self_attn.o_proj.weight": "model-19-of-35.safetensors",
|
244 |
+
"model.layers.18.self_attn.q_proj.bias": "model-19-of-35.safetensors",
|
245 |
+
"model.layers.18.self_attn.k_proj.bias": "model-19-of-35.safetensors",
|
246 |
+
"model.layers.18.self_attn.v_proj.bias": "model-19-of-35.safetensors",
|
247 |
+
"model.layers.18.input_layernorm.weight": "model-19-of-35.safetensors",
|
248 |
+
"model.layers.18.post_attention_layernorm.weight": "model-19-of-35.safetensors",
|
249 |
+
"model.layers.18.mlp.gate_proj.weight": "model-19-of-35.safetensors",
|
250 |
+
"model.layers.18.mlp.down_proj.weight": "model-19-of-35.safetensors",
|
251 |
+
"model.layers.18.mlp.up_proj.weight": "model-19-of-35.safetensors",
|
252 |
+
"model.layers.18.self_attn.rotary_emb.inv_freq": "model-19-of-35.safetensors",
|
253 |
+
"model.layers.19.self_attn.q_proj.weight": "model-20-of-35.safetensors",
|
254 |
+
"model.layers.19.self_attn.k_proj.weight": "model-20-of-35.safetensors",
|
255 |
+
"model.layers.19.self_attn.v_proj.weight": "model-20-of-35.safetensors",
|
256 |
+
"model.layers.19.self_attn.o_proj.weight": "model-20-of-35.safetensors",
|
257 |
+
"model.layers.19.self_attn.q_proj.bias": "model-20-of-35.safetensors",
|
258 |
+
"model.layers.19.self_attn.k_proj.bias": "model-20-of-35.safetensors",
|
259 |
+
"model.layers.19.self_attn.v_proj.bias": "model-20-of-35.safetensors",
|
260 |
+
"model.layers.19.input_layernorm.weight": "model-20-of-35.safetensors",
|
261 |
+
"model.layers.19.post_attention_layernorm.weight": "model-20-of-35.safetensors",
|
262 |
+
"model.layers.19.mlp.gate_proj.weight": "model-20-of-35.safetensors",
|
263 |
+
"model.layers.19.mlp.down_proj.weight": "model-20-of-35.safetensors",
|
264 |
+
"model.layers.19.mlp.up_proj.weight": "model-20-of-35.safetensors",
|
265 |
+
"model.layers.19.self_attn.rotary_emb.inv_freq": "model-20-of-35.safetensors",
|
266 |
+
"model.layers.20.self_attn.q_proj.weight": "model-21-of-35.safetensors",
|
267 |
+
"model.layers.20.self_attn.k_proj.weight": "model-21-of-35.safetensors",
|
268 |
+
"model.layers.20.self_attn.v_proj.weight": "model-21-of-35.safetensors",
|
269 |
+
"model.layers.20.self_attn.o_proj.weight": "model-21-of-35.safetensors",
|
270 |
+
"model.layers.20.self_attn.q_proj.bias": "model-21-of-35.safetensors",
|
271 |
+
"model.layers.20.self_attn.k_proj.bias": "model-21-of-35.safetensors",
|
272 |
+
"model.layers.20.self_attn.v_proj.bias": "model-21-of-35.safetensors",
|
273 |
+
"model.layers.20.input_layernorm.weight": "model-21-of-35.safetensors",
|
274 |
+
"model.layers.20.post_attention_layernorm.weight": "model-21-of-35.safetensors",
|
275 |
+
"model.layers.20.mlp.gate_proj.weight": "model-21-of-35.safetensors",
|
276 |
+
"model.layers.20.mlp.down_proj.weight": "model-21-of-35.safetensors",
|
277 |
+
"model.layers.20.mlp.up_proj.weight": "model-21-of-35.safetensors",
|
278 |
+
"model.layers.20.self_attn.rotary_emb.inv_freq": "model-21-of-35.safetensors",
|
279 |
+
"model.layers.21.self_attn.q_proj.weight": "model-22-of-35.safetensors",
|
280 |
+
"model.layers.21.self_attn.k_proj.weight": "model-22-of-35.safetensors",
|
281 |
+
"model.layers.21.self_attn.v_proj.weight": "model-22-of-35.safetensors",
|
282 |
+
"model.layers.21.self_attn.o_proj.weight": "model-22-of-35.safetensors",
|
283 |
+
"model.layers.21.self_attn.q_proj.bias": "model-22-of-35.safetensors",
|
284 |
+
"model.layers.21.self_attn.k_proj.bias": "model-22-of-35.safetensors",
|
285 |
+
"model.layers.21.self_attn.v_proj.bias": "model-22-of-35.safetensors",
|
286 |
+
"model.layers.21.input_layernorm.weight": "model-22-of-35.safetensors",
|
287 |
+
"model.layers.21.post_attention_layernorm.weight": "model-22-of-35.safetensors",
|
288 |
+
"model.layers.21.mlp.gate_proj.weight": "model-22-of-35.safetensors",
|
289 |
+
"model.layers.21.mlp.down_proj.weight": "model-22-of-35.safetensors",
|
290 |
+
"model.layers.21.mlp.up_proj.weight": "model-22-of-35.safetensors",
|
291 |
+
"model.layers.21.self_attn.rotary_emb.inv_freq": "model-22-of-35.safetensors",
|
292 |
+
"model.layers.22.self_attn.q_proj.weight": "model-23-of-35.safetensors",
|
293 |
+
"model.layers.22.self_attn.k_proj.weight": "model-23-of-35.safetensors",
|
294 |
+
"model.layers.22.self_attn.v_proj.weight": "model-23-of-35.safetensors",
|
295 |
+
"model.layers.22.self_attn.o_proj.weight": "model-23-of-35.safetensors",
|
296 |
+
"model.layers.22.self_attn.q_proj.bias": "model-23-of-35.safetensors",
|
297 |
+
"model.layers.22.self_attn.k_proj.bias": "model-23-of-35.safetensors",
|
298 |
+
"model.layers.22.self_attn.v_proj.bias": "model-23-of-35.safetensors",
|
299 |
+
"model.layers.22.input_layernorm.weight": "model-23-of-35.safetensors",
|
300 |
+
"model.layers.22.post_attention_layernorm.weight": "model-23-of-35.safetensors",
|
301 |
+
"model.layers.22.mlp.gate_proj.weight": "model-23-of-35.safetensors",
|
302 |
+
"model.layers.22.mlp.down_proj.weight": "model-23-of-35.safetensors",
|
303 |
+
"model.layers.22.mlp.up_proj.weight": "model-23-of-35.safetensors",
|
304 |
+
"model.layers.22.self_attn.rotary_emb.inv_freq": "model-23-of-35.safetensors",
|
305 |
+
"model.layers.23.self_attn.q_proj.weight": "model-24-of-35.safetensors",
|
306 |
+
"model.layers.23.self_attn.k_proj.weight": "model-24-of-35.safetensors",
|
307 |
+
"model.layers.23.self_attn.v_proj.weight": "model-24-of-35.safetensors",
|
308 |
+
"model.layers.23.self_attn.o_proj.weight": "model-24-of-35.safetensors",
|
309 |
+
"model.layers.23.self_attn.q_proj.bias": "model-24-of-35.safetensors",
|
310 |
+
"model.layers.23.self_attn.k_proj.bias": "model-24-of-35.safetensors",
|
311 |
+
"model.layers.23.self_attn.v_proj.bias": "model-24-of-35.safetensors",
|
312 |
+
"model.layers.23.input_layernorm.weight": "model-24-of-35.safetensors",
|
313 |
+
"model.layers.23.post_attention_layernorm.weight": "model-24-of-35.safetensors",
|
314 |
+
"model.layers.23.mlp.gate_proj.weight": "model-24-of-35.safetensors",
|
315 |
+
"model.layers.23.mlp.down_proj.weight": "model-24-of-35.safetensors",
|
316 |
+
"model.layers.23.mlp.up_proj.weight": "model-24-of-35.safetensors",
|
317 |
+
"model.layers.23.self_attn.rotary_emb.inv_freq": "model-24-of-35.safetensors",
|
318 |
+
"model.layers.24.self_attn.q_proj.weight": "model-25-of-35.safetensors",
|
319 |
+
"model.layers.24.self_attn.k_proj.weight": "model-25-of-35.safetensors",
|
320 |
+
"model.layers.24.self_attn.v_proj.weight": "model-25-of-35.safetensors",
|
321 |
+
"model.layers.24.self_attn.o_proj.weight": "model-25-of-35.safetensors",
|
322 |
+
"model.layers.24.self_attn.q_proj.bias": "model-25-of-35.safetensors",
|
323 |
+
"model.layers.24.self_attn.k_proj.bias": "model-25-of-35.safetensors",
|
324 |
+
"model.layers.24.self_attn.v_proj.bias": "model-25-of-35.safetensors",
|
325 |
+
"model.layers.24.input_layernorm.weight": "model-25-of-35.safetensors",
|
326 |
+
"model.layers.24.post_attention_layernorm.weight": "model-25-of-35.safetensors",
|
327 |
+
"model.layers.24.mlp.gate_proj.weight": "model-25-of-35.safetensors",
|
328 |
+
"model.layers.24.mlp.down_proj.weight": "model-25-of-35.safetensors",
|
329 |
+
"model.layers.24.mlp.up_proj.weight": "model-25-of-35.safetensors",
|
330 |
+
"model.layers.24.self_attn.rotary_emb.inv_freq": "model-25-of-35.safetensors",
|
331 |
+
"model.layers.25.self_attn.q_proj.weight": "model-26-of-35.safetensors",
|
332 |
+
"model.layers.25.self_attn.k_proj.weight": "model-26-of-35.safetensors",
|
333 |
+
"model.layers.25.self_attn.v_proj.weight": "model-26-of-35.safetensors",
|
334 |
+
"model.layers.25.self_attn.o_proj.weight": "model-26-of-35.safetensors",
|
335 |
+
"model.layers.25.self_attn.q_proj.bias": "model-26-of-35.safetensors",
|
336 |
+
"model.layers.25.self_attn.k_proj.bias": "model-26-of-35.safetensors",
|
337 |
+
"model.layers.25.self_attn.v_proj.bias": "model-26-of-35.safetensors",
|
338 |
+
"model.layers.25.input_layernorm.weight": "model-26-of-35.safetensors",
|
339 |
+
"model.layers.25.post_attention_layernorm.weight": "model-26-of-35.safetensors",
|
340 |
+
"model.layers.25.mlp.gate_proj.weight": "model-26-of-35.safetensors",
|
341 |
+
"model.layers.25.mlp.down_proj.weight": "model-26-of-35.safetensors",
|
342 |
+
"model.layers.25.mlp.up_proj.weight": "model-26-of-35.safetensors",
|
343 |
+
"model.layers.25.self_attn.rotary_emb.inv_freq": "model-26-of-35.safetensors",
|
344 |
+
"model.layers.26.self_attn.q_proj.weight": "model-27-of-35.safetensors",
|
345 |
+
"model.layers.26.self_attn.k_proj.weight": "model-27-of-35.safetensors",
|
346 |
+
"model.layers.26.self_attn.v_proj.weight": "model-27-of-35.safetensors",
|
347 |
+
"model.layers.26.self_attn.o_proj.weight": "model-27-of-35.safetensors",
|
348 |
+
"model.layers.26.self_attn.q_proj.bias": "model-27-of-35.safetensors",
|
349 |
+
"model.layers.26.self_attn.k_proj.bias": "model-27-of-35.safetensors",
|
350 |
+
"model.layers.26.self_attn.v_proj.bias": "model-27-of-35.safetensors",
|
351 |
+
"model.layers.26.input_layernorm.weight": "model-27-of-35.safetensors",
|
352 |
+
"model.layers.26.post_attention_layernorm.weight": "model-27-of-35.safetensors",
|
353 |
+
"model.layers.26.mlp.gate_proj.weight": "model-27-of-35.safetensors",
|
354 |
+
"model.layers.26.mlp.down_proj.weight": "model-27-of-35.safetensors",
|
355 |
+
"model.layers.26.mlp.up_proj.weight": "model-27-of-35.safetensors",
|
356 |
+
"model.layers.26.self_attn.rotary_emb.inv_freq": "model-27-of-35.safetensors",
|
357 |
+
"model.layers.27.self_attn.q_proj.weight": "model-28-of-35.safetensors",
|
358 |
+
"model.layers.27.self_attn.k_proj.weight": "model-28-of-35.safetensors",
|
359 |
+
"model.layers.27.self_attn.v_proj.weight": "model-28-of-35.safetensors",
|
360 |
+
"model.layers.27.self_attn.o_proj.weight": "model-28-of-35.safetensors",
|
361 |
+
"model.layers.27.self_attn.q_proj.bias": "model-28-of-35.safetensors",
|
362 |
+
"model.layers.27.self_attn.k_proj.bias": "model-28-of-35.safetensors",
|
363 |
+
"model.layers.27.self_attn.v_proj.bias": "model-28-of-35.safetensors",
|
364 |
+
"model.layers.27.input_layernorm.weight": "model-28-of-35.safetensors",
|
365 |
+
"model.layers.27.post_attention_layernorm.weight": "model-28-of-35.safetensors",
|
366 |
+
"model.layers.27.mlp.gate_proj.weight": "model-28-of-35.safetensors",
|
367 |
+
"model.layers.27.mlp.down_proj.weight": "model-28-of-35.safetensors",
|
368 |
+
"model.layers.27.mlp.up_proj.weight": "model-28-of-35.safetensors",
|
369 |
+
"model.layers.27.self_attn.rotary_emb.inv_freq": "model-28-of-35.safetensors",
|
370 |
+
"model.mimo_layers.0.self_attn.q_proj.weight": "model-29-of-35.safetensors",
|
371 |
+
"model.mimo_layers.0.self_attn.k_proj.weight": "model-29-of-35.safetensors",
|
372 |
+
"model.mimo_layers.0.self_attn.v_proj.weight": "model-29-of-35.safetensors",
|
373 |
+
"model.mimo_layers.0.self_attn.o_proj.weight": "model-29-of-35.safetensors",
|
374 |
+
"model.mimo_layers.0.input_layernorm.weight": "model-29-of-35.safetensors",
|
375 |
+
"model.mimo_layers.0.post_attention_layernorm.weight": "model-29-of-35.safetensors",
|
376 |
+
"model.mimo_layers.0.mlp.gate_proj.weight": "model-29-of-35.safetensors",
|
377 |
+
"model.mimo_layers.0.mlp.down_proj.weight": "model-29-of-35.safetensors",
|
378 |
+
"model.mimo_layers.0.mlp.up_proj.weight": "model-29-of-35.safetensors",
|
379 |
+
"model.mimo_layers.0.self_attn.q_proj.bias": "model-29-of-35.safetensors",
|
380 |
+
"model.mimo_layers.0.self_attn.k_proj.bias": "model-29-of-35.safetensors",
|
381 |
+
"model.mimo_layers.0.self_attn.v_proj.bias": "model-29-of-35.safetensors",
|
382 |
+
"model.mimo_layers.0.self_attn.rotary_emb.inv_freq": "model-29-of-35.safetensors",
|
383 |
+
"model.mimo_layers.1.self_attn.q_proj.weight": "model-30-of-35.safetensors",
|
384 |
+
"model.mimo_layers.1.self_attn.k_proj.weight": "model-30-of-35.safetensors",
|
385 |
+
"model.mimo_layers.1.self_attn.v_proj.weight": "model-30-of-35.safetensors",
|
386 |
+
"model.mimo_layers.1.self_attn.o_proj.weight": "model-30-of-35.safetensors",
|
387 |
+
"model.mimo_layers.1.input_layernorm.weight": "model-30-of-35.safetensors",
|
388 |
+
"model.mimo_layers.1.post_attention_layernorm.weight": "model-30-of-35.safetensors",
|
389 |
+
"model.mimo_layers.1.mlp.gate_proj.weight": "model-30-of-35.safetensors",
|
390 |
+
"model.mimo_layers.1.mlp.down_proj.weight": "model-30-of-35.safetensors",
|
391 |
+
"model.mimo_layers.1.mlp.up_proj.weight": "model-30-of-35.safetensors",
|
392 |
+
"model.mimo_layers.1.self_attn.q_proj.bias": "model-30-of-35.safetensors",
|
393 |
+
"model.mimo_layers.1.self_attn.k_proj.bias": "model-30-of-35.safetensors",
|
394 |
+
"model.mimo_layers.1.self_attn.v_proj.bias": "model-30-of-35.safetensors",
|
395 |
+
"model.mimo_layers.1.self_attn.rotary_emb.inv_freq": "model-30-of-35.safetensors",
|
396 |
+
"model.mimo_layers.2.self_attn.q_proj.weight": "model-31-of-35.safetensors",
|
397 |
+
"model.mimo_layers.2.self_attn.k_proj.weight": "model-31-of-35.safetensors",
|
398 |
+
"model.mimo_layers.2.self_attn.v_proj.weight": "model-31-of-35.safetensors",
|
399 |
+
"model.mimo_layers.2.self_attn.o_proj.weight": "model-31-of-35.safetensors",
|
400 |
+
"model.mimo_layers.2.input_layernorm.weight": "model-31-of-35.safetensors",
|
401 |
+
"model.mimo_layers.2.post_attention_layernorm.weight": "model-31-of-35.safetensors",
|
402 |
+
"model.mimo_layers.2.mlp.gate_proj.weight": "model-31-of-35.safetensors",
|
403 |
+
"model.mimo_layers.2.mlp.down_proj.weight": "model-31-of-35.safetensors",
|
404 |
+
"model.mimo_layers.2.mlp.up_proj.weight": "model-31-of-35.safetensors",
|
405 |
+
"model.mimo_layers.2.self_attn.q_proj.bias": "model-31-of-35.safetensors",
|
406 |
+
"model.mimo_layers.2.self_attn.k_proj.bias": "model-31-of-35.safetensors",
|
407 |
+
"model.mimo_layers.2.self_attn.v_proj.bias": "model-31-of-35.safetensors",
|
408 |
+
"model.mimo_layers.2.self_attn.rotary_emb.inv_freq": "model-31-of-35.safetensors",
|
409 |
+
"model.mimo_layers.3.self_attn.q_proj.weight": "model-32-of-35.safetensors",
|
410 |
+
"model.mimo_layers.3.self_attn.k_proj.weight": "model-32-of-35.safetensors",
|
411 |
+
"model.mimo_layers.3.self_attn.v_proj.weight": "model-32-of-35.safetensors",
|
412 |
+
"model.mimo_layers.3.self_attn.o_proj.weight": "model-32-of-35.safetensors",
|
413 |
+
"model.mimo_layers.3.input_layernorm.weight": "model-32-of-35.safetensors",
|
414 |
+
"model.mimo_layers.3.post_attention_layernorm.weight": "model-32-of-35.safetensors",
|
415 |
+
"model.mimo_layers.3.mlp.gate_proj.weight": "model-32-of-35.safetensors",
|
416 |
+
"model.mimo_layers.3.mlp.down_proj.weight": "model-32-of-35.safetensors",
|
417 |
+
"model.mimo_layers.3.mlp.up_proj.weight": "model-32-of-35.safetensors",
|
418 |
+
"model.mimo_layers.3.self_attn.q_proj.bias": "model-32-of-35.safetensors",
|
419 |
+
"model.mimo_layers.3.self_attn.k_proj.bias": "model-32-of-35.safetensors",
|
420 |
+
"model.mimo_layers.3.self_attn.v_proj.bias": "model-32-of-35.safetensors",
|
421 |
+
"model.mimo_layers.3.self_attn.rotary_emb.inv_freq": "model-32-of-35.safetensors",
|
422 |
+
"model.mimo_layers.4.self_attn.q_proj.weight": "model-33-of-35.safetensors",
|
423 |
+
"model.mimo_layers.4.self_attn.k_proj.weight": "model-33-of-35.safetensors",
|
424 |
+
"model.mimo_layers.4.self_attn.v_proj.weight": "model-33-of-35.safetensors",
|
425 |
+
"model.mimo_layers.4.self_attn.o_proj.weight": "model-33-of-35.safetensors",
|
426 |
+
"model.mimo_layers.4.input_layernorm.weight": "model-33-of-35.safetensors",
|
427 |
+
"model.mimo_layers.4.post_attention_layernorm.weight": "model-33-of-35.safetensors",
|
428 |
+
"model.mimo_layers.4.mlp.gate_proj.weight": "model-33-of-35.safetensors",
|
429 |
+
"model.mimo_layers.4.mlp.down_proj.weight": "model-33-of-35.safetensors",
|
430 |
+
"model.mimo_layers.4.mlp.up_proj.weight": "model-33-of-35.safetensors",
|
431 |
+
"model.mimo_layers.4.self_attn.q_proj.bias": "model-33-of-35.safetensors",
|
432 |
+
"model.mimo_layers.4.self_attn.k_proj.bias": "model-33-of-35.safetensors",
|
433 |
+
"model.mimo_layers.4.self_attn.v_proj.bias": "model-33-of-35.safetensors",
|
434 |
+
"model.mimo_layers.4.self_attn.rotary_emb.inv_freq": "model-33-of-35.safetensors",
|
435 |
+
"model.mimo_layers.5.self_attn.q_proj.weight": "model-34-of-35.safetensors",
|
436 |
+
"model.mimo_layers.5.self_attn.k_proj.weight": "model-34-of-35.safetensors",
|
437 |
+
"model.mimo_layers.5.self_attn.v_proj.weight": "model-34-of-35.safetensors",
|
438 |
+
"model.mimo_layers.5.self_attn.o_proj.weight": "model-34-of-35.safetensors",
|
439 |
+
"model.mimo_layers.5.input_layernorm.weight": "model-34-of-35.safetensors",
|
440 |
+
"model.mimo_layers.5.post_attention_layernorm.weight": "model-34-of-35.safetensors",
|
441 |
+
"model.mimo_layers.5.mlp.gate_proj.weight": "model-34-of-35.safetensors",
|
442 |
+
"model.mimo_layers.5.mlp.down_proj.weight": "model-34-of-35.safetensors",
|
443 |
+
"model.mimo_layers.5.mlp.up_proj.weight": "model-34-of-35.safetensors",
|
444 |
+
"model.mimo_layers.5.self_attn.q_proj.bias": "model-34-of-35.safetensors",
|
445 |
+
"model.mimo_layers.5.self_attn.k_proj.bias": "model-34-of-35.safetensors",
|
446 |
+
"model.mimo_layers.5.self_attn.v_proj.bias": "model-34-of-35.safetensors",
|
447 |
+
"model.mimo_layers.5.self_attn.rotary_emb.inv_freq": "model-34-of-35.safetensors",
|
448 |
+
"model.vq_adaptor.layers.0.weight": "model-35-of-35.safetensors",
|
449 |
+
"model.vq_adaptor.layers.0.bias": "model-35-of-35.safetensors",
|
450 |
+
"model.vq_adaptor.layers.3.weight": "model-35-of-35.safetensors",
|
451 |
+
"model.vq_adaptor.layers.3.bias": "model-35-of-35.safetensors",
|
452 |
+
"model.vq_adaptor.layers.4.weight": "model-35-of-35.safetensors",
|
453 |
+
"model.vq_adaptor.layers.4.bias": "model-35-of-35.safetensors",
|
454 |
+
"model.embed_tokens.weight": "model-36-of-36.safetensors",
|
455 |
+
"model.norm.weight": "model-36-of-36.safetensors",
|
456 |
+
"lm_head.weight": "model-36-of-36.safetensors",
|
457 |
+
"mimo_output.weight": "model-36-of-36.safetensors",
|
458 |
+
"model.mimo_norm.weight": "model-36-of-36.safetensors"
|
459 |
+
}
|
460 |
+
}
|
modeling_moonshot_kimia.py
ADDED
@@ -0,0 +1,917 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2025 The Moonshot AI Team, Qwen Team, and HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# The code is based on Qwen2.5-7B, but modified for KimiAudio.
|
5 |
+
#
|
6 |
+
# Licensing Information:
|
7 |
+
# - Code derived from Qwen2.5-7B is licensed under the Apache License, Version 2.0.
|
8 |
+
# - Other parts of the code are licensed under the MIT License.
|
9 |
+
#
|
10 |
+
# Apache License, Version 2.0:
|
11 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
12 |
+
# you may not use this file except in compliance with the License.
|
13 |
+
# You may obtain a copy of the License at
|
14 |
+
#
|
15 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
16 |
+
#
|
17 |
+
# Unless required by applicable law or agreed to in writing, software
|
18 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
19 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
20 |
+
# See the License for the specific language governing permissions and
|
21 |
+
# limitations under the License.
|
22 |
+
#
|
23 |
+
# MIT License:
|
24 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
25 |
+
# of this software and associated documentation files (the "Software"), to deal
|
26 |
+
# in the Software without restriction, including without limitation the rights
|
27 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
28 |
+
# copies of the Software, and to permit persons to whom the Software is
|
29 |
+
# furnished to do so, subject to the following conditions:
|
30 |
+
#
|
31 |
+
# The above copyright notice and this permission notice shall be included in all
|
32 |
+
# copies or substantial portions of the Software.
|
33 |
+
#
|
34 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
35 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
36 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
37 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
38 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
39 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
40 |
+
# SOFTWARE.
|
41 |
+
"""PyTorch KimiAudio model."""
|
42 |
+
|
43 |
+
from typing import List, Optional, Tuple, Union
|
44 |
+
import torch
|
45 |
+
import torch.utils.checkpoint
|
46 |
+
from torch import nn
|
47 |
+
|
48 |
+
import transformers
|
49 |
+
from packaging import version
|
50 |
+
|
51 |
+
assert version.parse(transformers.__version__) >= version.parse("4.34.1")
|
52 |
+
|
53 |
+
from transformers.modeling_outputs import (
|
54 |
+
BaseModelOutputWithPast,
|
55 |
+
CausalLMOutputWithPast,
|
56 |
+
)
|
57 |
+
from transformers.utils import (
|
58 |
+
logging,
|
59 |
+
)
|
60 |
+
from .configuration_moonshot_kimia import KimiAudioConfig
|
61 |
+
import torch.nn.functional as F
|
62 |
+
from transformers.models.qwen2.modeling_qwen2 import (
|
63 |
+
Qwen2RMSNorm,
|
64 |
+
Qwen2MLP,
|
65 |
+
Qwen2PreTrainedModel,
|
66 |
+
)
|
67 |
+
from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb
|
68 |
+
|
69 |
+
if version.parse(transformers.__version__) >= version.parse("4.35.0"):
|
70 |
+
from transformers.utils import is_flash_attn_2_available as is_flash_attn_available
|
71 |
+
else:
|
72 |
+
from transformers.utils import is_flash_attn_available
|
73 |
+
|
74 |
+
if is_flash_attn_available():
|
75 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
76 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
77 |
+
else:
|
78 |
+
raise RuntimeError("flash attention must be installed")
|
79 |
+
|
80 |
+
|
81 |
+
logger = logging.get_logger(__name__)
|
82 |
+
|
83 |
+
|
84 |
+
def _get_unpad_data(padding_mask):
|
85 |
+
seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
|
86 |
+
indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
|
87 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
88 |
+
cu_seqlens = F.pad(
|
89 |
+
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
|
90 |
+
)
|
91 |
+
return (
|
92 |
+
indices,
|
93 |
+
cu_seqlens,
|
94 |
+
max_seqlen_in_batch,
|
95 |
+
)
|
96 |
+
|
97 |
+
|
98 |
+
def _upad_input(query_layer, key_layer, value_layer, padding_mask, query_length):
|
99 |
+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask)
|
100 |
+
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
101 |
+
num_heads = query_layer.shape[2]
|
102 |
+
|
103 |
+
key_layer = index_first_axis(
|
104 |
+
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
|
105 |
+
indices_k,
|
106 |
+
)
|
107 |
+
value_layer = index_first_axis(
|
108 |
+
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
|
109 |
+
indices_k,
|
110 |
+
)
|
111 |
+
if query_length == kv_seq_len:
|
112 |
+
query_layer = index_first_axis(
|
113 |
+
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
|
114 |
+
)
|
115 |
+
cu_seqlens_q = cu_seqlens_k
|
116 |
+
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
117 |
+
indices_q = indices_k
|
118 |
+
elif query_length == 1:
|
119 |
+
max_seqlen_in_batch_q = 1
|
120 |
+
cu_seqlens_q = torch.arange(
|
121 |
+
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
122 |
+
) # There is a memcpy here, that is very bad.
|
123 |
+
indices_q = cu_seqlens_q[:-1]
|
124 |
+
query_layer = query_layer.squeeze(1)
|
125 |
+
else:
|
126 |
+
# The -q_len: slice assumes left padding.
|
127 |
+
padding_mask = padding_mask[:, -query_length:]
|
128 |
+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
|
129 |
+
query_layer, padding_mask
|
130 |
+
)
|
131 |
+
|
132 |
+
return (
|
133 |
+
query_layer,
|
134 |
+
key_layer,
|
135 |
+
value_layer,
|
136 |
+
indices_q,
|
137 |
+
(cu_seqlens_q, cu_seqlens_k),
|
138 |
+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
139 |
+
)
|
140 |
+
|
141 |
+
|
142 |
+
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
143 |
+
def _make_causal_mask(
|
144 |
+
input_ids_shape: torch.Size,
|
145 |
+
dtype: torch.dtype,
|
146 |
+
device: torch.device,
|
147 |
+
past_key_values_length: int = 0,
|
148 |
+
):
|
149 |
+
"""
|
150 |
+
Make causal mask used for bi-directional self-attention.
|
151 |
+
"""
|
152 |
+
bsz, tgt_len = input_ids_shape
|
153 |
+
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
|
154 |
+
mask_cond = torch.arange(mask.size(-1), device=device)
|
155 |
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
156 |
+
mask = mask.to(dtype)
|
157 |
+
|
158 |
+
if past_key_values_length > 0:
|
159 |
+
mask = torch.cat(
|
160 |
+
[
|
161 |
+
torch.zeros(
|
162 |
+
tgt_len, past_key_values_length, dtype=dtype, device=device
|
163 |
+
),
|
164 |
+
mask,
|
165 |
+
],
|
166 |
+
dim=-1,
|
167 |
+
)
|
168 |
+
return mask[None, None, :, :].expand(
|
169 |
+
bsz, 1, tgt_len, tgt_len + past_key_values_length
|
170 |
+
)
|
171 |
+
|
172 |
+
|
173 |
+
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
174 |
+
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
175 |
+
"""
|
176 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
177 |
+
"""
|
178 |
+
bsz, src_len = mask.size()
|
179 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
180 |
+
|
181 |
+
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
182 |
+
|
183 |
+
inverted_mask = 1.0 - expanded_mask
|
184 |
+
|
185 |
+
return inverted_mask.masked_fill(
|
186 |
+
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
187 |
+
)
|
188 |
+
|
189 |
+
|
190 |
+
class RotaryEmbedding(nn.Module):
|
191 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
192 |
+
super().__init__()
|
193 |
+
|
194 |
+
self.dim = dim
|
195 |
+
self.max_position_embeddings = max_position_embeddings
|
196 |
+
self.base = base
|
197 |
+
inv_freq = 1.0 / (
|
198 |
+
self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
|
199 |
+
)
|
200 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
201 |
+
|
202 |
+
# Build here to make `torch.jit.trace` work.
|
203 |
+
self._set_cos_sin_cache(
|
204 |
+
seq_len=max_position_embeddings,
|
205 |
+
device=self.inv_freq.device,
|
206 |
+
dtype=torch.get_default_dtype(),
|
207 |
+
)
|
208 |
+
|
209 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
210 |
+
self.max_seq_len_cached = seq_len
|
211 |
+
t = torch.arange(
|
212 |
+
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
|
213 |
+
)
|
214 |
+
|
215 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
216 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
217 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
218 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
219 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
220 |
+
|
221 |
+
def forward(self, x, seq_len=None):
|
222 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
223 |
+
if seq_len > self.max_seq_len_cached:
|
224 |
+
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
225 |
+
|
226 |
+
return (
|
227 |
+
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
228 |
+
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
229 |
+
)
|
230 |
+
|
231 |
+
|
232 |
+
class MoonshotAttention(nn.Module):
|
233 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
234 |
+
|
235 |
+
def __init__(self, config: KimiAudioConfig):
|
236 |
+
super().__init__()
|
237 |
+
self.config = config
|
238 |
+
self.hidden_size = config.hidden_size
|
239 |
+
self.num_heads = config.num_attention_heads
|
240 |
+
self.head_dim = self.hidden_size // self.num_heads
|
241 |
+
self.num_key_value_heads = config.num_key_value_heads
|
242 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
243 |
+
self.max_position_embeddings = config.max_position_embeddings
|
244 |
+
self.rope_theta = config.rope_theta
|
245 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
246 |
+
raise ValueError(
|
247 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
248 |
+
f" and `num_heads`: {self.num_heads})."
|
249 |
+
)
|
250 |
+
self.q_proj = nn.Linear(
|
251 |
+
self.hidden_size, self.num_heads * self.head_dim, bias=True
|
252 |
+
)
|
253 |
+
self.k_proj = nn.Linear(
|
254 |
+
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
|
255 |
+
)
|
256 |
+
self.v_proj = nn.Linear(
|
257 |
+
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
|
258 |
+
)
|
259 |
+
self.o_proj = nn.Linear(
|
260 |
+
self.num_heads * self.head_dim, self.hidden_size, bias=False
|
261 |
+
)
|
262 |
+
|
263 |
+
self._init_rope()
|
264 |
+
|
265 |
+
def _init_rope(self):
|
266 |
+
|
267 |
+
self.rotary_emb = RotaryEmbedding(
|
268 |
+
self.head_dim,
|
269 |
+
max_position_embeddings=self.max_position_embeddings,
|
270 |
+
base=self.rope_theta,
|
271 |
+
)
|
272 |
+
|
273 |
+
def forward(
|
274 |
+
self,
|
275 |
+
hidden_states: torch.Tensor,
|
276 |
+
attention_mask: Optional[torch.Tensor] = None,
|
277 |
+
position_ids: Optional[torch.LongTensor] = None,
|
278 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
279 |
+
output_attentions: bool = False,
|
280 |
+
use_cache: bool = False,
|
281 |
+
padding_mask: Optional[torch.LongTensor] = None,
|
282 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
283 |
+
# LlamaFlashAttention2 attention does not support output_attentions
|
284 |
+
|
285 |
+
output_attentions = False
|
286 |
+
|
287 |
+
bsz, q_len, _ = hidden_states.size()
|
288 |
+
|
289 |
+
query_states = self.q_proj(hidden_states)
|
290 |
+
key_states = self.k_proj(hidden_states)
|
291 |
+
value_states = self.v_proj(hidden_states)
|
292 |
+
|
293 |
+
# Flash attention requires the input to have the shape
|
294 |
+
# batch_size x seq_length x head_dime x hidden_dim
|
295 |
+
# therefore we just need to keep the original shape
|
296 |
+
query_states = query_states.view(
|
297 |
+
bsz, q_len, self.num_heads, self.head_dim
|
298 |
+
).transpose(1, 2)
|
299 |
+
key_states = key_states.view(
|
300 |
+
bsz, q_len, self.num_key_value_heads, self.head_dim
|
301 |
+
).transpose(1, 2)
|
302 |
+
value_states = value_states.view(
|
303 |
+
bsz, q_len, self.num_key_value_heads, self.head_dim
|
304 |
+
).transpose(1, 2)
|
305 |
+
|
306 |
+
kv_seq_len = key_states.shape[-2]
|
307 |
+
if past_key_value is not None:
|
308 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
309 |
+
|
310 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
311 |
+
cos = cos[position_ids]
|
312 |
+
sin = sin[position_ids]
|
313 |
+
query_states, key_states = apply_rotary_pos_emb(
|
314 |
+
query_states, key_states, cos, sin, position_ids
|
315 |
+
)
|
316 |
+
|
317 |
+
if past_key_value is not None:
|
318 |
+
# reuse k, v, self_attention
|
319 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
320 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
321 |
+
|
322 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
323 |
+
|
324 |
+
query_states = query_states.transpose(1, 2)
|
325 |
+
key_states = key_states.transpose(1, 2)
|
326 |
+
value_states = value_states.transpose(1, 2)
|
327 |
+
|
328 |
+
# TODO: llama does not have dropout in the config??
|
329 |
+
# It is recommended to use dropout with FA according to the docs
|
330 |
+
# when training.
|
331 |
+
dropout_rate = 0.0 # if not self.training else self.attn_dropout
|
332 |
+
|
333 |
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
334 |
+
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
335 |
+
# cast them back in float16 just to be sure everything works as expected.
|
336 |
+
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
337 |
+
# in fp32. (LlamaRMSNorm handles it correctly)
|
338 |
+
input_dtype = query_states.dtype
|
339 |
+
if input_dtype == torch.float32:
|
340 |
+
logger.warning_once(
|
341 |
+
"The input hidden states seems to be silently casted in float32, this might be related to"
|
342 |
+
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
343 |
+
" float16."
|
344 |
+
)
|
345 |
+
|
346 |
+
query_states = query_states.to(torch.float16)
|
347 |
+
key_states = key_states.to(torch.float16)
|
348 |
+
value_states = value_states.to(torch.float16)
|
349 |
+
|
350 |
+
attn_output = self._flash_attention_forward(
|
351 |
+
query_states,
|
352 |
+
key_states,
|
353 |
+
value_states,
|
354 |
+
padding_mask,
|
355 |
+
q_len,
|
356 |
+
dropout=dropout_rate,
|
357 |
+
)
|
358 |
+
|
359 |
+
if input_dtype == torch.float32:
|
360 |
+
attn_output = attn_output.to(torch.float32)
|
361 |
+
|
362 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
363 |
+
attn_output = self.o_proj(attn_output)
|
364 |
+
|
365 |
+
if not output_attentions:
|
366 |
+
attn_weights = None
|
367 |
+
|
368 |
+
return attn_output, attn_weights, past_key_value
|
369 |
+
|
370 |
+
def _flash_attention_forward(
|
371 |
+
self,
|
372 |
+
query_states,
|
373 |
+
key_states,
|
374 |
+
value_states,
|
375 |
+
padding_mask,
|
376 |
+
query_length,
|
377 |
+
dropout=0.0,
|
378 |
+
softmax_scale=None,
|
379 |
+
):
|
380 |
+
"""
|
381 |
+
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
382 |
+
first unpad the input, then computes the attention scores and pad the final attention scores.
|
383 |
+
|
384 |
+
Args:
|
385 |
+
query_states (`torch.Tensor`):
|
386 |
+
Input query states to be passed to Flash Attention API
|
387 |
+
key_states (`torch.Tensor`):
|
388 |
+
Input key states to be passed to Flash Attention API
|
389 |
+
value_states (`torch.Tensor`):
|
390 |
+
Input value states to be passed to Flash Attention API
|
391 |
+
padding_mask (`torch.Tensor`):
|
392 |
+
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
393 |
+
position of padding tokens and 1 for the position of non-padding tokens.
|
394 |
+
dropout (`int`, *optional*):
|
395 |
+
Attention dropout
|
396 |
+
softmax_scale (`float`, *optional*):
|
397 |
+
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
398 |
+
"""
|
399 |
+
# Contains at least one padding token in the sequence
|
400 |
+
if padding_mask is not None:
|
401 |
+
batch_size = query_states.shape[0]
|
402 |
+
(
|
403 |
+
query_states,
|
404 |
+
key_states,
|
405 |
+
value_states,
|
406 |
+
indices_q,
|
407 |
+
cu_seq_lens,
|
408 |
+
max_seq_lens,
|
409 |
+
) = _upad_input(
|
410 |
+
query_states, key_states, value_states, padding_mask, query_length
|
411 |
+
)
|
412 |
+
|
413 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
414 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
415 |
+
|
416 |
+
attn_output_unpad = flash_attn_varlen_func(
|
417 |
+
query_states,
|
418 |
+
key_states,
|
419 |
+
value_states,
|
420 |
+
cu_seqlens_q=cu_seqlens_q,
|
421 |
+
cu_seqlens_k=cu_seqlens_k,
|
422 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
423 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
424 |
+
dropout_p=dropout,
|
425 |
+
softmax_scale=softmax_scale,
|
426 |
+
causal=True,
|
427 |
+
)
|
428 |
+
|
429 |
+
attn_output = pad_input(
|
430 |
+
attn_output_unpad, indices_q, batch_size, query_length
|
431 |
+
)
|
432 |
+
else:
|
433 |
+
attn_output = flash_attn_func(
|
434 |
+
query_states,
|
435 |
+
key_states,
|
436 |
+
value_states,
|
437 |
+
dropout,
|
438 |
+
softmax_scale=softmax_scale,
|
439 |
+
causal=True,
|
440 |
+
)
|
441 |
+
|
442 |
+
return attn_output
|
443 |
+
|
444 |
+
|
445 |
+
class MoonshotDecoderLayer(nn.Module):
|
446 |
+
def __init__(self, config: KimiAudioConfig):
|
447 |
+
super().__init__()
|
448 |
+
self.hidden_size = config.hidden_size
|
449 |
+
self.config = config
|
450 |
+
|
451 |
+
logger.warning_once("using normal flash attention")
|
452 |
+
self.self_attn = MoonshotAttention(config=config)
|
453 |
+
|
454 |
+
self.mlp = Qwen2MLP(config)
|
455 |
+
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
456 |
+
self.post_attention_layernorm = Qwen2RMSNorm(
|
457 |
+
config.hidden_size, eps=config.rms_norm_eps
|
458 |
+
)
|
459 |
+
|
460 |
+
def forward(
|
461 |
+
self,
|
462 |
+
hidden_states: torch.Tensor,
|
463 |
+
attention_mask: Optional[torch.Tensor] = None,
|
464 |
+
position_ids: Optional[torch.LongTensor] = None,
|
465 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
466 |
+
output_attentions: Optional[bool] = False,
|
467 |
+
use_cache: Optional[bool] = False,
|
468 |
+
padding_mask: Optional[torch.LongTensor] = None,
|
469 |
+
) -> Tuple[
|
470 |
+
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
471 |
+
]:
|
472 |
+
"""
|
473 |
+
Args:
|
474 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
475 |
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
476 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
477 |
+
output_attentions (`bool`, *optional*):
|
478 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
479 |
+
returned tensors for more detail.
|
480 |
+
use_cache (`bool`, *optional*):
|
481 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
482 |
+
(see `past_key_values`).
|
483 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
484 |
+
"""
|
485 |
+
|
486 |
+
residual = hidden_states
|
487 |
+
|
488 |
+
hidden_states = self.input_layernorm(hidden_states)
|
489 |
+
|
490 |
+
# Self Attention
|
491 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
492 |
+
hidden_states=hidden_states,
|
493 |
+
attention_mask=attention_mask,
|
494 |
+
position_ids=position_ids,
|
495 |
+
past_key_value=past_key_value,
|
496 |
+
output_attentions=output_attentions,
|
497 |
+
use_cache=use_cache,
|
498 |
+
padding_mask=padding_mask,
|
499 |
+
)
|
500 |
+
hidden_states = residual + hidden_states
|
501 |
+
|
502 |
+
# Fully Connected
|
503 |
+
residual = hidden_states
|
504 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
505 |
+
hidden_states = self.mlp(hidden_states)
|
506 |
+
hidden_states = residual + hidden_states
|
507 |
+
|
508 |
+
outputs = (hidden_states,)
|
509 |
+
|
510 |
+
if output_attentions:
|
511 |
+
outputs += (self_attn_weights,)
|
512 |
+
|
513 |
+
if use_cache:
|
514 |
+
outputs += (present_key_value,)
|
515 |
+
|
516 |
+
return outputs
|
517 |
+
|
518 |
+
|
519 |
+
class VQAdaptor(nn.Module):
|
520 |
+
def __init__(self, config):
|
521 |
+
super().__init__()
|
522 |
+
self.layers = nn.Sequential(
|
523 |
+
nn.Linear(config.kimia_adaptor_input_dim, config.hidden_size, bias=True),
|
524 |
+
nn.SiLU(),
|
525 |
+
nn.Dropout(0.0),
|
526 |
+
nn.Linear(config.hidden_size, config.hidden_size, bias=True),
|
527 |
+
nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps, bias=True),
|
528 |
+
)
|
529 |
+
|
530 |
+
def forward(self, x):
|
531 |
+
return self.layers(x)
|
532 |
+
|
533 |
+
|
534 |
+
class MoonshotKimiaModel(Qwen2PreTrainedModel):
|
535 |
+
"""
|
536 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`QwenDecoderLayer`]
|
537 |
+
|
538 |
+
Args:
|
539 |
+
config: KimiAudioConfig
|
540 |
+
"""
|
541 |
+
|
542 |
+
config_class = KimiAudioConfig
|
543 |
+
|
544 |
+
def __init__(self, config: KimiAudioConfig):
|
545 |
+
super().__init__(config)
|
546 |
+
self.padding_idx = config.pad_token_id
|
547 |
+
self.vocab_size = config.vocab_size
|
548 |
+
self.kimia_mimo_transformer_from_layer_index = (
|
549 |
+
config.kimia_mimo_transformer_from_layer_index
|
550 |
+
)
|
551 |
+
|
552 |
+
self.embed_tokens = nn.Embedding(
|
553 |
+
config.vocab_size, config.hidden_size, self.padding_idx
|
554 |
+
)
|
555 |
+
self.layers = nn.ModuleList(
|
556 |
+
[MoonshotDecoderLayer(config) for _ in range(config.num_hidden_layers)]
|
557 |
+
)
|
558 |
+
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
559 |
+
|
560 |
+
# extra 1B audio transformers
|
561 |
+
self.mimo_layers = nn.ModuleList(
|
562 |
+
[MoonshotDecoderLayer(config) for _ in range(config.kimia_mimo_layers)]
|
563 |
+
)
|
564 |
+
self.mimo_norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
565 |
+
self.use_whisper_feature = config.use_whisper_feature
|
566 |
+
if self.use_whisper_feature:
|
567 |
+
self.vq_adaptor = VQAdaptor(config)
|
568 |
+
self.kimia_media_begin = config.kimia_media_begin
|
569 |
+
self.kimia_media_end = config.kimia_media_end
|
570 |
+
|
571 |
+
self.gradient_checkpointing = False
|
572 |
+
# Initialize weights and apply final processing
|
573 |
+
self.post_init()
|
574 |
+
|
575 |
+
def get_input_embeddings(self):
|
576 |
+
return self.embed_tokens
|
577 |
+
|
578 |
+
def set_input_embeddings(self, value):
|
579 |
+
self.embed_tokens = value
|
580 |
+
|
581 |
+
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
582 |
+
def _prepare_decoder_attention_mask(
|
583 |
+
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
584 |
+
):
|
585 |
+
# create causal mask
|
586 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
587 |
+
combined_attention_mask = None
|
588 |
+
if input_shape[-1] > 1:
|
589 |
+
combined_attention_mask = _make_causal_mask(
|
590 |
+
input_shape,
|
591 |
+
inputs_embeds.dtype,
|
592 |
+
device=inputs_embeds.device,
|
593 |
+
past_key_values_length=past_key_values_length,
|
594 |
+
)
|
595 |
+
|
596 |
+
if attention_mask is not None:
|
597 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
598 |
+
expanded_attn_mask = _expand_mask(
|
599 |
+
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
600 |
+
).to(inputs_embeds.device)
|
601 |
+
combined_attention_mask = (
|
602 |
+
expanded_attn_mask
|
603 |
+
if combined_attention_mask is None
|
604 |
+
else expanded_attn_mask + combined_attention_mask
|
605 |
+
)
|
606 |
+
|
607 |
+
return combined_attention_mask
|
608 |
+
|
609 |
+
def forward(
|
610 |
+
self,
|
611 |
+
input_ids: torch.LongTensor = None,
|
612 |
+
text_input_ids: torch.LongTensor = None,
|
613 |
+
whisper_input_feature: Optional[torch.FloatTensor] = None,
|
614 |
+
is_continuous_mask: Optional[torch.Tensor] = None,
|
615 |
+
attention_mask: Optional[torch.Tensor] = None,
|
616 |
+
position_ids: Optional[torch.LongTensor] = None,
|
617 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
618 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
619 |
+
use_cache: Optional[bool] = None,
|
620 |
+
output_attentions: Optional[bool] = None,
|
621 |
+
output_hidden_states: Optional[bool] = None,
|
622 |
+
return_dict: Optional[bool] = None,
|
623 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
624 |
+
output_attentions = (
|
625 |
+
output_attentions
|
626 |
+
if output_attentions is not None
|
627 |
+
else self.config.output_attentions
|
628 |
+
)
|
629 |
+
output_hidden_states = (
|
630 |
+
output_hidden_states
|
631 |
+
if output_hidden_states is not None
|
632 |
+
else self.config.output_hidden_states
|
633 |
+
)
|
634 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
635 |
+
|
636 |
+
return_dict = (
|
637 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
638 |
+
)
|
639 |
+
|
640 |
+
# retrieve input_ids and inputs_embeds
|
641 |
+
if input_ids is not None and inputs_embeds is not None:
|
642 |
+
raise ValueError(
|
643 |
+
"You cannot specify both input_ids and inputs_embeds at the same time"
|
644 |
+
)
|
645 |
+
elif input_ids is not None:
|
646 |
+
batch_size, seq_length = input_ids.shape
|
647 |
+
elif inputs_embeds is not None:
|
648 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
649 |
+
else:
|
650 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
651 |
+
|
652 |
+
seq_length_with_past = seq_length
|
653 |
+
past_key_values_length = 0
|
654 |
+
|
655 |
+
if past_key_values is not None:
|
656 |
+
past_key_values_length = past_key_values[0][0].shape[2]
|
657 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
658 |
+
if position_ids is None:
|
659 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
660 |
+
position_ids = torch.arange(
|
661 |
+
past_key_values_length,
|
662 |
+
seq_length + past_key_values_length,
|
663 |
+
dtype=torch.long,
|
664 |
+
device=device,
|
665 |
+
)
|
666 |
+
position_ids = position_ids.unsqueeze(0)
|
667 |
+
|
668 |
+
if inputs_embeds is None:
|
669 |
+
# shape: batch, seq_len, hidden_size
|
670 |
+
input_ids = input_ids.to(torch.cuda.current_device())
|
671 |
+
text_input_ids = text_input_ids.to(torch.cuda.current_device())
|
672 |
+
audio_emb = self.embed_tokens(input_ids)
|
673 |
+
if self.use_whisper_feature and whisper_input_feature is not None:
|
674 |
+
if not isinstance(whisper_input_feature, list):
|
675 |
+
whisper_input_feature = whisper_input_feature.squeeze(0)
|
676 |
+
whisper_input_feature = [whisper_input_feature]
|
677 |
+
|
678 |
+
media_start_idx = (input_ids == self.kimia_media_begin).nonzero()
|
679 |
+
media_end_idx = (input_ids == self.kimia_media_end).nonzero()
|
680 |
+
# shape: batch, seq_len, hidden_size
|
681 |
+
whisper_input_dim = whisper_input_feature[0].shape[-1]
|
682 |
+
whisper_dtype = whisper_input_feature[0].dtype
|
683 |
+
expanded_whisper = (
|
684 |
+
torch.zeros(audio_emb.shape[1], whisper_input_dim)
|
685 |
+
.to(torch.cuda.current_device())
|
686 |
+
.to(whisper_dtype)
|
687 |
+
)
|
688 |
+
for (seg_idx, start_idx), (_, end_idx) in zip(
|
689 |
+
media_start_idx, media_end_idx
|
690 |
+
):
|
691 |
+
# assert whisper_emb.shape[1] == end_idx - (start_idx + 1)
|
692 |
+
|
693 |
+
feat_len = end_idx - (start_idx + 1)
|
694 |
+
whisper_input_feature_i = whisper_input_feature[seg_idx].squeeze(0)
|
695 |
+
assert feat_len == is_continuous_mask[seg_idx].sum()
|
696 |
+
expanded_whisper[start_idx + 1 : end_idx, :] = (
|
697 |
+
whisper_input_feature_i[:feat_len, :]
|
698 |
+
)
|
699 |
+
|
700 |
+
expanded_whisper = expanded_whisper.unsqueeze(0)
|
701 |
+
whisper_emb = self.vq_adaptor(
|
702 |
+
expanded_whisper.transpose(0, 1)
|
703 |
+
).transpose(0, 1)
|
704 |
+
is_continuous_mask = is_continuous_mask.to(torch.cuda.current_device())
|
705 |
+
whisper_emb = whisper_emb.to(torch.cuda.current_device())
|
706 |
+
whisper_emb = whisper_emb * is_continuous_mask[:, :, None]
|
707 |
+
|
708 |
+
encoder_input_addwith_discrete_token = (
|
709 |
+
audio_emb + whisper_emb
|
710 |
+
) * torch.sqrt(
|
711 |
+
torch.tensor(
|
712 |
+
2.0, dtype=whisper_emb.dtype, device=torch.cuda.current_device()
|
713 |
+
)
|
714 |
+
)
|
715 |
+
audio_emb = (
|
716 |
+
audio_emb * (~is_continuous_mask[:, :, None])
|
717 |
+
+ encoder_input_addwith_discrete_token
|
718 |
+
* is_continuous_mask[:, :, None]
|
719 |
+
)
|
720 |
+
if text_input_ids is not None and text_input_ids.sum() != 0:
|
721 |
+
inputs_embeds = audio_emb + self.embed_tokens(text_input_ids)
|
722 |
+
else:
|
723 |
+
inputs_embeds = audio_emb
|
724 |
+
# embed positions
|
725 |
+
# TODO kill attention_mask for prefill
|
726 |
+
padding_mask = attention_mask
|
727 |
+
|
728 |
+
hidden_states = inputs_embeds
|
729 |
+
|
730 |
+
# decoder layers
|
731 |
+
all_hidden_states = () if output_hidden_states else None
|
732 |
+
all_self_attns = () if output_attentions else None
|
733 |
+
next_decoder_cache = () if use_cache else None
|
734 |
+
for idx, decoder_layer in enumerate(self.layers):
|
735 |
+
if output_hidden_states:
|
736 |
+
all_hidden_states += (hidden_states,)
|
737 |
+
|
738 |
+
past_key_value = (
|
739 |
+
past_key_values[idx] if past_key_values is not None else None
|
740 |
+
)
|
741 |
+
layer_outputs = decoder_layer(
|
742 |
+
hidden_states,
|
743 |
+
attention_mask=attention_mask,
|
744 |
+
position_ids=position_ids,
|
745 |
+
past_key_value=past_key_value,
|
746 |
+
output_attentions=output_attentions,
|
747 |
+
use_cache=use_cache,
|
748 |
+
padding_mask=padding_mask,
|
749 |
+
)
|
750 |
+
|
751 |
+
hidden_states = layer_outputs[0]
|
752 |
+
if idx == self.kimia_mimo_transformer_from_layer_index:
|
753 |
+
mimo_hidden_states = hidden_states.clone()
|
754 |
+
|
755 |
+
if use_cache:
|
756 |
+
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
757 |
+
|
758 |
+
if output_attentions:
|
759 |
+
all_self_attns += (layer_outputs[1],)
|
760 |
+
|
761 |
+
hidden_states = self.norm(hidden_states)
|
762 |
+
if output_hidden_states:
|
763 |
+
all_hidden_states += (hidden_states,)
|
764 |
+
|
765 |
+
# apply audio transformer layers
|
766 |
+
for idx, decoder_layer in enumerate(self.mimo_layers):
|
767 |
+
if output_hidden_states:
|
768 |
+
all_hidden_states += (mimo_hidden_states,)
|
769 |
+
|
770 |
+
past_key_value = (
|
771 |
+
past_key_values[idx + len(self.layers)]
|
772 |
+
if past_key_values is not None
|
773 |
+
else None
|
774 |
+
)
|
775 |
+
layer_outputs = decoder_layer(
|
776 |
+
mimo_hidden_states,
|
777 |
+
attention_mask=attention_mask,
|
778 |
+
position_ids=position_ids,
|
779 |
+
past_key_value=past_key_value,
|
780 |
+
output_attentions=output_attentions,
|
781 |
+
use_cache=use_cache,
|
782 |
+
padding_mask=padding_mask,
|
783 |
+
)
|
784 |
+
|
785 |
+
mimo_hidden_states = layer_outputs[0]
|
786 |
+
|
787 |
+
if use_cache:
|
788 |
+
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
789 |
+
|
790 |
+
mimo_hidden_states = self.mimo_norm(mimo_hidden_states)
|
791 |
+
|
792 |
+
# add hidden states from the last decoder layer
|
793 |
+
if output_hidden_states:
|
794 |
+
all_hidden_states += (mimo_hidden_states,)
|
795 |
+
|
796 |
+
next_cache = next_decoder_cache if use_cache else None
|
797 |
+
if not return_dict:
|
798 |
+
return tuple(
|
799 |
+
v
|
800 |
+
for v in [
|
801 |
+
hidden_states,
|
802 |
+
mimo_hidden_states,
|
803 |
+
next_cache,
|
804 |
+
all_hidden_states,
|
805 |
+
all_hidden_states,
|
806 |
+
all_self_attns,
|
807 |
+
]
|
808 |
+
if v is not None
|
809 |
+
)
|
810 |
+
return BaseModelOutputWithPast(
|
811 |
+
last_hidden_state=(hidden_states, mimo_hidden_states),
|
812 |
+
past_key_values=next_cache,
|
813 |
+
hidden_states=all_hidden_states,
|
814 |
+
attentions=all_self_attns,
|
815 |
+
)
|
816 |
+
|
817 |
+
|
818 |
+
class MoonshotKimiaForCausalLM(Qwen2PreTrainedModel):
|
819 |
+
_tied_weights_keys = ["lm_head.weight", "mimo_output.weight"]
|
820 |
+
config_class = KimiAudioConfig
|
821 |
+
|
822 |
+
def __init__(self, config):
|
823 |
+
super().__init__(config)
|
824 |
+
self.model = MoonshotKimiaModel(config)
|
825 |
+
self.vocab_size = config.vocab_size
|
826 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
827 |
+
self.mimo_output = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
828 |
+
|
829 |
+
# Initialize weights and apply final processing
|
830 |
+
self.post_init()
|
831 |
+
|
832 |
+
def get_input_embeddings(self):
|
833 |
+
return self.model.embed_tokens
|
834 |
+
|
835 |
+
def set_input_embeddings(self, value):
|
836 |
+
self.model.embed_tokens = value
|
837 |
+
|
838 |
+
def get_output_embeddings(self):
|
839 |
+
return self.lm_head
|
840 |
+
|
841 |
+
def set_output_embeddings(self, new_embeddings):
|
842 |
+
self.lm_head = new_embeddings
|
843 |
+
|
844 |
+
def set_decoder(self, decoder):
|
845 |
+
self.model = decoder
|
846 |
+
|
847 |
+
def get_decoder(self):
|
848 |
+
return self.model
|
849 |
+
|
850 |
+
def forward(
|
851 |
+
self,
|
852 |
+
input_ids: torch.LongTensor = None,
|
853 |
+
text_input_ids: torch.LongTensor = None,
|
854 |
+
whisper_input_feature: Optional[torch.FloatTensor] = None,
|
855 |
+
is_continuous_mask: Optional[torch.Tensor] = None,
|
856 |
+
attention_mask: Optional[torch.Tensor] = None,
|
857 |
+
position_ids: Optional[torch.LongTensor] = None,
|
858 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
859 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
860 |
+
labels: Optional[torch.LongTensor] = None,
|
861 |
+
use_cache: Optional[bool] = None,
|
862 |
+
output_attentions: Optional[bool] = None,
|
863 |
+
output_hidden_states: Optional[bool] = None,
|
864 |
+
generation_mode: Optional[bool] = None,
|
865 |
+
return_dict: Optional[bool] = None,
|
866 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
867 |
+
|
868 |
+
output_attentions = (
|
869 |
+
output_attentions
|
870 |
+
if output_attentions is not None
|
871 |
+
else self.config.output_attentions
|
872 |
+
)
|
873 |
+
output_hidden_states = (
|
874 |
+
output_hidden_states
|
875 |
+
if output_hidden_states is not None
|
876 |
+
else self.config.output_hidden_states
|
877 |
+
)
|
878 |
+
return_dict = (
|
879 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
880 |
+
)
|
881 |
+
|
882 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
883 |
+
outputs = self.model(
|
884 |
+
input_ids=input_ids,
|
885 |
+
text_input_ids=text_input_ids,
|
886 |
+
whisper_input_feature=whisper_input_feature,
|
887 |
+
is_continuous_mask=is_continuous_mask,
|
888 |
+
attention_mask=attention_mask,
|
889 |
+
position_ids=position_ids,
|
890 |
+
past_key_values=past_key_values,
|
891 |
+
inputs_embeds=inputs_embeds,
|
892 |
+
use_cache=use_cache,
|
893 |
+
output_attentions=output_attentions,
|
894 |
+
output_hidden_states=output_hidden_states,
|
895 |
+
return_dict=return_dict,
|
896 |
+
)
|
897 |
+
if return_dict:
|
898 |
+
hidden_states, mimo_hidden_states = (
|
899 |
+
outputs.last_hidden_state[0],
|
900 |
+
outputs.last_hidden_state[1],
|
901 |
+
)
|
902 |
+
else:
|
903 |
+
hidden_states, mimo_hidden_states = outputs[0], outputs[1]
|
904 |
+
|
905 |
+
audio_logits = self.lm_head(hidden_states)
|
906 |
+
text_logits = self.mimo_output(mimo_hidden_states)
|
907 |
+
|
908 |
+
if not return_dict:
|
909 |
+
output = (text_logits, audio_logits) + outputs[2:]
|
910 |
+
return output
|
911 |
+
return CausalLMOutputWithPast(
|
912 |
+
loss=None,
|
913 |
+
logits=(text_logits, audio_logits),
|
914 |
+
past_key_values=outputs.past_key_values,
|
915 |
+
hidden_states=outputs.hidden_states,
|
916 |
+
attentions=outputs.attentions,
|
917 |
+
)
|
special_tokens_map.json
ADDED
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"additional_special_tokens": [
|
3 |
+
"<|im_msg_end|>",
|
4 |
+
"<|im_user_msg_start|>",
|
5 |
+
"<|im_assistant_msg_start|>",
|
6 |
+
"<|reserved_token_0|>",
|
7 |
+
"<|reserved_token_1|>",
|
8 |
+
"<|reserved_token_2|>",
|
9 |
+
"<|reserved_token_3|>",
|
10 |
+
"[EOT]",
|
11 |
+
"<|reserved_token_4|>",
|
12 |
+
"<|reserved_token_5|>",
|
13 |
+
"<|reserved_token_6|>",
|
14 |
+
"<|reserved_token_7|>",
|
15 |
+
"<|reserved_token_8|>",
|
16 |
+
"<|reserved_token_9|>",
|
17 |
+
"<|reserved_token_10|>",
|
18 |
+
"<|reserved_token_11|>",
|
19 |
+
"<|im_media_begin|>",
|
20 |
+
"<|reserved_token_12|>",
|
21 |
+
"<|im_media_end|>",
|
22 |
+
"<|reserved_token_13|>",
|
23 |
+
"<|reserved_token_14|>",
|
24 |
+
"<|im_kimia_text_blank|>",
|
25 |
+
"<|im_kimia_text_eos|>",
|
26 |
+
"<|reserved_token_15|>",
|
27 |
+
"<|reserved_token_16|>",
|
28 |
+
"<|im_kimia_user_msg_start|>",
|
29 |
+
"<|im_kimia_assistant_msg_start|>",
|
30 |
+
"<|reserved_token_17|>",
|
31 |
+
"<|reserved_token_18|>",
|
32 |
+
"<|reserved_token_19|>",
|
33 |
+
"<|im_kimia_speech_ct_id|>",
|
34 |
+
"<|im_kimia_speech_ctd_id|>",
|
35 |
+
"<|reserved_token_20|>",
|
36 |
+
"<|reserved_token_21|>",
|
37 |
+
"<|reserved_token_22|>",
|
38 |
+
"<|reserved_token_23|>",
|
39 |
+
"<|reserved_token_24|>",
|
40 |
+
"<|reserved_token_25|>",
|
41 |
+
"<|reserved_token_26|>",
|
42 |
+
"<|reserved_token_27|>",
|
43 |
+
"<|reserved_token_28|>",
|
44 |
+
"<|reserved_token_29|>",
|
45 |
+
"<|reserved_token_30|>",
|
46 |
+
"<|reserved_token_31|>",
|
47 |
+
"<|reserved_token_32|>",
|
48 |
+
"<|reserved_token_33|>",
|
49 |
+
"<|reserved_token_34|>",
|
50 |
+
"<|reserved_token_35|>",
|
51 |
+
"<|reserved_token_36|>",
|
52 |
+
"<|reserved_token_37|>",
|
53 |
+
"<|reserved_token_38|>",
|
54 |
+
"<|reserved_token_39|>",
|
55 |
+
"<|reserved_token_40|>",
|
56 |
+
"<|reserved_token_41|>",
|
57 |
+
"<|reserved_token_42|>",
|
58 |
+
"<|reserved_token_43|>",
|
59 |
+
"<|reserved_token_44|>",
|
60 |
+
"<|reserved_token_45|>",
|
61 |
+
"<|reserved_token_46|>",
|
62 |
+
"<|reserved_token_47|>",
|
63 |
+
"<|reserved_token_48|>",
|
64 |
+
"<|reserved_token_49|>",
|
65 |
+
"<|reserved_token_50|>",
|
66 |
+
"<|reserved_token_51|>",
|
67 |
+
"<|reserved_token_52|>",
|
68 |
+
"<|reserved_token_53|>",
|
69 |
+
"<|reserved_token_54|>",
|
70 |
+
"<|reserved_token_55|>",
|
71 |
+
"<|reserved_token_56|>",
|
72 |
+
"<|reserved_token_57|>",
|
73 |
+
"<|reserved_token_58|>",
|
74 |
+
"<|reserved_token_59|>",
|
75 |
+
"<|reserved_token_60|>",
|
76 |
+
"<|reserved_token_61|>",
|
77 |
+
"<|reserved_token_62|>",
|
78 |
+
"<|reserved_token_63|>",
|
79 |
+
"<|reserved_token_64|>",
|
80 |
+
"<|reserved_token_65|>",
|
81 |
+
"<|reserved_token_66|>",
|
82 |
+
"<|reserved_token_67|>",
|
83 |
+
"<|reserved_token_68|>",
|
84 |
+
"<|reserved_token_69|>",
|
85 |
+
"<|reserved_token_70|>",
|
86 |
+
"<|reserved_token_71|>",
|
87 |
+
"<|reserved_token_72|>",
|
88 |
+
"<|reserved_token_73|>",
|
89 |
+
"<|reserved_token_74|>",
|
90 |
+
"<|reserved_token_75|>",
|
91 |
+
"<|reserved_token_76|>",
|
92 |
+
"<|reserved_token_77|>",
|
93 |
+
"<|reserved_token_78|>",
|
94 |
+
"<|reserved_token_79|>",
|
95 |
+
"<|reserved_token_80|>",
|
96 |
+
"<|reserved_token_81|>",
|
97 |
+
"<|reserved_token_82|>",
|
98 |
+
"<|reserved_token_83|>",
|
99 |
+
"<|reserved_token_84|>",
|
100 |
+
"<|reserved_token_85|>",
|
101 |
+
"<|reserved_token_86|>",
|
102 |
+
"<|reserved_token_87|>",
|
103 |
+
"<|reserved_token_88|>",
|
104 |
+
"<|reserved_token_89|>",
|
105 |
+
"<|reserved_token_90|>",
|
106 |
+
"<|reserved_token_91|>",
|
107 |
+
"<|reserved_token_92|>",
|
108 |
+
"<|reserved_token_93|>",
|
109 |
+
"<|reserved_token_94|>",
|
110 |
+
"<|reserved_token_95|>",
|
111 |
+
"<|reserved_token_96|>",
|
112 |
+
"<|reserved_token_97|>",
|
113 |
+
"<|reserved_token_98|>",
|
114 |
+
"<|reserved_token_99|>",
|
115 |
+
"<|reserved_token_100|>",
|
116 |
+
"<|reserved_token_101|>",
|
117 |
+
"<|reserved_token_102|>",
|
118 |
+
"<|reserved_token_103|>",
|
119 |
+
"<|reserved_token_104|>",
|
120 |
+
"<|reserved_token_105|>",
|
121 |
+
"<|reserved_token_106|>",
|
122 |
+
"<|reserved_token_107|>",
|
123 |
+
"<|reserved_token_108|>",
|
124 |
+
"<|reserved_token_109|>",
|
125 |
+
"<|reserved_token_110|>",
|
126 |
+
"<|reserved_token_111|>",
|
127 |
+
"<|reserved_token_112|>",
|
128 |
+
"<|reserved_token_113|>",
|
129 |
+
"<|reserved_token_114|>",
|
130 |
+
"<|reserved_token_115|>",
|
131 |
+
"<|reserved_token_116|>",
|
132 |
+
"<|reserved_token_117|>",
|
133 |
+
"<|reserved_token_118|>",
|
134 |
+
"<|reserved_token_119|>",
|
135 |
+
"<|reserved_token_120|>",
|
136 |
+
"<|reserved_token_121|>",
|
137 |
+
"<|reserved_token_122|>",
|
138 |
+
"<|reserved_token_123|>",
|
139 |
+
"<|reserved_token_124|>",
|
140 |
+
"<|reserved_token_125|>",
|
141 |
+
"<|reserved_token_126|>",
|
142 |
+
"<|reserved_token_127|>",
|
143 |
+
"<|reserved_token_128|>",
|
144 |
+
"<|reserved_token_129|>",
|
145 |
+
"<|reserved_token_130|>",
|
146 |
+
"<|reserved_token_131|>",
|
147 |
+
"<|reserved_token_132|>",
|
148 |
+
"<|reserved_token_133|>",
|
149 |
+
"<|reserved_token_134|>",
|
150 |
+
"<|reserved_token_135|>",
|
151 |
+
"<|reserved_token_136|>",
|
152 |
+
"<|reserved_token_137|>",
|
153 |
+
"<|reserved_token_138|>",
|
154 |
+
"<|reserved_token_139|>",
|
155 |
+
"<|reserved_token_140|>",
|
156 |
+
"<|reserved_token_141|>",
|
157 |
+
"<|reserved_token_142|>",
|
158 |
+
"<|reserved_token_143|>",
|
159 |
+
"<|reserved_token_144|>",
|
160 |
+
"<|reserved_token_145|>",
|
161 |
+
"<|reserved_token_146|>",
|
162 |
+
"<|reserved_token_147|>",
|
163 |
+
"<|reserved_token_148|>",
|
164 |
+
"<|reserved_token_149|>",
|
165 |
+
"<|reserved_token_150|>",
|
166 |
+
"<|reserved_token_151|>",
|
167 |
+
"<|reserved_token_152|>",
|
168 |
+
"<|reserved_token_153|>",
|
169 |
+
"<|reserved_token_154|>",
|
170 |
+
"<|reserved_token_155|>",
|
171 |
+
"<|reserved_token_156|>",
|
172 |
+
"<|reserved_token_157|>",
|
173 |
+
"<|reserved_token_158|>",
|
174 |
+
"<|reserved_token_159|>",
|
175 |
+
"<|reserved_token_160|>",
|
176 |
+
"<|reserved_token_161|>",
|
177 |
+
"<|reserved_token_162|>",
|
178 |
+
"<|reserved_token_163|>",
|
179 |
+
"<|reserved_token_164|>",
|
180 |
+
"<|reserved_token_165|>",
|
181 |
+
"<|reserved_token_166|>",
|
182 |
+
"<|reserved_token_167|>",
|
183 |
+
"<|reserved_token_168|>",
|
184 |
+
"<|reserved_token_169|>",
|
185 |
+
"<|reserved_token_170|>",
|
186 |
+
"<|reserved_token_171|>",
|
187 |
+
"<|reserved_token_172|>",
|
188 |
+
"<|reserved_token_173|>",
|
189 |
+
"<|reserved_token_174|>",
|
190 |
+
"<|reserved_token_175|>",
|
191 |
+
"<|reserved_token_176|>",
|
192 |
+
"<|reserved_token_177|>",
|
193 |
+
"<|reserved_token_178|>",
|
194 |
+
"<|reserved_token_179|>",
|
195 |
+
"<|reserved_token_180|>",
|
196 |
+
"<|reserved_token_181|>",
|
197 |
+
"<|reserved_token_182|>",
|
198 |
+
"<|reserved_token_183|>",
|
199 |
+
"<|reserved_token_184|>",
|
200 |
+
"<|reserved_token_185|>",
|
201 |
+
"<|reserved_token_186|>",
|
202 |
+
"<|reserved_token_187|>",
|
203 |
+
"<|reserved_token_188|>",
|
204 |
+
"<|reserved_token_189|>",
|
205 |
+
"<|reserved_token_190|>",
|
206 |
+
"<|reserved_token_191|>",
|
207 |
+
"<|reserved_token_192|>",
|
208 |
+
"<|reserved_token_193|>",
|
209 |
+
"<|reserved_token_194|>",
|
210 |
+
"<|reserved_token_195|>",
|
211 |
+
"<|reserved_token_196|>",
|
212 |
+
"<|reserved_token_197|>",
|
213 |
+
"<|reserved_token_198|>",
|
214 |
+
"<|reserved_token_199|>",
|
215 |
+
"<|reserved_token_200|>",
|
216 |
+
"<|reserved_token_201|>",
|
217 |
+
"<|reserved_token_202|>",
|
218 |
+
"<|reserved_token_203|>",
|
219 |
+
"<|reserved_token_204|>",
|
220 |
+
"<|reserved_token_205|>",
|
221 |
+
"<|reserved_token_206|>",
|
222 |
+
"<|reserved_token_207|>",
|
223 |
+
"<|reserved_token_208|>",
|
224 |
+
"<|reserved_token_209|>",
|
225 |
+
"<|reserved_token_210|>",
|
226 |
+
"<|reserved_token_211|>",
|
227 |
+
"<|reserved_token_212|>",
|
228 |
+
"<|reserved_token_213|>",
|
229 |
+
"<|reserved_token_214|>",
|
230 |
+
"<|reserved_token_215|>",
|
231 |
+
"<|reserved_token_216|>",
|
232 |
+
"<|reserved_token_217|>",
|
233 |
+
"<|reserved_token_218|>",
|
234 |
+
"<|reserved_token_219|>",
|
235 |
+
"<|reserved_token_220|>",
|
236 |
+
"<|reserved_token_221|>",
|
237 |
+
"<|reserved_token_222|>",
|
238 |
+
"<|reserved_token_223|>",
|
239 |
+
"<|reserved_token_224|>",
|
240 |
+
"<|reserved_token_225|>",
|
241 |
+
"<|reserved_token_226|>",
|
242 |
+
"<|reserved_token_227|>",
|
243 |
+
"<|reserved_token_228|>",
|
244 |
+
"<|reserved_token_229|>",
|
245 |
+
"<|reserved_token_230|>",
|
246 |
+
"<|reserved_token_231|>",
|
247 |
+
"<|reserved_token_232|>",
|
248 |
+
"<|reserved_token_233|>",
|
249 |
+
"<|reserved_token_234|>",
|
250 |
+
"<|reserved_token_235|>",
|
251 |
+
"<|reserved_token_236|>",
|
252 |
+
"<|reserved_token_237|>",
|
253 |
+
"<|reserved_token_238|>",
|
254 |
+
"<|reserved_token_239|>",
|
255 |
+
"<|reserved_token_240|>",
|
256 |
+
"<|reserved_token_241|>",
|
257 |
+
"<|reserved_token_242|>",
|
258 |
+
"<|reserved_token_243|>",
|
259 |
+
"<|reserved_token_244|>",
|
260 |
+
"<|reserved_token_245|>",
|
261 |
+
"<|reserved_token_246|>",
|
262 |
+
"<|reserved_token_247|>",
|
263 |
+
"<|reserved_token_248|>",
|
264 |
+
"<|reserved_token_249|>",
|
265 |
+
"<|reserved_token_250|>",
|
266 |
+
"<|reserved_token_251|>",
|
267 |
+
"<|reserved_token_252|>",
|
268 |
+
"<|reserved_token_253|>",
|
269 |
+
"<|reserved_token_254|>",
|
270 |
+
"<|reserved_token_255|>",
|
271 |
+
"<|reserved_token_256|>",
|
272 |
+
"<|reserved_token_257|>",
|
273 |
+
"<|reserved_token_258|>",
|
274 |
+
"<|reserved_token_259|>",
|
275 |
+
"<|reserved_token_260|>",
|
276 |
+
"<|reserved_token_261|>",
|
277 |
+
"<|reserved_token_262|>",
|
278 |
+
"<|reserved_token_263|>",
|
279 |
+
"<|reserved_token_264|>",
|
280 |
+
"<|reserved_token_265|>",
|
281 |
+
"<|reserved_token_266|>",
|
282 |
+
"<|reserved_token_267|>",
|
283 |
+
"<|reserved_token_268|>",
|
284 |
+
"<|reserved_token_269|>",
|
285 |
+
"<|reserved_token_270|>",
|
286 |
+
"<|reserved_token_271|>",
|
287 |
+
"<|reserved_token_272|>",
|
288 |
+
"<|reserved_token_273|>",
|
289 |
+
"<|reserved_token_274|>",
|
290 |
+
"<|reserved_token_275|>",
|
291 |
+
"<|reserved_token_276|>",
|
292 |
+
"<|reserved_token_277|>",
|
293 |
+
"<|reserved_token_278|>",
|
294 |
+
"<|reserved_token_279|>",
|
295 |
+
"<|reserved_token_280|>",
|
296 |
+
"<|reserved_token_281|>",
|
297 |
+
"<|reserved_token_282|>",
|
298 |
+
"<|reserved_token_283|>",
|
299 |
+
"<|reserved_token_284|>",
|
300 |
+
"<|reserved_token_285|>",
|
301 |
+
"<|reserved_token_286|>",
|
302 |
+
"<|reserved_token_287|>",
|
303 |
+
"<|reserved_token_288|>",
|
304 |
+
"<|reserved_token_289|>",
|
305 |
+
"<|reserved_token_290|>",
|
306 |
+
"<|reserved_token_291|>",
|
307 |
+
"<|reserved_token_292|>",
|
308 |
+
"<|reserved_token_293|>",
|
309 |
+
"<|reserved_token_294|>",
|
310 |
+
"<|reserved_token_295|>",
|
311 |
+
"<|reserved_token_296|>",
|
312 |
+
"<|reserved_token_297|>",
|
313 |
+
"<|reserved_token_298|>",
|
314 |
+
"<|reserved_token_299|>",
|
315 |
+
"<|reserved_token_300|>",
|
316 |
+
"<|reserved_token_301|>",
|
317 |
+
"<|reserved_token_302|>",
|
318 |
+
"<|reserved_token_303|>",
|
319 |
+
"<|reserved_token_304|>",
|
320 |
+
"<|reserved_token_305|>",
|
321 |
+
"<|reserved_token_306|>",
|
322 |
+
"<|reserved_token_307|>",
|
323 |
+
"<|reserved_token_308|>",
|
324 |
+
"<|reserved_token_309|>",
|
325 |
+
"<|reserved_token_310|>",
|
326 |
+
"<|reserved_token_311|>",
|
327 |
+
"<|reserved_token_312|>",
|
328 |
+
"<|reserved_token_313|>",
|
329 |
+
"<|reserved_token_314|>",
|
330 |
+
"<|reserved_token_315|>",
|
331 |
+
"<|reserved_token_316|>",
|
332 |
+
"<|reserved_token_317|>",
|
333 |
+
"<|reserved_token_318|>",
|
334 |
+
"<|reserved_token_319|>",
|
335 |
+
"<|reserved_token_320|>",
|
336 |
+
"<|reserved_token_321|>",
|
337 |
+
"<|reserved_token_322|>",
|
338 |
+
"<|reserved_token_323|>",
|
339 |
+
"<|reserved_token_324|>",
|
340 |
+
"<|reserved_token_325|>",
|
341 |
+
"<|reserved_token_326|>",
|
342 |
+
"<|reserved_token_327|>",
|
343 |
+
"<|reserved_token_328|>",
|
344 |
+
"<|reserved_token_329|>",
|
345 |
+
"<|reserved_token_330|>",
|
346 |
+
"<|reserved_token_331|>",
|
347 |
+
"<|reserved_token_332|>",
|
348 |
+
"<|reserved_token_333|>",
|
349 |
+
"<|reserved_token_334|>",
|
350 |
+
"<|reserved_token_335|>",
|
351 |
+
"<|reserved_token_336|>",
|
352 |
+
"<|reserved_token_337|>",
|
353 |
+
"<|reserved_token_338|>",
|
354 |
+
"<|reserved_token_339|>",
|
355 |
+
"<|reserved_token_340|>",
|
356 |
+
"<|reserved_token_341|>",
|
357 |
+
"<|reserved_token_342|>",
|
358 |
+
"<|reserved_token_343|>",
|
359 |
+
"<|reserved_token_344|>",
|
360 |
+
"<|reserved_token_345|>",
|
361 |
+
"<|reserved_token_346|>",
|
362 |
+
"<|reserved_token_347|>",
|
363 |
+
"<|reserved_token_348|>",
|
364 |
+
"<|reserved_token_349|>",
|
365 |
+
"<|reserved_token_350|>",
|
366 |
+
"<|reserved_token_351|>",
|
367 |
+
"<|reserved_token_352|>",
|
368 |
+
"<|reserved_token_353|>",
|
369 |
+
"<|reserved_token_354|>",
|
370 |
+
"<|reserved_token_355|>",
|
371 |
+
"<|reserved_token_356|>",
|
372 |
+
"<|reserved_token_357|>",
|
373 |
+
"<|reserved_token_358|>",
|
374 |
+
"<|reserved_token_359|>",
|
375 |
+
"<|reserved_token_360|>",
|
376 |
+
"<|reserved_token_361|>",
|
377 |
+
"<|reserved_token_362|>",
|
378 |
+
"<|reserved_token_363|>",
|
379 |
+
"<|reserved_token_364|>",
|
380 |
+
"<|reserved_token_365|>",
|
381 |
+
"<|reserved_token_366|>",
|
382 |
+
"<|reserved_token_367|>",
|
383 |
+
"<|reserved_token_368|>",
|
384 |
+
"<|reserved_token_369|>",
|
385 |
+
"<|reserved_token_370|>",
|
386 |
+
"<|reserved_token_371|>",
|
387 |
+
"<|reserved_token_372|>",
|
388 |
+
"<|reserved_token_373|>",
|
389 |
+
"<|reserved_token_374|>",
|
390 |
+
"<|reserved_token_375|>",
|
391 |
+
"<|reserved_token_376|>",
|
392 |
+
"<|reserved_token_377|>",
|
393 |
+
"<|reserved_token_378|>",
|
394 |
+
"<|reserved_token_379|>",
|
395 |
+
"<|reserved_token_380|>",
|
396 |
+
"<|reserved_token_381|>",
|
397 |
+
"<|reserved_token_382|>",
|
398 |
+
"<|reserved_token_383|>",
|
399 |
+
"<|reserved_token_384|>",
|
400 |
+
"<|reserved_token_385|>",
|
401 |
+
"<|reserved_token_386|>",
|
402 |
+
"<|reserved_token_387|>",
|
403 |
+
"<|reserved_token_388|>",
|
404 |
+
"<|reserved_token_389|>",
|
405 |
+
"<|reserved_token_390|>",
|
406 |
+
"<|reserved_token_391|>",
|
407 |
+
"<|reserved_token_392|>",
|
408 |
+
"<|reserved_token_393|>",
|
409 |
+
"<|reserved_token_394|>",
|
410 |
+
"<|reserved_token_395|>",
|
411 |
+
"<|reserved_token_396|>",
|
412 |
+
"<|reserved_token_397|>",
|
413 |
+
"<|reserved_token_398|>",
|
414 |
+
"<|reserved_token_399|>",
|
415 |
+
"<|reserved_token_400|>",
|
416 |
+
"<|reserved_token_401|>",
|
417 |
+
"<|reserved_token_402|>",
|
418 |
+
"<|reserved_token_403|>",
|
419 |
+
"<|reserved_token_404|>"
|
420 |
+
],
|
421 |
+
"bos_token": "[BOS]",
|
422 |
+
"eos_token": "[EOS]",
|
423 |
+
"pad_token": "<|reserved_token_406|>",
|
424 |
+
"unk_token": "<|reserved_token_405|>"
|
425 |
+
}
|
tiktoken.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b2b1b8dfb5cc5f024bafc373121c6aba3f66f9a5a0269e243470a1de16a33186
|
3 |
+
size 2561218
|
tokenization_kimia.py
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
|
3 |
+
"""Megatron tokenizers."""
|
4 |
+
from transformers.tokenization_utils import PreTrainedTokenizer
|
5 |
+
from typing import Union
|
6 |
+
from typing import (
|
7 |
+
AbstractSet,
|
8 |
+
cast,
|
9 |
+
Collection,
|
10 |
+
Dict,
|
11 |
+
Iterator,
|
12 |
+
List,
|
13 |
+
Literal,
|
14 |
+
Sequence,
|
15 |
+
Union,
|
16 |
+
Optional,
|
17 |
+
)
|
18 |
+
from tiktoken.load import load_tiktoken_bpe
|
19 |
+
import tiktoken
|
20 |
+
from pathlib import Path
|
21 |
+
import os
|
22 |
+
import logging
|
23 |
+
from tokenizers import AddedToken
|
24 |
+
|
25 |
+
logger = logging.getLogger(__name__)
|
26 |
+
VOCAB_FILES_NAMES = {"vocab_file": "tiktoken.model"}
|
27 |
+
|
28 |
+
|
29 |
+
class TikTokenTokenizer(PreTrainedTokenizer):
|
30 |
+
"""
|
31 |
+
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
|
32 |
+
"""
|
33 |
+
|
34 |
+
special_tokens: Dict[str, int]
|
35 |
+
|
36 |
+
num_reserved_special_tokens = 293 + 128
|
37 |
+
|
38 |
+
pat_str = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
|
39 |
+
|
40 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
41 |
+
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
vocab_file,
|
45 |
+
bos_token: Union[str, AddedToken] = "[BOS]",
|
46 |
+
eos_token: Union[str, AddedToken] = "[EOS]",
|
47 |
+
unk_token: Union[str, AddedToken] = "[UNK]",
|
48 |
+
pad_token: Union[str, AddedToken] = "[PAD]",
|
49 |
+
additional_special_tokens: Optional[List[str]] = None,
|
50 |
+
added_tokens_decoder: Optional[dict] = None,
|
51 |
+
**kwargs,
|
52 |
+
):
|
53 |
+
"""
|
54 |
+
Initializes the Tokenizer with a Tiktoken model.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
model_path (str): The path to the Tiktoken model file.
|
58 |
+
"""
|
59 |
+
assert os.path.isfile(vocab_file), vocab_file
|
60 |
+
|
61 |
+
mergeable_ranks = load_tiktoken_bpe(vocab_file)
|
62 |
+
num_base_tokens = len(mergeable_ranks)
|
63 |
+
|
64 |
+
used_special_tokens = [
|
65 |
+
"[BOS]",
|
66 |
+
"[EOS]",
|
67 |
+
"<|im_msg_end|>", # 0
|
68 |
+
"<|im_user_msg_start|>", # 1
|
69 |
+
"<|im_assistant_msg_start|>", # 2
|
70 |
+
"<|reserved_token_0|>", # 3
|
71 |
+
"<|reserved_token_1|>",
|
72 |
+
"<|reserved_token_2|>",
|
73 |
+
"<|reserved_token_3|>", # 4
|
74 |
+
"[EOT]",
|
75 |
+
"<|reserved_token_4|>", # 5
|
76 |
+
"<|reserved_token_5|>", # 6
|
77 |
+
"<|reserved_token_6|>", # 7
|
78 |
+
"<|reserved_token_7|>", # 8
|
79 |
+
"<|reserved_token_8|>", # 9
|
80 |
+
"<|reserved_token_9|>", # 10
|
81 |
+
"<|reserved_token_10|>", # 11
|
82 |
+
"<|reserved_token_11|>", # 12
|
83 |
+
"<|im_media_begin|>", # 13
|
84 |
+
"<|reserved_token_12|>", # 14
|
85 |
+
"<|im_media_end|>", # 15
|
86 |
+
"<|reserved_token_13|>", # 16
|
87 |
+
"<|reserved_token_14|>", # 17
|
88 |
+
"<|im_kimia_text_blank|>", # 18
|
89 |
+
"<|im_kimia_text_eos|>", # 19
|
90 |
+
"<|reserved_token_15|>", # 20
|
91 |
+
"<|reserved_token_16|>", # 21
|
92 |
+
"<|im_kimia_user_msg_start|>", # 22
|
93 |
+
"<|im_kimia_assistant_msg_start|>", # 23
|
94 |
+
"<|reserved_token_17|>", # 24
|
95 |
+
"<|reserved_token_18|>", # 25
|
96 |
+
"<|reserved_token_19|>", # 26
|
97 |
+
"<|im_kimia_speech_ct_id|>", # 27
|
98 |
+
"<|im_kimia_speech_ctd_id|>", # 28
|
99 |
+
]
|
100 |
+
autoset_special_tokens = [
|
101 |
+
f"<|reserved_token_{i}|>"
|
102 |
+
for i in range(
|
103 |
+
20, self.num_reserved_special_tokens - len(used_special_tokens) + 20
|
104 |
+
)
|
105 |
+
]
|
106 |
+
special_tokens = used_special_tokens + autoset_special_tokens
|
107 |
+
self.special_tokens = {
|
108 |
+
token: num_base_tokens + i for i, token in enumerate(special_tokens)
|
109 |
+
}
|
110 |
+
self.model = tiktoken.Encoding(
|
111 |
+
name=Path(vocab_file).name,
|
112 |
+
pat_str=self.pat_str,
|
113 |
+
mergeable_ranks=mergeable_ranks,
|
114 |
+
special_tokens=self.special_tokens,
|
115 |
+
)
|
116 |
+
logger.info(f"Reloaded tiktoken model from {vocab_file}")
|
117 |
+
|
118 |
+
self.n_words: int = self.model.n_vocab
|
119 |
+
# BOS / EOS token IDs
|
120 |
+
self.bos_token = "[BOS]"
|
121 |
+
self.bos_id: int = self.special_tokens["[BOS]"]
|
122 |
+
self.eos_token = "[EOS]"
|
123 |
+
self.eos_id: int = self.special_tokens["[EOS]"]
|
124 |
+
|
125 |
+
# use last speical token as pad token, the last - 1 is unk_token
|
126 |
+
self.pad_token: str = special_tokens[-1]
|
127 |
+
self.pad_id: int = self.special_tokens[self.pad_token]
|
128 |
+
|
129 |
+
self.unk_token: str = special_tokens[-2]
|
130 |
+
self.unk_id: int = self.special_tokens[self.pad_token]
|
131 |
+
|
132 |
+
self.stop_tokens = {
|
133 |
+
self.special_tokens["[EOS]"],
|
134 |
+
self.special_tokens["[EOT]"],
|
135 |
+
}
|
136 |
+
|
137 |
+
logger.info(
|
138 |
+
f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
|
139 |
+
)
|
140 |
+
|
141 |
+
def encode(
|
142 |
+
self,
|
143 |
+
s: str,
|
144 |
+
*,
|
145 |
+
bos: bool,
|
146 |
+
eos: bool,
|
147 |
+
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
|
148 |
+
disallowed_special: Union[Literal["all"], Collection[str]] = (),
|
149 |
+
) -> List[int]:
|
150 |
+
"""
|
151 |
+
Encodes a string into a list of token IDs.
|
152 |
+
|
153 |
+
Args:
|
154 |
+
s (str): The input string to be encoded.
|
155 |
+
bos (bool): Whether to prepend the beginning-of-sequence token.
|
156 |
+
eos (bool): Whether to append the end-of-sequence token.
|
157 |
+
allowed_tokens ("all"|set[str]): allowed special tokens in string
|
158 |
+
disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string
|
159 |
+
|
160 |
+
Returns:
|
161 |
+
list[int]: A list of token IDs.
|
162 |
+
|
163 |
+
By default, setting disallowed_special=() encodes a string by ignoring
|
164 |
+
special tokens. Specifically:
|
165 |
+
- Setting `disallowed_special` to () will cause all text corresponding
|
166 |
+
to special tokens to be encoded as natural text (insteading of raising
|
167 |
+
an error).
|
168 |
+
- Setting `allowed_special` to "all" will treat all text corresponding
|
169 |
+
to special tokens to be encoded as special tokens.
|
170 |
+
"""
|
171 |
+
assert type(s) is str
|
172 |
+
|
173 |
+
# The tiktoken tokenizer can handle <=400k chars without
|
174 |
+
# pyo3_runtime.PanicException.
|
175 |
+
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
|
176 |
+
|
177 |
+
# https://github.com/openai/tiktoken/issues/195
|
178 |
+
# Here we iterate over subsequences and split if we exceed the limit
|
179 |
+
# of max consecutive non-whitespace or whitespace characters.
|
180 |
+
MAX_NO_WHITESPACES_CHARS = 25_000
|
181 |
+
|
182 |
+
substrs = (
|
183 |
+
substr
|
184 |
+
for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
|
185 |
+
for substr in self._split_whitespaces_or_nonwhitespaces(
|
186 |
+
s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
|
187 |
+
)
|
188 |
+
)
|
189 |
+
t: List[int] = []
|
190 |
+
for substr in substrs:
|
191 |
+
t.extend(
|
192 |
+
self.model.encode(
|
193 |
+
substr,
|
194 |
+
allowed_special=allowed_special,
|
195 |
+
disallowed_special=disallowed_special,
|
196 |
+
)
|
197 |
+
)
|
198 |
+
if bos:
|
199 |
+
t.insert(0, self.bos_id)
|
200 |
+
if eos:
|
201 |
+
t.append(self.eos_id)
|
202 |
+
return t
|
203 |
+
|
204 |
+
def decode(self, t: Sequence[int]) -> str:
|
205 |
+
"""
|
206 |
+
Decodes a list of token IDs into a string.
|
207 |
+
|
208 |
+
Args:
|
209 |
+
t (List[int]): The list of token IDs to be decoded.
|
210 |
+
|
211 |
+
Returns:
|
212 |
+
str: The decoded string.
|
213 |
+
"""
|
214 |
+
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
|
215 |
+
return self.model.decode(cast(List[int], t))
|
216 |
+
|
217 |
+
@staticmethod
|
218 |
+
def _split_whitespaces_or_nonwhitespaces(
|
219 |
+
s: str, max_consecutive_slice_len: int
|
220 |
+
) -> Iterator[str]:
|
221 |
+
"""
|
222 |
+
Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
|
223 |
+
consecutive whitespaces or consecutive non-whitespaces.
|
224 |
+
"""
|
225 |
+
current_slice_len = 0
|
226 |
+
current_slice_is_space = s[0].isspace() if len(s) > 0 else False
|
227 |
+
slice_start = 0
|
228 |
+
|
229 |
+
for i in range(len(s)):
|
230 |
+
is_now_space = s[i].isspace()
|
231 |
+
|
232 |
+
if current_slice_is_space ^ is_now_space:
|
233 |
+
current_slice_len = 1
|
234 |
+
current_slice_is_space = is_now_space
|
235 |
+
else:
|
236 |
+
current_slice_len += 1
|
237 |
+
if current_slice_len > max_consecutive_slice_len:
|
238 |
+
yield s[slice_start:i]
|
239 |
+
slice_start = i
|
240 |
+
current_slice_len = 1
|
241 |
+
yield s[slice_start:]
|
242 |
+
|
243 |
+
""" ----- Below are the abstract methods required by megatron ----- """
|
244 |
+
|
245 |
+
@property
|
246 |
+
def vocab_size(self):
|
247 |
+
return self.n_words
|
248 |
+
|
249 |
+
@property
|
250 |
+
def vocab(self):
|
251 |
+
if hasattr(self, "str_vocab"):
|
252 |
+
return self.str_vocab
|
253 |
+
self.str_vocab = {}
|
254 |
+
|
255 |
+
# convert mergeable_ranks from bytes to string
|
256 |
+
utf8_num, unicode_num = 0, 0
|
257 |
+
for byte_key, index in self.model._mergeable_ranks.items():
|
258 |
+
try:
|
259 |
+
str_key = byte_key.decode("utf-8")
|
260 |
+
utf8_num += 1
|
261 |
+
except UnicodeDecodeError:
|
262 |
+
# use backslashreplace so we can get num vocab different tokens
|
263 |
+
# see: https://docs.python.org/3/howto/unicode.html
|
264 |
+
# this vocab is only used for offline processing, so this is fine
|
265 |
+
str_key = byte_key.decode("utf-8", "backslashreplace") + "_unicode_"
|
266 |
+
unicode_num += 1
|
267 |
+
|
268 |
+
self.str_vocab[str_key] = index
|
269 |
+
logger.info(f"num utf8: {utf8_num}, num unicode: {unicode_num}")
|
270 |
+
|
271 |
+
# add all special tokens to the dictionary
|
272 |
+
self.str_vocab.update(self.model._special_tokens)
|
273 |
+
|
274 |
+
assert len(self.str_vocab) == self.vocab_size
|
275 |
+
return self.str_vocab
|
276 |
+
|
277 |
+
@property
|
278 |
+
def inv_vocab(self):
|
279 |
+
return {v: k for k, v in self.vocab.items()}
|
280 |
+
|
281 |
+
def tokenize(self, text, eos=True):
|
282 |
+
# BOS: always add bos token
|
283 |
+
# EOS:
|
284 |
+
# Most cases should be true when we are tokenizing a full sequence
|
285 |
+
# Only setting to false when we are running a inference
|
286 |
+
return self.encode(text, bos=True, eos=eos)
|
287 |
+
|
288 |
+
def detokenize(self, tokens):
|
289 |
+
# convert tensor to list if needed...
|
290 |
+
if not isinstance(tokens, list):
|
291 |
+
tokens = tokens.tolist()
|
292 |
+
return self.decode(tokens)
|
293 |
+
|
294 |
+
@property
|
295 |
+
def eod(self):
|
296 |
+
return self.eos_id
|
297 |
+
|
298 |
+
def bod(self):
|
299 |
+
return self.bos_id
|
300 |
+
|
301 |
+
@property
|
302 |
+
def msk_start_id(self):
|
303 |
+
return self.msk_start
|
304 |
+
|
305 |
+
@property
|
306 |
+
def msk_end_id(self):
|
307 |
+
return self.msk_end
|
308 |
+
|
309 |
+
def _get_index_2_bytes(self):
|
310 |
+
if hasattr(self, "index_2_bytes"):
|
311 |
+
return self.index_2_bytes
|
312 |
+
|
313 |
+
# use array rather than dict for faster access
|
314 |
+
self.index_2_bytes = [0] * self.model.n_vocab
|
315 |
+
for byte_key, index in self.model._mergeable_ranks.items():
|
316 |
+
self.index_2_bytes[index] = len(byte_key)
|
317 |
+
|
318 |
+
for _, index in self.model._special_tokens.items():
|
319 |
+
# in total we have 256 special tokens, 2^8 = 256
|
320 |
+
# so the num of bytes of each token is only 1
|
321 |
+
self.index_2_bytes[index] = 1
|
322 |
+
|
323 |
+
return self.index_2_bytes
|
324 |
+
|
325 |
+
def get_array_bytes(self, array):
|
326 |
+
index_2_bytes = self._get_index_2_bytes()
|
327 |
+
return sum(index_2_bytes[i] for i in array)
|
328 |
+
|
329 |
+
@property
|
330 |
+
def eos_token_id(self):
|
331 |
+
return self.eos_id
|
332 |
+
|
333 |
+
@property
|
334 |
+
def pad_token_id(self):
|
335 |
+
return self.pad_id
|
tokenizer_config.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|