Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -2,7 +2,6 @@ | |
| 2 | 
             
            import os
         | 
| 3 | 
             
            import random
         | 
| 4 | 
             
            import uuid
         | 
| 5 | 
            -
            import json
         | 
| 6 | 
             
            import gradio as gr
         | 
| 7 | 
             
            import numpy as np
         | 
| 8 | 
             
            from PIL import Image
         | 
| @@ -49,7 +48,7 @@ BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) | |
| 49 |  | 
| 50 | 
             
            device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
         | 
| 51 |  | 
| 52 | 
            -
            def  | 
| 53 | 
             
                pipe = StableDiffusionXLPipeline.from_pretrained(
         | 
| 54 | 
             
                    model_id,
         | 
| 55 | 
             
                    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
         | 
| @@ -66,8 +65,8 @@ def load_model(model_id): | |
| 66 |  | 
| 67 | 
             
                return pipe
         | 
| 68 |  | 
| 69 | 
            -
             | 
| 70 | 
            -
             | 
| 71 |  | 
| 72 | 
             
            MAX_SEED = np.iinfo(np.int32).max
         | 
| 73 |  | 
| @@ -97,9 +96,8 @@ def generate( | |
| 97 | 
             
                num_images: int = 1,  
         | 
| 98 | 
             
                progress=gr.Progress(track_tqdm=True),
         | 
| 99 | 
             
            ):
         | 
| 100 | 
            -
                global  | 
| 101 | 
            -
                 | 
| 102 | 
            -
                    pipe = load_model(MODEL_OPTIONS[model_choice])
         | 
| 103 |  | 
| 104 | 
             
                seed = int(randomize_seed_fn(seed, randomize_seed))
         | 
| 105 | 
             
                generator = torch.Generator(device=device).manual_seed(seed)
         | 
| @@ -131,14 +129,7 @@ def generate( | |
| 131 |  | 
| 132 | 
             
            with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
         | 
| 133 | 
             
                gr.Markdown(DESCRIPTIONx)
         | 
| 134 | 
            -
             | 
| 135 | 
            -
                with gr.Group():
         | 
| 136 | 
             
                    with gr.Row():
         | 
| 137 | 
            -
                        model_choice = gr.Dropdown(
         | 
| 138 | 
            -
                            label="Model",
         | 
| 139 | 
            -
                            choices=list(MODEL_OPTIONS.keys()),
         | 
| 140 | 
            -
                            value="RealVisXL_V4.0_Lightning"
         | 
| 141 | 
            -
                        )
         | 
| 142 | 
             
                        prompt = gr.Text(
         | 
| 143 | 
             
                            label="Prompt",
         | 
| 144 | 
             
                            show_label=False,
         | 
| @@ -149,6 +140,13 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo: | |
| 149 | 
             
                        run_button = gr.Button("Run", scale=0)
         | 
| 150 | 
             
                    result = gr.Gallery(label="Result", columns=1, show_label=False) 
         | 
| 151 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 152 | 
             
                with gr.Accordion("Advanced options", open=False, visible=False):
         | 
| 153 | 
             
                    num_images = gr.Slider(
         | 
| 154 | 
             
                        label="Number of Images",
         | 
| @@ -250,4 +248,4 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo: | |
| 250 | 
             
                gr.Markdown("⚠️ users are accountable for the content they generate and are responsible for ensuring it meets appropriate ethical standards.")
         | 
| 251 |  | 
| 252 | 
             
            if __name__ == "__main__":
         | 
| 253 | 
            -
                demo.queue(max_size=40).launch()
         | 
|  | |
| 2 | 
             
            import os
         | 
| 3 | 
             
            import random
         | 
| 4 | 
             
            import uuid
         | 
|  | |
| 5 | 
             
            import gradio as gr
         | 
| 6 | 
             
            import numpy as np
         | 
| 7 | 
             
            from PIL import Image
         | 
|  | |
| 48 |  | 
| 49 | 
             
            device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
         | 
| 50 |  | 
| 51 | 
            +
            def load_and_prepare_model(model_id):
         | 
| 52 | 
             
                pipe = StableDiffusionXLPipeline.from_pretrained(
         | 
| 53 | 
             
                    model_id,
         | 
| 54 | 
             
                    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
         | 
|  | |
| 65 |  | 
| 66 | 
             
                return pipe
         | 
| 67 |  | 
| 68 | 
            +
            # Preload and compile both models
         | 
| 69 | 
            +
            models = {key: load_and_prepare_model(value) for key, value in MODEL_OPTIONS.items()}
         | 
| 70 |  | 
| 71 | 
             
            MAX_SEED = np.iinfo(np.int32).max
         | 
| 72 |  | 
|  | |
| 96 | 
             
                num_images: int = 1,  
         | 
| 97 | 
             
                progress=gr.Progress(track_tqdm=True),
         | 
| 98 | 
             
            ):
         | 
| 99 | 
            +
                global models
         | 
| 100 | 
            +
                pipe = models[model_choice]
         | 
|  | |
| 101 |  | 
| 102 | 
             
                seed = int(randomize_seed_fn(seed, randomize_seed))
         | 
| 103 | 
             
                generator = torch.Generator(device=device).manual_seed(seed)
         | 
|  | |
| 129 |  | 
| 130 | 
             
            with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
         | 
| 131 | 
             
                gr.Markdown(DESCRIPTIONx)
         | 
|  | |
|  | |
| 132 | 
             
                    with gr.Row():
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 133 | 
             
                        prompt = gr.Text(
         | 
| 134 | 
             
                            label="Prompt",
         | 
| 135 | 
             
                            show_label=False,
         | 
|  | |
| 140 | 
             
                        run_button = gr.Button("Run", scale=0)
         | 
| 141 | 
             
                    result = gr.Gallery(label="Result", columns=1, show_label=False) 
         | 
| 142 |  | 
| 143 | 
            +
                    with gr.Row():
         | 
| 144 | 
            +
                        model_choice = gr.Dropdown(
         | 
| 145 | 
            +
                            label="Model",
         | 
| 146 | 
            +
                            choices=list(MODEL_OPTIONS.keys()),
         | 
| 147 | 
            +
                            value="RealVisXL_V4.0_Lightning"
         | 
| 148 | 
            +
                        )
         | 
| 149 | 
            +
             | 
| 150 | 
             
                with gr.Accordion("Advanced options", open=False, visible=False):
         | 
| 151 | 
             
                    num_images = gr.Slider(
         | 
| 152 | 
             
                        label="Number of Images",
         | 
|  | |
| 248 | 
             
                gr.Markdown("⚠️ users are accountable for the content they generate and are responsible for ensuring it meets appropriate ethical standards.")
         | 
| 249 |  | 
| 250 | 
             
            if __name__ == "__main__":
         | 
| 251 | 
            +
                demo.queue(max_size=40).launch()
         | 
