|
|
|
""" |
|
Export Qwen 2.5 models using TorchScript tracing instead of ONNX |
|
This approach may work better for modern transformer architectures |
|
""" |
|
|
|
import argparse |
|
import torch |
|
import torch.nn as nn |
|
from pathlib import Path |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import logging |
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
|
logger = logging.getLogger(__name__) |
|
|
|
class TracedQwenModel(nn.Module): |
|
"""Wrapper for Qwen model that's optimized for tracing""" |
|
|
|
def __init__(self, original_model): |
|
super().__init__() |
|
self.model = original_model |
|
|
|
def forward(self, input_ids): |
|
|
|
with torch.no_grad(): |
|
outputs = self.model(input_ids=input_ids, use_cache=False, return_dict=False) |
|
return outputs[0] |
|
|
|
def export_traced_model(model_path: str, output_path: Path, seq_len: int = 64): |
|
"""Export model using TorchScript tracing""" |
|
logger.info(f"Loading model from {model_path}") |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_path, |
|
torch_dtype=torch.float32, |
|
device_map="cpu", |
|
trust_remote_code=True, |
|
low_cpu_mem_usage=True |
|
) |
|
model.eval() |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
|
|
|
|
|
traced_model = TracedQwenModel(model) |
|
traced_model.eval() |
|
|
|
|
|
sample_input = torch.randint(1, min(tokenizer.vocab_size, 1000), (1, seq_len), dtype=torch.long) |
|
|
|
logger.info("Creating TorchScript traced model...") |
|
try: |
|
|
|
traced_script = torch.jit.trace(traced_model, sample_input) |
|
|
|
|
|
output_path.parent.mkdir(parents=True, exist_ok=True) |
|
torch.jit.save(traced_script, str(output_path)) |
|
|
|
logger.info(f"✅ Traced model saved to: {output_path}") |
|
|
|
|
|
logger.info("Testing traced model...") |
|
with torch.no_grad(): |
|
original_output = traced_model(sample_input) |
|
traced_output = traced_script(sample_input) |
|
diff = torch.abs(original_output - traced_output).max().item() |
|
logger.info(f"Max difference between original and traced: {diff}") |
|
|
|
return True |
|
|
|
except Exception as e: |
|
logger.error(f"Tracing failed: {e}") |
|
return False |
|
|
|
def convert_traced_to_onnx(traced_path: Path, onnx_path: Path, seq_len: int = 64): |
|
"""Convert traced model to ONNX""" |
|
logger.info(f"Converting traced model to ONNX: {onnx_path}") |
|
|
|
try: |
|
|
|
traced_model = torch.jit.load(str(traced_path)) |
|
traced_model.eval() |
|
|
|
|
|
sample_input = torch.randint(1, 1000, (1, seq_len), dtype=torch.long) |
|
|
|
|
|
torch.onnx.export( |
|
traced_model, |
|
sample_input, |
|
str(onnx_path), |
|
input_names=["input_ids"], |
|
output_names=["logits"], |
|
opset_version=11, |
|
do_constant_folding=False, |
|
dynamic_axes=None, |
|
) |
|
|
|
logger.info(f"✅ ONNX model saved to: {onnx_path}") |
|
return True |
|
|
|
except Exception as e: |
|
logger.error(f"ONNX conversion failed: {e}") |
|
return False |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="Export Qwen 2.5 using TorchScript tracing") |
|
parser.add_argument("--model-path", type=str, required=True, help="Model path or HF model name") |
|
parser.add_argument("--output-dir", type=Path, required=True, help="Output directory") |
|
parser.add_argument("--seq-len", type=int, default=64, help="Sequence length for tracing") |
|
parser.add_argument("--export-onnx", action="store_true", help="Also export to ONNX") |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
args.output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
traced_path = args.output_dir / "model_traced.pt" |
|
success = export_traced_model(args.model_path, traced_path, args.seq_len) |
|
|
|
if success and args.export_onnx: |
|
onnx_path = args.output_dir / "model.onnx" |
|
convert_traced_to_onnx(traced_path, onnx_path, args.seq_len) |
|
|
|
if __name__ == "__main__": |
|
main() |