1inkusFace commited on
Commit
14948a8
·
verified ·
1 Parent(s): dcc9828

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -39
app.py CHANGED
@@ -211,19 +211,24 @@ def generate_30(
211
  return_tensors="pt",
212
  )
213
  text_input_ids2 = text_inputs2.input_ids
214
- prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
215
- pooled_prompt_embeds_list.append(prompt_embedsa[0])
216
- prompt_embedsa = prompt_embedsa.hidden_states[-2]
217
- print('text_encoder shape: ',prompt_embedsa.shape)
218
- prompt_embeds_list.append(prompt_embedsa)
219
- prompt_embedsb = pipe.text_encoder(text_input_ids2.to(device), output_hidden_states=True)
220
- pooled_prompt_embeds_list.append(prompt_embedsb[0])
221
- prompt_embedsb = prompt_embedsb.hidden_states[-2]
222
- prompt_embeds_list.append(prompt_embedsb)
223
- prompt_embeds = torch.cat(prompt_embeds_list).mean(dim=1, keepdim=True)
224
- print('catted shape: ',prompt_embeds.shape)
225
- pooled_prompt_embeds = torch.cat(pooled_prompt_embeds_list).mean(dim=1, keepdim=True)
226
-
 
 
 
 
 
227
  options = {
228
  #"prompt": prompt,
229
  "prompt_embeds": prompt_embeds,
@@ -288,19 +293,24 @@ def generate_60(
288
  return_tensors="pt",
289
  )
290
  text_input_ids2 = text_inputs2.input_ids
291
- prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
292
- pooled_prompt_embeds_list.append(prompt_embedsa[0])
293
- prompt_embedsa = prompt_embedsa.hidden_states[-2]
294
- print('text_encoder shape: ',prompt_embedsa.shape)
295
- prompt_embeds_list.append(prompt_embedsa)
296
- prompt_embedsb = pipe.text_encoder(text_input_ids2.to(device), output_hidden_states=True)
297
- pooled_prompt_embeds_list.append(prompt_embedsb[0])
298
- prompt_embedsb = prompt_embedsb.hidden_states[-2]
299
- prompt_embeds_list.append(prompt_embedsb)
300
- prompt_embeds = torch.cat(prompt_embeds_list).mean(dim=1, keepdim=True)
301
- print('catted shape: ',prompt_embeds.shape)
302
- pooled_prompt_embeds = torch.cat(pooled_prompt_embeds_list).mean(dim=1, keepdim=True)
303
-
 
 
 
 
 
304
  options = {
305
  #"prompt": prompt,
306
  "prompt_embeds": prompt_embeds,
@@ -365,19 +375,24 @@ def generate_90(
365
  return_tensors="pt",
366
  )
367
  text_input_ids2 = text_inputs2.input_ids
368
- prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
369
- pooled_prompt_embeds_list.append(prompt_embedsa[0])
370
- prompt_embedsa = prompt_embedsa.hidden_states[-2]
371
- print('text_encoder shape: ',prompt_embedsa.shape)
372
- prompt_embeds_list.append(prompt_embedsa)
373
- prompt_embedsb = pipe.text_encoder(text_input_ids2.to(device), output_hidden_states=True)
374
- pooled_prompt_embeds_list.append(prompt_embedsb[0])
375
- prompt_embedsb = prompt_embedsb.hidden_states[-2]
376
- prompt_embeds_list.append(prompt_embedsb)
377
- prompt_embeds = torch.cat(prompt_embeds_list).mean(dim=1, keepdim=True)
378
- print('catted shape: ',prompt_embeds.shape)
379
- pooled_prompt_embeds = torch.cat(pooled_prompt_embeds_list).mean(dim=1, keepdim=True)
380
-
 
 
 
 
 
381
  options = {
382
  #"prompt": prompt,
383
  "prompt_embeds": prompt_embeds,
 
211
  return_tensors="pt",
212
  )
213
  text_input_ids2 = text_inputs2.input_ids
214
+
215
+ # 2. Encode with the two text encoders
216
+ prompt_embeds_a = pipe.text_encoder(text_input_ids1, output_hidden_states=True)
217
+ pooled_prompt_embeds_a = prompt_embeds_a[0] # Pooled output from encoder 1
218
+ prompt_embeds_a = prompt_embeds_a.hidden_states[-2] # Penultimate hidden state from encoder 1
219
+
220
+ prompt_embeds_b = pipe.text_encoder_2(text_input_ids2, output_hidden_states=True)
221
+ pooled_prompt_embeds_b = prompt_embeds_b[0] # Pooled output from encoder 2
222
+ prompt_embeds_b = prompt_embeds_b.hidden_states[-2] # Penultimate hidden state from encoder 2
223
+
224
+ # 3. Concatenate the embeddings along the sequence dimension (dim=1)
225
+ prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b], dim=1)
226
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b], dim=1)
227
+
228
+ # 4. (Optional) Average the pooled embeddings
229
+ pooled_prompt_embeds = pooled_prompt_embeds.mean(dim=1, keepdim=True)
230
+
231
+
232
  options = {
233
  #"prompt": prompt,
234
  "prompt_embeds": prompt_embeds,
 
293
  return_tensors="pt",
294
  )
295
  text_input_ids2 = text_inputs2.input_ids
296
+
297
+ # 2. Encode with the two text encoders
298
+ prompt_embeds_a = pipe.text_encoder(text_input_ids1, output_hidden_states=True)
299
+ pooled_prompt_embeds_a = prompt_embeds_a[0] # Pooled output from encoder 1
300
+ prompt_embeds_a = prompt_embeds_a.hidden_states[-2] # Penultimate hidden state from encoder 1
301
+
302
+ prompt_embeds_b = pipe.text_encoder_2(text_input_ids2, output_hidden_states=True)
303
+ pooled_prompt_embeds_b = prompt_embeds_b[0] # Pooled output from encoder 2
304
+ prompt_embeds_b = prompt_embeds_b.hidden_states[-2] # Penultimate hidden state from encoder 2
305
+
306
+ # 3. Concatenate the embeddings along the sequence dimension (dim=1)
307
+ prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b], dim=1)
308
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b], dim=1)
309
+
310
+ # 4. (Optional) Average the pooled embeddings
311
+ pooled_prompt_embeds = pooled_prompt_embeds.mean(dim=1, keepdim=True)
312
+
313
+
314
  options = {
315
  #"prompt": prompt,
316
  "prompt_embeds": prompt_embeds,
 
375
  return_tensors="pt",
376
  )
377
  text_input_ids2 = text_inputs2.input_ids
378
+
379
+ # 2. Encode with the two text encoders
380
+ prompt_embeds_a = pipe.text_encoder(text_input_ids1, output_hidden_states=True)
381
+ pooled_prompt_embeds_a = prompt_embeds_a[0] # Pooled output from encoder 1
382
+ prompt_embeds_a = prompt_embeds_a.hidden_states[-2] # Penultimate hidden state from encoder 1
383
+
384
+ prompt_embeds_b = pipe.text_encoder_2(text_input_ids2, output_hidden_states=True)
385
+ pooled_prompt_embeds_b = prompt_embeds_b[0] # Pooled output from encoder 2
386
+ prompt_embeds_b = prompt_embeds_b.hidden_states[-2] # Penultimate hidden state from encoder 2
387
+
388
+ # 3. Concatenate the embeddings along the sequence dimension (dim=1)
389
+ prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b], dim=1)
390
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b], dim=1)
391
+
392
+ # 4. (Optional) Average the pooled embeddings
393
+ pooled_prompt_embeds = pooled_prompt_embeds.mean(dim=1, keepdim=True)
394
+
395
+
396
  options = {
397
  #"prompt": prompt,
398
  "prompt_embeds": prompt_embeds,