Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	Upload wrapper.py
Browse files
    	
        diffusers_helper/k_diffusion/wrapper.py
    ADDED
    
    | 
         @@ -0,0 +1,51 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            def append_dims(x, target_dims):
         
     | 
| 5 | 
         
            +
                return x[(...,) + (None,) * (target_dims - x.ndim)]
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=1.0):
         
     | 
| 9 | 
         
            +
                if guidance_rescale == 0:
         
     | 
| 10 | 
         
            +
                    return noise_cfg
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
                std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
         
     | 
| 13 | 
         
            +
                std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
         
     | 
| 14 | 
         
            +
                noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
         
     | 
| 15 | 
         
            +
                noise_cfg = guidance_rescale * noise_pred_rescaled + (1.0 - guidance_rescale) * noise_cfg
         
     | 
| 16 | 
         
            +
                return noise_cfg
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            def fm_wrapper(transformer, t_scale=1000.0):
         
     | 
| 20 | 
         
            +
                def k_model(x, sigma, **extra_args):
         
     | 
| 21 | 
         
            +
                    dtype = extra_args['dtype']
         
     | 
| 22 | 
         
            +
                    cfg_scale = extra_args['cfg_scale']
         
     | 
| 23 | 
         
            +
                    cfg_rescale = extra_args['cfg_rescale']
         
     | 
| 24 | 
         
            +
                    concat_latent = extra_args['concat_latent']
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                    original_dtype = x.dtype
         
     | 
| 27 | 
         
            +
                    sigma = sigma.float()
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                    x = x.to(dtype)
         
     | 
| 30 | 
         
            +
                    timestep = (sigma * t_scale).to(dtype)
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                    if concat_latent is None:
         
     | 
| 33 | 
         
            +
                        hidden_states = x
         
     | 
| 34 | 
         
            +
                    else:
         
     | 
| 35 | 
         
            +
                        hidden_states = torch.cat([x, concat_latent.to(x)], dim=1)
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                    pred_positive = transformer(hidden_states=hidden_states, timestep=timestep, return_dict=False, **extra_args['positive'])[0].float()
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                    if cfg_scale == 1.0:
         
     | 
| 40 | 
         
            +
                        pred_negative = torch.zeros_like(pred_positive)
         
     | 
| 41 | 
         
            +
                    else:
         
     | 
| 42 | 
         
            +
                        pred_negative = transformer(hidden_states=hidden_states, timestep=timestep, return_dict=False, **extra_args['negative'])[0].float()
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                    pred_cfg = pred_negative + cfg_scale * (pred_positive - pred_negative)
         
     | 
| 45 | 
         
            +
                    pred = rescale_noise_cfg(pred_cfg, pred_positive, guidance_rescale=cfg_rescale)
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                    x0 = x.float() - pred.float() * append_dims(sigma, x.ndim)
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                    return x0.to(dtype=original_dtype)
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                return k_model
         
     |