# # For licensing see accompanying LICENSE file. # Copyright (C) 2025 Apple Inc. All Rights Reserved. # import os import json import copy import argparse import torch from llava.model.builder import load_pretrained_model from llava.utils import disable_torch_init from llava.mm_utils import get_model_name_from_path def export(args): # Load model disable_torch_init() model_path = os.path.expanduser(args.model_path) model_name = get_model_name_from_path(model_path) tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name, device="cpu") # Save extra metadata that is not saved during LLaVA training # required by HF for auto-loading model and for mlx-vlm preprocessing # Save image processing config setattr(image_processor, "processor_class", "LlavaProcessor") output_path = os.path.join(model_path, "preprocessor_config.json") image_processor.to_json_file(output_path) # Create processor config processor_config = dict() processor_config["image_token"] = "" processor_config["num_additional_image_tokens"] = 0 processor_config["processor_class"] = "LlavaProcessor" processor_config["patch_size"] = 64 output_path = os.path.join(model_path, "processor_config.json") json.dump(processor_config, open(output_path, "w"), indent=2) # Modify tokenizer to include special token. tokenizer_config_path = os.path.join(model_path, "tokenizer_config.json") tokenizer_config = json.load(open(tokenizer_config_path, 'r')) token_ids = list() image_token_is_present = False for k, v in tokenizer_config['added_tokens_decoder'].items(): token_ids.append(int(k)) if v["content"] == "": image_token_is_present = True token_ids.pop() # Append only if token is not present if not image_token_is_present: tokenizer_config['added_tokens_decoder'][f'{max(token_ids) + 1}'] = copy.deepcopy( tokenizer_config['added_tokens_decoder'][f'{token_ids[0]}']) tokenizer_config['added_tokens_decoder'][f'{max(token_ids) + 1}']["content"] = "" json.dump(tokenizer_config, open(tokenizer_config_path, 'w'), indent=2) # Modify config to contain token id for config_path = os.path.join(model_path, "config.json") model_config = json.load(open(config_path, 'r')) model_config["image_token_index"] = max(token_ids) + 1 json.dump(model_config, open(config_path, 'w'), indent=2) # Export the vision encoder to ONNX image_res = image_processor.to_dict()['size']['shortest_edge'] dummy_vision_input = torch.rand(1, 3, image_res, image_res).float() # Dummy input tensor vision_model = model.get_vision_tower() # Ensure model is on CPU, in float precision, and in evaluation mode for ONNX export vision_model = vision_model.cpu().float().eval() onnx_vision_model_path = os.path.join(model_path, "fastvithd.onnx") print(f"Exporting vision model to {onnx_vision_model_path}...") torch.onnx.export( vision_model, dummy_vision_input, # Pass the dummy input tensor onnx_vision_model_path, input_names=['pixel_values'], # ONNX图中输入节点的名称 output_names=['last_hidden_state'], # ONNX图中输出节点的名称 # dynamic_axes={ # 'pixel_values': {0: 'batch_size'}, # 输入'pixel_values'的第0维是动态的batch_size # 'last_hidden_state': {0: 'batch_size'} # 输出'last_hidden_state'的第0维是动态的batch_size # }, opset_version=17, # ONNX opset 版本 export_params=True, # 在模型文件中存储训练好的参数权重 do_constant_folding=True # 执行常量折叠优化 ) print(f"Vision model ONNX export complete: {onnx_vision_model_path}") # Generate dummy input for mm_projector by passing dummy_vision_input through vision_model # This ensures the mm_projector receives input with the correct shape and characteristics with torch.no_grad(): dummy_mm_projector_input = vision_model(dummy_vision_input) # Ensure the input is on CPU and in float32 precision for the projector dummy_mm_projector_input = dummy_mm_projector_input.cpu().float() # Export the mm_projector to ONNX # model.get_model() gives the underlying base model (e.g., LlavaLlamaModel) # which contains the mm_projector attribute. mm_projector = model.get_model().mm_projector mm_projector = mm_projector.cpu().float().eval() onnx_mm_projector_path = os.path.join(model_path, "mm_projector.onnx") print(f"Exporting mm_projector to {onnx_mm_projector_path}...") torch.onnx.export( mm_projector, dummy_mm_projector_input, onnx_mm_projector_path, input_names=['last_hidden_state'], output_names=['projected_image_features'], opset_version=17, export_params=True, do_constant_folding=True ) print(f"mm_projector ONNX export complete: {onnx_mm_projector_path}") # Removed CoreML specific code and intermediate .pt file handling # No need for os.remove(pt_name) as pt_name is no longer created if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-path", type=str, required=True) parser.add_argument("--model-base", type=str, default=None) parser.add_argument("--conv-mode", type=str, default="qwen_2") args = parser.parse_args() export(args)