Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -15,12 +15,12 @@ from huggingface_hub import login, HfApi, hf_hub_download
|
|
| 15 |
from huggingface_hub.utils import validate_repo_id, HFValidationError
|
| 16 |
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
| 17 |
from huggingface_hub.utils import HfHubHTTPError
|
| 18 |
-
from accelerate import Accelerator
|
| 19 |
|
| 20 |
|
| 21 |
# ---------------------- DEPENDENCIES ----------------------
|
| 22 |
def install_dependencies_gradio():
|
| 23 |
-
"""Installs the necessary dependencies
|
| 24 |
try:
|
| 25 |
subprocess.run(
|
| 26 |
[
|
|
@@ -99,33 +99,39 @@ def download_model(model_path_or_url):
|
|
| 99 |
|
| 100 |
|
| 101 |
def create_model_repo(api, user, orgs_name, model_name, make_private=False):
|
| 102 |
-
"""Creates a Hugging Face model repository, handling missing inputs."""
|
| 103 |
|
| 104 |
-
print("---- create_model_repo Called ----")
|
| 105 |
-
print(f" user: {user}")
|
| 106 |
-
print(f" orgs_name: {orgs_name}")
|
| 107 |
-
print(f" model_name: {model_name}")
|
| 108 |
|
| 109 |
if not model_name:
|
| 110 |
model_name = f"converted-model-{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
| 111 |
-
print(f" Using default model_name: {model_name}")
|
| 112 |
|
| 113 |
if orgs_name:
|
| 114 |
repo_id = f"{orgs_name}/{model_name.strip()}"
|
| 115 |
elif user:
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
else:
|
| 118 |
-
raise ValueError(
|
|
|
|
|
|
|
| 119 |
|
| 120 |
-
print(f" repo_id: {repo_id}")
|
| 121 |
|
| 122 |
try:
|
| 123 |
api.create_repo(repo_id=repo_id, repo_type="model", private=make_private)
|
| 124 |
print(f"Model repo '{repo_id}' created.")
|
|
|
|
| 125 |
except Exception as e:
|
| 126 |
print(f"Error creating repo: {e}")
|
| 127 |
raise
|
| 128 |
-
return repo_id
|
| 129 |
|
| 130 |
def load_sdxl_checkpoint(checkpoint_path):
|
| 131 |
"""Loads checkpoint and extracts state dicts."""
|
|
@@ -207,22 +213,6 @@ def convert_and_save_sdxl_to_diffusers(checkpoint_path_or_url, output_path, refe
|
|
| 207 |
pipeline.save_pretrained(output_path)
|
| 208 |
print(f"Model saved as Diffusers format: {output_path}")
|
| 209 |
|
| 210 |
-
def upload_to_huggingface(model_path, hf_token, orgs_name, model_name, make_private):
|
| 211 |
-
"""Uploads a model to the Hugging Face Hub."""
|
| 212 |
-
print("---- upload_to_huggingface Called ----") # Debug Print
|
| 213 |
-
print(f" hf_token: {hf_token}") # Debug Print
|
| 214 |
-
print(f" orgs_name: {orgs_name}") # Debug Print
|
| 215 |
-
print(f" model_name: {model_name}") # Debug Print
|
| 216 |
-
api = HfApi()
|
| 217 |
-
# --- CRUCIAL: Log in with the token FIRST ---
|
| 218 |
-
login(token=hf_token, add_to_git_credential=True)
|
| 219 |
-
user = api.whoami() # Get the logged-in user *without* the token
|
| 220 |
-
print(f" Logged-in user: {user}") # Debug Print
|
| 221 |
-
|
| 222 |
-
model_repo = create_model_repo(api, user, orgs_name, model_name, make_private)
|
| 223 |
-
api.upload_folder(folder_path=model_path, repo_id=model_repo)
|
| 224 |
-
print(f"Model uploaded to: https://huggingface.co/{model_repo}")
|
| 225 |
-
|
| 226 |
# ---------------------- MAIN FUNCTION (with Debugging Prints) ----------------------
|
| 227 |
|
| 228 |
def main(
|
|
@@ -234,7 +224,7 @@ def main(
|
|
| 234 |
model_name,
|
| 235 |
make_private,
|
| 236 |
):
|
| 237 |
-
"""Main function: SDXL checkpoint to Diffusers,
|
| 238 |
|
| 239 |
print("---- Main Function Called ----")
|
| 240 |
print(f" model_to_load: {model_to_load}")
|
|
@@ -245,22 +235,63 @@ def main(
|
|
| 245 |
print(f" model_name: {model_name}")
|
| 246 |
print(f" make_private: {make_private}")
|
| 247 |
|
|
|
|
| 248 |
try:
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
)
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
result = "Conversion and upload completed successfully!"
|
| 256 |
print(f"---- Main Function Successful: {result} ----")
|
| 257 |
return result
|
|
|
|
| 258 |
except Exception as e:
|
| 259 |
error_message = f"An error occurred: {e}"
|
| 260 |
print(f"---- Main Function Error: {error_message} ----")
|
| 261 |
return error_message
|
| 262 |
|
| 263 |
-
# ---------------------- GRADIO INTERFACE
|
| 264 |
|
| 265 |
css = """
|
| 266 |
#main-container {
|
|
@@ -286,15 +317,18 @@ with gr.Blocks(css=css) as demo:
|
|
| 286 |
- Direct URLs to model files
|
| 287 |
- Hugging Face model repositories (e.g., 'my-org/my-model' or 'my-org/my-model/file.safetensors')
|
| 288 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
### ℹ️ Important Notes:
|
| 290 |
- This tool runs on **CPU**, conversion might be slower than on GPU.
|
| 291 |
- For Hugging Face uploads, you need a **WRITE** token (not a read token).
|
| 292 |
-
- Get your HF token here: [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)
|
| 293 |
-
|
| 294 |
-
### 💾 Memory Usage:
|
| 295 |
- This space is configured for **FP16** precision to reduce memory usage.
|
| 296 |
-
- Close other applications during conversion.
|
| 297 |
-
- For large models, ensure you have at least 16GB of RAM.
|
| 298 |
|
| 299 |
### 💻 Source Code:
|
| 300 |
- [GitHub Repository](https://github.com/Ktiseos-Nyx/Gradio-SDXL-Diffusers)
|
|
@@ -307,11 +341,11 @@ with gr.Blocks(css=css) as demo:
|
|
| 307 |
with gr.Row():
|
| 308 |
with gr.Column():
|
| 309 |
model_to_load = gr.Textbox(
|
| 310 |
-
label="SDXL Checkpoint (Path, URL, or HF Repo)",
|
| 311 |
placeholder="Path, URL, or Hugging Face Repo ID (e.g., my-org/my-model or my-org/my-model/file.safetensors)",
|
| 312 |
)
|
| 313 |
reference_model = gr.Textbox(
|
| 314 |
-
label="Reference Diffusers Model (Optional)",
|
| 315 |
placeholder="e.g., stabilityai/stable-diffusion-xl-base-1.0 (Leave blank for default)",
|
| 316 |
)
|
| 317 |
output_path = gr.Textbox(label="Output Path (Diffusers Format)", value="output")
|
|
@@ -324,7 +358,6 @@ with gr.Blocks(css=css) as demo:
|
|
| 324 |
with gr.Column(variant="panel"):
|
| 325 |
output = gr.Markdown(container=True)
|
| 326 |
|
| 327 |
-
# --- CORRECT BUTTON CLICK PLACEMENT ---
|
| 328 |
convert_button.click(
|
| 329 |
fn=main,
|
| 330 |
inputs=[
|
|
|
|
| 15 |
from huggingface_hub.utils import validate_repo_id, HFValidationError
|
| 16 |
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
| 17 |
from huggingface_hub.utils import HfHubHTTPError
|
| 18 |
+
from accelerate import Accelerator
|
| 19 |
|
| 20 |
|
| 21 |
# ---------------------- DEPENDENCIES ----------------------
|
| 22 |
def install_dependencies_gradio():
|
| 23 |
+
"""Installs the necessary dependencies."""
|
| 24 |
try:
|
| 25 |
subprocess.run(
|
| 26 |
[
|
|
|
|
| 99 |
|
| 100 |
|
| 101 |
def create_model_repo(api, user, orgs_name, model_name, make_private=False):
|
| 102 |
+
"""Creates a Hugging Face model repository, handling missing inputs and sanitizing the username."""
|
| 103 |
|
| 104 |
+
print("---- create_model_repo Called ----")
|
| 105 |
+
print(f" user: {user}")
|
| 106 |
+
print(f" orgs_name: {orgs_name}")
|
| 107 |
+
print(f" model_name: {model_name}")
|
| 108 |
|
| 109 |
if not model_name:
|
| 110 |
model_name = f"converted-model-{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
| 111 |
+
print(f" Using default model_name: {model_name}")
|
| 112 |
|
| 113 |
if orgs_name:
|
| 114 |
repo_id = f"{orgs_name}/{model_name.strip()}"
|
| 115 |
elif user:
|
| 116 |
+
# --- Sanitize the username ---
|
| 117 |
+
sanitized_username = re.sub(r"[^a-zA-Z0-9._-]", "-", user['name']) # Replace invalid chars with hyphens
|
| 118 |
+
print(f" Original Username: {user['name']}") #Debugging
|
| 119 |
+
print(f" Sanitized Username: {sanitized_username}") #Debugging
|
| 120 |
+
repo_id = f"{sanitized_username}/{model_name.strip()}"
|
| 121 |
else:
|
| 122 |
+
raise ValueError(
|
| 123 |
+
"Must provide either an organization name or be logged in."
|
| 124 |
+
)
|
| 125 |
|
| 126 |
+
print(f" repo_id: {repo_id}")
|
| 127 |
|
| 128 |
try:
|
| 129 |
api.create_repo(repo_id=repo_id, repo_type="model", private=make_private)
|
| 130 |
print(f"Model repo '{repo_id}' created.")
|
| 131 |
+
return repo_id
|
| 132 |
except Exception as e:
|
| 133 |
print(f"Error creating repo: {e}")
|
| 134 |
raise
|
|
|
|
| 135 |
|
| 136 |
def load_sdxl_checkpoint(checkpoint_path):
|
| 137 |
"""Loads checkpoint and extracts state dicts."""
|
|
|
|
| 213 |
pipeline.save_pretrained(output_path)
|
| 214 |
print(f"Model saved as Diffusers format: {output_path}")
|
| 215 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
# ---------------------- MAIN FUNCTION (with Debugging Prints) ----------------------
|
| 217 |
|
| 218 |
def main(
|
|
|
|
| 224 |
model_name,
|
| 225 |
make_private,
|
| 226 |
):
|
| 227 |
+
"""Main function: SDXL checkpoint to Diffusers, always fp16."""
|
| 228 |
|
| 229 |
print("---- Main Function Called ----")
|
| 230 |
print(f" model_to_load: {model_to_load}")
|
|
|
|
| 235 |
print(f" model_name: {model_name}")
|
| 236 |
print(f" make_private: {make_private}")
|
| 237 |
|
| 238 |
+
# --- Force Login at the Beginning of main() ---
|
| 239 |
try:
|
| 240 |
+
login(token=hf_token, add_to_git_credential=True)
|
| 241 |
+
api = HfApi()
|
| 242 |
+
user = api.whoami() # Get logged-in user info
|
| 243 |
+
print(f" Logged-in user: {user}")
|
| 244 |
+
except Exception as e:
|
| 245 |
+
error_message = f"Error during login: {e} Ensure a valid WRITE token is provided."
|
| 246 |
+
print(f"---- Main Function Error: {error_message} ----")
|
| 247 |
+
return error_message
|
| 248 |
+
|
| 249 |
+
# --- Strip Whitespace from Inputs ---
|
| 250 |
+
model_to_load = model_to_load.strip()
|
| 251 |
+
reference_model = reference_model.strip()
|
| 252 |
+
output_path = output_path.strip()
|
| 253 |
+
hf_token = hf_token.strip() # Even though it's a password field, good practice
|
| 254 |
+
orgs_name = orgs_name.strip() if orgs_name else "" #Handle empty
|
| 255 |
+
model_name = model_name.strip() if model_name else "" #Handle empty
|
| 256 |
+
|
| 257 |
+
try:
|
| 258 |
+
convert_and_save_sdxl_to_diffusers(model_to_load, output_path, reference_model)
|
| 259 |
+
|
| 260 |
+
# --- Create Repo and Upload (Simplified) ---
|
| 261 |
+
if not model_name:
|
| 262 |
+
model_name = f"converted-model-{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
| 263 |
+
print(f"Using default model_name: {model_name}")
|
| 264 |
+
|
| 265 |
+
if orgs_name:
|
| 266 |
+
repo_id = f"{orgs_name}/{model_name.strip()}"
|
| 267 |
+
elif user:
|
| 268 |
+
# Sanitize username here as well:
|
| 269 |
+
sanitized_username = re.sub(r"[^a-zA-Z0-9._-]", "-", user['name'])
|
| 270 |
+
repo_id = f"{sanitized_username}/{model_name.strip()}"
|
| 271 |
+
|
| 272 |
+
else: # Should never happen because of login, but good practice
|
| 273 |
+
raise ValueError("Must provide either an organization name or be logged in.")
|
| 274 |
+
print(f"repo_id = {repo_id}")
|
| 275 |
+
try:
|
| 276 |
+
api.create_repo(repo_id=repo_id, repo_type="model", private=make_private)
|
| 277 |
+
print(f"Model repo '{repo_id}' created.")
|
| 278 |
+
except Exception as e:
|
| 279 |
+
print(f"Error in creating model repo: {e}")
|
| 280 |
+
raise
|
| 281 |
+
|
| 282 |
+
api.upload_folder(folder_path=output_path, repo_id=repo_id)
|
| 283 |
+
print(f"Model uploaded to: https://huggingface.co/{repo_id}")
|
| 284 |
+
|
| 285 |
result = "Conversion and upload completed successfully!"
|
| 286 |
print(f"---- Main Function Successful: {result} ----")
|
| 287 |
return result
|
| 288 |
+
|
| 289 |
except Exception as e:
|
| 290 |
error_message = f"An error occurred: {e}"
|
| 291 |
print(f"---- Main Function Error: {error_message} ----")
|
| 292 |
return error_message
|
| 293 |
|
| 294 |
+
# ---------------------- GRADIO INTERFACE ----------------------
|
| 295 |
|
| 296 |
css = """
|
| 297 |
#main-container {
|
|
|
|
| 317 |
- Direct URLs to model files
|
| 318 |
- Hugging Face model repositories (e.g., 'my-org/my-model' or 'my-org/my-model/file.safetensors')
|
| 319 |
|
| 320 |
+
### How To Use
|
| 321 |
+
- Insert URL or Repository Information into Field 1 SDXL Checkpoint (Path, URL, or HF Repo)
|
| 322 |
+
- Optional: Insert Reference Diffusers Model (Optional)
|
| 323 |
+
- Optional: Output Path (Diffusers Format)
|
| 324 |
+
- Insert your HF Token WRITE: Get your HF token here: [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)
|
| 325 |
+
- Organization Name (Optional)
|
| 326 |
+
- Create a Model Name for Output Purposes
|
| 327 |
+
|
| 328 |
### ℹ️ Important Notes:
|
| 329 |
- This tool runs on **CPU**, conversion might be slower than on GPU.
|
| 330 |
- For Hugging Face uploads, you need a **WRITE** token (not a read token).
|
|
|
|
|
|
|
|
|
|
| 331 |
- This space is configured for **FP16** precision to reduce memory usage.
|
|
|
|
|
|
|
| 332 |
|
| 333 |
### 💻 Source Code:
|
| 334 |
- [GitHub Repository](https://github.com/Ktiseos-Nyx/Gradio-SDXL-Diffusers)
|
|
|
|
| 341 |
with gr.Row():
|
| 342 |
with gr.Column():
|
| 343 |
model_to_load = gr.Textbox(
|
| 344 |
+
label="SDXL Checkpoint (Path, URL, or HF Repo)", # Corrected Label
|
| 345 |
placeholder="Path, URL, or Hugging Face Repo ID (e.g., my-org/my-model or my-org/my-model/file.safetensors)",
|
| 346 |
)
|
| 347 |
reference_model = gr.Textbox(
|
| 348 |
+
label="Reference Diffusers Model (Optional)", # Corrected Label
|
| 349 |
placeholder="e.g., stabilityai/stable-diffusion-xl-base-1.0 (Leave blank for default)",
|
| 350 |
)
|
| 351 |
output_path = gr.Textbox(label="Output Path (Diffusers Format)", value="output")
|
|
|
|
| 358 |
with gr.Column(variant="panel"):
|
| 359 |
output = gr.Markdown(container=True)
|
| 360 |
|
|
|
|
| 361 |
convert_button.click(
|
| 362 |
fn=main,
|
| 363 |
inputs=[
|