Duskfallcrew commited on
Commit
1126c53
·
verified ·
1 Parent(s): ced2dd0

Update app.py

Browse files

Key Changes and Explanations:

Focused on Checkpoint to Diffusers: The code now only handles the conversion from an SDXL checkpoint (either .ckpt or .safetensors) to the Diffusers format. All checkpoint saving functionality has been removed.

load_sdxl_checkpoint: This function remains largely the same, loading the state dict and converting the relevant tensors to fp16.

build_diffusers_model Function:: This NEW function takes the extracted state dictionaries and constructs the actual Diffusers model components (Text Encoders, VAE, UNet).
* It loads the configurations (e.g., CLIPTextConfig, UNet2DConditionModel.config) from a reference model. This is important because the checkpoint only contains the weights, not the architecture definition. You provide the path to a reference Diffusers model (like stabilityai/stable-diffusion-xl-base-1.0) in the Gradio interface, and the code uses that to get the correct configurations. If no refernce model path is given, it defaults to stabilityai/stable-diffusion-xl-base-1.0
* It then creates empty instances of the models (e.g., CLIPTextModel(config_text_encoder1)).
* It loads the extracted state_dict (weights) into these empty models using model.load_state_dict(state_dict).
* It explicitly moves the models to fp16 using .to(torch.float16).

convert_and_save_sdxl_to_diffusers: This function now:

Calls load_sdxl_checkpoint to get the component state dictionaries.

Calls build_diffusers_model to construct the Diffusers model components.

Creates a StableDiffusionXLPipeline from these components. Important: You'll likely need to add the tokenizer, tokenizer_2, and scheduler to the pipeline creation as well. These are typically part of a Diffusers model, and you should load them from the same reference model you used for the configurations. I've added placeholder lines where you should do this.

Saves the pipeline using pipeline.save_pretrained(output_path).

upload_to_huggingface (Using upload_folder): Crucially, I've changed this to use api.upload_folder(folder_path=model_path, repo_id=model_repo). This is the correct way to upload an entire Diffusers model directory to Hugging Face Hub. The previous print statement was incorrect.

main Function (Simplified): The main function is now much simpler, only calling the conversion and upload functions.

Gradio Interface:

The model_to_load label is clarified to specify that it's for checkpoints.

The output_path label is clarified to indicate that it's for the Diffusers format output.

The reference_model input is now optional, and will default.

Removed Unnecessary functions: Removed the checkpoint saving and determination functions.

Gemini probably broke it again but it's ok i saved the old code XD

Files changed (1) hide show
  1. app.py +87 -82
app.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import gradio as gr
3
  import torch
4
  from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, AutoencoderKL
5
- from transformers import CLIPTextModel
6
  from safetensors.torch import load_file
7
  from collections import OrderedDict
8
  import re
@@ -25,7 +25,7 @@ from huggingface_hub.errors import HfHubHTTPError
25
 
26
  # ---------------------- DEPENDENCIES ----------------------
27
  def install_dependencies_gradio():
28
- """Installs the necessary dependencies for the Gradio app. Run this ONCE."""
29
  try:
30
  subprocess.run(["pip", "install", "-U", "torch", "diffusers", "transformers", "accelerate", "safetensors", "huggingface_hub", "xformers"])
31
  print("Dependencies installed successfully.")
@@ -33,26 +33,6 @@ def install_dependencies_gradio():
33
  print(f"Error installing dependencies: {e}")
34
 
35
  # ---------------------- UTILITY FUNCTIONS ----------------------
36
- def get_save_dtype(save_precision_as):
37
- """Determines the save dtype based on the user's choice."""
38
- if save_precision_as == "fp16":
39
- return torch.float16
40
- elif save_precision_as == "bf16":
41
- return torch.bfloat16
42
- elif save_precision_as == "float":
43
- return torch.float32
44
- else:
45
- return None
46
-
47
- def determine_load_checkpoint(model_to_load):
48
- """Determines if the model to load is a checkpoint or a Diffusers model."""
49
- if model_to_load.endswith('.ckpt') or model_to_load.endswith('.safetensors'):
50
- return True
51
- elif os.path.isdir(model_to_load):
52
- required_folders = {"unet", "text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2", "scheduler", "vae"}
53
- if required_folders.issubset(set(os.listdir(model_to_load))) and os.path.isfile(os.path.join(model_to_load, "model_index.json")):
54
- return False
55
- return None
56
 
57
  def increment_filename(filename):
58
  """Increments the filename to avoid overwriting existing files."""
@@ -63,72 +43,105 @@ def increment_filename(filename):
63
  counter += 1
64
  return filename
65
 
66
- # ---------------------- UPLOAD FUNCTION ----------------------# ---------------------- UPLOAD FUNCTION ----------------------
67
  def create_model_repo(api, user, orgs_name, model_name, make_private=False):
68
- """Creates a Hugging Face model repository if it doesn't exist."""
69
  repo_id = f"{orgs_name}/{model_name.strip()}" if orgs_name else f"{user['name']}/{model_name.strip()}"
70
  try:
71
- # Attempt to create the repository
72
  api.create_repo(repo_id=repo_id, repo_type="model", private=make_private)
73
  print(f"Model repo '{repo_id}' created.")
74
  except HfHubHTTPError:
75
  print(f"Model repo '{repo_id}' already exists.")
76
-
77
  return repo_id
78
 
79
  # ---------------------- MODEL LOADING AND CONVERSION ----------------------
80
- def load_sdxl_model(model_to_load, is_load_checkpoint, load_dtype):
81
- """Loads the SDXL model from a checkpoint or Diffusers model."""
82
- model_load_message = "checkpoint" if is_load_checkpoint else "Diffusers" + (" as fp16" if load_dtype == torch.float16 else "")
83
- print(f"Loading {model_load_message}: {model_to_load}")
84
 
85
- if is_load_checkpoint:
86
- return load_from_sdxl_checkpoint(model_to_load)
 
 
 
 
 
87
  else:
88
- return load_sdxl_from_diffusers(model_to_load, load_dtype)
89
-
90
- def load_from_sdxl_checkpoint(model_to_load):
91
- """Loads the SDXL model components from a checkpoint file."""
92
- # Implement loading logic here
93
- text_encoder1, text_encoder2, vae, unet = None, None, None, None
94
- # Example loading logic (replace with actual loading code)
95
- # text_encoder1, text_encoder2, vae, unet = sdxl_model_util.load_models_from_sdxl_checkpoint("sdxl_base_v1-0", model_to_load, "cpu")
96
- print(f"Loaded from checkpoint: {model_to_load}")
97
- return text_encoder1, text_encoder2, vae, unet
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
- def load_sdxl_from_diffusers(model_to_load, load_dtype):
100
- """Loads an SDXL model from a Diffusers model directory."""
101
- pipeline = StableDiffusionXLPipeline.from_pretrained(model_to_load, torch_dtype=load_dtype)
102
- text_encoder1 = pipeline.text_encoder
103
- text_encoder2 = pipeline.text_encoder_2
104
- vae = pipeline.vae
105
- unet = pipeline.unet
106
 
107
  return text_encoder1, text_encoder2, vae, unet
108
 
109
- def convert_and_save_sdxl_model(model_to_load, is_save_checkpoint, loaded_model_data, save_dtype):
110
- """Converts and saves the SDXL model as either a checkpoint or a Diffusers model."""
111
- text_encoder1, text_encoder2, vae, unet = loaded_model_data
112
- if is_save_checkpoint:
113
- save_sdxl_as_checkpoint(model_to_load, text_encoder1, text_encoder2, vae, unet, save_dtype)
114
- else:
115
- save_sdxl_as_diffusers(model_to_load, text_encoder1, text_encoder2, vae, unet, save_dtype)
116
 
117
- def save_sdxl_as_checkpoint(model_to_save, text_encoder1, text_encoder2, vae, unet, save_dtype):
118
- """Saves the SDXL model components as a checkpoint file."""
119
- # Implement saving logic here
120
- print(f"Model saved as checkpoint: {model_to_save}")
121
 
122
- def save_sdxl_as_diffusers(model_to_save, text_encoder1, text_encoder2, vae, unet, save_dtype):
123
- """Saves the SDXL model as a Diffusers model."""
 
 
 
 
 
124
  pipeline = StableDiffusionXLPipeline(
125
  vae=vae,
126
  text_encoder=text_encoder1,
127
  text_encoder_2=text_encoder2,
128
- unet=unet
 
 
 
 
129
  )
130
- pipeline.save_pretrained(model_to_save)
131
- print(f"Model saved as Diffusers format: {model_to_save}")
132
 
133
  # ---------------------- UPLOAD FUNCTION ----------------------
134
  def upload_to_huggingface(model_path, hf_token, orgs_name, model_name, make_private):
@@ -137,30 +150,22 @@ def upload_to_huggingface(model_path, hf_token, orgs_name, model_name, make_priv
137
  api = HfApi()
138
  user = api.whoami(hf_token)
139
  model_repo = create_model_repo(api, user, orgs_name, model_name, make_private)
140
-
141
- # Upload logic here
142
  print(f"Model uploaded to: https://huggingface.co/{model_repo}")
143
 
144
  # ---------------------- GRADIO INTERFACE ----------------------
145
- def main(model_to_load, save_precision_as, epoch, global_step, reference_model, output_path, fp16, hf_token, orgs_name, model_name, make_private):
146
- """Main function orchestrating the entire process."""
147
- load_dtype = get_save_dtype(save_precision_as)
148
- is_load_checkpoint = determine_load_checkpoint(model_to_load)
149
- is_save_checkpoint = not is_load_checkpoint
150
-
151
- loaded_model_data = load_sdxl_model(model_to_load, is_load_checkpoint, load_dtype)
152
- convert_and_save_sdxl_model(model_to_load, is_save_checkpoint, loaded_model_data, load_dtype)
153
  upload_to_huggingface(output_path, hf_token, orgs_name, model_name, make_private)
154
 
155
  return "Conversion and upload completed successfully!"
156
 
157
  with gr.Blocks() as demo:
158
- model_to_load = gr.Textbox(label="Model to Load (Checkpoint or Diffusers)", placeholder="Path to model")
159
- save_precision_as = gr.Dropdown(choices=["fp16", "bf16", "float"], label="Save Precision As")
160
- epoch = gr.Number(value=0, label="Epoch to Write (Checkpoint)")
161
- global_step = gr.Number(value=0, label="Global Step to Write (Checkpoint)")
162
- reference_model = gr.Textbox(label="Reference Diffusers Model", placeholder="e.g., stabilityai/stable-diffusion-xl-base-1.0")
163
- output_path = gr.Textbox(label="Output Path", value="/content/output")
164
  hf_token = gr.Textbox(label="Hugging Face Token", placeholder="Your Hugging Face write token")
165
  orgs_name = gr.Textbox(label="Organization Name (Optional)", placeholder="Your organization name")
166
  model_name = gr.Textbox(label="Model Name", placeholder="The name of your model on Hugging Face")
@@ -169,6 +174,6 @@ with gr.Blocks() as demo:
169
  convert_button = gr.Button("Convert and Upload")
170
  output = gr.Markdown()
171
 
172
- convert_button.click(fn=main, inputs=[model_to_load, save_precision_as, epoch, global_step, reference_model, output_path, fp16, hf_token, orgs_name, model_name, make_private], outputs=output)
173
 
174
  demo.launch()
 
2
  import gradio as gr
3
  import torch
4
  from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, AutoencoderKL
5
+ from transformers import CLIPTextModel, CLIPTextConfig
6
  from safetensors.torch import load_file
7
  from collections import OrderedDict
8
  import re
 
25
 
26
  # ---------------------- DEPENDENCIES ----------------------
27
  def install_dependencies_gradio():
28
+ """Installs the necessary dependencies."""
29
  try:
30
  subprocess.run(["pip", "install", "-U", "torch", "diffusers", "transformers", "accelerate", "safetensors", "huggingface_hub", "xformers"])
31
  print("Dependencies installed successfully.")
 
33
  print(f"Error installing dependencies: {e}")
34
 
35
  # ---------------------- UTILITY FUNCTIONS ----------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  def increment_filename(filename):
38
  """Increments the filename to avoid overwriting existing files."""
 
43
  counter += 1
44
  return filename
45
 
46
+ # ---------------------- UPLOAD FUNCTION ----------------------
47
  def create_model_repo(api, user, orgs_name, model_name, make_private=False):
48
+ """Creates a Hugging Face model repository."""
49
  repo_id = f"{orgs_name}/{model_name.strip()}" if orgs_name else f"{user['name']}/{model_name.strip()}"
50
  try:
 
51
  api.create_repo(repo_id=repo_id, repo_type="model", private=make_private)
52
  print(f"Model repo '{repo_id}' created.")
53
  except HfHubHTTPError:
54
  print(f"Model repo '{repo_id}' already exists.")
 
55
  return repo_id
56
 
57
  # ---------------------- MODEL LOADING AND CONVERSION ----------------------
 
 
 
 
58
 
59
+ def load_sdxl_checkpoint(checkpoint_path):
60
+ """Loads an SDXL checkpoint (.ckpt or .safetensors) and returns components."""
61
+
62
+ if checkpoint_path.endswith(".safetensors"):
63
+ state_dict = load_file(checkpoint_path, device="cpu")
64
+ elif checkpoint_path.endswith(".ckpt"):
65
+ state_dict = torch.load(checkpoint_path, map_location="cpu")["state_dict"]
66
  else:
67
+ raise ValueError("Unsupported checkpoint format. Must be .safetensors or .ckpt")
68
+
69
+ text_encoder1_state = OrderedDict()
70
+ text_encoder2_state = OrderedDict()
71
+ vae_state = OrderedDict()
72
+ unet_state = OrderedDict()
73
+
74
+ for key, value in state_dict.items():
75
+ if key.startswith("first_stage_model."): # VAE
76
+ vae_state[key.replace("first_stage_model.", "")] = value.to(torch.float16)
77
+ elif key.startswith("condition_model.model.text_encoder."): # Text Encoder 1
78
+ text_encoder1_state[key.replace("condition_model.model.text_encoder.", "")] = value.to(torch.float16)
79
+ elif key.startswith("condition_model.model.text_encoder_2."): # Text Encoder 2
80
+ text_encoder2_state[key.replace("condition_model.model.text_encoder_2.", "")] = value.to(torch.float16)
81
+ elif key.startswith("model.diffusion_model."): # UNet
82
+ unet_state[key.replace("model.diffusion_model.", "")] = value.to(torch.float16)
83
+
84
+ return text_encoder1_state, text_encoder2_state, vae_state, unet_state
85
+
86
+ def build_diffusers_model(text_encoder1_state, text_encoder2_state, vae_state, unet_state, reference_model_path=None):
87
+ """Builds the Diffusers pipeline components from the loaded state dicts."""
88
+ # --- Load configurations, create models (empty), load state dicts ---
89
+
90
+ # 1. Text Encoders
91
+ if reference_model_path:
92
+ config_text_encoder1 = CLIPTextConfig.from_pretrained(reference_model_path, subfolder="text_encoder")
93
+ config_text_encoder2 = CLIPTextConfig.from_pretrained(reference_model_path, subfolder="text_encoder_2")
94
+ else: #Default
95
+ config_text_encoder1 = CLIPTextConfig.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder")
96
+ config_text_encoder2 = CLIPTextConfig.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder_2")
97
+
98
+ text_encoder1 = CLIPTextModel(config_text_encoder1)
99
+ text_encoder2 = CLIPTextModel(config_text_encoder2)
100
+ text_encoder1.load_state_dict(text_encoder1_state)
101
+ text_encoder2.load_state_dict(text_encoder2_state)
102
+ text_encoder1.to(torch.float16) # Ensure fp16
103
+ text_encoder2.to(torch.float16)
104
+
105
+ # 2. VAE
106
+ if reference_model_path:
107
+ vae = AutoencoderKL.from_pretrained(reference_model_path, subfolder="vae")
108
+ else:
109
+ vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="vae")
110
+ vae.load_state_dict(vae_state)
111
+ vae.to(torch.float16)
112
+
113
+ # 3. UNet
114
+ if reference_model_path:
115
+ unet = UNet2DConditionModel.from_pretrained(reference_model_path, subfolder="unet")
116
+ else:
117
+ unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet")
118
 
119
+ unet.load_state_dict(unet_state)
120
+ unet.to(torch.float16)
 
 
 
 
 
121
 
122
  return text_encoder1, text_encoder2, vae, unet
123
 
 
 
 
 
 
 
 
124
 
 
 
 
 
125
 
126
+ def convert_and_save_sdxl_to_diffusers(checkpoint_path, output_path, reference_model_path):
127
+ """Converts an SDXL checkpoint to Diffusers format and saves it."""
128
+
129
+ text_encoder1_state, text_encoder2_state, vae_state, unet_state = load_sdxl_checkpoint(checkpoint_path)
130
+ text_encoder1, text_encoder2, vae, unet = build_diffusers_model(text_encoder1_state, text_encoder2_state, vae_state, unet_state, reference_model_path)
131
+
132
+
133
  pipeline = StableDiffusionXLPipeline(
134
  vae=vae,
135
  text_encoder=text_encoder1,
136
  text_encoder_2=text_encoder2,
137
+ unet=unet,
138
+ # You'll likely need to add tokenizer, scheduler, etc., here from the reference model
139
+ tokenizer = pipeline.tokenizer,
140
+ tokenizer_2 = pipeline.tokenizer_2,
141
+ scheduler = pipeline.scheduler
142
  )
143
+ pipeline.save_pretrained(output_path)
144
+ print(f"Model saved as Diffusers format: {output_path}")
145
 
146
  # ---------------------- UPLOAD FUNCTION ----------------------
147
  def upload_to_huggingface(model_path, hf_token, orgs_name, model_name, make_private):
 
150
  api = HfApi()
151
  user = api.whoami(hf_token)
152
  model_repo = create_model_repo(api, user, orgs_name, model_name, make_private)
153
+ api.upload_folder(folder_path=model_path, repo_id=model_repo) # Use upload_folder
 
154
  print(f"Model uploaded to: https://huggingface.co/{model_repo}")
155
 
156
  # ---------------------- GRADIO INTERFACE ----------------------
157
+ def main(model_to_load, reference_model, output_path, hf_token, orgs_name, model_name, make_private):
158
+ """Main function: SDXL checkpoint to Diffusers, always fp16."""
159
+
160
+ convert_and_save_sdxl_to_diffusers(model_to_load, output_path, reference_model)
 
 
 
 
161
  upload_to_huggingface(output_path, hf_token, orgs_name, model_name, make_private)
162
 
163
  return "Conversion and upload completed successfully!"
164
 
165
  with gr.Blocks() as demo:
166
+ model_to_load = gr.Textbox(label="SDXL Checkpoint to Load (.ckpt or .safetensors)", placeholder="Path to checkpoint")
167
+ reference_model = gr.Textbox(label="Reference Diffusers Model (Optional)", placeholder="e.g., stabilityai/stable-diffusion-xl-base-1.0 (Leave blank for default)")
168
+ output_path = gr.Textbox(label="Output Path (Diffusers Format)", value="/content/output") # Clarified label
 
 
 
169
  hf_token = gr.Textbox(label="Hugging Face Token", placeholder="Your Hugging Face write token")
170
  orgs_name = gr.Textbox(label="Organization Name (Optional)", placeholder="Your organization name")
171
  model_name = gr.Textbox(label="Model Name", placeholder="The name of your model on Hugging Face")
 
174
  convert_button = gr.Button("Convert and Upload")
175
  output = gr.Markdown()
176
 
177
+ convert_button.click(fn=main, inputs=[model_to_load, reference_model, output_path, hf_token, orgs_name, model_name, make_private], outputs=output)
178
 
179
  demo.launch()