Duskfallcrew commited on
Commit
54694b3
·
verified ·
1 Parent(s): f90283a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -44
app.py CHANGED
@@ -15,12 +15,12 @@ from huggingface_hub import login, HfApi, hf_hub_download
15
  from huggingface_hub.utils import validate_repo_id, HFValidationError
16
  from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
17
  from huggingface_hub.utils import HfHubHTTPError
18
- from accelerate import Accelerator # Import accelerate
19
 
20
 
21
  # ---------------------- DEPENDENCIES ----------------------
22
  def install_dependencies_gradio():
23
- """Installs the necessary dependencies, including accelerate."""
24
  try:
25
  subprocess.run(
26
  [
@@ -99,33 +99,39 @@ def download_model(model_path_or_url):
99
 
100
 
101
  def create_model_repo(api, user, orgs_name, model_name, make_private=False):
102
- """Creates a Hugging Face model repository, handling missing inputs."""
103
 
104
- print("---- create_model_repo Called ----") # Debug Print
105
- print(f" user: {user}") # Debug Print
106
- print(f" orgs_name: {orgs_name}") # Debug Print
107
- print(f" model_name: {model_name}") # Debug Print
108
 
109
  if not model_name:
110
  model_name = f"converted-model-{datetime.now().strftime('%Y%m%d%H%M%S')}"
111
- print(f" Using default model_name: {model_name}") # Debug Print
112
 
113
  if orgs_name:
114
  repo_id = f"{orgs_name}/{model_name.strip()}"
115
  elif user:
116
- repo_id = f"{user['name']}/{model_name.strip()}"
 
 
 
 
117
  else:
118
- raise ValueError("Must provide either an organization name or be logged in.")
 
 
119
 
120
- print(f" repo_id: {repo_id}") # Debug Print
121
 
122
  try:
123
  api.create_repo(repo_id=repo_id, repo_type="model", private=make_private)
124
  print(f"Model repo '{repo_id}' created.")
 
125
  except Exception as e:
126
  print(f"Error creating repo: {e}")
127
  raise
128
- return repo_id
129
 
130
  def load_sdxl_checkpoint(checkpoint_path):
131
  """Loads checkpoint and extracts state dicts."""
@@ -207,22 +213,6 @@ def convert_and_save_sdxl_to_diffusers(checkpoint_path_or_url, output_path, refe
207
  pipeline.save_pretrained(output_path)
208
  print(f"Model saved as Diffusers format: {output_path}")
209
 
210
- def upload_to_huggingface(model_path, hf_token, orgs_name, model_name, make_private):
211
- """Uploads a model to the Hugging Face Hub."""
212
- print("---- upload_to_huggingface Called ----") # Debug Print
213
- print(f" hf_token: {hf_token}") # Debug Print
214
- print(f" orgs_name: {orgs_name}") # Debug Print
215
- print(f" model_name: {model_name}") # Debug Print
216
- api = HfApi()
217
- # --- CRUCIAL: Log in with the token FIRST ---
218
- login(token=hf_token, add_to_git_credential=True)
219
- user = api.whoami() # Get the logged-in user *without* the token
220
- print(f" Logged-in user: {user}") # Debug Print
221
-
222
- model_repo = create_model_repo(api, user, orgs_name, model_name, make_private)
223
- api.upload_folder(folder_path=model_path, repo_id=model_repo)
224
- print(f"Model uploaded to: https://huggingface.co/{model_repo}")
225
-
226
  # ---------------------- MAIN FUNCTION (with Debugging Prints) ----------------------
227
 
228
  def main(
@@ -234,7 +224,7 @@ def main(
234
  model_name,
235
  make_private,
236
  ):
237
- """Main function: SDXL checkpoint to Diffusers, with debugging prints."""
238
 
239
  print("---- Main Function Called ----")
240
  print(f" model_to_load: {model_to_load}")
@@ -245,22 +235,63 @@ def main(
245
  print(f" model_name: {model_name}")
246
  print(f" make_private: {make_private}")
247
 
 
248
  try:
249
- convert_and_save_sdxl_to_diffusers(
250
- model_to_load, output_path, reference_model
251
- )
252
- upload_to_huggingface(
253
- output_path, hf_token, orgs_name, model_name, make_private
254
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  result = "Conversion and upload completed successfully!"
256
  print(f"---- Main Function Successful: {result} ----")
257
  return result
 
258
  except Exception as e:
259
  error_message = f"An error occurred: {e}"
260
  print(f"---- Main Function Error: {error_message} ----")
261
  return error_message
262
 
263
- # ---------------------- GRADIO INTERFACE (Corrected Button Placement) ----------------------
264
 
265
  css = """
266
  #main-container {
@@ -286,15 +317,18 @@ with gr.Blocks(css=css) as demo:
286
  - Direct URLs to model files
287
  - Hugging Face model repositories (e.g., 'my-org/my-model' or 'my-org/my-model/file.safetensors')
288
 
 
 
 
 
 
 
 
 
289
  ### ℹ️ Important Notes:
290
  - This tool runs on **CPU**, conversion might be slower than on GPU.
291
  - For Hugging Face uploads, you need a **WRITE** token (not a read token).
292
- - Get your HF token here: [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)
293
-
294
- ### 💾 Memory Usage:
295
  - This space is configured for **FP16** precision to reduce memory usage.
296
- - Close other applications during conversion.
297
- - For large models, ensure you have at least 16GB of RAM.
298
 
299
  ### 💻 Source Code:
300
  - [GitHub Repository](https://github.com/Ktiseos-Nyx/Gradio-SDXL-Diffusers)
@@ -307,11 +341,11 @@ with gr.Blocks(css=css) as demo:
307
  with gr.Row():
308
  with gr.Column():
309
  model_to_load = gr.Textbox(
310
- label="SDXL Checkpoint (Path, URL, or HF Repo)",
311
  placeholder="Path, URL, or Hugging Face Repo ID (e.g., my-org/my-model or my-org/my-model/file.safetensors)",
312
  )
313
  reference_model = gr.Textbox(
314
- label="Reference Diffusers Model (Optional)",
315
  placeholder="e.g., stabilityai/stable-diffusion-xl-base-1.0 (Leave blank for default)",
316
  )
317
  output_path = gr.Textbox(label="Output Path (Diffusers Format)", value="output")
@@ -324,7 +358,6 @@ with gr.Blocks(css=css) as demo:
324
  with gr.Column(variant="panel"):
325
  output = gr.Markdown(container=True)
326
 
327
- # --- CORRECT BUTTON CLICK PLACEMENT ---
328
  convert_button.click(
329
  fn=main,
330
  inputs=[
 
15
  from huggingface_hub.utils import validate_repo_id, HFValidationError
16
  from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
17
  from huggingface_hub.utils import HfHubHTTPError
18
+ from accelerate import Accelerator
19
 
20
 
21
  # ---------------------- DEPENDENCIES ----------------------
22
  def install_dependencies_gradio():
23
+ """Installs the necessary dependencies."""
24
  try:
25
  subprocess.run(
26
  [
 
99
 
100
 
101
  def create_model_repo(api, user, orgs_name, model_name, make_private=False):
102
+ """Creates a Hugging Face model repository, handling missing inputs and sanitizing the username."""
103
 
104
+ print("---- create_model_repo Called ----")
105
+ print(f" user: {user}")
106
+ print(f" orgs_name: {orgs_name}")
107
+ print(f" model_name: {model_name}")
108
 
109
  if not model_name:
110
  model_name = f"converted-model-{datetime.now().strftime('%Y%m%d%H%M%S')}"
111
+ print(f" Using default model_name: {model_name}")
112
 
113
  if orgs_name:
114
  repo_id = f"{orgs_name}/{model_name.strip()}"
115
  elif user:
116
+ # --- Sanitize the username ---
117
+ sanitized_username = re.sub(r"[^a-zA-Z0-9._-]", "-", user['name']) # Replace invalid chars with hyphens
118
+ print(f" Original Username: {user['name']}") #Debugging
119
+ print(f" Sanitized Username: {sanitized_username}") #Debugging
120
+ repo_id = f"{sanitized_username}/{model_name.strip()}"
121
  else:
122
+ raise ValueError(
123
+ "Must provide either an organization name or be logged in."
124
+ )
125
 
126
+ print(f" repo_id: {repo_id}")
127
 
128
  try:
129
  api.create_repo(repo_id=repo_id, repo_type="model", private=make_private)
130
  print(f"Model repo '{repo_id}' created.")
131
+ return repo_id
132
  except Exception as e:
133
  print(f"Error creating repo: {e}")
134
  raise
 
135
 
136
  def load_sdxl_checkpoint(checkpoint_path):
137
  """Loads checkpoint and extracts state dicts."""
 
213
  pipeline.save_pretrained(output_path)
214
  print(f"Model saved as Diffusers format: {output_path}")
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  # ---------------------- MAIN FUNCTION (with Debugging Prints) ----------------------
217
 
218
  def main(
 
224
  model_name,
225
  make_private,
226
  ):
227
+ """Main function: SDXL checkpoint to Diffusers, always fp16."""
228
 
229
  print("---- Main Function Called ----")
230
  print(f" model_to_load: {model_to_load}")
 
235
  print(f" model_name: {model_name}")
236
  print(f" make_private: {make_private}")
237
 
238
+ # --- Force Login at the Beginning of main() ---
239
  try:
240
+ login(token=hf_token, add_to_git_credential=True)
241
+ api = HfApi()
242
+ user = api.whoami() # Get logged-in user info
243
+ print(f" Logged-in user: {user}")
244
+ except Exception as e:
245
+ error_message = f"Error during login: {e} Ensure a valid WRITE token is provided."
246
+ print(f"---- Main Function Error: {error_message} ----")
247
+ return error_message
248
+
249
+ # --- Strip Whitespace from Inputs ---
250
+ model_to_load = model_to_load.strip()
251
+ reference_model = reference_model.strip()
252
+ output_path = output_path.strip()
253
+ hf_token = hf_token.strip() # Even though it's a password field, good practice
254
+ orgs_name = orgs_name.strip() if orgs_name else "" #Handle empty
255
+ model_name = model_name.strip() if model_name else "" #Handle empty
256
+
257
+ try:
258
+ convert_and_save_sdxl_to_diffusers(model_to_load, output_path, reference_model)
259
+
260
+ # --- Create Repo and Upload (Simplified) ---
261
+ if not model_name:
262
+ model_name = f"converted-model-{datetime.now().strftime('%Y%m%d%H%M%S')}"
263
+ print(f"Using default model_name: {model_name}")
264
+
265
+ if orgs_name:
266
+ repo_id = f"{orgs_name}/{model_name.strip()}"
267
+ elif user:
268
+ # Sanitize username here as well:
269
+ sanitized_username = re.sub(r"[^a-zA-Z0-9._-]", "-", user['name'])
270
+ repo_id = f"{sanitized_username}/{model_name.strip()}"
271
+
272
+ else: # Should never happen because of login, but good practice
273
+ raise ValueError("Must provide either an organization name or be logged in.")
274
+ print(f"repo_id = {repo_id}")
275
+ try:
276
+ api.create_repo(repo_id=repo_id, repo_type="model", private=make_private)
277
+ print(f"Model repo '{repo_id}' created.")
278
+ except Exception as e:
279
+ print(f"Error in creating model repo: {e}")
280
+ raise
281
+
282
+ api.upload_folder(folder_path=output_path, repo_id=repo_id)
283
+ print(f"Model uploaded to: https://huggingface.co/{repo_id}")
284
+
285
  result = "Conversion and upload completed successfully!"
286
  print(f"---- Main Function Successful: {result} ----")
287
  return result
288
+
289
  except Exception as e:
290
  error_message = f"An error occurred: {e}"
291
  print(f"---- Main Function Error: {error_message} ----")
292
  return error_message
293
 
294
+ # ---------------------- GRADIO INTERFACE ----------------------
295
 
296
  css = """
297
  #main-container {
 
317
  - Direct URLs to model files
318
  - Hugging Face model repositories (e.g., 'my-org/my-model' or 'my-org/my-model/file.safetensors')
319
 
320
+ ### How To Use
321
+ - Insert URL or Repository Information into Field 1 SDXL Checkpoint (Path, URL, or HF Repo)
322
+ - Optional: Insert Reference Diffusers Model (Optional)
323
+ - Optional: Output Path (Diffusers Format)
324
+ - Insert your HF Token WRITE: Get your HF token here: [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)
325
+ - Organization Name (Optional)
326
+ - Create a Model Name for Output Purposes
327
+
328
  ### ℹ️ Important Notes:
329
  - This tool runs on **CPU**, conversion might be slower than on GPU.
330
  - For Hugging Face uploads, you need a **WRITE** token (not a read token).
 
 
 
331
  - This space is configured for **FP16** precision to reduce memory usage.
 
 
332
 
333
  ### 💻 Source Code:
334
  - [GitHub Repository](https://github.com/Ktiseos-Nyx/Gradio-SDXL-Diffusers)
 
341
  with gr.Row():
342
  with gr.Column():
343
  model_to_load = gr.Textbox(
344
+ label="SDXL Checkpoint (Path, URL, or HF Repo)", # Corrected Label
345
  placeholder="Path, URL, or Hugging Face Repo ID (e.g., my-org/my-model or my-org/my-model/file.safetensors)",
346
  )
347
  reference_model = gr.Textbox(
348
+ label="Reference Diffusers Model (Optional)", # Corrected Label
349
  placeholder="e.g., stabilityai/stable-diffusion-xl-base-1.0 (Leave blank for default)",
350
  )
351
  output_path = gr.Textbox(label="Output Path (Diffusers Format)", value="output")
 
358
  with gr.Column(variant="panel"):
359
  output = gr.Markdown(container=True)
360
 
 
361
  convert_button.click(
362
  fn=main,
363
  inputs=[