Spaces:
				
			
			
	
			
			
		Configuration error
		
	
	
	
			
			
	
	
	
	
		
		
		Configuration error
		
	
		Erwann Millon
		
	commited on
		
		
					Commit 
							
							·
						
						28c5269
	
1
								Parent(s):
							
							e37b9e5
								
refactoring and change default path
Browse files- app.py +2 -6
- loaders.py +1 -1
    	
        app.py
    CHANGED
    
    | @@ -5,7 +5,7 @@ import sys | |
| 5 | 
             
            import wandb
         | 
| 6 | 
             
            import torch
         | 
| 7 |  | 
| 8 | 
            -
            from presets import  | 
| 9 |  | 
| 10 | 
             
            sys.path.append("taming-transformers")
         | 
| 11 |  | 
| @@ -36,7 +36,7 @@ def set_img_from_example(state, img): | |
| 36 | 
             
            def get_cleared_mask():
         | 
| 37 | 
             
                return gr.Image.update(value=None)
         | 
| 38 | 
             
            class StateWrapper:
         | 
| 39 | 
            -
                """This extremely ugly code is a hacky fix to allow  | 
| 40 | 
             
                def create_gif(state, *args, **kwargs):
         | 
| 41 | 
             
                    return state, state[0].create_gif(*args, **kwargs)
         | 
| 42 | 
             
                def apply_asian_vector(state, *args, **kwargs):
         | 
| @@ -191,15 +191,11 @@ with gr.Blocks(css="styles.css") as demo: | |
| 191 | 
             
                clear.click(StateWrapper.clear_transforms, inputs=[state], outputs=[state, out, mask])
         | 
| 192 | 
             
                asian_weight.change(StateWrapper.apply_asian_vector, inputs=[state, asian_weight], outputs=[state, out, mask])
         | 
| 193 | 
             
                lip_size.change(StateWrapper.apply_lip_vector, inputs=[state, lip_size], outputs=[state, out, mask])
         | 
| 194 | 
            -
                # hair_green_purple.change(StateWrapper.apply_gp_vector, inputs=[state, hair_green_purple], outputs=[state, out, mask])
         | 
| 195 | 
             
                blue_eyes.change(StateWrapper.apply_rb_vector, inputs=[state, blue_eyes], outputs=[state, out, mask])
         | 
| 196 | 
             
                blend_weight.change(StateWrapper.blend, inputs=[state, blend_weight], outputs=[state, out, mask])
         | 
| 197 | 
             
                # requantize.change(StateWrapper.update_requant, inputs=[state, requantize], outputs=[state, out, mask])
         | 
| 198 | 
             
                base_img.change(StateWrapper.update_images, inputs=[state, base_img, blend_img, blend_weight], outputs=[state, out, mask])
         | 
| 199 | 
             
                blend_img.change(StateWrapper.update_images, inputs=[state, base_img, blend_img, blend_weight], outputs=[state, out, mask])
         | 
| 200 | 
            -
                # small_local.click(set_small_local, outputs=[iterations, learning_rate, lpips_weight, reconstruction_steps])
         | 
| 201 | 
            -
                # major_local.click(set_major_local, outputs=[iterations, learning_rate, lpips_weight, reconstruction_steps])
         | 
| 202 | 
            -
                # major_global.click(set_major_global, outputs=[iterations, learning_rate, lpips_weight, reconstruction_steps])
         | 
| 203 | 
             
                apply_prompts.click(StateWrapper.apply_prompts, inputs=[state, positive_prompts, negative_prompts, learning_rate, iterations, lpips_weight, reconstruction_steps], outputs=[state, out, mask])
         | 
| 204 | 
             
                rewind.change(StateWrapper.rewind, inputs=[state, rewind], outputs=[state, out, mask])
         | 
| 205 | 
             
                set_mask.click(StateWrapper.set_mask, inputs=[state, mask], outputs=[state, testim])
         | 
|  | |
| 5 | 
             
            import wandb
         | 
| 6 | 
             
            import torch
         | 
| 7 |  | 
| 8 | 
            +
            from presets import set_preset 
         | 
| 9 |  | 
| 10 | 
             
            sys.path.append("taming-transformers")
         | 
| 11 |  | 
|  | |
| 36 | 
             
            def get_cleared_mask():
         | 
| 37 | 
             
                return gr.Image.update(value=None)
         | 
| 38 | 
             
            class StateWrapper:
         | 
| 39 | 
            +
                """This extremely ugly code is a hacky fix to allow concurrent users on HF Spaces without instantiating new models for each user."""
         | 
| 40 | 
             
                def create_gif(state, *args, **kwargs):
         | 
| 41 | 
             
                    return state, state[0].create_gif(*args, **kwargs)
         | 
| 42 | 
             
                def apply_asian_vector(state, *args, **kwargs):
         | 
|  | |
| 191 | 
             
                clear.click(StateWrapper.clear_transforms, inputs=[state], outputs=[state, out, mask])
         | 
| 192 | 
             
                asian_weight.change(StateWrapper.apply_asian_vector, inputs=[state, asian_weight], outputs=[state, out, mask])
         | 
| 193 | 
             
                lip_size.change(StateWrapper.apply_lip_vector, inputs=[state, lip_size], outputs=[state, out, mask])
         | 
|  | |
| 194 | 
             
                blue_eyes.change(StateWrapper.apply_rb_vector, inputs=[state, blue_eyes], outputs=[state, out, mask])
         | 
| 195 | 
             
                blend_weight.change(StateWrapper.blend, inputs=[state, blend_weight], outputs=[state, out, mask])
         | 
| 196 | 
             
                # requantize.change(StateWrapper.update_requant, inputs=[state, requantize], outputs=[state, out, mask])
         | 
| 197 | 
             
                base_img.change(StateWrapper.update_images, inputs=[state, base_img, blend_img, blend_weight], outputs=[state, out, mask])
         | 
| 198 | 
             
                blend_img.change(StateWrapper.update_images, inputs=[state, base_img, blend_img, blend_weight], outputs=[state, out, mask])
         | 
|  | |
|  | |
|  | |
| 199 | 
             
                apply_prompts.click(StateWrapper.apply_prompts, inputs=[state, positive_prompts, negative_prompts, learning_rate, iterations, lpips_weight, reconstruction_steps], outputs=[state, out, mask])
         | 
| 200 | 
             
                rewind.change(StateWrapper.rewind, inputs=[state, rewind], outputs=[state, out, mask])
         | 
| 201 | 
             
                set_mask.click(StateWrapper.set_mask, inputs=[state, mask], outputs=[state, testim])
         | 
    	
        loaders.py
    CHANGED
    
    | @@ -17,7 +17,7 @@ def load_config(config_path, display=False): | |
| 17 |  | 
| 18 |  | 
| 19 | 
             
            def load_default(device):
         | 
| 20 | 
            -
                conf_path = "./celeba_vqgan/ | 
| 21 | 
             
                config = load_config(conf_path, display=False)
         | 
| 22 | 
             
                model = taming.models.vqgan.VQModel(**config.model.params)
         | 
| 23 | 
             
                sd = torch.load("./celeba_vqgan/vqgan_only.pt", map_location=device)
         | 
|  | |
| 17 |  | 
| 18 |  | 
| 19 | 
             
            def load_default(device):
         | 
| 20 | 
            +
                conf_path = "./celeba_vqgan/vqgan_only.yaml"
         | 
| 21 | 
             
                config = load_config(conf_path, display=False)
         | 
| 22 | 
             
                model = taming.models.vqgan.VQModel(**config.model.params)
         | 
| 23 | 
             
                sd = torch.load("./celeba_vqgan/vqgan_only.pt", map_location=device)
         |