1inkusFace commited on
Commit
f69d021
·
verified ·
1 Parent(s): a532195

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -13
app.py CHANGED
@@ -247,7 +247,7 @@ def generate_30(
247
  prompt_embeds_a2 = pipe.text_encoder_2(text_input_ids1b.to(torch.device('cuda')), output_hidden_states=True)
248
  pooled_prompt_embeds_a2 = prompt_embeds_a2[0] # Pooled output from encoder 1
249
  prompt_embeds_a2 = prompt_embeds_a2.hidden_states[-2] # Penultimate hidden state from encoder 1
250
- print('encoder shape: ', prompt_embeds_a2.shape)
251
  prompt_embeds_b2 = pipe.text_encoder_2(text_input_ids2b.to(torch.device('cuda')), output_hidden_states=True)
252
  pooled_prompt_embeds_b2 = prompt_embeds_b2[0] # Pooled output from encoder 2
253
  prompt_embeds_b2 = prompt_embeds_b2.hidden_states[-2] # Penultimate hidden state from encoder 2
@@ -259,16 +259,14 @@ def generate_30(
259
  # 4. (Optional) Average the pooled embeddings
260
  prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
261
  print('averaged shape: ', prompt_embeds.shape)
262
- pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
263
- print('pooled averaged shape: ', pooled_prompt_embeds.shape)
264
 
265
  # 3. Concatenate the text_encoder_2 embeddings
266
  prompt_embeds2 = torch.cat([prompt_embeds_a2, prompt_embeds_b2])
267
  print('catted shape2: ', prompt_embeds.shape)
268
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
269
- pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2], dim=2)
270
  # 4. (Optional) Average the pooled embeddings
271
- pooled_prompt_embeds = torch.mean(pooled_prompt_embeds2,dim=0)
272
  print('pooled averaged shape: ', pooled_prompt_embeds.shape)
273
 
274
  options = {
@@ -376,6 +374,7 @@ def generate_60(
376
  pooled_prompt_embeds_b2 = prompt_embeds_b2[0] # Pooled output from encoder 2
377
  prompt_embeds_b2 = prompt_embeds_b2.hidden_states[-2] # Penultimate hidden state from encoder 2
378
 
 
379
  # 3. Concatenate the embeddings
380
  prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
381
  print('catted shape: ', prompt_embeds.shape)
@@ -383,16 +382,14 @@ def generate_60(
383
  # 4. (Optional) Average the pooled embeddings
384
  prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
385
  print('averaged shape: ', prompt_embeds.shape)
386
- pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
387
- print('pooled averaged shape: ', pooled_prompt_embeds.shape)
388
 
389
  # 3. Concatenate the text_encoder_2 embeddings
390
  prompt_embeds2 = torch.cat([prompt_embeds_a2, prompt_embeds_b2])
391
  print('catted shape2: ', prompt_embeds.shape)
392
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
393
- pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2], dim=2)
394
  # 4. (Optional) Average the pooled embeddings
395
- pooled_prompt_embeds = torch.mean(pooled_prompt_embeds2,dim=0)
396
  print('pooled averaged shape: ', pooled_prompt_embeds.shape)
397
 
398
  options = {
@@ -500,6 +497,7 @@ def generate_90(
500
  pooled_prompt_embeds_b2 = prompt_embeds_b2[0] # Pooled output from encoder 2
501
  prompt_embeds_b2 = prompt_embeds_b2.hidden_states[-2] # Penultimate hidden state from encoder 2
502
 
 
503
  # 3. Concatenate the embeddings
504
  prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
505
  print('catted shape: ', prompt_embeds.shape)
@@ -507,16 +505,14 @@ def generate_90(
507
  # 4. (Optional) Average the pooled embeddings
508
  prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
509
  print('averaged shape: ', prompt_embeds.shape)
510
- pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
511
- print('pooled averaged shape: ', pooled_prompt_embeds.shape)
512
 
513
  # 3. Concatenate the text_encoder_2 embeddings
514
  prompt_embeds2 = torch.cat([prompt_embeds_a2, prompt_embeds_b2])
515
  print('catted shape2: ', prompt_embeds.shape)
516
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
517
- pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2], dim=2)
518
  # 4. (Optional) Average the pooled embeddings
519
- pooled_prompt_embeds = torch.mean(pooled_prompt_embeds2,dim=0)
520
  print('pooled averaged shape: ', pooled_prompt_embeds.shape)
521
 
522
  options = {
 
247
  prompt_embeds_a2 = pipe.text_encoder_2(text_input_ids1b.to(torch.device('cuda')), output_hidden_states=True)
248
  pooled_prompt_embeds_a2 = prompt_embeds_a2[0] # Pooled output from encoder 1
249
  prompt_embeds_a2 = prompt_embeds_a2.hidden_states[-2] # Penultimate hidden state from encoder 1
250
+ print('encoder shape2: ', prompt_embeds_a2.shape)
251
  prompt_embeds_b2 = pipe.text_encoder_2(text_input_ids2b.to(torch.device('cuda')), output_hidden_states=True)
252
  pooled_prompt_embeds_b2 = prompt_embeds_b2[0] # Pooled output from encoder 2
253
  prompt_embeds_b2 = prompt_embeds_b2.hidden_states[-2] # Penultimate hidden state from encoder 2
 
259
  # 4. (Optional) Average the pooled embeddings
260
  prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
261
  print('averaged shape: ', prompt_embeds.shape)
 
 
262
 
263
  # 3. Concatenate the text_encoder_2 embeddings
264
  prompt_embeds2 = torch.cat([prompt_embeds_a2, prompt_embeds_b2])
265
  print('catted shape2: ', prompt_embeds.shape)
266
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
267
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
268
  # 4. (Optional) Average the pooled embeddings
269
+ pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
270
  print('pooled averaged shape: ', pooled_prompt_embeds.shape)
271
 
272
  options = {
 
374
  pooled_prompt_embeds_b2 = prompt_embeds_b2[0] # Pooled output from encoder 2
375
  prompt_embeds_b2 = prompt_embeds_b2.hidden_states[-2] # Penultimate hidden state from encoder 2
376
 
377
+
378
  # 3. Concatenate the embeddings
379
  prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
380
  print('catted shape: ', prompt_embeds.shape)
 
382
  # 4. (Optional) Average the pooled embeddings
383
  prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
384
  print('averaged shape: ', prompt_embeds.shape)
 
 
385
 
386
  # 3. Concatenate the text_encoder_2 embeddings
387
  prompt_embeds2 = torch.cat([prompt_embeds_a2, prompt_embeds_b2])
388
  print('catted shape2: ', prompt_embeds.shape)
389
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
390
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
391
  # 4. (Optional) Average the pooled embeddings
392
+ pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
393
  print('pooled averaged shape: ', pooled_prompt_embeds.shape)
394
 
395
  options = {
 
497
  pooled_prompt_embeds_b2 = prompt_embeds_b2[0] # Pooled output from encoder 2
498
  prompt_embeds_b2 = prompt_embeds_b2.hidden_states[-2] # Penultimate hidden state from encoder 2
499
 
500
+
501
  # 3. Concatenate the embeddings
502
  prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
503
  print('catted shape: ', prompt_embeds.shape)
 
505
  # 4. (Optional) Average the pooled embeddings
506
  prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
507
  print('averaged shape: ', prompt_embeds.shape)
 
 
508
 
509
  # 3. Concatenate the text_encoder_2 embeddings
510
  prompt_embeds2 = torch.cat([prompt_embeds_a2, prompt_embeds_b2])
511
  print('catted shape2: ', prompt_embeds.shape)
512
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
513
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
514
  # 4. (Optional) Average the pooled embeddings
515
+ pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
516
  print('pooled averaged shape: ', pooled_prompt_embeds.shape)
517
 
518
  options = {