import torch
from transformers import AutoTokenizer, AutoConfig, DistilBertForQuestionAnswering  # Correct import
import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType
import os
import logging
from typing import Optional, Dict, Any

class ONNXModelConverter:
    def __init__(self, model_name: str, output_dir: str):
        self.model_name = model_name
        self.output_dir = output_dir
        self.setup_logging()

        os.makedirs(output_dir, exist_ok=True)

        self.logger.info(f"Loading tokenizer {model_name}...")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

        self.logger.info(f"Loading model config {model_name}...")
        config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)

        self.logger.info(f"Loading model {model_name}...")
        try:
            self.model = DistilBertForQuestionAnswering.from_pretrained(
                model_name,
                config=config,
                trust_remote_code=True,
                torch_dtype=torch.float32  # Keep this for consistency, though it might not be strictly necessary
            )
        except Exception as e: # Catch the exception if pytorch weights are not found
            self.logger.info(f"Trying to load tensorflow weights")
            try:
                self.model = DistilBertForQuestionAnswering.from_pretrained(
                    model_name,
                    config=config,
                    trust_remote_code=True,
                    from_tf=True # Load from TensorFlow weights
                )
            except Exception as e:
                self.logger.error(f"Failed to load the model: {e}")
                raise # Re-raise the exception after logging

        self.model.eval()

    def setup_logging(self):
        self.logger = logging.getLogger(__name__)
        self.logger.setLevel(logging.INFO)
        handler = logging.StreamHandler()
        formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
        handler.setFormatter(formatter)
        self.logger.addHandler(handler)

    def prepare_dummy_inputs(self):
        dummy_input = self.tokenizer(
            "Hello, how are you?",
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=128
        )
        dummy_input.pop('token_type_ids', None)
        return {
            'input_ids': dummy_input['input_ids'],
            'attention_mask': dummy_input['attention_mask'],
        }

    def export_to_onnx(self):
        output_path = os.path.join(self.output_dir, "model.onnx")
        inputs = self.prepare_dummy_inputs()

        dynamic_axes = {
            'input_ids': {0: 'batch_size', 1: 'sequence_length'},
            'attention_mask': {0: 'batch_size', 1: 'sequence_length'},
            'start_logits': {0: 'batch_size', 1: 'sequence_length'},
            'end_logits': {0: 'batch_size', 1: 'sequence_length'},
        }

        class ModelWrapper(torch.nn.Module):
            def __init__(self, model):
                super().__init__()
                self.model = model

            def forward(self, input_ids, attention_mask):
                outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
                return outputs.start_logits, outputs.end_logits

        wrapped_model = ModelWrapper(self.model)

        try:
            torch.onnx.export(
                wrapped_model,
                (inputs['input_ids'], inputs['attention_mask']),
                output_path,
                export_params=True,
                opset_version=14,  # Or a suitable version
                do_constant_folding=True,
                input_names=['input_ids', 'attention_mask'],
                output_names=['start_logits', 'end_logits'],
                dynamic_axes=dynamic_axes,
                verbose=False
            )
            self.logger.info(f"Model exported to {output_path}")
            return output_path
        except Exception as e:
            self.logger.error(f"ONNX export failed: {str(e)}")
            raise

    def verify_model(self, model_path: str):
        try:
            onnx_model = onnx.load(model_path)
            onnx.checker.check_model(onnx_model)
            self.logger.info("ONNX model verification successful")
            return True
        except Exception as e:
            self.logger.error(f"Model verification failed: {str(e)}")
            return False

    def quantize_model(self, model_path: str):
        weight_types = {'int4':QuantType.QInt4, 'int8':QuantType.QInt8, 'uint4':QuantType.QUInt4, 'uint8':QuantType.QUInt8, 'uint16':QuantType.QUInt16, 'int16':QuantType.QInt16}
        all_quantized_paths = []
        for weight_type in weight_types.keys():
            quantized_path = os.path.join(self.output_dir, "model_" + weight_type + ".onnx")

            try:
                quantize_dynamic(
                    model_path,
                    quantized_path,
                    weight_type=weight_types[weight_type]
                )
                self.logger.info(f"Model quantized ({weight_type}) and saved to {quantized_path}")
                all_quantized_paths.append(quantized_path)
            except Exception as e:
                self.logger.error(f"Quantization ({weight_type}) failed: {str(e)}")
                raise

        return all_quantized_paths

    def convert(self):
        try:
            onnx_path = self.export_to_onnx()

            if self.verify_model(onnx_path):
                quantized_paths = self.quantize_model(onnx_path)

                tokenizer_path = os.path.join(self.output_dir, "tokenizer")
                self.tokenizer.save_pretrained(tokenizer_path)
                self.logger.info(f"Tokenizer saved to {tokenizer_path}")

                return {
                    'onnx_model': onnx_path,
                    'quantized_models': quantized_paths,
                    'tokenizer': tokenizer_path
                }
            else:
                raise Exception("Model verification failed")

        except Exception as e:
            self.logger.error(f"Conversion process failed: {str(e)}")
            raise

if __name__ == "__main__":
    MODEL_NAME = "Docty/question_and_answer"  # Or any other suitable model
    OUTPUT_DIR = "onnx"

    try:
        converter = ONNXModelConverter(MODEL_NAME, OUTPUT_DIR)
        results = converter.convert()

        print("\nConversion completed successfully!")
        print(f"ONNX model path: {results['onnx_model']}")
        print(f"Quantized model paths: {results['quantized_models']}")
        print(f"Tokenizer path: {results['tokenizer']}")

    except Exception as e:
        print(f"Conversion failed: {str(e)}")