Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	Update app.py (#18)
Browse files- Update app.py (ed734c18d4564c9f50b0ffd4168fc692c2ee5239)
    	
        app.py
    CHANGED
    
    | 
         @@ -15,8 +15,9 @@ import hashlib 
     | 
|
| 15 | 
         
             
            from datetime import datetime
         
     | 
| 16 | 
         
             
            from typing import Dict, List, Optional
         
     | 
| 17 | 
         
             
            from huggingface_hub import login, HfApi, hf_hub_download
         
     | 
| 18 | 
         
            -
            from huggingface_hub.utils import validate_repo_id, HFValidationError  
     | 
| 19 | 
         
             
            from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
         
     | 
| 
         | 
|
| 20 | 
         | 
| 21 | 
         | 
| 22 | 
         
             
            # ---------------------- DEPENDENCIES ----------------------
         
     | 
| 
         @@ -66,7 +67,7 @@ def create_model_repo(api, user, orgs_name, model_name, make_private=False): 
     | 
|
| 66 | 
         
             
                try:
         
     | 
| 67 | 
         
             
                    api.create_repo(repo_id=repo_id, repo_type="model", private=make_private)
         
     | 
| 68 | 
         
             
                    print(f"Model repo '{repo_id}' created.")
         
     | 
| 69 | 
         
            -
                except HfHubHTTPError:
         
     | 
| 70 | 
         
             
                    print(f"Model repo '{repo_id}' already exists.")
         
     | 
| 71 | 
         
             
                return repo_id
         
     | 
| 72 | 
         | 
| 
         @@ -82,7 +83,7 @@ def download_model(model_path_or_url): 
     | 
|
| 82 | 
         
             
                        local_path = hf_hub_download(repo_id=model_path_or_url)
         
     | 
| 83 | 
         
             
                        return local_path
         
     | 
| 84 | 
         
             
                    except HFValidationError:
         
     | 
| 85 | 
         
            -
                        pass  # Not a simple repo ID. 
     | 
| 86 | 
         | 
| 87 | 
         
             
                    # 2. Check if it's a URL
         
     | 
| 88 | 
         
             
                    if model_path_or_url.startswith("http://") or model_path_or_url.startswith(
         
     | 
| 
         @@ -177,7 +178,11 @@ def load_sdxl_checkpoint(checkpoint_path): 
     | 
|
| 177 | 
         | 
| 178 | 
         | 
| 179 | 
         
             
            def build_diffusers_model(
         
     | 
| 180 | 
         
            -
                text_encoder1_state, 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 181 | 
         
             
            ):
         
     | 
| 182 | 
         
             
                """Builds the Diffusers pipeline components from the loaded state dicts."""
         
     | 
| 183 | 
         | 
| 
         @@ -253,21 +258,33 @@ def convert_and_save_sdxl_to_diffusers( 
     | 
|
| 253 | 
         
             
            # ---------------------- UPLOAD FUNCTION ----------------------
         
     | 
| 254 | 
         
             
            def upload_to_huggingface(model_path, hf_token, orgs_name, model_name, make_private):
         
     | 
| 255 | 
         
             
                """Uploads a model to the Hugging Face Hub."""
         
     | 
| 256 | 
         
            -
                login(hf_token, add_to_git_credential=True)
         
     | 
| 257 | 
         
             
                api = HfApi()
         
     | 
| 258 | 
         
            -
                user = api.whoami(hf_token)
         
     | 
| 259 | 
         
             
                model_repo = create_model_repo(api, user, orgs_name, model_name, make_private)
         
     | 
| 260 | 
         
             
                api.upload_folder(folder_path=model_path, repo_id=model_repo)
         
     | 
| 261 | 
         
             
                print(f"Model uploaded to: https://huggingface.co/{model_repo}")
         
     | 
| 262 | 
         | 
| 263 | 
         | 
| 264 | 
         
             
            # ---------------------- GRADIO INTERFACE ----------------------
         
     | 
| 265 | 
         
            -
            def main( 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 266 | 
         
             
                """Main function: SDXL checkpoint to Diffusers, always fp16."""
         
     | 
| 267 | 
         | 
| 268 | 
         
             
                try:
         
     | 
| 269 | 
         
            -
                    convert_and_save_sdxl_to_diffusers( 
     | 
| 270 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 271 | 
         
             
                    return "Conversion and upload completed successfully!"
         
     | 
| 272 | 
         
             
                except Exception as e:
         
     | 
| 273 | 
         
             
                    return f"An error occurred: {e}"  # Return the error message
         
     | 
| 
         @@ -277,14 +294,14 @@ css = """ 
     | 
|
| 277 | 
         
             
            #main-container {
         
     | 
| 278 | 
         
             
                display: flex;
         
     | 
| 279 | 
         
             
                flex-direction: column;
         
     | 
| 280 | 
         
            -
                height: 100vh;
         
     | 
| 281 | 
         
            -
                justify-content: space-between;
         
     | 
| 282 | 
         
             
                font-family: 'Arial', sans-serif;
         
     | 
| 283 | 
         
             
                font-size: 16px;
         
     | 
| 284 | 
         
             
                color: #333;
         
     | 
| 285 | 
         
             
            }
         
     | 
| 286 | 
         
             
            #convert-button {
         
     | 
| 287 | 
         
            -
                margin-top:  
     | 
| 288 | 
         
             
            }
         
     | 
| 289 | 
         
             
            """
         
     | 
| 290 | 
         | 
| 
         @@ -317,44 +334,48 @@ with gr.Blocks(css=css) as demo: 
     | 
|
| 317 | 
         
             
                """
         
     | 
| 318 | 
         
             
                )
         
     | 
| 319 | 
         | 
| 320 | 
         
            -
                with gr. 
     | 
| 321 | 
         
            -
                     
     | 
| 322 | 
         
            -
             
     | 
| 323 | 
         
            -
                         
     | 
| 324 | 
         
            -
             
     | 
| 325 | 
         
            -
             
     | 
| 326 | 
         
            -
                         
     | 
| 327 | 
         
            -
                         
     | 
| 328 | 
         
            -
             
     | 
| 329 | 
         
            -
             
     | 
| 330 | 
         
            -
                         
     | 
| 331 | 
         
            -
             
     | 
| 332 | 
         
            -
             
     | 
| 333 | 
         
            -
                         
     | 
| 334 | 
         
            -
             
     | 
| 335 | 
         
            -
             
     | 
| 336 | 
         
            -
                         
     | 
| 337 | 
         
            -
             
     | 
| 338 | 
         
            -
             
     | 
| 339 | 
         
            -
                         
     | 
| 340 | 
         
            -
             
     | 
| 341 | 
         
            -
             
     | 
| 342 | 
         
            -
             
     | 
| 343 | 
         
            -
             
     | 
| 344 | 
         
            -
             
     | 
| 345 | 
         
            -
             
     | 
| 346 | 
         
            -
             
     | 
| 347 | 
         
            -
             
     | 
| 348 | 
         
            -
                         
     | 
| 349 | 
         
            -
             
     | 
| 350 | 
         
            -
             
     | 
| 351 | 
         
            -
             
     | 
| 352 | 
         
            -
             
     | 
| 353 | 
         
            -
             
     | 
| 354 | 
         
            -
             
     | 
| 355 | 
         
            -
             
     | 
| 356 | 
         
            -
                         
     | 
| 357 | 
         
            -
                         
     | 
| 358 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 359 | 
         | 
| 360 | 
         
             
            demo.launch()
         
     | 
| 
         | 
|
| 15 | 
         
             
            from datetime import datetime
         
     | 
| 16 | 
         
             
            from typing import Dict, List, Optional
         
     | 
| 17 | 
         
             
            from huggingface_hub import login, HfApi, hf_hub_download
         
     | 
| 18 | 
         
            +
            from huggingface_hub.utils import validate_repo_id, HFValidationError  # Removed get_from_cache
         
     | 
| 19 | 
         
             
            from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
         
     | 
| 20 | 
         
            +
            from huggingface_hub.utils import HfHubHTTPError
         
     | 
| 21 | 
         | 
| 22 | 
         | 
| 23 | 
         
             
            # ---------------------- DEPENDENCIES ----------------------
         
     | 
| 
         | 
|
| 67 | 
         
             
                try:
         
     | 
| 68 | 
         
             
                    api.create_repo(repo_id=repo_id, repo_type="model", private=make_private)
         
     | 
| 69 | 
         
             
                    print(f"Model repo '{repo_id}' created.")
         
     | 
| 70 | 
         
            +
                except HfHubHTTPError:  # Corrected the exception name
         
     | 
| 71 | 
         
             
                    print(f"Model repo '{repo_id}' already exists.")
         
     | 
| 72 | 
         
             
                return repo_id
         
     | 
| 73 | 
         | 
| 
         | 
|
| 83 | 
         
             
                        local_path = hf_hub_download(repo_id=model_path_or_url)
         
     | 
| 84 | 
         
             
                        return local_path
         
     | 
| 85 | 
         
             
                    except HFValidationError:
         
     | 
| 86 | 
         
            +
                        pass  # Not a simple repo ID.  Might be repo ID + filename, or a URL.
         
     | 
| 87 | 
         | 
| 88 | 
         
             
                    # 2. Check if it's a URL
         
     | 
| 89 | 
         
             
                    if model_path_or_url.startswith("http://") or model_path_or_url.startswith(
         
     | 
| 
         | 
|
| 178 | 
         | 
| 179 | 
         | 
| 180 | 
         
             
            def build_diffusers_model(
         
     | 
| 181 | 
         
            +
                text_encoder1_state,
         
     | 
| 182 | 
         
            +
                text_encoder2_state,
         
     | 
| 183 | 
         
            +
                vae_state,
         
     | 
| 184 | 
         
            +
                unet_state,
         
     | 
| 185 | 
         
            +
                reference_model_path=None,
         
     | 
| 186 | 
         
             
            ):
         
     | 
| 187 | 
         
             
                """Builds the Diffusers pipeline components from the loaded state dicts."""
         
     | 
| 188 | 
         | 
| 
         | 
|
| 258 | 
         
             
            # ---------------------- UPLOAD FUNCTION ----------------------
         
     | 
| 259 | 
         
             
            def upload_to_huggingface(model_path, hf_token, orgs_name, model_name, make_private):
         
     | 
| 260 | 
         
             
                """Uploads a model to the Hugging Face Hub."""
         
     | 
| 261 | 
         
            +
                login(token=hf_token, add_to_git_credential=True)
         
     | 
| 262 | 
         
             
                api = HfApi()
         
     | 
| 263 | 
         
            +
                user = api.whoami(token=hf_token)
         
     | 
| 264 | 
         
             
                model_repo = create_model_repo(api, user, orgs_name, model_name, make_private)
         
     | 
| 265 | 
         
             
                api.upload_folder(folder_path=model_path, repo_id=model_repo)
         
     | 
| 266 | 
         
             
                print(f"Model uploaded to: https://huggingface.co/{model_repo}")
         
     | 
| 267 | 
         | 
| 268 | 
         | 
| 269 | 
         
             
            # ---------------------- GRADIO INTERFACE ----------------------
         
     | 
| 270 | 
         
            +
            def main(
         
     | 
| 271 | 
         
            +
                model_to_load,
         
     | 
| 272 | 
         
            +
                reference_model,
         
     | 
| 273 | 
         
            +
                output_path,
         
     | 
| 274 | 
         
            +
                hf_token,
         
     | 
| 275 | 
         
            +
                orgs_name,
         
     | 
| 276 | 
         
            +
                model_name,
         
     | 
| 277 | 
         
            +
                make_private,
         
     | 
| 278 | 
         
            +
            ):
         
     | 
| 279 | 
         
             
                """Main function: SDXL checkpoint to Diffusers, always fp16."""
         
     | 
| 280 | 
         | 
| 281 | 
         
             
                try:
         
     | 
| 282 | 
         
            +
                    convert_and_save_sdxl_to_diffusers(
         
     | 
| 283 | 
         
            +
                        model_to_load, output_path, reference_model
         
     | 
| 284 | 
         
            +
                    )
         
     | 
| 285 | 
         
            +
                    upload_to_huggingface(
         
     | 
| 286 | 
         
            +
                        output_path, hf_token, orgs_name, model_name, make_private
         
     | 
| 287 | 
         
            +
                    )
         
     | 
| 288 | 
         
             
                    return "Conversion and upload completed successfully!"
         
     | 
| 289 | 
         
             
                except Exception as e:
         
     | 
| 290 | 
         
             
                    return f"An error occurred: {e}"  # Return the error message
         
     | 
| 
         | 
|
| 294 | 
         
             
            #main-container {
         
     | 
| 295 | 
         
             
                display: flex;
         
     | 
| 296 | 
         
             
                flex-direction: column;
         
     | 
| 297 | 
         
            +
                /* Removed height: 100vh; */
         
     | 
| 298 | 
         
            +
                /* Removed justify-content: space-between; */
         
     | 
| 299 | 
         
             
                font-family: 'Arial', sans-serif;
         
     | 
| 300 | 
         
             
                font-size: 16px;
         
     | 
| 301 | 
         
             
                color: #333;
         
     | 
| 302 | 
         
             
            }
         
     | 
| 303 | 
         
             
            #convert-button {
         
     | 
| 304 | 
         
            +
                margin-top: 1em; /* Adds some space above the button */
         
     | 
| 305 | 
         
             
            }
         
     | 
| 306 | 
         
             
            """
         
     | 
| 307 | 
         | 
| 
         | 
|
| 334 | 
         
             
                """
         
     | 
| 335 | 
         
             
                )
         
     | 
| 336 | 
         | 
| 337 | 
         
            +
                with gr.Row():  # Use gr.Row for horizontal layout
         
     | 
| 338 | 
         
            +
                    with gr.Column():  # Group input components in a Column
         
     | 
| 339 | 
         
            +
             
     | 
| 340 | 
         
            +
                        model_to_load = gr.Textbox(
         
     | 
| 341 | 
         
            +
                            label="SDXL Checkpoint (Path, URL, or HF Repo)",
         
     | 
| 342 | 
         
            +
                            placeholder="Path, URL, or Hugging Face Repo ID (e.g., my-org/my-model or my-org/my-model/file.safetensors)",
         
     | 
| 343 | 
         
            +
                        )
         
     | 
| 344 | 
         
            +
                        reference_model = gr.Textbox(
         
     | 
| 345 | 
         
            +
                            label="Reference Diffusers Model (Optional)",
         
     | 
| 346 | 
         
            +
                            placeholder="e.g., stabilityai/stable-diffusion-xl-base-1.0 (Leave blank for default)",
         
     | 
| 347 | 
         
            +
                        )
         
     | 
| 348 | 
         
            +
                        output_path = gr.Textbox(
         
     | 
| 349 | 
         
            +
                            label="Output Path (Diffusers Format)", value="output"
         
     | 
| 350 | 
         
            +
                        )  # Default changed to "output"
         
     | 
| 351 | 
         
            +
                        hf_token = gr.Textbox(
         
     | 
| 352 | 
         
            +
                            label="Hugging Face Token", placeholder="Your Hugging Face write token", type="password"
         
     | 
| 353 | 
         
            +
                        )
         
     | 
| 354 | 
         
            +
                        orgs_name = gr.Textbox(
         
     | 
| 355 | 
         
            +
                            label="Organization Name (Optional)", placeholder="Your organization name"
         
     | 
| 356 | 
         
            +
                        )
         
     | 
| 357 | 
         
            +
                        model_name = gr.Textbox(
         
     | 
| 358 | 
         
            +
                            label="Model Name", placeholder="The name of your model on Hugging Face"
         
     | 
| 359 | 
         
            +
                        )
         
     | 
| 360 | 
         
            +
                        make_private = gr.Checkbox(label="Make Repository Private", value=False)
         
     | 
| 361 | 
         
            +
             
     | 
| 362 | 
         
            +
                        convert_button = gr.Button("Convert and Upload")
         
     | 
| 363 | 
         
            +
             
     | 
| 364 | 
         
            +
                    with gr.Column():
         
     | 
| 365 | 
         
            +
                        output = gr.Markdown() #Output is in its own column
         
     | 
| 366 | 
         
            +
             
     | 
| 367 | 
         
            +
                convert_button.click(
         
     | 
| 368 | 
         
            +
                    fn=main,
         
     | 
| 369 | 
         
            +
                    inputs=[
         
     | 
| 370 | 
         
            +
                        model_to_load,
         
     | 
| 371 | 
         
            +
                        reference_model,
         
     | 
| 372 | 
         
            +
                        output_path,
         
     | 
| 373 | 
         
            +
                        hf_token,
         
     | 
| 374 | 
         
            +
                        orgs_name,
         
     | 
| 375 | 
         
            +
                        model_name,
         
     | 
| 376 | 
         
            +
                        make_private,
         
     | 
| 377 | 
         
            +
                    ],
         
     | 
| 378 | 
         
            +
                    outputs=output,
         
     | 
| 379 | 
         
            +
                )
         
     | 
| 380 | 
         | 
| 381 | 
         
             
            demo.launch()
         
     |