AbrahamicSolver / model_merger.py
Gatsby767's picture
Upload 5 files
da64666 verified
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import re
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Tuple
import numpy as np
import torch
from torch.distributed._tensor import DTensor, Placement, Shard
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForTokenClassification,
AutoModelForVision2Seq,
PretrainedConfig,
PreTrainedModel,
)
def merge_by_placement(tensors: List[torch.Tensor], placement: Placement):
if placement.is_replicate():
return tensors[0]
elif placement.is_partial():
raise NotImplementedError("Partial placement is not supported yet")
elif placement.is_shard():
return torch.cat(tensors, dim=placement.dim).contiguous()
else:
raise ValueError(f"Unsupported placement: {placement}")
def upload_model_to_huggingface(local_path: str, remote_path: str):
# Push to hugging face
from huggingface_hub import HfApi
api = HfApi()
api.create_repo(repo_id=remote_path, private=False, exist_ok=True)
api.upload_folder(repo_id=remote_path, folder_path=local_path, repo_type="model")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--local_dir", required=True, type=str, help="The path for your saved model")
parser.add_argument("--hf_upload_path", default=False, type=str, help="The path of the huggingface repo to upload")
args = parser.parse_args()
local_dir: str = args.local_dir
assert not local_dir.endswith("huggingface"), "The local_dir should not end with huggingface."
# copy rank zero to find the shape of (dp, fsdp)
rank = 0
world_size = 0
for filename in os.listdir(local_dir):
match = re.match(r"model_world_size_(\d+)_rank_0\.pt", filename)
if match:
world_size = match.group(1)
break
assert world_size, "No model file with the proper format."
rank0_weight_path = os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt")
state_dict = torch.load(rank0_weight_path, map_location="cpu", weights_only=False)
pivot_key = sorted(state_dict.keys())[0]
weight = state_dict[pivot_key]
if isinstance(weight, DTensor):
# get sharding info
device_mesh = weight.device_mesh
mesh = device_mesh.mesh
mesh_dim_names = device_mesh.mesh_dim_names
else:
# for non-DTensor
mesh = np.array([int(world_size)], dtype=np.int64)
mesh_dim_names = ("fsdp",)
print(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}")
assert mesh_dim_names in (("fsdp",), ("ddp", "fsdp")), f"Unsupported mesh_dim_names {mesh_dim_names}."
if "tp" in mesh_dim_names:
# fsdp * tp
total_shards = mesh.shape[-1] * mesh.shape[-2]
mesh_shape = (mesh.shape[-2], mesh.shape[-1])
else:
# fsdp
total_shards = mesh.shape[-1]
mesh_shape = (mesh.shape[-1],)
print(f"Processing {total_shards} model shards in total.")
model_state_dict_lst = []
model_state_dict_lst.append(state_dict)
model_state_dict_lst.extend([""] * (total_shards - 1))
def process_one_shard(rank, model_state_dict_lst):
model_path = os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt")
state_dict = torch.load(model_path, map_location="cpu", weights_only=False)
model_state_dict_lst[rank] = state_dict
return state_dict
with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor:
for rank in range(1, total_shards):
executor.submit(process_one_shard, rank, model_state_dict_lst)
state_dict: Dict[str, List[torch.Tensor]] = {}
param_placements: Dict[str, List[Placement]] = {}
keys = set(model_state_dict_lst[0].keys())
for key in keys:
state_dict[key] = []
for model_state_dict in model_state_dict_lst:
try:
tensor = model_state_dict.pop(key)
except Exception:
print(f"Cannot find key {key} in rank {rank}.")
if isinstance(tensor, DTensor):
state_dict[key].append(tensor._local_tensor.bfloat16())
placements = tuple(tensor.placements)
# replicated placement at ddp dimension can be discarded
if mesh_dim_names[0] == "ddp":
placements = placements[1:]
if key not in param_placements:
param_placements[key] = placements
else:
assert param_placements[key] == placements
else:
state_dict[key].append(tensor.bfloat16())
del model_state_dict_lst
for key in sorted(state_dict):
if not isinstance(state_dict[key], list):
print(f"No need to merge key {key}")
continue
if key in param_placements:
# merge shards
placements: Tuple[Shard] = param_placements[key]
if len(mesh_shape) == 1:
# 1-D list, FSDP without TP
assert len(placements) == 1
shards = state_dict[key]
state_dict[key] = merge_by_placement(shards, placements[0])
else:
# 2-D list, FSDP + TP
raise NotImplementedError("FSDP + TP is not supported yet.")
else:
state_dict[key] = torch.cat(state_dict[key], dim=0)
print("Merge completed.")
hf_path = os.path.join(local_dir, "huggingface")
config: PretrainedConfig = AutoConfig.from_pretrained(hf_path)
architectures: List[str] = getattr(config, "architectures", ["Unknown"])
if "ForTokenClassification" in architectures[0]:
AutoClass = AutoModelForTokenClassification
elif "ForCausalLM" in architectures[0]:
AutoClass = AutoModelForCausalLM
elif "ForConditionalGeneration" in architectures[0]:
AutoClass = AutoModelForVision2Seq
else:
raise NotImplementedError(f"Unknown architecture {architectures}.")
with torch.device("meta"):
model: PreTrainedModel = AutoClass.from_config(config, torch_dtype=torch.bfloat16)
assert isinstance(model, PreTrainedModel)
model.to_empty(device="cpu")
print(f"Saving model to {hf_path}...")
model.save_pretrained(hf_path, state_dict=state_dict)
del state_dict, model
if args.hf_upload_path:
upload_model_to_huggingface(hf_path, args.hf_upload_path)