rkfg commited on
Commit
e52102d
·
verified ·
1 Parent(s): d9e7be1

Upload quantize.py

Browse files
Files changed (1) hide show
  1. quantize.py +26 -0
quantize.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ from safetensors.torch import load_file, save_file
4
+ from torch import float8_e4m3fn, float8_e5m2
5
+
6
+ if len(sys.argv) != 4:
7
+ print("Provide input/output file names and fp8 type as either e4m3 or e5m2")
8
+ exit(1)
9
+
10
+ if sys.argv[3] == "e4m3":
11
+ dt = float8_e4m3fn
12
+ elif sys.argv[3] == "e5m2":
13
+ dt = float8_e5m2
14
+ else:
15
+ print("Invalid quantization type, should be either e4m3 or e5m2")
16
+ exit(1)
17
+
18
+ state_dict = load_file(sys.argv[1])
19
+
20
+ for k in state_dict:
21
+ if "norm" in k or "bias" in "k":
22
+ state_dict[k] = state_dict[k].bfloat16()
23
+ else:
24
+ state_dict[k] = state_dict[k].to(dt)
25
+
26
+ save_file(state_dict, sys.argv[2])