|
import torch
|
|
|
|
|
|
def load_text_encoders(args, class_one, class_two):
|
|
text_encoder_one = class_one.from_pretrained(
|
|
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
|
|
)
|
|
text_encoder_two = class_two.from_pretrained(
|
|
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
|
|
)
|
|
return text_encoder_one, text_encoder_two
|
|
|
|
|
|
def tokenize_prompt(tokenizer, prompt, max_sequence_length):
|
|
text_inputs = tokenizer(
|
|
prompt,
|
|
padding="max_length",
|
|
max_length=max_sequence_length,
|
|
truncation=True,
|
|
return_length=False,
|
|
return_overflowing_tokens=False,
|
|
return_tensors="pt",
|
|
)
|
|
text_input_ids = text_inputs.input_ids
|
|
return text_input_ids
|
|
|
|
|
|
def tokenize_prompt_clip(tokenizer, prompt):
|
|
text_inputs = tokenizer(
|
|
prompt,
|
|
padding="max_length",
|
|
max_length=77,
|
|
truncation=True,
|
|
return_length=False,
|
|
return_overflowing_tokens=False,
|
|
return_tensors="pt",
|
|
)
|
|
text_input_ids = text_inputs.input_ids
|
|
return text_input_ids
|
|
|
|
|
|
def tokenize_prompt_t5(tokenizer, prompt):
|
|
text_inputs = tokenizer(
|
|
prompt,
|
|
padding="max_length",
|
|
max_length=512,
|
|
truncation=True,
|
|
return_length=False,
|
|
return_overflowing_tokens=False,
|
|
return_tensors="pt",
|
|
)
|
|
text_input_ids = text_inputs.input_ids
|
|
return text_input_ids
|
|
|
|
|
|
def _encode_prompt_with_t5(
|
|
text_encoder,
|
|
tokenizer,
|
|
max_sequence_length=512,
|
|
prompt=None,
|
|
num_images_per_prompt=1,
|
|
device=None,
|
|
text_input_ids=None,
|
|
):
|
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
|
batch_size = len(prompt)
|
|
|
|
if tokenizer is not None:
|
|
text_inputs = tokenizer(
|
|
prompt,
|
|
padding="max_length",
|
|
max_length=max_sequence_length,
|
|
truncation=True,
|
|
return_length=False,
|
|
return_overflowing_tokens=False,
|
|
return_tensors="pt",
|
|
)
|
|
text_input_ids = text_inputs.input_ids
|
|
else:
|
|
if text_input_ids is None:
|
|
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
|
|
|
|
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
|
|
|
|
dtype = text_encoder.dtype
|
|
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
|
|
|
_, seq_len, _ = prompt_embeds.shape
|
|
|
|
|
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
|
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
|
|
|
return prompt_embeds
|
|
|
|
|
|
def _encode_prompt_with_clip(
|
|
text_encoder,
|
|
tokenizer,
|
|
prompt: str,
|
|
device=None,
|
|
text_input_ids=None,
|
|
num_images_per_prompt: int = 1,
|
|
):
|
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
|
batch_size = len(prompt)
|
|
|
|
if tokenizer is not None:
|
|
text_inputs = tokenizer(
|
|
prompt,
|
|
padding="max_length",
|
|
max_length=77,
|
|
truncation=True,
|
|
return_overflowing_tokens=False,
|
|
return_length=False,
|
|
return_tensors="pt",
|
|
)
|
|
|
|
text_input_ids = text_inputs.input_ids
|
|
else:
|
|
if text_input_ids is None:
|
|
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
|
|
|
|
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
|
|
|
|
|
prompt_embeds = prompt_embeds.pooler_output
|
|
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
|
|
|
|
|
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
|
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
|
|
|
return prompt_embeds
|
|
|
|
|
|
def encode_prompt(
|
|
text_encoders,
|
|
tokenizers,
|
|
prompt: str,
|
|
max_sequence_length,
|
|
device=None,
|
|
num_images_per_prompt: int = 1,
|
|
text_input_ids_list=None,
|
|
):
|
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
|
dtype = text_encoders[0].dtype
|
|
|
|
pooled_prompt_embeds = _encode_prompt_with_clip(
|
|
text_encoder=text_encoders[0],
|
|
tokenizer=tokenizers[0],
|
|
prompt=prompt,
|
|
device=device if device is not None else text_encoders[0].device,
|
|
num_images_per_prompt=num_images_per_prompt,
|
|
text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,
|
|
)
|
|
|
|
prompt_embeds = _encode_prompt_with_t5(
|
|
text_encoder=text_encoders[1],
|
|
tokenizer=tokenizers[1],
|
|
max_sequence_length=max_sequence_length,
|
|
prompt=prompt,
|
|
num_images_per_prompt=num_images_per_prompt,
|
|
device=device if device is not None else text_encoders[1].device,
|
|
text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
|
|
)
|
|
|
|
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
|
|
|
return prompt_embeds, pooled_prompt_embeds, text_ids
|
|
|
|
|
|
def encode_token_ids(text_encoders, tokens, accelerator, num_images_per_prompt=1, device=None):
|
|
text_encoder_clip = text_encoders[0]
|
|
text_encoder_t5 = text_encoders[1]
|
|
tokens_clip, tokens_t5 = tokens[0], tokens[1]
|
|
batch_size = tokens_clip.shape[0]
|
|
|
|
if device == "cpu":
|
|
device = "cpu"
|
|
else:
|
|
device = accelerator.device
|
|
|
|
|
|
prompt_embeds = text_encoder_clip(tokens_clip.to(device), output_hidden_states=False)
|
|
|
|
prompt_embeds = prompt_embeds.pooler_output
|
|
prompt_embeds = prompt_embeds.to(dtype=text_encoder_clip.dtype, device=accelerator.device)
|
|
|
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
|
pooled_prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
|
pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=text_encoder_clip.dtype, device=accelerator.device)
|
|
|
|
|
|
prompt_embeds = text_encoder_t5(tokens_t5.to(device))[0]
|
|
dtype = text_encoder_t5.dtype
|
|
prompt_embeds = prompt_embeds.to(dtype=dtype, device=accelerator.device)
|
|
_, seq_len, _ = prompt_embeds.shape
|
|
|
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
|
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
|
|
|
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=accelerator.device, dtype=dtype)
|
|
|
|
return prompt_embeds, pooled_prompt_embeds, text_ids |