OrlandoHugBot commited on
Commit
5b16c4f
·
verified ·
1 Parent(s): 1116bb3

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Skywork-R1V
2
+ <div align="center">
3
+ <img src="logo.jpeg" alt="Introduction Image" width="400" height="400">
4
+ <br>
5
+ <a href="README_ZH.md">中文</a> | <a href="https://github.com/SkyworkAI/Skywork-R1V">📂 GitHub</a>
6
+ </div>
7
+
8
+
9
+
10
+ ## 1. Introduction
11
+
12
+ We introduce Skywork-R1V, a multimodal reasoning model that extends the R1-series text models to visual modalities through a near-lossless transfer method. Using a lightweight visual projector, Skywork-R1V enables seamless multimodal adaptation without requiring retraining of either the base language model or vision encoder. To enhance visual-text alignment, we developed a hybrid optimization strategy combining Iterative Supervised Fine-Tuning (SFT) with Group Relative Policy Optimization (GRPO), significantly improving cross-modal integration. Additionally, we created an adaptive-length Chain-of-Thought distillation approach for generating reasoning data, which dynamically optimizes reasoning chain lengths to improve inference efficiency and prevent overthinking. The model achieves state-of-the-art performance on key multimodal reasoning benchmarks, scoring 68.1 on MMMU and 71.0 on MathVista, comparable to leading closed-source models like Gemini 2.0 and Kimi-k1.5. It also maintains strong textual reasoning capabilities, achieving impressive scores of 72.6 on AIME and 94.3 on MATH500.
13
+
14
+
15
+ ## 2. Model Summary
16
+
17
+ ****Architecture:****
18
+ Skywork-R1V employs a modular architecture that efficiently combines vision and language capabilities:
19
+ - Vision Encoder: Uses Vision Transformer (ViT) as the visual backbone to process image inputs.
20
+ - Visual Projector: A lightweight MLP (multilayer perceptron) adapter that serves as the bridge between the vision and language components.
21
+ - Language Model: Utilizes R1-distilled-Qwen-32B as the reasoning-capable language model backbone.
22
+
23
+ The model follows a connection pattern of Vision Encoder → MLP Adapter → Language Model, where the MLP adapter aligns the output space of the vision encoder with the input space of the language model. This design allows for efficient transfer of reasoning capabilities from text to multimodal domains without requiring extensive retraining of either the vision encoder or language model.
24
+
25
+ ****Key Designs****
26
+ - **Advanced Multimodal Reasoning**
27
+ Excels in complex reasoning across textual and visual modalities.
28
+ - **Iterative Training Strategies**
29
+ Employs iterative supervision and grpo to refine model alignment and performance.
30
+ - **Adaptive length Chain-of-Thought**
31
+ Dynamically adjusts reasoning length to enhance inference efficiency and accuracy.
32
+ - **Scalable Performance**
33
+ Benchmarked to rival proprietary models across mathematics, coding, and multimodal tasks.
34
+
35
+ ## 3. Evaluation
36
+
37
+ <div align="center">
38
+ <img src="eval.jpeg" width="600" height="200" alt="skywork_r1v_eval" />
39
+ </div>
40
+
41
+ <div align="center">
42
+ <b>Evaluation results of state-of-the-art LLMs and VLMs</b>
43
+ </div>
44
+ <table>
45
+ <thead>
46
+ <tr>
47
+ <th></th>
48
+ <th align="center"><strong>Vision</strong></th>
49
+ <th align="center" colspan="3"><strong>Reasoning</strong></th>
50
+ <th align="center" colspan="3"><strong>Vision</strong></th>
51
+ </tr>
52
+ <tr>
53
+ <th></th>
54
+ <th></th>
55
+ <th align="center"><strong>MATH-500</strong></th>
56
+ <th align="center"><strong>AIME 2024</strong></th>
57
+ <th align="center"><strong>GPQA</strong></th>
58
+ <th align="center"><strong>MathVista(mini)</strong></th>
59
+ <th align="center"><strong>MMMU(Val)</strong></th>
60
+ <th align="center"><strong>CSVQA</strong></th>
61
+ </tr>
62
+ <tr>
63
+ <th></th>
64
+ <th></th>
65
+ <th align="center">pass@1</th>
66
+ <th align="center">pass@1</th>
67
+ <th align="center">pass@1</th>
68
+ <th align="center">pass@1</th>
69
+ <th align="center">pass@1</th>
70
+ <th align="center">pass@1</th>
71
+ </tr>
72
+ </thead>
73
+ <tbody>
74
+ <tr>
75
+ <td>Qwen2.5-72B-Instruct</td>
76
+ <td align="center">❌</td>
77
+ <td align="center">82.6</td>
78
+ <td align="center">23.3</td>
79
+ <td align="center">49.0</td>
80
+ <td align="center">-</td>
81
+ <td align="center">-</td>
82
+ <td align="center">-</td>
83
+ </tr>
84
+ <tr>
85
+ <td>Deepseek V3</td>
86
+ <td align="center">❌</td>
87
+ <td align="center">90.2</td>
88
+ <td align="center">39.2</td>
89
+ <td align="center">59.1</td>
90
+ <td align="center">-</td>
91
+ <td align="center">-</td>
92
+ <td align="center">-</td>
93
+ </tr>
94
+ <tr>
95
+ <td>Deepseek R1</td>
96
+ <td align="center">❌</td>
97
+ <td align="center">97.3</td>
98
+ <td align="center">79.8</td>
99
+ <td align="center">71.5</td>
100
+ <td align="center">-</td>
101
+ <td align="center">-</td>
102
+ <td align="center">-</td>
103
+ </tr>
104
+ <tr>
105
+ <td>Claude 3.5 Sonnet</td>
106
+ <td align="center">✅</td>
107
+ <td align="center">78.3</td>
108
+ <td align="center">16.0</td>
109
+ <td align="center">65.0</td>
110
+ <td align="center">67.7</td>
111
+ <td align="center">68.3</td>
112
+ <td align="center">-</td>
113
+ </tr>
114
+ <tr>
115
+ <td>GPT-4o</td>
116
+ <td align="center">✅</td>
117
+ <td align="center">76.6</td>
118
+ <td align="center">9.3</td>
119
+ <td align="center">53.6</td>
120
+ <td align="center">63.8</td>
121
+ <td align="center">69.1</td>
122
+ <td align="center">-</td>
123
+ </tr>
124
+ <tr>
125
+ <td>Kimi k1.5</td>
126
+ <td align="center">✅</td>
127
+ <td align="center">96.2</td>
128
+ <td align="center">77.5</td>
129
+ <td align="center">-</td>
130
+ <td align="center">74.9</td>
131
+ <td align="center">70.0</td>
132
+ <td align="center">-</td>
133
+ </tr>
134
+ <tr>
135
+ <td>Qwen2.5-VL-72B-Instruct</td>
136
+ <td align="center">✅</td>
137
+ <td align="center">-</td>
138
+ <td align="center">-</td>
139
+ <td align="center">-</td>
140
+ <td align="center">74.8</td>
141
+ <td align="center">70.2</td>
142
+ <td align="center">-</td>
143
+ </tr>
144
+ <tr>
145
+ <td>LLaVA-Onevision-72B</td>
146
+ <td align="center">✅</td>
147
+ <td align="center">-</td>
148
+ <td align="center">-</td>
149
+ <td align="center">-</td>
150
+ <td align="center">67.5</td>
151
+ <td align="center">56.8</td>
152
+ <td align="center">-</td>
153
+ </tr>
154
+ <tr>
155
+ <td>InternVL2-Llama3-76B</td>
156
+ <td align="center">✅</td>
157
+ <td align="center">-</td>
158
+ <td align="center">-</td>
159
+ <td align="center">-</td>
160
+ <td align="center">65.5</td>
161
+ <td align="center">58.3</td>
162
+ <td align="center">-</td>
163
+ </tr>
164
+ <tr>
165
+ <td>InternVL2.5-78B</td>
166
+ <td align="center">✅</td>
167
+ <td align="center">-</td>
168
+ <td align="center">-</td>
169
+ <td align="center">-</td>
170
+ <td align="center">72.3</td>
171
+ <td align="center">70.1</td>
172
+ <td align="center">-</td>
173
+ </tr>
174
+ <tr>
175
+ <td>Skywork-R1V-38B</td>
176
+ <td align="center">✅</td>
177
+ <td align="center">94.0</td>
178
+ <td align="center">72.0</td>
179
+ <td align="center">61.6</td>
180
+ <td align="center">71.0</td>
181
+ <td align="center">68.1</td>
182
+ <td align="center">XXX</td>
183
+ </tr>
184
+ </tbody>
185
+ </table>
186
+
187
+ <div align="center">
188
+ <b>Comparison with Larger-Scale Open-Source and Closed-Source Models</b>
189
+ </div>
190
+
191
+ <table align="center">
192
+ <thead>
193
+ <tr>
194
+ <th></th>
195
+ <th align="center"><strong>Benchmark</strong></th>
196
+ <th align="center"><strong>LLM</strong></th>
197
+ <th align="center" colspan="4"><strong>VLM</strong></th>
198
+ </tr>
199
+ <tr>
200
+ <th></th>
201
+ <th></th>
202
+ <th align="center"><strong>QwQ-32B-Preview</strong></th>
203
+ <th align="center"><strong>InternVL-2.5-38B</strong></th>
204
+ <th align="center"><strong>VILA 1.5-40B</strong></th>
205
+ <th align="center"><strong>InternVL2-40B</strong></th>
206
+ <th align="center"><strong>Skywork-R1V-38B</strong></th>
207
+ </tr>
208
+ </thead>
209
+ <tbody>
210
+ <tr>
211
+ <td rowspan="3">Reasoning</td>
212
+ <td>MATH-500</td>
213
+ <td align="center">90.6</td>
214
+ <td align="center">-</td>
215
+ <td align="center">-</td>
216
+ <td align="center">-</td>
217
+ <td align="center"><strong>94.0</strong></td>
218
+ </tr>
219
+ <tr>
220
+ <td>AIME 2024</td>
221
+ <td align="center">50.0</td>
222
+ <td align="center">-</td>
223
+ <td align="center">-</td>
224
+ <td align="center">-</td>
225
+ <td align="center"><strong>72.0</strong></td>
226
+ </tr>
227
+ <tr>
228
+ <td>GPQA</td>
229
+ <td align="center">65.2</td>
230
+ <td align="center">-</td>
231
+ <td align="center">-</td>
232
+ <td align="center">-</td>
233
+ <td align="center">61.6</td>
234
+ </tr>
235
+ <tr>
236
+ <td rowspan="3">Vision</td>
237
+ <td>MathVista(mini)</td>
238
+ <td align="center">-</td>
239
+ <td align="center">71.9</td>
240
+ <td align="center">49.5</td>
241
+ <td align="center">63.7</td>
242
+ <td align="center">71.0</td>
243
+ </tr>
244
+ <tr>
245
+ <td>MMMU(Val)</td>
246
+ <td align="center">-</td>
247
+ <td align="center">63.9</td>
248
+ <td align="center">55.1</td>
249
+ <td align="center">55.2</td>
250
+ <td align="center">68.1</td>
251
+ </tr>
252
+ <tr>
253
+ <td>CSVQA</td>
254
+ <td align="center">-</td>
255
+ <td align="center"></td>
256
+ <td align="center"></td>
257
+ <td align="center"></td>
258
+ <td align="center"></td>
259
+ </tr>
260
+ </tbody>
261
+ </table>
262
+
263
+
264
+ ## 4. Skywork-R1V Family
265
+
266
+ | Model Name | Vision Encoder | Language Model | HF Link |
267
+ | ---------------------- | -------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------- | ------------ |
268
+ | Skywork-R1V-38B | [InternViT-6B-448px-V2_5](https://huggingface.co/OpenGVLab/InternViT-6B-448px-V2_5) | [deepseek-ai/DeepSeek-R1-Distill-Qwen-32B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B) | [🤗 Link](#) |
269
+ | Skywork-R1V-38B-qwq | [InternViT-6B-448px-V2_5](https://huggingface.co/OpenGVLab/InternViT-6B-448px-V2_5) | [Qwen/QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | - |
270
+
271
+ ---
272
+
273
+
274
+ ## 5. Quick Start
275
+
276
+ This section describes how to quickly install, configure, and run the Skywork-R1V model.
277
+
278
+ **Example Steps:**
279
+
280
+ 1. **Clone GitHub repository**
281
+ ```bash
282
+ git clone https://github.com/your-repo
283
+ ```
284
+
285
+ 2. **Install dependencies**
286
+ ```bash
287
+ cd your-repo
288
+ pip install -r requirements.txt
289
+ ```
290
+
291
+ 3. **Run example code**
292
+ ```bash
293
+ python demo.py
294
+ ```
295
+
296
+ ---
297
+
298
+ ## 6. Additional Resources
299
+
300
+ - [📂 GitHub Repository](https://github.com/your-repo)
301
+ - [🗨️ Chat Demo](#)
302
+ - [🚀 Quick Start](#快速入门)
303
+ - [📖 Full Documentation](#)
304
+
305
+
306
+ ## 7. Citation
307
+ If you use Skywork-R1V in your research, please cite:
308
+
309
+ ```
310
+ @article{skywork2025r1v,
311
+ title = {Skywork-R1V: Bridging Vision and Language for Advanced Multimodal Reasoning},
312
+ author = {SkyworkVL Team},
313
+ year = {2025},
314
+ journal = {arXiv preprint arXiv:XXXX.XXXXX},
315
+ url = {https://github.com/skywork-ai/Skywork-R1V}
316
+ }
317
+ ```
318
+
319
+ *This project is released under an open-source license.*
README_ZH.md ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Skywork-R1V
2
+
3
+ <div align="center">
4
+ <img src="logo.jpeg" alt="Introduction Image" width="400" height="400">
5
+ <br>
6
+ <a href="README.md">English</a> | <a href="https://github.com/SkyworkAI/Skywork-R1V">📂 GitHub</a>
7
+ </div>
8
+
9
+
10
+ ## 1. 介绍
11
+
12
+ 我们推出Skywork-R1V,一种多模态推理模型,通过近乎无损的迁移方法,将R1系列文本模型扩展到视觉模态。Skywork-R1V采用轻量级视觉投影器,无需重新训练基础语言模型或视觉编码器,即可实现无缝的多模态适配。为提升视觉-文本对齐,我们开发了结合迭代监督微调(SFT)与组相对策略优化(GRPO)的混合优化策略,显著提高了跨模态融合能力。此外,我们创造了一种自适应长度的思维链(Chain-of-Thought)蒸馏方法用于生成推理数据,动态优化推理链长度以提高推理效率并避免过度推理。该模型在重要多模态推理基准测试中达到最先进水平,在MMMU上得分68.1,在MathVista上得分71.0,可与领先的闭源模型(如Gemini 2.0和Kimi-k1.5)媲美。同时,它还保持了出色的文本推理能力,在AIME达到72.6分,在MATH500达到94.3分。
13
+
14
+ ## 2. 模型概述
15
+
16
+ **架构:**
17
+
18
+ Skywork-R1V采用模块化架构,有效结合视觉和语言能力:
19
+ - **视觉编码器:** 使用视觉Transformer (ViT)作为视觉主干处理图像输入。
20
+ - **视觉投影器:** 轻量级MLP适配器,作为视觉与语言组件间的桥梁。
21
+ - **语言模型:** 采用R1-distilled-Qwen-32B作为具备推理能力的语言模型主干。
22
+
23
+ 模型连接模式为视觉编码器 → MLP适配器 → 语言模型,其中MLP适配器将视觉编码器的输出空间与语言模型的输入空间对齐。这种设计可高效地将文本的推理能力迁移到多模态领域,无需大规模重新训练视觉编码器或语言模型。
24
+
25
+ **关键设计**
26
+ - **先进的多模态推理**
27
+ 擅长跨文本和视觉模态的复杂推理。
28
+ - **迭代训练策略**
29
+ 采用迭代监督和GRPO优化模型对齐和性能。
30
+ - **自适应长度思维链**
31
+ 动态调整推理长度以增强推理效率和准确性。
32
+ - **可扩展性能**
33
+ 在数学、编程和多模态任务上性能媲美专有模型。
34
+
35
+ ## 3. 评估
36
+
37
+ <div align="center">
38
+ <img src="eval.jpeg" width="600" height="200" alt="skywork_r1v_eval" />
39
+ </div>
40
+
41
+ <div align="center">
42
+ <b>Evaluation results of state-of-the-art LLMs and VLMs</b>
43
+ </div>
44
+ <table>
45
+ <thead>
46
+ <tr>
47
+ <th></th>
48
+ <th align="center"><strong>Vision</strong></th>
49
+ <th align="center" colspan="3"><strong>Reasoning</strong></th>
50
+ <th align="center" colspan="3"><strong>Vision</strong></th>
51
+ </tr>
52
+ <tr>
53
+ <th></th>
54
+ <th></th>
55
+ <th align="center"><strong>MATH-500</strong></th>
56
+ <th align="center"><strong>AIME 2024</strong></th>
57
+ <th align="center"><strong>GPQA</strong></th>
58
+ <th align="center"><strong>MathVista(mini)</strong></th>
59
+ <th align="center"><strong>MMMU(Val)</strong></th>
60
+ <th align="center"><strong>CSVQA</strong></th>
61
+ </tr>
62
+ <tr>
63
+ <th></th>
64
+ <th></th>
65
+ <th align="center">pass@1</th>
66
+ <th align="center">pass@1</th>
67
+ <th align="center">pass@1</th>
68
+ <th align="center">pass@1</th>
69
+ <th align="center">pass@1</th>
70
+ <th align="center">pass@1</th>
71
+ </tr>
72
+ </thead>
73
+ <tbody>
74
+ <tr>
75
+ <td>Qwen2.5-72B-Instruct</td>
76
+ <td align="center">❌</td>
77
+ <td align="center">82.6</td>
78
+ <td align="center">23.3</td>
79
+ <td align="center">49.0</td>
80
+ <td align="center">-</td>
81
+ <td align="center">-</td>
82
+ <td align="center">-</td>
83
+ </tr>
84
+ <tr>
85
+ <td>Deepseek V3</td>
86
+ <td align="center">❌</td>
87
+ <td align="center">90.2</td>
88
+ <td align="center">39.2</td>
89
+ <td align="center">59.1</td>
90
+ <td align="center">-</td>
91
+ <td align="center">-</td>
92
+ <td align="center">-</td>
93
+ </tr>
94
+ <tr>
95
+ <td>Deepseek R1</td>
96
+ <td align="center">❌</td>
97
+ <td align="center">97.3</td>
98
+ <td align="center">79.8</td>
99
+ <td align="center">71.5</td>
100
+ <td align="center">-</td>
101
+ <td align="center">-</td>
102
+ <td align="center">-</td>
103
+ </tr>
104
+ <tr>
105
+ <td>Claude 3.5 Sonnet</td>
106
+ <td align="center">✅</td>
107
+ <td align="center">78.3</td>
108
+ <td align="center">16.0</td>
109
+ <td align="center">65.0</td>
110
+ <td align="center">67.7</td>
111
+ <td align="center">68.3</td>
112
+ <td align="center">-</td>
113
+ </tr>
114
+ <tr>
115
+ <td>GPT-4o</td>
116
+ <td align="center">✅</td>
117
+ <td align="center">76.6</td>
118
+ <td align="center">9.3</td>
119
+ <td align="center">53.6</td>
120
+ <td align="center">63.8</td>
121
+ <td align="center">69.1</td>
122
+ <td align="center">-</td>
123
+ </tr>
124
+ <tr>
125
+ <td>Kimi k1.5</td>
126
+ <td align="center">✅</td>
127
+ <td align="center">96.2</td>
128
+ <td align="center">77.5</td>
129
+ <td align="center">-</td>
130
+ <td align="center">74.9</td>
131
+ <td align="center">70.0</td>
132
+ <td align="center">-</td>
133
+ </tr>
134
+ <tr>
135
+ <td>Qwen2.5-VL-72B-Instruct</td>
136
+ <td align="center">✅</td>
137
+ <td align="center">-</td>
138
+ <td align="center">-</td>
139
+ <td align="center">-</td>
140
+ <td align="center">74.8</td>
141
+ <td align="center">70.2</td>
142
+ <td align="center">-</td>
143
+ </tr>
144
+ <tr>
145
+ <td>LLaVA-Onevision-72B</td>
146
+ <td align="center">✅</td>
147
+ <td align="center">-</td>
148
+ <td align="center">-</td>
149
+ <td align="center">-</td>
150
+ <td align="center">67.5</td>
151
+ <td align="center">56.8</td>
152
+ <td align="center">-</td>
153
+ </tr>
154
+ <tr>
155
+ <td>InternVL2-Llama3-76B</td>
156
+ <td align="center">✅</td>
157
+ <td align="center">-</td>
158
+ <td align="center">-</td>
159
+ <td align="center">-</td>
160
+ <td align="center">65.5</td>
161
+ <td align="center">58.3</td>
162
+ <td align="center">-</td>
163
+ </tr>
164
+ <tr>
165
+ <td>InternVL2.5-78B</td>
166
+ <td align="center">✅</td>
167
+ <td align="center">-</td>
168
+ <td align="center">-</td>
169
+ <td align="center">-</td>
170
+ <td align="center">72.3</td>
171
+ <td align="center">70.1</td>
172
+ <td align="center">-</td>
173
+ </tr>
174
+ <tr>
175
+ <td>Skywork-R1V-38B</td>
176
+ <td align="center">✅</td>
177
+ <td align="center">94.0</td>
178
+ <td align="center">72.0</td>
179
+ <td align="center">61.6</td>
180
+ <td align="center">71.0</td>
181
+ <td align="center">68.1</td>
182
+ <td align="center">XXX</td>
183
+ </tr>
184
+ </tbody>
185
+ </table>
186
+
187
+ <div align="center">
188
+ <b>Comparison with Larger-Scale Open-Source and Closed-Source Models</b>
189
+ </div>
190
+
191
+ <table align="center">
192
+ <thead>
193
+ <tr>
194
+ <th></th>
195
+ <th align="center"><strong>Benchmark</strong></th>
196
+ <th align="center"><strong>LLM</strong></th>
197
+ <th align="center" colspan="4"><strong>VLM</strong></th>
198
+ </tr>
199
+ <tr>
200
+ <th></th>
201
+ <th></th>
202
+ <th align="center"><strong>QwQ-32B-Preview</strong></th>
203
+ <th align="center"><strong>InternVL-2.5-38B</strong></th>
204
+ <th align="center"><strong>VILA 1.5-40B</strong></th>
205
+ <th align="center"><strong>InternVL2-40B</strong></th>
206
+ <th align="center"><strong>Skywork-R1V-38B</strong></th>
207
+ </tr>
208
+ </thead>
209
+ <tbody>
210
+ <tr>
211
+ <td rowspan="3">Reasoning</td>
212
+ <td>MATH-500</td>
213
+ <td align="center">90.6</td>
214
+ <td align="center">-</td>
215
+ <td align="center">-</td>
216
+ <td align="center">-</td>
217
+ <td align="center"><strong>94.0</strong></td>
218
+ </tr>
219
+ <tr>
220
+ <td>AIME 2024</td>
221
+ <td align="center">50.0</td>
222
+ <td align="center">-</td>
223
+ <td align="center">-</td>
224
+ <td align="center">-</td>
225
+ <td align="center"><strong>72.0</strong></td>
226
+ </tr>
227
+ <tr>
228
+ <td>GPQA</td>
229
+ <td align="center">65.2</td>
230
+ <td align="center">-</td>
231
+ <td align="center">-</td>
232
+ <td align="center">-</td>
233
+ <td align="center">61.6</td>
234
+ </tr>
235
+ <tr>
236
+ <td rowspan="3">Vision</td>
237
+ <td>MathVista(mini)</td>
238
+ <td align="center">-</td>
239
+ <td align="center">71.9</td>
240
+ <td align="center">49.5</td>
241
+ <td align="center">63.7</td>
242
+ <td align="center">71.0</td>
243
+ </tr>
244
+ <tr>
245
+ <td>MMMU(Val)</td>
246
+ <td align="center">-</td>
247
+ <td align="center">63.9</td>
248
+ <td align="center">55.1</td>
249
+ <td align="center">55.2</td>
250
+ <td align="center">68.1</td>
251
+ </tr>
252
+ <tr>
253
+ <td>CSVQA</td>
254
+ <td align="center">-</td>
255
+ <td align="center"></td>
256
+ <td align="center"></td>
257
+ <td align="center"></td>
258
+ <td align="center"></td>
259
+ </tr>
260
+ </tbody>
261
+ </table>
262
+
263
+ ## 4. Skywork-R1V家族
264
+
265
+ | Model Name | Vision Encoder | Language Model | HF Link |
266
+ | ---------------------- | -------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------- | ------------ |
267
+ | Skywork-R1V-38B | [InternViT-6B-448px-V2_5](https://huggingface.co/OpenGVLab/InternViT-6B-448px-V2_5) | [deepseek-ai/DeepSeek-R1-Distill-Qwen-32B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B) | [🤗 Link](#) |
268
+ | Skywork-R1V-38B-qwq | [InternViT-6B-448px-V2_5](https://huggingface.co/OpenGVLab/InternViT-6B-448px-V2_5) | [Qwen/QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | - |
269
+
270
+ ---
271
+
272
+ ## 5. 快速开始
273
+
274
+ **示例步骤:**
275
+
276
+ 1. **克隆GitHub仓库**
277
+ ```bash
278
+ git clone https://github.com/your-repo
279
+ ```
280
+
281
+ 2. **安装依赖**
282
+ ```bash
283
+ cd your-repo
284
+ pip install -r requirements.txt
285
+ ```
286
+
287
+ 3. **运行示例代码**
288
+ ```bash
289
+ python demo.py
290
+ ```
291
+
292
+ ---
293
+
294
+ ## 6. 附加资源
295
+
296
+ - [📂 GitHub仓库](https://github.com/your-repo)
297
+ - [🗨️ Chat Demo](#)
298
+ - [🚀 快速入门](#快速入门)
299
+ - [📖 完整文档](#)
300
+
301
+ ---
302
+
303
+ ## 7. 引用
304
+ 如果您在研究中使用了Skywork-R1V,请引用:
305
+
306
+ ```
307
+ @article{skywork2025r1v,
308
+ title = {Skywork-R1V: Bridging Vision and Language for Advanced Multimodal Reasoning},
309
+ author = {SkyworkVL Team},
310
+ year = {2025},
311
+ journal = {arXiv preprint arXiv:XXXX.XXXXX},
312
+ url = {https://github.com/skywork-ai/Skywork-R1V}
313
+ }
314
+ ```
315
+
316
+ *本项目采用开源许可证发布。*
317
+
added_tokens.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</box>": 92552,
3
+ "</img>": 92545,
4
+ "</quad>": 92548,
5
+ "</ref>": 92550,
6
+ "<IMG_CONTEXT>": 92546,
7
+ "<box>": 92551,
8
+ "<img>": 92544,
9
+ "<quad>": 92547,
10
+ "<ref>": 92549
11
+ }
config.json ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_commit_hash": null,
3
+ "_name_or_path": "",
4
+ "auto_map": {
5
+ "AutoConfig": "configuration_skywork_chat.SkyworkChatConfig",
6
+ "AutoModel": "modeling_skywork_chat.SkyworkChatModel",
7
+ "AutoModelForCausalLM": "modeling_skywork_chat.SkyworkChatModel"
8
+ },
9
+ "downsample_ratio": 0.5,
10
+ "dynamic_image_size": true,
11
+ "force_image_size": 448,
12
+ "hidden_size": 5120,
13
+ "llm_config": {
14
+ "_name_or_path": "",
15
+ "add_cross_attention": false,
16
+ "architectures": [
17
+ "Qwen2ForCausalLM"
18
+ ],
19
+ "attention_dropout": 0.0,
20
+ "bad_words_ids": null,
21
+ "begin_suppress_tokens": null,
22
+ "bos_token_id": 151643,
23
+ "chunk_size_feed_forward": 0,
24
+ "cross_attention_hidden_size": null,
25
+ "decoder_start_token_id": null,
26
+ "diversity_penalty": 0.0,
27
+ "do_sample": false,
28
+ "early_stopping": false,
29
+ "encoder_no_repeat_ngram_size": 0,
30
+ "eos_token_id": 151643,
31
+ "exponential_decay_length_penalty": null,
32
+ "finetuning_task": null,
33
+ "forced_bos_token_id": null,
34
+ "forced_eos_token_id": null,
35
+ "hidden_act": "silu",
36
+ "hidden_size": 5120,
37
+ "id2label": {
38
+ "0": "LABEL_0",
39
+ "1": "LABEL_1"
40
+ },
41
+ "initializer_range": 0.02,
42
+ "intermediate_size": 27648,
43
+ "is_decoder": false,
44
+ "is_encoder_decoder": false,
45
+ "label2id": {
46
+ "LABEL_0": 0,
47
+ "LABEL_1": 1
48
+ },
49
+ "length_penalty": 1.0,
50
+ "max_length": 20,
51
+ "max_position_embeddings": 131072,
52
+ "max_window_layers": 64,
53
+ "min_length": 0,
54
+ "model_type": "qwen2",
55
+ "no_repeat_ngram_size": 0,
56
+ "num_attention_heads": 40,
57
+ "num_beam_groups": 1,
58
+ "num_beams": 1,
59
+ "num_hidden_layers": 64,
60
+ "num_key_value_heads": 8,
61
+ "num_return_sequences": 1,
62
+ "output_attentions": false,
63
+ "output_hidden_states": false,
64
+ "output_scores": false,
65
+ "pad_token_id": null,
66
+ "prefix": null,
67
+ "problem_type": null,
68
+ "pruned_heads": {},
69
+ "remove_invalid_values": false,
70
+ "repetition_penalty": 1.0,
71
+ "return_dict": true,
72
+ "return_dict_in_generate": false,
73
+ "rms_norm_eps": 1e-05,
74
+ "rope_theta": 1000000.0,
75
+ "sep_token_id": null,
76
+ "sliding_window": 131072,
77
+ "suppress_tokens": null,
78
+ "task_specific_params": null,
79
+ "temperature": 1.0,
80
+ "tf_legacy_loss": false,
81
+ "tie_encoder_decoder": false,
82
+ "tie_word_embeddings": false,
83
+ "tokenizer_class": null,
84
+ "top_k": 50,
85
+ "top_p": 1.0,
86
+ "torch_dtype": "bfloat16",
87
+ "torchscript": false,
88
+ "transformers_version": "4.37.2",
89
+ "typical_p": 1.0,
90
+ "use_bfloat16": false,
91
+ "use_cache": false,
92
+ "use_sliding_window": false,
93
+ "vocab_size": 152064
94
+ },
95
+ "max_dynamic_patch": 6,
96
+ "min_dynamic_patch": 1,
97
+ "model_type": "skywork_chat",
98
+ "pad2square": false,
99
+ "ps_version": "v2",
100
+ "select_layer": -1,
101
+ "template": "skywork-r1v-chat",
102
+ "tie_word_embeddings": false,
103
+ "torch_dtype": "bfloat16",
104
+ "transformers_version": null,
105
+ "use_backbone_lora": 0,
106
+ "use_llm_lora": 0,
107
+ "use_thumbnail": true,
108
+ "vision_config": {
109
+ "attention_dropout": 0.0,
110
+ "drop_path_rate": 0.0,
111
+ "dropout": 0.0,
112
+ "hidden_act": "gelu",
113
+ "hidden_size": 3200,
114
+ "image_size": 448,
115
+ "initializer_factor": 0.1,
116
+ "initializer_range": 1e-10,
117
+ "intermediate_size": 12800,
118
+ "layer_norm_eps": 1e-06,
119
+ "norm_type": "rms_norm",
120
+ "num_attention_heads": 25,
121
+ "num_channels": 3,
122
+ "num_hidden_layers": 45,
123
+ "output_attentions": false,
124
+ "output_hidden_states": false,
125
+ "patch_size": 14,
126
+ "qk_normalization": true,
127
+ "qkv_bias": false,
128
+ "return_dict": true,
129
+ "torch_dtype": "bfloat16",
130
+ "transformers_version": "4.37.2",
131
+ "use_bfloat16": true,
132
+ "use_flash_attn": true
133
+ }
134
+ }
configuration.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"framework": "pytorch", "task": "image-text-to-text", "allow_remote": true}
configuration_skywork_chat.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ from transformers import AutoConfig, LlamaConfig
4
+ from transformers.configuration_utils import PretrainedConfig
5
+ from transformers.utils import logging
6
+
7
+ from .configuration_skywork_vit import SkyworkVisionConfig
8
+ from .configuration_skywork_lm2 import SkyworkLM2Config
9
+ from transformers import Qwen2Config, Qwen2ForCausalLM
10
+
11
+ logger = logging.get_logger(__name__)
12
+
13
+
14
+ class SkyworkChatConfig(PretrainedConfig):
15
+ model_type = 'skywork_chat'
16
+ is_composition = True
17
+
18
+ def __init__(
19
+ self,
20
+ vision_config=None,
21
+ llm_config=None,
22
+ use_backbone_lora=0,
23
+ use_llm_lora=0,
24
+ select_layer=-1,
25
+ force_image_size=None,
26
+ downsample_ratio=0.5,
27
+ template=None,
28
+ dynamic_image_size=False,
29
+ use_thumbnail=False,
30
+ ps_version='v1',
31
+ min_dynamic_patch=1,
32
+ max_dynamic_patch=6,
33
+ **kwargs):
34
+ super().__init__(**kwargs)
35
+ if vision_config is None:
36
+ vision_config = {'architectures': ['SkyworkVisionModel']}
37
+ logger.info('vision_config is None. Initializing the SkyworkVisionConfig with default values.')
38
+
39
+ if llm_config is None:
40
+ llm_config = {'architectures': ['Qwen2ForCausalLM']}
41
+ logger.info('llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`).')
42
+
43
+ self.vision_config = SkyworkVisionConfig(**vision_config)
44
+ if llm_config.get('architectures')[0] == 'LlamaForCausalLM':
45
+ self.llm_config = LlamaConfig(**llm_config)
46
+ elif llm_config.get('architectures')[0] == 'Qwen2ForCausalLM':
47
+ self.llm_config = Qwen2Config(**llm_config)
48
+ else:
49
+ raise ValueError('Unsupported architecture: {}'.format(llm_config.get('architectures')[0]))
50
+
51
+
52
+ self.use_backbone_lora = use_backbone_lora
53
+ self.use_llm_lora = use_llm_lora
54
+ self.select_layer = select_layer
55
+ self.force_image_size = force_image_size
56
+ self.downsample_ratio = downsample_ratio
57
+ self.template = template
58
+ self.dynamic_image_size = dynamic_image_size
59
+ self.use_thumbnail = use_thumbnail
60
+ self.ps_version = ps_version # pixel shuffle version
61
+ self.min_dynamic_patch = min_dynamic_patch
62
+ self.max_dynamic_patch = max_dynamic_patch
63
+
64
+ logger.info(f'vision_select_layer: {self.select_layer}')
65
+ logger.info(f'ps_version: {self.ps_version}')
66
+ logger.info(f'min_dynamic_patch: {self.min_dynamic_patch}')
67
+ logger.info(f'max_dynamic_patch: {self.max_dynamic_patch}')
68
+
69
+ def to_dict(self):
70
+ """
71
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
72
+
73
+ Returns:
74
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
75
+ """
76
+ output = copy.deepcopy(self.__dict__)
77
+ output['vision_config'] = self.vision_config.to_dict()
78
+ output['llm_config'] = self.llm_config.to_dict()
79
+ output['model_type'] = self.__class__.model_type
80
+ output['use_backbone_lora'] = self.use_backbone_lora
81
+ output['use_llm_lora'] = self.use_llm_lora
82
+ output['select_layer'] = self.select_layer
83
+ output['force_image_size'] = self.force_image_size
84
+ output['downsample_ratio'] = self.downsample_ratio
85
+ output['template'] = self.template
86
+ output['dynamic_image_size'] = self.dynamic_image_size
87
+ output['use_thumbnail'] = self.use_thumbnail
88
+ output['ps_version'] = self.ps_version
89
+ output['min_dynamic_patch'] = self.min_dynamic_patch
90
+ output['max_dynamic_patch'] = self.max_dynamic_patch
91
+
92
+ return output
configuration_skywork_lm2.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) The Skywork team and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # This code is based on transformers/src/transformers/models/llama/configuration_llama.py
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ SkyworkLM2 model configuration"""
17
+
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.utils import logging
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ # Modified from transformers.model.llama.configuration_llama.LlamaConfig
25
+ class SkyworkLM2Config(PretrainedConfig):
26
+ r"""
27
+ Args:
28
+ vocab_size (`int`, *optional*, defaults to 32000):
29
+ Vocabulary size of the SkyworkLM2 model. Defines the number of different tokens that can be represented by the
30
+ `inputs_ids` passed when calling [`SkyworkLM2Model`]
31
+ hidden_size (`int`, *optional*, defaults to 4096):
32
+ Dimension of the hidden representations.
33
+ intermediate_size (`int`, *optional*, defaults to 11008):
34
+ Dimension of the MLP representations.
35
+ num_hidden_layers (`int`, *optional*, defaults to 32):
36
+ Number of hidden layers in the Transformer encoder.
37
+ num_attention_heads (`int`, *optional*, defaults to 32):
38
+ Number of attention heads for each attention layer in the Transformer encoder.
39
+ num_key_value_heads (`int`, *optional*):
40
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
41
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
42
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
43
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
44
+ by meanpooling all the original heads within that group. For more details checkout [this
45
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
46
+ `num_attention_heads`.
47
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
48
+ The non-linear activation function (function or string) in the decoder.
49
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
50
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
51
+ just in case (e.g., 512 or 1024 or 2048).
52
+ initializer_range (`float`, *optional*, defaults to 0.02):
53
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
54
+ rms_norm_eps (`float`, *optional*, defaults to 1e-12):
55
+ The epsilon used by the rms normalization layers.
56
+ use_cache (`bool`, *optional*, defaults to `True`):
57
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
58
+ relevant if `config.is_decoder=True`.
59
+ tie_word_embeddings(`bool`, *optional*, defaults to `False`):
60
+ Whether to tie weight embeddings
61
+ Example:
62
+
63
+ """
64
+ _auto_class = 'AutoConfig'
65
+
66
+ def __init__(
67
+ self,
68
+ vocab_size=103168,
69
+ hidden_size=4096,
70
+ intermediate_size=11008,
71
+ num_hidden_layers=32,
72
+ num_attention_heads=32,
73
+ num_key_value_heads=None,
74
+ hidden_act='silu',
75
+ max_position_embeddings=2048,
76
+ initializer_range=0.02,
77
+ rms_norm_eps=1e-6,
78
+ use_cache=True,
79
+ pad_token_id=0,
80
+ bos_token_id=1,
81
+ eos_token_id=2,
82
+ tie_word_embeddings=False,
83
+ bias=True,
84
+ rope_theta=10000,
85
+ rope_scaling=None,
86
+ attn_implementation='eager',
87
+ **kwargs,
88
+ ):
89
+ self.vocab_size = vocab_size
90
+ self.max_position_embeddings = max_position_embeddings
91
+ self.hidden_size = hidden_size
92
+ self.intermediate_size = intermediate_size
93
+ self.num_hidden_layers = num_hidden_layers
94
+ self.num_attention_heads = num_attention_heads
95
+ self.bias = bias
96
+
97
+ if num_key_value_heads is None:
98
+ num_key_value_heads = num_attention_heads
99
+ self.num_key_value_heads = num_key_value_heads
100
+
101
+ self.hidden_act = hidden_act
102
+ self.initializer_range = initializer_range
103
+ self.rms_norm_eps = rms_norm_eps
104
+ self.use_cache = use_cache
105
+ self.rope_theta = rope_theta
106
+ self.rope_scaling = rope_scaling
107
+ self._rope_scaling_validation()
108
+
109
+ self.attn_implementation = attn_implementation
110
+ if self.attn_implementation is None:
111
+ self.attn_implementation = 'eager'
112
+ super().__init__(
113
+ pad_token_id=pad_token_id,
114
+ bos_token_id=bos_token_id,
115
+ eos_token_id=eos_token_id,
116
+ tie_word_embeddings=tie_word_embeddings,
117
+ **kwargs,
118
+ )
119
+
120
+ def _rope_scaling_validation(self):
121
+ """
122
+ Validate the `rope_scaling` configuration.
123
+ """
124
+ if self.rope_scaling is None:
125
+ return
126
+
127
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
128
+ raise ValueError(
129
+ '`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, '
130
+ f'got {self.rope_scaling}'
131
+ )
132
+ rope_scaling_type = self.rope_scaling.get('type', None)
133
+ rope_scaling_factor = self.rope_scaling.get('factor', None)
134
+ if rope_scaling_type is None or rope_scaling_type not in ['linear', 'dynamic']:
135
+ raise ValueError(
136
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
137
+ )
138
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor < 1.0:
139
+ raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {rope_scaling_factor}")
configuration_skywork_vit.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Union
3
+
4
+ from transformers.configuration_utils import PretrainedConfig
5
+ from transformers.utils import logging
6
+
7
+ logger = logging.get_logger(__name__)
8
+
9
+
10
+ class SkyworkVisionConfig(PretrainedConfig):
11
+ r"""
12
+ Args:
13
+ num_channels (`int`, *optional*, defaults to 3):
14
+ Number of color channels in the input images (e.g., 3 for RGB).
15
+ patch_size (`int`, *optional*, defaults to 14):
16
+ The size (resolution) of each patch.
17
+ image_size (`int`, *optional*, defaults to 224):
18
+ The size (resolution) of each image.
19
+ qkv_bias (`bool`, *optional*, defaults to `False`):
20
+ Whether to add a bias to the queries and values in the self-attention layers.
21
+ hidden_size (`int`, *optional*, defaults to 3200):
22
+ Dimensionality of the encoder layers and the pooler layer.
23
+ num_attention_heads (`int`, *optional*, defaults to 25):
24
+ Number of attention heads for each attention layer in the Transformer encoder.
25
+ intermediate_size (`int`, *optional*, defaults to 12800):
26
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
27
+ qk_normalization (`bool`, *optional*, defaults to `True`):
28
+ Whether to normalize the queries and keys in the self-attention layers.
29
+ num_hidden_layers (`int`, *optional*, defaults to 48):
30
+ Number of hidden layers in the Transformer encoder.
31
+ use_flash_attn (`bool`, *optional*, defaults to `True`):
32
+ Whether to use flash attention mechanism.
33
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
34
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
35
+ `"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported.
36
+ layer_norm_eps (`float`, *optional*, defaults to 1e-6):
37
+ The epsilon used by the layer normalization layers.
38
+ dropout (`float`, *optional*, defaults to 0.0):
39
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
40
+ drop_path_rate (`float`, *optional*, defaults to 0.0):
41
+ Dropout rate for stochastic depth.
42
+ attention_dropout (`float`, *optional*, defaults to 0.0):
43
+ The dropout ratio for the attention probabilities.
44
+ initializer_range (`float`, *optional*, defaults to 0.02):
45
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
46
+ initializer_factor (`float`, *optional*, defaults to 0.1):
47
+ A factor for layer scale.
48
+ """
49
+
50
+
51
+ def __init__(
52
+ self,
53
+ num_channels=3,
54
+ patch_size=14,
55
+ image_size=224,
56
+ qkv_bias=False,
57
+ hidden_size=3200,
58
+ num_attention_heads=25,
59
+ intermediate_size=12800,
60
+ qk_normalization=True,
61
+ num_hidden_layers=48,
62
+ use_flash_attn=True,
63
+ hidden_act='gelu',
64
+ norm_type='rms_norm',
65
+ layer_norm_eps=1e-6,
66
+ dropout=0.0,
67
+ drop_path_rate=0.0,
68
+ attention_dropout=0.0,
69
+ initializer_range=0.02,
70
+ initializer_factor=0.1,
71
+ **kwargs,
72
+ ):
73
+ super().__init__(**kwargs)
74
+
75
+ self.hidden_size = hidden_size
76
+ self.intermediate_size = intermediate_size
77
+ self.dropout = dropout
78
+ self.drop_path_rate = drop_path_rate
79
+ self.num_hidden_layers = num_hidden_layers
80
+ self.num_attention_heads = num_attention_heads
81
+ self.num_channels = num_channels
82
+ self.patch_size = patch_size
83
+ self.image_size = image_size
84
+ self.initializer_range = initializer_range
85
+ self.initializer_factor = initializer_factor
86
+ self.attention_dropout = attention_dropout
87
+ self.layer_norm_eps = layer_norm_eps
88
+ self.hidden_act = hidden_act
89
+ self.norm_type = norm_type
90
+ self.qkv_bias = qkv_bias
91
+ self.qk_normalization = qk_normalization
92
+ self.use_flash_attn = use_flash_attn
93
+
94
+ @classmethod
95
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig':
96
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
97
+
98
+ if 'vision_config' in config_dict:
99
+ config_dict = config_dict['vision_config']
100
+
101
+ return cls.from_dict(config_dict, **kwargs)
conversation.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Conversation prompt templates.
3
+
4
+ We kindly request that you import fastchat instead of copying this file if you wish to use it.
5
+ If you have changes in mind, please contribute back so the community can benefit collectively and continue to maintain these valuable templates.
6
+ """
7
+
8
+ import dataclasses
9
+ from enum import IntEnum, auto
10
+ from typing import Any, Dict, List, Tuple, Union
11
+
12
+
13
+ class SeparatorStyle(IntEnum):
14
+ """Separator styles."""
15
+
16
+ ADD_COLON_SINGLE = auto()
17
+ ADD_COLON_TWO = auto()
18
+ ADD_COLON_SPACE_SINGLE = auto()
19
+ NO_COLON_SINGLE = auto()
20
+ NO_COLON_TWO = auto()
21
+ ADD_NEW_LINE_SINGLE = auto()
22
+ LLAMA2 = auto()
23
+ CHATGLM = auto()
24
+ CHATML = auto()
25
+ DOLLY = auto()
26
+ RWKV = auto()
27
+ PHOENIX = auto()
28
+ ROBIN = auto()
29
+ FALCON_CHAT = auto()
30
+ CHATGLM3 = auto()
31
+ MPT = auto()
32
+
33
+
34
+ @dataclasses.dataclass
35
+ class Conversation:
36
+ """A class that manages prompt templates and keeps all conversation history."""
37
+
38
+ # The name of this template
39
+ name: str
40
+ # The template of the system prompt
41
+ system_template: str = '{system_message}'
42
+ # The system message
43
+ system_message: str = ''
44
+ # The names of two roles
45
+ roles: Tuple[str] = ('USER', 'ASSISTANT')
46
+ # All messages. Each item is (role, message).
47
+ messages: List[List[str]] = ()
48
+ # The number of few shot examples
49
+ offset: int = 0
50
+ # The separator style and configurations
51
+ sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE
52
+ sep: str = '\n'
53
+ sep2: str = None
54
+ # Stop criteria (the default one is EOS token)
55
+ stop_str: Union[str, List[str]] = None
56
+ # Stops generation if meeting any token in this list
57
+ stop_token_ids: List[int] = None
58
+
59
+ def get_prompt(self) -> str:
60
+ """Get the prompt for generation."""
61
+ system_prompt = self.system_template.format(system_message=self.system_message)
62
+ ret = system_prompt
63
+ for role, message in self.messages:
64
+ if message:
65
+ if type(message) is tuple:
66
+ message, _, _ = message
67
+ ret += role + message
68
+ else:
69
+ ret += role
70
+
71
+ return ret
72
+
73
+ def set_system_message(self, system_message: str):
74
+ """Set the system message."""
75
+ self.system_message = system_message
76
+
77
+ def append_message(self, role: str, message: str):
78
+ """Append a new message."""
79
+ self.messages.append([role, message])
80
+
81
+ def update_last_message(self, message: str):
82
+ """Update the last output.
83
+
84
+ The last message is typically set to be None when constructing the prompt,
85
+ so we need to update it in-place after getting the response from a model.
86
+ """
87
+ self.messages[-1][1] = message
88
+
89
+ def to_gradio_chatbot(self):
90
+ """Convert the conversation to gradio chatbot format."""
91
+ ret = []
92
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
93
+ if i % 2 == 0:
94
+ ret.append([msg, None])
95
+ else:
96
+ ret[-1][-1] = msg
97
+ return ret
98
+
99
+ def to_openai_api_messages(self):
100
+ """Convert the conversation to OpenAI chat completion format."""
101
+ ret = [{'role': 'system', 'content': self.system_message}]
102
+
103
+ for i, (_, msg) in enumerate(self.messages[self.offset :]):
104
+ if i % 2 == 0:
105
+ ret.append({'role': 'user', 'content': msg})
106
+ else:
107
+ if msg is not None:
108
+ ret.append({'role': 'assistant', 'content': msg})
109
+ return ret
110
+
111
+ def copy(self):
112
+ return Conversation(
113
+ name=self.name,
114
+ system_template=self.system_template,
115
+ system_message=self.system_message,
116
+ roles=self.roles,
117
+ messages=[[x, y] for x, y in self.messages],
118
+ offset=self.offset,
119
+ sep_style=self.sep_style,
120
+ sep=self.sep,
121
+ sep2=self.sep2,
122
+ stop_str=self.stop_str,
123
+ stop_token_ids=self.stop_token_ids,
124
+ )
125
+
126
+ def dict(self):
127
+ return {
128
+ 'template_name': self.name,
129
+ 'system_message': self.system_message,
130
+ 'roles': self.roles,
131
+ 'messages': self.messages,
132
+ 'offset': self.offset,
133
+ }
134
+
135
+
136
+ # A global registry for all conversation templates
137
+ conv_templates: Dict[str, Conversation] = {}
138
+
139
+
140
+ def register_conv_template(template: Conversation, override: bool = False):
141
+ """Register a new conversation template."""
142
+ if not override:
143
+ assert (
144
+ template.name not in conv_templates
145
+ ), f'{template.name} has been registered.'
146
+
147
+ conv_templates[template.name] = template
148
+
149
+
150
+ def get_conv_template(name: str) -> Conversation:
151
+ """Get a conversation template."""
152
+ return conv_templates[name].copy()
153
+
154
+ register_conv_template(
155
+ Conversation(
156
+ name='skywork-r1v-chat',
157
+ system_template='<|begin▁of▁sentence|>{system_message}',
158
+ system_message='',
159
+ roles=('<|User|>\n', '<|Assistant|><think>\n'),
160
+ sep_style=SeparatorStyle.MPT,
161
+ sep='<|end▁of▁sentence|>',
162
+ )
163
+ )
generation_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 151646,
4
+ "eos_token_id": 151643,
5
+ "do_sample": true,
6
+ "temperature": 0.6,
7
+ "top_p": 0.95,
8
+ "transformers_version": "4.39.3"
9
+ }
model-00001-of-00016.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2373921d7d9a490e820d98dd62776fab4edff4ed655125c7a064c4b3ac72fd17
3
+ size 4988569440
model-00002-of-00016.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8bb2779f566d99d96508a46b5d9a45273ddc5a7eea53f2cc424041e1ed4f1a43
3
+ size 4937253584
model-00003-of-00016.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:20faf8e5c3fc1ae0a02ab4cbe3d0407135970dfdc72f16ba44ea739ad5d5f21a
3
+ size 4780642896
model-00004-of-00016.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:04fb02e6c393cb7443fc2d63942de331e92278ae118850d9fb3641a99daf808b
3
+ size 4876060256
model-00005-of-00016.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a1e59db52268bfc5e83a61144f840630a8b130a068bf016867ddabbce961814f
3
+ size 4876060272
model-00006-of-00016.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:14bbe28e02c34a3a1929cce1449bc3abbe7f5760d702aef19e2b68d2509b15fd
3
+ size 4876060312
model-00007-of-00016.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:acc431255af7fda8193a91117cdfb61c70945a687f6a0bddd2494716f51f6469
3
+ size 4876060312
model-00008-of-00016.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0913c56f6d93bc03698e96b4cf24b6c58bfd6a0b9d9e2f6d6bc5ba4f7f959292
3
+ size 4876060312
model-00009-of-00016.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9be9378569371d0fc815bd51ade809522ff061427910b684103d978fd6865534
3
+ size 4876060312
model-00010-of-00016.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7aafb8b7b718b0523ce9cfdcfa443864e163db3bd8ad9fa79c9d3ef0f792bbc8
3
+ size 4876060312
model-00011-of-00016.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec336bb125194e259262b2cd60982435376fa2430c7f4db791355bb120ea0bd3
3
+ size 4876060312
model-00012-of-00016.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dc0942288d5161ada1d130d44d5d194f3b74f76c2637541a2e9267556b26feb6
3
+ size 4876060312
model-00013-of-00016.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e0faa8ba5caffc27a55b3d3612cdab22a71b782b7507e3938e152d615b51d55
3
+ size 4876060312
model-00014-of-00016.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1aa651e0e0bd996e9b856615a58113f29654db50271f6d7751c605aecb99282d
3
+ size 4876060312
model-00015-of-00016.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f443e511bddf529426b4f305c032e63d747826898b4731741d14e0a87debf6df
3
+ size 4876060312
model-00016-of-00016.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5c349b706b0711b33aed44b63d6d652370fe85f54f213af561e270db4e4ebf1
3
+ size 3565298616
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_skywork_chat.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import torch.utils.checkpoint
5
+ import transformers
6
+ from torch import nn
7
+ from torch.nn import CrossEntropyLoss
8
+ from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM,
9
+ LlamaTokenizer)
10
+ from transformers.modeling_outputs import CausalLMOutputWithPast
11
+ from transformers.modeling_utils import PreTrainedModel
12
+ from transformers.utils import ModelOutput, logging
13
+
14
+ from .configuration_skywork_chat import SkyworkChatConfig
15
+ from .conversation import get_conv_template
16
+ from .modeling_skywork_vit import SkyworkVisionModel, has_flash_attn
17
+ from .modeling_skywork_lm2 import SkyworkLM2ForCausalLM
18
+
19
+ from transformers import Qwen2Config, Qwen2ForCausalLM
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ def version_cmp(v1, v2, op='eq'):
25
+ import operator
26
+
27
+ from packaging import version
28
+ op_func = getattr(operator, op)
29
+ return op_func(version.parse(v1), version.parse(v2))
30
+
31
+
32
+ class SkyworkChatModel(PreTrainedModel):
33
+ config_class = SkyworkChatConfig
34
+ main_input_name = 'pixel_values'
35
+ base_model_prefix = 'language_model'
36
+ _supports_flash_attn_2 = True
37
+ _no_split_modules = ['SkyworkVisionModel', 'LlamaDecoderLayer', 'SkyworkLM2DecoderLayer']
38
+
39
+ def __init__(self, config: SkyworkChatConfig, vision_model=None, language_model=None, use_flash_attn=True):
40
+ super().__init__(config)
41
+
42
+ assert version_cmp(transformers.__version__, '4.36.2', 'ge')
43
+ image_size = config.force_image_size or config.vision_config.image_size
44
+ patch_size = config.vision_config.patch_size
45
+ self.patch_size = patch_size
46
+ self.select_layer = config.select_layer
47
+ self.template = config.template
48
+ self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
49
+ self.downsample_ratio = config.downsample_ratio
50
+ self.ps_version = config.ps_version
51
+ use_flash_attn = use_flash_attn if has_flash_attn else False
52
+ config.vision_config.use_flash_attn = True if use_flash_attn else False
53
+ config.llm_config.attn_implementation = 'flash_attention_2' if use_flash_attn else 'eager'
54
+
55
+ logger.info(f'num_image_token: {self.num_image_token}')
56
+ logger.info(f'ps_version: {self.ps_version}')
57
+ if vision_model is not None:
58
+ self.vision_model = vision_model
59
+ else:
60
+ self.vision_model = SkyworkVisionModel(config.vision_config)
61
+ if language_model is not None:
62
+ self.language_model = language_model
63
+ else:
64
+ if config.llm_config.architectures[0] == 'LlamaForCausalLM':
65
+ self.language_model = LlamaForCausalLM(config.llm_config)
66
+ elif config.llm_config.architectures[0] == 'SkyworkLM2ForCausalLM':
67
+ self.language_model = SkyworkLM2ForCausalLM(config.llm_config)
68
+ elif config.llm_config.architectures[0] == 'Qwen2ForCausalLM':
69
+ self.language_model = Qwen2ForCausalLM(config.llm_config)
70
+ else:
71
+ raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
72
+
73
+ vit_hidden_size = config.vision_config.hidden_size
74
+ llm_hidden_size = config.llm_config.hidden_size
75
+
76
+ self.mlp1 = nn.Sequential(
77
+ nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
78
+ nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
79
+ nn.GELU(),
80
+ nn.Linear(llm_hidden_size, llm_hidden_size)
81
+ )
82
+
83
+ self.img_context_token_id = None
84
+ self.conv_template = get_conv_template(self.template)
85
+ self.system_message = self.conv_template.system_message
86
+
87
+ def forward(
88
+ self,
89
+ pixel_values: torch.FloatTensor,
90
+ input_ids: torch.LongTensor = None,
91
+ attention_mask: Optional[torch.Tensor] = None,
92
+ position_ids: Optional[torch.LongTensor] = None,
93
+ image_flags: Optional[torch.LongTensor] = None,
94
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
95
+ labels: Optional[torch.LongTensor] = None,
96
+ use_cache: Optional[bool] = None,
97
+ output_attentions: Optional[bool] = None,
98
+ output_hidden_states: Optional[bool] = None,
99
+ return_dict: Optional[bool] = None,
100
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
101
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
102
+
103
+ image_flags = image_flags.squeeze(-1)
104
+ input_embeds = self.language_model.get_input_embeddings()(input_ids).clone()
105
+
106
+ vit_embeds = self.extract_feature(pixel_values)
107
+ vit_embeds = vit_embeds[image_flags == 1]
108
+ vit_batch_size = pixel_values.shape[0]
109
+
110
+ B, N, C = input_embeds.shape
111
+ input_embeds = input_embeds.reshape(B * N, C)
112
+
113
+ if torch.distributed.get_rank() == 0:
114
+ print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}')
115
+
116
+ input_ids = input_ids.reshape(B * N)
117
+ selected = (input_ids == self.img_context_token_id)
118
+ try:
119
+ input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C)
120
+ except Exception as e:
121
+ vit_embeds = vit_embeds.reshape(-1, C)
122
+ print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, '
123
+ f'vit_embeds.shape={vit_embeds.shape}')
124
+ n_token = selected.sum()
125
+ input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]
126
+
127
+ input_embeds = input_embeds.reshape(B, N, C)
128
+
129
+ outputs = self.language_model(
130
+ inputs_embeds=input_embeds,
131
+ attention_mask=attention_mask,
132
+ position_ids=position_ids,
133
+ past_key_values=past_key_values,
134
+ use_cache=use_cache,
135
+ output_attentions=output_attentions,
136
+ output_hidden_states=output_hidden_states,
137
+ return_dict=return_dict,
138
+ )
139
+ logits = outputs.logits
140
+
141
+ loss = None
142
+ if labels is not None:
143
+ # Shift so that tokens < n predict n
144
+ shift_logits = logits[..., :-1, :].contiguous()
145
+ shift_labels = labels[..., 1:].contiguous()
146
+ # Flatten the tokens
147
+ loss_fct = CrossEntropyLoss()
148
+ shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
149
+ shift_labels = shift_labels.view(-1)
150
+ # Enable model parallelism
151
+ shift_labels = shift_labels.to(shift_logits.device)
152
+ loss = loss_fct(shift_logits, shift_labels)
153
+
154
+ if not return_dict:
155
+ output = (logits,) + outputs[1:]
156
+ return (loss,) + output if loss is not None else output
157
+
158
+ return CausalLMOutputWithPast(
159
+ loss=loss,
160
+ logits=logits,
161
+ past_key_values=outputs.past_key_values,
162
+ hidden_states=outputs.hidden_states,
163
+ attentions=outputs.attentions,
164
+ )
165
+
166
+ def pixel_shuffle(self, x, scale_factor=0.5):
167
+ n, w, h, c = x.size()
168
+ # N, W, H, C --> N, W, H * scale, C // scale
169
+ x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
170
+ # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
171
+ x = x.permute(0, 2, 1, 3).contiguous()
172
+ # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
173
+ x = x.view(n, int(h * scale_factor), int(w * scale_factor),
174
+ int(c / (scale_factor * scale_factor)))
175
+ if self.ps_version == 'v1':
176
+ warnings.warn("In ps_version 'v1', the height and width have not been swapped back, "
177
+ 'which results in a transposed image.')
178
+ else:
179
+ x = x.permute(0, 2, 1, 3).contiguous()
180
+ return x
181
+
182
+ def extract_feature(self, pixel_values):
183
+ if self.select_layer == -1:
184
+ vit_embeds = self.vision_model(
185
+ pixel_values=pixel_values,
186
+ output_hidden_states=False,
187
+ return_dict=True).last_hidden_state
188
+ else:
189
+ vit_embeds = self.vision_model(
190
+ pixel_values=pixel_values,
191
+ output_hidden_states=True,
192
+ return_dict=True).hidden_states[self.select_layer]
193
+ vit_embeds = vit_embeds[:, 1:, :]
194
+
195
+ h = w = int(vit_embeds.shape[1] ** 0.5)
196
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
197
+ vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
198
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
199
+ vit_embeds = self.mlp1(vit_embeds)
200
+ return vit_embeds
201
+
202
+ def batch_chat(self, tokenizer, pixel_values, questions, generation_config, num_patches_list=None,
203
+ history=None, return_history=False, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>',
204
+ IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', verbose=False, image_counts=None):
205
+ if history is not None or return_history:
206
+ print('Now multi-turn chat is not supported in batch_chat.')
207
+ raise NotImplementedError
208
+
209
+ if image_counts is not None:
210
+ num_patches_list = image_counts
211
+ print('Warning: `image_counts` is deprecated. Please use `num_patches_list` instead.')
212
+
213
+ img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
214
+ self.img_context_token_id = img_context_token_id
215
+
216
+
217
+ if verbose and pixel_values is not None:
218
+ image_bs = pixel_values.shape[0]
219
+ print(f'dynamic ViT batch size: {image_bs}')
220
+
221
+ queries = []
222
+ for idx, num_patches in enumerate(num_patches_list):
223
+ question = questions[idx]
224
+ if pixel_values is not None and '<image>' not in question:
225
+ question = '<image>\n' + question
226
+ template = get_conv_template(self.template)
227
+ template.system_message = self.system_message
228
+ template.append_message(template.roles[0], question)
229
+ template.append_message(template.roles[1], None)
230
+ query = template.get_prompt()
231
+
232
+ image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
233
+ query = query.replace('<image>', image_tokens, 1)
234
+ queries.append(query)
235
+
236
+ tokenizer.padding_side = 'left'
237
+ model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
238
+ input_ids = model_inputs['input_ids'].to(self.device)
239
+ attention_mask = model_inputs['attention_mask'].to(self.device)
240
+ eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())
241
+ generation_config['eos_token_id'] = eos_token_id
242
+ generation_output = self.generate(
243
+ pixel_values=pixel_values,
244
+ input_ids=input_ids,
245
+ attention_mask=attention_mask,
246
+ **generation_config
247
+ )
248
+ responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
249
+ responses = [response.split(template.sep.strip())[0].strip() for response in responses]
250
+ return responses
251
+
252
+ def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False,
253
+ num_patches_list=None, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>',
254
+ verbose=False):
255
+
256
+ if history is None and pixel_values is not None and '<image>' not in question:
257
+ question = '<image>\n' + question
258
+
259
+ if num_patches_list is None:
260
+ num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
261
+ assert pixel_values is None or len(pixel_values) == sum(num_patches_list)
262
+
263
+ img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
264
+ self.img_context_token_id = img_context_token_id
265
+
266
+ template = get_conv_template(self.template)
267
+ template.system_message = self.system_message
268
+ eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())
269
+
270
+
271
+ history = [] if history is None else history
272
+ for (old_question, old_answer) in history:
273
+ template.append_message(template.roles[0], old_question)
274
+ template.append_message(template.roles[1], old_answer)
275
+ template.append_message(template.roles[0], question)
276
+ template.append_message(template.roles[1], None)
277
+ query = template.get_prompt()
278
+
279
+
280
+ if verbose and pixel_values is not None:
281
+ image_bs = pixel_values.shape[0]
282
+ print(f'dynamic ViT batch size: {image_bs}')
283
+
284
+ for num_patches in num_patches_list:
285
+ image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
286
+ query = query.replace('<image>', image_tokens, 1)
287
+
288
+
289
+ model_inputs = tokenizer(query, return_tensors='pt')
290
+ input_ids = model_inputs['input_ids'].to(self.device)
291
+ attention_mask = model_inputs['attention_mask'].to(self.device)
292
+ generation_config['eos_token_id'] = eos_token_id
293
+ generation_output = self.generate(
294
+ pixel_values=pixel_values,
295
+ input_ids=input_ids,
296
+ attention_mask=attention_mask,
297
+ **generation_config
298
+ )
299
+ response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
300
+ response = response.split(template.sep.strip())[0].strip()
301
+ history.append((question, response))
302
+
303
+ if return_history:
304
+ return response, history
305
+ else:
306
+ query_to_print = query.replace(IMG_CONTEXT_TOKEN, '')
307
+ query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>')
308
+ if verbose:
309
+ print(query_to_print, response)
310
+ return response
311
+
312
+ @torch.no_grad()
313
+ def generate(
314
+ self,
315
+ pixel_values: Optional[torch.FloatTensor] = None,
316
+ input_ids: Optional[torch.FloatTensor] = None,
317
+ attention_mask: Optional[torch.LongTensor] = None,
318
+ visual_features: Optional[torch.FloatTensor] = None,
319
+ generation_config: Optional[GenerationConfig] = None,
320
+ output_hidden_states: Optional[bool] = None,
321
+ **generate_kwargs,
322
+ ) -> torch.LongTensor:
323
+
324
+ assert self.img_context_token_id is not None
325
+ if pixel_values is not None:
326
+ if visual_features is not None:
327
+ vit_embeds = visual_features
328
+ else:
329
+ vit_embeds = self.extract_feature(pixel_values)
330
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
331
+ B, N, C = input_embeds.shape
332
+ input_embeds = input_embeds.reshape(B * N, C)
333
+
334
+ input_ids = input_ids.reshape(B * N)
335
+ selected = (input_ids == self.img_context_token_id)
336
+
337
+ assert selected.sum() != 0
338
+ input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
339
+
340
+ input_embeds = input_embeds.reshape(B, N, C)
341
+ else:
342
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
343
+
344
+
345
+ outputs = self.language_model.generate(
346
+ inputs_embeds=input_embeds,
347
+ attention_mask=attention_mask,
348
+ generation_config=generation_config,
349
+ output_hidden_states=output_hidden_states,
350
+ use_cache=True,
351
+ **generate_kwargs,
352
+ )
353
+
354
+ return outputs
modeling_skywork_vit.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import torch.utils.checkpoint
6
+ from einops import rearrange
7
+ from timm.models.layers import DropPath
8
+ from torch import nn
9
+ from transformers.activations import ACT2FN
10
+ from transformers.modeling_outputs import (BaseModelOutput,
11
+ BaseModelOutputWithPooling)
12
+ from transformers.modeling_utils import PreTrainedModel
13
+ from transformers.utils import logging
14
+
15
+ from .configuration_skywork_vit import SkyworkVisionConfig
16
+
17
+ try:
18
+ from flash_attn.bert_padding import pad_input, unpad_input
19
+ from flash_attn.flash_attn_interface import \
20
+ flash_attn_varlen_qkvpacked_func
21
+ has_flash_attn = True
22
+ except:
23
+ print('FlashAttention2 is not installed.')
24
+ has_flash_attn = False
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ class FlashAttention(nn.Module):
30
+ """Implement the scaled dot product attention with softmax.
31
+ Arguments
32
+ ---------
33
+ softmax_scale: The temperature to use for the softmax attention.
34
+ (default: 1/sqrt(d_keys) where d_keys is computed at
35
+ runtime)
36
+ attention_dropout: The dropout rate to apply to the attention
37
+ (default: 0.0)
38
+ """
39
+
40
+ def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
41
+ super().__init__()
42
+ self.softmax_scale = softmax_scale
43
+ self.dropout_p = attention_dropout
44
+
45
+ def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
46
+ max_s=None, need_weights=False):
47
+ """Implements the multihead softmax attention.
48
+ Arguments
49
+ ---------
50
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
51
+ if unpadded: (nnz, 3, h, d)
52
+ key_padding_mask: a bool tensor of shape (B, S)
53
+ """
54
+ assert not need_weights
55
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
56
+ assert qkv.is_cuda
57
+
58
+ if cu_seqlens is None:
59
+ batch_size = qkv.shape[0]
60
+ seqlen = qkv.shape[1]
61
+ if key_padding_mask is None:
62
+ qkv = rearrange(qkv, 'b s ... -> (b s) ...')
63
+ max_s = seqlen
64
+ cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
65
+ device=qkv.device)
66
+ output = flash_attn_varlen_qkvpacked_func(
67
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
68
+ softmax_scale=self.softmax_scale, causal=causal
69
+ )
70
+ output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
71
+ else:
72
+ nheads = qkv.shape[-2]
73
+ x = rearrange(qkv, 'b s three h d -> b s (three h d)')
74
+ x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
75
+ x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
76
+ output_unpad = flash_attn_varlen_qkvpacked_func(
77
+ x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
78
+ softmax_scale=self.softmax_scale, causal=causal
79
+ )
80
+ output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
81
+ indices, batch_size, seqlen),
82
+ 'b s (h d) -> b s h d', h=nheads)
83
+ else:
84
+ assert max_s is not None
85
+ output = flash_attn_varlen_qkvpacked_func(
86
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
87
+ softmax_scale=self.softmax_scale, causal=causal
88
+ )
89
+
90
+ return output, None
91
+
92
+
93
+ class SkyworkRMSNorm(nn.Module):
94
+ def __init__(self, hidden_size, eps=1e-6):
95
+ super().__init__()
96
+ self.weight = nn.Parameter(torch.ones(hidden_size))
97
+ self.variance_epsilon = eps
98
+
99
+ def forward(self, hidden_states):
100
+ input_dtype = hidden_states.dtype
101
+ hidden_states = hidden_states.to(torch.float32)
102
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
103
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
104
+ return self.weight * hidden_states.to(input_dtype)
105
+
106
+
107
+ try:
108
+ from apex.normalization import FusedRMSNorm
109
+
110
+ SkyworkRMSNorm = FusedRMSNorm # noqa
111
+
112
+ logger.info('Discovered apex.normalization.FusedRMSNorm - will use it instead ofSkyworkRMSNorm')
113
+ except ImportError:
114
+ # using the normal SkyworkRMSNorm
115
+ pass
116
+ except Exception:
117
+ logger.warning('discovered apex but it failed to load, falling back to SkyworkRMSNorm')
118
+ pass
119
+
120
+
121
+ NORM2FN = {
122
+ 'rms_norm': SkyworkRMSNorm,
123
+ 'layer_norm': nn.LayerNorm,
124
+ }
125
+
126
+
127
+ class SkyworkVisionEmbeddings(nn.Module):
128
+ def __init__(self, config: SkyworkVisionConfig):
129
+ super().__init__()
130
+ self.config = config
131
+ self.embed_dim = config.hidden_size
132
+ self.image_size = config.image_size
133
+ self.patch_size = config.patch_size
134
+
135
+ self.class_embedding = nn.Parameter(
136
+ torch.randn(1, 1, self.embed_dim),
137
+ )
138
+
139
+ self.patch_embedding = nn.Conv2d(
140
+ in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
141
+ )
142
+
143
+ self.num_patches = (self.image_size // self.patch_size) ** 2
144
+ self.num_positions = self.num_patches + 1
145
+
146
+ self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
147
+
148
+ def _get_pos_embed(self, pos_embed, H, W):
149
+ target_dtype = pos_embed.dtype
150
+ pos_embed = pos_embed.float().reshape(
151
+ 1, self.image_size // self.patch_size, self.image_size // self.patch_size, -1).permute(0, 3, 1, 2)
152
+ pos_embed = F.interpolate(pos_embed, size=(H, W), mode='bicubic', align_corners=False). \
153
+ reshape(1, -1, H * W).permute(0, 2, 1).to(target_dtype)
154
+ return pos_embed
155
+
156
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
157
+ target_dtype = self.patch_embedding.weight.dtype
158
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, channel, width, height]
159
+ batch_size, _, height, width = patch_embeds.shape
160
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
161
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
162
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
163
+ position_embedding = torch.cat([
164
+ self.position_embedding[:, :1, :],
165
+ self._get_pos_embed(self.position_embedding[:, 1:, :], height, width)
166
+ ], dim=1)
167
+ embeddings = embeddings + position_embedding.to(target_dtype)
168
+ return embeddings
169
+
170
+
171
+ class SkyworkAttention(nn.Module):
172
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
173
+
174
+ def __init__(self, config: SkyworkVisionConfig):
175
+ super().__init__()
176
+ self.config = config
177
+ self.embed_dim = config.hidden_size
178
+ self.num_heads = config.num_attention_heads
179
+ self.use_flash_attn = config.use_flash_attn and has_flash_attn
180
+ if config.use_flash_attn and not has_flash_attn:
181
+ print('Warning: Flash Attention is not available, use_flash_attn is set to False.')
182
+ self.head_dim = self.embed_dim // self.num_heads
183
+ if self.head_dim * self.num_heads != self.embed_dim:
184
+ raise ValueError(
185
+ f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:'
186
+ f' {self.num_heads}).'
187
+ )
188
+
189
+ self.scale = self.head_dim ** -0.5
190
+ self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
191
+ self.attn_drop = nn.Dropout(config.attention_dropout)
192
+ self.proj_drop = nn.Dropout(config.dropout)
193
+
194
+ self.qk_normalization = config.qk_normalization
195
+
196
+ if self.qk_normalization:
197
+ self.q_norm = SkyworkRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
198
+ self.k_norm = SkyworkRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
199
+
200
+ if self.use_flash_attn:
201
+ self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout)
202
+ self.proj = nn.Linear(self.embed_dim, self.embed_dim)
203
+
204
+ def _naive_attn(self, x):
205
+ B, N, C = x.shape
206
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
207
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
208
+
209
+ if self.qk_normalization:
210
+ B_, H_, N_, D_ = q.shape
211
+ q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
212
+ k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
213
+
214
+ attn = ((q * self.scale) @ k.transpose(-2, -1))
215
+ attn = attn.softmax(dim=-1)
216
+ attn = self.attn_drop(attn)
217
+
218
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
219
+ x = self.proj(x)
220
+ x = self.proj_drop(x)
221
+ return x
222
+
223
+ def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
224
+ qkv = self.qkv(x)
225
+ qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
226
+
227
+ if self.qk_normalization:
228
+ q, k, v = qkv.unbind(2)
229
+ q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
230
+ k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
231
+ qkv = torch.stack([q, k, v], dim=2)
232
+
233
+ context, _ = self.inner_attn(
234
+ qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False
235
+ )
236
+ outs = self.proj(rearrange(context, 'b s h d -> b s (h d)'))
237
+ outs = self.proj_drop(outs)
238
+ return outs
239
+
240
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
241
+ x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states)
242
+ return x
243
+
244
+
245
+ class SkyworkMLP(nn.Module):
246
+ def __init__(self, config: SkyworkVisionConfig):
247
+ super().__init__()
248
+ self.config = config
249
+ self.act = ACT2FN[config.hidden_act]
250
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
251
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
252
+
253
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
254
+ hidden_states = self.fc1(hidden_states)
255
+ hidden_states = self.act(hidden_states)
256
+ hidden_states = self.fc2(hidden_states)
257
+ return hidden_states
258
+
259
+
260
+ class SkyworkVisionEncoderLayer(nn.Module):
261
+ def __init__(self, config: SkyworkVisionConfig, drop_path_rate: float):
262
+ super().__init__()
263
+ self.embed_dim = config.hidden_size
264
+ self.intermediate_size = config.intermediate_size
265
+ self.norm_type = config.norm_type
266
+
267
+ self.attn = SkyworkAttention(config)
268
+ self.mlp = SkyworkMLP(config)
269
+ self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
270
+ self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
271
+
272
+ self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
273
+ self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
274
+ self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
275
+ self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
276
+
277
+ def forward(
278
+ self,
279
+ hidden_states: torch.Tensor,
280
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
281
+ """
282
+ Args:
283
+ hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
284
+ """
285
+ hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states).to(hidden_states.dtype)) * self.ls1)
286
+
287
+ hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states).to(hidden_states.dtype)) * self.ls2)
288
+
289
+ return hidden_states
290
+
291
+
292
+ class SkyworkVisionEncoder(nn.Module):
293
+ """
294
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
295
+ [`SkyworkEncoderLayer`].
296
+
297
+ Args:
298
+ config (`SkyworkConfig`):
299
+ The corresponding vision configuration for the `SkyworkEncoder`.
300
+ """
301
+
302
+ def __init__(self, config: SkyworkVisionConfig):
303
+ super().__init__()
304
+ self.config = config
305
+ # stochastic depth decay rule
306
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
307
+ self.layers = nn.ModuleList([
308
+ SkyworkVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)])
309
+ self.gradient_checkpointing = True
310
+
311
+ def forward(
312
+ self,
313
+ inputs_embeds,
314
+ output_hidden_states: Optional[bool] = None,
315
+ return_dict: Optional[bool] = None,
316
+ ) -> Union[Tuple, BaseModelOutput]:
317
+ r"""
318
+ Args:
319
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
320
+ Embedded representation of the inputs. Should be float, not int tokens.
321
+ output_hidden_states (`bool`, *optional*):
322
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
323
+ for more detail.
324
+ return_dict (`bool`, *optional*):
325
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
326
+ """
327
+ output_hidden_states = (
328
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
329
+ )
330
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
331
+
332
+ encoder_states = () if output_hidden_states else None
333
+ hidden_states = inputs_embeds
334
+
335
+ for idx, encoder_layer in enumerate(self.layers):
336
+ if output_hidden_states:
337
+ encoder_states = encoder_states + (hidden_states,)
338
+ if self.gradient_checkpointing and self.training:
339
+ layer_outputs = torch.utils.checkpoint.checkpoint(
340
+ encoder_layer,
341
+ hidden_states)
342
+ else:
343
+ layer_outputs = encoder_layer(
344
+ hidden_states,
345
+ )
346
+ hidden_states = layer_outputs
347
+
348
+ if output_hidden_states:
349
+ encoder_states = encoder_states + (hidden_states,)
350
+
351
+ if not return_dict:
352
+ return tuple(v for v in [hidden_states, encoder_states] if v is not None)
353
+ return BaseModelOutput(
354
+ last_hidden_state=hidden_states, hidden_states=encoder_states
355
+ )
356
+
357
+
358
+ class SkyworkVisionModel(PreTrainedModel):
359
+ main_input_name = 'pixel_values'
360
+ _supports_flash_attn_2 = True
361
+ config_class = SkyworkVisionConfig
362
+ _no_split_modules = ['SkyworkVisionEncoderLayer']
363
+
364
+ def __init__(self, config: SkyworkVisionConfig):
365
+ super().__init__(config)
366
+ self.config = config
367
+
368
+ self.embeddings = SkyworkVisionEmbeddings(config)
369
+ self.encoder = SkyworkVisionEncoder(config)
370
+
371
+ def resize_pos_embeddings(self, old_size, new_size, patch_size):
372
+ pos_emb = self.embeddings.position_embedding
373
+ _, num_positions, embed_dim = pos_emb.shape
374
+ cls_emb = pos_emb[:, :1, :]
375
+ pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2)
376
+ pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode='bicubic', align_corners=False)
377
+ pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
378
+ pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
379
+ self.embeddings.position_embedding = nn.Parameter(pos_emb)
380
+ self.embeddings.image_size = new_size
381
+ logger.info('Resized position embeddings from {} to {}'.format(old_size, new_size))
382
+
383
+ def get_input_embeddings(self):
384
+ return self.embeddings
385
+
386
+ def forward(
387
+ self,
388
+ pixel_values: Optional[torch.FloatTensor] = None,
389
+ output_hidden_states: Optional[bool] = None,
390
+ return_dict: Optional[bool] = None,
391
+ pixel_embeds: Optional[torch.FloatTensor] = None,
392
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
393
+ output_hidden_states = (
394
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
395
+ )
396
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
397
+
398
+ if pixel_values is None and pixel_embeds is None:
399
+ raise ValueError('You have to specify pixel_values or pixel_embeds')
400
+
401
+ if pixel_embeds is not None:
402
+ hidden_states = pixel_embeds
403
+ else:
404
+ if len(pixel_values.shape) == 4:
405
+ hidden_states = self.embeddings(pixel_values)
406
+ else:
407
+ raise ValueError(f'wrong pixel_values size: {pixel_values.shape}')
408
+ encoder_outputs = self.encoder(
409
+ inputs_embeds=hidden_states,
410
+ output_hidden_states=output_hidden_states,
411
+ return_dict=return_dict,
412
+ )
413
+ last_hidden_state = encoder_outputs.last_hidden_state
414
+ pooled_output = last_hidden_state[:, 0, :]
415
+
416
+ if not return_dict:
417
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
418
+
419
+ return BaseModelOutputWithPooling(
420
+ last_hidden_state=last_hidden_state,
421
+ pooler_output=pooled_output,
422
+ hidden_states=encoder_outputs.hidden_states,
423
+ attentions=encoder_outputs.attentions,
424
+ )
preprocessor_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": 448,
3
+ "do_center_crop": true,
4
+ "do_normalize": true,
5
+ "do_resize": true,
6
+ "feature_extractor_type": "CLIPFeatureExtractor",
7
+ "image_mean": [
8
+ 0.485,
9
+ 0.456,
10
+ 0.406
11
+ ],
12
+ "image_std": [
13
+ 0.229,
14
+ 0.224,
15
+ 0.225
16
+ ],
17
+ "resample": 3,
18
+ "size": 448
19
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|object_ref_start|>",
4
+ "<|object_ref_end|>",
5
+ "<|box_start|>",
6
+ "<|box_end|>",
7
+ "<|quad_start|>",
8
+ "<|quad_end|>",
9
+ "<|vision_start|>",
10
+ "<|vision_end|>",
11
+ "<|vision_pad|>",
12
+ "<|image_pad|>",
13
+ "<|video_pad|>",
14
+ "<img>",
15
+ "</img>",
16
+ "<IMG_CONTEXT>",
17
+ "<quad>",
18
+ "</quad>",
19
+ "<ref>",
20
+ "</ref>",
21
+ "<box>",
22
+ "</box>",
23
+ "<|begin▁of▁sentence|>",
24
+ "<|end▁of▁sentence|>"
25
+ ],
26
+ "bos_token": {
27
+ "content": "<|begin▁of▁sentence|>",
28
+ "lstrip": false,
29
+ "normalized": false,
30
+ "rstrip": false,
31
+ "single_word": false
32
+ },
33
+ "eos_token": {
34
+ "content": "<|end▁of▁sentence|>",
35
+ "lstrip": false,
36
+ "normalized": false,
37
+ "rstrip": false,
38
+ "single_word": false
39
+ },
40
+ "pad_token": {
41
+ "content": "<|endoftext|>",
42
+ "lstrip": false,
43
+ "normalized": false,
44
+ "rstrip": false,
45
+ "single_word": false
46
+ }
47
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "bos_token": {
5
+ "__type": "AddedToken",
6
+ "content": "<|begin▁of▁sentence|>",
7
+ "lstrip": false,
8
+ "normalized": true,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "clean_up_tokenization_spaces": false,
13
+ "eos_token": {
14
+ "__type": "AddedToken",
15
+ "content": "<|end▁of▁sentence|>",
16
+ "lstrip": false,
17
+ "normalized": true,
18
+ "rstrip": false,
19
+ "single_word": false
20
+ },
21
+ "legacy": true,
22
+ "model_max_length": 16384,
23
+ "pad_token": {
24
+ "__type": "AddedToken",
25
+ "content": "<|end▁of▁sentence|>",
26
+ "lstrip": false,
27
+ "normalized": true,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ },
31
+ "sp_model_kwargs": {},
32
+ "unk_token": null,
33
+ "tokenizer_class": "LlamaTokenizerFast",
34
+ "chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|><think>\\n'}}{% endif %}"
35
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
zero_to_fp32.py ADDED
@@ -0,0 +1,604 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright (c) Microsoft Corporation.
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ # DeepSpeed Team
7
+
8
+ # This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
9
+ # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
10
+ # the future. Once extracted, the weights don't require DeepSpeed and can be used in any
11
+ # application.
12
+ #
13
+ # example: python zero_to_fp32.py . pytorch_model.bin
14
+
15
+ import argparse
16
+ import torch
17
+ import glob
18
+ import math
19
+ import os
20
+ import re
21
+ from collections import OrderedDict
22
+ from dataclasses import dataclass
23
+
24
+ # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
25
+ # DeepSpeed data structures it has to be available in the current python environment.
26
+ from deepspeed.utils import logger
27
+ from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
28
+ FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
29
+ FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
30
+
31
+
32
+ @dataclass
33
+ class zero_model_state:
34
+ buffers: dict()
35
+ param_shapes: dict()
36
+ shared_params: list
37
+ ds_version: int
38
+ frozen_param_shapes: dict()
39
+ frozen_param_fragments: dict()
40
+
41
+
42
+ debug = 0
43
+
44
+ # load to cpu
45
+ device = torch.device('cpu')
46
+
47
+
48
+ def atoi(text):
49
+ return int(text) if text.isdigit() else text
50
+
51
+
52
+ def natural_keys(text):
53
+ '''
54
+ alist.sort(key=natural_keys) sorts in human order
55
+ http://nedbatchelder.com/blog/200712/human_sorting.html
56
+ (See Toothy's implementation in the comments)
57
+ '''
58
+ return [atoi(c) for c in re.split(r'(\d+)', text)]
59
+
60
+
61
+ def get_model_state_file(checkpoint_dir, zero_stage):
62
+ if not os.path.isdir(checkpoint_dir):
63
+ raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
64
+
65
+ # there should be only one file
66
+ if zero_stage <= 2:
67
+ file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
68
+ elif zero_stage == 3:
69
+ file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
70
+
71
+ if not os.path.exists(file):
72
+ raise FileNotFoundError(f"can't find model states file at '{file}'")
73
+
74
+ return file
75
+
76
+
77
+ def get_checkpoint_files(checkpoint_dir, glob_pattern):
78
+ # XXX: need to test that this simple glob rule works for multi-node setup too
79
+ ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
80
+
81
+ if len(ckpt_files) == 0:
82
+ raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
83
+
84
+ return ckpt_files
85
+
86
+
87
+ def get_optim_files(checkpoint_dir):
88
+ return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
89
+
90
+
91
+ def get_model_state_files(checkpoint_dir):
92
+ return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
93
+
94
+
95
+ def parse_model_states(files):
96
+ zero_model_states = []
97
+ for file in files:
98
+ state_dict = torch.load(file, map_location=device)
99
+
100
+ if BUFFER_NAMES not in state_dict:
101
+ raise ValueError(f"{file} is not a model state checkpoint")
102
+ buffer_names = state_dict[BUFFER_NAMES]
103
+ if debug:
104
+ print("Found buffers:", buffer_names)
105
+
106
+ # recover just the buffers while restoring them to fp32 if they were saved in fp16
107
+ buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
108
+ param_shapes = state_dict[PARAM_SHAPES]
109
+
110
+ # collect parameters that are included in param_shapes
111
+ param_names = []
112
+ for s in param_shapes:
113
+ for name in s.keys():
114
+ param_names.append(name)
115
+
116
+ # update with frozen parameters
117
+ frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
118
+ if frozen_param_shapes is not None:
119
+ if debug:
120
+ print(f"Found frozen_param_shapes: {frozen_param_shapes}")
121
+ param_names += list(frozen_param_shapes.keys())
122
+
123
+ # handle shared params
124
+ shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
125
+
126
+ ds_version = state_dict.get(DS_VERSION, None)
127
+
128
+ frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
129
+
130
+ z_model_state = zero_model_state(buffers=buffers,
131
+ param_shapes=param_shapes,
132
+ shared_params=shared_params,
133
+ ds_version=ds_version,
134
+ frozen_param_shapes=frozen_param_shapes,
135
+ frozen_param_fragments=frozen_param_fragments)
136
+ zero_model_states.append(z_model_state)
137
+
138
+ return zero_model_states
139
+
140
+
141
+ def parse_optim_states(files, ds_checkpoint_dir):
142
+
143
+ total_files = len(files)
144
+ state_dicts = []
145
+ for f in files:
146
+ state_dict = torch.load(f, map_location=device)
147
+ # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
148
+ # and also handle the case where it was already removed by another helper script
149
+ state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
150
+ state_dicts.append(state_dict)
151
+
152
+ if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:
153
+ raise ValueError(f"{files[0]} is not a zero checkpoint")
154
+ zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
155
+ world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
156
+
157
+ # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
158
+ # parameters can be different from data parallelism for non-expert parameters. So we can just
159
+ # use the max of the partition_count to get the dp world_size.
160
+
161
+ if type(world_size) is list:
162
+ world_size = max(world_size)
163
+
164
+ if world_size != total_files:
165
+ raise ValueError(
166
+ f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
167
+ "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
168
+ )
169
+
170
+ # the groups are named differently in each stage
171
+ if zero_stage <= 2:
172
+ fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
173
+ elif zero_stage == 3:
174
+ fp32_groups_key = FP32_FLAT_GROUPS
175
+ else:
176
+ raise ValueError(f"unknown zero stage {zero_stage}")
177
+
178
+ if zero_stage <= 2:
179
+ fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
180
+ elif zero_stage == 3:
181
+ # if there is more than one param group, there will be multiple flattened tensors - one
182
+ # flattened tensor per group - for simplicity merge them into a single tensor
183
+ #
184
+ # XXX: could make the script more memory efficient for when there are multiple groups - it
185
+ # will require matching the sub-lists of param_shapes for each param group flattened tensor
186
+
187
+ fp32_flat_groups = [
188
+ torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts))
189
+ ]
190
+
191
+ return zero_stage, world_size, fp32_flat_groups
192
+
193
+
194
+ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters):
195
+ """
196
+ Returns fp32 state_dict reconstructed from ds checkpoint
197
+
198
+ Args:
199
+ - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
200
+
201
+ """
202
+ print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
203
+
204
+ optim_files = get_optim_files(ds_checkpoint_dir)
205
+ zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
206
+ print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
207
+
208
+ model_files = get_model_state_files(ds_checkpoint_dir)
209
+
210
+ zero_model_states = parse_model_states(model_files)
211
+ print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
212
+
213
+ if zero_stage <= 2:
214
+ return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
215
+ exclude_frozen_parameters)
216
+ elif zero_stage == 3:
217
+ return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
218
+ exclude_frozen_parameters)
219
+
220
+
221
+ def _zero2_merge_frozen_params(state_dict, zero_model_states):
222
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
223
+ return
224
+
225
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
226
+ frozen_param_fragments = zero_model_states[0].frozen_param_fragments
227
+
228
+ if debug:
229
+ num_elem = sum(s.numel() for s in frozen_param_shapes.values())
230
+ print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
231
+
232
+ wanted_params = len(frozen_param_shapes)
233
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
234
+ avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
235
+ print(f'Frozen params: Have {avail_numel} numels to process.')
236
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
237
+
238
+ total_params = 0
239
+ total_numel = 0
240
+ for name, shape in frozen_param_shapes.items():
241
+ total_params += 1
242
+ unpartitioned_numel = shape.numel()
243
+ total_numel += unpartitioned_numel
244
+
245
+ state_dict[name] = frozen_param_fragments[name]
246
+
247
+ if debug:
248
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
249
+
250
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
251
+
252
+
253
+ def _has_callable(obj, fn):
254
+ attr = getattr(obj, fn, None)
255
+ return callable(attr)
256
+
257
+
258
+ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
259
+ param_shapes = zero_model_states[0].param_shapes
260
+
261
+ # Reconstruction protocol:
262
+ #
263
+ # XXX: document this
264
+
265
+ if debug:
266
+ for i in range(world_size):
267
+ for j in range(len(fp32_flat_groups[0])):
268
+ print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
269
+
270
+ # XXX: memory usage doubles here (zero2)
271
+ num_param_groups = len(fp32_flat_groups[0])
272
+ merged_single_partition_of_fp32_groups = []
273
+ for i in range(num_param_groups):
274
+ merged_partitions = [sd[i] for sd in fp32_flat_groups]
275
+ full_single_fp32_vector = torch.cat(merged_partitions, 0)
276
+ merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
277
+ avail_numel = sum(
278
+ [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
279
+
280
+ if debug:
281
+ wanted_params = sum([len(shapes) for shapes in param_shapes])
282
+ wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
283
+ # not asserting if there is a mismatch due to possible padding
284
+ print(f"Have {avail_numel} numels to process.")
285
+ print(f"Need {wanted_numel} numels in {wanted_params} params.")
286
+
287
+ # params
288
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
289
+ # out-of-core computing solution
290
+ total_numel = 0
291
+ total_params = 0
292
+ for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
293
+ offset = 0
294
+ avail_numel = full_single_fp32_vector.numel()
295
+ for name, shape in shapes.items():
296
+
297
+ unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
298
+ total_numel += unpartitioned_numel
299
+ total_params += 1
300
+
301
+ if debug:
302
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
303
+ state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
304
+ offset += unpartitioned_numel
305
+
306
+ # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
307
+ # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
308
+ # paddings performed in the code it's almost impossible to predict the exact numbers w/o the
309
+ # live optimizer object, so we are checking that the numbers are within the right range
310
+ align_to = 2 * world_size
311
+
312
+ def zero2_align(x):
313
+ return align_to * math.ceil(x / align_to)
314
+
315
+ if debug:
316
+ print(f"original offset={offset}, avail_numel={avail_numel}")
317
+
318
+ offset = zero2_align(offset)
319
+ avail_numel = zero2_align(avail_numel)
320
+
321
+ if debug:
322
+ print(f"aligned offset={offset}, avail_numel={avail_numel}")
323
+
324
+ # Sanity check
325
+ if offset != avail_numel:
326
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
327
+
328
+ print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
329
+
330
+
331
+ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
332
+ exclude_frozen_parameters):
333
+ state_dict = OrderedDict()
334
+
335
+ # buffers
336
+ buffers = zero_model_states[0].buffers
337
+ state_dict.update(buffers)
338
+ if debug:
339
+ print(f"added {len(buffers)} buffers")
340
+
341
+ if not exclude_frozen_parameters:
342
+ _zero2_merge_frozen_params(state_dict, zero_model_states)
343
+
344
+ _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
345
+
346
+ # recover shared parameters
347
+ for pair in zero_model_states[0].shared_params:
348
+ if pair[1] in state_dict:
349
+ state_dict[pair[0]] = state_dict[pair[1]]
350
+
351
+ return state_dict
352
+
353
+
354
+ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
355
+ remainder = unpartitioned_numel % world_size
356
+ padding_numel = (world_size - remainder) if remainder else 0
357
+ partitioned_numel = math.ceil(unpartitioned_numel / world_size)
358
+ return partitioned_numel, padding_numel
359
+
360
+
361
+ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
362
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
363
+ return
364
+
365
+ if debug:
366
+ for i in range(world_size):
367
+ num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
368
+ print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
369
+
370
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
371
+ wanted_params = len(frozen_param_shapes)
372
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
373
+ avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
374
+ print(f'Frozen params: Have {avail_numel} numels to process.')
375
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
376
+
377
+ total_params = 0
378
+ total_numel = 0
379
+ for name, shape in zero_model_states[0].frozen_param_shapes.items():
380
+ total_params += 1
381
+ unpartitioned_numel = shape.numel()
382
+ total_numel += unpartitioned_numel
383
+
384
+ param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
385
+ state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
386
+
387
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
388
+
389
+ if debug:
390
+ print(
391
+ f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
392
+ )
393
+
394
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
395
+
396
+
397
+ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
398
+ param_shapes = zero_model_states[0].param_shapes
399
+ avail_numel = fp32_flat_groups[0].numel() * world_size
400
+ # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
401
+ # param, re-consolidating each param, while dealing with padding if any
402
+
403
+ # merge list of dicts, preserving order
404
+ param_shapes = {k: v for d in param_shapes for k, v in d.items()}
405
+
406
+ if debug:
407
+ for i in range(world_size):
408
+ print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
409
+
410
+ wanted_params = len(param_shapes)
411
+ wanted_numel = sum(shape.numel() for shape in param_shapes.values())
412
+ # not asserting if there is a mismatch due to possible padding
413
+ avail_numel = fp32_flat_groups[0].numel() * world_size
414
+ print(f"Trainable params: Have {avail_numel} numels to process.")
415
+ print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
416
+
417
+ # params
418
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
419
+ # out-of-core computing solution
420
+ offset = 0
421
+ total_numel = 0
422
+ total_params = 0
423
+ for name, shape in param_shapes.items():
424
+
425
+ unpartitioned_numel = shape.numel()
426
+ total_numel += unpartitioned_numel
427
+ total_params += 1
428
+
429
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
430
+
431
+ if debug:
432
+ print(
433
+ f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
434
+ )
435
+
436
+ # XXX: memory usage doubles here
437
+ state_dict[name] = torch.cat(
438
+ tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)),
439
+ 0).narrow(0, 0, unpartitioned_numel).view(shape)
440
+ offset += partitioned_numel
441
+
442
+ offset *= world_size
443
+
444
+ # Sanity check
445
+ if offset != avail_numel:
446
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
447
+
448
+ print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
449
+
450
+
451
+ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
452
+ exclude_frozen_parameters):
453
+ state_dict = OrderedDict()
454
+
455
+ # buffers
456
+ buffers = zero_model_states[0].buffers
457
+ state_dict.update(buffers)
458
+ if debug:
459
+ print(f"added {len(buffers)} buffers")
460
+
461
+ if not exclude_frozen_parameters:
462
+ _zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
463
+
464
+ _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
465
+
466
+ # recover shared parameters
467
+ for pair in zero_model_states[0].shared_params:
468
+ if pair[1] in state_dict:
469
+ state_dict[pair[0]] = state_dict[pair[1]]
470
+
471
+ return state_dict
472
+
473
+
474
+ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_frozen_parameters=False):
475
+ """
476
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
477
+ ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
478
+ via a model hub.
479
+
480
+ Args:
481
+ - ``checkpoint_dir``: path to the desired checkpoint folder
482
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
483
+ - ``exclude_frozen_parameters``: exclude frozen parameters
484
+
485
+ Returns:
486
+ - pytorch ``state_dict``
487
+
488
+ Note: this approach may not work if your application doesn't have sufficient free CPU memory and
489
+ you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
490
+ the checkpoint.
491
+
492
+ A typical usage might be ::
493
+
494
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
495
+ # do the training and checkpoint saving
496
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
497
+ model = model.cpu() # move to cpu
498
+ model.load_state_dict(state_dict)
499
+ # submit to model hub or save the model to share with others
500
+
501
+ In this example the ``model`` will no longer be usable in the deepspeed context of the same
502
+ application. i.e. you will need to re-initialize the deepspeed engine, since
503
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
504
+
505
+ If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
506
+
507
+ """
508
+ if tag is None:
509
+ latest_path = os.path.join(checkpoint_dir, 'latest')
510
+ if os.path.isfile(latest_path):
511
+ with open(latest_path, 'r') as fd:
512
+ tag = fd.read().strip()
513
+ else:
514
+ raise ValueError(f"Unable to find 'latest' file at {latest_path}")
515
+
516
+ ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
517
+
518
+ if not os.path.isdir(ds_checkpoint_dir):
519
+ raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
520
+
521
+ return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
522
+
523
+
524
+ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None, exclude_frozen_parameters=False):
525
+ """
526
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
527
+ loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
528
+
529
+ Args:
530
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
531
+ - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
532
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
533
+ - ``exclude_frozen_parameters``: exclude frozen parameters
534
+ """
535
+
536
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters)
537
+ print(f"Saving fp32 state dict to {output_file}")
538
+ torch.save(state_dict, output_file)
539
+
540
+
541
+ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
542
+ """
543
+ 1. Put the provided model to cpu
544
+ 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
545
+ 3. Load it into the provided model
546
+
547
+ Args:
548
+ - ``model``: the model object to update
549
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
550
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
551
+
552
+ Returns:
553
+ - ``model`: modified model
554
+
555
+ Make sure you have plenty of CPU memory available before you call this function. If you don't
556
+ have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
557
+ conveniently placed for you in the checkpoint folder.
558
+
559
+ A typical usage might be ::
560
+
561
+ from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
562
+ model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
563
+ # submit to model hub or save the model to share with others
564
+
565
+ Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
566
+ of the same application. i.e. you will need to re-initialize the deepspeed engine, since
567
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
568
+
569
+ """
570
+ logger.info(f"Extracting fp32 weights")
571
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
572
+
573
+ logger.info(f"Overwriting model with fp32 weights")
574
+ model = model.cpu()
575
+ model.load_state_dict(state_dict, strict=False)
576
+
577
+ return model
578
+
579
+
580
+ if __name__ == "__main__":
581
+
582
+ parser = argparse.ArgumentParser()
583
+ parser.add_argument("checkpoint_dir",
584
+ type=str,
585
+ help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
586
+ parser.add_argument(
587
+ "output_file",
588
+ type=str,
589
+ help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)")
590
+ parser.add_argument("-t",
591
+ "--tag",
592
+ type=str,
593
+ default=None,
594
+ help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
595
+ parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters")
596
+ parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
597
+ args = parser.parse_args()
598
+
599
+ debug = args.debug
600
+
601
+ convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir,
602
+ args.output_file,
603
+ tag=args.tag,
604
+ exclude_frozen_parameters=args.exclude_frozen_parameters)