Spaces:
Runtime error
Runtime error
Update model.py
Browse files
model.py
CHANGED
|
@@ -78,7 +78,7 @@ class StreamMultiDiffusion(nn.Module):
|
|
| 78 |
self.default_mask_strength = default_mask_strength
|
| 79 |
self.default_prompt_strength = default_prompt_strength
|
| 80 |
self.register_buffer('bootstrap_steps', (
|
| 81 |
-
bootstrap_steps > torch.arange(len(t_index_list))).to(dtype=self.dtype, device=self.device))
|
| 82 |
self.bootstrap_mix_steps = bootstrap_mix_steps
|
| 83 |
self.register_buffer('bootstrap_mix_ratios', (
|
| 84 |
bootstrap_mix_steps - torch.arange(len(t_index_list), device=self.device)).clip_(0, 1).to(self.dtype))
|
|
@@ -1091,8 +1091,6 @@ class StreamMultiDiffusion(nn.Module):
|
|
| 1091 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 1092 |
p = self.num_layers
|
| 1093 |
x_t_latent = x_t_latent.repeat_interleave(p, dim=0) # (T * p, 4, h, w)
|
| 1094 |
-
print('111111111111111111111')
|
| 1095 |
-
|
| 1096 |
if self.bootstrap_steps[0] > 0:
|
| 1097 |
# Background bootstrapping.
|
| 1098 |
bootstrap_latent = self.scheduler.add_noise(
|
|
@@ -1100,7 +1098,6 @@ class StreamMultiDiffusion(nn.Module):
|
|
| 1100 |
self.stock_noise,
|
| 1101 |
torch.tensor(self.sub_timesteps_tensor, device=self.device),
|
| 1102 |
)
|
| 1103 |
-
print('111111111111111111111', self.bootstrap_steps)
|
| 1104 |
|
| 1105 |
x_t_latent = rearrange(x_t_latent, '(t p) c h w -> p t c h w', p=p)
|
| 1106 |
bootstrap_mask = (
|
|
@@ -1109,11 +1106,9 @@ class StreamMultiDiffusion(nn.Module):
|
|
| 1109 |
) # (p, t, c, h, w)
|
| 1110 |
x_t_latent = (1.0 - bootstrap_mask) * bootstrap_latent[None] + bootstrap_mask * x_t_latent
|
| 1111 |
x_t_latent = rearrange(x_t_latent, 'p t c h w -> (t p) c h w')
|
| 1112 |
-
print('222222222222222222222')
|
| 1113 |
|
| 1114 |
# Centering.
|
| 1115 |
x_t_latent = shift_to_mask_bbox_center(x_t_latent, rearrange(self.masks, 'p t c h w -> (t p) c h w'), reverse=True)
|
| 1116 |
-
print('333333333333333333333')
|
| 1117 |
|
| 1118 |
t_list = self.sub_timesteps_tensor_ # (T * p,)
|
| 1119 |
if self.guidance_scale > 1.0 and self.cfg_type == 'initialize':
|
|
|
|
| 78 |
self.default_mask_strength = default_mask_strength
|
| 79 |
self.default_prompt_strength = default_prompt_strength
|
| 80 |
self.register_buffer('bootstrap_steps', (
|
| 81 |
+
bootstrap_steps > torch.arange(len(t_index_list))).float().to(dtype=self.dtype, device=self.device))
|
| 82 |
self.bootstrap_mix_steps = bootstrap_mix_steps
|
| 83 |
self.register_buffer('bootstrap_mix_ratios', (
|
| 84 |
bootstrap_mix_steps - torch.arange(len(t_index_list), device=self.device)).clip_(0, 1).to(self.dtype))
|
|
|
|
| 1091 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 1092 |
p = self.num_layers
|
| 1093 |
x_t_latent = x_t_latent.repeat_interleave(p, dim=0) # (T * p, 4, h, w)
|
|
|
|
|
|
|
| 1094 |
if self.bootstrap_steps[0] > 0:
|
| 1095 |
# Background bootstrapping.
|
| 1096 |
bootstrap_latent = self.scheduler.add_noise(
|
|
|
|
| 1098 |
self.stock_noise,
|
| 1099 |
torch.tensor(self.sub_timesteps_tensor, device=self.device),
|
| 1100 |
)
|
|
|
|
| 1101 |
|
| 1102 |
x_t_latent = rearrange(x_t_latent, '(t p) c h w -> p t c h w', p=p)
|
| 1103 |
bootstrap_mask = (
|
|
|
|
| 1106 |
) # (p, t, c, h, w)
|
| 1107 |
x_t_latent = (1.0 - bootstrap_mask) * bootstrap_latent[None] + bootstrap_mask * x_t_latent
|
| 1108 |
x_t_latent = rearrange(x_t_latent, 'p t c h w -> (t p) c h w')
|
|
|
|
| 1109 |
|
| 1110 |
# Centering.
|
| 1111 |
x_t_latent = shift_to_mask_bbox_center(x_t_latent, rearrange(self.masks, 'p t c h w -> (t p) c h w'), reverse=True)
|
|
|
|
| 1112 |
|
| 1113 |
t_list = self.sub_timesteps_tensor_ # (T * p,)
|
| 1114 |
if self.guidance_scale > 1.0 and self.cfg_type == 'initialize':
|