1inkusFace commited on
Commit
75cd14f
·
verified ·
1 Parent(s): 1fdea20

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -8
app.py CHANGED
@@ -200,8 +200,7 @@ def generate_30(
200
  truncation=True,
201
  return_tensors="pt",
202
  )
203
- text_input_ids1 = text_inputs1.input_ids
204
-
205
  text_inputs2 = pipe.tokenizer(
206
  prompt2,
207
  padding="max_length",
@@ -209,8 +208,7 @@ def generate_30(
209
  truncation=True,
210
  return_tensors="pt",
211
  )
212
- text_input_ids2 = text_inputs2.input_ids
213
-
214
  prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device, dtype=torch.bfloat16), output_hidden_states=True)
215
  prompt_embedsb = pipe.text_encoder(text_input_ids2.to(device, dtype=torch.bfloat16), output_hidden_states=True)
216
  prompt_embeds = torch.cat([prompt_embedsa,prompt_embedsb]).mean(dim=-1)
@@ -267,6 +265,7 @@ def generate_60(
267
  truncation=True,
268
  return_tensors="pt",
269
  )
 
270
  text_inputs2 = pipe.tokenizer(
271
  prompt2,
272
  padding="max_length",
@@ -274,8 +273,9 @@ def generate_60(
274
  truncation=True,
275
  return_tensors="pt",
276
  )
277
- prompt_embedsa = pipe.text_encoder(text_inputs1.to(device), output_hidden_states=True)
278
- prompt_embedsb = pipe.text_encoder(text_inputs2.to(device), output_hidden_states=True)
 
279
  prompt_embeds = torch.cat([prompt_embedsa,prompt_embedsb]).mean(dim=-1)
280
 
281
  options = {
@@ -330,6 +330,7 @@ def generate_90(
330
  truncation=True,
331
  return_tensors="pt",
332
  )
 
333
  text_inputs2 = pipe.tokenizer(
334
  prompt2,
335
  padding="max_length",
@@ -337,8 +338,9 @@ def generate_90(
337
  truncation=True,
338
  return_tensors="pt",
339
  )
340
- prompt_embedsa = pipe.text_encoder(text_inputs1.to(device), output_hidden_states=True)
341
- prompt_embedsb = pipe.text_encoder(text_inputs2.to(device), output_hidden_states=True)
 
342
  prompt_embeds = torch.cat([prompt_embedsa,prompt_embedsb]).mean(dim=-1)
343
 
344
  options = {
 
200
  truncation=True,
201
  return_tensors="pt",
202
  )
203
+ text_input_ids1 = text_inputs1.input_ids
 
204
  text_inputs2 = pipe.tokenizer(
205
  prompt2,
206
  padding="max_length",
 
208
  truncation=True,
209
  return_tensors="pt",
210
  )
211
+ text_input_ids2 = text_inputs2.input_ids
 
212
  prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device, dtype=torch.bfloat16), output_hidden_states=True)
213
  prompt_embedsb = pipe.text_encoder(text_input_ids2.to(device, dtype=torch.bfloat16), output_hidden_states=True)
214
  prompt_embeds = torch.cat([prompt_embedsa,prompt_embedsb]).mean(dim=-1)
 
265
  truncation=True,
266
  return_tensors="pt",
267
  )
268
+ text_input_ids1 = text_inputs1.input_ids
269
  text_inputs2 = pipe.tokenizer(
270
  prompt2,
271
  padding="max_length",
 
273
  truncation=True,
274
  return_tensors="pt",
275
  )
276
+ text_input_ids2 = text_inputs2.input_ids
277
+ prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device, dtype=torch.bfloat16), output_hidden_states=True)
278
+ prompt_embedsb = pipe.text_encoder(text_input_ids2.to(device, dtype=torch.bfloat16), output_hidden_states=True)
279
  prompt_embeds = torch.cat([prompt_embedsa,prompt_embedsb]).mean(dim=-1)
280
 
281
  options = {
 
330
  truncation=True,
331
  return_tensors="pt",
332
  )
333
+ text_input_ids1 = text_inputs1.input_ids
334
  text_inputs2 = pipe.tokenizer(
335
  prompt2,
336
  padding="max_length",
 
338
  truncation=True,
339
  return_tensors="pt",
340
  )
341
+ text_input_ids2 = text_inputs2.input_ids
342
+ prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device, dtype=torch.bfloat16), output_hidden_states=True)
343
+ prompt_embedsb = pipe.text_encoder(text_input_ids2.to(device, dtype=torch.bfloat16), output_hidden_states=True)
344
  prompt_embeds = torch.cat([prompt_embedsa,prompt_embedsb]).mean(dim=-1)
345
 
346
  options = {