|
|
|
""" |
|
Convert Qwen 2.5 models to ONNX format for QNN compatibility |
|
""" |
|
|
|
import argparse |
|
import gc |
|
import json |
|
import logging |
|
import os |
|
import sys |
|
import warnings |
|
from pathlib import Path |
|
from typing import Dict, List, Optional, Tuple |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import functional as F |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig |
|
import onnx |
|
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class SimpleQwenModel(nn.Module): |
|
"""Simple wrapper for Qwen model that avoids cache-related issues""" |
|
|
|
def __init__(self, original_model): |
|
super().__init__() |
|
self.original_model = original_model |
|
|
|
def forward(self, input_ids): |
|
|
|
try: |
|
outputs = self.original_model( |
|
input_ids=input_ids, |
|
use_cache=False, |
|
return_dict=False |
|
) |
|
|
|
if isinstance(outputs, tuple): |
|
return outputs[0] |
|
else: |
|
return outputs.logits |
|
except Exception as e: |
|
|
|
with torch.no_grad(): |
|
outputs = self.original_model(input_ids) |
|
if hasattr(outputs, 'logits'): |
|
return outputs.logits |
|
else: |
|
return outputs[0] if isinstance(outputs, tuple) else outputs |
|
|
|
class QwenONNXExporter: |
|
"""ONNX exporter for Qwen 2.5 models with QNN-specific optimizations""" |
|
|
|
def __init__(self, model_path: Path): |
|
self.model_path = model_path |
|
self.tokenizer = None |
|
self.model = None |
|
self.config = None |
|
self.wrapped_model = None |
|
|
|
def load_model(self): |
|
"""Load the Qwen model and tokenizer""" |
|
logger.info(f"Loading model from {self.model_path}") |
|
|
|
try: |
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path) |
|
self.config = AutoConfig.from_pretrained(self.model_path) |
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
self.model_path, |
|
torch_dtype=torch.float16, |
|
device_map="cpu", |
|
trust_remote_code=True, |
|
low_cpu_mem_usage=True, |
|
use_safetensors=True, |
|
) |
|
|
|
self.model.eval() |
|
logger.info("Model loaded successfully") |
|
|
|
|
|
gc.collect() |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
logger.info("Memory cleanup completed") |
|
|
|
except Exception as e: |
|
logger.error(f"Failed to load model: {e}") |
|
raise |
|
|
|
def prepare_model_for_onnx(self): |
|
"""Prepare model for ONNX export by fixing dynamic shapes""" |
|
logger.info("Preparing model for ONNX export...") |
|
|
|
try: |
|
|
|
|
|
if hasattr(self.model, "generation_config"): |
|
self.model.generation_config.max_length = 2048 |
|
|
|
|
|
for param in self.model.parameters(): |
|
param.requires_grad = False |
|
|
|
|
|
self.model.eval() |
|
|
|
|
|
self.wrapped_model = SimpleQwenModel(self.model) |
|
self.wrapped_model.eval() |
|
|
|
logger.info("Model preparation completed") |
|
|
|
except Exception as e: |
|
logger.warning(f"Model preparation encountered issues: {e}") |
|
logger.info("Continuing with basic model preparation") |
|
|
|
def _fix_attention_patterns(self): |
|
"""Fix attention patterns that may not be compatible with ONNX/QNN""" |
|
|
|
|
|
pass |
|
|
|
def create_sample_inputs(self, batch_size: int = 1, seq_len: int = 128) -> Dict: |
|
"""Create sample inputs for ONNX export""" |
|
logger.info( |
|
f"Creating sample inputs (batch_size={batch_size}, seq_len={seq_len})" |
|
) |
|
|
|
|
|
vocab_size = min(self.tokenizer.vocab_size, 32000) |
|
input_ids = torch.randint( |
|
1, vocab_size - 1, (batch_size, seq_len), dtype=torch.long |
|
) |
|
|
|
|
|
|
|
inputs = {"input_ids": input_ids} |
|
|
|
return inputs |
|
|
|
def export_to_onnx( |
|
self, |
|
output_path: Path, |
|
batch_size: int = 1, |
|
seq_len: int = 128, |
|
opset_version: int = 17, |
|
optimize: bool = True, |
|
) -> Dict: |
|
"""Export model to ONNX format""" |
|
logger.info(f"Exporting to ONNX: {output_path}") |
|
logger.info(f"ONNX opset version: {opset_version}") |
|
|
|
try: |
|
|
|
sample_inputs = self.create_sample_inputs(batch_size, seq_len) |
|
|
|
|
|
input_names = list(sample_inputs.keys()) |
|
output_names = ["logits"] |
|
|
|
|
|
|
|
|
|
|
|
logger.info("Using legacy ONNX exporter for better compatibility") |
|
|
|
|
|
simplified_inputs = { |
|
"input_ids": sample_inputs["input_ids"], |
|
} |
|
|
|
|
|
with torch.no_grad(): |
|
torch.onnx.export( |
|
self.wrapped_model, |
|
tuple(simplified_inputs.values()), |
|
str(output_path), |
|
input_names=["input_ids"], |
|
output_names=["logits"], |
|
opset_version=9, |
|
do_constant_folding=False, |
|
verbose=True, |
|
training=torch.onnx.TrainingMode.EVAL, |
|
export_params=True, |
|
operator_export_type=torch.onnx.OperatorExportTypes.ONNX, |
|
) |
|
|
|
|
|
onnx_model = onnx.load(str(output_path)) |
|
onnx.checker.check_model(onnx_model) |
|
|
|
logger.info("ONNX export successful") |
|
|
|
|
|
if optimize: |
|
logger.info("Optimizing ONNX model...") |
|
onnx_model = self._optimize_onnx_model(onnx_model) |
|
onnx.save(onnx_model, str(output_path)) |
|
logger.info("ONNX optimization completed") |
|
|
|
|
|
export_info = { |
|
"model_path": str(self.model_path), |
|
"onnx_path": str(output_path), |
|
"batch_size": batch_size, |
|
"sequence_length": seq_len, |
|
"opset_version": opset_version, |
|
"input_names": input_names, |
|
"output_names": output_names, |
|
"model_size_mb": output_path.stat().st_size / (1024 * 1024), |
|
"vocab_size": self.tokenizer.vocab_size, |
|
"hidden_size": self.config.hidden_size, |
|
"num_layers": self.config.num_hidden_layers, |
|
"num_heads": self.config.num_attention_heads, |
|
} |
|
|
|
return export_info |
|
|
|
except Exception as e: |
|
logger.error(f"ONNX export failed: {e}") |
|
raise |
|
|
|
def _optimize_onnx_model(self, onnx_model): |
|
"""Apply ONNX optimizations for better QNN compatibility""" |
|
try: |
|
|
|
|
|
logger.info("Using basic ONNX model without additional optimizations") |
|
return onnx_model |
|
|
|
except Exception as e: |
|
logger.warning(f"ONNX optimization failed: {e}, using original model") |
|
return onnx_model |
|
|
|
def validate_onnx_export(self, onnx_path: Path, sample_inputs: Dict) -> bool: |
|
"""Validate the exported ONNX model""" |
|
logger.info("Validating ONNX export...") |
|
|
|
try: |
|
import onnxruntime as ort |
|
|
|
|
|
ort_session = ort.InferenceSession(str(onnx_path)) |
|
|
|
|
|
ort_inputs = { |
|
name: tensor.numpy() for name, tensor in sample_inputs.items() |
|
} |
|
|
|
|
|
ort_outputs = ort_session.run(None, ort_inputs) |
|
|
|
|
|
with torch.no_grad(): |
|
torch_outputs = self.model(**sample_inputs) |
|
torch_logits = torch_outputs.logits.numpy() |
|
|
|
|
|
max_diff = abs(ort_outputs[0] - torch_logits).max() |
|
|
|
if max_diff < 1e-3: |
|
logger.info(f"ONNX validation successful (max_diff: {max_diff:.6f})") |
|
return True |
|
else: |
|
logger.warning( |
|
f"ONNX validation warning: max difference = {max_diff:.6f}" |
|
) |
|
return True |
|
|
|
except Exception as e: |
|
logger.error(f"ONNX validation failed: {e}") |
|
return False |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="Convert Qwen 2.5 to ONNX") |
|
parser.add_argument( |
|
"--model-path", type=Path, required=True, help="Path to Qwen model" |
|
) |
|
parser.add_argument( |
|
"--output-path", type=Path, required=True, help="Output ONNX file path" |
|
) |
|
parser.add_argument( |
|
"--batch-size", type=int, default=1, help="Batch size for export" |
|
) |
|
parser.add_argument( |
|
"--seq-len", type=int, default=512, help="Sequence length for export" |
|
) |
|
parser.add_argument( |
|
"--opset-version", type=int, default=17, help="ONNX opset version" |
|
) |
|
parser.add_argument( |
|
"--no-optimize", action="store_true", help="Skip ONNX optimization" |
|
) |
|
parser.add_argument( |
|
"--optimize-for-mobile", |
|
action="store_true", |
|
help="Optimize for mobile deployment", |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if not args.model_path.exists(): |
|
logger.error(f"Model path does not exist: {args.model_path}") |
|
sys.exit(1) |
|
|
|
|
|
args.output_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
try: |
|
|
|
exporter = QwenONNXExporter(args.model_path) |
|
|
|
|
|
exporter.load_model() |
|
exporter.prepare_model_for_onnx() |
|
|
|
|
|
export_info = exporter.export_to_onnx( |
|
output_path=args.output_path, |
|
batch_size=args.batch_size, |
|
seq_len=args.seq_len, |
|
opset_version=args.opset_version, |
|
optimize=not args.no_optimize, |
|
) |
|
|
|
|
|
info_path = ( |
|
args.output_path.parent / f"{args.output_path.stem}_export_info.json" |
|
) |
|
with open(info_path, "w") as f: |
|
json.dump(export_info, f, indent=2) |
|
|
|
logger.info(f"Export completed successfully!") |
|
logger.info(f"ONNX model: {args.output_path}") |
|
logger.info(f"Model size: {export_info['model_size_mb']:.1f} MB") |
|
logger.info(f"Export info: {info_path}") |
|
|
|
except Exception as e: |
|
logger.error(f"Conversion failed: {e}") |
|
sys.exit(1) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|