File size: 6,970 Bytes
54d6eb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import os
import json
import torch
import argparse
import logging
from transformers import AutoTokenizer, AutoModelForCausalLM

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')
logger = logging.getLogger(__name__)

def debug_weights_structure(weights_path):
    """Examine the structure of the weights file to help debug loading issues"""
    weights = torch.load(weights_path, map_location="cpu")
    logger.info(f"Type of loaded weights: {type(weights)}")
    if isinstance(weights, dict):
        logger.info(f"Top-level keys: {list(weights.keys())}")
        # Print a few sample keys to understand the structure
        sample_keys = list(weights.keys())[:5]
        for key in sample_keys:
            logger.info(f"Sample key structure: {key} -> {type(weights[key])}")
    return weights

def main():
    parser = argparse.ArgumentParser(description="Run inference with a TEQ-quantized model")
    parser.add_argument("--model_dir", type=str, default=".",
                        help="Directory containing quantized model files")
    parser.add_argument("--weights_file", type=str, default="quantized_weight.pt",
                        help="Name of the quantized weights file")
    parser.add_argument("--config_file", type=str, default="qconfig.json",
                        help="Name of the quantization config file")
    parser.add_argument("--base_model", type=str, required=True,
                        help="Original model name or path (for tokenizer and model architecture)")
    parser.add_argument("--prompt", type=str, default="Once upon a time, a little girl",
                        help="Text prompt for inference")
    parser.add_argument("--max_new_tokens", type=int, default=100,
                        help="Maximum number of new tokens to generate")
    parser.add_argument("--device", type=str, default="cpu", choices=["cpu", "cuda", "xpu"],
                        help="Device to run inference on")
    parser.add_argument("--output_file", type=str, default=None,
                        help="File to save the generated text to (optional)")
    parser.add_argument("--debug", action="store_true",
                        help="Print additional debug information")
    args = parser.parse_args()

    # Set up paths
    weights_path = os.path.join(args.model_dir, args.weights_file)
    config_path = os.path.join(args.model_dir, args.config_file)
    
    # Check if files exist
    if not os.path.exists(weights_path):
        raise FileNotFoundError(f"Quantized weights file not found: {weights_path}")
    if not os.path.exists(config_path):
        raise FileNotFoundError(f"Quantization config file not found: {config_path}")
    
    # Load tokenizer
    logger.info(f"Loading tokenizer from {args.base_model}...")
    tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True)
    
    # Examine the structure of the weights file
    logger.info(f"Analyzing weights structure from {weights_path}...")
    weights = debug_weights_structure(weights_path)
    
    # Load the base model directly (bypassing TEQ quantization)
    logger.info(f"Loading base model from {args.base_model}...")
    model = AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)
    
    # Print model's state_dict keys for debugging
    if args.debug:
        model_keys = list(model.state_dict().keys())
        logger.info(f"Model has {len(model_keys)} keys in state_dict")
        logger.info(f"Sample model keys: {model_keys[:5]}")
    
    # Check if weights contains 'state_dict' key and adjust accordingly
    if 'state_dict' in weights:
        logger.info("Found 'state_dict' key in weights file, extracting it...")
        weights = weights['state_dict']
    
    # Try to match the weights to the model structure
    try:
        # First attempt: Direct loading
        logger.info("Attempting to load weights directly...")
        missing_keys, unexpected_keys = model.load_state_dict(weights, strict=False)
        
        if missing_keys:
            logger.warning(f"Missing {len(missing_keys)} keys in state_dict")
            if args.debug:
                logger.warning(f"Sample missing keys: {missing_keys[:5]}")
        
        if unexpected_keys:
            logger.warning(f"Found {len(unexpected_keys)} unexpected keys in state_dict")
            if args.debug:
                logger.warning(f"Sample unexpected keys: {unexpected_keys[:5]}")
                
        # Validate if we have critical missing keys
        if len(missing_keys) > len(model.state_dict()) * 0.5:
            logger.error("Too many missing keys! Weight loading may have failed")
            
    except Exception as e:
        logger.error(f"Error loading weights: {str(e)}")
        logger.info("Attempting to transform keys to match model structure...")
        
        # Create a transformed state_dict
        transformed_weights = {}
        
        # Try removing 'module.' prefix
        for key in weights:
            if key.startswith('module.'):
                transformed_weights[key[7:]] = weights[key]
            else:
                transformed_weights[key] = weights[key]
        
        # Try loading the transformed weights
        missing_keys, unexpected_keys = model.load_state_dict(transformed_weights, strict=False)
        logger.info(f"After transformation: {len(missing_keys)} missing keys, {len(unexpected_keys)} unexpected keys")
    
    # Put model in evaluation mode
    model.eval()
    
    # Move model to specified device
    device = args.device
    logger.info(f"Moving model to {device}...")
    model = model.to(device)
    
    # Optimize with IPEX if using Intel hardware
    if device == "xpu":
        try:
            import intel_extension_for_pytorch as ipex
            logger.info("Optimizing model with IPEX...")
            model = ipex.optimize(model, dtype=torch.float16)
        except ImportError:
            logger.warning("IPEX not available, skipping optimization")
    
    # Run inference
    logger.info(f"Generating text for prompt: '{args.prompt}'")
    inputs = tokenizer(args.prompt, return_tensors="pt").to(device)
    
    # Generate text
    with torch.no_grad():
        output_ids = model.generate(
            inputs["input_ids"],
            max_new_tokens=args.max_new_tokens,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
        )
    
    # Decode the generated tokens
    generated_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
    logger.info("\nGenerated text:")
    logger.info("-" * 50)
    logger.info(generated_text)
    logger.info("-" * 50)
    
    # Save to file if specified
    if args.output_file:
        with open(args.output_file, 'w') as f:
            f.write(generated_text)
        logger.info(f"Generated text saved to {args.output_file}")

if __name__ == "__main__":
    main()