|
import torch
|
|
import torchvision.models as models
|
|
from model_code import InitialOnlyImageTagger
|
|
from safetensors.torch import load_file
|
|
|
|
|
|
|
|
safetensors_path = 'model_initial.safetensors'
|
|
state_dict = load_file(safetensors_path, device='cpu')
|
|
|
|
|
|
model = InitialOnlyImageTagger(total_tags=70527, dataset=None, pretrained=True)
|
|
model.load_state_dict(state_dict)
|
|
model.eval()
|
|
|
|
|
|
dummy_input = torch.randn(1, 3, 512, 512, dtype=torch.float32)
|
|
|
|
|
|
onnx_path = "camie_tagger_initial_v15.onnx"
|
|
torch.onnx.export(
|
|
model, dummy_input, onnx_path,
|
|
export_params=True,
|
|
opset_version=13,
|
|
do_constant_folding=True,
|
|
input_names=["input"],
|
|
output_names=["initial_logits", "refined_logits"],
|
|
dynamic_axes={"input": {0: "batch_size"}}
|
|
)
|
|
print(f"ONNX model saved to: {onnx_path}")
|
|
|