Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -247,7 +247,7 @@ def generate_30(
|
|
247 |
prompt_embeds_a2 = pipe.text_encoder_2(text_input_ids1b.to(torch.device('cuda')), output_hidden_states=True)
|
248 |
pooled_prompt_embeds_a2 = prompt_embeds_a2[0] # Pooled output from encoder 1
|
249 |
prompt_embeds_a2 = prompt_embeds_a2.hidden_states[-2] # Penultimate hidden state from encoder 1
|
250 |
-
print('encoder
|
251 |
prompt_embeds_b2 = pipe.text_encoder_2(text_input_ids2b.to(torch.device('cuda')), output_hidden_states=True)
|
252 |
pooled_prompt_embeds_b2 = prompt_embeds_b2[0] # Pooled output from encoder 2
|
253 |
prompt_embeds_b2 = prompt_embeds_b2.hidden_states[-2] # Penultimate hidden state from encoder 2
|
@@ -259,16 +259,14 @@ def generate_30(
|
|
259 |
# 4. (Optional) Average the pooled embeddings
|
260 |
prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
|
261 |
print('averaged shape: ', prompt_embeds.shape)
|
262 |
-
pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
|
263 |
-
print('pooled averaged shape: ', pooled_prompt_embeds.shape)
|
264 |
|
265 |
# 3. Concatenate the text_encoder_2 embeddings
|
266 |
prompt_embeds2 = torch.cat([prompt_embeds_a2, prompt_embeds_b2])
|
267 |
print('catted shape2: ', prompt_embeds.shape)
|
268 |
pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
|
269 |
-
|
270 |
# 4. (Optional) Average the pooled embeddings
|
271 |
-
pooled_prompt_embeds = torch.mean(
|
272 |
print('pooled averaged shape: ', pooled_prompt_embeds.shape)
|
273 |
|
274 |
options = {
|
@@ -376,6 +374,7 @@ def generate_60(
|
|
376 |
pooled_prompt_embeds_b2 = prompt_embeds_b2[0] # Pooled output from encoder 2
|
377 |
prompt_embeds_b2 = prompt_embeds_b2.hidden_states[-2] # Penultimate hidden state from encoder 2
|
378 |
|
|
|
379 |
# 3. Concatenate the embeddings
|
380 |
prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
|
381 |
print('catted shape: ', prompt_embeds.shape)
|
@@ -383,16 +382,14 @@ def generate_60(
|
|
383 |
# 4. (Optional) Average the pooled embeddings
|
384 |
prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
|
385 |
print('averaged shape: ', prompt_embeds.shape)
|
386 |
-
pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
|
387 |
-
print('pooled averaged shape: ', pooled_prompt_embeds.shape)
|
388 |
|
389 |
# 3. Concatenate the text_encoder_2 embeddings
|
390 |
prompt_embeds2 = torch.cat([prompt_embeds_a2, prompt_embeds_b2])
|
391 |
print('catted shape2: ', prompt_embeds.shape)
|
392 |
pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
|
393 |
-
|
394 |
# 4. (Optional) Average the pooled embeddings
|
395 |
-
pooled_prompt_embeds = torch.mean(
|
396 |
print('pooled averaged shape: ', pooled_prompt_embeds.shape)
|
397 |
|
398 |
options = {
|
@@ -500,6 +497,7 @@ def generate_90(
|
|
500 |
pooled_prompt_embeds_b2 = prompt_embeds_b2[0] # Pooled output from encoder 2
|
501 |
prompt_embeds_b2 = prompt_embeds_b2.hidden_states[-2] # Penultimate hidden state from encoder 2
|
502 |
|
|
|
503 |
# 3. Concatenate the embeddings
|
504 |
prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
|
505 |
print('catted shape: ', prompt_embeds.shape)
|
@@ -507,16 +505,14 @@ def generate_90(
|
|
507 |
# 4. (Optional) Average the pooled embeddings
|
508 |
prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
|
509 |
print('averaged shape: ', prompt_embeds.shape)
|
510 |
-
pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
|
511 |
-
print('pooled averaged shape: ', pooled_prompt_embeds.shape)
|
512 |
|
513 |
# 3. Concatenate the text_encoder_2 embeddings
|
514 |
prompt_embeds2 = torch.cat([prompt_embeds_a2, prompt_embeds_b2])
|
515 |
print('catted shape2: ', prompt_embeds.shape)
|
516 |
pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
|
517 |
-
|
518 |
# 4. (Optional) Average the pooled embeddings
|
519 |
-
pooled_prompt_embeds = torch.mean(
|
520 |
print('pooled averaged shape: ', pooled_prompt_embeds.shape)
|
521 |
|
522 |
options = {
|
|
|
247 |
prompt_embeds_a2 = pipe.text_encoder_2(text_input_ids1b.to(torch.device('cuda')), output_hidden_states=True)
|
248 |
pooled_prompt_embeds_a2 = prompt_embeds_a2[0] # Pooled output from encoder 1
|
249 |
prompt_embeds_a2 = prompt_embeds_a2.hidden_states[-2] # Penultimate hidden state from encoder 1
|
250 |
+
print('encoder shape2: ', prompt_embeds_a2.shape)
|
251 |
prompt_embeds_b2 = pipe.text_encoder_2(text_input_ids2b.to(torch.device('cuda')), output_hidden_states=True)
|
252 |
pooled_prompt_embeds_b2 = prompt_embeds_b2[0] # Pooled output from encoder 2
|
253 |
prompt_embeds_b2 = prompt_embeds_b2.hidden_states[-2] # Penultimate hidden state from encoder 2
|
|
|
259 |
# 4. (Optional) Average the pooled embeddings
|
260 |
prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
|
261 |
print('averaged shape: ', prompt_embeds.shape)
|
|
|
|
|
262 |
|
263 |
# 3. Concatenate the text_encoder_2 embeddings
|
264 |
prompt_embeds2 = torch.cat([prompt_embeds_a2, prompt_embeds_b2])
|
265 |
print('catted shape2: ', prompt_embeds.shape)
|
266 |
pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
|
267 |
+
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
|
268 |
# 4. (Optional) Average the pooled embeddings
|
269 |
+
pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
|
270 |
print('pooled averaged shape: ', pooled_prompt_embeds.shape)
|
271 |
|
272 |
options = {
|
|
|
374 |
pooled_prompt_embeds_b2 = prompt_embeds_b2[0] # Pooled output from encoder 2
|
375 |
prompt_embeds_b2 = prompt_embeds_b2.hidden_states[-2] # Penultimate hidden state from encoder 2
|
376 |
|
377 |
+
|
378 |
# 3. Concatenate the embeddings
|
379 |
prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
|
380 |
print('catted shape: ', prompt_embeds.shape)
|
|
|
382 |
# 4. (Optional) Average the pooled embeddings
|
383 |
prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
|
384 |
print('averaged shape: ', prompt_embeds.shape)
|
|
|
|
|
385 |
|
386 |
# 3. Concatenate the text_encoder_2 embeddings
|
387 |
prompt_embeds2 = torch.cat([prompt_embeds_a2, prompt_embeds_b2])
|
388 |
print('catted shape2: ', prompt_embeds.shape)
|
389 |
pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
|
390 |
+
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
|
391 |
# 4. (Optional) Average the pooled embeddings
|
392 |
+
pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
|
393 |
print('pooled averaged shape: ', pooled_prompt_embeds.shape)
|
394 |
|
395 |
options = {
|
|
|
497 |
pooled_prompt_embeds_b2 = prompt_embeds_b2[0] # Pooled output from encoder 2
|
498 |
prompt_embeds_b2 = prompt_embeds_b2.hidden_states[-2] # Penultimate hidden state from encoder 2
|
499 |
|
500 |
+
|
501 |
# 3. Concatenate the embeddings
|
502 |
prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
|
503 |
print('catted shape: ', prompt_embeds.shape)
|
|
|
505 |
# 4. (Optional) Average the pooled embeddings
|
506 |
prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
|
507 |
print('averaged shape: ', prompt_embeds.shape)
|
|
|
|
|
508 |
|
509 |
# 3. Concatenate the text_encoder_2 embeddings
|
510 |
prompt_embeds2 = torch.cat([prompt_embeds_a2, prompt_embeds_b2])
|
511 |
print('catted shape2: ', prompt_embeds.shape)
|
512 |
pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
|
513 |
+
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
|
514 |
# 4. (Optional) Average the pooled embeddings
|
515 |
+
pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
|
516 |
print('pooled averaged shape: ', pooled_prompt_embeds.shape)
|
517 |
|
518 |
options = {
|