p1atdev commited on
Commit
7c61bf1
·
1 Parent(s): 2f0a1c2

chore: debug

Browse files
Files changed (1) hide show
  1. tkg.py +4 -0
tkg.py CHANGED
@@ -12,6 +12,8 @@ def get_mean_shifted_latents(
12
  ) -> torch.Tensor:
13
  shifted_latents = latents.clone()
14
 
 
 
15
  for idx, sign in enumerate(channels):
16
  if sign == 0:
17
  # skip
@@ -20,6 +22,8 @@ def get_mean_shifted_latents(
20
  latent_channel = shifted_latents[:, idx, :, :]
21
 
22
  positive_ratio = (latent_channel > 0).float().mean()
 
 
23
  target_ratio = positive_ratio + shift * sign
24
 
25
  # gradually shift latent_channel
 
12
  ) -> torch.Tensor:
13
  shifted_latents = latents.clone()
14
 
15
+ print("channels", channels)
16
+
17
  for idx, sign in enumerate(channels):
18
  if sign == 0:
19
  # skip
 
22
  latent_channel = shifted_latents[:, idx, :, :]
23
 
24
  positive_ratio = (latent_channel > 0).float().mean()
25
+ print("positive_ratio", positive_ratio)
26
+ print("shift", shift, "sign", sign)
27
  target_ratio = positive_ratio + shift * sign
28
 
29
  # gradually shift latent_channel