File size: 2,850 Bytes
119e1fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch


def get_pipeline_embeds(pipeline, prompt, negative_prompt, device):
    """ Get pipeline embeds for prompts bigger than the maxlength of the pipe
    :param pipeline:
    :param prompt:
    :param negative_prompt:
    :param device:
    :return:
    """
    max_length = pipeline.tokenizer.model_max_length

    # simple way to determine length of tokens
    # count_prompt = len(prompt.split(" "))
    # count_negative_prompt = len(negative_prompt.split(" "))

    # create the tensor based on which prompt is longer
    # if count_prompt >= count_negative_prompt:
    input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False, padding='longest').input_ids.to(device)
    # input_ids = pipeline.tokenizer(prompt, padding="max_length",
    #             max_length=pipeline.tokenizer.model_max_length,
    #             truncation=True,
    #             return_tensors="pt",).input_ids.to(device)
    shape_max_length = input_ids.shape[-1]

    if negative_prompt is not None:
        negative_ids = pipeline.tokenizer(negative_prompt, truncation=True, padding="max_length",
                                        max_length=shape_max_length, return_tensors="pt").input_ids.to(device)

    # else:
    #     negative_ids = pipeline.tokenizer(negative_prompt, return_tensors="pt", truncation=False).input_ids.to(device)
    #     shape_max_length = negative_ids.shape[-1]
    #     input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False, padding="max_length",
    #                                    max_length=shape_max_length).input_ids.to(device)

    concat_embeds = []
    neg_embeds = []
    for i in range(0, shape_max_length, max_length):
        if hasattr(pipeline.text_encoder.config, "use_attention_mask") and pipeline.text_encoder.config.use_attention_mask:
            attention_mask = input_ids[:, i: i + max_length].attention_mask.to(device)
        else:
            attention_mask = None
        concat_embeds.append(pipeline.text_encoder(input_ids[:, i: i + max_length],
                                                   attention_mask=attention_mask)[0])
        
        if negative_prompt is not None:
            if hasattr(pipeline.text_encoder.config, "use_attention_mask") and pipeline.text_encoder.config.use_attention_mask:
                attention_mask = negative_ids[:, i: i + max_length].attention_mask.to(device)
            else:
                attention_mask = None
            neg_embeds.append(pipeline.text_encoder(negative_ids[:, i: i + max_length],
                                                    attention_mask=attention_mask)[0])

    concat_embeds = torch.cat(concat_embeds, dim=1)

    if negative_prompt is not None:
        neg_embeds = torch.cat(neg_embeds, dim=1)
    else:
        neg_embeds = None

    return concat_embeds, neg_embeds