Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Model Processing Script | |
| Processes recovered model with quantization and pushing to HF Hub | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import logging | |
| import subprocess | |
| from pathlib import Path | |
| from typing import Dict, Any, Optional | |
| # Setup logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class ModelProcessor: | |
| """Process recovered model with quantization and pushing""" | |
| def __init__(self, model_path: str = "recovered_model"): | |
| self.model_path = Path(model_path) | |
| self.hf_token = os.getenv('HF_TOKEN') | |
| def validate_model(self) -> bool: | |
| """Validate that the model can be loaded""" | |
| try: | |
| logger.info("π Validating model loading...") | |
| # Try to load the model | |
| cmd = [ | |
| sys.executable, "-c", | |
| "from transformers import AutoModelForCausalLM; " | |
| "model = AutoModelForCausalLM.from_pretrained('recovered_model', " | |
| "torch_dtype='auto', device_map='auto'); " | |
| "print('β Model loaded successfully')" | |
| ] | |
| result = subprocess.run(cmd, capture_output=True, text=True, timeout=300) | |
| if result.returncode == 0: | |
| logger.info("β Model validation successful") | |
| return True | |
| else: | |
| logger.error(f"β Model validation failed: {result.stderr}") | |
| return False | |
| except Exception as e: | |
| logger.error(f"β Model validation error: {e}") | |
| return False | |
| def get_model_info(self) -> Dict[str, Any]: | |
| """Get information about the model""" | |
| try: | |
| # Load config | |
| config_path = self.model_path / "config.json" | |
| if config_path.exists(): | |
| with open(config_path, 'r') as f: | |
| config = json.load(f) | |
| else: | |
| config = {} | |
| # Calculate model size | |
| total_size = 0 | |
| for file in self.model_path.rglob("*"): | |
| if file.is_file(): | |
| total_size += file.stat().st_size | |
| model_info = { | |
| "model_type": config.get("model_type", "smollm3"), | |
| "architectures": config.get("architectures", ["SmolLM3ForCausalLM"]), | |
| "model_size_gb": total_size / (1024**3), | |
| "vocab_size": config.get("vocab_size", 32000), | |
| "hidden_size": config.get("hidden_size", 2048), | |
| "num_attention_heads": config.get("num_attention_heads", 16), | |
| "num_hidden_layers": config.get("num_hidden_layers", 24), | |
| "max_position_embeddings": config.get("max_position_embeddings", 8192) | |
| } | |
| logger.info(f"π Model info: {model_info}") | |
| return model_info | |
| except Exception as e: | |
| logger.error(f"β Failed to get model info: {e}") | |
| return {} | |
| def run_quantization(self, repo_name: str, quant_type: str = "int8_weight_only") -> bool: | |
| """Run quantization on the model""" | |
| try: | |
| logger.info(f"π Running quantization: {quant_type}") | |
| # Check if quantization script exists | |
| quantize_script = Path("scripts/model_tonic/quantize_model.py") | |
| if not quantize_script.exists(): | |
| logger.error(f"β Quantization script not found: {quantize_script}") | |
| return False | |
| # Run quantization | |
| cmd = [ | |
| sys.executable, str(quantize_script), | |
| str(self.model_path), | |
| repo_name, | |
| "--quant-type", quant_type, | |
| "--device", "auto" | |
| ] | |
| if self.hf_token: | |
| cmd.extend(["--token", self.hf_token]) | |
| logger.info(f"π Running: {' '.join(cmd)}") | |
| result = subprocess.run(cmd, capture_output=True, text=True, timeout=1800) # 30 min timeout | |
| if result.returncode == 0: | |
| logger.info("β Quantization completed successfully") | |
| logger.info(result.stdout) | |
| return True | |
| else: | |
| logger.error("β Quantization failed") | |
| logger.error(result.stderr) | |
| return False | |
| except subprocess.TimeoutExpired: | |
| logger.error("β Quantization timed out") | |
| return False | |
| except Exception as e: | |
| logger.error(f"β Failed to run quantization: {e}") | |
| return False | |
| def run_model_push(self, repo_name: str) -> bool: | |
| """Push the model to HF Hub""" | |
| try: | |
| logger.info(f"π Pushing model to: {repo_name}") | |
| # Check if push script exists | |
| push_script = Path("scripts/model_tonic/push_to_huggingface.py") | |
| if not push_script.exists(): | |
| logger.error(f"β Push script not found: {push_script}") | |
| return False | |
| # Run push | |
| cmd = [ | |
| sys.executable, str(push_script), | |
| str(self.model_path), | |
| repo_name | |
| ] | |
| if self.hf_token: | |
| cmd.extend(["--token", self.hf_token]) | |
| logger.info(f"π Running: {' '.join(cmd)}") | |
| result = subprocess.run(cmd, capture_output=True, text=True, timeout=1800) # 30 min timeout | |
| if result.returncode == 0: | |
| logger.info("β Model push completed successfully") | |
| logger.info(result.stdout) | |
| return True | |
| else: | |
| logger.error("β Model push failed") | |
| logger.error(result.stderr) | |
| return False | |
| except subprocess.TimeoutExpired: | |
| logger.error("β Model push timed out") | |
| return False | |
| except Exception as e: | |
| logger.error(f"β Failed to push model: {e}") | |
| return False | |
| def process_model(self, repo_name: str, quantize: bool = True, push: bool = True) -> bool: | |
| """Complete model processing workflow""" | |
| logger.info("π Starting model processing...") | |
| # Step 1: Validate model | |
| if not self.validate_model(): | |
| logger.error("β Model validation failed") | |
| return False | |
| # Step 2: Get model info | |
| model_info = self.get_model_info() | |
| # Step 3: Quantize if requested | |
| if quantize: | |
| if not self.run_quantization(repo_name): | |
| logger.error("β Quantization failed") | |
| return False | |
| # Step 4: Push if requested | |
| if push: | |
| if not self.run_model_push(repo_name): | |
| logger.error("β Model push failed") | |
| return False | |
| logger.info("π Model processing completed successfully!") | |
| logger.info(f"π View your model at: https://huggingface.co/{repo_name}") | |
| return True | |
| def main(): | |
| """Main function""" | |
| import argparse | |
| parser = argparse.ArgumentParser(description="Process recovered model") | |
| parser.add_argument("repo_name", help="Hugging Face repository name (username/model-name)") | |
| parser.add_argument("--model-path", default="recovered_model", help="Path to recovered model") | |
| parser.add_argument("--no-quantize", action="store_true", help="Skip quantization") | |
| parser.add_argument("--no-push", action="store_true", help="Skip pushing to HF Hub") | |
| parser.add_argument("--quant-type", default="int8_weight_only", | |
| choices=["int8_weight_only", "int4_weight_only", "int8_dynamic"], | |
| help="Quantization type") | |
| args = parser.parse_args() | |
| # Initialize processor | |
| processor = ModelProcessor(args.model_path) | |
| # Process model | |
| success = processor.process_model( | |
| repo_name=args.repo_name, | |
| quantize=not args.no_quantize, | |
| push=not args.no_push | |
| ) | |
| return 0 if success else 1 | |
| if __name__ == "__main__": | |
| exit(main()) |