File size: 2,850 Bytes
119e1fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import torch
def get_pipeline_embeds(pipeline, prompt, negative_prompt, device):
""" Get pipeline embeds for prompts bigger than the maxlength of the pipe
:param pipeline:
:param prompt:
:param negative_prompt:
:param device:
:return:
"""
max_length = pipeline.tokenizer.model_max_length
# simple way to determine length of tokens
# count_prompt = len(prompt.split(" "))
# count_negative_prompt = len(negative_prompt.split(" "))
# create the tensor based on which prompt is longer
# if count_prompt >= count_negative_prompt:
input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False, padding='longest').input_ids.to(device)
# input_ids = pipeline.tokenizer(prompt, padding="max_length",
# max_length=pipeline.tokenizer.model_max_length,
# truncation=True,
# return_tensors="pt",).input_ids.to(device)
shape_max_length = input_ids.shape[-1]
if negative_prompt is not None:
negative_ids = pipeline.tokenizer(negative_prompt, truncation=True, padding="max_length",
max_length=shape_max_length, return_tensors="pt").input_ids.to(device)
# else:
# negative_ids = pipeline.tokenizer(negative_prompt, return_tensors="pt", truncation=False).input_ids.to(device)
# shape_max_length = negative_ids.shape[-1]
# input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False, padding="max_length",
# max_length=shape_max_length).input_ids.to(device)
concat_embeds = []
neg_embeds = []
for i in range(0, shape_max_length, max_length):
if hasattr(pipeline.text_encoder.config, "use_attention_mask") and pipeline.text_encoder.config.use_attention_mask:
attention_mask = input_ids[:, i: i + max_length].attention_mask.to(device)
else:
attention_mask = None
concat_embeds.append(pipeline.text_encoder(input_ids[:, i: i + max_length],
attention_mask=attention_mask)[0])
if negative_prompt is not None:
if hasattr(pipeline.text_encoder.config, "use_attention_mask") and pipeline.text_encoder.config.use_attention_mask:
attention_mask = negative_ids[:, i: i + max_length].attention_mask.to(device)
else:
attention_mask = None
neg_embeds.append(pipeline.text_encoder(negative_ids[:, i: i + max_length],
attention_mask=attention_mask)[0])
concat_embeds = torch.cat(concat_embeds, dim=1)
if negative_prompt is not None:
neg_embeds = torch.cat(neg_embeds, dim=1)
else:
neg_embeds = None
return concat_embeds, neg_embeds
|