marcusmi4n's picture
Upload folder using huggingface_hub
b2cacbb verified
#!/usr/bin/env python3
"""
Export Qwen 2.5 models using TorchScript tracing instead of ONNX
This approach may work better for modern transformer architectures
"""
import argparse
import torch
import torch.nn as nn
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer
import logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
class TracedQwenModel(nn.Module):
"""Wrapper for Qwen model that's optimized for tracing"""
def __init__(self, original_model):
super().__init__()
self.model = original_model
def forward(self, input_ids):
# Simple forward pass without any dynamic features
with torch.no_grad():
outputs = self.model(input_ids=input_ids, use_cache=False, return_dict=False)
return outputs[0] # Return only logits
def export_traced_model(model_path: str, output_path: Path, seq_len: int = 64):
"""Export model using TorchScript tracing"""
logger.info(f"Loading model from {model_path}")
# Load model with minimal settings
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float32, # Use float32 for better compatibility
device_map="cpu",
trust_remote_code=True,
low_cpu_mem_usage=True
)
model.eval()
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Create wrapper
traced_model = TracedQwenModel(model)
traced_model.eval()
# Create sample input
sample_input = torch.randint(1, min(tokenizer.vocab_size, 1000), (1, seq_len), dtype=torch.long)
logger.info("Creating TorchScript traced model...")
try:
# Use tracing instead of scripting
traced_script = torch.jit.trace(traced_model, sample_input)
# Save the traced model
output_path.parent.mkdir(parents=True, exist_ok=True)
torch.jit.save(traced_script, str(output_path))
logger.info(f"✅ Traced model saved to: {output_path}")
# Test the traced model
logger.info("Testing traced model...")
with torch.no_grad():
original_output = traced_model(sample_input)
traced_output = traced_script(sample_input)
diff = torch.abs(original_output - traced_output).max().item()
logger.info(f"Max difference between original and traced: {diff}")
return True
except Exception as e:
logger.error(f"Tracing failed: {e}")
return False
def convert_traced_to_onnx(traced_path: Path, onnx_path: Path, seq_len: int = 64):
"""Convert traced model to ONNX"""
logger.info(f"Converting traced model to ONNX: {onnx_path}")
try:
# Load traced model
traced_model = torch.jit.load(str(traced_path))
traced_model.eval()
# Create sample input
sample_input = torch.randint(1, 1000, (1, seq_len), dtype=torch.long)
# Export to ONNX
torch.onnx.export(
traced_model,
sample_input,
str(onnx_path),
input_names=["input_ids"],
output_names=["logits"],
opset_version=11,
do_constant_folding=False,
dynamic_axes=None, # Fixed shapes
)
logger.info(f"✅ ONNX model saved to: {onnx_path}")
return True
except Exception as e:
logger.error(f"ONNX conversion failed: {e}")
return False
def main():
parser = argparse.ArgumentParser(description="Export Qwen 2.5 using TorchScript tracing")
parser.add_argument("--model-path", type=str, required=True, help="Model path or HF model name")
parser.add_argument("--output-dir", type=Path, required=True, help="Output directory")
parser.add_argument("--seq-len", type=int, default=64, help="Sequence length for tracing")
parser.add_argument("--export-onnx", action="store_true", help="Also export to ONNX")
args = parser.parse_args()
# Create output directory
args.output_dir.mkdir(parents=True, exist_ok=True)
# Export traced model
traced_path = args.output_dir / "model_traced.pt"
success = export_traced_model(args.model_path, traced_path, args.seq_len)
if success and args.export_onnx:
onnx_path = args.output_dir / "model.onnx"
convert_traced_to_onnx(traced_path, onnx_path, args.seq_len)
if __name__ == "__main__":
main()