Spaces:
Running
Running
| import logging | |
| import subprocess | |
| import sys | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Optional, Tuple | |
| from urllib.request import urlopen, urlretrieve | |
| import streamlit as st | |
| from huggingface_hub import HfApi, whoami | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class Config: | |
| """Application configuration.""" | |
| hf_token: str | |
| hf_username: str | |
| transformers_version: str = "3.0.0" | |
| hf_base_url: str = "https://huggingface.co" | |
| transformers_base_url: str = ( | |
| "https://github.com/xenova/transformers.js/archive/refs" | |
| ) | |
| repo_path: Path = Path("./transformers.js") | |
| def from_env(cls) -> "Config": | |
| """Create config from environment variables and secrets.""" | |
| system_token = st.secrets.get("HF_TOKEN") | |
| user_token = st.session_state.get("user_hf_token", "") | |
| hf_username = ( | |
| whoami(token=user_token)["name"] | |
| or st.secrets.get("SPACE_AUTHOR_NAME") | |
| or whoami(token=system_token)["name"] | |
| ) | |
| hf_token = user_token or system_token | |
| if not hf_token: | |
| raise ValueError("HF_TOKEN must be set") | |
| return cls(hf_token=hf_token, hf_username=hf_username) | |
| class ModelConverter: | |
| """Handles model conversion and upload operations.""" | |
| def __init__(self, config: Config): | |
| self.config = config | |
| self.api = HfApi(token=config.hf_token) | |
| def _get_ref_type(self) -> str: | |
| """Determine the reference type for the transformers repository.""" | |
| url = f"{self.config.transformers_base_url}/tags/{self.config.transformers_version}.tar.gz" | |
| try: | |
| return "tags" if urlopen(url).getcode() == 200 else "heads" | |
| except Exception as e: | |
| logger.warning(f"Failed to check tags, defaulting to heads: {e}") | |
| return "heads" | |
| def setup_repository(self) -> None: | |
| """Download and setup transformers repository if needed.""" | |
| if self.config.repo_path.exists(): | |
| return | |
| ref_type = self._get_ref_type() | |
| archive_url = f"{self.config.transformers_base_url}/{ref_type}/{self.config.transformers_version}.tar.gz" | |
| archive_path = Path(f"./transformers_{self.config.transformers_version}.tar.gz") | |
| try: | |
| urlretrieve(archive_url, archive_path) | |
| self._extract_archive(archive_path) | |
| logger.info("Repository downloaded and extracted successfully") | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to setup repository: {e}") | |
| finally: | |
| archive_path.unlink(missing_ok=True) | |
| def _extract_archive(self, archive_path: Path) -> None: | |
| """Extract the downloaded archive.""" | |
| import tarfile | |
| import tempfile | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| with tarfile.open(archive_path, "r:gz") as tar: | |
| tar.extractall(tmp_dir) | |
| extracted_folder = next(Path(tmp_dir).iterdir()) | |
| extracted_folder.rename(self.config.repo_path) | |
| def convert_model(self, input_model_id: str) -> Tuple[bool, Optional[str]]: | |
| """Convert the model to ONNX format.""" | |
| try: | |
| result = subprocess.run( | |
| [ | |
| sys.executable, | |
| "-m", | |
| "scripts.convert", | |
| "--quantize", | |
| "--model_id", | |
| input_model_id, | |
| ], | |
| cwd=self.config.repo_path, | |
| capture_output=True, | |
| text=True, | |
| env={}, | |
| ) | |
| if result.returncode != 0: | |
| return False, result.stderr | |
| self._rename_model_files(input_model_id) | |
| return True, result.stderr | |
| except Exception as e: | |
| return False, str(e) | |
| def _rename_model_files(self, input_model_id: str) -> None: | |
| """Rename the converted model files.""" | |
| model_path = self.config.repo_path / "models" / input_model_id / "onnx" | |
| renames = [ | |
| ("model.onnx", "decoder_model_merged.onnx"), | |
| ("model_quantized.onnx", "decoder_model_merged_quantized.onnx"), | |
| ] | |
| for old_name, new_name in renames: | |
| (model_path / old_name).rename(model_path / new_name) | |
| def upload_model(self, input_model_id: str, output_model_id: str) -> Optional[str]: | |
| """Upload the converted model to Hugging Face.""" | |
| try: | |
| self.api.create_repo(output_model_id, exist_ok=True, private=False) | |
| model_folder_path = self.config.repo_path / "models" / input_model_id | |
| self.api.upload_folder( | |
| folder_path=str(model_folder_path), repo_id=output_model_id | |
| ) | |
| return None | |
| except Exception as e: | |
| return str(e) | |
| finally: | |
| import shutil | |
| shutil.rmtree(model_folder_path, ignore_errors=True) | |
| def main(): | |
| """Main application entry point.""" | |
| st.write("## Convert a Hugging Face model to ONNX") | |
| try: | |
| config = Config.from_env() | |
| converter = ModelConverter(config) | |
| converter.setup_repository() | |
| input_model_id = st.text_input( | |
| "Enter the Hugging Face model ID to convert. Example: `EleutherAI/pythia-14m`" | |
| ) | |
| if not input_model_id: | |
| return | |
| st.text_input( | |
| f"Optional: Your Hugging Face write token. Leave empty to upload under {config.hf_username}'s account.", | |
| type="password", | |
| key="user_hf_token", | |
| ) | |
| model_name = ( | |
| input_model_id.replace(f"{config.hf_base_url}/", "") | |
| .replace("/", "-") | |
| .replace(f"{config.hf_username}-", "") | |
| .strip() | |
| ) | |
| output_model_id = f"{config.hf_username}/{model_name}-ONNX" | |
| output_model_url = f"{config.hf_base_url}/{output_model_id}" | |
| if converter.api.repo_exists(output_model_id): | |
| st.write("This model has already been converted! 🎉") | |
| st.link_button(f"Go to {output_model_id}", output_model_url, type="primary") | |
| return | |
| st.write(f"This model will be converted and uploaded to the following URL:") | |
| st.code(output_model_url, language="plaintext") | |
| if not st.button(label="Proceed", type="primary"): | |
| return | |
| with st.spinner("Converting model..."): | |
| success, stderr = converter.convert_model(input_model_id) | |
| if not success: | |
| st.error(f"Conversion failed: {stderr}") | |
| return | |
| st.success("Conversion successful!") | |
| st.code(stderr) | |
| with st.spinner("Uploading model..."): | |
| error = converter.upload_model(input_model_id, output_model_id) | |
| if error: | |
| st.error(f"Upload failed: {error}") | |
| return | |
| st.success("Upload successful!") | |
| st.write("You can now go and view the model on Hugging Face!") | |
| st.link_button(f"Go to {output_model_id}", output_model_url, type="primary") | |
| except Exception as e: | |
| logger.exception("Application error") | |
| st.error(f"An error occurred: {str(e)}") | |
| if __name__ == "__main__": | |
| main() | |