chentianqi commited on
Commit
6e8247e
·
verified ·
1 Parent(s): 363e0bf

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +209 -0
README.md ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ - de
6
+ - fr
7
+ - it
8
+ - pt
9
+ - hi
10
+ - es
11
+ - th
12
+ base_model:
13
+ - GSAI-ML/LLaDA-8B-Base
14
+ pipeline_tag: text-generation
15
+ tags:
16
+ - gptqmodel
17
+ - FunAGI
18
+ - llada
19
+ - int4
20
+ ---
21
+
22
+
23
+ This model has been 4-bit quantized Llada-8B-Base model with [GPTQModel](https://github.com/ModelCloud/GPTQModel).
24
+
25
+
26
+ - **bits**: 4
27
+ - **dynamic**: null
28
+ - **group_size**: 128
29
+ - **desc_act**: true
30
+ - **static_groups**: false
31
+ - **sym**: false
32
+ - **lm_head**: false
33
+ - **true_sequential**: true
34
+ - **quant_method**: "gptq"
35
+ - **checkpoint_format**: "gptq"
36
+ - **meta**:
37
+ - **quantizer**: gptqmodel:1.1.0
38
+ - **uri**: https://github.com/modelcloud/gptqmodel
39
+ - **damp_percent**: 0.1
40
+ - **damp_auto_increment**: 0.0015
41
+
42
+ ## Benchmark
43
+ ### Performance of Quantized Models
44
+
45
+ | Dataset | GPTQ-4bit | FP16 |
46
+ |----------------|-------------|------|
47
+ | mmlu | TODO | 65.9(5) |
48
+ | cmmlu | TODO | 69.9(5) |
49
+ | arc_challenge | 45.48 | 47.9(0) |
50
+
51
+ ## Example:
52
+ ```python
53
+ import torch
54
+ from datasets import load_dataset
55
+ from gptqmodel import GPTQModel, QuantizeConfig, BACKEND
56
+ from transformers import AutoTokenizer, AutoModelForCausalLM
57
+ import torch.nn.functional as F
58
+ import numpy as np
59
+
60
+
61
+
62
+
63
+
64
+ def add_gumbel_noise(logits, temperature):
65
+ '''
66
+ The Gumbel max is a method for sampling categorical distributions.
67
+ According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
68
+ Thus, we use float64.
69
+ '''
70
+ logits = logits.to(torch.float64)
71
+ noise = torch.rand_like(logits, dtype=torch.float64)
72
+ gumbel_noise = (- torch.log(noise)) ** temperature
73
+ return logits.exp() / gumbel_noise
74
+
75
+
76
+ def get_num_transfer_tokens(mask_index, steps):
77
+ '''
78
+ In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals.
79
+ Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)),
80
+ the expected number of tokens transitioned at each step should be consistent.
81
+
82
+ This function is designed to precompute the number of tokens that need to be transitioned at each step.
83
+ '''
84
+ mask_num = mask_index.sum(dim=1, keepdim=True) #
85
+
86
+ base = mask_num // steps
87
+ remainder = mask_num % steps
88
+
89
+ num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
90
+
91
+ for i in range(mask_num.size(0)):
92
+ num_transfer_tokens[i, :remainder[i]] += 1
93
+
94
+ return num_transfer_tokens
95
+
96
+
97
+
98
+
99
+
100
+
101
+
102
+ @ torch.no_grad()
103
+ def generate(model, prompt, steps=128, gen_length=128, block_length=128, temperature=0.,
104
+ cfg_scale=0., remasking='low_confidence', mask_id=126336):
105
+ '''
106
+ Args:
107
+ model: Mask predictor.
108
+ prompt: A tensor of shape (1, l).
109
+ steps: Sampling steps, less than or equal to gen_length.
110
+ gen_length: Generated answer length.
111
+ block_length: Block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking.
112
+ temperature: Categorical distribution sampling temperature.
113
+ cfg_scale: Unsupervised classifier-free guidance scale.
114
+ remasking: Remasking strategy. 'low_confidence' or 'random'.
115
+ mask_id: The toke id of [MASK] is 126336.
116
+ '''
117
+ x = torch.full((1, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device)
118
+ x[:, :prompt.shape[1]] = prompt.clone()
119
+
120
+ prompt_index = (x != mask_id)
121
+
122
+ assert gen_length % block_length == 0
123
+ num_blocks = gen_length // block_length
124
+
125
+ assert steps % num_blocks == 0
126
+ steps = steps // num_blocks
127
+
128
+ for num_block in range(num_blocks):
129
+ block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length:] == mask_id)
130
+ num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)
131
+ for i in range(steps):
132
+
133
+ mask_index = (x == mask_id)
134
+ if cfg_scale > 0.:
135
+ un_x = x.clone()
136
+ un_x[prompt_index] = mask_id
137
+ x_ = torch.cat([x, un_x], dim=0)
138
+ logits = model(x_).logits
139
+ logits, un_logits = torch.chunk(logits, 2, dim=0)
140
+ logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
141
+ else:
142
+ logits = model(x).logits
143
+
144
+ logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
145
+ x0 = torch.argmax(logits_with_noise, dim=-1) # b, l
146
+
147
+ if remasking == 'low_confidence':
148
+ p = F.softmax(logits.to(torch.float64), dim=-1)
149
+ x0_p = torch.squeeze(
150
+ torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
151
+ elif remasking == 'random':
152
+ x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
153
+ else:
154
+ raise NotImplementedError(remasking)
155
+
156
+ x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf
157
+
158
+ x0 = torch.where(mask_index, x0, x)
159
+ confidence = torch.where(mask_index, x0_p, -np.inf)
160
+
161
+ transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
162
+ for j in range(confidence.shape[0]):
163
+ _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i])
164
+ transfer_index[j, select_index] = True
165
+ x[transfer_index] = x0[transfer_index]
166
+
167
+ return x
168
+
169
+ def main():
170
+ quantized_model_id="FunAGI/LLaDA-8B-Instruct-gptqmodel-4bit"
171
+ tokenizer = AutoTokenizer.from_pretrained(quantized_model_id ,use_fast=False)
172
+
173
+
174
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
175
+ prompt = "Paul is at a train station and is waiting for his train. He isn't sure how long he needs to wait, but he knows that the fourth train scheduled to arrive at the station is the one he needs to get on. The first train is scheduled to arrive in 10 minutes, and this train will stay in the station for 20 minutes. The second train is to arrive half an hour after the first train leaves the station, and this second train will stay in the station for a quarter of the amount of time that the first train stayed in the station. The third train is to arrive an hour after the second train leaves the station, and this third train is to leave the station immediately after it arrives. The fourth train will arrive 20 minutes after the third train leaves, and this is the train Paul will board. In total, how long, in minutes, will Paul wait for his train?"
176
+
177
+ # # # Add special tokens for the Instruct model. The Base model does not require the following two lines.
178
+ m = [{"role": "user", "content": prompt}, ]
179
+ prompt = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
180
+
181
+ input_ids = tokenizer(prompt)['input_ids']
182
+ input_ids = torch.tensor(input_ids).to(device).unsqueeze(0)
183
+
184
+
185
+
186
+
187
+ model = GPTQModel.load(quantized_model_id, device=device , trust_remote_code=True )
188
+
189
+
190
+ steps=256
191
+ out = generate(model, input_ids, steps=steps , gen_length=256, block_length=8, temperature=0., cfg_scale=0., remasking='low_confidence')
192
+ print("*"*30+ f"Steps {steps}"+ "*"*30)
193
+ print(input_ids.shape)
194
+ print( tokenizer.batch_decode(out[:, input_ids.shape[1]:], skip_special_tokens=True)[0])
195
+
196
+
197
+
198
+ if __name__ == "__main__":
199
+ import logging
200
+
201
+ logging.basicConfig(
202
+ format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
203
+ level=logging.INFO,
204
+ datefmt="%Y-%m-%d %H:%M:%S",
205
+ )
206
+
207
+ main()
208
+
209
+ ```