math-zhu commited on
Commit
fa3338f
·
verified ·
1 Parent(s): a676db9

Init Program

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
README.md CHANGED
@@ -1,3 +1,38 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 介绍
2
+ 使用了MiniMind2的模型参数,
3
+ - Github项目链接在:<a href="https://github.com/jingyaogong/minimind">Github Link</a>
4
+ - HuggingFace链接在 <a href="https://huggingface.co/jingyaogong/MiniMind2">Hugging Face</a>
5
+ # 快速开始
6
+ 安装依赖:
7
+ ```bash
8
+ pip install torch, transformer
9
+ ```
10
+
11
+ 运行模型:
12
+ ```bash
13
+ python model_congnilite.py
14
+ ```
15
+
16
+
17
+ # 常见问题介绍
18
+ 在流式输出中,每输出一个token_id,就将它解码为字符并输出,会造成中文乱码现象,但是将token_id放到一个列表中一起解码就不会出现乱码
19
+
20
+ 专业描述:**token边界不对齐导致的解码错误**
21
+
22
+ - tokenizer采用的是子词(subword)分词(如BPE、SentencePiece等),一个汉字或词语可能被拆成多个token。
23
+ - 单独解码一个token_id时,tokenizer.decode()会把这个token当作一个完整的单元去还原为字符,但实际上它可能只是一个汉字的“片段”或“字节”,导致输出乱码或不可见字符。
24
+ - 只有把一组token_id(即一个完整的token序列)一起decode,tokenizer才能正确地拼接还原出原始的中文字符。
25
+
26
+
27
+ 原本的代码:
28
+ ```python
29
+ new_token_str = tokenizer.decode(next_token_id.item(), skip_special_tokens=False)
30
+ print(new_token_str, end='', flush=True)
31
+ ```
32
+
33
+ 更改后:
34
+ ```python
35
+ prev_decoded = tokenizer.decode(token_list[:-1], skip_special_tokens=False)
36
+ curr_decoded = tokenizer.decode(token_list, skip_special_tokens=False)
37
+ print(curr_decoded[len(prev_decoded):], end='', flush=True)
38
+ ```
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6ac5213cee7e73410aaf2f422589537fe47e920c1bf3dd4e2aced5a4b5410442
3
+ size 217908728
model_cognilite.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy import false
2
+ import test
3
+ from transformers import PretrainedConfig
4
+
5
+ # 定义了模型的超参数和配置
6
+ class CogniLiteConfig(PretrainedConfig):
7
+ model_type = "minimind"
8
+
9
+ def __init__(
10
+ self,
11
+ dropout: float = 0.0,
12
+ bos_token_id: int = 1,
13
+ eos_token_id: int = 2,
14
+ hidden_act: str = 'silu',
15
+ hidden_size: int = 768,
16
+ intermediate_size: int = None,
17
+ max_position_embeddings: int = 32768,
18
+ num_attention_heads: int = 8,
19
+ num_hidden_layers: int = 16,
20
+ num_key_value_heads: int = 2,
21
+ vocab_size: int = 6400,
22
+ rms_norm_eps: float = 1e-05,
23
+ rope_theta: int = 1000000.0,
24
+ **kwargs
25
+ ):
26
+ super().__init__(**kwargs)
27
+ # 各种模型超参数
28
+ self.dropout = dropout
29
+ self.bos_token_id = bos_token_id
30
+ self.eos_token_id = eos_token_id
31
+ self.hidden_act = hidden_act
32
+ self.hidden_size = hidden_size
33
+ self.intermediate_size = intermediate_size
34
+ self.max_position_embeddings = max_position_embeddings
35
+ self.num_attention_heads = num_attention_heads
36
+ self.num_hidden_layers = num_hidden_layers
37
+ self.num_key_value_heads = num_key_value_heads
38
+ self.vocab_size = vocab_size
39
+ self.rms_norm_eps = rms_norm_eps
40
+ self.rope_theta = rope_theta
41
+
42
+ import math
43
+ import torch
44
+ from torch import nn
45
+ from transformers.activations import ACT2FN
46
+ from typing import Optional, Tuple, List, Union
47
+ import torch.nn.functional as F
48
+
49
+ # RMSNorm 层实现,Root Mean Square Layer Normalization
50
+ class RMSNorm(torch.nn.Module):
51
+ def __init__(self, dim: int, eps: float = 1e-5):
52
+ super().__init__()
53
+ self.eps = eps
54
+ self.weight = nn.Parameter(torch.ones(dim))
55
+
56
+ def _norm(self, x):
57
+ # 归一化操作
58
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
59
+
60
+ def forward(self, x):
61
+ # 应用归一化和缩放
62
+ return self.weight * self._norm(x.float()).type_as(x)
63
+
64
+ # 预计算旋转位置编码的频率
65
+ def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
66
+ # 生成旋转位置编码所需的 cos 和 sin
67
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
68
+ t = torch.arange(end, device=freqs.device)
69
+ freqs = torch.outer(t, freqs).float()
70
+ freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1)
71
+ freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1)
72
+ return freqs_cos, freqs_sin
73
+
74
+ # 应用旋转位置编码到 Q、K
75
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
76
+ def rotate_half(x):
77
+ # 将向量一分为二,后一半取负并交换
78
+ return torch.cat((-x[..., x.shape[-1] // 2:], x[..., : x.shape[-1] // 2]), dim=-1)
79
+
80
+ q_embed = (q * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(q) * sin.unsqueeze(unsqueeze_dim))
81
+ k_embed = (k * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(k) * sin.unsqueeze(unsqueeze_dim))
82
+ return q_embed, k_embed
83
+
84
+ # 将 KV 头重复扩展到所有 attention head
85
+ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
86
+ """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
87
+ bs, slen, num_key_value_heads, head_dim = x.shape
88
+ if n_rep == 1:
89
+ return x
90
+ return (
91
+ x[:, :, :, None, :]
92
+ .expand(bs, slen, num_key_value_heads, n_rep, head_dim)
93
+ .reshape(bs, slen, num_key_value_heads * n_rep, head_dim)
94
+ )
95
+
96
+ # 注意力机制实现
97
+ class Attention(nn.Module):
98
+ def __init__(self, args: CogniLiteConfig):
99
+ super().__init__()
100
+ # 处理 KV 头数
101
+ self.num_key_value_heads = args.num_attention_heads if args.num_key_value_heads is None else args.num_key_value_heads
102
+ assert args.num_attention_heads % self.num_key_value_heads == 0
103
+ self.n_local_heads = args.num_attention_heads
104
+ self.n_local_kv_heads = self.num_key_value_heads
105
+ self.n_rep = self.n_local_heads // self.n_local_kv_heads
106
+ self.head_dim = args.hidden_size // args.num_attention_heads
107
+ # QKV 投影
108
+ self.q_proj = nn.Linear(args.hidden_size, args.num_attention_heads * self.head_dim, bias=False)
109
+ self.k_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
110
+ self.v_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
111
+ self.o_proj = nn.Linear(args.num_attention_heads * self.head_dim, args.hidden_size, bias=False)
112
+ self.attn_dropout = nn.Dropout(args.dropout)
113
+ self.resid_dropout = nn.Dropout(args.dropout)
114
+ self.dropout = args.dropout
115
+ # 是否使用 flash attention
116
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
117
+
118
+ def forward(self,
119
+ x: torch.Tensor,
120
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor], # cos 和 sin
121
+ past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
122
+ use_cache=False,
123
+ attention_mask: Optional[torch.Tensor] = None):
124
+ bsz, seq_len, _ = x.shape
125
+ # QKV 投影并 reshape
126
+ xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x)
127
+ xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
128
+ xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
129
+ xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
130
+
131
+ cos, sin = position_embeddings
132
+ # 应用旋转位置编码
133
+ xq, xk = apply_rotary_pos_emb(xq, xk, cos[:seq_len], sin[:seq_len])
134
+
135
+ # 拼接 KV cache
136
+ if past_key_value is not None:
137
+ xk = torch.cat([past_key_value[0], xk], dim=1)
138
+ xv = torch.cat([past_key_value[1], xv], dim=1)
139
+ past_kv = (xk, xv) if use_cache else None
140
+
141
+ # KV 头扩展到所有 attention head
142
+ xq, xk, xv = (
143
+ xq.transpose(1, 2),
144
+ repeat_kv(xk, self.n_rep).transpose(1, 2),
145
+ repeat_kv(xv, self.n_rep).transpose(1, 2)
146
+ )
147
+
148
+ # 使用 flash attention 或常规 attention
149
+ if self.flash and seq_len != 1:
150
+ dropout_p = self.dropout if self.training else 0.0
151
+ attn_mask = None
152
+ if attention_mask is not None:
153
+ attn_mask = attention_mask.view(bsz, 1, 1, -1).expand(bsz, self.n_local_heads, seq_len, -1)
154
+ attn_mask = attn_mask.bool() if attention_mask is not None else None
155
+
156
+ output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True)
157
+ else:
158
+ # 计算注意力分数
159
+ scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
160
+ scores = scores + torch.triu(
161
+ torch.full((seq_len, seq_len), float("-inf"), device=scores.device),
162
+ diagonal=1
163
+ ).unsqueeze(0).unsqueeze(0) # 上三角 mask
164
+
165
+ if attention_mask is not None:
166
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
167
+ extended_attention_mask = (1.0 - extended_attention_mask) * -1e9
168
+ scores = scores + extended_attention_mask
169
+
170
+ scores = F.softmax(scores.float(), dim=-1).type_as(xq)
171
+ scores = self.attn_dropout(scores)
172
+ output = scores @ xv
173
+
174
+ # 恢复 shape 并输出
175
+ output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
176
+ output = self.resid_dropout(self.o_proj(output))
177
+ return output, past_kv
178
+
179
+ # 前馈网络实现
180
+ class FeedForward(nn.Module):
181
+ def __init__(self, config: CogniLiteConfig):
182
+ super().__init__()
183
+ # 自动推断中间层维度
184
+ if config.intermediate_size is None:
185
+ intermediate_size = int(config.hidden_size * 8 / 3)
186
+ config.intermediate_size = 64 * ((intermediate_size + 64 - 1) // 64)
187
+ self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
188
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
189
+ self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
190
+ self.dropout = nn.Dropout(config.dropout)
191
+ self.act_fn = ACT2FN[config.hidden_act]
192
+
193
+ def forward(self, x):
194
+ # SwiGLU 激活
195
+ return self.dropout(self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)))
196
+
197
+ # Transformer Block
198
+ class TransformerBlock(nn.Module):
199
+ def __init__(self, layer_id: int, config: CogniLiteConfig):
200
+ super().__init__()
201
+ self.num_attention_heads = config.num_attention_heads
202
+ self.hidden_size = config.hidden_size
203
+ self.head_dim = config.hidden_size // config.num_attention_heads
204
+ self.self_attn = Attention(config)
205
+
206
+ self.layer_id = layer_id
207
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
208
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
209
+ self.mlp = FeedForward(config)
210
+
211
+ def forward(self, hidden_states, position_embeddings, past_key_value=None, use_cache=False, attention_mask=None):
212
+ # 残差连接 + 注意力 + 前馈
213
+ residual = hidden_states
214
+ hidden_states, present_key_value = self.self_attn(
215
+ self.input_layernorm(hidden_states), position_embeddings,
216
+ past_key_value, use_cache, attention_mask
217
+ )
218
+ hidden_states += residual
219
+ hidden_states = hidden_states + self.mlp(self.post_attention_layernorm(hidden_states))
220
+ return hidden_states, present_key_value
221
+
222
+ # CogniLite模型主体
223
+ class CogniLiteModel(nn.Module):
224
+ def __init__(self, config: CogniLiteConfig):
225
+ super().__init__()
226
+ self.config = config
227
+ self.vocab_size, self.num_hidden_layers = config.vocab_size, config.num_hidden_layers
228
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
229
+ self.dropout = nn.Dropout(config.dropout)
230
+ self.layers = nn.ModuleList([TransformerBlock(l, config) for l in range(self.num_hidden_layers)])
231
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
232
+
233
+ # 注册旋转位置编码的 cos/sin buffer
234
+ freqs_cos, freqs_sin = precompute_freqs_cis(dim=config.hidden_size // config.num_attention_heads,
235
+ end=config.max_position_embeddings, theta=config.rope_theta)
236
+ self.register_buffer("freqs_cos", freqs_cos, persistent=False)
237
+ self.register_buffer("freqs_sin", freqs_sin, persistent=False)
238
+
239
+ def forward(self,
240
+ input_ids: Optional[torch.Tensor] = None,
241
+ attention_mask: Optional[torch.Tensor] = None,
242
+ past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
243
+ use_cache: bool = False,
244
+ **kwargs):
245
+ # input_ids: (batch, seq)
246
+ _, seq_length = input_ids.shape
247
+ past_key_values = past_key_values or [None] * len(self.layers)
248
+ start_pos = past_key_values[0][0].shape[1] if past_key_values[0] is not None else 0
249
+
250
+ # 词嵌入
251
+ hidden_states = self.dropout(self.embed_tokens(input_ids))
252
+
253
+ # 取出对应位置的 cos/sin
254
+ position_embeddings = (
255
+ self.freqs_cos[start_pos:start_pos + seq_length],
256
+ self.freqs_sin[start_pos:start_pos + seq_length]
257
+ )
258
+
259
+ presents = []
260
+ for layer_idx, (layer, past_key_value) in enumerate(zip(self.layers, past_key_values)):
261
+ hidden_states, present = layer(
262
+ hidden_states,
263
+ position_embeddings,
264
+ past_key_value=past_key_value,
265
+ use_cache=use_cache,
266
+ attention_mask=attention_mask
267
+ )
268
+ presents.append(present)
269
+
270
+ hidden_states = self.norm(hidden_states)
271
+
272
+ return hidden_states, presents, 0
273
+
274
+ class CogniLiteForCausalLM(nn.Module):
275
+ def __init__(self, config: CogniLiteConfig = None):
276
+ super().__init__()
277
+ self.config = config or CogniLiteConfig()
278
+ self.model = CogniLiteModel(self.config)
279
+ self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
280
+ # 权重共享
281
+ self.lm_head.weight = self.model.embed_tokens.weight
282
+
283
+ def forward(self,
284
+ input_ids: Optional[torch.Tensor] = None,
285
+ attention_mask: Optional[torch.Tensor] = None,
286
+ past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
287
+ use_cache: bool = False,
288
+ logits_to_keep: Union[int, torch.Tensor] = 0,
289
+ **args):
290
+ h, past_kvs, aux_loss = self.model(
291
+ input_ids=input_ids,
292
+ attention_mask=attention_mask,
293
+ past_key_values=past_key_values,
294
+ use_cache=use_cache,
295
+ **args
296
+ )
297
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) and logits_to_keep > 0 else slice(None)
298
+ logits = self.lm_head(h[:, slice_indices, :])
299
+ return {
300
+ "last_hidden_state": h,
301
+ "logits": logits,
302
+ "aux_loss": aux_loss,
303
+ "past_key_values": past_kvs
304
+ }
305
+
306
+ import safetensors.torch
307
+ from transformers import AutoTokenizer
308
+
309
+ def init_cognilite_model():
310
+ print("start loading CogniLite model...")
311
+
312
+ # CogniLite Total parameters: 104M
313
+ # structure: (hidden_size=768, num_hidden_layers=16)
314
+ args = {
315
+ "device": "cuda" if torch.cuda.is_available() else "cpu",
316
+ "hidden_size": 768,
317
+ "num_hidden_layers": 16,
318
+ }
319
+ tokenizer = AutoTokenizer.from_pretrained('./tokenizer/')
320
+
321
+ state_dict = safetensors.torch.load_file("model.safetensors", device=args["device"])
322
+
323
+ model = CogniLiteForCausalLM(CogniLiteConfig())
324
+
325
+ # 加载模型参数
326
+ model.load_state_dict(state_dict, strict= True)
327
+
328
+ print(f'模型参数量: {sum(p.numel() for p in model.parameters() if p.requires_grad)}')
329
+ return model.eval().to(args["device"]), tokenizer
330
+
331
+ import random
332
+ import numpy as np
333
+ def setup_seed(seed):
334
+ random.seed(seed)
335
+ np.random.seed(seed)
336
+ torch.manual_seed(seed)
337
+ torch.cuda.manual_seed(seed)
338
+ torch.cuda.manual_seed_all(seed)
339
+ torch.backends.cudnn.deterministic = True
340
+ torch.backends.cudnn.benchmark = False
341
+
342
+ def communicate_with_model(random_seed):
343
+ model, tokenizer = init_cognilite_model()
344
+
345
+ print("随机种子是:", random_seed)
346
+ setup_seed(random_seed)
347
+
348
+ prompt= input("你: ")
349
+
350
+
351
+ messages = [{"role": "user", "content": prompt}]
352
+ new_prompt = tokenizer.apply_chat_template(
353
+ messages,
354
+ tokenize=False,
355
+ add_generation_prompt=True
356
+ )
357
+
358
+ device = "cuda" if torch.cuda.is_available() else "cpu"
359
+
360
+ inputs = tokenizer(
361
+ new_prompt,
362
+ return_tensors="pt",
363
+ truncation=True
364
+ ).to(device)
365
+
366
+ # shape: [seq_len]
367
+ input_ids = inputs["input_ids"][0]
368
+ attention_mask = inputs.get("attention_mask", None)
369
+ max_new_tokens = 128
370
+ eos_token_id = tokenizer.eos_token_id
371
+
372
+ exit_reason = None
373
+
374
+ token_list = []
375
+
376
+ print("模型 token 输出:[", end=' ')
377
+
378
+ for _ in range(max_new_tokens):
379
+ with torch.no_grad():
380
+ outputs = model(
381
+ input_ids=input_ids.unsqueeze(0),
382
+ attention_mask=attention_mask
383
+ )
384
+ logits = outputs["logits"]
385
+
386
+ next_token_id = torch.argmax(logits[0, -1], dim=-1).unsqueeze(0)
387
+ if next_token_id.item() == eos_token_id:
388
+ exit_reason = "EOS token detected"
389
+ break
390
+
391
+ token_list.append(next_token_id.item())
392
+
393
+ print(next_token_id.item(), end=' ', flush=True)
394
+
395
+ # 拼接到输入
396
+ input_ids = torch.cat([input_ids, next_token_id], dim=0)
397
+
398
+ # attention_mask 也要扩展
399
+ if attention_mask is not None:
400
+ attention_mask = torch.cat([attention_mask[0], torch.ones(1, device=device, dtype=attention_mask.dtype)], dim=0).unsqueeze(0)
401
+
402
+ print("]\n模型文字输出: " + tokenizer.decode(token_list, skip_special_tokens=False))
403
+
404
+ if exit_reason is None:
405
+ print("\n 结束对话原因: 达到最大 Token 数量限制。")
406
+
407
+ elif exit_reason == "EOS token detected":
408
+ print("\n 结束对话原因: EOS token detected.")
409
+
410
+ if __name__ == "__main__":
411
+ random_type = input("请输入随机种子(整数):")
412
+ try:
413
+ random_seed = int(random_type)
414
+ if random_seed <= 0:
415
+ print("随机种子不能为非正整数,使用随机值")
416
+ random_seed = random.randint(0, 10000)
417
+ except ValueError:
418
+ print("无效的随机种子,使用随机值")
419
+ random_seed = random.randint(0, 10000)
420
+ communicate_with_model(random_seed)
model_lora.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ # 定义Lora网络结构
6
+ class LoRA(nn.Module):
7
+ def __init__(self, in_features, out_features, rank):
8
+ super().__init__()
9
+ self.rank = rank # LoRA的秩(rank),控制低秩矩阵的大小
10
+ self.A = nn.Linear(in_features, rank, bias=False) # 低秩矩阵A
11
+ self.B = nn.Linear(rank, out_features, bias=False) # 低秩矩阵B
12
+ # 矩阵A高斯初始化
13
+ self.A.weight.data.normal_(mean=0.0, std=0.02)
14
+ # 矩阵B全0初始化
15
+ self.B.weight.data.zero_()
16
+
17
+ def forward(self, x):
18
+ return self.B(self.A(x))
19
+
20
+
21
+ def apply_lora(model, rank=8):
22
+ for name, module in model.named_modules():
23
+ if isinstance(module, nn.Linear) and module.weight.shape[0] == module.weight.shape[1]:
24
+ lora = LoRA(module.weight.shape[0], module.weight.shape[1], rank=rank).to(model.device)
25
+ setattr(module, "lora", lora)
26
+ original_forward = module.forward
27
+
28
+ # 显式绑定
29
+ def forward_with_lora(x, layer1=original_forward, layer2=lora):
30
+ return layer1(x) + layer2(x)
31
+
32
+ module.forward = forward_with_lora
33
+
34
+
35
+ def load_lora(model, path):
36
+ state_dict = torch.load(path, map_location=model.device)
37
+ for name, module in model.named_modules():
38
+ if hasattr(module, 'lora'):
39
+ lora_state = {k.replace(f'{name}.lora.', ''): v for k, v in state_dict.items() if f'{name}.lora.' in k}
40
+ module.lora.load_state_dict(lora_state)
41
+
42
+
43
+ def save_lora(model, path):
44
+ state_dict = {}
45
+ for name, module in model.named_modules():
46
+ if hasattr(module, 'lora'):
47
+ lora_state = {f'{name}.lora.{k}': v for k, v in module.lora.state_dict().items()}
48
+ state_dict.update(lora_state)
49
+ torch.save(state_dict, path)
tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|im_start|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|im_end|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<|endoftext|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<|endoftext|>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
tokenizer/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": false,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<|endoftext|>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<|im_start|>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "2": {
23
+ "content": "<|im_end|>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": true
29
+ }
30
+ },
31
+ "additional_special_tokens": [],
32
+ "bos_token": "<|im_start|>",
33
+ "chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{{ '<|im_start|>system\\n' + system_message + '<|im_end|>\\n' }}{% else %}{{ '<|im_start|>system\\nYou are a helpful assistant<|im_end|>\\n' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|im_start|>user\\n' + content + '<|im_end|>\\n<|im_start|>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '<|im_end|>' + '\\n' }}{% endif %}{% endfor %}",
34
+ "clean_up_tokenization_spaces": false,
35
+ "eos_token": "<|im_end|>",
36
+ "extra_special_tokens": {},
37
+ "legacy": true,
38
+ "model_max_length": 32768,
39
+ "pad_token": "<|endoftext|>",
40
+ "sp_model_kwargs": {},
41
+ "spaces_between_special_tokens": false,
42
+ "tokenizer_class": "PreTrainedTokenizer",
43
+ "unk_token": "<|endoftext|>"
44
+ }
train_lora.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ from sympy import true
5
+
6
+ __package__ = "trainer"
7
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
8
+
9
+ import argparse
10
+ import time
11
+ import math
12
+ import warnings
13
+ import torch
14
+ from torch import optim, nn
15
+ import torch.distributed as dist
16
+ from contextlib import nullcontexts
17
+ from torch.utils.data import DataLoader, DistributedSampler
18
+ from transformers import AutoTokenizer
19
+ from model_cognilite import CogniLiteConfig, CogniLiteForCausalLM
20
+ from dataset.lm_dataset import SFTDataset
21
+ from model_lora import load_lora, save_lora, apply_lora
22
+
23
+ warnings.filterwarnings('ignore')
24
+
25
+
26
+ # Logger function
27
+ def Logger(content):
28
+ if not ddp or dist.get_rank() == 0:
29
+ print(content)
30
+
31
+
32
+ def get_lr(current_step, total_steps, lr):
33
+ return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
34
+
35
+
36
+ # 代码和full_sft「几乎」一致
37
+ def train_epoch(epoch, wandb):
38
+ loss_fct = nn.CrossEntropyLoss(reduction='none')
39
+ start_time = time.time()
40
+ for step, (X, Y, loss_mask) in enumerate(train_loader):
41
+ X = X.to(args.device)
42
+ Y = Y.to(args.device)
43
+ loss_mask = loss_mask.to(args.device)
44
+ lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
45
+ for param_group in optimizer.param_groups:
46
+ param_group['lr'] = lr
47
+
48
+ with ctx:
49
+ res = model(X)
50
+ loss = loss_fct(
51
+ res.logits.view(-1, res.logits.size(-1)),
52
+ Y.view(-1)
53
+ ).view(Y.size())
54
+ loss = (loss * loss_mask).sum() / loss_mask.sum()
55
+ loss += res.aux_loss
56
+ loss = loss / args.accumulation_steps
57
+
58
+ scaler.scale(loss).backward()
59
+
60
+ if (step + 1) % args.accumulation_steps == 0:
61
+ scaler.unscale_(optimizer)
62
+ torch.nn.utils.clip_grad_norm_(lora_params, args.grad_clip)
63
+
64
+ scaler.step(optimizer)
65
+ scaler.update()
66
+
67
+ optimizer.zero_grad(set_to_none=True)
68
+
69
+ if step % args.log_interval == 0:
70
+ spend_time = time.time() - start_time
71
+ Logger(
72
+ 'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format(
73
+ epoch + 1,
74
+ args.epochs,
75
+ step,
76
+ iter_per_epoch,
77
+ loss.item() * args.accumulation_steps,
78
+ optimizer.param_groups[-1]['lr'],
79
+ spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
80
+
81
+ if (wandb is not None) and (not ddp or dist.get_rank() == 0):
82
+ wandb.log({"loss": loss * args.accumulation_steps,
83
+ "lr": optimizer.param_groups[-1]['lr'],
84
+ "epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
85
+
86
+ if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
87
+ model.eval()
88
+ lora_save_path = f'{args.save_dir}/lora/{args.lora_name}_{lm_config.hidden_size}.pth'
89
+ os.makedirs(os.path.dirname(lora_save_path), exist_ok=True)
90
+ # 【区别1】只保存lora权重即可
91
+ save_lora(model, lora_save_path)
92
+ model.train()
93
+
94
+
95
+ def init_model(lm_config):
96
+ current_dir = os.path.dirname(os.path.abspath(__file__))
97
+ model_path = os.path.join(current_dir, '..', 'model')
98
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
99
+ model = CogniLiteForCausalLM(lm_config)
100
+ if args.minimind2:
101
+ model_data_path = os.path.join(current_dir, '..', 'MiniMind2')
102
+ model.from_pretrained(model_data_path)
103
+ return model.to(args.device), tokenizer
104
+ moe_path = '_moe' if lm_config.use_moe else ''
105
+ ckp = f'{args.save_dir}/full_sft_{lm_config.hidden_size}{moe_path}.pth'
106
+ state_dict = torch.load(ckp, map_location=args.device)
107
+ model.load_state_dict(state_dict, strict=False)
108
+ return model.to(args.device), tokenizer
109
+
110
+
111
+ def init_distributed_mode():
112
+ if not ddp: return
113
+ global ddp_local_rank, DEVICE
114
+
115
+ dist.init_process_group(backend="nccl")
116
+ ddp_local_rank = int(os.environ["LOCAL_RANK"])
117
+ DEVICE = f"cuda:{ddp_local_rank}"
118
+ torch.cuda.set_device(DEVICE)
119
+
120
+
121
+ if __name__ == "__main__":
122
+ parser = argparse.ArgumentParser(description="MiniMind SFT with LoRA")
123
+ parser.add_argument("--out_dir", type=str, default="../out")
124
+ parser.add_argument("--epochs", type=int, default=10)
125
+ parser.add_argument("--batch_size", type=int, default=32)
126
+ parser.add_argument("--learning_rate", type=float, default=1e-4)
127
+ parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
128
+ parser.add_argument("--dtype", type=str, default="bfloat16")
129
+ parser.add_argument("--use_wandb", action="store_true")
130
+ parser.add_argument("--wandb_project", type=str, default="MiniMind-LoRA-SFT")
131
+ parser.add_argument("--num_workers", type=int, default=1)
132
+ parser.add_argument("--ddp", action="store_true")
133
+ parser.add_argument("--accumulation_steps", type=int, default=1)
134
+ parser.add_argument("--grad_clip", type=float, default=1.0)
135
+ parser.add_argument("--warmup_iters", type=int, default=0)
136
+ parser.add_argument("--log_interval", type=int, default=100)
137
+ parser.add_argument("--save_interval", type=int, default=100)
138
+ parser.add_argument('--local_rank', type=int, default=-1)
139
+ parser.add_argument('--hidden_size', default=512, type=int)
140
+ parser.add_argument('--num_hidden_layers', default=8, type=int)
141
+ parser.add_argument('--max_seq_len', default=512, type=int)
142
+ parser.add_argument('--use_moe', default=False, type=bool)
143
+ parser.add_argument("--data_path", type=str, default="../dataset/lora_medical.jsonl")
144
+ parser.add_argument("--lora_name", type=str, default="lora_medical", help="根据任务保存成lora_(英文/医学/心理...)")
145
+ parser.add_argument("--minimind2", type=bool, default=true, help="是否使用从huggingface下载下来的MiniMind2模型")
146
+ args = parser.parse_args()
147
+
148
+ if args.minimind2 == true:
149
+ args.hidden_size = 768
150
+ args.num_hidden_layers=16
151
+ current_dir = os.path.dirname(os.path.abspath(__file__))
152
+ args.data_path = os.path.join(current_dir, "../dataset/lora_medical.jsonl")
153
+
154
+
155
+ lm_config = CogniLiteConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
156
+ use_moe=args.use_moe)
157
+ args.save_dir = os.path.join(args.out_dir)
158
+ os.makedirs(args.save_dir, exist_ok=True)
159
+ os.makedirs(args.out_dir, exist_ok=True)
160
+ tokens_per_iter = args.batch_size * args.max_seq_len
161
+ device_type = "cuda" if "cuda" in args.device else "cpu"
162
+
163
+ ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
164
+ ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
165
+ ddp_local_rank, DEVICE = 0, "cuda:0"
166
+ base_seed = 1337
167
+ torch.manual_seed(base_seed)
168
+ torch.cuda.manual_seed(base_seed)
169
+
170
+ if ddp:
171
+ init_distributed_mode()
172
+ args.device = torch.device(DEVICE)
173
+ rank = dist.get_rank()
174
+ torch.manual_seed(base_seed + rank)
175
+ # 同时设置 CUDA 的随机种子
176
+ torch.cuda.manual_seed(base_seed + rank)
177
+
178
+ args.wandb_run_name = f"MiniMind-Lora-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
179
+ if args.use_wandb and (not ddp or ddp_local_rank == 0):
180
+ import wandb
181
+
182
+ wandb.init(project=args.wandb_project, name=args.wandb_run_name)
183
+ else:
184
+ wandb = None
185
+
186
+ model, tokenizer = init_model(lm_config)
187
+ apply_lora(model)
188
+
189
+ total_params = sum(p.numel() for p in model.parameters()) # 总参数数量
190
+ lora_params_count = sum(p.numel() for name, p in model.named_parameters() if 'lora' in name) # LoRA 参数数量
191
+ if not ddp or dist.get_rank() == 0:
192
+ print(f"LLM 总参数量: {total_params}")
193
+ print(f"LoRA 参数量: {lora_params_count}")
194
+ print(f"LoRA 参数占比: {lora_params_count / total_params * 100:.2f}%")
195
+
196
+ for name, param in model.named_parameters():
197
+ if 'lora' not in name:
198
+ param.requires_grad = False
199
+ lora_params = []
200
+ for name, param in model.named_parameters():
201
+ if 'lora' in name:
202
+ lora_params.append(param)
203
+
204
+ # 只对 LoRA 参数进行优化
205
+ optimizer = optim.AdamW(lora_params, lr=args.learning_rate)
206
+ train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
207
+ train_sampler = DistributedSampler(train_ds) if ddp else None
208
+ train_loader = DataLoader(
209
+ train_ds,
210
+ batch_size=args.batch_size,
211
+ pin_memory=True,
212
+ drop_last=False,
213
+ shuffle=False,
214
+ num_workers=args.num_workers,
215
+ sampler=train_sampler
216
+ )
217
+
218
+ scaler = torch.cuda.amp.GradScaler("cuda", enabled=(args.dtype in ['float16', 'bfloat16']))
219
+ iter_per_epoch = len(train_loader)
220
+
221
+ for epoch in range(args.epochs):
222
+ train_epoch(epoch, wandb)