|
import torch |
|
|
|
|
|
def load_text_encoders(args, class_one, class_two): |
|
text_encoder_one = class_one.from_pretrained( |
|
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant |
|
) |
|
text_encoder_two = class_two.from_pretrained( |
|
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant |
|
) |
|
return text_encoder_one, text_encoder_two |
|
|
|
|
|
def tokenize_prompt(tokenizer, prompt, max_sequence_length): |
|
text_inputs = tokenizer( |
|
prompt, |
|
padding="max_length", |
|
max_length=max_sequence_length, |
|
truncation=True, |
|
return_length=False, |
|
return_overflowing_tokens=False, |
|
return_tensors="pt", |
|
) |
|
text_input_ids = text_inputs.input_ids |
|
return text_input_ids |
|
|
|
|
|
def tokenize_prompt_clip(tokenizer, prompt): |
|
text_inputs = tokenizer( |
|
prompt, |
|
padding="max_length", |
|
max_length=77, |
|
truncation=True, |
|
return_length=False, |
|
return_overflowing_tokens=False, |
|
return_tensors="pt", |
|
) |
|
text_input_ids = text_inputs.input_ids |
|
return text_input_ids |
|
|
|
|
|
def tokenize_prompt_t5(tokenizer, prompt): |
|
text_inputs = tokenizer( |
|
prompt, |
|
padding="max_length", |
|
max_length=512, |
|
truncation=True, |
|
return_length=False, |
|
return_overflowing_tokens=False, |
|
return_tensors="pt", |
|
) |
|
text_input_ids = text_inputs.input_ids |
|
return text_input_ids |
|
|
|
|
|
def _encode_prompt_with_t5( |
|
text_encoder, |
|
tokenizer, |
|
max_sequence_length=512, |
|
prompt=None, |
|
num_images_per_prompt=1, |
|
device=None, |
|
text_input_ids=None, |
|
): |
|
prompt = [prompt] if isinstance(prompt, str) else prompt |
|
batch_size = len(prompt) |
|
|
|
if tokenizer is not None: |
|
text_inputs = tokenizer( |
|
prompt, |
|
padding="max_length", |
|
max_length=max_sequence_length, |
|
truncation=True, |
|
return_length=False, |
|
return_overflowing_tokens=False, |
|
return_tensors="pt", |
|
) |
|
text_input_ids = text_inputs.input_ids |
|
else: |
|
if text_input_ids is None: |
|
raise ValueError("text_input_ids must be provided when the tokenizer is not specified") |
|
|
|
prompt_embeds = text_encoder(text_input_ids.to(device))[0] |
|
|
|
dtype = text_encoder.dtype |
|
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) |
|
|
|
_, seq_len, _ = prompt_embeds.shape |
|
|
|
|
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) |
|
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) |
|
|
|
return prompt_embeds |
|
|
|
|
|
def _encode_prompt_with_clip( |
|
text_encoder, |
|
tokenizer, |
|
prompt: str, |
|
device=None, |
|
text_input_ids=None, |
|
num_images_per_prompt: int = 1, |
|
): |
|
prompt = [prompt] if isinstance(prompt, str) else prompt |
|
batch_size = len(prompt) |
|
|
|
if tokenizer is not None: |
|
text_inputs = tokenizer( |
|
prompt, |
|
padding="max_length", |
|
max_length=77, |
|
truncation=True, |
|
return_overflowing_tokens=False, |
|
return_length=False, |
|
return_tensors="pt", |
|
) |
|
|
|
text_input_ids = text_inputs.input_ids |
|
else: |
|
if text_input_ids is None: |
|
raise ValueError("text_input_ids must be provided when the tokenizer is not specified") |
|
|
|
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) |
|
|
|
|
|
prompt_embeds = prompt_embeds.pooler_output |
|
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) |
|
|
|
|
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) |
|
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) |
|
|
|
return prompt_embeds |
|
|
|
|
|
def encode_prompt( |
|
text_encoders, |
|
tokenizers, |
|
prompt: str, |
|
max_sequence_length, |
|
device=None, |
|
num_images_per_prompt: int = 1, |
|
text_input_ids_list=None, |
|
): |
|
prompt = [prompt] if isinstance(prompt, str) else prompt |
|
dtype = text_encoders[0].dtype |
|
|
|
pooled_prompt_embeds = _encode_prompt_with_clip( |
|
text_encoder=text_encoders[0], |
|
tokenizer=tokenizers[0], |
|
prompt=prompt, |
|
device=device if device is not None else text_encoders[0].device, |
|
num_images_per_prompt=num_images_per_prompt, |
|
text_input_ids=text_input_ids_list[0] if text_input_ids_list else None, |
|
) |
|
|
|
prompt_embeds = _encode_prompt_with_t5( |
|
text_encoder=text_encoders[1], |
|
tokenizer=tokenizers[1], |
|
max_sequence_length=max_sequence_length, |
|
prompt=prompt, |
|
num_images_per_prompt=num_images_per_prompt, |
|
device=device if device is not None else text_encoders[1].device, |
|
text_input_ids=text_input_ids_list[1] if text_input_ids_list else None, |
|
) |
|
|
|
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) |
|
|
|
return prompt_embeds, pooled_prompt_embeds, text_ids |
|
|
|
|
|
def encode_token_ids(text_encoders, tokens, accelerator, num_images_per_prompt=1, device=None): |
|
text_encoder_clip = text_encoders[0] |
|
text_encoder_t5 = text_encoders[1] |
|
tokens_clip, tokens_t5 = tokens[0], tokens[1] |
|
batch_size = tokens_clip.shape[0] |
|
|
|
if device == "cpu": |
|
device = "cpu" |
|
else: |
|
device = accelerator.device |
|
|
|
|
|
prompt_embeds = text_encoder_clip(tokens_clip.to(device), output_hidden_states=False) |
|
|
|
prompt_embeds = prompt_embeds.pooler_output |
|
prompt_embeds = prompt_embeds.to(dtype=text_encoder_clip.dtype, device=accelerator.device) |
|
|
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) |
|
pooled_prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) |
|
pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=text_encoder_clip.dtype, device=accelerator.device) |
|
|
|
|
|
prompt_embeds = text_encoder_t5(tokens_t5.to(device))[0] |
|
dtype = text_encoder_t5.dtype |
|
prompt_embeds = prompt_embeds.to(dtype=dtype, device=accelerator.device) |
|
_, seq_len, _ = prompt_embeds.shape |
|
|
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) |
|
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) |
|
|
|
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=accelerator.device, dtype=dtype) |
|
|
|
return prompt_embeds, pooled_prompt_embeds, text_ids |