1inkusFace commited on
Commit
024ab2e
·
verified ·
1 Parent(s): c8fa40a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -9
app.py CHANGED
@@ -259,7 +259,7 @@ def generate_30(
259
  print('catted shape: ', prompt_embeds.shape)
260
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
261
  print('catted pooled shape: ', pooled_prompt_embeds.shape)
262
- pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
263
  print('meaned pooled shape: ', pooled_prompt_embeds.shape)
264
 
265
  # 4. (Optional) Average the pooled embeddings
@@ -271,9 +271,9 @@ def generate_30(
271
  print('catted shape2: ', prompt_embeds2.shape)
272
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
273
  print('catted pooled shape 2: ', pooled_prompt_embeds2.shape)
274
- pooled_prompt_embeds2 = torch.mean(pooled_prompt_embeds2,dim=1)
275
  print('pooled meaned shape 2: ', pooled_prompt_embeds2.shape)
276
- pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2],dim=2)
277
  print('catted combined meaned pooled shape: ', pooled_prompt_embeds.shape)
278
 
279
  options = {
@@ -388,7 +388,7 @@ def generate_60(
388
  print('catted shape: ', prompt_embeds.shape)
389
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
390
  print('catted pooled shape: ', pooled_prompt_embeds.shape)
391
- pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
392
  print('meaned pooled shape: ', pooled_prompt_embeds.shape)
393
 
394
  # 4. (Optional) Average the pooled embeddings
@@ -400,9 +400,9 @@ def generate_60(
400
  print('catted shape2: ', prompt_embeds2.shape)
401
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
402
  print('catted pooled shape 2: ', pooled_prompt_embeds2.shape)
403
- pooled_prompt_embeds2 = torch.mean(pooled_prompt_embeds2,dim=1)
404
  print('pooled meaned shape 2: ', pooled_prompt_embeds2.shape)
405
- pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2],dim=2)
406
  print('catted combined meaned pooled shape: ', pooled_prompt_embeds.shape)
407
 
408
  options = {
@@ -517,7 +517,7 @@ def generate_90(
517
  print('catted shape: ', prompt_embeds.shape)
518
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
519
  print('catted pooled shape: ', pooled_prompt_embeds.shape)
520
- pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
521
  print('meaned pooled shape: ', pooled_prompt_embeds.shape)
522
 
523
  # 4. (Optional) Average the pooled embeddings
@@ -529,9 +529,9 @@ def generate_90(
529
  print('catted shape2: ', prompt_embeds2.shape)
530
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
531
  print('catted pooled shape 2: ', pooled_prompt_embeds2.shape)
532
- pooled_prompt_embeds2 = torch.mean(pooled_prompt_embeds2,dim=1)
533
  print('pooled meaned shape 2: ', pooled_prompt_embeds2.shape)
534
- pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2],dim=2)
535
  print('catted combined meaned pooled shape: ', pooled_prompt_embeds.shape)
536
 
537
  options = {
 
259
  print('catted shape: ', prompt_embeds.shape)
260
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
261
  print('catted pooled shape: ', pooled_prompt_embeds.shape)
262
+ pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0,keepdim=True)
263
  print('meaned pooled shape: ', pooled_prompt_embeds.shape)
264
 
265
  # 4. (Optional) Average the pooled embeddings
 
271
  print('catted shape2: ', prompt_embeds2.shape)
272
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
273
  print('catted pooled shape 2: ', pooled_prompt_embeds2.shape)
274
+ pooled_prompt_embeds2 = torch.mean(pooled_prompt_embeds2,dim=1,keepdim=True)
275
  print('pooled meaned shape 2: ', pooled_prompt_embeds2.shape)
276
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2],dim=1)
277
  print('catted combined meaned pooled shape: ', pooled_prompt_embeds.shape)
278
 
279
  options = {
 
388
  print('catted shape: ', prompt_embeds.shape)
389
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
390
  print('catted pooled shape: ', pooled_prompt_embeds.shape)
391
+ pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0,keepdim=True)
392
  print('meaned pooled shape: ', pooled_prompt_embeds.shape)
393
 
394
  # 4. (Optional) Average the pooled embeddings
 
400
  print('catted shape2: ', prompt_embeds2.shape)
401
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
402
  print('catted pooled shape 2: ', pooled_prompt_embeds2.shape)
403
+ pooled_prompt_embeds2 = torch.mean(pooled_prompt_embeds2,dim=1,keepdim=True)
404
  print('pooled meaned shape 2: ', pooled_prompt_embeds2.shape)
405
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2],dim=1)
406
  print('catted combined meaned pooled shape: ', pooled_prompt_embeds.shape)
407
 
408
  options = {
 
517
  print('catted shape: ', prompt_embeds.shape)
518
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
519
  print('catted pooled shape: ', pooled_prompt_embeds.shape)
520
+ pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0,keepdim=True)
521
  print('meaned pooled shape: ', pooled_prompt_embeds.shape)
522
 
523
  # 4. (Optional) Average the pooled embeddings
 
529
  print('catted shape2: ', prompt_embeds2.shape)
530
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
531
  print('catted pooled shape 2: ', pooled_prompt_embeds2.shape)
532
+ pooled_prompt_embeds2 = torch.mean(pooled_prompt_embeds2,dim=1,keepdim=True)
533
  print('pooled meaned shape 2: ', pooled_prompt_embeds2.shape)
534
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2],dim=1)
535
  print('catted combined meaned pooled shape: ', pooled_prompt_embeds.shape)
536
 
537
  options = {