Upload folder using huggingface_hub
Browse files- README.md +319 -0
- README_ZH.md +317 -0
- added_tokens.json +11 -0
- config.json +134 -0
- configuration.json +1 -0
- configuration_skywork_chat.py +92 -0
- configuration_skywork_lm2.py +139 -0
- configuration_skywork_vit.py +101 -0
- conversation.py +163 -0
- generation_config.json +9 -0
- model-00001-of-00016.safetensors +3 -0
- model-00002-of-00016.safetensors +3 -0
- model-00003-of-00016.safetensors +3 -0
- model-00004-of-00016.safetensors +3 -0
- model-00005-of-00016.safetensors +3 -0
- model-00006-of-00016.safetensors +3 -0
- model-00007-of-00016.safetensors +3 -0
- model-00008-of-00016.safetensors +3 -0
- model-00009-of-00016.safetensors +3 -0
- model-00010-of-00016.safetensors +3 -0
- model-00011-of-00016.safetensors +3 -0
- model-00012-of-00016.safetensors +3 -0
- model-00013-of-00016.safetensors +3 -0
- model-00014-of-00016.safetensors +3 -0
- model-00015-of-00016.safetensors +3 -0
- model-00016-of-00016.safetensors +3 -0
- model.safetensors.index.json +0 -0
- modeling_skywork_chat.py +354 -0
- modeling_skywork_vit.py +424 -0
- preprocessor_config.json +19 -0
- special_tokens_map.json +47 -0
- tokenizer.json +0 -0
- tokenizer_config.json +35 -0
- vocab.json +0 -0
- zero_to_fp32.py +604 -0
README.md
ADDED
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Skywork-R1V
|
2 |
+
<div align="center">
|
3 |
+
<img src="logo.jpeg" alt="Introduction Image" width="400" height="400">
|
4 |
+
<br>
|
5 |
+
<a href="README_ZH.md">中文</a> | <a href="https://github.com/SkyworkAI/Skywork-R1V">📂 GitHub</a>
|
6 |
+
</div>
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
## 1. Introduction
|
11 |
+
|
12 |
+
We introduce Skywork-R1V, a multimodal reasoning model that extends the R1-series text models to visual modalities through a near-lossless transfer method. Using a lightweight visual projector, Skywork-R1V enables seamless multimodal adaptation without requiring retraining of either the base language model or vision encoder. To enhance visual-text alignment, we developed a hybrid optimization strategy combining Iterative Supervised Fine-Tuning (SFT) with Group Relative Policy Optimization (GRPO), significantly improving cross-modal integration. Additionally, we created an adaptive-length Chain-of-Thought distillation approach for generating reasoning data, which dynamically optimizes reasoning chain lengths to improve inference efficiency and prevent overthinking. The model achieves state-of-the-art performance on key multimodal reasoning benchmarks, scoring 68.1 on MMMU and 71.0 on MathVista, comparable to leading closed-source models like Gemini 2.0 and Kimi-k1.5. It also maintains strong textual reasoning capabilities, achieving impressive scores of 72.6 on AIME and 94.3 on MATH500.
|
13 |
+
|
14 |
+
|
15 |
+
## 2. Model Summary
|
16 |
+
|
17 |
+
****Architecture:****
|
18 |
+
Skywork-R1V employs a modular architecture that efficiently combines vision and language capabilities:
|
19 |
+
- Vision Encoder: Uses Vision Transformer (ViT) as the visual backbone to process image inputs.
|
20 |
+
- Visual Projector: A lightweight MLP (multilayer perceptron) adapter that serves as the bridge between the vision and language components.
|
21 |
+
- Language Model: Utilizes R1-distilled-Qwen-32B as the reasoning-capable language model backbone.
|
22 |
+
|
23 |
+
The model follows a connection pattern of Vision Encoder → MLP Adapter → Language Model, where the MLP adapter aligns the output space of the vision encoder with the input space of the language model. This design allows for efficient transfer of reasoning capabilities from text to multimodal domains without requiring extensive retraining of either the vision encoder or language model.
|
24 |
+
|
25 |
+
****Key Designs****
|
26 |
+
- **Advanced Multimodal Reasoning**
|
27 |
+
Excels in complex reasoning across textual and visual modalities.
|
28 |
+
- **Iterative Training Strategies**
|
29 |
+
Employs iterative supervision and grpo to refine model alignment and performance.
|
30 |
+
- **Adaptive length Chain-of-Thought**
|
31 |
+
Dynamically adjusts reasoning length to enhance inference efficiency and accuracy.
|
32 |
+
- **Scalable Performance**
|
33 |
+
Benchmarked to rival proprietary models across mathematics, coding, and multimodal tasks.
|
34 |
+
|
35 |
+
## 3. Evaluation
|
36 |
+
|
37 |
+
<div align="center">
|
38 |
+
<img src="eval.jpeg" width="600" height="200" alt="skywork_r1v_eval" />
|
39 |
+
</div>
|
40 |
+
|
41 |
+
<div align="center">
|
42 |
+
<b>Evaluation results of state-of-the-art LLMs and VLMs</b>
|
43 |
+
</div>
|
44 |
+
<table>
|
45 |
+
<thead>
|
46 |
+
<tr>
|
47 |
+
<th></th>
|
48 |
+
<th align="center"><strong>Vision</strong></th>
|
49 |
+
<th align="center" colspan="3"><strong>Reasoning</strong></th>
|
50 |
+
<th align="center" colspan="3"><strong>Vision</strong></th>
|
51 |
+
</tr>
|
52 |
+
<tr>
|
53 |
+
<th></th>
|
54 |
+
<th></th>
|
55 |
+
<th align="center"><strong>MATH-500</strong></th>
|
56 |
+
<th align="center"><strong>AIME 2024</strong></th>
|
57 |
+
<th align="center"><strong>GPQA</strong></th>
|
58 |
+
<th align="center"><strong>MathVista(mini)</strong></th>
|
59 |
+
<th align="center"><strong>MMMU(Val)</strong></th>
|
60 |
+
<th align="center"><strong>CSVQA</strong></th>
|
61 |
+
</tr>
|
62 |
+
<tr>
|
63 |
+
<th></th>
|
64 |
+
<th></th>
|
65 |
+
<th align="center">pass@1</th>
|
66 |
+
<th align="center">pass@1</th>
|
67 |
+
<th align="center">pass@1</th>
|
68 |
+
<th align="center">pass@1</th>
|
69 |
+
<th align="center">pass@1</th>
|
70 |
+
<th align="center">pass@1</th>
|
71 |
+
</tr>
|
72 |
+
</thead>
|
73 |
+
<tbody>
|
74 |
+
<tr>
|
75 |
+
<td>Qwen2.5-72B-Instruct</td>
|
76 |
+
<td align="center">❌</td>
|
77 |
+
<td align="center">82.6</td>
|
78 |
+
<td align="center">23.3</td>
|
79 |
+
<td align="center">49.0</td>
|
80 |
+
<td align="center">-</td>
|
81 |
+
<td align="center">-</td>
|
82 |
+
<td align="center">-</td>
|
83 |
+
</tr>
|
84 |
+
<tr>
|
85 |
+
<td>Deepseek V3</td>
|
86 |
+
<td align="center">❌</td>
|
87 |
+
<td align="center">90.2</td>
|
88 |
+
<td align="center">39.2</td>
|
89 |
+
<td align="center">59.1</td>
|
90 |
+
<td align="center">-</td>
|
91 |
+
<td align="center">-</td>
|
92 |
+
<td align="center">-</td>
|
93 |
+
</tr>
|
94 |
+
<tr>
|
95 |
+
<td>Deepseek R1</td>
|
96 |
+
<td align="center">❌</td>
|
97 |
+
<td align="center">97.3</td>
|
98 |
+
<td align="center">79.8</td>
|
99 |
+
<td align="center">71.5</td>
|
100 |
+
<td align="center">-</td>
|
101 |
+
<td align="center">-</td>
|
102 |
+
<td align="center">-</td>
|
103 |
+
</tr>
|
104 |
+
<tr>
|
105 |
+
<td>Claude 3.5 Sonnet</td>
|
106 |
+
<td align="center">✅</td>
|
107 |
+
<td align="center">78.3</td>
|
108 |
+
<td align="center">16.0</td>
|
109 |
+
<td align="center">65.0</td>
|
110 |
+
<td align="center">67.7</td>
|
111 |
+
<td align="center">68.3</td>
|
112 |
+
<td align="center">-</td>
|
113 |
+
</tr>
|
114 |
+
<tr>
|
115 |
+
<td>GPT-4o</td>
|
116 |
+
<td align="center">✅</td>
|
117 |
+
<td align="center">76.6</td>
|
118 |
+
<td align="center">9.3</td>
|
119 |
+
<td align="center">53.6</td>
|
120 |
+
<td align="center">63.8</td>
|
121 |
+
<td align="center">69.1</td>
|
122 |
+
<td align="center">-</td>
|
123 |
+
</tr>
|
124 |
+
<tr>
|
125 |
+
<td>Kimi k1.5</td>
|
126 |
+
<td align="center">✅</td>
|
127 |
+
<td align="center">96.2</td>
|
128 |
+
<td align="center">77.5</td>
|
129 |
+
<td align="center">-</td>
|
130 |
+
<td align="center">74.9</td>
|
131 |
+
<td align="center">70.0</td>
|
132 |
+
<td align="center">-</td>
|
133 |
+
</tr>
|
134 |
+
<tr>
|
135 |
+
<td>Qwen2.5-VL-72B-Instruct</td>
|
136 |
+
<td align="center">✅</td>
|
137 |
+
<td align="center">-</td>
|
138 |
+
<td align="center">-</td>
|
139 |
+
<td align="center">-</td>
|
140 |
+
<td align="center">74.8</td>
|
141 |
+
<td align="center">70.2</td>
|
142 |
+
<td align="center">-</td>
|
143 |
+
</tr>
|
144 |
+
<tr>
|
145 |
+
<td>LLaVA-Onevision-72B</td>
|
146 |
+
<td align="center">✅</td>
|
147 |
+
<td align="center">-</td>
|
148 |
+
<td align="center">-</td>
|
149 |
+
<td align="center">-</td>
|
150 |
+
<td align="center">67.5</td>
|
151 |
+
<td align="center">56.8</td>
|
152 |
+
<td align="center">-</td>
|
153 |
+
</tr>
|
154 |
+
<tr>
|
155 |
+
<td>InternVL2-Llama3-76B</td>
|
156 |
+
<td align="center">✅</td>
|
157 |
+
<td align="center">-</td>
|
158 |
+
<td align="center">-</td>
|
159 |
+
<td align="center">-</td>
|
160 |
+
<td align="center">65.5</td>
|
161 |
+
<td align="center">58.3</td>
|
162 |
+
<td align="center">-</td>
|
163 |
+
</tr>
|
164 |
+
<tr>
|
165 |
+
<td>InternVL2.5-78B</td>
|
166 |
+
<td align="center">✅</td>
|
167 |
+
<td align="center">-</td>
|
168 |
+
<td align="center">-</td>
|
169 |
+
<td align="center">-</td>
|
170 |
+
<td align="center">72.3</td>
|
171 |
+
<td align="center">70.1</td>
|
172 |
+
<td align="center">-</td>
|
173 |
+
</tr>
|
174 |
+
<tr>
|
175 |
+
<td>Skywork-R1V-38B</td>
|
176 |
+
<td align="center">✅</td>
|
177 |
+
<td align="center">94.0</td>
|
178 |
+
<td align="center">72.0</td>
|
179 |
+
<td align="center">61.6</td>
|
180 |
+
<td align="center">71.0</td>
|
181 |
+
<td align="center">68.1</td>
|
182 |
+
<td align="center">XXX</td>
|
183 |
+
</tr>
|
184 |
+
</tbody>
|
185 |
+
</table>
|
186 |
+
|
187 |
+
<div align="center">
|
188 |
+
<b>Comparison with Larger-Scale Open-Source and Closed-Source Models</b>
|
189 |
+
</div>
|
190 |
+
|
191 |
+
<table align="center">
|
192 |
+
<thead>
|
193 |
+
<tr>
|
194 |
+
<th></th>
|
195 |
+
<th align="center"><strong>Benchmark</strong></th>
|
196 |
+
<th align="center"><strong>LLM</strong></th>
|
197 |
+
<th align="center" colspan="4"><strong>VLM</strong></th>
|
198 |
+
</tr>
|
199 |
+
<tr>
|
200 |
+
<th></th>
|
201 |
+
<th></th>
|
202 |
+
<th align="center"><strong>QwQ-32B-Preview</strong></th>
|
203 |
+
<th align="center"><strong>InternVL-2.5-38B</strong></th>
|
204 |
+
<th align="center"><strong>VILA 1.5-40B</strong></th>
|
205 |
+
<th align="center"><strong>InternVL2-40B</strong></th>
|
206 |
+
<th align="center"><strong>Skywork-R1V-38B</strong></th>
|
207 |
+
</tr>
|
208 |
+
</thead>
|
209 |
+
<tbody>
|
210 |
+
<tr>
|
211 |
+
<td rowspan="3">Reasoning</td>
|
212 |
+
<td>MATH-500</td>
|
213 |
+
<td align="center">90.6</td>
|
214 |
+
<td align="center">-</td>
|
215 |
+
<td align="center">-</td>
|
216 |
+
<td align="center">-</td>
|
217 |
+
<td align="center"><strong>94.0</strong></td>
|
218 |
+
</tr>
|
219 |
+
<tr>
|
220 |
+
<td>AIME 2024</td>
|
221 |
+
<td align="center">50.0</td>
|
222 |
+
<td align="center">-</td>
|
223 |
+
<td align="center">-</td>
|
224 |
+
<td align="center">-</td>
|
225 |
+
<td align="center"><strong>72.0</strong></td>
|
226 |
+
</tr>
|
227 |
+
<tr>
|
228 |
+
<td>GPQA</td>
|
229 |
+
<td align="center">65.2</td>
|
230 |
+
<td align="center">-</td>
|
231 |
+
<td align="center">-</td>
|
232 |
+
<td align="center">-</td>
|
233 |
+
<td align="center">61.6</td>
|
234 |
+
</tr>
|
235 |
+
<tr>
|
236 |
+
<td rowspan="3">Vision</td>
|
237 |
+
<td>MathVista(mini)</td>
|
238 |
+
<td align="center">-</td>
|
239 |
+
<td align="center">71.9</td>
|
240 |
+
<td align="center">49.5</td>
|
241 |
+
<td align="center">63.7</td>
|
242 |
+
<td align="center">71.0</td>
|
243 |
+
</tr>
|
244 |
+
<tr>
|
245 |
+
<td>MMMU(Val)</td>
|
246 |
+
<td align="center">-</td>
|
247 |
+
<td align="center">63.9</td>
|
248 |
+
<td align="center">55.1</td>
|
249 |
+
<td align="center">55.2</td>
|
250 |
+
<td align="center">68.1</td>
|
251 |
+
</tr>
|
252 |
+
<tr>
|
253 |
+
<td>CSVQA</td>
|
254 |
+
<td align="center">-</td>
|
255 |
+
<td align="center"></td>
|
256 |
+
<td align="center"></td>
|
257 |
+
<td align="center"></td>
|
258 |
+
<td align="center"></td>
|
259 |
+
</tr>
|
260 |
+
</tbody>
|
261 |
+
</table>
|
262 |
+
|
263 |
+
|
264 |
+
## 4. Skywork-R1V Family
|
265 |
+
|
266 |
+
| Model Name | Vision Encoder | Language Model | HF Link |
|
267 |
+
| ---------------------- | -------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------- | ------------ |
|
268 |
+
| Skywork-R1V-38B | [InternViT-6B-448px-V2_5](https://huggingface.co/OpenGVLab/InternViT-6B-448px-V2_5) | [deepseek-ai/DeepSeek-R1-Distill-Qwen-32B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B) | [🤗 Link](#) |
|
269 |
+
| Skywork-R1V-38B-qwq | [InternViT-6B-448px-V2_5](https://huggingface.co/OpenGVLab/InternViT-6B-448px-V2_5) | [Qwen/QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | - |
|
270 |
+
|
271 |
+
---
|
272 |
+
|
273 |
+
|
274 |
+
## 5. Quick Start
|
275 |
+
|
276 |
+
This section describes how to quickly install, configure, and run the Skywork-R1V model.
|
277 |
+
|
278 |
+
**Example Steps:**
|
279 |
+
|
280 |
+
1. **Clone GitHub repository**
|
281 |
+
```bash
|
282 |
+
git clone https://github.com/your-repo
|
283 |
+
```
|
284 |
+
|
285 |
+
2. **Install dependencies**
|
286 |
+
```bash
|
287 |
+
cd your-repo
|
288 |
+
pip install -r requirements.txt
|
289 |
+
```
|
290 |
+
|
291 |
+
3. **Run example code**
|
292 |
+
```bash
|
293 |
+
python demo.py
|
294 |
+
```
|
295 |
+
|
296 |
+
---
|
297 |
+
|
298 |
+
## 6. Additional Resources
|
299 |
+
|
300 |
+
- [📂 GitHub Repository](https://github.com/your-repo)
|
301 |
+
- [🗨️ Chat Demo](#)
|
302 |
+
- [🚀 Quick Start](#快速入门)
|
303 |
+
- [📖 Full Documentation](#)
|
304 |
+
|
305 |
+
|
306 |
+
## 7. Citation
|
307 |
+
If you use Skywork-R1V in your research, please cite:
|
308 |
+
|
309 |
+
```
|
310 |
+
@article{skywork2025r1v,
|
311 |
+
title = {Skywork-R1V: Bridging Vision and Language for Advanced Multimodal Reasoning},
|
312 |
+
author = {SkyworkVL Team},
|
313 |
+
year = {2025},
|
314 |
+
journal = {arXiv preprint arXiv:XXXX.XXXXX},
|
315 |
+
url = {https://github.com/skywork-ai/Skywork-R1V}
|
316 |
+
}
|
317 |
+
```
|
318 |
+
|
319 |
+
*This project is released under an open-source license.*
|
README_ZH.md
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Skywork-R1V
|
2 |
+
|
3 |
+
<div align="center">
|
4 |
+
<img src="logo.jpeg" alt="Introduction Image" width="400" height="400">
|
5 |
+
<br>
|
6 |
+
<a href="README.md">English</a> | <a href="https://github.com/SkyworkAI/Skywork-R1V">📂 GitHub</a>
|
7 |
+
</div>
|
8 |
+
|
9 |
+
|
10 |
+
## 1. 介绍
|
11 |
+
|
12 |
+
我们推出Skywork-R1V,一种多模态推理模型,通过近乎无损的迁移方法,将R1系列文本模型扩展到视觉模态。Skywork-R1V采用轻量级视觉投影器,无需重新训练基础语言模型或视觉编码器,即可实现无缝的多模态适配。为提升视觉-文本对齐,我们开发了结合迭代监督微调(SFT)与组相对策略优化(GRPO)的混合优化策略,显著提高了跨模态融合能力。此外,我们创造了一种自适应长度的思维链(Chain-of-Thought)蒸馏方法用于生成推理数据,动态优化推理链长度以提高推理效率并避免过度推理。该模型在重要多模态推理基准测试中达到最先进水平,在MMMU上得分68.1,在MathVista上得分71.0,可与领先的闭源模型(如Gemini 2.0和Kimi-k1.5)媲美。同时,它还保持了出色的文本推理能力,在AIME达到72.6分,在MATH500达到94.3分。
|
13 |
+
|
14 |
+
## 2. 模型概述
|
15 |
+
|
16 |
+
**架构:**
|
17 |
+
|
18 |
+
Skywork-R1V采用模块化架构,有效结合视觉和语言能力:
|
19 |
+
- **视觉编码器:** 使用视觉Transformer (ViT)作为视觉主干处理图像输入。
|
20 |
+
- **视觉投影器:** 轻量级MLP适配器,作为视觉与语言组件间的桥梁。
|
21 |
+
- **语言模型:** 采用R1-distilled-Qwen-32B作为具备推理能力的语言模型主干。
|
22 |
+
|
23 |
+
模型连接模式为视觉编码器 → MLP适配器 → 语言模型,其中MLP适配器将视觉编码器的输出空间与语言模型的输入空间对齐。这种设计可高效地将文本的推理能力迁移到多模态领域,无需大规模重新训练视觉编码器或语言模型。
|
24 |
+
|
25 |
+
**关键设计**
|
26 |
+
- **先进的多模态推理**
|
27 |
+
擅长跨文本和视觉模态的复杂推理。
|
28 |
+
- **迭代训练策略**
|
29 |
+
采用迭代监督和GRPO优化模型对齐和性能。
|
30 |
+
- **自适应长度思维链**
|
31 |
+
动态调整推理长度以增强推理效率和准确性。
|
32 |
+
- **可扩展性能**
|
33 |
+
在数学、编程和多模态任务上性能媲美专有模型。
|
34 |
+
|
35 |
+
## 3. 评估
|
36 |
+
|
37 |
+
<div align="center">
|
38 |
+
<img src="eval.jpeg" width="600" height="200" alt="skywork_r1v_eval" />
|
39 |
+
</div>
|
40 |
+
|
41 |
+
<div align="center">
|
42 |
+
<b>Evaluation results of state-of-the-art LLMs and VLMs</b>
|
43 |
+
</div>
|
44 |
+
<table>
|
45 |
+
<thead>
|
46 |
+
<tr>
|
47 |
+
<th></th>
|
48 |
+
<th align="center"><strong>Vision</strong></th>
|
49 |
+
<th align="center" colspan="3"><strong>Reasoning</strong></th>
|
50 |
+
<th align="center" colspan="3"><strong>Vision</strong></th>
|
51 |
+
</tr>
|
52 |
+
<tr>
|
53 |
+
<th></th>
|
54 |
+
<th></th>
|
55 |
+
<th align="center"><strong>MATH-500</strong></th>
|
56 |
+
<th align="center"><strong>AIME 2024</strong></th>
|
57 |
+
<th align="center"><strong>GPQA</strong></th>
|
58 |
+
<th align="center"><strong>MathVista(mini)</strong></th>
|
59 |
+
<th align="center"><strong>MMMU(Val)</strong></th>
|
60 |
+
<th align="center"><strong>CSVQA</strong></th>
|
61 |
+
</tr>
|
62 |
+
<tr>
|
63 |
+
<th></th>
|
64 |
+
<th></th>
|
65 |
+
<th align="center">pass@1</th>
|
66 |
+
<th align="center">pass@1</th>
|
67 |
+
<th align="center">pass@1</th>
|
68 |
+
<th align="center">pass@1</th>
|
69 |
+
<th align="center">pass@1</th>
|
70 |
+
<th align="center">pass@1</th>
|
71 |
+
</tr>
|
72 |
+
</thead>
|
73 |
+
<tbody>
|
74 |
+
<tr>
|
75 |
+
<td>Qwen2.5-72B-Instruct</td>
|
76 |
+
<td align="center">❌</td>
|
77 |
+
<td align="center">82.6</td>
|
78 |
+
<td align="center">23.3</td>
|
79 |
+
<td align="center">49.0</td>
|
80 |
+
<td align="center">-</td>
|
81 |
+
<td align="center">-</td>
|
82 |
+
<td align="center">-</td>
|
83 |
+
</tr>
|
84 |
+
<tr>
|
85 |
+
<td>Deepseek V3</td>
|
86 |
+
<td align="center">❌</td>
|
87 |
+
<td align="center">90.2</td>
|
88 |
+
<td align="center">39.2</td>
|
89 |
+
<td align="center">59.1</td>
|
90 |
+
<td align="center">-</td>
|
91 |
+
<td align="center">-</td>
|
92 |
+
<td align="center">-</td>
|
93 |
+
</tr>
|
94 |
+
<tr>
|
95 |
+
<td>Deepseek R1</td>
|
96 |
+
<td align="center">❌</td>
|
97 |
+
<td align="center">97.3</td>
|
98 |
+
<td align="center">79.8</td>
|
99 |
+
<td align="center">71.5</td>
|
100 |
+
<td align="center">-</td>
|
101 |
+
<td align="center">-</td>
|
102 |
+
<td align="center">-</td>
|
103 |
+
</tr>
|
104 |
+
<tr>
|
105 |
+
<td>Claude 3.5 Sonnet</td>
|
106 |
+
<td align="center">✅</td>
|
107 |
+
<td align="center">78.3</td>
|
108 |
+
<td align="center">16.0</td>
|
109 |
+
<td align="center">65.0</td>
|
110 |
+
<td align="center">67.7</td>
|
111 |
+
<td align="center">68.3</td>
|
112 |
+
<td align="center">-</td>
|
113 |
+
</tr>
|
114 |
+
<tr>
|
115 |
+
<td>GPT-4o</td>
|
116 |
+
<td align="center">✅</td>
|
117 |
+
<td align="center">76.6</td>
|
118 |
+
<td align="center">9.3</td>
|
119 |
+
<td align="center">53.6</td>
|
120 |
+
<td align="center">63.8</td>
|
121 |
+
<td align="center">69.1</td>
|
122 |
+
<td align="center">-</td>
|
123 |
+
</tr>
|
124 |
+
<tr>
|
125 |
+
<td>Kimi k1.5</td>
|
126 |
+
<td align="center">✅</td>
|
127 |
+
<td align="center">96.2</td>
|
128 |
+
<td align="center">77.5</td>
|
129 |
+
<td align="center">-</td>
|
130 |
+
<td align="center">74.9</td>
|
131 |
+
<td align="center">70.0</td>
|
132 |
+
<td align="center">-</td>
|
133 |
+
</tr>
|
134 |
+
<tr>
|
135 |
+
<td>Qwen2.5-VL-72B-Instruct</td>
|
136 |
+
<td align="center">✅</td>
|
137 |
+
<td align="center">-</td>
|
138 |
+
<td align="center">-</td>
|
139 |
+
<td align="center">-</td>
|
140 |
+
<td align="center">74.8</td>
|
141 |
+
<td align="center">70.2</td>
|
142 |
+
<td align="center">-</td>
|
143 |
+
</tr>
|
144 |
+
<tr>
|
145 |
+
<td>LLaVA-Onevision-72B</td>
|
146 |
+
<td align="center">✅</td>
|
147 |
+
<td align="center">-</td>
|
148 |
+
<td align="center">-</td>
|
149 |
+
<td align="center">-</td>
|
150 |
+
<td align="center">67.5</td>
|
151 |
+
<td align="center">56.8</td>
|
152 |
+
<td align="center">-</td>
|
153 |
+
</tr>
|
154 |
+
<tr>
|
155 |
+
<td>InternVL2-Llama3-76B</td>
|
156 |
+
<td align="center">✅</td>
|
157 |
+
<td align="center">-</td>
|
158 |
+
<td align="center">-</td>
|
159 |
+
<td align="center">-</td>
|
160 |
+
<td align="center">65.5</td>
|
161 |
+
<td align="center">58.3</td>
|
162 |
+
<td align="center">-</td>
|
163 |
+
</tr>
|
164 |
+
<tr>
|
165 |
+
<td>InternVL2.5-78B</td>
|
166 |
+
<td align="center">✅</td>
|
167 |
+
<td align="center">-</td>
|
168 |
+
<td align="center">-</td>
|
169 |
+
<td align="center">-</td>
|
170 |
+
<td align="center">72.3</td>
|
171 |
+
<td align="center">70.1</td>
|
172 |
+
<td align="center">-</td>
|
173 |
+
</tr>
|
174 |
+
<tr>
|
175 |
+
<td>Skywork-R1V-38B</td>
|
176 |
+
<td align="center">✅</td>
|
177 |
+
<td align="center">94.0</td>
|
178 |
+
<td align="center">72.0</td>
|
179 |
+
<td align="center">61.6</td>
|
180 |
+
<td align="center">71.0</td>
|
181 |
+
<td align="center">68.1</td>
|
182 |
+
<td align="center">XXX</td>
|
183 |
+
</tr>
|
184 |
+
</tbody>
|
185 |
+
</table>
|
186 |
+
|
187 |
+
<div align="center">
|
188 |
+
<b>Comparison with Larger-Scale Open-Source and Closed-Source Models</b>
|
189 |
+
</div>
|
190 |
+
|
191 |
+
<table align="center">
|
192 |
+
<thead>
|
193 |
+
<tr>
|
194 |
+
<th></th>
|
195 |
+
<th align="center"><strong>Benchmark</strong></th>
|
196 |
+
<th align="center"><strong>LLM</strong></th>
|
197 |
+
<th align="center" colspan="4"><strong>VLM</strong></th>
|
198 |
+
</tr>
|
199 |
+
<tr>
|
200 |
+
<th></th>
|
201 |
+
<th></th>
|
202 |
+
<th align="center"><strong>QwQ-32B-Preview</strong></th>
|
203 |
+
<th align="center"><strong>InternVL-2.5-38B</strong></th>
|
204 |
+
<th align="center"><strong>VILA 1.5-40B</strong></th>
|
205 |
+
<th align="center"><strong>InternVL2-40B</strong></th>
|
206 |
+
<th align="center"><strong>Skywork-R1V-38B</strong></th>
|
207 |
+
</tr>
|
208 |
+
</thead>
|
209 |
+
<tbody>
|
210 |
+
<tr>
|
211 |
+
<td rowspan="3">Reasoning</td>
|
212 |
+
<td>MATH-500</td>
|
213 |
+
<td align="center">90.6</td>
|
214 |
+
<td align="center">-</td>
|
215 |
+
<td align="center">-</td>
|
216 |
+
<td align="center">-</td>
|
217 |
+
<td align="center"><strong>94.0</strong></td>
|
218 |
+
</tr>
|
219 |
+
<tr>
|
220 |
+
<td>AIME 2024</td>
|
221 |
+
<td align="center">50.0</td>
|
222 |
+
<td align="center">-</td>
|
223 |
+
<td align="center">-</td>
|
224 |
+
<td align="center">-</td>
|
225 |
+
<td align="center"><strong>72.0</strong></td>
|
226 |
+
</tr>
|
227 |
+
<tr>
|
228 |
+
<td>GPQA</td>
|
229 |
+
<td align="center">65.2</td>
|
230 |
+
<td align="center">-</td>
|
231 |
+
<td align="center">-</td>
|
232 |
+
<td align="center">-</td>
|
233 |
+
<td align="center">61.6</td>
|
234 |
+
</tr>
|
235 |
+
<tr>
|
236 |
+
<td rowspan="3">Vision</td>
|
237 |
+
<td>MathVista(mini)</td>
|
238 |
+
<td align="center">-</td>
|
239 |
+
<td align="center">71.9</td>
|
240 |
+
<td align="center">49.5</td>
|
241 |
+
<td align="center">63.7</td>
|
242 |
+
<td align="center">71.0</td>
|
243 |
+
</tr>
|
244 |
+
<tr>
|
245 |
+
<td>MMMU(Val)</td>
|
246 |
+
<td align="center">-</td>
|
247 |
+
<td align="center">63.9</td>
|
248 |
+
<td align="center">55.1</td>
|
249 |
+
<td align="center">55.2</td>
|
250 |
+
<td align="center">68.1</td>
|
251 |
+
</tr>
|
252 |
+
<tr>
|
253 |
+
<td>CSVQA</td>
|
254 |
+
<td align="center">-</td>
|
255 |
+
<td align="center"></td>
|
256 |
+
<td align="center"></td>
|
257 |
+
<td align="center"></td>
|
258 |
+
<td align="center"></td>
|
259 |
+
</tr>
|
260 |
+
</tbody>
|
261 |
+
</table>
|
262 |
+
|
263 |
+
## 4. Skywork-R1V家族
|
264 |
+
|
265 |
+
| Model Name | Vision Encoder | Language Model | HF Link |
|
266 |
+
| ---------------------- | -------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------- | ------------ |
|
267 |
+
| Skywork-R1V-38B | [InternViT-6B-448px-V2_5](https://huggingface.co/OpenGVLab/InternViT-6B-448px-V2_5) | [deepseek-ai/DeepSeek-R1-Distill-Qwen-32B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B) | [🤗 Link](#) |
|
268 |
+
| Skywork-R1V-38B-qwq | [InternViT-6B-448px-V2_5](https://huggingface.co/OpenGVLab/InternViT-6B-448px-V2_5) | [Qwen/QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | - |
|
269 |
+
|
270 |
+
---
|
271 |
+
|
272 |
+
## 5. 快速开始
|
273 |
+
|
274 |
+
**示例步骤:**
|
275 |
+
|
276 |
+
1. **克隆GitHub仓库**
|
277 |
+
```bash
|
278 |
+
git clone https://github.com/your-repo
|
279 |
+
```
|
280 |
+
|
281 |
+
2. **安装依赖**
|
282 |
+
```bash
|
283 |
+
cd your-repo
|
284 |
+
pip install -r requirements.txt
|
285 |
+
```
|
286 |
+
|
287 |
+
3. **运行示例代码**
|
288 |
+
```bash
|
289 |
+
python demo.py
|
290 |
+
```
|
291 |
+
|
292 |
+
---
|
293 |
+
|
294 |
+
## 6. 附加资源
|
295 |
+
|
296 |
+
- [📂 GitHub仓库](https://github.com/your-repo)
|
297 |
+
- [🗨️ Chat Demo](#)
|
298 |
+
- [🚀 快速入门](#快速入门)
|
299 |
+
- [📖 完整文档](#)
|
300 |
+
|
301 |
+
---
|
302 |
+
|
303 |
+
## 7. 引用
|
304 |
+
如果您在研究中使用了Skywork-R1V,请引用:
|
305 |
+
|
306 |
+
```
|
307 |
+
@article{skywork2025r1v,
|
308 |
+
title = {Skywork-R1V: Bridging Vision and Language for Advanced Multimodal Reasoning},
|
309 |
+
author = {SkyworkVL Team},
|
310 |
+
year = {2025},
|
311 |
+
journal = {arXiv preprint arXiv:XXXX.XXXXX},
|
312 |
+
url = {https://github.com/skywork-ai/Skywork-R1V}
|
313 |
+
}
|
314 |
+
```
|
315 |
+
|
316 |
+
*本项目采用开源许可证发布。*
|
317 |
+
|
added_tokens.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"</box>": 92552,
|
3 |
+
"</img>": 92545,
|
4 |
+
"</quad>": 92548,
|
5 |
+
"</ref>": 92550,
|
6 |
+
"<IMG_CONTEXT>": 92546,
|
7 |
+
"<box>": 92551,
|
8 |
+
"<img>": 92544,
|
9 |
+
"<quad>": 92547,
|
10 |
+
"<ref>": 92549
|
11 |
+
}
|
config.json
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_commit_hash": null,
|
3 |
+
"_name_or_path": "",
|
4 |
+
"auto_map": {
|
5 |
+
"AutoConfig": "configuration_skywork_chat.SkyworkChatConfig",
|
6 |
+
"AutoModel": "modeling_skywork_chat.SkyworkChatModel",
|
7 |
+
"AutoModelForCausalLM": "modeling_skywork_chat.SkyworkChatModel"
|
8 |
+
},
|
9 |
+
"downsample_ratio": 0.5,
|
10 |
+
"dynamic_image_size": true,
|
11 |
+
"force_image_size": 448,
|
12 |
+
"hidden_size": 5120,
|
13 |
+
"llm_config": {
|
14 |
+
"_name_or_path": "",
|
15 |
+
"add_cross_attention": false,
|
16 |
+
"architectures": [
|
17 |
+
"Qwen2ForCausalLM"
|
18 |
+
],
|
19 |
+
"attention_dropout": 0.0,
|
20 |
+
"bad_words_ids": null,
|
21 |
+
"begin_suppress_tokens": null,
|
22 |
+
"bos_token_id": 151643,
|
23 |
+
"chunk_size_feed_forward": 0,
|
24 |
+
"cross_attention_hidden_size": null,
|
25 |
+
"decoder_start_token_id": null,
|
26 |
+
"diversity_penalty": 0.0,
|
27 |
+
"do_sample": false,
|
28 |
+
"early_stopping": false,
|
29 |
+
"encoder_no_repeat_ngram_size": 0,
|
30 |
+
"eos_token_id": 151643,
|
31 |
+
"exponential_decay_length_penalty": null,
|
32 |
+
"finetuning_task": null,
|
33 |
+
"forced_bos_token_id": null,
|
34 |
+
"forced_eos_token_id": null,
|
35 |
+
"hidden_act": "silu",
|
36 |
+
"hidden_size": 5120,
|
37 |
+
"id2label": {
|
38 |
+
"0": "LABEL_0",
|
39 |
+
"1": "LABEL_1"
|
40 |
+
},
|
41 |
+
"initializer_range": 0.02,
|
42 |
+
"intermediate_size": 27648,
|
43 |
+
"is_decoder": false,
|
44 |
+
"is_encoder_decoder": false,
|
45 |
+
"label2id": {
|
46 |
+
"LABEL_0": 0,
|
47 |
+
"LABEL_1": 1
|
48 |
+
},
|
49 |
+
"length_penalty": 1.0,
|
50 |
+
"max_length": 20,
|
51 |
+
"max_position_embeddings": 131072,
|
52 |
+
"max_window_layers": 64,
|
53 |
+
"min_length": 0,
|
54 |
+
"model_type": "qwen2",
|
55 |
+
"no_repeat_ngram_size": 0,
|
56 |
+
"num_attention_heads": 40,
|
57 |
+
"num_beam_groups": 1,
|
58 |
+
"num_beams": 1,
|
59 |
+
"num_hidden_layers": 64,
|
60 |
+
"num_key_value_heads": 8,
|
61 |
+
"num_return_sequences": 1,
|
62 |
+
"output_attentions": false,
|
63 |
+
"output_hidden_states": false,
|
64 |
+
"output_scores": false,
|
65 |
+
"pad_token_id": null,
|
66 |
+
"prefix": null,
|
67 |
+
"problem_type": null,
|
68 |
+
"pruned_heads": {},
|
69 |
+
"remove_invalid_values": false,
|
70 |
+
"repetition_penalty": 1.0,
|
71 |
+
"return_dict": true,
|
72 |
+
"return_dict_in_generate": false,
|
73 |
+
"rms_norm_eps": 1e-05,
|
74 |
+
"rope_theta": 1000000.0,
|
75 |
+
"sep_token_id": null,
|
76 |
+
"sliding_window": 131072,
|
77 |
+
"suppress_tokens": null,
|
78 |
+
"task_specific_params": null,
|
79 |
+
"temperature": 1.0,
|
80 |
+
"tf_legacy_loss": false,
|
81 |
+
"tie_encoder_decoder": false,
|
82 |
+
"tie_word_embeddings": false,
|
83 |
+
"tokenizer_class": null,
|
84 |
+
"top_k": 50,
|
85 |
+
"top_p": 1.0,
|
86 |
+
"torch_dtype": "bfloat16",
|
87 |
+
"torchscript": false,
|
88 |
+
"transformers_version": "4.37.2",
|
89 |
+
"typical_p": 1.0,
|
90 |
+
"use_bfloat16": false,
|
91 |
+
"use_cache": false,
|
92 |
+
"use_sliding_window": false,
|
93 |
+
"vocab_size": 152064
|
94 |
+
},
|
95 |
+
"max_dynamic_patch": 6,
|
96 |
+
"min_dynamic_patch": 1,
|
97 |
+
"model_type": "skywork_chat",
|
98 |
+
"pad2square": false,
|
99 |
+
"ps_version": "v2",
|
100 |
+
"select_layer": -1,
|
101 |
+
"template": "skywork-r1v-chat",
|
102 |
+
"tie_word_embeddings": false,
|
103 |
+
"torch_dtype": "bfloat16",
|
104 |
+
"transformers_version": null,
|
105 |
+
"use_backbone_lora": 0,
|
106 |
+
"use_llm_lora": 0,
|
107 |
+
"use_thumbnail": true,
|
108 |
+
"vision_config": {
|
109 |
+
"attention_dropout": 0.0,
|
110 |
+
"drop_path_rate": 0.0,
|
111 |
+
"dropout": 0.0,
|
112 |
+
"hidden_act": "gelu",
|
113 |
+
"hidden_size": 3200,
|
114 |
+
"image_size": 448,
|
115 |
+
"initializer_factor": 0.1,
|
116 |
+
"initializer_range": 1e-10,
|
117 |
+
"intermediate_size": 12800,
|
118 |
+
"layer_norm_eps": 1e-06,
|
119 |
+
"norm_type": "rms_norm",
|
120 |
+
"num_attention_heads": 25,
|
121 |
+
"num_channels": 3,
|
122 |
+
"num_hidden_layers": 45,
|
123 |
+
"output_attentions": false,
|
124 |
+
"output_hidden_states": false,
|
125 |
+
"patch_size": 14,
|
126 |
+
"qk_normalization": true,
|
127 |
+
"qkv_bias": false,
|
128 |
+
"return_dict": true,
|
129 |
+
"torch_dtype": "bfloat16",
|
130 |
+
"transformers_version": "4.37.2",
|
131 |
+
"use_bfloat16": true,
|
132 |
+
"use_flash_attn": true
|
133 |
+
}
|
134 |
+
}
|
configuration.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"framework": "pytorch", "task": "image-text-to-text", "allow_remote": true}
|
configuration_skywork_chat.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
|
3 |
+
from transformers import AutoConfig, LlamaConfig
|
4 |
+
from transformers.configuration_utils import PretrainedConfig
|
5 |
+
from transformers.utils import logging
|
6 |
+
|
7 |
+
from .configuration_skywork_vit import SkyworkVisionConfig
|
8 |
+
from .configuration_skywork_lm2 import SkyworkLM2Config
|
9 |
+
from transformers import Qwen2Config, Qwen2ForCausalLM
|
10 |
+
|
11 |
+
logger = logging.get_logger(__name__)
|
12 |
+
|
13 |
+
|
14 |
+
class SkyworkChatConfig(PretrainedConfig):
|
15 |
+
model_type = 'skywork_chat'
|
16 |
+
is_composition = True
|
17 |
+
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
vision_config=None,
|
21 |
+
llm_config=None,
|
22 |
+
use_backbone_lora=0,
|
23 |
+
use_llm_lora=0,
|
24 |
+
select_layer=-1,
|
25 |
+
force_image_size=None,
|
26 |
+
downsample_ratio=0.5,
|
27 |
+
template=None,
|
28 |
+
dynamic_image_size=False,
|
29 |
+
use_thumbnail=False,
|
30 |
+
ps_version='v1',
|
31 |
+
min_dynamic_patch=1,
|
32 |
+
max_dynamic_patch=6,
|
33 |
+
**kwargs):
|
34 |
+
super().__init__(**kwargs)
|
35 |
+
if vision_config is None:
|
36 |
+
vision_config = {'architectures': ['SkyworkVisionModel']}
|
37 |
+
logger.info('vision_config is None. Initializing the SkyworkVisionConfig with default values.')
|
38 |
+
|
39 |
+
if llm_config is None:
|
40 |
+
llm_config = {'architectures': ['Qwen2ForCausalLM']}
|
41 |
+
logger.info('llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`).')
|
42 |
+
|
43 |
+
self.vision_config = SkyworkVisionConfig(**vision_config)
|
44 |
+
if llm_config.get('architectures')[0] == 'LlamaForCausalLM':
|
45 |
+
self.llm_config = LlamaConfig(**llm_config)
|
46 |
+
elif llm_config.get('architectures')[0] == 'Qwen2ForCausalLM':
|
47 |
+
self.llm_config = Qwen2Config(**llm_config)
|
48 |
+
else:
|
49 |
+
raise ValueError('Unsupported architecture: {}'.format(llm_config.get('architectures')[0]))
|
50 |
+
|
51 |
+
|
52 |
+
self.use_backbone_lora = use_backbone_lora
|
53 |
+
self.use_llm_lora = use_llm_lora
|
54 |
+
self.select_layer = select_layer
|
55 |
+
self.force_image_size = force_image_size
|
56 |
+
self.downsample_ratio = downsample_ratio
|
57 |
+
self.template = template
|
58 |
+
self.dynamic_image_size = dynamic_image_size
|
59 |
+
self.use_thumbnail = use_thumbnail
|
60 |
+
self.ps_version = ps_version # pixel shuffle version
|
61 |
+
self.min_dynamic_patch = min_dynamic_patch
|
62 |
+
self.max_dynamic_patch = max_dynamic_patch
|
63 |
+
|
64 |
+
logger.info(f'vision_select_layer: {self.select_layer}')
|
65 |
+
logger.info(f'ps_version: {self.ps_version}')
|
66 |
+
logger.info(f'min_dynamic_patch: {self.min_dynamic_patch}')
|
67 |
+
logger.info(f'max_dynamic_patch: {self.max_dynamic_patch}')
|
68 |
+
|
69 |
+
def to_dict(self):
|
70 |
+
"""
|
71 |
+
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
75 |
+
"""
|
76 |
+
output = copy.deepcopy(self.__dict__)
|
77 |
+
output['vision_config'] = self.vision_config.to_dict()
|
78 |
+
output['llm_config'] = self.llm_config.to_dict()
|
79 |
+
output['model_type'] = self.__class__.model_type
|
80 |
+
output['use_backbone_lora'] = self.use_backbone_lora
|
81 |
+
output['use_llm_lora'] = self.use_llm_lora
|
82 |
+
output['select_layer'] = self.select_layer
|
83 |
+
output['force_image_size'] = self.force_image_size
|
84 |
+
output['downsample_ratio'] = self.downsample_ratio
|
85 |
+
output['template'] = self.template
|
86 |
+
output['dynamic_image_size'] = self.dynamic_image_size
|
87 |
+
output['use_thumbnail'] = self.use_thumbnail
|
88 |
+
output['ps_version'] = self.ps_version
|
89 |
+
output['min_dynamic_patch'] = self.min_dynamic_patch
|
90 |
+
output['max_dynamic_patch'] = self.max_dynamic_patch
|
91 |
+
|
92 |
+
return output
|
configuration_skywork_lm2.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) The Skywork team and The HuggingFace Inc. team. All rights reserved.
|
2 |
+
#
|
3 |
+
# This code is based on transformers/src/transformers/models/llama/configuration_llama.py
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
""" SkyworkLM2 model configuration"""
|
17 |
+
|
18 |
+
from transformers.configuration_utils import PretrainedConfig
|
19 |
+
from transformers.utils import logging
|
20 |
+
|
21 |
+
logger = logging.get_logger(__name__)
|
22 |
+
|
23 |
+
|
24 |
+
# Modified from transformers.model.llama.configuration_llama.LlamaConfig
|
25 |
+
class SkyworkLM2Config(PretrainedConfig):
|
26 |
+
r"""
|
27 |
+
Args:
|
28 |
+
vocab_size (`int`, *optional*, defaults to 32000):
|
29 |
+
Vocabulary size of the SkyworkLM2 model. Defines the number of different tokens that can be represented by the
|
30 |
+
`inputs_ids` passed when calling [`SkyworkLM2Model`]
|
31 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
32 |
+
Dimension of the hidden representations.
|
33 |
+
intermediate_size (`int`, *optional*, defaults to 11008):
|
34 |
+
Dimension of the MLP representations.
|
35 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
36 |
+
Number of hidden layers in the Transformer encoder.
|
37 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
38 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
39 |
+
num_key_value_heads (`int`, *optional*):
|
40 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
41 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
42 |
+
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
43 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
44 |
+
by meanpooling all the original heads within that group. For more details checkout [this
|
45 |
+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
46 |
+
`num_attention_heads`.
|
47 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
48 |
+
The non-linear activation function (function or string) in the decoder.
|
49 |
+
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
50 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
51 |
+
just in case (e.g., 512 or 1024 or 2048).
|
52 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
53 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
54 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-12):
|
55 |
+
The epsilon used by the rms normalization layers.
|
56 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
57 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
58 |
+
relevant if `config.is_decoder=True`.
|
59 |
+
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
|
60 |
+
Whether to tie weight embeddings
|
61 |
+
Example:
|
62 |
+
|
63 |
+
"""
|
64 |
+
_auto_class = 'AutoConfig'
|
65 |
+
|
66 |
+
def __init__(
|
67 |
+
self,
|
68 |
+
vocab_size=103168,
|
69 |
+
hidden_size=4096,
|
70 |
+
intermediate_size=11008,
|
71 |
+
num_hidden_layers=32,
|
72 |
+
num_attention_heads=32,
|
73 |
+
num_key_value_heads=None,
|
74 |
+
hidden_act='silu',
|
75 |
+
max_position_embeddings=2048,
|
76 |
+
initializer_range=0.02,
|
77 |
+
rms_norm_eps=1e-6,
|
78 |
+
use_cache=True,
|
79 |
+
pad_token_id=0,
|
80 |
+
bos_token_id=1,
|
81 |
+
eos_token_id=2,
|
82 |
+
tie_word_embeddings=False,
|
83 |
+
bias=True,
|
84 |
+
rope_theta=10000,
|
85 |
+
rope_scaling=None,
|
86 |
+
attn_implementation='eager',
|
87 |
+
**kwargs,
|
88 |
+
):
|
89 |
+
self.vocab_size = vocab_size
|
90 |
+
self.max_position_embeddings = max_position_embeddings
|
91 |
+
self.hidden_size = hidden_size
|
92 |
+
self.intermediate_size = intermediate_size
|
93 |
+
self.num_hidden_layers = num_hidden_layers
|
94 |
+
self.num_attention_heads = num_attention_heads
|
95 |
+
self.bias = bias
|
96 |
+
|
97 |
+
if num_key_value_heads is None:
|
98 |
+
num_key_value_heads = num_attention_heads
|
99 |
+
self.num_key_value_heads = num_key_value_heads
|
100 |
+
|
101 |
+
self.hidden_act = hidden_act
|
102 |
+
self.initializer_range = initializer_range
|
103 |
+
self.rms_norm_eps = rms_norm_eps
|
104 |
+
self.use_cache = use_cache
|
105 |
+
self.rope_theta = rope_theta
|
106 |
+
self.rope_scaling = rope_scaling
|
107 |
+
self._rope_scaling_validation()
|
108 |
+
|
109 |
+
self.attn_implementation = attn_implementation
|
110 |
+
if self.attn_implementation is None:
|
111 |
+
self.attn_implementation = 'eager'
|
112 |
+
super().__init__(
|
113 |
+
pad_token_id=pad_token_id,
|
114 |
+
bos_token_id=bos_token_id,
|
115 |
+
eos_token_id=eos_token_id,
|
116 |
+
tie_word_embeddings=tie_word_embeddings,
|
117 |
+
**kwargs,
|
118 |
+
)
|
119 |
+
|
120 |
+
def _rope_scaling_validation(self):
|
121 |
+
"""
|
122 |
+
Validate the `rope_scaling` configuration.
|
123 |
+
"""
|
124 |
+
if self.rope_scaling is None:
|
125 |
+
return
|
126 |
+
|
127 |
+
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
128 |
+
raise ValueError(
|
129 |
+
'`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, '
|
130 |
+
f'got {self.rope_scaling}'
|
131 |
+
)
|
132 |
+
rope_scaling_type = self.rope_scaling.get('type', None)
|
133 |
+
rope_scaling_factor = self.rope_scaling.get('factor', None)
|
134 |
+
if rope_scaling_type is None or rope_scaling_type not in ['linear', 'dynamic']:
|
135 |
+
raise ValueError(
|
136 |
+
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
137 |
+
)
|
138 |
+
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor < 1.0:
|
139 |
+
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {rope_scaling_factor}")
|
configuration_skywork_vit.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Union
|
3 |
+
|
4 |
+
from transformers.configuration_utils import PretrainedConfig
|
5 |
+
from transformers.utils import logging
|
6 |
+
|
7 |
+
logger = logging.get_logger(__name__)
|
8 |
+
|
9 |
+
|
10 |
+
class SkyworkVisionConfig(PretrainedConfig):
|
11 |
+
r"""
|
12 |
+
Args:
|
13 |
+
num_channels (`int`, *optional*, defaults to 3):
|
14 |
+
Number of color channels in the input images (e.g., 3 for RGB).
|
15 |
+
patch_size (`int`, *optional*, defaults to 14):
|
16 |
+
The size (resolution) of each patch.
|
17 |
+
image_size (`int`, *optional*, defaults to 224):
|
18 |
+
The size (resolution) of each image.
|
19 |
+
qkv_bias (`bool`, *optional*, defaults to `False`):
|
20 |
+
Whether to add a bias to the queries and values in the self-attention layers.
|
21 |
+
hidden_size (`int`, *optional*, defaults to 3200):
|
22 |
+
Dimensionality of the encoder layers and the pooler layer.
|
23 |
+
num_attention_heads (`int`, *optional*, defaults to 25):
|
24 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
25 |
+
intermediate_size (`int`, *optional*, defaults to 12800):
|
26 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
27 |
+
qk_normalization (`bool`, *optional*, defaults to `True`):
|
28 |
+
Whether to normalize the queries and keys in the self-attention layers.
|
29 |
+
num_hidden_layers (`int`, *optional*, defaults to 48):
|
30 |
+
Number of hidden layers in the Transformer encoder.
|
31 |
+
use_flash_attn (`bool`, *optional*, defaults to `True`):
|
32 |
+
Whether to use flash attention mechanism.
|
33 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
34 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
35 |
+
`"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported.
|
36 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-6):
|
37 |
+
The epsilon used by the layer normalization layers.
|
38 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
39 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
40 |
+
drop_path_rate (`float`, *optional*, defaults to 0.0):
|
41 |
+
Dropout rate for stochastic depth.
|
42 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
43 |
+
The dropout ratio for the attention probabilities.
|
44 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
45 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
46 |
+
initializer_factor (`float`, *optional*, defaults to 0.1):
|
47 |
+
A factor for layer scale.
|
48 |
+
"""
|
49 |
+
|
50 |
+
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
num_channels=3,
|
54 |
+
patch_size=14,
|
55 |
+
image_size=224,
|
56 |
+
qkv_bias=False,
|
57 |
+
hidden_size=3200,
|
58 |
+
num_attention_heads=25,
|
59 |
+
intermediate_size=12800,
|
60 |
+
qk_normalization=True,
|
61 |
+
num_hidden_layers=48,
|
62 |
+
use_flash_attn=True,
|
63 |
+
hidden_act='gelu',
|
64 |
+
norm_type='rms_norm',
|
65 |
+
layer_norm_eps=1e-6,
|
66 |
+
dropout=0.0,
|
67 |
+
drop_path_rate=0.0,
|
68 |
+
attention_dropout=0.0,
|
69 |
+
initializer_range=0.02,
|
70 |
+
initializer_factor=0.1,
|
71 |
+
**kwargs,
|
72 |
+
):
|
73 |
+
super().__init__(**kwargs)
|
74 |
+
|
75 |
+
self.hidden_size = hidden_size
|
76 |
+
self.intermediate_size = intermediate_size
|
77 |
+
self.dropout = dropout
|
78 |
+
self.drop_path_rate = drop_path_rate
|
79 |
+
self.num_hidden_layers = num_hidden_layers
|
80 |
+
self.num_attention_heads = num_attention_heads
|
81 |
+
self.num_channels = num_channels
|
82 |
+
self.patch_size = patch_size
|
83 |
+
self.image_size = image_size
|
84 |
+
self.initializer_range = initializer_range
|
85 |
+
self.initializer_factor = initializer_factor
|
86 |
+
self.attention_dropout = attention_dropout
|
87 |
+
self.layer_norm_eps = layer_norm_eps
|
88 |
+
self.hidden_act = hidden_act
|
89 |
+
self.norm_type = norm_type
|
90 |
+
self.qkv_bias = qkv_bias
|
91 |
+
self.qk_normalization = qk_normalization
|
92 |
+
self.use_flash_attn = use_flash_attn
|
93 |
+
|
94 |
+
@classmethod
|
95 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig':
|
96 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
97 |
+
|
98 |
+
if 'vision_config' in config_dict:
|
99 |
+
config_dict = config_dict['vision_config']
|
100 |
+
|
101 |
+
return cls.from_dict(config_dict, **kwargs)
|
conversation.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Conversation prompt templates.
|
3 |
+
|
4 |
+
We kindly request that you import fastchat instead of copying this file if you wish to use it.
|
5 |
+
If you have changes in mind, please contribute back so the community can benefit collectively and continue to maintain these valuable templates.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import dataclasses
|
9 |
+
from enum import IntEnum, auto
|
10 |
+
from typing import Any, Dict, List, Tuple, Union
|
11 |
+
|
12 |
+
|
13 |
+
class SeparatorStyle(IntEnum):
|
14 |
+
"""Separator styles."""
|
15 |
+
|
16 |
+
ADD_COLON_SINGLE = auto()
|
17 |
+
ADD_COLON_TWO = auto()
|
18 |
+
ADD_COLON_SPACE_SINGLE = auto()
|
19 |
+
NO_COLON_SINGLE = auto()
|
20 |
+
NO_COLON_TWO = auto()
|
21 |
+
ADD_NEW_LINE_SINGLE = auto()
|
22 |
+
LLAMA2 = auto()
|
23 |
+
CHATGLM = auto()
|
24 |
+
CHATML = auto()
|
25 |
+
DOLLY = auto()
|
26 |
+
RWKV = auto()
|
27 |
+
PHOENIX = auto()
|
28 |
+
ROBIN = auto()
|
29 |
+
FALCON_CHAT = auto()
|
30 |
+
CHATGLM3 = auto()
|
31 |
+
MPT = auto()
|
32 |
+
|
33 |
+
|
34 |
+
@dataclasses.dataclass
|
35 |
+
class Conversation:
|
36 |
+
"""A class that manages prompt templates and keeps all conversation history."""
|
37 |
+
|
38 |
+
# The name of this template
|
39 |
+
name: str
|
40 |
+
# The template of the system prompt
|
41 |
+
system_template: str = '{system_message}'
|
42 |
+
# The system message
|
43 |
+
system_message: str = ''
|
44 |
+
# The names of two roles
|
45 |
+
roles: Tuple[str] = ('USER', 'ASSISTANT')
|
46 |
+
# All messages. Each item is (role, message).
|
47 |
+
messages: List[List[str]] = ()
|
48 |
+
# The number of few shot examples
|
49 |
+
offset: int = 0
|
50 |
+
# The separator style and configurations
|
51 |
+
sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE
|
52 |
+
sep: str = '\n'
|
53 |
+
sep2: str = None
|
54 |
+
# Stop criteria (the default one is EOS token)
|
55 |
+
stop_str: Union[str, List[str]] = None
|
56 |
+
# Stops generation if meeting any token in this list
|
57 |
+
stop_token_ids: List[int] = None
|
58 |
+
|
59 |
+
def get_prompt(self) -> str:
|
60 |
+
"""Get the prompt for generation."""
|
61 |
+
system_prompt = self.system_template.format(system_message=self.system_message)
|
62 |
+
ret = system_prompt
|
63 |
+
for role, message in self.messages:
|
64 |
+
if message:
|
65 |
+
if type(message) is tuple:
|
66 |
+
message, _, _ = message
|
67 |
+
ret += role + message
|
68 |
+
else:
|
69 |
+
ret += role
|
70 |
+
|
71 |
+
return ret
|
72 |
+
|
73 |
+
def set_system_message(self, system_message: str):
|
74 |
+
"""Set the system message."""
|
75 |
+
self.system_message = system_message
|
76 |
+
|
77 |
+
def append_message(self, role: str, message: str):
|
78 |
+
"""Append a new message."""
|
79 |
+
self.messages.append([role, message])
|
80 |
+
|
81 |
+
def update_last_message(self, message: str):
|
82 |
+
"""Update the last output.
|
83 |
+
|
84 |
+
The last message is typically set to be None when constructing the prompt,
|
85 |
+
so we need to update it in-place after getting the response from a model.
|
86 |
+
"""
|
87 |
+
self.messages[-1][1] = message
|
88 |
+
|
89 |
+
def to_gradio_chatbot(self):
|
90 |
+
"""Convert the conversation to gradio chatbot format."""
|
91 |
+
ret = []
|
92 |
+
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
93 |
+
if i % 2 == 0:
|
94 |
+
ret.append([msg, None])
|
95 |
+
else:
|
96 |
+
ret[-1][-1] = msg
|
97 |
+
return ret
|
98 |
+
|
99 |
+
def to_openai_api_messages(self):
|
100 |
+
"""Convert the conversation to OpenAI chat completion format."""
|
101 |
+
ret = [{'role': 'system', 'content': self.system_message}]
|
102 |
+
|
103 |
+
for i, (_, msg) in enumerate(self.messages[self.offset :]):
|
104 |
+
if i % 2 == 0:
|
105 |
+
ret.append({'role': 'user', 'content': msg})
|
106 |
+
else:
|
107 |
+
if msg is not None:
|
108 |
+
ret.append({'role': 'assistant', 'content': msg})
|
109 |
+
return ret
|
110 |
+
|
111 |
+
def copy(self):
|
112 |
+
return Conversation(
|
113 |
+
name=self.name,
|
114 |
+
system_template=self.system_template,
|
115 |
+
system_message=self.system_message,
|
116 |
+
roles=self.roles,
|
117 |
+
messages=[[x, y] for x, y in self.messages],
|
118 |
+
offset=self.offset,
|
119 |
+
sep_style=self.sep_style,
|
120 |
+
sep=self.sep,
|
121 |
+
sep2=self.sep2,
|
122 |
+
stop_str=self.stop_str,
|
123 |
+
stop_token_ids=self.stop_token_ids,
|
124 |
+
)
|
125 |
+
|
126 |
+
def dict(self):
|
127 |
+
return {
|
128 |
+
'template_name': self.name,
|
129 |
+
'system_message': self.system_message,
|
130 |
+
'roles': self.roles,
|
131 |
+
'messages': self.messages,
|
132 |
+
'offset': self.offset,
|
133 |
+
}
|
134 |
+
|
135 |
+
|
136 |
+
# A global registry for all conversation templates
|
137 |
+
conv_templates: Dict[str, Conversation] = {}
|
138 |
+
|
139 |
+
|
140 |
+
def register_conv_template(template: Conversation, override: bool = False):
|
141 |
+
"""Register a new conversation template."""
|
142 |
+
if not override:
|
143 |
+
assert (
|
144 |
+
template.name not in conv_templates
|
145 |
+
), f'{template.name} has been registered.'
|
146 |
+
|
147 |
+
conv_templates[template.name] = template
|
148 |
+
|
149 |
+
|
150 |
+
def get_conv_template(name: str) -> Conversation:
|
151 |
+
"""Get a conversation template."""
|
152 |
+
return conv_templates[name].copy()
|
153 |
+
|
154 |
+
register_conv_template(
|
155 |
+
Conversation(
|
156 |
+
name='skywork-r1v-chat',
|
157 |
+
system_template='<|begin▁of▁sentence|>{system_message}',
|
158 |
+
system_message='',
|
159 |
+
roles=('<|User|>\n', '<|Assistant|><think>\n'),
|
160 |
+
sep_style=SeparatorStyle.MPT,
|
161 |
+
sep='<|end▁of▁sentence|>',
|
162 |
+
)
|
163 |
+
)
|
generation_config.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"bos_token_id": 151646,
|
4 |
+
"eos_token_id": 151643,
|
5 |
+
"do_sample": true,
|
6 |
+
"temperature": 0.6,
|
7 |
+
"top_p": 0.95,
|
8 |
+
"transformers_version": "4.39.3"
|
9 |
+
}
|
model-00001-of-00016.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2373921d7d9a490e820d98dd62776fab4edff4ed655125c7a064c4b3ac72fd17
|
3 |
+
size 4988569440
|
model-00002-of-00016.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8bb2779f566d99d96508a46b5d9a45273ddc5a7eea53f2cc424041e1ed4f1a43
|
3 |
+
size 4937253584
|
model-00003-of-00016.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:20faf8e5c3fc1ae0a02ab4cbe3d0407135970dfdc72f16ba44ea739ad5d5f21a
|
3 |
+
size 4780642896
|
model-00004-of-00016.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:04fb02e6c393cb7443fc2d63942de331e92278ae118850d9fb3641a99daf808b
|
3 |
+
size 4876060256
|
model-00005-of-00016.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a1e59db52268bfc5e83a61144f840630a8b130a068bf016867ddabbce961814f
|
3 |
+
size 4876060272
|
model-00006-of-00016.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:14bbe28e02c34a3a1929cce1449bc3abbe7f5760d702aef19e2b68d2509b15fd
|
3 |
+
size 4876060312
|
model-00007-of-00016.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:acc431255af7fda8193a91117cdfb61c70945a687f6a0bddd2494716f51f6469
|
3 |
+
size 4876060312
|
model-00008-of-00016.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0913c56f6d93bc03698e96b4cf24b6c58bfd6a0b9d9e2f6d6bc5ba4f7f959292
|
3 |
+
size 4876060312
|
model-00009-of-00016.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9be9378569371d0fc815bd51ade809522ff061427910b684103d978fd6865534
|
3 |
+
size 4876060312
|
model-00010-of-00016.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7aafb8b7b718b0523ce9cfdcfa443864e163db3bd8ad9fa79c9d3ef0f792bbc8
|
3 |
+
size 4876060312
|
model-00011-of-00016.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ec336bb125194e259262b2cd60982435376fa2430c7f4db791355bb120ea0bd3
|
3 |
+
size 4876060312
|
model-00012-of-00016.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dc0942288d5161ada1d130d44d5d194f3b74f76c2637541a2e9267556b26feb6
|
3 |
+
size 4876060312
|
model-00013-of-00016.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0e0faa8ba5caffc27a55b3d3612cdab22a71b782b7507e3938e152d615b51d55
|
3 |
+
size 4876060312
|
model-00014-of-00016.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1aa651e0e0bd996e9b856615a58113f29654db50271f6d7751c605aecb99282d
|
3 |
+
size 4876060312
|
model-00015-of-00016.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f443e511bddf529426b4f305c032e63d747826898b4731741d14e0a87debf6df
|
3 |
+
size 4876060312
|
model-00016-of-00016.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a5c349b706b0711b33aed44b63d6d652370fe85f54f213af561e270db4e4ebf1
|
3 |
+
size 3565298616
|
model.safetensors.index.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
modeling_skywork_chat.py
ADDED
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
from typing import List, Optional, Tuple, Union
|
3 |
+
|
4 |
+
import torch.utils.checkpoint
|
5 |
+
import transformers
|
6 |
+
from torch import nn
|
7 |
+
from torch.nn import CrossEntropyLoss
|
8 |
+
from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM,
|
9 |
+
LlamaTokenizer)
|
10 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
11 |
+
from transformers.modeling_utils import PreTrainedModel
|
12 |
+
from transformers.utils import ModelOutput, logging
|
13 |
+
|
14 |
+
from .configuration_skywork_chat import SkyworkChatConfig
|
15 |
+
from .conversation import get_conv_template
|
16 |
+
from .modeling_skywork_vit import SkyworkVisionModel, has_flash_attn
|
17 |
+
from .modeling_skywork_lm2 import SkyworkLM2ForCausalLM
|
18 |
+
|
19 |
+
from transformers import Qwen2Config, Qwen2ForCausalLM
|
20 |
+
|
21 |
+
logger = logging.get_logger(__name__)
|
22 |
+
|
23 |
+
|
24 |
+
def version_cmp(v1, v2, op='eq'):
|
25 |
+
import operator
|
26 |
+
|
27 |
+
from packaging import version
|
28 |
+
op_func = getattr(operator, op)
|
29 |
+
return op_func(version.parse(v1), version.parse(v2))
|
30 |
+
|
31 |
+
|
32 |
+
class SkyworkChatModel(PreTrainedModel):
|
33 |
+
config_class = SkyworkChatConfig
|
34 |
+
main_input_name = 'pixel_values'
|
35 |
+
base_model_prefix = 'language_model'
|
36 |
+
_supports_flash_attn_2 = True
|
37 |
+
_no_split_modules = ['SkyworkVisionModel', 'LlamaDecoderLayer', 'SkyworkLM2DecoderLayer']
|
38 |
+
|
39 |
+
def __init__(self, config: SkyworkChatConfig, vision_model=None, language_model=None, use_flash_attn=True):
|
40 |
+
super().__init__(config)
|
41 |
+
|
42 |
+
assert version_cmp(transformers.__version__, '4.36.2', 'ge')
|
43 |
+
image_size = config.force_image_size or config.vision_config.image_size
|
44 |
+
patch_size = config.vision_config.patch_size
|
45 |
+
self.patch_size = patch_size
|
46 |
+
self.select_layer = config.select_layer
|
47 |
+
self.template = config.template
|
48 |
+
self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
|
49 |
+
self.downsample_ratio = config.downsample_ratio
|
50 |
+
self.ps_version = config.ps_version
|
51 |
+
use_flash_attn = use_flash_attn if has_flash_attn else False
|
52 |
+
config.vision_config.use_flash_attn = True if use_flash_attn else False
|
53 |
+
config.llm_config.attn_implementation = 'flash_attention_2' if use_flash_attn else 'eager'
|
54 |
+
|
55 |
+
logger.info(f'num_image_token: {self.num_image_token}')
|
56 |
+
logger.info(f'ps_version: {self.ps_version}')
|
57 |
+
if vision_model is not None:
|
58 |
+
self.vision_model = vision_model
|
59 |
+
else:
|
60 |
+
self.vision_model = SkyworkVisionModel(config.vision_config)
|
61 |
+
if language_model is not None:
|
62 |
+
self.language_model = language_model
|
63 |
+
else:
|
64 |
+
if config.llm_config.architectures[0] == 'LlamaForCausalLM':
|
65 |
+
self.language_model = LlamaForCausalLM(config.llm_config)
|
66 |
+
elif config.llm_config.architectures[0] == 'SkyworkLM2ForCausalLM':
|
67 |
+
self.language_model = SkyworkLM2ForCausalLM(config.llm_config)
|
68 |
+
elif config.llm_config.architectures[0] == 'Qwen2ForCausalLM':
|
69 |
+
self.language_model = Qwen2ForCausalLM(config.llm_config)
|
70 |
+
else:
|
71 |
+
raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
|
72 |
+
|
73 |
+
vit_hidden_size = config.vision_config.hidden_size
|
74 |
+
llm_hidden_size = config.llm_config.hidden_size
|
75 |
+
|
76 |
+
self.mlp1 = nn.Sequential(
|
77 |
+
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
|
78 |
+
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
|
79 |
+
nn.GELU(),
|
80 |
+
nn.Linear(llm_hidden_size, llm_hidden_size)
|
81 |
+
)
|
82 |
+
|
83 |
+
self.img_context_token_id = None
|
84 |
+
self.conv_template = get_conv_template(self.template)
|
85 |
+
self.system_message = self.conv_template.system_message
|
86 |
+
|
87 |
+
def forward(
|
88 |
+
self,
|
89 |
+
pixel_values: torch.FloatTensor,
|
90 |
+
input_ids: torch.LongTensor = None,
|
91 |
+
attention_mask: Optional[torch.Tensor] = None,
|
92 |
+
position_ids: Optional[torch.LongTensor] = None,
|
93 |
+
image_flags: Optional[torch.LongTensor] = None,
|
94 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
95 |
+
labels: Optional[torch.LongTensor] = None,
|
96 |
+
use_cache: Optional[bool] = None,
|
97 |
+
output_attentions: Optional[bool] = None,
|
98 |
+
output_hidden_states: Optional[bool] = None,
|
99 |
+
return_dict: Optional[bool] = None,
|
100 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
101 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
102 |
+
|
103 |
+
image_flags = image_flags.squeeze(-1)
|
104 |
+
input_embeds = self.language_model.get_input_embeddings()(input_ids).clone()
|
105 |
+
|
106 |
+
vit_embeds = self.extract_feature(pixel_values)
|
107 |
+
vit_embeds = vit_embeds[image_flags == 1]
|
108 |
+
vit_batch_size = pixel_values.shape[0]
|
109 |
+
|
110 |
+
B, N, C = input_embeds.shape
|
111 |
+
input_embeds = input_embeds.reshape(B * N, C)
|
112 |
+
|
113 |
+
if torch.distributed.get_rank() == 0:
|
114 |
+
print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}')
|
115 |
+
|
116 |
+
input_ids = input_ids.reshape(B * N)
|
117 |
+
selected = (input_ids == self.img_context_token_id)
|
118 |
+
try:
|
119 |
+
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C)
|
120 |
+
except Exception as e:
|
121 |
+
vit_embeds = vit_embeds.reshape(-1, C)
|
122 |
+
print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, '
|
123 |
+
f'vit_embeds.shape={vit_embeds.shape}')
|
124 |
+
n_token = selected.sum()
|
125 |
+
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]
|
126 |
+
|
127 |
+
input_embeds = input_embeds.reshape(B, N, C)
|
128 |
+
|
129 |
+
outputs = self.language_model(
|
130 |
+
inputs_embeds=input_embeds,
|
131 |
+
attention_mask=attention_mask,
|
132 |
+
position_ids=position_ids,
|
133 |
+
past_key_values=past_key_values,
|
134 |
+
use_cache=use_cache,
|
135 |
+
output_attentions=output_attentions,
|
136 |
+
output_hidden_states=output_hidden_states,
|
137 |
+
return_dict=return_dict,
|
138 |
+
)
|
139 |
+
logits = outputs.logits
|
140 |
+
|
141 |
+
loss = None
|
142 |
+
if labels is not None:
|
143 |
+
# Shift so that tokens < n predict n
|
144 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
145 |
+
shift_labels = labels[..., 1:].contiguous()
|
146 |
+
# Flatten the tokens
|
147 |
+
loss_fct = CrossEntropyLoss()
|
148 |
+
shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
|
149 |
+
shift_labels = shift_labels.view(-1)
|
150 |
+
# Enable model parallelism
|
151 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
152 |
+
loss = loss_fct(shift_logits, shift_labels)
|
153 |
+
|
154 |
+
if not return_dict:
|
155 |
+
output = (logits,) + outputs[1:]
|
156 |
+
return (loss,) + output if loss is not None else output
|
157 |
+
|
158 |
+
return CausalLMOutputWithPast(
|
159 |
+
loss=loss,
|
160 |
+
logits=logits,
|
161 |
+
past_key_values=outputs.past_key_values,
|
162 |
+
hidden_states=outputs.hidden_states,
|
163 |
+
attentions=outputs.attentions,
|
164 |
+
)
|
165 |
+
|
166 |
+
def pixel_shuffle(self, x, scale_factor=0.5):
|
167 |
+
n, w, h, c = x.size()
|
168 |
+
# N, W, H, C --> N, W, H * scale, C // scale
|
169 |
+
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
|
170 |
+
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
|
171 |
+
x = x.permute(0, 2, 1, 3).contiguous()
|
172 |
+
# N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
|
173 |
+
x = x.view(n, int(h * scale_factor), int(w * scale_factor),
|
174 |
+
int(c / (scale_factor * scale_factor)))
|
175 |
+
if self.ps_version == 'v1':
|
176 |
+
warnings.warn("In ps_version 'v1', the height and width have not been swapped back, "
|
177 |
+
'which results in a transposed image.')
|
178 |
+
else:
|
179 |
+
x = x.permute(0, 2, 1, 3).contiguous()
|
180 |
+
return x
|
181 |
+
|
182 |
+
def extract_feature(self, pixel_values):
|
183 |
+
if self.select_layer == -1:
|
184 |
+
vit_embeds = self.vision_model(
|
185 |
+
pixel_values=pixel_values,
|
186 |
+
output_hidden_states=False,
|
187 |
+
return_dict=True).last_hidden_state
|
188 |
+
else:
|
189 |
+
vit_embeds = self.vision_model(
|
190 |
+
pixel_values=pixel_values,
|
191 |
+
output_hidden_states=True,
|
192 |
+
return_dict=True).hidden_states[self.select_layer]
|
193 |
+
vit_embeds = vit_embeds[:, 1:, :]
|
194 |
+
|
195 |
+
h = w = int(vit_embeds.shape[1] ** 0.5)
|
196 |
+
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
|
197 |
+
vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
|
198 |
+
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
|
199 |
+
vit_embeds = self.mlp1(vit_embeds)
|
200 |
+
return vit_embeds
|
201 |
+
|
202 |
+
def batch_chat(self, tokenizer, pixel_values, questions, generation_config, num_patches_list=None,
|
203 |
+
history=None, return_history=False, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>',
|
204 |
+
IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', verbose=False, image_counts=None):
|
205 |
+
if history is not None or return_history:
|
206 |
+
print('Now multi-turn chat is not supported in batch_chat.')
|
207 |
+
raise NotImplementedError
|
208 |
+
|
209 |
+
if image_counts is not None:
|
210 |
+
num_patches_list = image_counts
|
211 |
+
print('Warning: `image_counts` is deprecated. Please use `num_patches_list` instead.')
|
212 |
+
|
213 |
+
img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
|
214 |
+
self.img_context_token_id = img_context_token_id
|
215 |
+
|
216 |
+
|
217 |
+
if verbose and pixel_values is not None:
|
218 |
+
image_bs = pixel_values.shape[0]
|
219 |
+
print(f'dynamic ViT batch size: {image_bs}')
|
220 |
+
|
221 |
+
queries = []
|
222 |
+
for idx, num_patches in enumerate(num_patches_list):
|
223 |
+
question = questions[idx]
|
224 |
+
if pixel_values is not None and '<image>' not in question:
|
225 |
+
question = '<image>\n' + question
|
226 |
+
template = get_conv_template(self.template)
|
227 |
+
template.system_message = self.system_message
|
228 |
+
template.append_message(template.roles[0], question)
|
229 |
+
template.append_message(template.roles[1], None)
|
230 |
+
query = template.get_prompt()
|
231 |
+
|
232 |
+
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
|
233 |
+
query = query.replace('<image>', image_tokens, 1)
|
234 |
+
queries.append(query)
|
235 |
+
|
236 |
+
tokenizer.padding_side = 'left'
|
237 |
+
model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
|
238 |
+
input_ids = model_inputs['input_ids'].to(self.device)
|
239 |
+
attention_mask = model_inputs['attention_mask'].to(self.device)
|
240 |
+
eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())
|
241 |
+
generation_config['eos_token_id'] = eos_token_id
|
242 |
+
generation_output = self.generate(
|
243 |
+
pixel_values=pixel_values,
|
244 |
+
input_ids=input_ids,
|
245 |
+
attention_mask=attention_mask,
|
246 |
+
**generation_config
|
247 |
+
)
|
248 |
+
responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
|
249 |
+
responses = [response.split(template.sep.strip())[0].strip() for response in responses]
|
250 |
+
return responses
|
251 |
+
|
252 |
+
def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False,
|
253 |
+
num_patches_list=None, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>',
|
254 |
+
verbose=False):
|
255 |
+
|
256 |
+
if history is None and pixel_values is not None and '<image>' not in question:
|
257 |
+
question = '<image>\n' + question
|
258 |
+
|
259 |
+
if num_patches_list is None:
|
260 |
+
num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
|
261 |
+
assert pixel_values is None or len(pixel_values) == sum(num_patches_list)
|
262 |
+
|
263 |
+
img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
|
264 |
+
self.img_context_token_id = img_context_token_id
|
265 |
+
|
266 |
+
template = get_conv_template(self.template)
|
267 |
+
template.system_message = self.system_message
|
268 |
+
eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())
|
269 |
+
|
270 |
+
|
271 |
+
history = [] if history is None else history
|
272 |
+
for (old_question, old_answer) in history:
|
273 |
+
template.append_message(template.roles[0], old_question)
|
274 |
+
template.append_message(template.roles[1], old_answer)
|
275 |
+
template.append_message(template.roles[0], question)
|
276 |
+
template.append_message(template.roles[1], None)
|
277 |
+
query = template.get_prompt()
|
278 |
+
|
279 |
+
|
280 |
+
if verbose and pixel_values is not None:
|
281 |
+
image_bs = pixel_values.shape[0]
|
282 |
+
print(f'dynamic ViT batch size: {image_bs}')
|
283 |
+
|
284 |
+
for num_patches in num_patches_list:
|
285 |
+
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
|
286 |
+
query = query.replace('<image>', image_tokens, 1)
|
287 |
+
|
288 |
+
|
289 |
+
model_inputs = tokenizer(query, return_tensors='pt')
|
290 |
+
input_ids = model_inputs['input_ids'].to(self.device)
|
291 |
+
attention_mask = model_inputs['attention_mask'].to(self.device)
|
292 |
+
generation_config['eos_token_id'] = eos_token_id
|
293 |
+
generation_output = self.generate(
|
294 |
+
pixel_values=pixel_values,
|
295 |
+
input_ids=input_ids,
|
296 |
+
attention_mask=attention_mask,
|
297 |
+
**generation_config
|
298 |
+
)
|
299 |
+
response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
|
300 |
+
response = response.split(template.sep.strip())[0].strip()
|
301 |
+
history.append((question, response))
|
302 |
+
|
303 |
+
if return_history:
|
304 |
+
return response, history
|
305 |
+
else:
|
306 |
+
query_to_print = query.replace(IMG_CONTEXT_TOKEN, '')
|
307 |
+
query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>')
|
308 |
+
if verbose:
|
309 |
+
print(query_to_print, response)
|
310 |
+
return response
|
311 |
+
|
312 |
+
@torch.no_grad()
|
313 |
+
def generate(
|
314 |
+
self,
|
315 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
316 |
+
input_ids: Optional[torch.FloatTensor] = None,
|
317 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
318 |
+
visual_features: Optional[torch.FloatTensor] = None,
|
319 |
+
generation_config: Optional[GenerationConfig] = None,
|
320 |
+
output_hidden_states: Optional[bool] = None,
|
321 |
+
**generate_kwargs,
|
322 |
+
) -> torch.LongTensor:
|
323 |
+
|
324 |
+
assert self.img_context_token_id is not None
|
325 |
+
if pixel_values is not None:
|
326 |
+
if visual_features is not None:
|
327 |
+
vit_embeds = visual_features
|
328 |
+
else:
|
329 |
+
vit_embeds = self.extract_feature(pixel_values)
|
330 |
+
input_embeds = self.language_model.get_input_embeddings()(input_ids)
|
331 |
+
B, N, C = input_embeds.shape
|
332 |
+
input_embeds = input_embeds.reshape(B * N, C)
|
333 |
+
|
334 |
+
input_ids = input_ids.reshape(B * N)
|
335 |
+
selected = (input_ids == self.img_context_token_id)
|
336 |
+
|
337 |
+
assert selected.sum() != 0
|
338 |
+
input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
|
339 |
+
|
340 |
+
input_embeds = input_embeds.reshape(B, N, C)
|
341 |
+
else:
|
342 |
+
input_embeds = self.language_model.get_input_embeddings()(input_ids)
|
343 |
+
|
344 |
+
|
345 |
+
outputs = self.language_model.generate(
|
346 |
+
inputs_embeds=input_embeds,
|
347 |
+
attention_mask=attention_mask,
|
348 |
+
generation_config=generation_config,
|
349 |
+
output_hidden_states=output_hidden_states,
|
350 |
+
use_cache=True,
|
351 |
+
**generate_kwargs,
|
352 |
+
)
|
353 |
+
|
354 |
+
return outputs
|
modeling_skywork_vit.py
ADDED
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torch.utils.checkpoint
|
6 |
+
from einops import rearrange
|
7 |
+
from timm.models.layers import DropPath
|
8 |
+
from torch import nn
|
9 |
+
from transformers.activations import ACT2FN
|
10 |
+
from transformers.modeling_outputs import (BaseModelOutput,
|
11 |
+
BaseModelOutputWithPooling)
|
12 |
+
from transformers.modeling_utils import PreTrainedModel
|
13 |
+
from transformers.utils import logging
|
14 |
+
|
15 |
+
from .configuration_skywork_vit import SkyworkVisionConfig
|
16 |
+
|
17 |
+
try:
|
18 |
+
from flash_attn.bert_padding import pad_input, unpad_input
|
19 |
+
from flash_attn.flash_attn_interface import \
|
20 |
+
flash_attn_varlen_qkvpacked_func
|
21 |
+
has_flash_attn = True
|
22 |
+
except:
|
23 |
+
print('FlashAttention2 is not installed.')
|
24 |
+
has_flash_attn = False
|
25 |
+
|
26 |
+
logger = logging.get_logger(__name__)
|
27 |
+
|
28 |
+
|
29 |
+
class FlashAttention(nn.Module):
|
30 |
+
"""Implement the scaled dot product attention with softmax.
|
31 |
+
Arguments
|
32 |
+
---------
|
33 |
+
softmax_scale: The temperature to use for the softmax attention.
|
34 |
+
(default: 1/sqrt(d_keys) where d_keys is computed at
|
35 |
+
runtime)
|
36 |
+
attention_dropout: The dropout rate to apply to the attention
|
37 |
+
(default: 0.0)
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
|
41 |
+
super().__init__()
|
42 |
+
self.softmax_scale = softmax_scale
|
43 |
+
self.dropout_p = attention_dropout
|
44 |
+
|
45 |
+
def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
|
46 |
+
max_s=None, need_weights=False):
|
47 |
+
"""Implements the multihead softmax attention.
|
48 |
+
Arguments
|
49 |
+
---------
|
50 |
+
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
|
51 |
+
if unpadded: (nnz, 3, h, d)
|
52 |
+
key_padding_mask: a bool tensor of shape (B, S)
|
53 |
+
"""
|
54 |
+
assert not need_weights
|
55 |
+
assert qkv.dtype in [torch.float16, torch.bfloat16]
|
56 |
+
assert qkv.is_cuda
|
57 |
+
|
58 |
+
if cu_seqlens is None:
|
59 |
+
batch_size = qkv.shape[0]
|
60 |
+
seqlen = qkv.shape[1]
|
61 |
+
if key_padding_mask is None:
|
62 |
+
qkv = rearrange(qkv, 'b s ... -> (b s) ...')
|
63 |
+
max_s = seqlen
|
64 |
+
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
|
65 |
+
device=qkv.device)
|
66 |
+
output = flash_attn_varlen_qkvpacked_func(
|
67 |
+
qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
|
68 |
+
softmax_scale=self.softmax_scale, causal=causal
|
69 |
+
)
|
70 |
+
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
|
71 |
+
else:
|
72 |
+
nheads = qkv.shape[-2]
|
73 |
+
x = rearrange(qkv, 'b s three h d -> b s (three h d)')
|
74 |
+
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
|
75 |
+
x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
|
76 |
+
output_unpad = flash_attn_varlen_qkvpacked_func(
|
77 |
+
x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
|
78 |
+
softmax_scale=self.softmax_scale, causal=causal
|
79 |
+
)
|
80 |
+
output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
|
81 |
+
indices, batch_size, seqlen),
|
82 |
+
'b s (h d) -> b s h d', h=nheads)
|
83 |
+
else:
|
84 |
+
assert max_s is not None
|
85 |
+
output = flash_attn_varlen_qkvpacked_func(
|
86 |
+
qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
|
87 |
+
softmax_scale=self.softmax_scale, causal=causal
|
88 |
+
)
|
89 |
+
|
90 |
+
return output, None
|
91 |
+
|
92 |
+
|
93 |
+
class SkyworkRMSNorm(nn.Module):
|
94 |
+
def __init__(self, hidden_size, eps=1e-6):
|
95 |
+
super().__init__()
|
96 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
97 |
+
self.variance_epsilon = eps
|
98 |
+
|
99 |
+
def forward(self, hidden_states):
|
100 |
+
input_dtype = hidden_states.dtype
|
101 |
+
hidden_states = hidden_states.to(torch.float32)
|
102 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
103 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
104 |
+
return self.weight * hidden_states.to(input_dtype)
|
105 |
+
|
106 |
+
|
107 |
+
try:
|
108 |
+
from apex.normalization import FusedRMSNorm
|
109 |
+
|
110 |
+
SkyworkRMSNorm = FusedRMSNorm # noqa
|
111 |
+
|
112 |
+
logger.info('Discovered apex.normalization.FusedRMSNorm - will use it instead ofSkyworkRMSNorm')
|
113 |
+
except ImportError:
|
114 |
+
# using the normal SkyworkRMSNorm
|
115 |
+
pass
|
116 |
+
except Exception:
|
117 |
+
logger.warning('discovered apex but it failed to load, falling back to SkyworkRMSNorm')
|
118 |
+
pass
|
119 |
+
|
120 |
+
|
121 |
+
NORM2FN = {
|
122 |
+
'rms_norm': SkyworkRMSNorm,
|
123 |
+
'layer_norm': nn.LayerNorm,
|
124 |
+
}
|
125 |
+
|
126 |
+
|
127 |
+
class SkyworkVisionEmbeddings(nn.Module):
|
128 |
+
def __init__(self, config: SkyworkVisionConfig):
|
129 |
+
super().__init__()
|
130 |
+
self.config = config
|
131 |
+
self.embed_dim = config.hidden_size
|
132 |
+
self.image_size = config.image_size
|
133 |
+
self.patch_size = config.patch_size
|
134 |
+
|
135 |
+
self.class_embedding = nn.Parameter(
|
136 |
+
torch.randn(1, 1, self.embed_dim),
|
137 |
+
)
|
138 |
+
|
139 |
+
self.patch_embedding = nn.Conv2d(
|
140 |
+
in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
|
141 |
+
)
|
142 |
+
|
143 |
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
144 |
+
self.num_positions = self.num_patches + 1
|
145 |
+
|
146 |
+
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
|
147 |
+
|
148 |
+
def _get_pos_embed(self, pos_embed, H, W):
|
149 |
+
target_dtype = pos_embed.dtype
|
150 |
+
pos_embed = pos_embed.float().reshape(
|
151 |
+
1, self.image_size // self.patch_size, self.image_size // self.patch_size, -1).permute(0, 3, 1, 2)
|
152 |
+
pos_embed = F.interpolate(pos_embed, size=(H, W), mode='bicubic', align_corners=False). \
|
153 |
+
reshape(1, -1, H * W).permute(0, 2, 1).to(target_dtype)
|
154 |
+
return pos_embed
|
155 |
+
|
156 |
+
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
157 |
+
target_dtype = self.patch_embedding.weight.dtype
|
158 |
+
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, channel, width, height]
|
159 |
+
batch_size, _, height, width = patch_embeds.shape
|
160 |
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
161 |
+
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
|
162 |
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
163 |
+
position_embedding = torch.cat([
|
164 |
+
self.position_embedding[:, :1, :],
|
165 |
+
self._get_pos_embed(self.position_embedding[:, 1:, :], height, width)
|
166 |
+
], dim=1)
|
167 |
+
embeddings = embeddings + position_embedding.to(target_dtype)
|
168 |
+
return embeddings
|
169 |
+
|
170 |
+
|
171 |
+
class SkyworkAttention(nn.Module):
|
172 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
173 |
+
|
174 |
+
def __init__(self, config: SkyworkVisionConfig):
|
175 |
+
super().__init__()
|
176 |
+
self.config = config
|
177 |
+
self.embed_dim = config.hidden_size
|
178 |
+
self.num_heads = config.num_attention_heads
|
179 |
+
self.use_flash_attn = config.use_flash_attn and has_flash_attn
|
180 |
+
if config.use_flash_attn and not has_flash_attn:
|
181 |
+
print('Warning: Flash Attention is not available, use_flash_attn is set to False.')
|
182 |
+
self.head_dim = self.embed_dim // self.num_heads
|
183 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
184 |
+
raise ValueError(
|
185 |
+
f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:'
|
186 |
+
f' {self.num_heads}).'
|
187 |
+
)
|
188 |
+
|
189 |
+
self.scale = self.head_dim ** -0.5
|
190 |
+
self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
|
191 |
+
self.attn_drop = nn.Dropout(config.attention_dropout)
|
192 |
+
self.proj_drop = nn.Dropout(config.dropout)
|
193 |
+
|
194 |
+
self.qk_normalization = config.qk_normalization
|
195 |
+
|
196 |
+
if self.qk_normalization:
|
197 |
+
self.q_norm = SkyworkRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
198 |
+
self.k_norm = SkyworkRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
199 |
+
|
200 |
+
if self.use_flash_attn:
|
201 |
+
self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout)
|
202 |
+
self.proj = nn.Linear(self.embed_dim, self.embed_dim)
|
203 |
+
|
204 |
+
def _naive_attn(self, x):
|
205 |
+
B, N, C = x.shape
|
206 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
207 |
+
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
208 |
+
|
209 |
+
if self.qk_normalization:
|
210 |
+
B_, H_, N_, D_ = q.shape
|
211 |
+
q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
|
212 |
+
k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
|
213 |
+
|
214 |
+
attn = ((q * self.scale) @ k.transpose(-2, -1))
|
215 |
+
attn = attn.softmax(dim=-1)
|
216 |
+
attn = self.attn_drop(attn)
|
217 |
+
|
218 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
219 |
+
x = self.proj(x)
|
220 |
+
x = self.proj_drop(x)
|
221 |
+
return x
|
222 |
+
|
223 |
+
def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
|
224 |
+
qkv = self.qkv(x)
|
225 |
+
qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
|
226 |
+
|
227 |
+
if self.qk_normalization:
|
228 |
+
q, k, v = qkv.unbind(2)
|
229 |
+
q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
|
230 |
+
k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
|
231 |
+
qkv = torch.stack([q, k, v], dim=2)
|
232 |
+
|
233 |
+
context, _ = self.inner_attn(
|
234 |
+
qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False
|
235 |
+
)
|
236 |
+
outs = self.proj(rearrange(context, 'b s h d -> b s (h d)'))
|
237 |
+
outs = self.proj_drop(outs)
|
238 |
+
return outs
|
239 |
+
|
240 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
241 |
+
x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states)
|
242 |
+
return x
|
243 |
+
|
244 |
+
|
245 |
+
class SkyworkMLP(nn.Module):
|
246 |
+
def __init__(self, config: SkyworkVisionConfig):
|
247 |
+
super().__init__()
|
248 |
+
self.config = config
|
249 |
+
self.act = ACT2FN[config.hidden_act]
|
250 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
251 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
252 |
+
|
253 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
254 |
+
hidden_states = self.fc1(hidden_states)
|
255 |
+
hidden_states = self.act(hidden_states)
|
256 |
+
hidden_states = self.fc2(hidden_states)
|
257 |
+
return hidden_states
|
258 |
+
|
259 |
+
|
260 |
+
class SkyworkVisionEncoderLayer(nn.Module):
|
261 |
+
def __init__(self, config: SkyworkVisionConfig, drop_path_rate: float):
|
262 |
+
super().__init__()
|
263 |
+
self.embed_dim = config.hidden_size
|
264 |
+
self.intermediate_size = config.intermediate_size
|
265 |
+
self.norm_type = config.norm_type
|
266 |
+
|
267 |
+
self.attn = SkyworkAttention(config)
|
268 |
+
self.mlp = SkyworkMLP(config)
|
269 |
+
self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
|
270 |
+
self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
|
271 |
+
|
272 |
+
self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
|
273 |
+
self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
|
274 |
+
self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
275 |
+
self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
276 |
+
|
277 |
+
def forward(
|
278 |
+
self,
|
279 |
+
hidden_states: torch.Tensor,
|
280 |
+
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
|
281 |
+
"""
|
282 |
+
Args:
|
283 |
+
hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
284 |
+
"""
|
285 |
+
hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states).to(hidden_states.dtype)) * self.ls1)
|
286 |
+
|
287 |
+
hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states).to(hidden_states.dtype)) * self.ls2)
|
288 |
+
|
289 |
+
return hidden_states
|
290 |
+
|
291 |
+
|
292 |
+
class SkyworkVisionEncoder(nn.Module):
|
293 |
+
"""
|
294 |
+
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
295 |
+
[`SkyworkEncoderLayer`].
|
296 |
+
|
297 |
+
Args:
|
298 |
+
config (`SkyworkConfig`):
|
299 |
+
The corresponding vision configuration for the `SkyworkEncoder`.
|
300 |
+
"""
|
301 |
+
|
302 |
+
def __init__(self, config: SkyworkVisionConfig):
|
303 |
+
super().__init__()
|
304 |
+
self.config = config
|
305 |
+
# stochastic depth decay rule
|
306 |
+
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
|
307 |
+
self.layers = nn.ModuleList([
|
308 |
+
SkyworkVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)])
|
309 |
+
self.gradient_checkpointing = True
|
310 |
+
|
311 |
+
def forward(
|
312 |
+
self,
|
313 |
+
inputs_embeds,
|
314 |
+
output_hidden_states: Optional[bool] = None,
|
315 |
+
return_dict: Optional[bool] = None,
|
316 |
+
) -> Union[Tuple, BaseModelOutput]:
|
317 |
+
r"""
|
318 |
+
Args:
|
319 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
320 |
+
Embedded representation of the inputs. Should be float, not int tokens.
|
321 |
+
output_hidden_states (`bool`, *optional*):
|
322 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
323 |
+
for more detail.
|
324 |
+
return_dict (`bool`, *optional*):
|
325 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
326 |
+
"""
|
327 |
+
output_hidden_states = (
|
328 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
329 |
+
)
|
330 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
331 |
+
|
332 |
+
encoder_states = () if output_hidden_states else None
|
333 |
+
hidden_states = inputs_embeds
|
334 |
+
|
335 |
+
for idx, encoder_layer in enumerate(self.layers):
|
336 |
+
if output_hidden_states:
|
337 |
+
encoder_states = encoder_states + (hidden_states,)
|
338 |
+
if self.gradient_checkpointing and self.training:
|
339 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
340 |
+
encoder_layer,
|
341 |
+
hidden_states)
|
342 |
+
else:
|
343 |
+
layer_outputs = encoder_layer(
|
344 |
+
hidden_states,
|
345 |
+
)
|
346 |
+
hidden_states = layer_outputs
|
347 |
+
|
348 |
+
if output_hidden_states:
|
349 |
+
encoder_states = encoder_states + (hidden_states,)
|
350 |
+
|
351 |
+
if not return_dict:
|
352 |
+
return tuple(v for v in [hidden_states, encoder_states] if v is not None)
|
353 |
+
return BaseModelOutput(
|
354 |
+
last_hidden_state=hidden_states, hidden_states=encoder_states
|
355 |
+
)
|
356 |
+
|
357 |
+
|
358 |
+
class SkyworkVisionModel(PreTrainedModel):
|
359 |
+
main_input_name = 'pixel_values'
|
360 |
+
_supports_flash_attn_2 = True
|
361 |
+
config_class = SkyworkVisionConfig
|
362 |
+
_no_split_modules = ['SkyworkVisionEncoderLayer']
|
363 |
+
|
364 |
+
def __init__(self, config: SkyworkVisionConfig):
|
365 |
+
super().__init__(config)
|
366 |
+
self.config = config
|
367 |
+
|
368 |
+
self.embeddings = SkyworkVisionEmbeddings(config)
|
369 |
+
self.encoder = SkyworkVisionEncoder(config)
|
370 |
+
|
371 |
+
def resize_pos_embeddings(self, old_size, new_size, patch_size):
|
372 |
+
pos_emb = self.embeddings.position_embedding
|
373 |
+
_, num_positions, embed_dim = pos_emb.shape
|
374 |
+
cls_emb = pos_emb[:, :1, :]
|
375 |
+
pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2)
|
376 |
+
pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode='bicubic', align_corners=False)
|
377 |
+
pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
|
378 |
+
pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
|
379 |
+
self.embeddings.position_embedding = nn.Parameter(pos_emb)
|
380 |
+
self.embeddings.image_size = new_size
|
381 |
+
logger.info('Resized position embeddings from {} to {}'.format(old_size, new_size))
|
382 |
+
|
383 |
+
def get_input_embeddings(self):
|
384 |
+
return self.embeddings
|
385 |
+
|
386 |
+
def forward(
|
387 |
+
self,
|
388 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
389 |
+
output_hidden_states: Optional[bool] = None,
|
390 |
+
return_dict: Optional[bool] = None,
|
391 |
+
pixel_embeds: Optional[torch.FloatTensor] = None,
|
392 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
393 |
+
output_hidden_states = (
|
394 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
395 |
+
)
|
396 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
397 |
+
|
398 |
+
if pixel_values is None and pixel_embeds is None:
|
399 |
+
raise ValueError('You have to specify pixel_values or pixel_embeds')
|
400 |
+
|
401 |
+
if pixel_embeds is not None:
|
402 |
+
hidden_states = pixel_embeds
|
403 |
+
else:
|
404 |
+
if len(pixel_values.shape) == 4:
|
405 |
+
hidden_states = self.embeddings(pixel_values)
|
406 |
+
else:
|
407 |
+
raise ValueError(f'wrong pixel_values size: {pixel_values.shape}')
|
408 |
+
encoder_outputs = self.encoder(
|
409 |
+
inputs_embeds=hidden_states,
|
410 |
+
output_hidden_states=output_hidden_states,
|
411 |
+
return_dict=return_dict,
|
412 |
+
)
|
413 |
+
last_hidden_state = encoder_outputs.last_hidden_state
|
414 |
+
pooled_output = last_hidden_state[:, 0, :]
|
415 |
+
|
416 |
+
if not return_dict:
|
417 |
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
418 |
+
|
419 |
+
return BaseModelOutputWithPooling(
|
420 |
+
last_hidden_state=last_hidden_state,
|
421 |
+
pooler_output=pooled_output,
|
422 |
+
hidden_states=encoder_outputs.hidden_states,
|
423 |
+
attentions=encoder_outputs.attentions,
|
424 |
+
)
|
preprocessor_config.json
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"crop_size": 448,
|
3 |
+
"do_center_crop": true,
|
4 |
+
"do_normalize": true,
|
5 |
+
"do_resize": true,
|
6 |
+
"feature_extractor_type": "CLIPFeatureExtractor",
|
7 |
+
"image_mean": [
|
8 |
+
0.485,
|
9 |
+
0.456,
|
10 |
+
0.406
|
11 |
+
],
|
12 |
+
"image_std": [
|
13 |
+
0.229,
|
14 |
+
0.224,
|
15 |
+
0.225
|
16 |
+
],
|
17 |
+
"resample": 3,
|
18 |
+
"size": 448
|
19 |
+
}
|
special_tokens_map.json
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"additional_special_tokens": [
|
3 |
+
"<|object_ref_start|>",
|
4 |
+
"<|object_ref_end|>",
|
5 |
+
"<|box_start|>",
|
6 |
+
"<|box_end|>",
|
7 |
+
"<|quad_start|>",
|
8 |
+
"<|quad_end|>",
|
9 |
+
"<|vision_start|>",
|
10 |
+
"<|vision_end|>",
|
11 |
+
"<|vision_pad|>",
|
12 |
+
"<|image_pad|>",
|
13 |
+
"<|video_pad|>",
|
14 |
+
"<img>",
|
15 |
+
"</img>",
|
16 |
+
"<IMG_CONTEXT>",
|
17 |
+
"<quad>",
|
18 |
+
"</quad>",
|
19 |
+
"<ref>",
|
20 |
+
"</ref>",
|
21 |
+
"<box>",
|
22 |
+
"</box>",
|
23 |
+
"<|begin▁of▁sentence|>",
|
24 |
+
"<|end▁of▁sentence|>"
|
25 |
+
],
|
26 |
+
"bos_token": {
|
27 |
+
"content": "<|begin▁of▁sentence|>",
|
28 |
+
"lstrip": false,
|
29 |
+
"normalized": false,
|
30 |
+
"rstrip": false,
|
31 |
+
"single_word": false
|
32 |
+
},
|
33 |
+
"eos_token": {
|
34 |
+
"content": "<|end▁of▁sentence|>",
|
35 |
+
"lstrip": false,
|
36 |
+
"normalized": false,
|
37 |
+
"rstrip": false,
|
38 |
+
"single_word": false
|
39 |
+
},
|
40 |
+
"pad_token": {
|
41 |
+
"content": "<|endoftext|>",
|
42 |
+
"lstrip": false,
|
43 |
+
"normalized": false,
|
44 |
+
"rstrip": false,
|
45 |
+
"single_word": false
|
46 |
+
}
|
47 |
+
}
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer_config.json
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_bos_token": true,
|
3 |
+
"add_eos_token": false,
|
4 |
+
"bos_token": {
|
5 |
+
"__type": "AddedToken",
|
6 |
+
"content": "<|begin▁of▁sentence|>",
|
7 |
+
"lstrip": false,
|
8 |
+
"normalized": true,
|
9 |
+
"rstrip": false,
|
10 |
+
"single_word": false
|
11 |
+
},
|
12 |
+
"clean_up_tokenization_spaces": false,
|
13 |
+
"eos_token": {
|
14 |
+
"__type": "AddedToken",
|
15 |
+
"content": "<|end▁of▁sentence|>",
|
16 |
+
"lstrip": false,
|
17 |
+
"normalized": true,
|
18 |
+
"rstrip": false,
|
19 |
+
"single_word": false
|
20 |
+
},
|
21 |
+
"legacy": true,
|
22 |
+
"model_max_length": 16384,
|
23 |
+
"pad_token": {
|
24 |
+
"__type": "AddedToken",
|
25 |
+
"content": "<|end▁of▁sentence|>",
|
26 |
+
"lstrip": false,
|
27 |
+
"normalized": true,
|
28 |
+
"rstrip": false,
|
29 |
+
"single_word": false
|
30 |
+
},
|
31 |
+
"sp_model_kwargs": {},
|
32 |
+
"unk_token": null,
|
33 |
+
"tokenizer_class": "LlamaTokenizerFast",
|
34 |
+
"chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|><think>\\n'}}{% endif %}"
|
35 |
+
}
|
vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
zero_to_fp32.py
ADDED
@@ -0,0 +1,604 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# Copyright (c) Microsoft Corporation.
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
5 |
+
|
6 |
+
# DeepSpeed Team
|
7 |
+
|
8 |
+
# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
|
9 |
+
# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
|
10 |
+
# the future. Once extracted, the weights don't require DeepSpeed and can be used in any
|
11 |
+
# application.
|
12 |
+
#
|
13 |
+
# example: python zero_to_fp32.py . pytorch_model.bin
|
14 |
+
|
15 |
+
import argparse
|
16 |
+
import torch
|
17 |
+
import glob
|
18 |
+
import math
|
19 |
+
import os
|
20 |
+
import re
|
21 |
+
from collections import OrderedDict
|
22 |
+
from dataclasses import dataclass
|
23 |
+
|
24 |
+
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
|
25 |
+
# DeepSpeed data structures it has to be available in the current python environment.
|
26 |
+
from deepspeed.utils import logger
|
27 |
+
from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
|
28 |
+
FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
|
29 |
+
FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
|
30 |
+
|
31 |
+
|
32 |
+
@dataclass
|
33 |
+
class zero_model_state:
|
34 |
+
buffers: dict()
|
35 |
+
param_shapes: dict()
|
36 |
+
shared_params: list
|
37 |
+
ds_version: int
|
38 |
+
frozen_param_shapes: dict()
|
39 |
+
frozen_param_fragments: dict()
|
40 |
+
|
41 |
+
|
42 |
+
debug = 0
|
43 |
+
|
44 |
+
# load to cpu
|
45 |
+
device = torch.device('cpu')
|
46 |
+
|
47 |
+
|
48 |
+
def atoi(text):
|
49 |
+
return int(text) if text.isdigit() else text
|
50 |
+
|
51 |
+
|
52 |
+
def natural_keys(text):
|
53 |
+
'''
|
54 |
+
alist.sort(key=natural_keys) sorts in human order
|
55 |
+
http://nedbatchelder.com/blog/200712/human_sorting.html
|
56 |
+
(See Toothy's implementation in the comments)
|
57 |
+
'''
|
58 |
+
return [atoi(c) for c in re.split(r'(\d+)', text)]
|
59 |
+
|
60 |
+
|
61 |
+
def get_model_state_file(checkpoint_dir, zero_stage):
|
62 |
+
if not os.path.isdir(checkpoint_dir):
|
63 |
+
raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
|
64 |
+
|
65 |
+
# there should be only one file
|
66 |
+
if zero_stage <= 2:
|
67 |
+
file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
|
68 |
+
elif zero_stage == 3:
|
69 |
+
file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
|
70 |
+
|
71 |
+
if not os.path.exists(file):
|
72 |
+
raise FileNotFoundError(f"can't find model states file at '{file}'")
|
73 |
+
|
74 |
+
return file
|
75 |
+
|
76 |
+
|
77 |
+
def get_checkpoint_files(checkpoint_dir, glob_pattern):
|
78 |
+
# XXX: need to test that this simple glob rule works for multi-node setup too
|
79 |
+
ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
|
80 |
+
|
81 |
+
if len(ckpt_files) == 0:
|
82 |
+
raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
|
83 |
+
|
84 |
+
return ckpt_files
|
85 |
+
|
86 |
+
|
87 |
+
def get_optim_files(checkpoint_dir):
|
88 |
+
return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
|
89 |
+
|
90 |
+
|
91 |
+
def get_model_state_files(checkpoint_dir):
|
92 |
+
return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
|
93 |
+
|
94 |
+
|
95 |
+
def parse_model_states(files):
|
96 |
+
zero_model_states = []
|
97 |
+
for file in files:
|
98 |
+
state_dict = torch.load(file, map_location=device)
|
99 |
+
|
100 |
+
if BUFFER_NAMES not in state_dict:
|
101 |
+
raise ValueError(f"{file} is not a model state checkpoint")
|
102 |
+
buffer_names = state_dict[BUFFER_NAMES]
|
103 |
+
if debug:
|
104 |
+
print("Found buffers:", buffer_names)
|
105 |
+
|
106 |
+
# recover just the buffers while restoring them to fp32 if they were saved in fp16
|
107 |
+
buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
|
108 |
+
param_shapes = state_dict[PARAM_SHAPES]
|
109 |
+
|
110 |
+
# collect parameters that are included in param_shapes
|
111 |
+
param_names = []
|
112 |
+
for s in param_shapes:
|
113 |
+
for name in s.keys():
|
114 |
+
param_names.append(name)
|
115 |
+
|
116 |
+
# update with frozen parameters
|
117 |
+
frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
|
118 |
+
if frozen_param_shapes is not None:
|
119 |
+
if debug:
|
120 |
+
print(f"Found frozen_param_shapes: {frozen_param_shapes}")
|
121 |
+
param_names += list(frozen_param_shapes.keys())
|
122 |
+
|
123 |
+
# handle shared params
|
124 |
+
shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
|
125 |
+
|
126 |
+
ds_version = state_dict.get(DS_VERSION, None)
|
127 |
+
|
128 |
+
frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
|
129 |
+
|
130 |
+
z_model_state = zero_model_state(buffers=buffers,
|
131 |
+
param_shapes=param_shapes,
|
132 |
+
shared_params=shared_params,
|
133 |
+
ds_version=ds_version,
|
134 |
+
frozen_param_shapes=frozen_param_shapes,
|
135 |
+
frozen_param_fragments=frozen_param_fragments)
|
136 |
+
zero_model_states.append(z_model_state)
|
137 |
+
|
138 |
+
return zero_model_states
|
139 |
+
|
140 |
+
|
141 |
+
def parse_optim_states(files, ds_checkpoint_dir):
|
142 |
+
|
143 |
+
total_files = len(files)
|
144 |
+
state_dicts = []
|
145 |
+
for f in files:
|
146 |
+
state_dict = torch.load(f, map_location=device)
|
147 |
+
# immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
|
148 |
+
# and also handle the case where it was already removed by another helper script
|
149 |
+
state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
|
150 |
+
state_dicts.append(state_dict)
|
151 |
+
|
152 |
+
if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:
|
153 |
+
raise ValueError(f"{files[0]} is not a zero checkpoint")
|
154 |
+
zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
|
155 |
+
world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
|
156 |
+
|
157 |
+
# For ZeRO-2 each param group can have different partition_count as data parallelism for expert
|
158 |
+
# parameters can be different from data parallelism for non-expert parameters. So we can just
|
159 |
+
# use the max of the partition_count to get the dp world_size.
|
160 |
+
|
161 |
+
if type(world_size) is list:
|
162 |
+
world_size = max(world_size)
|
163 |
+
|
164 |
+
if world_size != total_files:
|
165 |
+
raise ValueError(
|
166 |
+
f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
|
167 |
+
"Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
|
168 |
+
)
|
169 |
+
|
170 |
+
# the groups are named differently in each stage
|
171 |
+
if zero_stage <= 2:
|
172 |
+
fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
|
173 |
+
elif zero_stage == 3:
|
174 |
+
fp32_groups_key = FP32_FLAT_GROUPS
|
175 |
+
else:
|
176 |
+
raise ValueError(f"unknown zero stage {zero_stage}")
|
177 |
+
|
178 |
+
if zero_stage <= 2:
|
179 |
+
fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
|
180 |
+
elif zero_stage == 3:
|
181 |
+
# if there is more than one param group, there will be multiple flattened tensors - one
|
182 |
+
# flattened tensor per group - for simplicity merge them into a single tensor
|
183 |
+
#
|
184 |
+
# XXX: could make the script more memory efficient for when there are multiple groups - it
|
185 |
+
# will require matching the sub-lists of param_shapes for each param group flattened tensor
|
186 |
+
|
187 |
+
fp32_flat_groups = [
|
188 |
+
torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts))
|
189 |
+
]
|
190 |
+
|
191 |
+
return zero_stage, world_size, fp32_flat_groups
|
192 |
+
|
193 |
+
|
194 |
+
def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters):
|
195 |
+
"""
|
196 |
+
Returns fp32 state_dict reconstructed from ds checkpoint
|
197 |
+
|
198 |
+
Args:
|
199 |
+
- ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
|
200 |
+
|
201 |
+
"""
|
202 |
+
print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
|
203 |
+
|
204 |
+
optim_files = get_optim_files(ds_checkpoint_dir)
|
205 |
+
zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
|
206 |
+
print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
|
207 |
+
|
208 |
+
model_files = get_model_state_files(ds_checkpoint_dir)
|
209 |
+
|
210 |
+
zero_model_states = parse_model_states(model_files)
|
211 |
+
print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
|
212 |
+
|
213 |
+
if zero_stage <= 2:
|
214 |
+
return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
|
215 |
+
exclude_frozen_parameters)
|
216 |
+
elif zero_stage == 3:
|
217 |
+
return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
|
218 |
+
exclude_frozen_parameters)
|
219 |
+
|
220 |
+
|
221 |
+
def _zero2_merge_frozen_params(state_dict, zero_model_states):
|
222 |
+
if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
|
223 |
+
return
|
224 |
+
|
225 |
+
frozen_param_shapes = zero_model_states[0].frozen_param_shapes
|
226 |
+
frozen_param_fragments = zero_model_states[0].frozen_param_fragments
|
227 |
+
|
228 |
+
if debug:
|
229 |
+
num_elem = sum(s.numel() for s in frozen_param_shapes.values())
|
230 |
+
print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
|
231 |
+
|
232 |
+
wanted_params = len(frozen_param_shapes)
|
233 |
+
wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
|
234 |
+
avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
|
235 |
+
print(f'Frozen params: Have {avail_numel} numels to process.')
|
236 |
+
print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
|
237 |
+
|
238 |
+
total_params = 0
|
239 |
+
total_numel = 0
|
240 |
+
for name, shape in frozen_param_shapes.items():
|
241 |
+
total_params += 1
|
242 |
+
unpartitioned_numel = shape.numel()
|
243 |
+
total_numel += unpartitioned_numel
|
244 |
+
|
245 |
+
state_dict[name] = frozen_param_fragments[name]
|
246 |
+
|
247 |
+
if debug:
|
248 |
+
print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
|
249 |
+
|
250 |
+
print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
|
251 |
+
|
252 |
+
|
253 |
+
def _has_callable(obj, fn):
|
254 |
+
attr = getattr(obj, fn, None)
|
255 |
+
return callable(attr)
|
256 |
+
|
257 |
+
|
258 |
+
def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
|
259 |
+
param_shapes = zero_model_states[0].param_shapes
|
260 |
+
|
261 |
+
# Reconstruction protocol:
|
262 |
+
#
|
263 |
+
# XXX: document this
|
264 |
+
|
265 |
+
if debug:
|
266 |
+
for i in range(world_size):
|
267 |
+
for j in range(len(fp32_flat_groups[0])):
|
268 |
+
print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
|
269 |
+
|
270 |
+
# XXX: memory usage doubles here (zero2)
|
271 |
+
num_param_groups = len(fp32_flat_groups[0])
|
272 |
+
merged_single_partition_of_fp32_groups = []
|
273 |
+
for i in range(num_param_groups):
|
274 |
+
merged_partitions = [sd[i] for sd in fp32_flat_groups]
|
275 |
+
full_single_fp32_vector = torch.cat(merged_partitions, 0)
|
276 |
+
merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
|
277 |
+
avail_numel = sum(
|
278 |
+
[full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
|
279 |
+
|
280 |
+
if debug:
|
281 |
+
wanted_params = sum([len(shapes) for shapes in param_shapes])
|
282 |
+
wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
|
283 |
+
# not asserting if there is a mismatch due to possible padding
|
284 |
+
print(f"Have {avail_numel} numels to process.")
|
285 |
+
print(f"Need {wanted_numel} numels in {wanted_params} params.")
|
286 |
+
|
287 |
+
# params
|
288 |
+
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
|
289 |
+
# out-of-core computing solution
|
290 |
+
total_numel = 0
|
291 |
+
total_params = 0
|
292 |
+
for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
|
293 |
+
offset = 0
|
294 |
+
avail_numel = full_single_fp32_vector.numel()
|
295 |
+
for name, shape in shapes.items():
|
296 |
+
|
297 |
+
unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
|
298 |
+
total_numel += unpartitioned_numel
|
299 |
+
total_params += 1
|
300 |
+
|
301 |
+
if debug:
|
302 |
+
print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
|
303 |
+
state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
|
304 |
+
offset += unpartitioned_numel
|
305 |
+
|
306 |
+
# Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
|
307 |
+
# avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
|
308 |
+
# paddings performed in the code it's almost impossible to predict the exact numbers w/o the
|
309 |
+
# live optimizer object, so we are checking that the numbers are within the right range
|
310 |
+
align_to = 2 * world_size
|
311 |
+
|
312 |
+
def zero2_align(x):
|
313 |
+
return align_to * math.ceil(x / align_to)
|
314 |
+
|
315 |
+
if debug:
|
316 |
+
print(f"original offset={offset}, avail_numel={avail_numel}")
|
317 |
+
|
318 |
+
offset = zero2_align(offset)
|
319 |
+
avail_numel = zero2_align(avail_numel)
|
320 |
+
|
321 |
+
if debug:
|
322 |
+
print(f"aligned offset={offset}, avail_numel={avail_numel}")
|
323 |
+
|
324 |
+
# Sanity check
|
325 |
+
if offset != avail_numel:
|
326 |
+
raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
|
327 |
+
|
328 |
+
print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
|
329 |
+
|
330 |
+
|
331 |
+
def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
|
332 |
+
exclude_frozen_parameters):
|
333 |
+
state_dict = OrderedDict()
|
334 |
+
|
335 |
+
# buffers
|
336 |
+
buffers = zero_model_states[0].buffers
|
337 |
+
state_dict.update(buffers)
|
338 |
+
if debug:
|
339 |
+
print(f"added {len(buffers)} buffers")
|
340 |
+
|
341 |
+
if not exclude_frozen_parameters:
|
342 |
+
_zero2_merge_frozen_params(state_dict, zero_model_states)
|
343 |
+
|
344 |
+
_zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
|
345 |
+
|
346 |
+
# recover shared parameters
|
347 |
+
for pair in zero_model_states[0].shared_params:
|
348 |
+
if pair[1] in state_dict:
|
349 |
+
state_dict[pair[0]] = state_dict[pair[1]]
|
350 |
+
|
351 |
+
return state_dict
|
352 |
+
|
353 |
+
|
354 |
+
def zero3_partitioned_param_info(unpartitioned_numel, world_size):
|
355 |
+
remainder = unpartitioned_numel % world_size
|
356 |
+
padding_numel = (world_size - remainder) if remainder else 0
|
357 |
+
partitioned_numel = math.ceil(unpartitioned_numel / world_size)
|
358 |
+
return partitioned_numel, padding_numel
|
359 |
+
|
360 |
+
|
361 |
+
def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
|
362 |
+
if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
|
363 |
+
return
|
364 |
+
|
365 |
+
if debug:
|
366 |
+
for i in range(world_size):
|
367 |
+
num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
|
368 |
+
print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
|
369 |
+
|
370 |
+
frozen_param_shapes = zero_model_states[0].frozen_param_shapes
|
371 |
+
wanted_params = len(frozen_param_shapes)
|
372 |
+
wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
|
373 |
+
avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
|
374 |
+
print(f'Frozen params: Have {avail_numel} numels to process.')
|
375 |
+
print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
|
376 |
+
|
377 |
+
total_params = 0
|
378 |
+
total_numel = 0
|
379 |
+
for name, shape in zero_model_states[0].frozen_param_shapes.items():
|
380 |
+
total_params += 1
|
381 |
+
unpartitioned_numel = shape.numel()
|
382 |
+
total_numel += unpartitioned_numel
|
383 |
+
|
384 |
+
param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
|
385 |
+
state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
|
386 |
+
|
387 |
+
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
|
388 |
+
|
389 |
+
if debug:
|
390 |
+
print(
|
391 |
+
f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
|
392 |
+
)
|
393 |
+
|
394 |
+
print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
|
395 |
+
|
396 |
+
|
397 |
+
def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
|
398 |
+
param_shapes = zero_model_states[0].param_shapes
|
399 |
+
avail_numel = fp32_flat_groups[0].numel() * world_size
|
400 |
+
# Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
|
401 |
+
# param, re-consolidating each param, while dealing with padding if any
|
402 |
+
|
403 |
+
# merge list of dicts, preserving order
|
404 |
+
param_shapes = {k: v for d in param_shapes for k, v in d.items()}
|
405 |
+
|
406 |
+
if debug:
|
407 |
+
for i in range(world_size):
|
408 |
+
print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
|
409 |
+
|
410 |
+
wanted_params = len(param_shapes)
|
411 |
+
wanted_numel = sum(shape.numel() for shape in param_shapes.values())
|
412 |
+
# not asserting if there is a mismatch due to possible padding
|
413 |
+
avail_numel = fp32_flat_groups[0].numel() * world_size
|
414 |
+
print(f"Trainable params: Have {avail_numel} numels to process.")
|
415 |
+
print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
|
416 |
+
|
417 |
+
# params
|
418 |
+
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
|
419 |
+
# out-of-core computing solution
|
420 |
+
offset = 0
|
421 |
+
total_numel = 0
|
422 |
+
total_params = 0
|
423 |
+
for name, shape in param_shapes.items():
|
424 |
+
|
425 |
+
unpartitioned_numel = shape.numel()
|
426 |
+
total_numel += unpartitioned_numel
|
427 |
+
total_params += 1
|
428 |
+
|
429 |
+
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
|
430 |
+
|
431 |
+
if debug:
|
432 |
+
print(
|
433 |
+
f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
|
434 |
+
)
|
435 |
+
|
436 |
+
# XXX: memory usage doubles here
|
437 |
+
state_dict[name] = torch.cat(
|
438 |
+
tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)),
|
439 |
+
0).narrow(0, 0, unpartitioned_numel).view(shape)
|
440 |
+
offset += partitioned_numel
|
441 |
+
|
442 |
+
offset *= world_size
|
443 |
+
|
444 |
+
# Sanity check
|
445 |
+
if offset != avail_numel:
|
446 |
+
raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
|
447 |
+
|
448 |
+
print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
|
449 |
+
|
450 |
+
|
451 |
+
def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
|
452 |
+
exclude_frozen_parameters):
|
453 |
+
state_dict = OrderedDict()
|
454 |
+
|
455 |
+
# buffers
|
456 |
+
buffers = zero_model_states[0].buffers
|
457 |
+
state_dict.update(buffers)
|
458 |
+
if debug:
|
459 |
+
print(f"added {len(buffers)} buffers")
|
460 |
+
|
461 |
+
if not exclude_frozen_parameters:
|
462 |
+
_zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
|
463 |
+
|
464 |
+
_zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
|
465 |
+
|
466 |
+
# recover shared parameters
|
467 |
+
for pair in zero_model_states[0].shared_params:
|
468 |
+
if pair[1] in state_dict:
|
469 |
+
state_dict[pair[0]] = state_dict[pair[1]]
|
470 |
+
|
471 |
+
return state_dict
|
472 |
+
|
473 |
+
|
474 |
+
def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_frozen_parameters=False):
|
475 |
+
"""
|
476 |
+
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
|
477 |
+
``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
|
478 |
+
via a model hub.
|
479 |
+
|
480 |
+
Args:
|
481 |
+
- ``checkpoint_dir``: path to the desired checkpoint folder
|
482 |
+
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
|
483 |
+
- ``exclude_frozen_parameters``: exclude frozen parameters
|
484 |
+
|
485 |
+
Returns:
|
486 |
+
- pytorch ``state_dict``
|
487 |
+
|
488 |
+
Note: this approach may not work if your application doesn't have sufficient free CPU memory and
|
489 |
+
you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
|
490 |
+
the checkpoint.
|
491 |
+
|
492 |
+
A typical usage might be ::
|
493 |
+
|
494 |
+
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
|
495 |
+
# do the training and checkpoint saving
|
496 |
+
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
|
497 |
+
model = model.cpu() # move to cpu
|
498 |
+
model.load_state_dict(state_dict)
|
499 |
+
# submit to model hub or save the model to share with others
|
500 |
+
|
501 |
+
In this example the ``model`` will no longer be usable in the deepspeed context of the same
|
502 |
+
application. i.e. you will need to re-initialize the deepspeed engine, since
|
503 |
+
``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
|
504 |
+
|
505 |
+
If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
|
506 |
+
|
507 |
+
"""
|
508 |
+
if tag is None:
|
509 |
+
latest_path = os.path.join(checkpoint_dir, 'latest')
|
510 |
+
if os.path.isfile(latest_path):
|
511 |
+
with open(latest_path, 'r') as fd:
|
512 |
+
tag = fd.read().strip()
|
513 |
+
else:
|
514 |
+
raise ValueError(f"Unable to find 'latest' file at {latest_path}")
|
515 |
+
|
516 |
+
ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
|
517 |
+
|
518 |
+
if not os.path.isdir(ds_checkpoint_dir):
|
519 |
+
raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
|
520 |
+
|
521 |
+
return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
|
522 |
+
|
523 |
+
|
524 |
+
def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None, exclude_frozen_parameters=False):
|
525 |
+
"""
|
526 |
+
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
|
527 |
+
loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
|
528 |
+
|
529 |
+
Args:
|
530 |
+
- ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
|
531 |
+
- ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
|
532 |
+
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
|
533 |
+
- ``exclude_frozen_parameters``: exclude frozen parameters
|
534 |
+
"""
|
535 |
+
|
536 |
+
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters)
|
537 |
+
print(f"Saving fp32 state dict to {output_file}")
|
538 |
+
torch.save(state_dict, output_file)
|
539 |
+
|
540 |
+
|
541 |
+
def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
|
542 |
+
"""
|
543 |
+
1. Put the provided model to cpu
|
544 |
+
2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
|
545 |
+
3. Load it into the provided model
|
546 |
+
|
547 |
+
Args:
|
548 |
+
- ``model``: the model object to update
|
549 |
+
- ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
|
550 |
+
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
|
551 |
+
|
552 |
+
Returns:
|
553 |
+
- ``model`: modified model
|
554 |
+
|
555 |
+
Make sure you have plenty of CPU memory available before you call this function. If you don't
|
556 |
+
have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
|
557 |
+
conveniently placed for you in the checkpoint folder.
|
558 |
+
|
559 |
+
A typical usage might be ::
|
560 |
+
|
561 |
+
from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
|
562 |
+
model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
|
563 |
+
# submit to model hub or save the model to share with others
|
564 |
+
|
565 |
+
Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
|
566 |
+
of the same application. i.e. you will need to re-initialize the deepspeed engine, since
|
567 |
+
``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
|
568 |
+
|
569 |
+
"""
|
570 |
+
logger.info(f"Extracting fp32 weights")
|
571 |
+
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
|
572 |
+
|
573 |
+
logger.info(f"Overwriting model with fp32 weights")
|
574 |
+
model = model.cpu()
|
575 |
+
model.load_state_dict(state_dict, strict=False)
|
576 |
+
|
577 |
+
return model
|
578 |
+
|
579 |
+
|
580 |
+
if __name__ == "__main__":
|
581 |
+
|
582 |
+
parser = argparse.ArgumentParser()
|
583 |
+
parser.add_argument("checkpoint_dir",
|
584 |
+
type=str,
|
585 |
+
help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
|
586 |
+
parser.add_argument(
|
587 |
+
"output_file",
|
588 |
+
type=str,
|
589 |
+
help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)")
|
590 |
+
parser.add_argument("-t",
|
591 |
+
"--tag",
|
592 |
+
type=str,
|
593 |
+
default=None,
|
594 |
+
help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
|
595 |
+
parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters")
|
596 |
+
parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
|
597 |
+
args = parser.parse_args()
|
598 |
+
|
599 |
+
debug = args.debug
|
600 |
+
|
601 |
+
convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir,
|
602 |
+
args.output_file,
|
603 |
+
tag=args.tag,
|
604 |
+
exclude_frozen_parameters=args.exclude_frozen_parameters)
|