File size: 4,617 Bytes
b2cacbb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#!/usr/bin/env python3
"""
Export Qwen 2.5 models using TorchScript tracing instead of ONNX
This approach may work better for modern transformer architectures
"""

import argparse
import torch
import torch.nn as nn
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer
import logging

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

class TracedQwenModel(nn.Module):
    """Wrapper for Qwen model that's optimized for tracing"""
    
    def __init__(self, original_model):
        super().__init__()
        self.model = original_model
        
    def forward(self, input_ids):
        # Simple forward pass without any dynamic features
        with torch.no_grad():
            outputs = self.model(input_ids=input_ids, use_cache=False, return_dict=False)
            return outputs[0]  # Return only logits

def export_traced_model(model_path: str, output_path: Path, seq_len: int = 64):
    """Export model using TorchScript tracing"""
    logger.info(f"Loading model from {model_path}")
    
    # Load model with minimal settings
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.float32,  # Use float32 for better compatibility
        device_map="cpu",
        trust_remote_code=True,
        low_cpu_mem_usage=True
    )
    model.eval()
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    
    # Create wrapper
    traced_model = TracedQwenModel(model)
    traced_model.eval()
    
    # Create sample input
    sample_input = torch.randint(1, min(tokenizer.vocab_size, 1000), (1, seq_len), dtype=torch.long)
    
    logger.info("Creating TorchScript traced model...")
    try:
        # Use tracing instead of scripting
        traced_script = torch.jit.trace(traced_model, sample_input)
        
        # Save the traced model
        output_path.parent.mkdir(parents=True, exist_ok=True)
        torch.jit.save(traced_script, str(output_path))
        
        logger.info(f"✅ Traced model saved to: {output_path}")
        
        # Test the traced model
        logger.info("Testing traced model...")
        with torch.no_grad():
            original_output = traced_model(sample_input)
            traced_output = traced_script(sample_input)
            diff = torch.abs(original_output - traced_output).max().item()
            logger.info(f"Max difference between original and traced: {diff}")
        
        return True
        
    except Exception as e:
        logger.error(f"Tracing failed: {e}")
        return False

def convert_traced_to_onnx(traced_path: Path, onnx_path: Path, seq_len: int = 64):
    """Convert traced model to ONNX"""
    logger.info(f"Converting traced model to ONNX: {onnx_path}")
    
    try:
        # Load traced model
        traced_model = torch.jit.load(str(traced_path))
        traced_model.eval()
        
        # Create sample input
        sample_input = torch.randint(1, 1000, (1, seq_len), dtype=torch.long)
        
        # Export to ONNX
        torch.onnx.export(
            traced_model,
            sample_input,
            str(onnx_path),
            input_names=["input_ids"],
            output_names=["logits"],
            opset_version=11,
            do_constant_folding=False,
            dynamic_axes=None,  # Fixed shapes
        )
        
        logger.info(f"✅ ONNX model saved to: {onnx_path}")
        return True
        
    except Exception as e:
        logger.error(f"ONNX conversion failed: {e}")
        return False

def main():
    parser = argparse.ArgumentParser(description="Export Qwen 2.5 using TorchScript tracing")
    parser.add_argument("--model-path", type=str, required=True, help="Model path or HF model name")
    parser.add_argument("--output-dir", type=Path, required=True, help="Output directory")
    parser.add_argument("--seq-len", type=int, default=64, help="Sequence length for tracing")
    parser.add_argument("--export-onnx", action="store_true", help="Also export to ONNX")
    
    args = parser.parse_args()
    
    # Create output directory
    args.output_dir.mkdir(parents=True, exist_ok=True)
    
    # Export traced model
    traced_path = args.output_dir / "model_traced.pt"
    success = export_traced_model(args.model_path, traced_path, args.seq_len)
    
    if success and args.export_onnx:
        onnx_path = args.output_dir / "model.onnx"
        convert_traced_to_onnx(traced_path, onnx_path, args.seq_len)

if __name__ == "__main__":
    main()