|
import torch |
|
|
|
|
|
def get_pipeline_embeds(pipeline, prompt, negative_prompt, device): |
|
""" Get pipeline embeds for prompts bigger than the maxlength of the pipe |
|
:param pipeline: |
|
:param prompt: |
|
:param negative_prompt: |
|
:param device: |
|
:return: |
|
""" |
|
max_length = pipeline.tokenizer.model_max_length |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False, padding='longest').input_ids.to(device) |
|
|
|
|
|
|
|
|
|
shape_max_length = input_ids.shape[-1] |
|
|
|
if negative_prompt is not None: |
|
negative_ids = pipeline.tokenizer(negative_prompt, truncation=True, padding="max_length", |
|
max_length=shape_max_length, return_tensors="pt").input_ids.to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
concat_embeds = [] |
|
neg_embeds = [] |
|
for i in range(0, shape_max_length, max_length): |
|
if hasattr(pipeline.text_encoder.config, "use_attention_mask") and pipeline.text_encoder.config.use_attention_mask: |
|
attention_mask = input_ids[:, i: i + max_length].attention_mask.to(device) |
|
else: |
|
attention_mask = None |
|
concat_embeds.append(pipeline.text_encoder(input_ids[:, i: i + max_length], |
|
attention_mask=attention_mask)[0]) |
|
|
|
if negative_prompt is not None: |
|
if hasattr(pipeline.text_encoder.config, "use_attention_mask") and pipeline.text_encoder.config.use_attention_mask: |
|
attention_mask = negative_ids[:, i: i + max_length].attention_mask.to(device) |
|
else: |
|
attention_mask = None |
|
neg_embeds.append(pipeline.text_encoder(negative_ids[:, i: i + max_length], |
|
attention_mask=attention_mask)[0]) |
|
|
|
concat_embeds = torch.cat(concat_embeds, dim=1) |
|
|
|
if negative_prompt is not None: |
|
neg_embeds = torch.cat(neg_embeds, dim=1) |
|
else: |
|
neg_embeds = None |
|
|
|
return concat_embeds, neg_embeds |
|
|