import torch import torchvision.models as models from model_code import InitialOnlyImageTagger # Assume model_code.py classes are accessible from safetensors.torch import load_file # Load the trained weights (Initial-only model). Adjust path to your weights file. #weights_path = "model_initial_only.pt" safetensors_path = 'model_initial.safetensors' state_dict = load_file(safetensors_path, device='cpu') #state_dict = torch.load(weights_path, map_location="cpu") # Instantiate the model with the same parameters as training model = InitialOnlyImageTagger(total_tags=70527, dataset=None, pretrained=True) # dataset not needed for forward model.load_state_dict(state_dict) model.eval() # set to evaluation mode # Define example input – a dummy image tensor of the expected input shape (1, 3, 512, 512) dummy_input = torch.randn(1, 3, 512, 512, dtype=torch.float32) # Export to ONNX onnx_path = "camie_tagger_initial_v15.onnx" torch.onnx.export( model, dummy_input, onnx_path, export_params=True, # store the trained parameter weights in the model file opset_version=13, # ONNX opset version (13 is widely supported) do_constant_folding=True, # optimize constant expressions input_names=["input"], output_names=["initial_logits", "refined_logits"], # model.forward returns two outputs (identical for InitialOnly) dynamic_axes={"input": {0: "batch_size"}} # allow variable batch size ) print(f"ONNX model saved to: {onnx_path}")