#!/usr/bin/env python3 """ Export Qwen 2.5 models using TorchScript tracing instead of ONNX This approach may work better for modern transformer architectures """ import argparse import torch import torch.nn as nn from pathlib import Path from transformers import AutoModelForCausalLM, AutoTokenizer import logging logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) class TracedQwenModel(nn.Module): """Wrapper for Qwen model that's optimized for tracing""" def __init__(self, original_model): super().__init__() self.model = original_model def forward(self, input_ids): # Simple forward pass without any dynamic features with torch.no_grad(): outputs = self.model(input_ids=input_ids, use_cache=False, return_dict=False) return outputs[0] # Return only logits def export_traced_model(model_path: str, output_path: Path, seq_len: int = 64): """Export model using TorchScript tracing""" logger.info(f"Loading model from {model_path}") # Load model with minimal settings model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch.float32, # Use float32 for better compatibility device_map="cpu", trust_remote_code=True, low_cpu_mem_usage=True ) model.eval() # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) # Create wrapper traced_model = TracedQwenModel(model) traced_model.eval() # Create sample input sample_input = torch.randint(1, min(tokenizer.vocab_size, 1000), (1, seq_len), dtype=torch.long) logger.info("Creating TorchScript traced model...") try: # Use tracing instead of scripting traced_script = torch.jit.trace(traced_model, sample_input) # Save the traced model output_path.parent.mkdir(parents=True, exist_ok=True) torch.jit.save(traced_script, str(output_path)) logger.info(f"✅ Traced model saved to: {output_path}") # Test the traced model logger.info("Testing traced model...") with torch.no_grad(): original_output = traced_model(sample_input) traced_output = traced_script(sample_input) diff = torch.abs(original_output - traced_output).max().item() logger.info(f"Max difference between original and traced: {diff}") return True except Exception as e: logger.error(f"Tracing failed: {e}") return False def convert_traced_to_onnx(traced_path: Path, onnx_path: Path, seq_len: int = 64): """Convert traced model to ONNX""" logger.info(f"Converting traced model to ONNX: {onnx_path}") try: # Load traced model traced_model = torch.jit.load(str(traced_path)) traced_model.eval() # Create sample input sample_input = torch.randint(1, 1000, (1, seq_len), dtype=torch.long) # Export to ONNX torch.onnx.export( traced_model, sample_input, str(onnx_path), input_names=["input_ids"], output_names=["logits"], opset_version=11, do_constant_folding=False, dynamic_axes=None, # Fixed shapes ) logger.info(f"✅ ONNX model saved to: {onnx_path}") return True except Exception as e: logger.error(f"ONNX conversion failed: {e}") return False def main(): parser = argparse.ArgumentParser(description="Export Qwen 2.5 using TorchScript tracing") parser.add_argument("--model-path", type=str, required=True, help="Model path or HF model name") parser.add_argument("--output-dir", type=Path, required=True, help="Output directory") parser.add_argument("--seq-len", type=int, default=64, help="Sequence length for tracing") parser.add_argument("--export-onnx", action="store_true", help="Also export to ONNX") args = parser.parse_args() # Create output directory args.output_dir.mkdir(parents=True, exist_ok=True) # Export traced model traced_path = args.output_dir / "model_traced.pt" success = export_traced_model(args.model_path, traced_path, args.seq_len) if success and args.export_onnx: onnx_path = args.output_dir / "model.onnx" convert_traced_to_onnx(traced_path, onnx_path, args.seq_len) if __name__ == "__main__": main()