| import argparse | |
| import json | |
| import os | |
| import shutil | |
| from diffusers.pipelines.stable_diffusion import safety_checker | |
| import torch | |
| from tempfile import TemporaryDirectory | |
| from typing import List, Optional | |
| from diffusers import StableDiffusionPipeline, ControlNetModel | |
| from huggingface_hub import CommitInfo, CommitOperationAdd, Discussion, HfApi, hf_hub_download | |
| from huggingface_hub.file_download import repo_folder_name | |
| def convert(api: "HfApi", model_id: str, force: bool = False) -> Optional["CommitInfo"]: | |
| info = api.model_info(model_id) | |
| filenames = set(s.rfilename for s in info.siblings) | |
| is_sd = "model_index.json" in filenames | |
| if is_sd: | |
| model = StableDiffusionPipeline.from_pretrained(model_id, from_flax=True, safety_checker=None) | |
| else: | |
| model = ControlNetModel.from_pretrained(model_id, from_flax=True) | |
| with TemporaryDirectory() as d: | |
| folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models")) | |
| os.makedirs(folder) | |
| model.save_pretrained(folder) | |
| model.save_pretrained(folder, safe_serialization=True) | |
| if is_sd: | |
| model.to(torch_dtype=torch.float16) | |
| else: | |
| model.half() | |
| model.save_pretrained(folder, variant="fp16") | |
| model.save_pretrained(folder, safe_serialization=True, variant="fp16") | |
| api.upload_folder( | |
| folder_path=folder, | |
| repo_id=model_id, | |
| repo_type="model", | |
| create_pr=True, | |
| ) | |
| print(model_id) | |
| if __name__ == "__main__": | |
| DESCRIPTION = """ | |
| Simple utility tool to convert automatically some weights on the hub to `safetensors` format. | |
| It is PyTorch exclusive for now. | |
| It works by downloading the weights (PT), converting them locally, and uploading them back | |
| as a PR on the hub. | |
| """ | |
| parser = argparse.ArgumentParser(description=DESCRIPTION) | |
| parser.add_argument( | |
| "model_id", | |
| type=str, | |
| help="The name of the model on the hub to convert. E.g. `gpt2` or `facebook/wav2vec2-base-960h`", | |
| ) | |
| args = parser.parse_args() | |
| model_id = args.model_id | |
| api = HfApi() | |
| convert(api, model_id) | |