AngelBottomless's picture
Upload 9 files
a7ab59e verified
import torch
import torchvision.models as models
from model_code import InitialOnlyImageTagger # Assume model_code.py classes are accessible
from safetensors.torch import load_file
# Load the trained weights (Initial-only model). Adjust path to your weights file.
#weights_path = "model_initial_only.pt"
safetensors_path = 'model_initial.safetensors'
state_dict = load_file(safetensors_path, device='cpu')
#state_dict = torch.load(weights_path, map_location="cpu")
# Instantiate the model with the same parameters as training
model = InitialOnlyImageTagger(total_tags=70527, dataset=None, pretrained=True) # dataset not needed for forward
model.load_state_dict(state_dict)
model.eval() # set to evaluation mode
# Define example input – a dummy image tensor of the expected input shape (1, 3, 512, 512)
dummy_input = torch.randn(1, 3, 512, 512, dtype=torch.float32)
# Export to ONNX
onnx_path = "camie_tagger_initial_v15.onnx"
torch.onnx.export(
model, dummy_input, onnx_path,
export_params=True, # store the trained parameter weights in the model file
opset_version=13, # ONNX opset version (13 is widely supported)
do_constant_folding=True, # optimize constant expressions
input_names=["input"],
output_names=["initial_logits", "refined_logits"], # model.forward returns two outputs (identical for InitialOnly)
dynamic_axes={"input": {0: "batch_size"}} # allow variable batch size
)
print(f"ONNX model saved to: {onnx_path}")