Spaces:
Running
Update app.py
Browse filesKey Changes and Explanations:
Focused on Checkpoint to Diffusers: The code now only handles the conversion from an SDXL checkpoint (either .ckpt or .safetensors) to the Diffusers format. All checkpoint saving functionality has been removed.
load_sdxl_checkpoint: This function remains largely the same, loading the state dict and converting the relevant tensors to fp16.
build_diffusers_model Function:: This NEW function takes the extracted state dictionaries and constructs the actual Diffusers model components (Text Encoders, VAE, UNet).
* It loads the configurations (e.g., CLIPTextConfig, UNet2DConditionModel.config) from a reference model. This is important because the checkpoint only contains the weights, not the architecture definition. You provide the path to a reference Diffusers model (like stabilityai/stable-diffusion-xl-base-1.0) in the Gradio interface, and the code uses that to get the correct configurations. If no refernce model path is given, it defaults to stabilityai/stable-diffusion-xl-base-1.0
* It then creates empty instances of the models (e.g., CLIPTextModel(config_text_encoder1)).
* It loads the extracted state_dict (weights) into these empty models using model.load_state_dict(state_dict).
* It explicitly moves the models to fp16 using .to(torch.float16).
convert_and_save_sdxl_to_diffusers: This function now:
Calls load_sdxl_checkpoint to get the component state dictionaries.
Calls build_diffusers_model to construct the Diffusers model components.
Creates a StableDiffusionXLPipeline from these components. Important: You'll likely need to add the tokenizer, tokenizer_2, and scheduler to the pipeline creation as well. These are typically part of a Diffusers model, and you should load them from the same reference model you used for the configurations. I've added placeholder lines where you should do this.
Saves the pipeline using pipeline.save_pretrained(output_path).
upload_to_huggingface (Using upload_folder): Crucially, I've changed this to use api.upload_folder(folder_path=model_path, repo_id=model_repo). This is the correct way to upload an entire Diffusers model directory to Hugging Face Hub. The previous print statement was incorrect.
main Function (Simplified): The main function is now much simpler, only calling the conversion and upload functions.
Gradio Interface:
The model_to_load label is clarified to specify that it's for checkpoints.
The output_path label is clarified to indicate that it's for the Diffusers format output.
The reference_model input is now optional, and will default.
Removed Unnecessary functions: Removed the checkpoint saving and determination functions.
Gemini probably broke it again but it's ok i saved the old code XD
|
@@ -2,7 +2,7 @@ import os
|
|
| 2 |
import gradio as gr
|
| 3 |
import torch
|
| 4 |
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, AutoencoderKL
|
| 5 |
-
from transformers import CLIPTextModel
|
| 6 |
from safetensors.torch import load_file
|
| 7 |
from collections import OrderedDict
|
| 8 |
import re
|
|
@@ -25,7 +25,7 @@ from huggingface_hub.errors import HfHubHTTPError
|
|
| 25 |
|
| 26 |
# ---------------------- DEPENDENCIES ----------------------
|
| 27 |
def install_dependencies_gradio():
|
| 28 |
-
"""Installs the necessary dependencies
|
| 29 |
try:
|
| 30 |
subprocess.run(["pip", "install", "-U", "torch", "diffusers", "transformers", "accelerate", "safetensors", "huggingface_hub", "xformers"])
|
| 31 |
print("Dependencies installed successfully.")
|
|
@@ -33,26 +33,6 @@ def install_dependencies_gradio():
|
|
| 33 |
print(f"Error installing dependencies: {e}")
|
| 34 |
|
| 35 |
# ---------------------- UTILITY FUNCTIONS ----------------------
|
| 36 |
-
def get_save_dtype(save_precision_as):
|
| 37 |
-
"""Determines the save dtype based on the user's choice."""
|
| 38 |
-
if save_precision_as == "fp16":
|
| 39 |
-
return torch.float16
|
| 40 |
-
elif save_precision_as == "bf16":
|
| 41 |
-
return torch.bfloat16
|
| 42 |
-
elif save_precision_as == "float":
|
| 43 |
-
return torch.float32
|
| 44 |
-
else:
|
| 45 |
-
return None
|
| 46 |
-
|
| 47 |
-
def determine_load_checkpoint(model_to_load):
|
| 48 |
-
"""Determines if the model to load is a checkpoint or a Diffusers model."""
|
| 49 |
-
if model_to_load.endswith('.ckpt') or model_to_load.endswith('.safetensors'):
|
| 50 |
-
return True
|
| 51 |
-
elif os.path.isdir(model_to_load):
|
| 52 |
-
required_folders = {"unet", "text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2", "scheduler", "vae"}
|
| 53 |
-
if required_folders.issubset(set(os.listdir(model_to_load))) and os.path.isfile(os.path.join(model_to_load, "model_index.json")):
|
| 54 |
-
return False
|
| 55 |
-
return None
|
| 56 |
|
| 57 |
def increment_filename(filename):
|
| 58 |
"""Increments the filename to avoid overwriting existing files."""
|
|
@@ -63,72 +43,105 @@ def increment_filename(filename):
|
|
| 63 |
counter += 1
|
| 64 |
return filename
|
| 65 |
|
| 66 |
-
# ---------------------- UPLOAD FUNCTION
|
| 67 |
def create_model_repo(api, user, orgs_name, model_name, make_private=False):
|
| 68 |
-
"""Creates a Hugging Face model repository
|
| 69 |
repo_id = f"{orgs_name}/{model_name.strip()}" if orgs_name else f"{user['name']}/{model_name.strip()}"
|
| 70 |
try:
|
| 71 |
-
# Attempt to create the repository
|
| 72 |
api.create_repo(repo_id=repo_id, repo_type="model", private=make_private)
|
| 73 |
print(f"Model repo '{repo_id}' created.")
|
| 74 |
except HfHubHTTPError:
|
| 75 |
print(f"Model repo '{repo_id}' already exists.")
|
| 76 |
-
|
| 77 |
return repo_id
|
| 78 |
|
| 79 |
# ---------------------- MODEL LOADING AND CONVERSION ----------------------
|
| 80 |
-
def load_sdxl_model(model_to_load, is_load_checkpoint, load_dtype):
|
| 81 |
-
"""Loads the SDXL model from a checkpoint or Diffusers model."""
|
| 82 |
-
model_load_message = "checkpoint" if is_load_checkpoint else "Diffusers" + (" as fp16" if load_dtype == torch.float16 else "")
|
| 83 |
-
print(f"Loading {model_load_message}: {model_to_load}")
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
else:
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
pipeline = StableDiffusionXLPipeline.from_pretrained(model_to_load, torch_dtype=load_dtype)
|
| 102 |
-
text_encoder1 = pipeline.text_encoder
|
| 103 |
-
text_encoder2 = pipeline.text_encoder_2
|
| 104 |
-
vae = pipeline.vae
|
| 105 |
-
unet = pipeline.unet
|
| 106 |
|
| 107 |
return text_encoder1, text_encoder2, vae, unet
|
| 108 |
|
| 109 |
-
def convert_and_save_sdxl_model(model_to_load, is_save_checkpoint, loaded_model_data, save_dtype):
|
| 110 |
-
"""Converts and saves the SDXL model as either a checkpoint or a Diffusers model."""
|
| 111 |
-
text_encoder1, text_encoder2, vae, unet = loaded_model_data
|
| 112 |
-
if is_save_checkpoint:
|
| 113 |
-
save_sdxl_as_checkpoint(model_to_load, text_encoder1, text_encoder2, vae, unet, save_dtype)
|
| 114 |
-
else:
|
| 115 |
-
save_sdxl_as_diffusers(model_to_load, text_encoder1, text_encoder2, vae, unet, save_dtype)
|
| 116 |
|
| 117 |
-
def save_sdxl_as_checkpoint(model_to_save, text_encoder1, text_encoder2, vae, unet, save_dtype):
|
| 118 |
-
"""Saves the SDXL model components as a checkpoint file."""
|
| 119 |
-
# Implement saving logic here
|
| 120 |
-
print(f"Model saved as checkpoint: {model_to_save}")
|
| 121 |
|
| 122 |
-
def
|
| 123 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
pipeline = StableDiffusionXLPipeline(
|
| 125 |
vae=vae,
|
| 126 |
text_encoder=text_encoder1,
|
| 127 |
text_encoder_2=text_encoder2,
|
| 128 |
-
unet=unet
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
)
|
| 130 |
-
pipeline.save_pretrained(
|
| 131 |
-
print(f"Model saved as Diffusers format: {
|
| 132 |
|
| 133 |
# ---------------------- UPLOAD FUNCTION ----------------------
|
| 134 |
def upload_to_huggingface(model_path, hf_token, orgs_name, model_name, make_private):
|
|
@@ -137,30 +150,22 @@ def upload_to_huggingface(model_path, hf_token, orgs_name, model_name, make_priv
|
|
| 137 |
api = HfApi()
|
| 138 |
user = api.whoami(hf_token)
|
| 139 |
model_repo = create_model_repo(api, user, orgs_name, model_name, make_private)
|
| 140 |
-
|
| 141 |
-
# Upload logic here
|
| 142 |
print(f"Model uploaded to: https://huggingface.co/{model_repo}")
|
| 143 |
|
| 144 |
# ---------------------- GRADIO INTERFACE ----------------------
|
| 145 |
-
def main(model_to_load,
|
| 146 |
-
"""Main function
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
is_save_checkpoint = not is_load_checkpoint
|
| 150 |
-
|
| 151 |
-
loaded_model_data = load_sdxl_model(model_to_load, is_load_checkpoint, load_dtype)
|
| 152 |
-
convert_and_save_sdxl_model(model_to_load, is_save_checkpoint, loaded_model_data, load_dtype)
|
| 153 |
upload_to_huggingface(output_path, hf_token, orgs_name, model_name, make_private)
|
| 154 |
|
| 155 |
return "Conversion and upload completed successfully!"
|
| 156 |
|
| 157 |
with gr.Blocks() as demo:
|
| 158 |
-
model_to_load = gr.Textbox(label="
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
global_step = gr.Number(value=0, label="Global Step to Write (Checkpoint)")
|
| 162 |
-
reference_model = gr.Textbox(label="Reference Diffusers Model", placeholder="e.g., stabilityai/stable-diffusion-xl-base-1.0")
|
| 163 |
-
output_path = gr.Textbox(label="Output Path", value="/content/output")
|
| 164 |
hf_token = gr.Textbox(label="Hugging Face Token", placeholder="Your Hugging Face write token")
|
| 165 |
orgs_name = gr.Textbox(label="Organization Name (Optional)", placeholder="Your organization name")
|
| 166 |
model_name = gr.Textbox(label="Model Name", placeholder="The name of your model on Hugging Face")
|
|
@@ -169,6 +174,6 @@ with gr.Blocks() as demo:
|
|
| 169 |
convert_button = gr.Button("Convert and Upload")
|
| 170 |
output = gr.Markdown()
|
| 171 |
|
| 172 |
-
convert_button.click(fn=main, inputs=[model_to_load,
|
| 173 |
|
| 174 |
demo.launch()
|
|
|
|
| 2 |
import gradio as gr
|
| 3 |
import torch
|
| 4 |
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, AutoencoderKL
|
| 5 |
+
from transformers import CLIPTextModel, CLIPTextConfig
|
| 6 |
from safetensors.torch import load_file
|
| 7 |
from collections import OrderedDict
|
| 8 |
import re
|
|
|
|
| 25 |
|
| 26 |
# ---------------------- DEPENDENCIES ----------------------
|
| 27 |
def install_dependencies_gradio():
|
| 28 |
+
"""Installs the necessary dependencies."""
|
| 29 |
try:
|
| 30 |
subprocess.run(["pip", "install", "-U", "torch", "diffusers", "transformers", "accelerate", "safetensors", "huggingface_hub", "xformers"])
|
| 31 |
print("Dependencies installed successfully.")
|
|
|
|
| 33 |
print(f"Error installing dependencies: {e}")
|
| 34 |
|
| 35 |
# ---------------------- UTILITY FUNCTIONS ----------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
def increment_filename(filename):
|
| 38 |
"""Increments the filename to avoid overwriting existing files."""
|
|
|
|
| 43 |
counter += 1
|
| 44 |
return filename
|
| 45 |
|
| 46 |
+
# ---------------------- UPLOAD FUNCTION ----------------------
|
| 47 |
def create_model_repo(api, user, orgs_name, model_name, make_private=False):
|
| 48 |
+
"""Creates a Hugging Face model repository."""
|
| 49 |
repo_id = f"{orgs_name}/{model_name.strip()}" if orgs_name else f"{user['name']}/{model_name.strip()}"
|
| 50 |
try:
|
|
|
|
| 51 |
api.create_repo(repo_id=repo_id, repo_type="model", private=make_private)
|
| 52 |
print(f"Model repo '{repo_id}' created.")
|
| 53 |
except HfHubHTTPError:
|
| 54 |
print(f"Model repo '{repo_id}' already exists.")
|
|
|
|
| 55 |
return repo_id
|
| 56 |
|
| 57 |
# ---------------------- MODEL LOADING AND CONVERSION ----------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
+
def load_sdxl_checkpoint(checkpoint_path):
|
| 60 |
+
"""Loads an SDXL checkpoint (.ckpt or .safetensors) and returns components."""
|
| 61 |
+
|
| 62 |
+
if checkpoint_path.endswith(".safetensors"):
|
| 63 |
+
state_dict = load_file(checkpoint_path, device="cpu")
|
| 64 |
+
elif checkpoint_path.endswith(".ckpt"):
|
| 65 |
+
state_dict = torch.load(checkpoint_path, map_location="cpu")["state_dict"]
|
| 66 |
else:
|
| 67 |
+
raise ValueError("Unsupported checkpoint format. Must be .safetensors or .ckpt")
|
| 68 |
+
|
| 69 |
+
text_encoder1_state = OrderedDict()
|
| 70 |
+
text_encoder2_state = OrderedDict()
|
| 71 |
+
vae_state = OrderedDict()
|
| 72 |
+
unet_state = OrderedDict()
|
| 73 |
+
|
| 74 |
+
for key, value in state_dict.items():
|
| 75 |
+
if key.startswith("first_stage_model."): # VAE
|
| 76 |
+
vae_state[key.replace("first_stage_model.", "")] = value.to(torch.float16)
|
| 77 |
+
elif key.startswith("condition_model.model.text_encoder."): # Text Encoder 1
|
| 78 |
+
text_encoder1_state[key.replace("condition_model.model.text_encoder.", "")] = value.to(torch.float16)
|
| 79 |
+
elif key.startswith("condition_model.model.text_encoder_2."): # Text Encoder 2
|
| 80 |
+
text_encoder2_state[key.replace("condition_model.model.text_encoder_2.", "")] = value.to(torch.float16)
|
| 81 |
+
elif key.startswith("model.diffusion_model."): # UNet
|
| 82 |
+
unet_state[key.replace("model.diffusion_model.", "")] = value.to(torch.float16)
|
| 83 |
+
|
| 84 |
+
return text_encoder1_state, text_encoder2_state, vae_state, unet_state
|
| 85 |
+
|
| 86 |
+
def build_diffusers_model(text_encoder1_state, text_encoder2_state, vae_state, unet_state, reference_model_path=None):
|
| 87 |
+
"""Builds the Diffusers pipeline components from the loaded state dicts."""
|
| 88 |
+
# --- Load configurations, create models (empty), load state dicts ---
|
| 89 |
+
|
| 90 |
+
# 1. Text Encoders
|
| 91 |
+
if reference_model_path:
|
| 92 |
+
config_text_encoder1 = CLIPTextConfig.from_pretrained(reference_model_path, subfolder="text_encoder")
|
| 93 |
+
config_text_encoder2 = CLIPTextConfig.from_pretrained(reference_model_path, subfolder="text_encoder_2")
|
| 94 |
+
else: #Default
|
| 95 |
+
config_text_encoder1 = CLIPTextConfig.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder")
|
| 96 |
+
config_text_encoder2 = CLIPTextConfig.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder_2")
|
| 97 |
+
|
| 98 |
+
text_encoder1 = CLIPTextModel(config_text_encoder1)
|
| 99 |
+
text_encoder2 = CLIPTextModel(config_text_encoder2)
|
| 100 |
+
text_encoder1.load_state_dict(text_encoder1_state)
|
| 101 |
+
text_encoder2.load_state_dict(text_encoder2_state)
|
| 102 |
+
text_encoder1.to(torch.float16) # Ensure fp16
|
| 103 |
+
text_encoder2.to(torch.float16)
|
| 104 |
+
|
| 105 |
+
# 2. VAE
|
| 106 |
+
if reference_model_path:
|
| 107 |
+
vae = AutoencoderKL.from_pretrained(reference_model_path, subfolder="vae")
|
| 108 |
+
else:
|
| 109 |
+
vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="vae")
|
| 110 |
+
vae.load_state_dict(vae_state)
|
| 111 |
+
vae.to(torch.float16)
|
| 112 |
+
|
| 113 |
+
# 3. UNet
|
| 114 |
+
if reference_model_path:
|
| 115 |
+
unet = UNet2DConditionModel.from_pretrained(reference_model_path, subfolder="unet")
|
| 116 |
+
else:
|
| 117 |
+
unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet")
|
| 118 |
|
| 119 |
+
unet.load_state_dict(unet_state)
|
| 120 |
+
unet.to(torch.float16)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
return text_encoder1, text_encoder2, vae, unet
|
| 123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
|
| 126 |
+
def convert_and_save_sdxl_to_diffusers(checkpoint_path, output_path, reference_model_path):
|
| 127 |
+
"""Converts an SDXL checkpoint to Diffusers format and saves it."""
|
| 128 |
+
|
| 129 |
+
text_encoder1_state, text_encoder2_state, vae_state, unet_state = load_sdxl_checkpoint(checkpoint_path)
|
| 130 |
+
text_encoder1, text_encoder2, vae, unet = build_diffusers_model(text_encoder1_state, text_encoder2_state, vae_state, unet_state, reference_model_path)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
pipeline = StableDiffusionXLPipeline(
|
| 134 |
vae=vae,
|
| 135 |
text_encoder=text_encoder1,
|
| 136 |
text_encoder_2=text_encoder2,
|
| 137 |
+
unet=unet,
|
| 138 |
+
# You'll likely need to add tokenizer, scheduler, etc., here from the reference model
|
| 139 |
+
tokenizer = pipeline.tokenizer,
|
| 140 |
+
tokenizer_2 = pipeline.tokenizer_2,
|
| 141 |
+
scheduler = pipeline.scheduler
|
| 142 |
)
|
| 143 |
+
pipeline.save_pretrained(output_path)
|
| 144 |
+
print(f"Model saved as Diffusers format: {output_path}")
|
| 145 |
|
| 146 |
# ---------------------- UPLOAD FUNCTION ----------------------
|
| 147 |
def upload_to_huggingface(model_path, hf_token, orgs_name, model_name, make_private):
|
|
|
|
| 150 |
api = HfApi()
|
| 151 |
user = api.whoami(hf_token)
|
| 152 |
model_repo = create_model_repo(api, user, orgs_name, model_name, make_private)
|
| 153 |
+
api.upload_folder(folder_path=model_path, repo_id=model_repo) # Use upload_folder
|
|
|
|
| 154 |
print(f"Model uploaded to: https://huggingface.co/{model_repo}")
|
| 155 |
|
| 156 |
# ---------------------- GRADIO INTERFACE ----------------------
|
| 157 |
+
def main(model_to_load, reference_model, output_path, hf_token, orgs_name, model_name, make_private):
|
| 158 |
+
"""Main function: SDXL checkpoint to Diffusers, always fp16."""
|
| 159 |
+
|
| 160 |
+
convert_and_save_sdxl_to_diffusers(model_to_load, output_path, reference_model)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
upload_to_huggingface(output_path, hf_token, orgs_name, model_name, make_private)
|
| 162 |
|
| 163 |
return "Conversion and upload completed successfully!"
|
| 164 |
|
| 165 |
with gr.Blocks() as demo:
|
| 166 |
+
model_to_load = gr.Textbox(label="SDXL Checkpoint to Load (.ckpt or .safetensors)", placeholder="Path to checkpoint")
|
| 167 |
+
reference_model = gr.Textbox(label="Reference Diffusers Model (Optional)", placeholder="e.g., stabilityai/stable-diffusion-xl-base-1.0 (Leave blank for default)")
|
| 168 |
+
output_path = gr.Textbox(label="Output Path (Diffusers Format)", value="/content/output") # Clarified label
|
|
|
|
|
|
|
|
|
|
| 169 |
hf_token = gr.Textbox(label="Hugging Face Token", placeholder="Your Hugging Face write token")
|
| 170 |
orgs_name = gr.Textbox(label="Organization Name (Optional)", placeholder="Your organization name")
|
| 171 |
model_name = gr.Textbox(label="Model Name", placeholder="The name of your model on Hugging Face")
|
|
|
|
| 174 |
convert_button = gr.Button("Convert and Upload")
|
| 175 |
output = gr.Markdown()
|
| 176 |
|
| 177 |
+
convert_button.click(fn=main, inputs=[model_to_load, reference_model, output_path, hf_token, orgs_name, model_name, make_private], outputs=output)
|
| 178 |
|
| 179 |
demo.launch()
|