Spaces:
Running
Running
File size: 8,460 Bytes
d0d19b2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 |
#!/usr/bin/env python3
"""
Model Processing Script
Processes recovered model with quantization and pushing to HF Hub
"""
import os
import sys
import json
import logging
import subprocess
from pathlib import Path
from typing import Dict, Any, Optional
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class ModelProcessor:
"""Process recovered model with quantization and pushing"""
def __init__(self, model_path: str = "recovered_model"):
self.model_path = Path(model_path)
self.hf_token = os.getenv('HF_TOKEN')
def validate_model(self) -> bool:
"""Validate that the model can be loaded"""
try:
logger.info("π Validating model loading...")
# Try to load the model
cmd = [
sys.executable, "-c",
"from transformers import AutoModelForCausalLM; "
"model = AutoModelForCausalLM.from_pretrained('recovered_model', "
"torch_dtype='auto', device_map='auto'); "
"print('β
Model loaded successfully')"
]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
if result.returncode == 0:
logger.info("β
Model validation successful")
return True
else:
logger.error(f"β Model validation failed: {result.stderr}")
return False
except Exception as e:
logger.error(f"β Model validation error: {e}")
return False
def get_model_info(self) -> Dict[str, Any]:
"""Get information about the model"""
try:
# Load config
config_path = self.model_path / "config.json"
if config_path.exists():
with open(config_path, 'r') as f:
config = json.load(f)
else:
config = {}
# Calculate model size
total_size = 0
for file in self.model_path.rglob("*"):
if file.is_file():
total_size += file.stat().st_size
model_info = {
"model_type": config.get("model_type", "smollm3"),
"architectures": config.get("architectures", ["SmolLM3ForCausalLM"]),
"model_size_gb": total_size / (1024**3),
"vocab_size": config.get("vocab_size", 32000),
"hidden_size": config.get("hidden_size", 2048),
"num_attention_heads": config.get("num_attention_heads", 16),
"num_hidden_layers": config.get("num_hidden_layers", 24),
"max_position_embeddings": config.get("max_position_embeddings", 8192)
}
logger.info(f"π Model info: {model_info}")
return model_info
except Exception as e:
logger.error(f"β Failed to get model info: {e}")
return {}
def run_quantization(self, repo_name: str, quant_type: str = "int8_weight_only") -> bool:
"""Run quantization on the model"""
try:
logger.info(f"π Running quantization: {quant_type}")
# Check if quantization script exists
quantize_script = Path("scripts/model_tonic/quantize_model.py")
if not quantize_script.exists():
logger.error(f"β Quantization script not found: {quantize_script}")
return False
# Run quantization
cmd = [
sys.executable, str(quantize_script),
str(self.model_path),
repo_name,
"--quant-type", quant_type,
"--device", "auto"
]
if self.hf_token:
cmd.extend(["--token", self.hf_token])
logger.info(f"π Running: {' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=True, text=True, timeout=1800) # 30 min timeout
if result.returncode == 0:
logger.info("β
Quantization completed successfully")
logger.info(result.stdout)
return True
else:
logger.error("β Quantization failed")
logger.error(result.stderr)
return False
except subprocess.TimeoutExpired:
logger.error("β Quantization timed out")
return False
except Exception as e:
logger.error(f"β Failed to run quantization: {e}")
return False
def run_model_push(self, repo_name: str) -> bool:
"""Push the model to HF Hub"""
try:
logger.info(f"π Pushing model to: {repo_name}")
# Check if push script exists
push_script = Path("scripts/model_tonic/push_to_huggingface.py")
if not push_script.exists():
logger.error(f"β Push script not found: {push_script}")
return False
# Run push
cmd = [
sys.executable, str(push_script),
str(self.model_path),
repo_name
]
if self.hf_token:
cmd.extend(["--token", self.hf_token])
logger.info(f"π Running: {' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=True, text=True, timeout=1800) # 30 min timeout
if result.returncode == 0:
logger.info("β
Model push completed successfully")
logger.info(result.stdout)
return True
else:
logger.error("β Model push failed")
logger.error(result.stderr)
return False
except subprocess.TimeoutExpired:
logger.error("β Model push timed out")
return False
except Exception as e:
logger.error(f"β Failed to push model: {e}")
return False
def process_model(self, repo_name: str, quantize: bool = True, push: bool = True) -> bool:
"""Complete model processing workflow"""
logger.info("π Starting model processing...")
# Step 1: Validate model
if not self.validate_model():
logger.error("β Model validation failed")
return False
# Step 2: Get model info
model_info = self.get_model_info()
# Step 3: Quantize if requested
if quantize:
if not self.run_quantization(repo_name):
logger.error("β Quantization failed")
return False
# Step 4: Push if requested
if push:
if not self.run_model_push(repo_name):
logger.error("β Model push failed")
return False
logger.info("π Model processing completed successfully!")
logger.info(f"π View your model at: https://huggingface.co/{repo_name}")
return True
def main():
"""Main function"""
import argparse
parser = argparse.ArgumentParser(description="Process recovered model")
parser.add_argument("repo_name", help="Hugging Face repository name (username/model-name)")
parser.add_argument("--model-path", default="recovered_model", help="Path to recovered model")
parser.add_argument("--no-quantize", action="store_true", help="Skip quantization")
parser.add_argument("--no-push", action="store_true", help="Skip pushing to HF Hub")
parser.add_argument("--quant-type", default="int8_weight_only",
choices=["int8_weight_only", "int4_weight_only", "int8_dynamic"],
help="Quantization type")
args = parser.parse_args()
# Initialize processor
processor = ModelProcessor(args.model_path)
# Process model
success = processor.process_model(
repo_name=args.repo_name,
quantize=not args.no_quantize,
push=not args.no_push
)
return 0 if success else 1
if __name__ == "__main__":
exit(main()) |