import torch def get_pipeline_embeds(pipeline, prompt, negative_prompt, device): """ Get pipeline embeds for prompts bigger than the maxlength of the pipe :param pipeline: :param prompt: :param negative_prompt: :param device: :return: """ max_length = pipeline.tokenizer.model_max_length # simple way to determine length of tokens # count_prompt = len(prompt.split(" ")) # count_negative_prompt = len(negative_prompt.split(" ")) # create the tensor based on which prompt is longer # if count_prompt >= count_negative_prompt: input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False, padding='longest').input_ids.to(device) # input_ids = pipeline.tokenizer(prompt, padding="max_length", # max_length=pipeline.tokenizer.model_max_length, # truncation=True, # return_tensors="pt",).input_ids.to(device) shape_max_length = input_ids.shape[-1] if negative_prompt is not None: negative_ids = pipeline.tokenizer(negative_prompt, truncation=True, padding="max_length", max_length=shape_max_length, return_tensors="pt").input_ids.to(device) # else: # negative_ids = pipeline.tokenizer(negative_prompt, return_tensors="pt", truncation=False).input_ids.to(device) # shape_max_length = negative_ids.shape[-1] # input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False, padding="max_length", # max_length=shape_max_length).input_ids.to(device) concat_embeds = [] neg_embeds = [] for i in range(0, shape_max_length, max_length): if hasattr(pipeline.text_encoder.config, "use_attention_mask") and pipeline.text_encoder.config.use_attention_mask: attention_mask = input_ids[:, i: i + max_length].attention_mask.to(device) else: attention_mask = None concat_embeds.append(pipeline.text_encoder(input_ids[:, i: i + max_length], attention_mask=attention_mask)[0]) if negative_prompt is not None: if hasattr(pipeline.text_encoder.config, "use_attention_mask") and pipeline.text_encoder.config.use_attention_mask: attention_mask = negative_ids[:, i: i + max_length].attention_mask.to(device) else: attention_mask = None neg_embeds.append(pipeline.text_encoder(negative_ids[:, i: i + max_length], attention_mask=attention_mask)[0]) concat_embeds = torch.cat(concat_embeds, dim=1) if negative_prompt is not None: neg_embeds = torch.cat(neg_embeds, dim=1) else: neg_embeds = None return concat_embeds, neg_embeds