Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
Model Processing Script | |
Processes recovered model with quantization and pushing to HF Hub | |
""" | |
import os | |
import sys | |
import json | |
import logging | |
import subprocess | |
from pathlib import Path | |
from typing import Dict, Any, Optional | |
# Setup logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
class ModelProcessor: | |
"""Process recovered model with quantization and pushing""" | |
def __init__(self, model_path: str = "recovered_model"): | |
self.model_path = Path(model_path) | |
self.hf_token = os.getenv('HF_TOKEN') | |
def validate_model(self) -> bool: | |
"""Validate that the model can be loaded""" | |
try: | |
logger.info("π Validating model loading...") | |
# Try to load the model | |
cmd = [ | |
sys.executable, "-c", | |
"from transformers import AutoModelForCausalLM; " | |
"model = AutoModelForCausalLM.from_pretrained('recovered_model', " | |
"torch_dtype='auto', device_map='auto'); " | |
"print('β Model loaded successfully')" | |
] | |
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300) | |
if result.returncode == 0: | |
logger.info("β Model validation successful") | |
return True | |
else: | |
logger.error(f"β Model validation failed: {result.stderr}") | |
return False | |
except Exception as e: | |
logger.error(f"β Model validation error: {e}") | |
return False | |
def get_model_info(self) -> Dict[str, Any]: | |
"""Get information about the model""" | |
try: | |
# Load config | |
config_path = self.model_path / "config.json" | |
if config_path.exists(): | |
with open(config_path, 'r') as f: | |
config = json.load(f) | |
else: | |
config = {} | |
# Calculate model size | |
total_size = 0 | |
for file in self.model_path.rglob("*"): | |
if file.is_file(): | |
total_size += file.stat().st_size | |
model_info = { | |
"model_type": config.get("model_type", "smollm3"), | |
"architectures": config.get("architectures", ["SmolLM3ForCausalLM"]), | |
"model_size_gb": total_size / (1024**3), | |
"vocab_size": config.get("vocab_size", 32000), | |
"hidden_size": config.get("hidden_size", 2048), | |
"num_attention_heads": config.get("num_attention_heads", 16), | |
"num_hidden_layers": config.get("num_hidden_layers", 24), | |
"max_position_embeddings": config.get("max_position_embeddings", 8192) | |
} | |
logger.info(f"π Model info: {model_info}") | |
return model_info | |
except Exception as e: | |
logger.error(f"β Failed to get model info: {e}") | |
return {} | |
def run_quantization(self, repo_name: str, quant_type: str = "int8_weight_only") -> bool: | |
"""Run quantization on the model""" | |
try: | |
logger.info(f"π Running quantization: {quant_type}") | |
# Check if quantization script exists | |
quantize_script = Path("scripts/model_tonic/quantize_model.py") | |
if not quantize_script.exists(): | |
logger.error(f"β Quantization script not found: {quantize_script}") | |
return False | |
# Run quantization | |
cmd = [ | |
sys.executable, str(quantize_script), | |
str(self.model_path), | |
repo_name, | |
"--quant-type", quant_type, | |
"--device", "auto" | |
] | |
if self.hf_token: | |
cmd.extend(["--token", self.hf_token]) | |
logger.info(f"π Running: {' '.join(cmd)}") | |
result = subprocess.run(cmd, capture_output=True, text=True, timeout=1800) # 30 min timeout | |
if result.returncode == 0: | |
logger.info("β Quantization completed successfully") | |
logger.info(result.stdout) | |
return True | |
else: | |
logger.error("β Quantization failed") | |
logger.error(result.stderr) | |
return False | |
except subprocess.TimeoutExpired: | |
logger.error("β Quantization timed out") | |
return False | |
except Exception as e: | |
logger.error(f"β Failed to run quantization: {e}") | |
return False | |
def run_model_push(self, repo_name: str) -> bool: | |
"""Push the model to HF Hub""" | |
try: | |
logger.info(f"π Pushing model to: {repo_name}") | |
# Check if push script exists | |
push_script = Path("scripts/model_tonic/push_to_huggingface.py") | |
if not push_script.exists(): | |
logger.error(f"β Push script not found: {push_script}") | |
return False | |
# Run push | |
cmd = [ | |
sys.executable, str(push_script), | |
str(self.model_path), | |
repo_name | |
] | |
if self.hf_token: | |
cmd.extend(["--token", self.hf_token]) | |
logger.info(f"π Running: {' '.join(cmd)}") | |
result = subprocess.run(cmd, capture_output=True, text=True, timeout=1800) # 30 min timeout | |
if result.returncode == 0: | |
logger.info("β Model push completed successfully") | |
logger.info(result.stdout) | |
return True | |
else: | |
logger.error("β Model push failed") | |
logger.error(result.stderr) | |
return False | |
except subprocess.TimeoutExpired: | |
logger.error("β Model push timed out") | |
return False | |
except Exception as e: | |
logger.error(f"β Failed to push model: {e}") | |
return False | |
def process_model(self, repo_name: str, quantize: bool = True, push: bool = True) -> bool: | |
"""Complete model processing workflow""" | |
logger.info("π Starting model processing...") | |
# Step 1: Validate model | |
if not self.validate_model(): | |
logger.error("β Model validation failed") | |
return False | |
# Step 2: Get model info | |
model_info = self.get_model_info() | |
# Step 3: Quantize if requested | |
if quantize: | |
if not self.run_quantization(repo_name): | |
logger.error("β Quantization failed") | |
return False | |
# Step 4: Push if requested | |
if push: | |
if not self.run_model_push(repo_name): | |
logger.error("β Model push failed") | |
return False | |
logger.info("π Model processing completed successfully!") | |
logger.info(f"π View your model at: https://huggingface.co/{repo_name}") | |
return True | |
def main(): | |
"""Main function""" | |
import argparse | |
parser = argparse.ArgumentParser(description="Process recovered model") | |
parser.add_argument("repo_name", help="Hugging Face repository name (username/model-name)") | |
parser.add_argument("--model-path", default="recovered_model", help="Path to recovered model") | |
parser.add_argument("--no-quantize", action="store_true", help="Skip quantization") | |
parser.add_argument("--no-push", action="store_true", help="Skip pushing to HF Hub") | |
parser.add_argument("--quant-type", default="int8_weight_only", | |
choices=["int8_weight_only", "int4_weight_only", "int8_dynamic"], | |
help="Quantization type") | |
args = parser.parse_args() | |
# Initialize processor | |
processor = ModelProcessor(args.model_path) | |
# Process model | |
success = processor.process_model( | |
repo_name=args.repo_name, | |
quantize=not args.no_quantize, | |
push=not args.no_push | |
) | |
return 0 if success else 1 | |
if __name__ == "__main__": | |
exit(main()) |