SmolFactory / process_model.py
Tonic's picture
Fix model recovery and deployment scripts - add safetensors support and Windows compatibility
d0d19b2 verified
raw
history blame
8.46 kB
#!/usr/bin/env python3
"""
Model Processing Script
Processes recovered model with quantization and pushing to HF Hub
"""
import os
import sys
import json
import logging
import subprocess
from pathlib import Path
from typing import Dict, Any, Optional
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class ModelProcessor:
"""Process recovered model with quantization and pushing"""
def __init__(self, model_path: str = "recovered_model"):
self.model_path = Path(model_path)
self.hf_token = os.getenv('HF_TOKEN')
def validate_model(self) -> bool:
"""Validate that the model can be loaded"""
try:
logger.info("πŸ” Validating model loading...")
# Try to load the model
cmd = [
sys.executable, "-c",
"from transformers import AutoModelForCausalLM; "
"model = AutoModelForCausalLM.from_pretrained('recovered_model', "
"torch_dtype='auto', device_map='auto'); "
"print('βœ… Model loaded successfully')"
]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
if result.returncode == 0:
logger.info("βœ… Model validation successful")
return True
else:
logger.error(f"❌ Model validation failed: {result.stderr}")
return False
except Exception as e:
logger.error(f"❌ Model validation error: {e}")
return False
def get_model_info(self) -> Dict[str, Any]:
"""Get information about the model"""
try:
# Load config
config_path = self.model_path / "config.json"
if config_path.exists():
with open(config_path, 'r') as f:
config = json.load(f)
else:
config = {}
# Calculate model size
total_size = 0
for file in self.model_path.rglob("*"):
if file.is_file():
total_size += file.stat().st_size
model_info = {
"model_type": config.get("model_type", "smollm3"),
"architectures": config.get("architectures", ["SmolLM3ForCausalLM"]),
"model_size_gb": total_size / (1024**3),
"vocab_size": config.get("vocab_size", 32000),
"hidden_size": config.get("hidden_size", 2048),
"num_attention_heads": config.get("num_attention_heads", 16),
"num_hidden_layers": config.get("num_hidden_layers", 24),
"max_position_embeddings": config.get("max_position_embeddings", 8192)
}
logger.info(f"πŸ“Š Model info: {model_info}")
return model_info
except Exception as e:
logger.error(f"❌ Failed to get model info: {e}")
return {}
def run_quantization(self, repo_name: str, quant_type: str = "int8_weight_only") -> bool:
"""Run quantization on the model"""
try:
logger.info(f"πŸ”„ Running quantization: {quant_type}")
# Check if quantization script exists
quantize_script = Path("scripts/model_tonic/quantize_model.py")
if not quantize_script.exists():
logger.error(f"❌ Quantization script not found: {quantize_script}")
return False
# Run quantization
cmd = [
sys.executable, str(quantize_script),
str(self.model_path),
repo_name,
"--quant-type", quant_type,
"--device", "auto"
]
if self.hf_token:
cmd.extend(["--token", self.hf_token])
logger.info(f"πŸš€ Running: {' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=True, text=True, timeout=1800) # 30 min timeout
if result.returncode == 0:
logger.info("βœ… Quantization completed successfully")
logger.info(result.stdout)
return True
else:
logger.error("❌ Quantization failed")
logger.error(result.stderr)
return False
except subprocess.TimeoutExpired:
logger.error("❌ Quantization timed out")
return False
except Exception as e:
logger.error(f"❌ Failed to run quantization: {e}")
return False
def run_model_push(self, repo_name: str) -> bool:
"""Push the model to HF Hub"""
try:
logger.info(f"πŸ”„ Pushing model to: {repo_name}")
# Check if push script exists
push_script = Path("scripts/model_tonic/push_to_huggingface.py")
if not push_script.exists():
logger.error(f"❌ Push script not found: {push_script}")
return False
# Run push
cmd = [
sys.executable, str(push_script),
str(self.model_path),
repo_name
]
if self.hf_token:
cmd.extend(["--token", self.hf_token])
logger.info(f"πŸš€ Running: {' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=True, text=True, timeout=1800) # 30 min timeout
if result.returncode == 0:
logger.info("βœ… Model push completed successfully")
logger.info(result.stdout)
return True
else:
logger.error("❌ Model push failed")
logger.error(result.stderr)
return False
except subprocess.TimeoutExpired:
logger.error("❌ Model push timed out")
return False
except Exception as e:
logger.error(f"❌ Failed to push model: {e}")
return False
def process_model(self, repo_name: str, quantize: bool = True, push: bool = True) -> bool:
"""Complete model processing workflow"""
logger.info("πŸš€ Starting model processing...")
# Step 1: Validate model
if not self.validate_model():
logger.error("❌ Model validation failed")
return False
# Step 2: Get model info
model_info = self.get_model_info()
# Step 3: Quantize if requested
if quantize:
if not self.run_quantization(repo_name):
logger.error("❌ Quantization failed")
return False
# Step 4: Push if requested
if push:
if not self.run_model_push(repo_name):
logger.error("❌ Model push failed")
return False
logger.info("πŸŽ‰ Model processing completed successfully!")
logger.info(f"🌐 View your model at: https://huggingface.co/{repo_name}")
return True
def main():
"""Main function"""
import argparse
parser = argparse.ArgumentParser(description="Process recovered model")
parser.add_argument("repo_name", help="Hugging Face repository name (username/model-name)")
parser.add_argument("--model-path", default="recovered_model", help="Path to recovered model")
parser.add_argument("--no-quantize", action="store_true", help="Skip quantization")
parser.add_argument("--no-push", action="store_true", help="Skip pushing to HF Hub")
parser.add_argument("--quant-type", default="int8_weight_only",
choices=["int8_weight_only", "int4_weight_only", "int8_dynamic"],
help="Quantization type")
args = parser.parse_args()
# Initialize processor
processor = ModelProcessor(args.model_path)
# Process model
success = processor.process_model(
repo_name=args.repo_name,
quantize=not args.no_quantize,
push=not args.no_push
)
return 0 if success else 1
if __name__ == "__main__":
exit(main())