Spaces:
Running
Running
workaround for quantization and push
Browse files- QUANTIZATION_FIX_SUMMARY.md +165 -0
- requirements_quantization.txt +17 -0
- scripts/model_tonic/quantize_model.py +154 -22
- test_quantization_fix.py +149 -0
QUANTIZATION_FIX_SUMMARY.md
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Quantization Fix Summary
|
2 |
+
|
3 |
+
## Issues Identified
|
4 |
+
|
5 |
+
The quantization script was failing due to several compatibility issues:
|
6 |
+
|
7 |
+
1. **Int8 Quantization Error**:
|
8 |
+
- Error: `The model is quantized with QuantizationMethod.TORCHAO and is not serializable`
|
9 |
+
- Cause: Offloaded modules in the model cannot be quantized with torchao
|
10 |
+
- Solution: Added alternative save method and fallback to bitsandbytes
|
11 |
+
|
12 |
+
2. **Int4 Quantization Error**:
|
13 |
+
- Error: `Could not run 'aten::_convert_weight_to_int4pack_for_cpu' with arguments from the 'CUDA' backend`
|
14 |
+
- Cause: Int4 quantization requires CPU backend but was being attempted on CUDA
|
15 |
+
- Solution: Added proper device selection logic
|
16 |
+
|
17 |
+
3. **Monitoring Error**:
|
18 |
+
- Error: `'SmolLM3Monitor' object has no attribute 'log_event'`
|
19 |
+
- Cause: Incorrect monitoring API usage
|
20 |
+
- Solution: Added flexible monitoring method detection
|
21 |
+
|
22 |
+
## Fixes Implemented
|
23 |
+
|
24 |
+
### 1. Enhanced Device Management (`scripts/model_tonic/quantize_model.py`)
|
25 |
+
|
26 |
+
```python
|
27 |
+
def get_optimal_device(self, quant_type: str) -> str:
|
28 |
+
"""Get optimal device for quantization type"""
|
29 |
+
if quant_type == "int4_weight_only":
|
30 |
+
# Int4 quantization works better on CPU
|
31 |
+
return "cpu"
|
32 |
+
elif quant_type == "int8_weight_only":
|
33 |
+
# Int8 quantization works on GPU
|
34 |
+
if torch.cuda.is_available():
|
35 |
+
return "cuda"
|
36 |
+
else:
|
37 |
+
logger.warning("β οΈ CUDA not available, falling back to CPU for int8")
|
38 |
+
return "cpu"
|
39 |
+
else:
|
40 |
+
return "auto"
|
41 |
+
```
|
42 |
+
|
43 |
+
### 2. Alternative Quantization Method
|
44 |
+
|
45 |
+
Added `quantize_model_alternative()` method using bitsandbytes for better compatibility:
|
46 |
+
|
47 |
+
```python
|
48 |
+
def quantize_model_alternative(self, quant_type: str, device: str = "auto", group_size: int = 128, save_dir: Optional[str] = None) -> Optional[str]:
|
49 |
+
"""Alternative quantization using bitsandbytes for better compatibility"""
|
50 |
+
# Uses BitsAndBytesConfig instead of TorchAoConfig
|
51 |
+
# Handles serialization issues better
|
52 |
+
```
|
53 |
+
|
54 |
+
### 3. Improved Error Handling
|
55 |
+
|
56 |
+
- Added fallback from torchao to bitsandbytes
|
57 |
+
- Enhanced save method with alternative approaches
|
58 |
+
- Better device mapping for different quantization types
|
59 |
+
|
60 |
+
### 4. Fixed Monitoring Integration
|
61 |
+
|
62 |
+
```python
|
63 |
+
def log_to_trackio(self, action: str, details: Dict[str, Any]):
|
64 |
+
"""Log quantization events to Trackio"""
|
65 |
+
if self.monitor:
|
66 |
+
try:
|
67 |
+
# Use the correct monitoring method
|
68 |
+
if hasattr(self.monitor, 'log_event'):
|
69 |
+
self.monitor.log_event(action, details)
|
70 |
+
elif hasattr(self.monitor, 'log_metric'):
|
71 |
+
self.monitor.log_metric(action, details.get('value', 1.0))
|
72 |
+
elif hasattr(self.monitor, 'log'):
|
73 |
+
self.monitor.log(action, details)
|
74 |
+
else:
|
75 |
+
logger.info(f"π {action}: {details}")
|
76 |
+
except Exception as e:
|
77 |
+
logger.warning(f"β οΈ Failed to log to Trackio: {e}")
|
78 |
+
```
|
79 |
+
|
80 |
+
## Usage Instructions
|
81 |
+
|
82 |
+
### 1. Install Dependencies
|
83 |
+
|
84 |
+
```bash
|
85 |
+
pip install -r requirements_quantization.txt
|
86 |
+
```
|
87 |
+
|
88 |
+
### 2. Run Quantization
|
89 |
+
|
90 |
+
```bash
|
91 |
+
python3 quantize_and_push.py
|
92 |
+
```
|
93 |
+
|
94 |
+
### 3. Test Fixes
|
95 |
+
|
96 |
+
```bash
|
97 |
+
python3 test_quantization_fix.py
|
98 |
+
```
|
99 |
+
|
100 |
+
## Expected Behavior
|
101 |
+
|
102 |
+
### Successful Quantization
|
103 |
+
|
104 |
+
The script will now:
|
105 |
+
|
106 |
+
1. **Try torchao first** for each quantization type
|
107 |
+
2. **Fall back to bitsandbytes** if torchao fails
|
108 |
+
3. **Use appropriate devices** (CPU for int4, GPU for int8)
|
109 |
+
4. **Handle serialization issues** with alternative save methods
|
110 |
+
5. **Log progress** without monitoring errors
|
111 |
+
|
112 |
+
### Output
|
113 |
+
|
114 |
+
```
|
115 |
+
β
Model files validated
|
116 |
+
π Processing quantization type: int8_weight_only
|
117 |
+
π Using device: cuda
|
118 |
+
β
int8_weight_only quantization and push completed
|
119 |
+
π Processing quantization type: int4_weight_only
|
120 |
+
π Using device: cpu
|
121 |
+
β
int4_weight_only quantization and push completed
|
122 |
+
π Quantization summary: 2/2 successful
|
123 |
+
β
Quantization completed successfully!
|
124 |
+
```
|
125 |
+
|
126 |
+
## Troubleshooting
|
127 |
+
|
128 |
+
### If All Quantization Fails
|
129 |
+
|
130 |
+
1. **Install bitsandbytes**:
|
131 |
+
```bash
|
132 |
+
pip install bitsandbytes
|
133 |
+
```
|
134 |
+
|
135 |
+
2. **Check model path**:
|
136 |
+
```bash
|
137 |
+
ls -la /output-checkpoint
|
138 |
+
```
|
139 |
+
|
140 |
+
3. **Verify dependencies**:
|
141 |
+
```bash
|
142 |
+
python3 test_quantization_fix.py
|
143 |
+
```
|
144 |
+
|
145 |
+
### Common Issues
|
146 |
+
|
147 |
+
1. **Memory Issues**: Use CPU for int4 quantization
|
148 |
+
2. **Serialization Errors**: The script now handles these automatically
|
149 |
+
3. **Device Conflicts**: Automatic device selection based on quantization type
|
150 |
+
|
151 |
+
## Files Modified
|
152 |
+
|
153 |
+
1. `scripts/model_tonic/quantize_model.py` - Main quantization logic
|
154 |
+
2. `quantize_and_push.py` - Main script with better error handling
|
155 |
+
3. `test_quantization_fix.py` - Test script for verification
|
156 |
+
4. `requirements_quantization.txt` - Dependencies file
|
157 |
+
|
158 |
+
## Next Steps
|
159 |
+
|
160 |
+
1. Run the test script to verify fixes
|
161 |
+
2. Install bitsandbytes if not already installed
|
162 |
+
3. Run the quantization script
|
163 |
+
4. Check the Hugging Face repository for quantized models
|
164 |
+
|
165 |
+
The fixes ensure robust quantization with multiple fallback options and proper error handling.
|
requirements_quantization.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Quantization Dependencies
|
2 |
+
# Core quantization libraries
|
3 |
+
torchao>=0.1.0
|
4 |
+
bitsandbytes>=0.41.0
|
5 |
+
|
6 |
+
# Transformers with quantization support
|
7 |
+
transformers>=4.36.0
|
8 |
+
|
9 |
+
# Hugging Face Hub for model pushing
|
10 |
+
huggingface_hub>=0.19.0
|
11 |
+
|
12 |
+
# Optional: For better performance
|
13 |
+
accelerate>=0.24.0
|
14 |
+
safetensors>=0.4.0
|
15 |
+
|
16 |
+
# Optional: For monitoring
|
17 |
+
datasets>=2.14.0
|
scripts/model_tonic/quantize_model.py
CHANGED
@@ -101,27 +101,16 @@ class ModelQuantizer:
|
|
101 |
return False
|
102 |
|
103 |
# Check for essential model files
|
104 |
-
required_files = ['config.json']
|
105 |
optional_files = ['tokenizer.json', 'tokenizer_config.json']
|
106 |
|
107 |
-
|
108 |
-
model_files = [
|
109 |
-
"model.safetensors.index.json", # Safetensors format
|
110 |
-
"pytorch_model.bin" # PyTorch format
|
111 |
-
]
|
112 |
-
|
113 |
-
missing_files = []
|
114 |
for file in required_files:
|
115 |
if not (self.model_path / file).exists():
|
116 |
-
|
117 |
-
|
118 |
-
# Check if at least one model file exists
|
119 |
-
model_file_exists = any((self.model_path / file).exists() for file in model_files)
|
120 |
-
if not model_file_exists:
|
121 |
-
missing_files.extend(model_files)
|
122 |
|
123 |
-
if
|
124 |
-
logger.error(f"β Missing required model files: {
|
125 |
return False
|
126 |
|
127 |
logger.info(f"β
Model path validated: {self.model_path}")
|
@@ -144,6 +133,99 @@ class ModelQuantizer:
|
|
144 |
|
145 |
return TorchAoConfig(quant_type=quant_config)
|
146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
def quantize_model(
|
148 |
self,
|
149 |
quant_type: str,
|
@@ -162,15 +244,32 @@ class ModelQuantizer:
|
|
162 |
logger.info(f"π Device: {device}")
|
163 |
logger.info(f"π Group size: {group_size}")
|
164 |
|
|
|
|
|
|
|
|
|
|
|
165 |
# Create quantization config
|
166 |
quantization_config = self.create_quantization_config(quant_type, group_size)
|
167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
# Load and quantize the model
|
169 |
quantized_model = AutoModelForCausalLM.from_pretrained(
|
170 |
str(self.model_path),
|
171 |
-
torch_dtype=
|
172 |
-
device_map=
|
173 |
-
quantization_config=quantization_config
|
|
|
174 |
)
|
175 |
|
176 |
# Determine save directory
|
@@ -183,7 +282,24 @@ class ModelQuantizer:
|
|
183 |
|
184 |
# Save quantized model (don't use safetensors for torchao)
|
185 |
logger.info(f"πΎ Saving quantized model to: {save_path}")
|
186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
|
188 |
# Copy tokenizer files if they exist
|
189 |
tokenizer_files = ['tokenizer.json', 'tokenizer_config.json', 'special_tokens_map.json']
|
@@ -198,7 +314,9 @@ class ModelQuantizer:
|
|
198 |
|
199 |
except Exception as e:
|
200 |
logger.error(f"β Quantization failed: {e}")
|
201 |
-
|
|
|
|
|
202 |
|
203 |
def create_quantized_model_card(self, quant_type: str, original_model: str, subdir: str) -> str:
|
204 |
"""Create a model card for the quantized model"""
|
@@ -470,10 +588,24 @@ For questions and support, please open an issue on the Hugging Face repository.
|
|
470 |
"""Log quantization events to Trackio"""
|
471 |
if self.monitor:
|
472 |
try:
|
473 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
474 |
logger.info(f"π Logged to Trackio: {action}")
|
475 |
except Exception as e:
|
476 |
logger.warning(f"β οΈ Failed to log to Trackio: {e}")
|
|
|
|
|
|
|
477 |
|
478 |
def quantize_and_push(
|
479 |
self,
|
|
|
101 |
return False
|
102 |
|
103 |
# Check for essential model files
|
104 |
+
required_files = ['config.json', 'pytorch_model.bin']
|
105 |
optional_files = ['tokenizer.json', 'tokenizer_config.json']
|
106 |
|
107 |
+
missing_required = []
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
for file in required_files:
|
109 |
if not (self.model_path / file).exists():
|
110 |
+
missing_required.append(file)
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
+
if missing_required:
|
113 |
+
logger.error(f"β Missing required model files: {missing_required}")
|
114 |
return False
|
115 |
|
116 |
logger.info(f"β
Model path validated: {self.model_path}")
|
|
|
133 |
|
134 |
return TorchAoConfig(quant_type=quant_config)
|
135 |
|
136 |
+
def get_optimal_device(self, quant_type: str) -> str:
|
137 |
+
"""Get optimal device for quantization type"""
|
138 |
+
if quant_type == "int4_weight_only":
|
139 |
+
# Int4 quantization works better on CPU
|
140 |
+
return "cpu"
|
141 |
+
elif quant_type == "int8_weight_only":
|
142 |
+
# Int8 quantization works on GPU
|
143 |
+
if torch.cuda.is_available():
|
144 |
+
return "cuda"
|
145 |
+
else:
|
146 |
+
logger.warning("β οΈ CUDA not available, falling back to CPU for int8")
|
147 |
+
return "cpu"
|
148 |
+
else:
|
149 |
+
return "auto"
|
150 |
+
|
151 |
+
def quantize_model_alternative(
|
152 |
+
self,
|
153 |
+
quant_type: str,
|
154 |
+
device: str = "auto",
|
155 |
+
group_size: int = 128,
|
156 |
+
save_dir: Optional[str] = None
|
157 |
+
) -> Optional[str]:
|
158 |
+
"""Alternative quantization using bitsandbytes for better compatibility"""
|
159 |
+
try:
|
160 |
+
logger.info(f"π Attempting alternative quantization for: {quant_type}")
|
161 |
+
|
162 |
+
# Import bitsandbytes if available
|
163 |
+
try:
|
164 |
+
import bitsandbytes as bnb
|
165 |
+
from transformers import BitsAndBytesConfig
|
166 |
+
BNB_AVAILABLE = True
|
167 |
+
except ImportError:
|
168 |
+
BNB_AVAILABLE = False
|
169 |
+
logger.error("β bitsandbytes not available for alternative quantization")
|
170 |
+
return None
|
171 |
+
|
172 |
+
if not BNB_AVAILABLE:
|
173 |
+
return None
|
174 |
+
|
175 |
+
# Create bitsandbytes config
|
176 |
+
if quant_type == "int8_weight_only":
|
177 |
+
bnb_config = BitsAndBytesConfig(
|
178 |
+
load_in_8bit=True,
|
179 |
+
llm_int8_threshold=6.0,
|
180 |
+
llm_int8_has_fp16_weight=False
|
181 |
+
)
|
182 |
+
elif quant_type == "int4_weight_only":
|
183 |
+
bnb_config = BitsAndBytesConfig(
|
184 |
+
load_in_4bit=True,
|
185 |
+
bnb_4bit_compute_dtype=torch.float16,
|
186 |
+
bnb_4bit_use_double_quant=True,
|
187 |
+
bnb_4bit_quant_type="nf4"
|
188 |
+
)
|
189 |
+
else:
|
190 |
+
logger.error(f"β Unsupported quantization type for alternative method: {quant_type}")
|
191 |
+
return None
|
192 |
+
|
193 |
+
# Load model with bitsandbytes quantization
|
194 |
+
quantized_model = AutoModelForCausalLM.from_pretrained(
|
195 |
+
str(self.model_path),
|
196 |
+
quantization_config=bnb_config,
|
197 |
+
device_map="auto",
|
198 |
+
torch_dtype=torch.bfloat16,
|
199 |
+
low_cpu_mem_usage=True
|
200 |
+
)
|
201 |
+
|
202 |
+
# Determine save directory
|
203 |
+
if save_dir is None:
|
204 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
205 |
+
save_dir = f"quantized_{quant_type}_bnb_{timestamp}"
|
206 |
+
|
207 |
+
save_path = Path(save_dir)
|
208 |
+
save_path.mkdir(parents=True, exist_ok=True)
|
209 |
+
|
210 |
+
# Save quantized model
|
211 |
+
logger.info(f"πΎ Saving quantized model to: {save_path}")
|
212 |
+
quantized_model.save_pretrained(save_path, safe_serialization=False)
|
213 |
+
|
214 |
+
# Copy tokenizer files if they exist
|
215 |
+
tokenizer_files = ['tokenizer.json', 'tokenizer_config.json', 'special_tokens_map.json']
|
216 |
+
for file in tokenizer_files:
|
217 |
+
src_file = self.model_path / file
|
218 |
+
if src_file.exists():
|
219 |
+
shutil.copy2(src_file, save_path / file)
|
220 |
+
logger.info(f"π Copied {file}")
|
221 |
+
|
222 |
+
logger.info(f"β
Alternative quantization successful: {save_path}")
|
223 |
+
return str(save_path)
|
224 |
+
|
225 |
+
except Exception as e:
|
226 |
+
logger.error(f"β Alternative quantization failed: {e}")
|
227 |
+
return None
|
228 |
+
|
229 |
def quantize_model(
|
230 |
self,
|
231 |
quant_type: str,
|
|
|
244 |
logger.info(f"π Device: {device}")
|
245 |
logger.info(f"π Group size: {group_size}")
|
246 |
|
247 |
+
# Determine optimal device
|
248 |
+
if device == "auto":
|
249 |
+
device = self.get_optimal_device(quant_type)
|
250 |
+
logger.info(f"π Using device: {device}")
|
251 |
+
|
252 |
# Create quantization config
|
253 |
quantization_config = self.create_quantization_config(quant_type, group_size)
|
254 |
|
255 |
+
# Load model with appropriate device mapping
|
256 |
+
if device == "cpu":
|
257 |
+
device_map = "cpu"
|
258 |
+
torch_dtype = torch.float32
|
259 |
+
elif device == "cuda":
|
260 |
+
device_map = "auto"
|
261 |
+
torch_dtype = torch.bfloat16
|
262 |
+
else:
|
263 |
+
device_map = "auto"
|
264 |
+
torch_dtype = "auto"
|
265 |
+
|
266 |
# Load and quantize the model
|
267 |
quantized_model = AutoModelForCausalLM.from_pretrained(
|
268 |
str(self.model_path),
|
269 |
+
torch_dtype=torch_dtype,
|
270 |
+
device_map=device_map,
|
271 |
+
quantization_config=quantization_config,
|
272 |
+
low_cpu_mem_usage=True
|
273 |
)
|
274 |
|
275 |
# Determine save directory
|
|
|
282 |
|
283 |
# Save quantized model (don't use safetensors for torchao)
|
284 |
logger.info(f"πΎ Saving quantized model to: {save_path}")
|
285 |
+
|
286 |
+
# For torchao models, we need to handle serialization carefully
|
287 |
+
try:
|
288 |
+
quantized_model.save_pretrained(save_path, safe_serialization=False)
|
289 |
+
except Exception as save_error:
|
290 |
+
logger.warning(f"β οΈ Standard save failed: {save_error}")
|
291 |
+
logger.info("π Attempting alternative save method...")
|
292 |
+
|
293 |
+
# Try saving without quantization config
|
294 |
+
try:
|
295 |
+
# Remove quantization config temporarily
|
296 |
+
original_config = quantized_model.config.quantization_config
|
297 |
+
quantized_model.config.quantization_config = None
|
298 |
+
quantized_model.save_pretrained(save_path, safe_serialization=False)
|
299 |
+
quantized_model.config.quantization_config = original_config
|
300 |
+
except Exception as alt_save_error:
|
301 |
+
logger.error(f"β Alternative save also failed: {alt_save_error}")
|
302 |
+
return None
|
303 |
|
304 |
# Copy tokenizer files if they exist
|
305 |
tokenizer_files = ['tokenizer.json', 'tokenizer_config.json', 'special_tokens_map.json']
|
|
|
314 |
|
315 |
except Exception as e:
|
316 |
logger.error(f"β Quantization failed: {e}")
|
317 |
+
# Try alternative quantization method
|
318 |
+
logger.info("π Attempting alternative quantization method...")
|
319 |
+
return self.quantize_model_alternative(quant_type, device, group_size, save_dir)
|
320 |
|
321 |
def create_quantized_model_card(self, quant_type: str, original_model: str, subdir: str) -> str:
|
322 |
"""Create a model card for the quantized model"""
|
|
|
588 |
"""Log quantization events to Trackio"""
|
589 |
if self.monitor:
|
590 |
try:
|
591 |
+
# Use the correct monitoring method
|
592 |
+
if hasattr(self.monitor, 'log_event'):
|
593 |
+
self.monitor.log_event(action, details)
|
594 |
+
elif hasattr(self.monitor, 'log_metric'):
|
595 |
+
# Log as metric instead
|
596 |
+
self.monitor.log_metric(action, details.get('value', 1.0))
|
597 |
+
elif hasattr(self.monitor, 'log'):
|
598 |
+
# Use generic log method
|
599 |
+
self.monitor.log(action, details)
|
600 |
+
else:
|
601 |
+
# Just log locally if no monitoring method available
|
602 |
+
logger.info(f"π {action}: {details}")
|
603 |
logger.info(f"π Logged to Trackio: {action}")
|
604 |
except Exception as e:
|
605 |
logger.warning(f"β οΈ Failed to log to Trackio: {e}")
|
606 |
+
else:
|
607 |
+
# Log locally if no monitor available
|
608 |
+
logger.info(f"π {action}: {details}")
|
609 |
|
610 |
def quantize_and_push(
|
611 |
self,
|
test_quantization_fix.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Test script to verify quantization fixes
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
import logging
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
# Setup logging
|
12 |
+
logging.basicConfig(
|
13 |
+
level=logging.INFO,
|
14 |
+
format='%(asctime)s - %(levelname)s - %(message)s'
|
15 |
+
)
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
def test_quantization_imports():
|
19 |
+
"""Test that all required imports work"""
|
20 |
+
try:
|
21 |
+
# Test torchao imports
|
22 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
|
23 |
+
from torchao.quantization import (
|
24 |
+
Int8WeightOnlyConfig,
|
25 |
+
Int4WeightOnlyConfig,
|
26 |
+
Int8DynamicActivationInt8WeightConfig
|
27 |
+
)
|
28 |
+
from torchao.dtypes import Int4CPULayout
|
29 |
+
logger.info("β
torchao imports successful")
|
30 |
+
|
31 |
+
# Test bitsandbytes imports
|
32 |
+
try:
|
33 |
+
import bitsandbytes as bnb
|
34 |
+
from transformers import BitsAndBytesConfig
|
35 |
+
logger.info("β
bitsandbytes imports successful")
|
36 |
+
except ImportError:
|
37 |
+
logger.warning("β οΈ bitsandbytes not available - alternative quantization disabled")
|
38 |
+
|
39 |
+
# Test HF imports
|
40 |
+
from huggingface_hub import HfApi
|
41 |
+
logger.info("β
huggingface_hub imports successful")
|
42 |
+
|
43 |
+
return True
|
44 |
+
|
45 |
+
except ImportError as e:
|
46 |
+
logger.error(f"β Import failed: {e}")
|
47 |
+
return False
|
48 |
+
|
49 |
+
def test_model_quantizer():
|
50 |
+
"""Test ModelQuantizer initialization"""
|
51 |
+
try:
|
52 |
+
from scripts.model_tonic.quantize_model import ModelQuantizer
|
53 |
+
|
54 |
+
# Test with dummy values
|
55 |
+
quantizer = ModelQuantizer(
|
56 |
+
model_path="/output-checkpoint",
|
57 |
+
repo_name="test/test-repo",
|
58 |
+
token="dummy_token"
|
59 |
+
)
|
60 |
+
|
61 |
+
logger.info("β
ModelQuantizer initialization successful")
|
62 |
+
return True
|
63 |
+
|
64 |
+
except Exception as e:
|
65 |
+
logger.error(f"β ModelQuantizer test failed: {e}")
|
66 |
+
return False
|
67 |
+
|
68 |
+
def test_quantization_configs():
|
69 |
+
"""Test quantization config creation"""
|
70 |
+
try:
|
71 |
+
from scripts.model_tonic.quantize_model import ModelQuantizer
|
72 |
+
|
73 |
+
quantizer = ModelQuantizer(
|
74 |
+
model_path="/output-checkpoint",
|
75 |
+
repo_name="test/test-repo",
|
76 |
+
token="dummy_token"
|
77 |
+
)
|
78 |
+
|
79 |
+
# Test int8 config
|
80 |
+
config = quantizer.create_quantization_config("int8_weight_only", 128)
|
81 |
+
logger.info("β
int8_weight_only config creation successful")
|
82 |
+
|
83 |
+
# Test int4 config
|
84 |
+
config = quantizer.create_quantization_config("int4_weight_only", 128)
|
85 |
+
logger.info("β
int4_weight_only config creation successful")
|
86 |
+
|
87 |
+
return True
|
88 |
+
|
89 |
+
except Exception as e:
|
90 |
+
logger.error(f"β Quantization config test failed: {e}")
|
91 |
+
return False
|
92 |
+
|
93 |
+
def test_device_selection():
|
94 |
+
"""Test optimal device selection"""
|
95 |
+
try:
|
96 |
+
from scripts.model_tonic.quantize_model import ModelQuantizer
|
97 |
+
|
98 |
+
quantizer = ModelQuantizer(
|
99 |
+
model_path="/output-checkpoint",
|
100 |
+
repo_name="test/test-repo",
|
101 |
+
token="dummy_token"
|
102 |
+
)
|
103 |
+
|
104 |
+
# Test device selection
|
105 |
+
device = quantizer.get_optimal_device("int8_weight_only")
|
106 |
+
logger.info(f"β
int8 device selection: {device}")
|
107 |
+
|
108 |
+
device = quantizer.get_optimal_device("int4_weight_only")
|
109 |
+
logger.info(f"β
int4 device selection: {device}")
|
110 |
+
|
111 |
+
return True
|
112 |
+
|
113 |
+
except Exception as e:
|
114 |
+
logger.error(f"β Device selection test failed: {e}")
|
115 |
+
return False
|
116 |
+
|
117 |
+
def main():
|
118 |
+
"""Run all tests"""
|
119 |
+
logger.info("π§ͺ Testing quantization fixes...")
|
120 |
+
|
121 |
+
tests = [
|
122 |
+
("Import Test", test_quantization_imports),
|
123 |
+
("ModelQuantizer Test", test_model_quantizer),
|
124 |
+
("Config Creation Test", test_quantization_configs),
|
125 |
+
("Device Selection Test", test_device_selection),
|
126 |
+
]
|
127 |
+
|
128 |
+
passed = 0
|
129 |
+
total = len(tests)
|
130 |
+
|
131 |
+
for test_name, test_func in tests:
|
132 |
+
logger.info(f"\nπ Running {test_name}...")
|
133 |
+
if test_func():
|
134 |
+
passed += 1
|
135 |
+
logger.info(f"β
{test_name} passed")
|
136 |
+
else:
|
137 |
+
logger.error(f"β {test_name} failed")
|
138 |
+
|
139 |
+
logger.info(f"\nπ Test Results: {passed}/{total} tests passed")
|
140 |
+
|
141 |
+
if passed == total:
|
142 |
+
logger.info("π All tests passed! Quantization fixes are working.")
|
143 |
+
return 0
|
144 |
+
else:
|
145 |
+
logger.error("β Some tests failed. Please check the errors above.")
|
146 |
+
return 1
|
147 |
+
|
148 |
+
if __name__ == "__main__":
|
149 |
+
exit(main())
|