Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -254,6 +254,7 @@ def generate_30(
|
|
254 |
pooled_prompt_embeds_b2 = prompt_embeds_b2[0] # Pooled output from encoder 2
|
255 |
prompt_embeds_b2 = prompt_embeds_b2.hidden_states[-2] # Penultimate hidden state from encoder 2
|
256 |
|
|
|
257 |
# 3. Concatenate the embeddings
|
258 |
prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
|
259 |
print('catted shape: ', prompt_embeds.shape)
|
@@ -271,7 +272,7 @@ def generate_30(
|
|
271 |
print('catted shape2: ', prompt_embeds2.shape)
|
272 |
pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
|
273 |
print('catted pooled shape 2: ', pooled_prompt_embeds2.shape)
|
274 |
-
pooled_prompt_embeds2 = torch.mean(pooled_prompt_embeds2,dim=
|
275 |
print('pooled meaned shape 2: ', pooled_prompt_embeds2.shape)
|
276 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2],dim=1)
|
277 |
print('catted combined meaned pooled shape: ', pooled_prompt_embeds.shape)
|
@@ -383,6 +384,7 @@ def generate_60(
|
|
383 |
pooled_prompt_embeds_b2 = prompt_embeds_b2[0] # Pooled output from encoder 2
|
384 |
prompt_embeds_b2 = prompt_embeds_b2.hidden_states[-2] # Penultimate hidden state from encoder 2
|
385 |
|
|
|
386 |
# 3. Concatenate the embeddings
|
387 |
prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
|
388 |
print('catted shape: ', prompt_embeds.shape)
|
@@ -400,7 +402,7 @@ def generate_60(
|
|
400 |
print('catted shape2: ', prompt_embeds2.shape)
|
401 |
pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
|
402 |
print('catted pooled shape 2: ', pooled_prompt_embeds2.shape)
|
403 |
-
pooled_prompt_embeds2 = torch.mean(pooled_prompt_embeds2,dim=
|
404 |
print('pooled meaned shape 2: ', pooled_prompt_embeds2.shape)
|
405 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2],dim=1)
|
406 |
print('catted combined meaned pooled shape: ', pooled_prompt_embeds.shape)
|
@@ -512,6 +514,7 @@ def generate_90(
|
|
512 |
pooled_prompt_embeds_b2 = prompt_embeds_b2[0] # Pooled output from encoder 2
|
513 |
prompt_embeds_b2 = prompt_embeds_b2.hidden_states[-2] # Penultimate hidden state from encoder 2
|
514 |
|
|
|
515 |
# 3. Concatenate the embeddings
|
516 |
prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
|
517 |
print('catted shape: ', prompt_embeds.shape)
|
@@ -529,7 +532,7 @@ def generate_90(
|
|
529 |
print('catted shape2: ', prompt_embeds2.shape)
|
530 |
pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
|
531 |
print('catted pooled shape 2: ', pooled_prompt_embeds2.shape)
|
532 |
-
pooled_prompt_embeds2 = torch.mean(pooled_prompt_embeds2,dim=
|
533 |
print('pooled meaned shape 2: ', pooled_prompt_embeds2.shape)
|
534 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2],dim=1)
|
535 |
print('catted combined meaned pooled shape: ', pooled_prompt_embeds.shape)
|
|
|
254 |
pooled_prompt_embeds_b2 = prompt_embeds_b2[0] # Pooled output from encoder 2
|
255 |
prompt_embeds_b2 = prompt_embeds_b2.hidden_states[-2] # Penultimate hidden state from encoder 2
|
256 |
|
257 |
+
|
258 |
# 3. Concatenate the embeddings
|
259 |
prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
|
260 |
print('catted shape: ', prompt_embeds.shape)
|
|
|
272 |
print('catted shape2: ', prompt_embeds2.shape)
|
273 |
pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
|
274 |
print('catted pooled shape 2: ', pooled_prompt_embeds2.shape)
|
275 |
+
pooled_prompt_embeds2 = torch.mean(pooled_prompt_embeds2,dim=0,keepdim=True)
|
276 |
print('pooled meaned shape 2: ', pooled_prompt_embeds2.shape)
|
277 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2],dim=1)
|
278 |
print('catted combined meaned pooled shape: ', pooled_prompt_embeds.shape)
|
|
|
384 |
pooled_prompt_embeds_b2 = prompt_embeds_b2[0] # Pooled output from encoder 2
|
385 |
prompt_embeds_b2 = prompt_embeds_b2.hidden_states[-2] # Penultimate hidden state from encoder 2
|
386 |
|
387 |
+
|
388 |
# 3. Concatenate the embeddings
|
389 |
prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
|
390 |
print('catted shape: ', prompt_embeds.shape)
|
|
|
402 |
print('catted shape2: ', prompt_embeds2.shape)
|
403 |
pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
|
404 |
print('catted pooled shape 2: ', pooled_prompt_embeds2.shape)
|
405 |
+
pooled_prompt_embeds2 = torch.mean(pooled_prompt_embeds2,dim=0,keepdim=True)
|
406 |
print('pooled meaned shape 2: ', pooled_prompt_embeds2.shape)
|
407 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2],dim=1)
|
408 |
print('catted combined meaned pooled shape: ', pooled_prompt_embeds.shape)
|
|
|
514 |
pooled_prompt_embeds_b2 = prompt_embeds_b2[0] # Pooled output from encoder 2
|
515 |
prompt_embeds_b2 = prompt_embeds_b2.hidden_states[-2] # Penultimate hidden state from encoder 2
|
516 |
|
517 |
+
|
518 |
# 3. Concatenate the embeddings
|
519 |
prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
|
520 |
print('catted shape: ', prompt_embeds.shape)
|
|
|
532 |
print('catted shape2: ', prompt_embeds2.shape)
|
533 |
pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
|
534 |
print('catted pooled shape 2: ', pooled_prompt_embeds2.shape)
|
535 |
+
pooled_prompt_embeds2 = torch.mean(pooled_prompt_embeds2,dim=0,keepdim=True)
|
536 |
print('pooled meaned shape 2: ', pooled_prompt_embeds2.shape)
|
537 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2],dim=1)
|
538 |
print('catted combined meaned pooled shape: ', pooled_prompt_embeds.shape)
|