1inkusFace commited on
Commit
ba6244d
·
verified ·
1 Parent(s): 024ab2e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -254,6 +254,7 @@ def generate_30(
254
  pooled_prompt_embeds_b2 = prompt_embeds_b2[0] # Pooled output from encoder 2
255
  prompt_embeds_b2 = prompt_embeds_b2.hidden_states[-2] # Penultimate hidden state from encoder 2
256
 
 
257
  # 3. Concatenate the embeddings
258
  prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
259
  print('catted shape: ', prompt_embeds.shape)
@@ -271,7 +272,7 @@ def generate_30(
271
  print('catted shape2: ', prompt_embeds2.shape)
272
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
273
  print('catted pooled shape 2: ', pooled_prompt_embeds2.shape)
274
- pooled_prompt_embeds2 = torch.mean(pooled_prompt_embeds2,dim=1,keepdim=True)
275
  print('pooled meaned shape 2: ', pooled_prompt_embeds2.shape)
276
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2],dim=1)
277
  print('catted combined meaned pooled shape: ', pooled_prompt_embeds.shape)
@@ -383,6 +384,7 @@ def generate_60(
383
  pooled_prompt_embeds_b2 = prompt_embeds_b2[0] # Pooled output from encoder 2
384
  prompt_embeds_b2 = prompt_embeds_b2.hidden_states[-2] # Penultimate hidden state from encoder 2
385
 
 
386
  # 3. Concatenate the embeddings
387
  prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
388
  print('catted shape: ', prompt_embeds.shape)
@@ -400,7 +402,7 @@ def generate_60(
400
  print('catted shape2: ', prompt_embeds2.shape)
401
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
402
  print('catted pooled shape 2: ', pooled_prompt_embeds2.shape)
403
- pooled_prompt_embeds2 = torch.mean(pooled_prompt_embeds2,dim=1,keepdim=True)
404
  print('pooled meaned shape 2: ', pooled_prompt_embeds2.shape)
405
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2],dim=1)
406
  print('catted combined meaned pooled shape: ', pooled_prompt_embeds.shape)
@@ -512,6 +514,7 @@ def generate_90(
512
  pooled_prompt_embeds_b2 = prompt_embeds_b2[0] # Pooled output from encoder 2
513
  prompt_embeds_b2 = prompt_embeds_b2.hidden_states[-2] # Penultimate hidden state from encoder 2
514
 
 
515
  # 3. Concatenate the embeddings
516
  prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
517
  print('catted shape: ', prompt_embeds.shape)
@@ -529,7 +532,7 @@ def generate_90(
529
  print('catted shape2: ', prompt_embeds2.shape)
530
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
531
  print('catted pooled shape 2: ', pooled_prompt_embeds2.shape)
532
- pooled_prompt_embeds2 = torch.mean(pooled_prompt_embeds2,dim=1,keepdim=True)
533
  print('pooled meaned shape 2: ', pooled_prompt_embeds2.shape)
534
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2],dim=1)
535
  print('catted combined meaned pooled shape: ', pooled_prompt_embeds.shape)
 
254
  pooled_prompt_embeds_b2 = prompt_embeds_b2[0] # Pooled output from encoder 2
255
  prompt_embeds_b2 = prompt_embeds_b2.hidden_states[-2] # Penultimate hidden state from encoder 2
256
 
257
+
258
  # 3. Concatenate the embeddings
259
  prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
260
  print('catted shape: ', prompt_embeds.shape)
 
272
  print('catted shape2: ', prompt_embeds2.shape)
273
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
274
  print('catted pooled shape 2: ', pooled_prompt_embeds2.shape)
275
+ pooled_prompt_embeds2 = torch.mean(pooled_prompt_embeds2,dim=0,keepdim=True)
276
  print('pooled meaned shape 2: ', pooled_prompt_embeds2.shape)
277
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2],dim=1)
278
  print('catted combined meaned pooled shape: ', pooled_prompt_embeds.shape)
 
384
  pooled_prompt_embeds_b2 = prompt_embeds_b2[0] # Pooled output from encoder 2
385
  prompt_embeds_b2 = prompt_embeds_b2.hidden_states[-2] # Penultimate hidden state from encoder 2
386
 
387
+
388
  # 3. Concatenate the embeddings
389
  prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
390
  print('catted shape: ', prompt_embeds.shape)
 
402
  print('catted shape2: ', prompt_embeds2.shape)
403
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
404
  print('catted pooled shape 2: ', pooled_prompt_embeds2.shape)
405
+ pooled_prompt_embeds2 = torch.mean(pooled_prompt_embeds2,dim=0,keepdim=True)
406
  print('pooled meaned shape 2: ', pooled_prompt_embeds2.shape)
407
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2],dim=1)
408
  print('catted combined meaned pooled shape: ', pooled_prompt_embeds.shape)
 
514
  pooled_prompt_embeds_b2 = prompt_embeds_b2[0] # Pooled output from encoder 2
515
  prompt_embeds_b2 = prompt_embeds_b2.hidden_states[-2] # Penultimate hidden state from encoder 2
516
 
517
+
518
  # 3. Concatenate the embeddings
519
  prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
520
  print('catted shape: ', prompt_embeds.shape)
 
532
  print('catted shape2: ', prompt_embeds2.shape)
533
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
534
  print('catted pooled shape 2: ', pooled_prompt_embeds2.shape)
535
+ pooled_prompt_embeds2 = torch.mean(pooled_prompt_embeds2,dim=0,keepdim=True)
536
  print('pooled meaned shape 2: ', pooled_prompt_embeds2.shape)
537
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2],dim=1)
538
  print('catted combined meaned pooled shape: ', pooled_prompt_embeds.shape)