|
import argparse |
|
import onnx |
|
import os |
|
import requests |
|
import shutil |
|
import subprocess |
|
import sys |
|
import torch |
|
|
|
from onnxruntime_genai.models.builder import create_model |
|
from PIL import Image |
|
from transformers import AutoConfig, AutoProcessor, AutoModelForCausalLM |
|
|
|
|
|
def build_vision(args): |
|
|
|
prompt = f"{user_prompt}<|image_1|>\n <|image_2|>\n <|image_3|>\n <|image_4|>\n What is shown in these four images?{prompt_suffix}{assistant_prompt}" |
|
url = "https://www.ilankelman.org/stopsigns/australia.jpg" |
|
image_1 = Image.open(requests.get(url, stream=True).raw) |
|
url = "https://img.freepik.com/free-photo/painting-mountain-lake-with-mountain-background_188544-9126.jpg?w=2000" |
|
image_2 = Image.open(requests.get(url, stream=True).raw) |
|
url = "https://th.bing.com/th/id/OIP.gCvQ1vmPVJmrq1nnzM3ZHQHaEo?rs=1&pid=ImgDetMain" |
|
image_3 = Image.open(requests.get(url, stream=True).raw) |
|
url = "https://wallpaper.dog/large/10809054.jpg" |
|
image_4 = Image.open(requests.get(url, stream=True).raw) |
|
images = [image_1, image_2, image_3, image_4] |
|
inputs = processor(prompt, images, return_tensors="pt").to(args.execution_provider.replace("dml", "cuda")) |
|
inputs["pixel_values"] = inputs["pixel_values"].to(args.precision) |
|
|
|
|
|
dummy_inputs = ( |
|
inputs["pixel_values"], |
|
inputs["image_sizes"], |
|
) |
|
dynamic_axes = { |
|
"pixel_values": {0: "num_images", 1: "max_num_crops", 3: "height", 4: "width"}, |
|
"image_sizes": {0: "num_images"}, |
|
"image_features": {0: "num_image_tokens"}, |
|
} |
|
filename = "phi-3.5-v-instruct-vision.onnx" |
|
|
|
temp_folder_1 = os.path.join(args.output, "vision_init_export") |
|
os.makedirs(temp_folder_1, exist_ok=True) |
|
|
|
fpath_1 = os.path.join(temp_folder_1, filename) |
|
torch.onnx.export( |
|
model.model.vision_embed_tokens, |
|
args=dummy_inputs, |
|
f=fpath_1, |
|
export_params=True, |
|
input_names=["pixel_values", "image_sizes"], |
|
output_names=["image_features"], |
|
dynamic_axes=dynamic_axes, |
|
opset_version=14, |
|
do_constant_folding=True, |
|
) |
|
|
|
onnx.checker.check_model(fpath_1) |
|
onnx.shape_inference.infer_shapes_path(fpath_1) |
|
onnx_model = onnx.load_model(fpath_1, load_external_data=True) |
|
|
|
temp_folder_2 = os.path.join(args.output, "vision_after_export") |
|
os.makedirs(temp_folder_2, exist_ok=True) |
|
|
|
fpath_2 = os.path.join(temp_folder_2, filename) |
|
onnx.save_model( |
|
onnx_model, |
|
fpath_2, |
|
save_as_external_data=True, |
|
all_tensors_to_one_file=True, |
|
location=f"{filename}.data", |
|
size_threshold=0, |
|
convert_attribute=False, |
|
) |
|
shutil.rmtree(temp_folder_1) |
|
|
|
|
|
temp_folder_3 = os.path.join(args.output, "vision_after_opt") |
|
fpath_3 = os.path.join(temp_folder_3, filename) |
|
subprocess.run( |
|
[ |
|
f"{sys.executable}", "-m", "onnxruntime.transformers.optimizer", |
|
"--input", fpath_2, |
|
"--output", fpath_3, |
|
"--model_type", "clip", |
|
"--num_heads", str(16), |
|
"--hidden_size", str(1024), |
|
"--use_external_data_format", |
|
"--opt_level", str(0), |
|
"--disable_shape_inference", |
|
] |
|
) |
|
shutil.rmtree(temp_folder_2) |
|
|
|
|
|
fpath_4 = os.path.join(args.output, filename) |
|
cmd = [ |
|
f"{sys.executable}", "-m", "onnxruntime.quantization.matmul_4bits_quantizer", |
|
"--input_model", fpath_3, |
|
"--output_model", fpath_4, |
|
"--block_size", str(32), |
|
] |
|
if args.precision == torch.float32: cmd.extend(["--accuracy_level", str(4)]) |
|
subprocess.run(cmd) |
|
shutil.rmtree(temp_folder_3) |
|
|
|
|
|
def build_embedding(args): |
|
|
|
batch_size, sequence_length, num_img_tokens = 2, 8, 2 |
|
inputs = { |
|
"input_ids": torch.randint(low=0, high=config.vocab_size, size=(batch_size, sequence_length), device=args.execution_provider.replace("dml", "cuda"), dtype=torch.int64), |
|
"image_features": torch.randn(num_img_tokens, config.hidden_size, device=args.execution_provider.replace("dml", "cuda"), dtype=args.precision), |
|
"inputs_embeds": torch.randn(batch_size, sequence_length, config.hidden_size, device=args.execution_provider.replace("dml", "cuda"), dtype=args.precision), |
|
} |
|
inputs["input_ids"][0][0] = -1 |
|
inputs["input_ids"][0][1] = -1 |
|
dummy_inputs = ( |
|
inputs["input_ids"], |
|
inputs["image_features"], |
|
) |
|
dynamic_axes = { |
|
"input_ids": {0: "batch_size", 1: "sequence_length"}, |
|
"image_features": {0: "num_image_tokens"}, |
|
"inputs_embeds": {0: "batch_size", 1: "sequence_length"}, |
|
} |
|
filename = "phi-3.5-v-instruct-embedding.onnx" |
|
|
|
temp_folder_1 = os.path.join(args.output, "embedding_init_export") |
|
os.makedirs(temp_folder_1, exist_ok=True) |
|
|
|
fpath_1 = os.path.join(temp_folder_1, filename) |
|
torch.onnx.export( |
|
model.model.combined_embed, |
|
args=dummy_inputs, |
|
f=fpath_1, |
|
export_params=True, |
|
input_names=["input_ids", "image_features"], |
|
output_names=["inputs_embeds"], |
|
dynamic_axes=dynamic_axes, |
|
opset_version=14, |
|
do_constant_folding=True, |
|
) |
|
|
|
onnx.checker.check_model(fpath_1) |
|
onnx.shape_inference.infer_shapes_path(fpath_1) |
|
onnx_model = onnx.load_model(fpath_1, load_external_data=True) |
|
|
|
fpath_2 = os.path.join(args.output, filename) |
|
onnx.save_model( |
|
onnx_model, |
|
fpath_2, |
|
save_as_external_data=True, |
|
all_tensors_to_one_file=True, |
|
location=f"{filename}.data", |
|
size_threshold=0, |
|
convert_attribute=False, |
|
) |
|
shutil.rmtree(temp_folder_1) |
|
|
|
|
|
def build_text(args): |
|
|
|
model_name = None |
|
precision = "int4" |
|
extra_options = { |
|
"exclude_embeds": "true", |
|
"filename": "phi-3.5-v-instruct-text.onnx", |
|
} |
|
if args.precision == torch.float32: extra_options["int4_accuracy_level"] = 4 |
|
create_model(model_name, args.input, args.output, precision, args.execution_provider, args.cache_dir, **extra_options) |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument( |
|
"-i", |
|
"--input", |
|
required=True, |
|
help="Path to folder on disk containing the Hugging Face config, model, tokenizer, etc.", |
|
) |
|
|
|
parser.add_argument( |
|
"-o", |
|
"--output", |
|
required=True, |
|
help="Path to folder to store ONNX model and additional files (e.g. GenAI config, external data files, etc.)", |
|
) |
|
|
|
parser.add_argument( |
|
"-p", |
|
"--precision", |
|
required=True, |
|
choices=["fp16", "fp32"], |
|
help="Precision to export PyTorch components with", |
|
) |
|
|
|
parser.add_argument( |
|
"-e", |
|
"--execution_provider", |
|
required=True, |
|
choices=["cpu", "cuda", "dml"], |
|
help="Execution provider for Phi-3.5 vision components", |
|
) |
|
|
|
parser.add_argument( |
|
"-c", |
|
"--cache_dir", |
|
required=False, |
|
default=os.path.join('.', 'cache_dir'), |
|
help="Cache directory for Hugging Face files and temporary ONNX external data files", |
|
) |
|
|
|
args = parser.parse_args() |
|
args.precision = torch.float16 if args.precision == "fp16" else torch.float32 |
|
return args |
|
|
|
if __name__ == "__main__": |
|
user_prompt = '<|user|>\n' |
|
assistant_prompt = '<|assistant|>\n' |
|
prompt_suffix = "<|end|>\n" |
|
|
|
args = get_args() |
|
config = AutoConfig.from_pretrained(args.input, trust_remote_code=True) |
|
processor = AutoProcessor.from_pretrained(args.input, trust_remote_code=True) |
|
model = AutoModelForCausalLM.from_pretrained(args.input, trust_remote_code=True, torch_dtype=args.precision).to(args.execution_provider.replace("dml", "cuda")) |
|
|
|
|
|
build_vision(args) |
|
build_embedding(args) |
|
build_text(args) |