kvaishnavi's picture
Upload Phi-3.5-vision-instruct scripts to make ONNX models
bd10177
import argparse
import onnx
import os
import requests
import shutil
import subprocess
import sys
import torch
from onnxruntime_genai.models.builder import create_model
from PIL import Image
from transformers import AutoConfig, AutoProcessor, AutoModelForCausalLM
def build_vision(args):
# Many images:
prompt = f"{user_prompt}<|image_1|>\n <|image_2|>\n <|image_3|>\n <|image_4|>\n What is shown in these four images?{prompt_suffix}{assistant_prompt}"
url = "https://www.ilankelman.org/stopsigns/australia.jpg"
image_1 = Image.open(requests.get(url, stream=True).raw)
url = "https://img.freepik.com/free-photo/painting-mountain-lake-with-mountain-background_188544-9126.jpg?w=2000"
image_2 = Image.open(requests.get(url, stream=True).raw)
url = "https://th.bing.com/th/id/OIP.gCvQ1vmPVJmrq1nnzM3ZHQHaEo?rs=1&pid=ImgDetMain"
image_3 = Image.open(requests.get(url, stream=True).raw)
url = "https://wallpaper.dog/large/10809054.jpg"
image_4 = Image.open(requests.get(url, stream=True).raw)
images = [image_1, image_2, image_3, image_4]
inputs = processor(prompt, images, return_tensors="pt").to(args.execution_provider.replace("dml", "cuda"))
inputs["pixel_values"] = inputs["pixel_values"].to(args.precision)
# TorchScript export
dummy_inputs = (
inputs["pixel_values"], # inputs_embeds: Optional[torch.FloatTensor] = None,
inputs["image_sizes"], # image_sizes: Optional[torch.FloatTensor] = None,
)
dynamic_axes = {
"pixel_values": {0: "num_images", 1: "max_num_crops", 3: "height", 4: "width"},
"image_sizes": {0: "num_images"},
"image_features": {0: "num_image_tokens"},
}
filename = "phi-3.5-v-instruct-vision.onnx"
temp_folder_1 = os.path.join(args.output, "vision_init_export")
os.makedirs(temp_folder_1, exist_ok=True)
fpath_1 = os.path.join(temp_folder_1, filename)
torch.onnx.export(
model.model.vision_embed_tokens,
args=dummy_inputs,
f=fpath_1,
export_params=True,
input_names=["pixel_values", "image_sizes"],
output_names=["image_features"],
dynamic_axes=dynamic_axes,
opset_version=14,
do_constant_folding=True,
)
onnx.checker.check_model(fpath_1)
onnx.shape_inference.infer_shapes_path(fpath_1)
onnx_model = onnx.load_model(fpath_1, load_external_data=True)
temp_folder_2 = os.path.join(args.output, "vision_after_export")
os.makedirs(temp_folder_2, exist_ok=True)
fpath_2 = os.path.join(temp_folder_2, filename)
onnx.save_model(
onnx_model,
fpath_2,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=f"{filename}.data",
size_threshold=0,
convert_attribute=False,
)
shutil.rmtree(temp_folder_1)
# ORT transformer optimizer
temp_folder_3 = os.path.join(args.output, "vision_after_opt")
fpath_3 = os.path.join(temp_folder_3, filename)
subprocess.run(
[
f"{sys.executable}", "-m", "onnxruntime.transformers.optimizer",
"--input", fpath_2,
"--output", fpath_3,
"--model_type", "clip",
"--num_heads", str(16),
"--hidden_size", str(1024),
"--use_external_data_format",
"--opt_level", str(0),
"--disable_shape_inference",
]
)
shutil.rmtree(temp_folder_2)
# ORT 4-bits quantizer
fpath_4 = os.path.join(args.output, filename)
cmd = [
f"{sys.executable}", "-m", "onnxruntime.quantization.matmul_4bits_quantizer",
"--input_model", fpath_3,
"--output_model", fpath_4,
"--block_size", str(32),
]
if args.precision == torch.float32: cmd.extend(["--accuracy_level", str(4)])
subprocess.run(cmd)
shutil.rmtree(temp_folder_3)
def build_embedding(args):
# TorchScript export
batch_size, sequence_length, num_img_tokens = 2, 8, 2
inputs = {
"input_ids": torch.randint(low=0, high=config.vocab_size, size=(batch_size, sequence_length), device=args.execution_provider.replace("dml", "cuda"), dtype=torch.int64),
"image_features": torch.randn(num_img_tokens, config.hidden_size, device=args.execution_provider.replace("dml", "cuda"), dtype=args.precision),
"inputs_embeds": torch.randn(batch_size, sequence_length, config.hidden_size, device=args.execution_provider.replace("dml", "cuda"), dtype=args.precision),
}
inputs["input_ids"][0][0] = -1
inputs["input_ids"][0][1] = -1
dummy_inputs = (
inputs["input_ids"], # input_ids: torch.LongTensor
inputs["image_features"], # image_features: Optional[torch.FloatTensor] = None,
)
dynamic_axes = {
"input_ids": {0: "batch_size", 1: "sequence_length"},
"image_features": {0: "num_image_tokens"},
"inputs_embeds": {0: "batch_size", 1: "sequence_length"},
}
filename = "phi-3.5-v-instruct-embedding.onnx"
temp_folder_1 = os.path.join(args.output, "embedding_init_export")
os.makedirs(temp_folder_1, exist_ok=True)
fpath_1 = os.path.join(temp_folder_1, filename)
torch.onnx.export(
model.model.combined_embed,
args=dummy_inputs,
f=fpath_1,
export_params=True,
input_names=["input_ids", "image_features"],
output_names=["inputs_embeds"],
dynamic_axes=dynamic_axes,
opset_version=14,
do_constant_folding=True,
)
onnx.checker.check_model(fpath_1)
onnx.shape_inference.infer_shapes_path(fpath_1)
onnx_model = onnx.load_model(fpath_1, load_external_data=True)
fpath_2 = os.path.join(args.output, filename)
onnx.save_model(
onnx_model,
fpath_2,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=f"{filename}.data",
size_threshold=0,
convert_attribute=False,
)
shutil.rmtree(temp_folder_1)
def build_text(args):
# Create ONNX model
model_name = None
precision = "int4"
extra_options = {
"exclude_embeds": "true",
"filename": "phi-3.5-v-instruct-text.onnx",
}
if args.precision == torch.float32: extra_options["int4_accuracy_level"] = 4
create_model(model_name, args.input, args.output, precision, args.execution_provider, args.cache_dir, **extra_options)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--input",
required=True,
help="Path to folder on disk containing the Hugging Face config, model, tokenizer, etc.",
)
parser.add_argument(
"-o",
"--output",
required=True,
help="Path to folder to store ONNX model and additional files (e.g. GenAI config, external data files, etc.)",
)
parser.add_argument(
"-p",
"--precision",
required=True,
choices=["fp16", "fp32"],
help="Precision to export PyTorch components with",
)
parser.add_argument(
"-e",
"--execution_provider",
required=True,
choices=["cpu", "cuda", "dml"],
help="Execution provider for Phi-3.5 vision components",
)
parser.add_argument(
"-c",
"--cache_dir",
required=False,
default=os.path.join('.', 'cache_dir'),
help="Cache directory for Hugging Face files and temporary ONNX external data files",
)
args = parser.parse_args()
args.precision = torch.float16 if args.precision == "fp16" else torch.float32
return args
if __name__ == "__main__":
user_prompt = '<|user|>\n'
assistant_prompt = '<|assistant|>\n'
prompt_suffix = "<|end|>\n"
args = get_args()
config = AutoConfig.from_pretrained(args.input, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(args.input, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(args.input, trust_remote_code=True, torch_dtype=args.precision).to(args.execution_provider.replace("dml", "cuda"))
# Build model components
build_vision(args)
build_embedding(args)
build_text(args)