Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -211,19 +211,24 @@ def generate_30(
|
|
211 |
return_tensors="pt",
|
212 |
)
|
213 |
text_input_ids2 = text_inputs2.input_ids
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
|
|
|
|
|
|
|
|
|
|
227 |
options = {
|
228 |
#"prompt": prompt,
|
229 |
"prompt_embeds": prompt_embeds,
|
@@ -288,19 +293,24 @@ def generate_60(
|
|
288 |
return_tensors="pt",
|
289 |
)
|
290 |
text_input_ids2 = text_inputs2.input_ids
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
|
|
|
|
|
|
|
|
|
|
304 |
options = {
|
305 |
#"prompt": prompt,
|
306 |
"prompt_embeds": prompt_embeds,
|
@@ -365,19 +375,24 @@ def generate_90(
|
|
365 |
return_tensors="pt",
|
366 |
)
|
367 |
text_input_ids2 = text_inputs2.input_ids
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
|
|
|
|
|
|
|
|
|
|
381 |
options = {
|
382 |
#"prompt": prompt,
|
383 |
"prompt_embeds": prompt_embeds,
|
|
|
211 |
return_tensors="pt",
|
212 |
)
|
213 |
text_input_ids2 = text_inputs2.input_ids
|
214 |
+
|
215 |
+
# 2. Encode with the two text encoders
|
216 |
+
prompt_embeds_a = pipe.text_encoder(text_input_ids1, output_hidden_states=True)
|
217 |
+
pooled_prompt_embeds_a = prompt_embeds_a[0] # Pooled output from encoder 1
|
218 |
+
prompt_embeds_a = prompt_embeds_a.hidden_states[-2] # Penultimate hidden state from encoder 1
|
219 |
+
|
220 |
+
prompt_embeds_b = pipe.text_encoder_2(text_input_ids2, output_hidden_states=True)
|
221 |
+
pooled_prompt_embeds_b = prompt_embeds_b[0] # Pooled output from encoder 2
|
222 |
+
prompt_embeds_b = prompt_embeds_b.hidden_states[-2] # Penultimate hidden state from encoder 2
|
223 |
+
|
224 |
+
# 3. Concatenate the embeddings along the sequence dimension (dim=1)
|
225 |
+
prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b], dim=1)
|
226 |
+
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b], dim=1)
|
227 |
+
|
228 |
+
# 4. (Optional) Average the pooled embeddings
|
229 |
+
pooled_prompt_embeds = pooled_prompt_embeds.mean(dim=1, keepdim=True)
|
230 |
+
|
231 |
+
|
232 |
options = {
|
233 |
#"prompt": prompt,
|
234 |
"prompt_embeds": prompt_embeds,
|
|
|
293 |
return_tensors="pt",
|
294 |
)
|
295 |
text_input_ids2 = text_inputs2.input_ids
|
296 |
+
|
297 |
+
# 2. Encode with the two text encoders
|
298 |
+
prompt_embeds_a = pipe.text_encoder(text_input_ids1, output_hidden_states=True)
|
299 |
+
pooled_prompt_embeds_a = prompt_embeds_a[0] # Pooled output from encoder 1
|
300 |
+
prompt_embeds_a = prompt_embeds_a.hidden_states[-2] # Penultimate hidden state from encoder 1
|
301 |
+
|
302 |
+
prompt_embeds_b = pipe.text_encoder_2(text_input_ids2, output_hidden_states=True)
|
303 |
+
pooled_prompt_embeds_b = prompt_embeds_b[0] # Pooled output from encoder 2
|
304 |
+
prompt_embeds_b = prompt_embeds_b.hidden_states[-2] # Penultimate hidden state from encoder 2
|
305 |
+
|
306 |
+
# 3. Concatenate the embeddings along the sequence dimension (dim=1)
|
307 |
+
prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b], dim=1)
|
308 |
+
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b], dim=1)
|
309 |
+
|
310 |
+
# 4. (Optional) Average the pooled embeddings
|
311 |
+
pooled_prompt_embeds = pooled_prompt_embeds.mean(dim=1, keepdim=True)
|
312 |
+
|
313 |
+
|
314 |
options = {
|
315 |
#"prompt": prompt,
|
316 |
"prompt_embeds": prompt_embeds,
|
|
|
375 |
return_tensors="pt",
|
376 |
)
|
377 |
text_input_ids2 = text_inputs2.input_ids
|
378 |
+
|
379 |
+
# 2. Encode with the two text encoders
|
380 |
+
prompt_embeds_a = pipe.text_encoder(text_input_ids1, output_hidden_states=True)
|
381 |
+
pooled_prompt_embeds_a = prompt_embeds_a[0] # Pooled output from encoder 1
|
382 |
+
prompt_embeds_a = prompt_embeds_a.hidden_states[-2] # Penultimate hidden state from encoder 1
|
383 |
+
|
384 |
+
prompt_embeds_b = pipe.text_encoder_2(text_input_ids2, output_hidden_states=True)
|
385 |
+
pooled_prompt_embeds_b = prompt_embeds_b[0] # Pooled output from encoder 2
|
386 |
+
prompt_embeds_b = prompt_embeds_b.hidden_states[-2] # Penultimate hidden state from encoder 2
|
387 |
+
|
388 |
+
# 3. Concatenate the embeddings along the sequence dimension (dim=1)
|
389 |
+
prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b], dim=1)
|
390 |
+
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b], dim=1)
|
391 |
+
|
392 |
+
# 4. (Optional) Average the pooled embeddings
|
393 |
+
pooled_prompt_embeds = pooled_prompt_embeds.mean(dim=1, keepdim=True)
|
394 |
+
|
395 |
+
|
396 |
options = {
|
397 |
#"prompt": prompt,
|
398 |
"prompt_embeds": prompt_embeds,
|