Duskfallcrew commited on
Commit
a882b42
·
verified ·
1 Parent(s): 16e8869

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -13
app.py CHANGED
@@ -8,7 +8,7 @@ from collections import OrderedDict
8
  import re
9
  import json
10
  import gdown
11
- #import requests # Removed
12
  import subprocess
13
  from urllib.parse import urlparse, unquote
14
  from pathlib import Path
@@ -20,7 +20,7 @@ import shutil
20
  import hashlib
21
  from datetime import datetime
22
  from typing import Dict, List, Optional
23
- from huggingface_hub import login, HfApi, hf_hub_download, get_from_cache
24
  from huggingface_hub.utils import validate_repo_id, HFValidationError
25
  from huggingface_hub.errors import HfHubHTTPError
26
  from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE # Import HUGGINGFACE_HUB_CACHE
@@ -80,16 +80,52 @@ def create_model_repo(api, user, orgs_name, model_name, make_private=False):
80
 
81
  # ---------------------- MODEL LOADING AND CONVERSION ----------------------
82
  def download_model(model_path_or_url):
83
- """Downloads a model from a Hugging Face Hub repository."""
84
  try:
85
- # Check if it's a valid Hugging Face repo ID (and potentially a file within)
86
  try:
87
  validate_repo_id(model_path_or_url)
88
  # It's a valid repo ID; use hf_hub_download (it handles caching)
89
  local_path = hf_hub_download(repo_id=model_path_or_url)
90
  return local_path
91
  except HFValidationError:
92
- # Might be a repo ID + filename
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  try:
94
  parts = model_path_or_url.split("/", 1)
95
  if len(parts) == 2:
@@ -98,12 +134,13 @@ def download_model(model_path_or_url):
98
  local_path = hf_hub_download(repo_id=repo_id, filename=filename)
99
  return local_path
100
  else:
101
- raise ValueError("Invalid Hugging Face repository format.")
 
102
  except HFValidationError:
103
- raise ValueError(f"Invalid Hugging Face repository ID or path: {model_path_or_url}")
104
 
105
  except Exception as e:
106
- raise ValueError(f"Error downloading model: {e}")
107
 
108
 
109
  def load_sdxl_checkpoint(checkpoint_path):
@@ -269,8 +306,9 @@ with gr.Blocks(css=css) as demo:
269
  Convert SDXL checkpoints to Diffusers format (FP16, CPU-only).
270
 
271
  ### 📥 Input Sources Supported:
272
- - Only Hugging Face model repositories (e.g., 'my-org/my-model' or 'my-org/my-model/file.safetensors')
273
- - Non HF repository Gradio edition will eventually pop up on the Ktiseos Nyx github.
 
274
 
275
  ### ℹ️ Important Notes:
276
  - This tool runs on **CPU**, conversion might be slower than on GPU.
@@ -281,7 +319,6 @@ with gr.Blocks(css=css) as demo:
281
  - This space is configured for **FP16** precision to reduce memory usage.
282
  - Close other applications during conversion.
283
  - For large models, ensure you have at least 16GB of RAM.
284
-
285
 
286
  ### 💻 Source Code:
287
  - [GitHub Repository](https://github.com/Ktiseos-Nyx/Gradio-SDXL-Diffusers)
@@ -293,8 +330,8 @@ with gr.Blocks(css=css) as demo:
293
 
294
  with gr.Column(elem_id="main-container"): # Use a Column for layout
295
  model_to_load = gr.Textbox(
296
- label="SDXL Checkpoint (HF Repo)", # More specific label
297
- placeholder="Hugging Face Repo ID (e.g., my-org/my-model or my-org/my-model/file.safetensors)",
298
  )
299
  reference_model = gr.Textbox(
300
  label="Reference Diffusers Model (Optional)",
 
8
  import re
9
  import json
10
  import gdown
11
+ import requests # Re-added for URL handling
12
  import subprocess
13
  from urllib.parse import urlparse, unquote
14
  from pathlib import Path
 
20
  import hashlib
21
  from datetime import datetime
22
  from typing import Dict, List, Optional
23
+ from huggingface_hub import login, HfApi, hf_hub_download, get_from_cache # Corrected import
24
  from huggingface_hub.utils import validate_repo_id, HFValidationError
25
  from huggingface_hub.errors import HfHubHTTPError
26
  from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE # Import HUGGINGFACE_HUB_CACHE
 
80
 
81
  # ---------------------- MODEL LOADING AND CONVERSION ----------------------
82
  def download_model(model_path_or_url):
83
+ """Downloads a model, handling URLs, HF repos, and local paths, caching appropriately."""
84
  try:
85
+ # 1. Check if it's a valid Hugging Face repo ID (and potentially a file within)
86
  try:
87
  validate_repo_id(model_path_or_url)
88
  # It's a valid repo ID; use hf_hub_download (it handles caching)
89
  local_path = hf_hub_download(repo_id=model_path_or_url)
90
  return local_path
91
  except HFValidationError:
92
+ pass # Not a simple repo ID. Might be repo ID + filename, or a URL.
93
+
94
+ # 2. Check if it's a URL
95
+ if model_path_or_url.startswith("http://") or model_path_or_url.startswith(
96
+ "https://"
97
+ ):
98
+ # Check if it's already in the cache
99
+ cache_path = get_from_cache(model_path_or_url) # Use get_from_cache
100
+ if cache_path is not None:
101
+ return cache_path
102
+
103
+ # It's a URL and not in cache: download manually and put into HF cache
104
+ response = requests.get(model_path_or_url, stream=True)
105
+ response.raise_for_status() # Raise HTTPError for bad requests (4xx or 5xx)
106
+
107
+ # Get filename from URL, or use a hash if we can't determine it
108
+ parsed_url = urlparse(model_path_or_url)
109
+ filename = os.path.basename(unquote(parsed_url.path))
110
+ if not filename:
111
+ filename = hashlib.sha256(model_path_or_url.encode()).hexdigest()
112
+
113
+ # Construct the cache path (using HF_HUB_CACHE + "downloads")
114
+ cache_dir = os.path.join(HUGGINGFACE_HUB_CACHE, "downloads")
115
+ os.makedirs(cache_dir, exist_ok=True) # Ensure cache directory exists
116
+ local_path = os.path.join(cache_dir, filename)
117
+
118
+ with open(local_path, "wb") as f:
119
+ for chunk in response.iter_content(chunk_size=8192):
120
+ f.write(chunk)
121
+ return local_path
122
+
123
+ # 3. Check if it's a local file
124
+ elif os.path.isfile(model_path_or_url):
125
+ return model_path_or_url
126
+
127
+ # 4. Handle Hugging Face repo with a specific file
128
+ else:
129
  try:
130
  parts = model_path_or_url.split("/", 1)
131
  if len(parts) == 2:
 
134
  local_path = hf_hub_download(repo_id=repo_id, filename=filename)
135
  return local_path
136
  else:
137
+ raise ValueError("Invalid input format.")
138
+
139
  except HFValidationError:
140
+ raise ValueError(f"Invalid model path or URL: {model_path_or_url}")
141
 
142
  except Exception as e:
143
+ raise ValueError(f"Error downloading or accessing model: {e}")
144
 
145
 
146
  def load_sdxl_checkpoint(checkpoint_path):
 
306
  Convert SDXL checkpoints to Diffusers format (FP16, CPU-only).
307
 
308
  ### 📥 Input Sources Supported:
309
+ - Local model files (.safetensors, .ckpt)
310
+ - Direct URLs to model files
311
+ - Hugging Face model repositories (e.g., 'my-org/my-model' or 'my-org/my-model/file.safetensors')
312
 
313
  ### ℹ️ Important Notes:
314
  - This tool runs on **CPU**, conversion might be slower than on GPU.
 
319
  - This space is configured for **FP16** precision to reduce memory usage.
320
  - Close other applications during conversion.
321
  - For large models, ensure you have at least 16GB of RAM.
 
322
 
323
  ### 💻 Source Code:
324
  - [GitHub Repository](https://github.com/Ktiseos-Nyx/Gradio-SDXL-Diffusers)
 
330
 
331
  with gr.Column(elem_id="main-container"): # Use a Column for layout
332
  model_to_load = gr.Textbox(
333
+ label="SDXL Checkpoint (Path, URL, or HF Repo)",
334
+ placeholder="Path, URL, or Hugging Face Repo ID (e.g., my-org/my-model or my-org/my-model/file.safetensors)",
335
  )
336
  reference_model = gr.Textbox(
337
  label="Reference Diffusers Model (Optional)",