Fix fp8_cast_bf16.py: https://github.com/deepseek-ai/DeepSeek-V3/commit/8f1c9488b53068992f9525fab03b1868e6f7c8c1
Browse files- inference/fp8_cast_bf16.py +37 -11
    	
        inference/fp8_cast_bf16.py
    CHANGED
    
    | @@ -16,32 +16,58 @@ def main(fp8_path, bf16_path): | |
| 16 | 
             
                with open(model_index_file, "r") as f:
         | 
| 17 | 
             
                    model_index = json.load(f)
         | 
| 18 | 
             
                weight_map = model_index["weight_map"]
         | 
| 19 | 
            -
                fp8_weight_names = []
         | 
| 20 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 21 | 
             
                safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
         | 
|  | |
| 22 | 
             
                for safetensor_file in tqdm(safetensor_files):
         | 
| 23 | 
             
                    file_name = os.path.basename(safetensor_file)
         | 
| 24 | 
            -
                     | 
|  | |
|  | |
| 25 | 
             
                    new_state_dict = {}
         | 
| 26 | 
            -
                    for weight_name, weight in  | 
| 27 | 
             
                        if weight_name.endswith("_scale_inv"):
         | 
| 28 | 
             
                            continue
         | 
| 29 | 
            -
                        elif weight.element_size() == 1:
         | 
| 30 | 
             
                            scale_inv_name = f"{weight_name}_scale_inv"
         | 
| 31 | 
            -
                             | 
| 32 | 
            -
             | 
| 33 | 
            -
             | 
| 34 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
| 35 | 
             
                        else:
         | 
| 36 | 
             
                            new_state_dict[weight_name] = weight
         | 
|  | |
| 37 | 
             
                    new_safetensor_file = os.path.join(bf16_path, file_name)
         | 
| 38 | 
             
                    save_file(new_state_dict, new_safetensor_file)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 39 |  | 
|  | |
| 40 | 
             
                new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
         | 
| 41 | 
             
                for weight_name in fp8_weight_names:
         | 
| 42 | 
             
                    scale_inv_name = f"{weight_name}_scale_inv"
         | 
| 43 | 
            -
                     | 
| 44 | 
            -
             | 
| 45 | 
             
                with open(new_model_index_file, "w") as f:
         | 
| 46 | 
             
                    json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)
         | 
| 47 |  | 
| @@ -52,4 +78,4 @@ if __name__ == "__main__": | |
| 52 | 
             
                parser.add_argument("--output-bf16-hf-path", type=str, required=True)
         | 
| 53 | 
             
                args = parser.parse_args()
         | 
| 54 | 
             
                main(args.input_fp8_hf_path, args.output_bf16_hf_path)
         | 
| 55 | 
            -
             | 
|  | |
| 16 | 
             
                with open(model_index_file, "r") as f:
         | 
| 17 | 
             
                    model_index = json.load(f)
         | 
| 18 | 
             
                weight_map = model_index["weight_map"]
         | 
|  | |
| 19 |  | 
| 20 | 
            +
                # Cache for loaded safetensor files
         | 
| 21 | 
            +
                loaded_files = {}
         | 
| 22 | 
            +
                fp8_weight_names = []
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                # Helper function to get tensor from the correct file
         | 
| 25 | 
            +
                def get_tensor(tensor_name):
         | 
| 26 | 
            +
                    file_name = weight_map[tensor_name]
         | 
| 27 | 
            +
                    if file_name not in loaded_files:
         | 
| 28 | 
            +
                        file_path = os.path.join(fp8_path, file_name)
         | 
| 29 | 
            +
                        loaded_files[file_name] = load_file(file_path, device="cuda")
         | 
| 30 | 
            +
                    return loaded_files[file_name][tensor_name]
         | 
| 31 | 
            +
             | 
| 32 | 
             
                safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
         | 
| 33 | 
            +
                safetensor_files.sort()
         | 
| 34 | 
             
                for safetensor_file in tqdm(safetensor_files):
         | 
| 35 | 
             
                    file_name = os.path.basename(safetensor_file)
         | 
| 36 | 
            +
                    current_state_dict = load_file(safetensor_file, device="cuda")
         | 
| 37 | 
            +
                    loaded_files[file_name] = current_state_dict
         | 
| 38 | 
            +
                    
         | 
| 39 | 
             
                    new_state_dict = {}
         | 
| 40 | 
            +
                    for weight_name, weight in current_state_dict.items():
         | 
| 41 | 
             
                        if weight_name.endswith("_scale_inv"):
         | 
| 42 | 
             
                            continue
         | 
| 43 | 
            +
                        elif weight.element_size() == 1:  # FP8 weight
         | 
| 44 | 
             
                            scale_inv_name = f"{weight_name}_scale_inv"
         | 
| 45 | 
            +
                            try:
         | 
| 46 | 
            +
                                # Get scale_inv from the correct file
         | 
| 47 | 
            +
                                scale_inv = get_tensor(scale_inv_name)
         | 
| 48 | 
            +
                                fp8_weight_names.append(weight_name)
         | 
| 49 | 
            +
                                new_state_dict[weight_name] = weight_dequant(weight, scale_inv)
         | 
| 50 | 
            +
                            except KeyError:
         | 
| 51 | 
            +
                                print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
         | 
| 52 | 
            +
                                new_state_dict[weight_name] = weight
         | 
| 53 | 
             
                        else:
         | 
| 54 | 
             
                            new_state_dict[weight_name] = weight
         | 
| 55 | 
            +
                            
         | 
| 56 | 
             
                    new_safetensor_file = os.path.join(bf16_path, file_name)
         | 
| 57 | 
             
                    save_file(new_state_dict, new_safetensor_file)
         | 
| 58 | 
            +
                    
         | 
| 59 | 
            +
                    # Memory management: keep only the 2 most recently used files
         | 
| 60 | 
            +
                    if len(loaded_files) > 2:
         | 
| 61 | 
            +
                        oldest_file = next(iter(loaded_files))
         | 
| 62 | 
            +
                        del loaded_files[oldest_file]
         | 
| 63 | 
            +
                        torch.cuda.empty_cache()
         | 
| 64 |  | 
| 65 | 
            +
                # Update model index
         | 
| 66 | 
             
                new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
         | 
| 67 | 
             
                for weight_name in fp8_weight_names:
         | 
| 68 | 
             
                    scale_inv_name = f"{weight_name}_scale_inv"
         | 
| 69 | 
            +
                    if scale_inv_name in weight_map:
         | 
| 70 | 
            +
                        weight_map.pop(scale_inv_name)
         | 
| 71 | 
             
                with open(new_model_index_file, "w") as f:
         | 
| 72 | 
             
                    json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)
         | 
| 73 |  | 
|  | |
| 78 | 
             
                parser.add_argument("--output-bf16-hf-path", type=str, required=True)
         | 
| 79 | 
             
                args = parser.parse_args()
         | 
| 80 | 
             
                main(args.input_fp8_hf_path, args.output_bf16_hf_path)
         | 
| 81 | 
            +
             | 
