Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| # from networks.layers import * | |
| import torch.nn.functional as F | |
| import clip | |
| from einops import rearrange, repeat | |
| import math | |
| from random import random | |
| from tqdm.auto import tqdm | |
| from typing import Callable, Optional, List, Dict | |
| from copy import deepcopy | |
| from functools import partial | |
| from models.mask_transformer.tools import * | |
| from torch.distributions.categorical import Categorical | |
| class InputProcess(nn.Module): | |
| def __init__(self, input_feats, latent_dim): | |
| super().__init__() | |
| self.input_feats = input_feats | |
| self.latent_dim = latent_dim | |
| self.poseEmbedding = nn.Linear(self.input_feats, self.latent_dim) | |
| def forward(self, x): | |
| # [bs, ntokens, input_feats] | |
| x = x.permute((1, 0, 2)) # [seqen, bs, input_feats] | |
| # print(x.shape) | |
| x = self.poseEmbedding(x) # [seqlen, bs, d] | |
| return x | |
| class PositionalEncoding(nn.Module): | |
| #Borrow from MDM, the same as above, but add dropout, exponential may improve precision | |
| def __init__(self, d_model, dropout=0.1, max_len=5000): | |
| super(PositionalEncoding, self).__init__() | |
| self.dropout = nn.Dropout(p=dropout) | |
| pe = torch.zeros(max_len, d_model) | |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| pe = pe.unsqueeze(0).transpose(0, 1) #[max_len, 1, d_model] | |
| self.register_buffer('pe', pe) | |
| def forward(self, x): | |
| # not used in the final model | |
| x = x + self.pe[:x.shape[0], :] | |
| return self.dropout(x) | |
| class OutputProcess_Bert(nn.Module): | |
| def __init__(self, out_feats, latent_dim): | |
| super().__init__() | |
| self.dense = nn.Linear(latent_dim, latent_dim) | |
| self.transform_act_fn = F.gelu | |
| self.LayerNorm = nn.LayerNorm(latent_dim, eps=1e-12) | |
| self.poseFinal = nn.Linear(latent_dim, out_feats) #Bias! | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| hidden_states = self.dense(hidden_states) | |
| hidden_states = self.transform_act_fn(hidden_states) | |
| hidden_states = self.LayerNorm(hidden_states) | |
| output = self.poseFinal(hidden_states) # [seqlen, bs, out_feats] | |
| output = output.permute(1, 2, 0) # [bs, c, seqlen] | |
| return output | |
| class OutputProcess(nn.Module): | |
| def __init__(self, out_feats, latent_dim): | |
| super().__init__() | |
| self.dense = nn.Linear(latent_dim, latent_dim) | |
| self.transform_act_fn = F.gelu | |
| self.LayerNorm = nn.LayerNorm(latent_dim, eps=1e-12) | |
| self.poseFinal = nn.Linear(latent_dim, out_feats) #Bias! | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| hidden_states = self.dense(hidden_states) | |
| hidden_states = self.transform_act_fn(hidden_states) | |
| hidden_states = self.LayerNorm(hidden_states) | |
| output = self.poseFinal(hidden_states) # [seqlen, bs, out_feats] | |
| output = output.permute(1, 2, 0) # [bs, e, seqlen] | |
| return output | |
| class MaskTransformer(nn.Module): | |
| def __init__(self, code_dim, cond_mode, latent_dim=256, ff_size=1024, num_layers=8, | |
| num_heads=4, dropout=0.1, clip_dim=512, cond_drop_prob=0.1, | |
| clip_version=None, opt=None, **kargs): | |
| super(MaskTransformer, self).__init__() | |
| print(f'latent_dim: {latent_dim}, ff_size: {ff_size}, nlayers: {num_layers}, nheads: {num_heads}, dropout: {dropout}') | |
| self.code_dim = code_dim | |
| self.latent_dim = latent_dim | |
| self.clip_dim = clip_dim | |
| self.dropout = dropout | |
| self.opt = opt | |
| self.cond_mode = cond_mode | |
| self.cond_drop_prob = cond_drop_prob | |
| if self.cond_mode == 'action': | |
| assert 'num_actions' in kargs | |
| self.num_actions = kargs.get('num_actions', 1) | |
| ''' | |
| Preparing Networks | |
| ''' | |
| self.input_process = InputProcess(self.code_dim, self.latent_dim) | |
| self.position_enc = PositionalEncoding(self.latent_dim, self.dropout) | |
| seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, | |
| nhead=num_heads, | |
| dim_feedforward=ff_size, | |
| dropout=dropout, | |
| activation='gelu') | |
| self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer, | |
| num_layers=num_layers) | |
| self.encode_action = partial(F.one_hot, num_classes=self.num_actions) | |
| # if self.cond_mode != 'no_cond': | |
| if self.cond_mode == 'text': | |
| self.cond_emb = nn.Linear(self.clip_dim, self.latent_dim) | |
| elif self.cond_mode == 'action': | |
| self.cond_emb = nn.Linear(self.num_actions, self.latent_dim) | |
| elif self.cond_mode == 'uncond': | |
| self.cond_emb = nn.Identity() | |
| else: | |
| raise KeyError("Unsupported condition mode!!!") | |
| _num_tokens = opt.num_tokens + 2 # two dummy tokens, one for masking, one for padding | |
| self.mask_id = opt.num_tokens | |
| self.pad_id = opt.num_tokens + 1 | |
| self.output_process = OutputProcess_Bert(out_feats=opt.num_tokens, latent_dim=latent_dim) | |
| self.token_emb = nn.Embedding(_num_tokens, self.code_dim) | |
| self.apply(self.__init_weights) | |
| ''' | |
| Preparing frozen weights | |
| ''' | |
| if self.cond_mode == 'text': | |
| print('Loading CLIP...') | |
| self.clip_version = clip_version | |
| self.clip_model = self.load_and_freeze_clip(clip_version) | |
| self.noise_schedule = cosine_schedule | |
| def load_and_freeze_token_emb(self, codebook): | |
| ''' | |
| :param codebook: (c, d) | |
| :return: | |
| ''' | |
| assert self.training, 'Only necessary in training mode' | |
| c, d = codebook.shape | |
| self.token_emb.weight = nn.Parameter(torch.cat([codebook, torch.zeros(size=(2, d), device=codebook.device)], dim=0)) #add two dummy tokens, 0 vectors | |
| self.token_emb.requires_grad_(False) | |
| # self.token_emb.weight.requires_grad = False | |
| # self.token_emb_ready = True | |
| print("Token embedding initialized!") | |
| def __init_weights(self, module): | |
| if isinstance(module, (nn.Linear, nn.Embedding)): | |
| module.weight.data.normal_(mean=0.0, std=0.02) | |
| if isinstance(module, nn.Linear) and module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.LayerNorm): | |
| module.bias.data.zero_() | |
| module.weight.data.fill_(1.0) | |
| def parameters_wo_clip(self): | |
| return [p for name, p in self.named_parameters() if not name.startswith('clip_model.')] | |
| def load_and_freeze_clip(self, clip_version): | |
| clip_model, clip_preprocess = clip.load(clip_version, device='cpu', | |
| jit=False) # Must set jit=False for training | |
| # Cannot run on cpu | |
| clip.model.convert_weights( | |
| clip_model) # Actually this line is unnecessary since clip by default already on float16 | |
| # Date 0707: It's necessary, only unecessary when load directly to gpu. Disable if need to run on cpu | |
| # Freeze CLIP weights | |
| clip_model.eval() | |
| for p in clip_model.parameters(): | |
| p.requires_grad = False | |
| return clip_model | |
| def encode_text(self, raw_text): | |
| device = next(self.parameters()).device | |
| text = clip.tokenize(raw_text, truncate=True).to(device) | |
| feat_clip_text = self.clip_model.encode_text(text).float() | |
| return feat_clip_text | |
| def mask_cond(self, cond, force_mask=False): | |
| bs, d = cond.shape | |
| if force_mask: | |
| return torch.zeros_like(cond) | |
| elif self.training and self.cond_drop_prob > 0.: | |
| mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_drop_prob).view(bs, 1) | |
| return cond * (1. - mask) | |
| else: | |
| return cond | |
| def trans_forward(self, motion_ids, cond, padding_mask, force_mask=False): | |
| ''' | |
| :param motion_ids: (b, seqlen) | |
| :padding_mask: (b, seqlen), all pad positions are TRUE else FALSE | |
| :param cond: (b, embed_dim) for text, (b, num_actions) for action | |
| :param force_mask: boolean | |
| :return: | |
| -logits: (b, num_token, seqlen) | |
| ''' | |
| cond = self.mask_cond(cond, force_mask=force_mask) | |
| # print(motion_ids.shape) | |
| x = self.token_emb(motion_ids) | |
| # print(x.shape) | |
| # (b, seqlen, d) -> (seqlen, b, latent_dim) | |
| x = self.input_process(x) | |
| cond = self.cond_emb(cond).unsqueeze(0) #(1, b, latent_dim) | |
| x = self.position_enc(x) | |
| xseq = torch.cat([cond, x], dim=0) #(seqlen+1, b, latent_dim) | |
| padding_mask = torch.cat([torch.zeros_like(padding_mask[:, 0:1]), padding_mask], dim=1) #(b, seqlen+1) | |
| # print(xseq.shape, padding_mask.shape) | |
| # print(padding_mask.shape, xseq.shape) | |
| output = self.seqTransEncoder(xseq, src_key_padding_mask=padding_mask)[1:] #(seqlen, b, e) | |
| logits = self.output_process(output) #(seqlen, b, e) -> (b, ntoken, seqlen) | |
| return logits | |
| def forward(self, ids, y, m_lens): | |
| ''' | |
| :param ids: (b, n) | |
| :param y: raw text for cond_mode=text, (b, ) for cond_mode=action | |
| :m_lens: (b,) | |
| :return: | |
| ''' | |
| bs, ntokens = ids.shape | |
| device = ids.device | |
| # Positions that are PADDED are ALL FALSE | |
| non_pad_mask = lengths_to_mask(m_lens, ntokens) #(b, n) | |
| ids = torch.where(non_pad_mask, ids, self.pad_id) | |
| force_mask = False | |
| if self.cond_mode == 'text': | |
| with torch.no_grad(): | |
| cond_vector = self.encode_text(y) | |
| elif self.cond_mode == 'action': | |
| cond_vector = self.enc_action(y).to(device).float() | |
| elif self.cond_mode == 'uncond': | |
| cond_vector = torch.zeros(bs, self.latent_dim).float().to(device) | |
| force_mask = True | |
| else: | |
| raise NotImplementedError("Unsupported condition mode!!!") | |
| ''' | |
| Prepare mask | |
| ''' | |
| rand_time = uniform((bs,), device=device) | |
| rand_mask_probs = self.noise_schedule(rand_time) | |
| num_token_masked = (ntokens * rand_mask_probs).round().clamp(min=1) | |
| batch_randperm = torch.rand((bs, ntokens), device=device).argsort(dim=-1) | |
| # Positions to be MASKED are ALL TRUE | |
| mask = batch_randperm < num_token_masked.unsqueeze(-1) | |
| # Positions to be MASKED must also be NON-PADDED | |
| mask &= non_pad_mask | |
| # Note this is our training target, not input | |
| labels = torch.where(mask, ids, self.mask_id) | |
| x_ids = ids.clone() | |
| # Further Apply Bert Masking Scheme | |
| # Step 1: 10% replace with an incorrect token | |
| mask_rid = get_mask_subset_prob(mask, 0.1) | |
| rand_id = torch.randint_like(x_ids, high=self.opt.num_tokens) | |
| x_ids = torch.where(mask_rid, rand_id, x_ids) | |
| # Step 2: 90% x 10% replace with correct token, and 90% x 88% replace with mask token | |
| mask_mid = get_mask_subset_prob(mask & ~mask_rid, 0.88) | |
| # mask_mid = mask | |
| x_ids = torch.where(mask_mid, self.mask_id, x_ids) | |
| logits = self.trans_forward(x_ids, cond_vector, ~non_pad_mask, force_mask) | |
| ce_loss, pred_id, acc = cal_performance(logits, labels, ignore_index=self.mask_id) | |
| return ce_loss, pred_id, acc | |
| def forward_with_cond_scale(self, | |
| motion_ids, | |
| cond_vector, | |
| padding_mask, | |
| cond_scale=3, | |
| force_mask=False): | |
| # bs = motion_ids.shape[0] | |
| # if cond_scale == 1: | |
| if force_mask: | |
| return self.trans_forward(motion_ids, cond_vector, padding_mask, force_mask=True) | |
| logits = self.trans_forward(motion_ids, cond_vector, padding_mask) | |
| if cond_scale == 1: | |
| return logits | |
| aux_logits = self.trans_forward(motion_ids, cond_vector, padding_mask, force_mask=True) | |
| scaled_logits = aux_logits + (logits - aux_logits) * cond_scale | |
| return scaled_logits | |
| def generate(self, | |
| conds, | |
| m_lens, | |
| timesteps: int, | |
| cond_scale: int, | |
| temperature=1, | |
| topk_filter_thres=0.9, | |
| gsample=False, | |
| force_mask=False | |
| ): | |
| # print(self.opt.num_quantizers) | |
| # assert len(timesteps) >= len(cond_scales) == self.opt.num_quantizers | |
| device = next(self.parameters()).device | |
| seq_len = max(m_lens) | |
| batch_size = len(m_lens) | |
| if self.cond_mode == 'text': | |
| with torch.no_grad(): | |
| cond_vector = self.encode_text(conds) | |
| elif self.cond_mode == 'action': | |
| cond_vector = self.enc_action(conds).to(device) | |
| elif self.cond_mode == 'uncond': | |
| cond_vector = torch.zeros(batch_size, self.latent_dim).float().to(device) | |
| else: | |
| raise NotImplementedError("Unsupported condition mode!!!") | |
| padding_mask = ~lengths_to_mask(m_lens, seq_len) | |
| # print(padding_mask.shape, ) | |
| # Start from all tokens being masked | |
| ids = torch.where(padding_mask, self.pad_id, self.mask_id) | |
| scores = torch.where(padding_mask, 1e5, 0.) | |
| starting_temperature = temperature | |
| for timestep, steps_until_x0 in zip(torch.linspace(0, 1, timesteps, device=device), reversed(range(timesteps))): | |
| # 0 < timestep < 1 | |
| rand_mask_prob = self.noise_schedule(timestep) # Tensor | |
| ''' | |
| Maskout, and cope with variable length | |
| ''' | |
| # fix: the ratio regarding lengths, instead of seq_len | |
| num_token_masked = torch.round(rand_mask_prob * m_lens).clamp(min=1) # (b, ) | |
| # select num_token_masked tokens with lowest scores to be masked | |
| sorted_indices = scores.argsort( | |
| dim=1) # (b, k), sorted_indices[i, j] = the index of j-th lowest element in scores on dim=1 | |
| ranks = sorted_indices.argsort(dim=1) # (b, k), rank[i, j] = the rank (0: lowest) of scores[i, j] on dim=1 | |
| is_mask = (ranks < num_token_masked.unsqueeze(-1)) | |
| ids = torch.where(is_mask, self.mask_id, ids) | |
| ''' | |
| Preparing input | |
| ''' | |
| # (b, num_token, seqlen) | |
| logits = self.forward_with_cond_scale(ids, cond_vector=cond_vector, | |
| padding_mask=padding_mask, | |
| cond_scale=cond_scale, | |
| force_mask=force_mask) | |
| logits = logits.permute(0, 2, 1) # (b, seqlen, ntoken) | |
| # print(logits.shape, self.opt.num_tokens) | |
| # clean low prob token | |
| filtered_logits = top_k(logits, topk_filter_thres, dim=-1) | |
| ''' | |
| Update ids | |
| ''' | |
| # if force_mask: | |
| temperature = starting_temperature | |
| # else: | |
| # temperature = starting_temperature * (steps_until_x0 / timesteps) | |
| # temperature = max(temperature, 1e-4) | |
| # print(filtered_logits.shape) | |
| # temperature is annealed, gradually reducing temperature as well as randomness | |
| if gsample: # use gumbel_softmax sampling | |
| # print("1111") | |
| pred_ids = gumbel_sample(filtered_logits, temperature=temperature, dim=-1) # (b, seqlen) | |
| else: # use multinomial sampling | |
| # print("2222") | |
| probs = F.softmax(filtered_logits, dim=-1) # (b, seqlen, ntoken) | |
| # print(temperature, starting_temperature, steps_until_x0, timesteps) | |
| # print(probs / temperature) | |
| pred_ids = Categorical(probs / temperature).sample() # (b, seqlen) | |
| # print(pred_ids.max(), pred_ids.min()) | |
| # if pred_ids. | |
| ids = torch.where(is_mask, pred_ids, ids) | |
| ''' | |
| Updating scores | |
| ''' | |
| probs_without_temperature = logits.softmax(dim=-1) # (b, seqlen, ntoken) | |
| scores = probs_without_temperature.gather(2, pred_ids.unsqueeze(dim=-1)) # (b, seqlen, 1) | |
| scores = scores.squeeze(-1) # (b, seqlen) | |
| # We do not want to re-mask the previously kept tokens, or pad tokens | |
| scores = scores.masked_fill(~is_mask, 1e5) | |
| ids = torch.where(padding_mask, -1, ids) | |
| # print("Final", ids.max(), ids.min()) | |
| return ids | |
| def edit(self, | |
| conds, | |
| tokens, | |
| m_lens, | |
| timesteps: int, | |
| cond_scale: int, | |
| temperature=1, | |
| topk_filter_thres=0.9, | |
| gsample=False, | |
| force_mask=False, | |
| edit_mask=None, | |
| padding_mask=None, | |
| ): | |
| assert edit_mask.shape == tokens.shape if edit_mask is not None else True | |
| device = next(self.parameters()).device | |
| seq_len = tokens.shape[1] | |
| if self.cond_mode == 'text': | |
| with torch.no_grad(): | |
| cond_vector = self.encode_text(conds) | |
| elif self.cond_mode == 'action': | |
| cond_vector = self.enc_action(conds).to(device) | |
| elif self.cond_mode == 'uncond': | |
| cond_vector = torch.zeros(1, self.latent_dim).float().to(device) | |
| else: | |
| raise NotImplementedError("Unsupported condition mode!!!") | |
| if padding_mask == None: | |
| padding_mask = ~lengths_to_mask(m_lens, seq_len) | |
| # Start from all tokens being masked | |
| if edit_mask == None: | |
| mask_free = True | |
| ids = torch.where(padding_mask, self.pad_id, tokens) | |
| edit_mask = torch.ones_like(padding_mask) | |
| edit_mask = edit_mask & ~padding_mask | |
| edit_len = edit_mask.sum(dim=-1) | |
| scores = torch.where(edit_mask, 0., 1e5) | |
| else: | |
| mask_free = False | |
| edit_mask = edit_mask & ~padding_mask | |
| edit_len = edit_mask.sum(dim=-1) | |
| ids = torch.where(edit_mask, self.mask_id, tokens) | |
| scores = torch.where(edit_mask, 0., 1e5) | |
| starting_temperature = temperature | |
| for timestep, steps_until_x0 in zip(torch.linspace(0, 1, timesteps, device=device), reversed(range(timesteps))): | |
| # 0 < timestep < 1 | |
| rand_mask_prob = 0.16 if mask_free else self.noise_schedule(timestep) # Tensor | |
| ''' | |
| Maskout, and cope with variable length | |
| ''' | |
| # fix: the ratio regarding lengths, instead of seq_len | |
| num_token_masked = torch.round(rand_mask_prob * edit_len).clamp(min=1) # (b, ) | |
| # select num_token_masked tokens with lowest scores to be masked | |
| sorted_indices = scores.argsort( | |
| dim=1) # (b, k), sorted_indices[i, j] = the index of j-th lowest element in scores on dim=1 | |
| ranks = sorted_indices.argsort(dim=1) # (b, k), rank[i, j] = the rank (0: lowest) of scores[i, j] on dim=1 | |
| is_mask = (ranks < num_token_masked.unsqueeze(-1)) | |
| # is_mask = (torch.rand_like(scores) < 0.8) * ~padding_mask if mask_free else is_mask | |
| ids = torch.where(is_mask, self.mask_id, ids) | |
| ''' | |
| Preparing input | |
| ''' | |
| # (b, num_token, seqlen) | |
| logits = self.forward_with_cond_scale(ids, cond_vector=cond_vector, | |
| padding_mask=padding_mask, | |
| cond_scale=cond_scale, | |
| force_mask=force_mask) | |
| logits = logits.permute(0, 2, 1) # (b, seqlen, ntoken) | |
| # print(logits.shape, self.opt.num_tokens) | |
| # clean low prob token | |
| filtered_logits = top_k(logits, topk_filter_thres, dim=-1) | |
| ''' | |
| Update ids | |
| ''' | |
| # if force_mask: | |
| temperature = starting_temperature | |
| # else: | |
| # temperature = starting_temperature * (steps_until_x0 / timesteps) | |
| # temperature = max(temperature, 1e-4) | |
| # print(filtered_logits.shape) | |
| # temperature is annealed, gradually reducing temperature as well as randomness | |
| if gsample: # use gumbel_softmax sampling | |
| # print("1111") | |
| pred_ids = gumbel_sample(filtered_logits, temperature=temperature, dim=-1) # (b, seqlen) | |
| else: # use multinomial sampling | |
| # print("2222") | |
| probs = F.softmax(filtered_logits, dim=-1) # (b, seqlen, ntoken) | |
| # print(temperature, starting_temperature, steps_until_x0, timesteps) | |
| # print(probs / temperature) | |
| pred_ids = Categorical(probs / temperature).sample() # (b, seqlen) | |
| # print(pred_ids.max(), pred_ids.min()) | |
| # if pred_ids. | |
| ids = torch.where(is_mask, pred_ids, ids) | |
| ''' | |
| Updating scores | |
| ''' | |
| probs_without_temperature = logits.softmax(dim=-1) # (b, seqlen, ntoken) | |
| scores = probs_without_temperature.gather(2, pred_ids.unsqueeze(dim=-1)) # (b, seqlen, 1) | |
| scores = scores.squeeze(-1) # (b, seqlen) | |
| # We do not want to re-mask the previously kept tokens, or pad tokens | |
| scores = scores.masked_fill(~edit_mask, 1e5) if mask_free else scores.masked_fill(~is_mask, 1e5) | |
| ids = torch.where(padding_mask, -1, ids) | |
| # print("Final", ids.max(), ids.min()) | |
| return ids | |
| def edit_beta(self, | |
| conds, | |
| conds_og, | |
| tokens, | |
| m_lens, | |
| cond_scale: int, | |
| force_mask=False, | |
| ): | |
| device = next(self.parameters()).device | |
| seq_len = tokens.shape[1] | |
| if self.cond_mode == 'text': | |
| with torch.no_grad(): | |
| cond_vector = self.encode_text(conds) | |
| if conds_og is not None: | |
| cond_vector_og = self.encode_text(conds_og) | |
| else: | |
| cond_vector_og = None | |
| elif self.cond_mode == 'action': | |
| cond_vector = self.enc_action(conds).to(device) | |
| if conds_og is not None: | |
| cond_vector_og = self.enc_action(conds_og).to(device) | |
| else: | |
| cond_vector_og = None | |
| else: | |
| raise NotImplementedError("Unsupported condition mode!!!") | |
| padding_mask = ~lengths_to_mask(m_lens, seq_len) | |
| # Start from all tokens being masked | |
| ids = torch.where(padding_mask, self.pad_id, tokens) # Do not mask anything | |
| ''' | |
| Preparing input | |
| ''' | |
| # (b, num_token, seqlen) | |
| logits = self.forward_with_cond_scale(ids, | |
| cond_vector=cond_vector, | |
| cond_vector_neg=cond_vector_og, | |
| padding_mask=padding_mask, | |
| cond_scale=cond_scale, | |
| force_mask=force_mask) | |
| logits = logits.permute(0, 2, 1) # (b, seqlen, ntoken) | |
| ''' | |
| Updating scores | |
| ''' | |
| probs_without_temperature = logits.softmax(dim=-1) # (b, seqlen, ntoken) | |
| tokens[tokens == -1] = 0 # just to get through an error when index = -1 using gather | |
| og_tokens_scores = probs_without_temperature.gather(2, tokens.unsqueeze(dim=-1)) # (b, seqlen, 1) | |
| og_tokens_scores = og_tokens_scores.squeeze(-1) # (b, seqlen) | |
| return og_tokens_scores | |
| class ResidualTransformer(nn.Module): | |
| def __init__(self, code_dim, cond_mode, latent_dim=256, ff_size=1024, num_layers=8, cond_drop_prob=0.1, | |
| num_heads=4, dropout=0.1, clip_dim=512, shared_codebook=False, share_weight=False, | |
| clip_version=None, opt=None, **kargs): | |
| super(ResidualTransformer, self).__init__() | |
| print(f'latent_dim: {latent_dim}, ff_size: {ff_size}, nlayers: {num_layers}, nheads: {num_heads}, dropout: {dropout}') | |
| # assert shared_codebook == True, "Only support shared codebook right now!" | |
| self.code_dim = code_dim | |
| self.latent_dim = latent_dim | |
| self.clip_dim = clip_dim | |
| self.dropout = dropout | |
| self.opt = opt | |
| self.cond_mode = cond_mode | |
| # self.cond_drop_prob = cond_drop_prob | |
| if self.cond_mode == 'action': | |
| assert 'num_actions' in kargs | |
| self.num_actions = kargs.get('num_actions', 1) | |
| self.cond_drop_prob = cond_drop_prob | |
| ''' | |
| Preparing Networks | |
| ''' | |
| self.input_process = InputProcess(self.code_dim, self.latent_dim) | |
| self.position_enc = PositionalEncoding(self.latent_dim, self.dropout) | |
| seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, | |
| nhead=num_heads, | |
| dim_feedforward=ff_size, | |
| dropout=dropout, | |
| activation='gelu') | |
| self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer, | |
| num_layers=num_layers) | |
| self.encode_quant = partial(F.one_hot, num_classes=self.opt.num_quantizers) | |
| self.encode_action = partial(F.one_hot, num_classes=self.num_actions) | |
| self.quant_emb = nn.Linear(self.opt.num_quantizers, self.latent_dim) | |
| # if self.cond_mode != 'no_cond': | |
| if self.cond_mode == 'text': | |
| self.cond_emb = nn.Linear(self.clip_dim, self.latent_dim) | |
| elif self.cond_mode == 'action': | |
| self.cond_emb = nn.Linear(self.num_actions, self.latent_dim) | |
| else: | |
| raise KeyError("Unsupported condition mode!!!") | |
| _num_tokens = opt.num_tokens + 1 # one dummy tokens for padding | |
| self.pad_id = opt.num_tokens | |
| # self.output_process = OutputProcess_Bert(out_feats=opt.num_tokens, latent_dim=latent_dim) | |
| self.output_process = OutputProcess(out_feats=code_dim, latent_dim=latent_dim) | |
| if shared_codebook: | |
| token_embed = nn.Parameter(torch.normal(mean=0, std=0.02, size=(_num_tokens, code_dim))) | |
| self.token_embed_weight = token_embed.expand(opt.num_quantizers-1, _num_tokens, code_dim) | |
| if share_weight: | |
| self.output_proj_weight = self.token_embed_weight | |
| self.output_proj_bias = None | |
| else: | |
| output_proj = nn.Parameter(torch.normal(mean=0, std=0.02, size=(_num_tokens, code_dim))) | |
| output_bias = nn.Parameter(torch.zeros(size=(_num_tokens,))) | |
| # self.output_proj_bias = 0 | |
| self.output_proj_weight = output_proj.expand(opt.num_quantizers-1, _num_tokens, code_dim) | |
| self.output_proj_bias = output_bias.expand(opt.num_quantizers-1, _num_tokens) | |
| else: | |
| if share_weight: | |
| self.embed_proj_shared_weight = nn.Parameter(torch.normal(mean=0, std=0.02, size=(opt.num_quantizers - 2, _num_tokens, code_dim))) | |
| self.token_embed_weight_ = nn.Parameter(torch.normal(mean=0, std=0.02, size=(1, _num_tokens, code_dim))) | |
| self.output_proj_weight_ = nn.Parameter(torch.normal(mean=0, std=0.02, size=(1, _num_tokens, code_dim))) | |
| self.output_proj_bias = None | |
| self.registered = False | |
| else: | |
| output_proj_weight = torch.normal(mean=0, std=0.02, | |
| size=(opt.num_quantizers - 1, _num_tokens, code_dim)) | |
| self.output_proj_weight = nn.Parameter(output_proj_weight) | |
| self.output_proj_bias = nn.Parameter(torch.zeros(size=(opt.num_quantizers, _num_tokens))) | |
| token_embed_weight = torch.normal(mean=0, std=0.02, | |
| size=(opt.num_quantizers - 1, _num_tokens, code_dim)) | |
| self.token_embed_weight = nn.Parameter(token_embed_weight) | |
| self.apply(self.__init_weights) | |
| self.shared_codebook = shared_codebook | |
| self.share_weight = share_weight | |
| if self.cond_mode == 'text': | |
| print('Loading CLIP...') | |
| self.clip_version = clip_version | |
| self.clip_model = self.load_and_freeze_clip(clip_version) | |
| # def | |
| def mask_cond(self, cond, force_mask=False): | |
| bs, d = cond.shape | |
| if force_mask: | |
| return torch.zeros_like(cond) | |
| elif self.training and self.cond_drop_prob > 0.: | |
| mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_drop_prob).view(bs, 1) | |
| return cond * (1. - mask) | |
| else: | |
| return cond | |
| def __init_weights(self, module): | |
| if isinstance(module, (nn.Linear, nn.Embedding)): | |
| module.weight.data.normal_(mean=0.0, std=0.02) | |
| if isinstance(module, nn.Linear) and module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.LayerNorm): | |
| module.bias.data.zero_() | |
| module.weight.data.fill_(1.0) | |
| def parameters_wo_clip(self): | |
| return [p for name, p in self.named_parameters() if not name.startswith('clip_model.')] | |
| def load_and_freeze_clip(self, clip_version): | |
| clip_model, clip_preprocess = clip.load(clip_version, device='cpu', | |
| jit=False) # Must set jit=False for training | |
| # Cannot run on cpu | |
| clip.model.convert_weights( | |
| clip_model) # Actually this line is unnecessary since clip by default already on float16 | |
| # Date 0707: It's necessary, only unecessary when load directly to gpu. Disable if need to run on cpu | |
| # Freeze CLIP weights | |
| clip_model.eval() | |
| for p in clip_model.parameters(): | |
| p.requires_grad = False | |
| return clip_model | |
| def encode_text(self, raw_text): | |
| device = next(self.parameters()).device | |
| text = clip.tokenize(raw_text, truncate=True).to(device) | |
| feat_clip_text = self.clip_model.encode_text(text).float() | |
| return feat_clip_text | |
| def q_schedule(self, bs, low, high): | |
| noise = uniform((bs,), device=self.opt.device) | |
| schedule = 1 - cosine_schedule(noise) | |
| return torch.round(schedule * (high - low)) + low | |
| def process_embed_proj_weight(self): | |
| if self.share_weight and (not self.shared_codebook): | |
| # if not self.registered: | |
| self.output_proj_weight = torch.cat([self.embed_proj_shared_weight, self.output_proj_weight_], dim=0) | |
| self.token_embed_weight = torch.cat([self.token_embed_weight_, self.embed_proj_shared_weight], dim=0) | |
| # self.registered = True | |
| def output_project(self, logits, qids): | |
| ''' | |
| :logits: (bs, code_dim, seqlen) | |
| :qids: (bs) | |
| :return: | |
| -logits (bs, ntoken, seqlen) | |
| ''' | |
| # (num_qlayers-1, num_token, code_dim) -> (bs, ntoken, code_dim) | |
| output_proj_weight = self.output_proj_weight[qids] | |
| # (num_qlayers, ntoken) -> (bs, ntoken) | |
| output_proj_bias = None if self.output_proj_bias is None else self.output_proj_bias[qids] | |
| output = torch.einsum('bnc, bcs->bns', output_proj_weight, logits) | |
| if output_proj_bias is not None: | |
| output += output + output_proj_bias.unsqueeze(-1) | |
| return output | |
| def trans_forward(self, motion_codes, qids, cond, padding_mask, force_mask=False): | |
| ''' | |
| :param motion_codes: (b, seqlen, d) | |
| :padding_mask: (b, seqlen), all pad positions are TRUE else FALSE | |
| :param qids: (b), quantizer layer ids | |
| :param cond: (b, embed_dim) for text, (b, num_actions) for action | |
| :return: | |
| -logits: (b, num_token, seqlen) | |
| ''' | |
| cond = self.mask_cond(cond, force_mask=force_mask) | |
| # (b, seqlen, d) -> (seqlen, b, latent_dim) | |
| x = self.input_process(motion_codes) | |
| # (b, num_quantizer) | |
| q_onehot = self.encode_quant(qids).float().to(x.device) | |
| q_emb = self.quant_emb(q_onehot).unsqueeze(0) # (1, b, latent_dim) | |
| cond = self.cond_emb(cond).unsqueeze(0) # (1, b, latent_dim) | |
| x = self.position_enc(x) | |
| xseq = torch.cat([cond, q_emb, x], dim=0) # (seqlen+2, b, latent_dim) | |
| padding_mask = torch.cat([torch.zeros_like(padding_mask[:, 0:2]), padding_mask], dim=1) # (b, seqlen+2) | |
| output = self.seqTransEncoder(xseq, src_key_padding_mask=padding_mask)[2:] # (seqlen, b, e) | |
| logits = self.output_process(output) | |
| return logits | |
| def forward_with_cond_scale(self, | |
| motion_codes, | |
| q_id, | |
| cond_vector, | |
| padding_mask, | |
| cond_scale=3, | |
| force_mask=False): | |
| bs = motion_codes.shape[0] | |
| # if cond_scale == 1: | |
| qids = torch.full((bs,), q_id, dtype=torch.long, device=motion_codes.device) | |
| if force_mask: | |
| logits = self.trans_forward(motion_codes, qids, cond_vector, padding_mask, force_mask=True) | |
| logits = self.output_project(logits, qids-1) | |
| return logits | |
| logits = self.trans_forward(motion_codes, qids, cond_vector, padding_mask) | |
| logits = self.output_project(logits, qids-1) | |
| if cond_scale == 1: | |
| return logits | |
| aux_logits = self.trans_forward(motion_codes, qids, cond_vector, padding_mask, force_mask=True) | |
| aux_logits = self.output_project(aux_logits, qids-1) | |
| scaled_logits = aux_logits + (logits - aux_logits) * cond_scale | |
| return scaled_logits | |
| def forward(self, all_indices, y, m_lens): | |
| ''' | |
| :param all_indices: (b, n, q) | |
| :param y: raw text for cond_mode=text, (b, ) for cond_mode=action | |
| :m_lens: (b,) | |
| :return: | |
| ''' | |
| self.process_embed_proj_weight() | |
| bs, ntokens, num_quant_layers = all_indices.shape | |
| device = all_indices.device | |
| # Positions that are PADDED are ALL FALSE | |
| non_pad_mask = lengths_to_mask(m_lens, ntokens) # (b, n) | |
| q_non_pad_mask = repeat(non_pad_mask, 'b n -> b n q', q=num_quant_layers) | |
| all_indices = torch.where(q_non_pad_mask, all_indices, self.pad_id) #(b, n, q) | |
| # randomly sample quantization layers to work on, [1, num_q) | |
| active_q_layers = q_schedule(bs, low=1, high=num_quant_layers, device=device) | |
| # print(self.token_embed_weight.shape, all_indices.shape) | |
| token_embed = repeat(self.token_embed_weight, 'q c d-> b c d q', b=bs) | |
| gather_indices = repeat(all_indices[..., :-1], 'b n q -> b n d q', d=token_embed.shape[2]) | |
| # print(token_embed.shape, gather_indices.shape) | |
| all_codes = token_embed.gather(1, gather_indices) # (b, n, d, q-1) | |
| cumsum_codes = torch.cumsum(all_codes, dim=-1) #(b, n, d, q-1) | |
| active_indices = all_indices[torch.arange(bs), :, active_q_layers] # (b, n) | |
| history_sum = cumsum_codes[torch.arange(bs), :, :, active_q_layers - 1] | |
| force_mask = False | |
| if self.cond_mode == 'text': | |
| with torch.no_grad(): | |
| cond_vector = self.encode_text(y) | |
| elif self.cond_mode == 'action': | |
| cond_vector = self.enc_action(y).to(device).float() | |
| elif self.cond_mode == 'uncond': | |
| cond_vector = torch.zeros(bs, self.latent_dim).float().to(device) | |
| force_mask = True | |
| else: | |
| raise NotImplementedError("Unsupported condition mode!!!") | |
| logits = self.trans_forward(history_sum, active_q_layers, cond_vector, ~non_pad_mask, force_mask) | |
| logits = self.output_project(logits, active_q_layers-1) | |
| ce_loss, pred_id, acc = cal_performance(logits, active_indices, ignore_index=self.pad_id) | |
| return ce_loss, pred_id, acc | |
| def generate(self, | |
| motion_ids, | |
| conds, | |
| m_lens, | |
| temperature=1, | |
| topk_filter_thres=0.9, | |
| cond_scale=2, | |
| num_res_layers=-1, # If it's -1, use all. | |
| ): | |
| # print(self.opt.num_quantizers) | |
| # assert len(timesteps) >= len(cond_scales) == self.opt.num_quantizers | |
| self.process_embed_proj_weight() | |
| device = next(self.parameters()).device | |
| seq_len = motion_ids.shape[1] | |
| batch_size = len(conds) | |
| if self.cond_mode == 'text': | |
| with torch.no_grad(): | |
| cond_vector = self.encode_text(conds) | |
| elif self.cond_mode == 'action': | |
| cond_vector = self.enc_action(conds).to(device) | |
| elif self.cond_mode == 'uncond': | |
| cond_vector = torch.zeros(batch_size, self.latent_dim).float().to(device) | |
| else: | |
| raise NotImplementedError("Unsupported condition mode!!!") | |
| # token_embed = repeat(self.token_embed_weight, 'c d -> b c d', b=batch_size) | |
| # gathered_ids = repeat(motion_ids, 'b n -> b n d', d=token_embed.shape[-1]) | |
| # history_sum = token_embed.gather(1, gathered_ids) | |
| # print(pa, seq_len) | |
| padding_mask = ~lengths_to_mask(m_lens, seq_len) | |
| # print(padding_mask.shape, motion_ids.shape) | |
| motion_ids = torch.where(padding_mask, self.pad_id, motion_ids) | |
| all_indices = [motion_ids] | |
| history_sum = 0 | |
| num_quant_layers = self.opt.num_quantizers if num_res_layers==-1 else num_res_layers+1 | |
| for i in range(1, num_quant_layers): | |
| # print(f"--> Working on {i}-th quantizer") | |
| # Start from all tokens being masked | |
| # qids = torch.full((batch_size,), i, dtype=torch.long, device=motion_ids.device) | |
| token_embed = self.token_embed_weight[i-1] | |
| token_embed = repeat(token_embed, 'c d -> b c d', b=batch_size) | |
| gathered_ids = repeat(motion_ids, 'b n -> b n d', d=token_embed.shape[-1]) | |
| history_sum += token_embed.gather(1, gathered_ids) | |
| logits = self.forward_with_cond_scale(history_sum, i, cond_vector, padding_mask, cond_scale=cond_scale) | |
| # logits = self.trans_forward(history_sum, qids, cond_vector, padding_mask) | |
| logits = logits.permute(0, 2, 1) # (b, seqlen, ntoken) | |
| # clean low prob token | |
| filtered_logits = top_k(logits, topk_filter_thres, dim=-1) | |
| pred_ids = gumbel_sample(filtered_logits, temperature=temperature, dim=-1) # (b, seqlen) | |
| # probs = F.softmax(filtered_logits, dim=-1) # (b, seqlen, ntoken) | |
| # # print(temperature, starting_temperature, steps_until_x0, timesteps) | |
| # # print(probs / temperature) | |
| # pred_ids = Categorical(probs / temperature).sample() # (b, seqlen) | |
| ids = torch.where(padding_mask, self.pad_id, pred_ids) | |
| motion_ids = ids | |
| all_indices.append(ids) | |
| all_indices = torch.stack(all_indices, dim=-1) | |
| # padding_mask = repeat(padding_mask, 'b n -> b n q', q=all_indices.shape[-1]) | |
| # all_indices = torch.where(padding_mask, -1, all_indices) | |
| all_indices = torch.where(all_indices==self.pad_id, -1, all_indices) | |
| # all_indices = all_indices.masked_fill() | |
| return all_indices | |
| def edit(self, | |
| motion_ids, | |
| conds, | |
| m_lens, | |
| temperature=1, | |
| topk_filter_thres=0.9, | |
| cond_scale=2 | |
| ): | |
| # print(self.opt.num_quantizers) | |
| # assert len(timesteps) >= len(cond_scales) == self.opt.num_quantizers | |
| self.process_embed_proj_weight() | |
| device = next(self.parameters()).device | |
| seq_len = motion_ids.shape[1] | |
| batch_size = len(conds) | |
| if self.cond_mode == 'text': | |
| with torch.no_grad(): | |
| cond_vector = self.encode_text(conds) | |
| elif self.cond_mode == 'action': | |
| cond_vector = self.enc_action(conds).to(device) | |
| elif self.cond_mode == 'uncond': | |
| cond_vector = torch.zeros(batch_size, self.latent_dim).float().to(device) | |
| else: | |
| raise NotImplementedError("Unsupported condition mode!!!") | |
| # token_embed = repeat(self.token_embed_weight, 'c d -> b c d', b=batch_size) | |
| # gathered_ids = repeat(motion_ids, 'b n -> b n d', d=token_embed.shape[-1]) | |
| # history_sum = token_embed.gather(1, gathered_ids) | |
| # print(pa, seq_len) | |
| padding_mask = ~lengths_to_mask(m_lens, seq_len) | |
| # print(padding_mask.shape, motion_ids.shape) | |
| motion_ids = torch.where(padding_mask, self.pad_id, motion_ids) | |
| all_indices = [motion_ids] | |
| history_sum = 0 | |
| for i in range(1, self.opt.num_quantizers): | |
| # print(f"--> Working on {i}-th quantizer") | |
| # Start from all tokens being masked | |
| # qids = torch.full((batch_size,), i, dtype=torch.long, device=motion_ids.device) | |
| token_embed = self.token_embed_weight[i-1] | |
| token_embed = repeat(token_embed, 'c d -> b c d', b=batch_size) | |
| gathered_ids = repeat(motion_ids, 'b n -> b n d', d=token_embed.shape[-1]) | |
| history_sum += token_embed.gather(1, gathered_ids) | |
| logits = self.forward_with_cond_scale(history_sum, i, cond_vector, padding_mask, cond_scale=cond_scale) | |
| # logits = self.trans_forward(history_sum, qids, cond_vector, padding_mask) | |
| logits = logits.permute(0, 2, 1) # (b, seqlen, ntoken) | |
| # clean low prob token | |
| filtered_logits = top_k(logits, topk_filter_thres, dim=-1) | |
| pred_ids = gumbel_sample(filtered_logits, temperature=temperature, dim=-1) # (b, seqlen) | |
| # probs = F.softmax(filtered_logits, dim=-1) # (b, seqlen, ntoken) | |
| # # print(temperature, starting_temperature, steps_until_x0, timesteps) | |
| # # print(probs / temperature) | |
| # pred_ids = Categorical(probs / temperature).sample() # (b, seqlen) | |
| ids = torch.where(padding_mask, self.pad_id, pred_ids) | |
| motion_ids = ids | |
| all_indices.append(ids) | |
| all_indices = torch.stack(all_indices, dim=-1) | |
| # padding_mask = repeat(padding_mask, 'b n -> b n q', q=all_indices.shape[-1]) | |
| # all_indices = torch.where(padding_mask, -1, all_indices) | |
| all_indices = torch.where(all_indices==self.pad_id, -1, all_indices) | |
| # all_indices = all_indices.masked_fill() | |
| return all_indices |