Spaces:
Paused
Paused
| import sys | |
| sys.path.append('../') | |
| from typing import Optional | |
| from copy import deepcopy | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel, Wav2Vec2FeatureExtractor, WhisperFeatureExtractor, WhisperModel | |
| # from .modeling_whisper import WhisperModel | |
| from my_laion_clap.CLAP.src.laion_clap.clap_module.htsat import create_htsat_model | |
| import torch | |
| import torchaudio | |
| import torchaudio.transforms as T | |
| import numpy as np | |
| from torch import nn | |
| import torchvision.transforms | |
| from contextlib import suppress | |
| try: | |
| from .flamingo import Flamingo | |
| from .flamingo_lm import FlamingoLMMixin | |
| from .utils import extend_instance | |
| except: | |
| from flamingo import Flamingo | |
| from flamingo_lm import FlamingoLMMixin | |
| from utils import extend_instance | |
| def int16_to_float32(x): | |
| return (x / 32767.0).astype(np.float32) | |
| def float32_to_int16(x): | |
| x = np.clip(x, a_min=-1., a_max=1.) | |
| return (x * 32767.).astype(np.int16) | |
| def int16_to_float32_torch(x): | |
| return (x / 32767.0).type(torch.float32) | |
| def float32_to_int16_torch(x): | |
| x = torch.clamp(x, min=-1., max=1.) | |
| return (x * 32767.).type(torch.int16) | |
| class CLAPAudioCfp: | |
| model_type: str = "HTSAT" | |
| model_name: str = "large" | |
| sample_rate: int = 16000 | |
| audio_length: int = 1024 | |
| window_size: int = 1024 | |
| hop_size: int = 160 | |
| fmin: int = 50 | |
| fmax: int = 14000 | |
| class_num: int = 527 | |
| mel_bins: int = 64 | |
| clip_samples: int = 160000 | |
| class CLAP(nn.Module): | |
| def __init__(self, clap_config): | |
| super(CLAP, self).__init__() | |
| self.clap_config = clap_config | |
| self.method = clap_config["method"] | |
| device_id = f'cuda:{torch.cuda.current_device()}' | |
| if ('finetune' in clap_config) and clap_config['finetune']: | |
| self.finetune = True | |
| print('Finetuning CLAP encoder as well!') | |
| else: | |
| self.finetune = False | |
| audio_cfg = CLAPAudioCfp() | |
| enable_fusion = True | |
| fusion_type = "aff_2d" | |
| self.nvclap = create_htsat_model(audio_cfg, enable_fusion, fusion_type) | |
| clap_state_dict = torch.load(clap_config["checkpoint"], map_location = 'cpu') | |
| clap_state_dict_copy = clap_state_dict['state_dict'].copy() | |
| for key in list(clap_state_dict['state_dict'].keys()): | |
| if 'audio' in key: | |
| clap_state_dict_copy[key.replace('module.audio_branch.','')] = clap_state_dict_copy[key] | |
| del clap_state_dict_copy[key] | |
| else: | |
| del clap_state_dict_copy[key] | |
| self.nvclap.load_state_dict(clap_state_dict_copy, strict = False) | |
| self.nvclap = self.nvclap.to(device_id) | |
| for param in self.nvclap.parameters(): | |
| param.requires_grad = self.finetune | |
| if self.finetune: | |
| self.nvclap.train() | |
| else: | |
| self.nvclap.eval() | |
| print('loaded NVCLAP model: {}'.format(clap_config["checkpoint"])) | |
| def get_mel(self, audio_data): | |
| # mel shape: (n_mels, T) | |
| mel_tf = torchaudio.transforms.MelSpectrogram( | |
| sample_rate=16000, | |
| n_fft=1024, | |
| win_length=1024, | |
| hop_length=160, | |
| center=True, | |
| pad_mode="reflect", | |
| power=2.0, | |
| norm=None, | |
| onesided=True, | |
| n_mels=64, | |
| f_min=50, | |
| f_max=14000 | |
| ).to(audio_data.device) | |
| mel = mel_tf(audio_data) | |
| # we use log mel spectrogram as input | |
| mel = torchaudio.transforms.AmplitudeToDB(top_db=None)(mel) | |
| return mel.T # (T, n_mels) | |
| def get_audio_features(self, sample, audio_data, max_len, data_truncating, data_filling, require_grad=False): | |
| grad_fn = suppress if require_grad else torch.no_grad | |
| with grad_fn(): | |
| if len(audio_data) > max_len: | |
| if data_truncating == "rand_trunc": | |
| longer = torch.tensor([True]) | |
| elif data_truncating == "fusion": | |
| # fusion | |
| mel = self.get_mel(audio_data) | |
| # split to three parts | |
| chunk_frames = max_len // 160 + 1 # the +1 related to how the spectrogram is computed | |
| total_frames = mel.shape[0] | |
| if chunk_frames == total_frames: | |
| # there is a corner case where the audio length is | |
| # larger than max_len but smaller than max_len+hop_size. | |
| # In this case, we just use the whole audio. | |
| mel_fusion = torch.stack([mel, mel, mel, mel], dim=0) | |
| sample["mel_fusion"] = mel_fusion | |
| longer = torch.tensor([False]) | |
| else: | |
| ranges = np.array_split(list(range(0, total_frames - chunk_frames + 1)), 3) | |
| if len(ranges[1]) == 0: | |
| # if the audio is too short, we just use the first chunk | |
| ranges[1] = [0] | |
| if len(ranges[2]) == 0: | |
| # if the audio is too short, we just use the first chunk | |
| ranges[2] = [0] | |
| # randomly choose index for each part | |
| idx_front = np.random.choice(ranges[0]) | |
| idx_middle = np.random.choice(ranges[1]) | |
| idx_back = np.random.choice(ranges[2]) | |
| # select mel | |
| mel_chunk_front = mel[idx_front:idx_front + chunk_frames, :] | |
| mel_chunk_middle = mel[idx_middle:idx_middle + chunk_frames, :] | |
| mel_chunk_back = mel[idx_back:idx_back + chunk_frames, :] | |
| # shrink the mel | |
| mel_shrink = torchvision.transforms.Resize(size=[chunk_frames, 64])(mel[None])[0] | |
| # logging.info(f"mel_shrink.shape: {mel_shrink.shape}") | |
| # stack | |
| mel_fusion = torch.stack([mel_shrink, mel_chunk_front, mel_chunk_middle, mel_chunk_back], dim=0) | |
| sample["mel_fusion"] = mel_fusion | |
| longer = torch.tensor([True]) | |
| else: | |
| raise NotImplementedError( | |
| f"data_truncating {data_truncating} not implemented" | |
| ) | |
| # random crop to max_len (for compatibility) | |
| overflow = len(audio_data) - max_len | |
| idx = np.random.randint(0, overflow + 1) | |
| audio_data = audio_data[idx: idx + max_len] | |
| else: # padding if too short | |
| if len(audio_data) < max_len: # do nothing if equal | |
| if data_filling == "repeatpad": | |
| n_repeat = int(max_len / len(audio_data)) | |
| audio_data = audio_data.repeat(n_repeat) | |
| # audio_data = audio_data.unsqueeze(0).unsqueeze(0).unsqueeze(0) | |
| # audio_data = F.interpolate(audio_data,size=max_len,mode="bicubic")[0,0,0] | |
| audio_data = F.pad( | |
| audio_data, | |
| (0, max_len - len(audio_data)), | |
| mode="constant", | |
| value=0, | |
| ) | |
| elif data_filling == "pad": | |
| audio_data = F.pad( | |
| audio_data, | |
| (0, max_len - len(audio_data)), | |
| mode="constant", | |
| value=0, | |
| ) | |
| elif data_filling == "repeat": | |
| n_repeat = int(max_len / len(audio_data)) | |
| audio_data = audio_data.repeat(n_repeat + 1)[:max_len] | |
| else: | |
| raise NotImplementedError( | |
| f"data_filling {data_filling} not implemented" | |
| ) | |
| if data_truncating == 'fusion': | |
| mel = self.get_mel(audio_data) | |
| mel_fusion = torch.stack([mel, mel, mel, mel], dim=0) | |
| sample["mel_fusion"] = mel_fusion | |
| longer = torch.tensor([False]) | |
| sample["longer"] = longer | |
| sample["waveform"] = audio_data | |
| return sample | |
| def load_audio(self, clips): | |
| # waveform, sr = torchaudio.load(filename) | |
| # waveform = torchaudio.functional.resample(waveform, orig_freq=self.clap_config['sampling_rate'], new_freq=16000) | |
| processed_clips = [] | |
| for clip in clips: | |
| audio_data = int16_to_float32_torch(float32_to_int16_torch(clip)) | |
| sample = self.get_audio_features({}, audio_data, 160000, "fusion", "repeatpad") | |
| processed_clips.append(sample) | |
| waveforms = {} | |
| waveforms["mel_fusion"] = torch.stack([item["mel_fusion"] for item in processed_clips], dim=0) | |
| waveforms["longer"] = torch.stack([item["longer"] for item in processed_clips], dim=0) | |
| waveforms["waveform"] = torch.stack([item["waveform"] for item in processed_clips], dim=0) | |
| return waveforms | |
| def forward(self, audio_clips): | |
| # It will handle various segments, 1 audio will have various segments [B X n_segments X time] | |
| # expand batch dimension during inference | |
| if len(audio_clips.shape) == 2: | |
| audio_clips = audio_clips.unsqueeze(0) | |
| assert len(audio_clips.shape) == 3 | |
| audio_embeds = [] | |
| for audio_clip in audio_clips: | |
| audio = self.load_audio(audio_clip) | |
| audio_embed = self.nvclap(audio) #.reshape(-1, self.clap_config["audio_embed_dim"]) | |
| audio_embeds.append(audio_embed) | |
| audio_embeds = torch.stack(audio_embeds, dim=0) | |
| # audio_embeds.requires_grad = self.finetune | |
| return audio_embeds | |
| class Whisper(nn.Module): | |
| def __init__(self, whisper_config): | |
| super(Whisper, self).__init__() | |
| self.whisper_config = whisper_config | |
| self.method = self.whisper_config["method"] | |
| device_id = f'cuda:{torch.cuda.current_device()}' | |
| if ('finetune' in self.whisper_config) and self.whisper_config['finetune']: | |
| self.finetune = True | |
| print('Finetuning Whisper encoder as well!') | |
| else: | |
| self.finetune = False | |
| self.whisper = WhisperModel.from_pretrained(self.whisper_config['path']).encoder | |
| self.whisper = self.whisper.to(device_id) | |
| self.wav_processor = WhisperFeatureExtractor.from_pretrained(self.whisper_config['path']) | |
| for param in self.whisper.parameters(): | |
| param.requires_grad = self.finetune | |
| if self.finetune: | |
| self.whisper.train() | |
| else: | |
| self.whisper.eval() | |
| print('loaded Whisper model: {}'.format(self.whisper_config['path'])) | |
| def load_audio(self, clips): | |
| device_id = f'cuda:{torch.cuda.current_device()}' | |
| sample = self.wav_processor(clips.cpu().numpy(), sampling_rate=self.whisper_config['sampling_rate'], return_tensors="pt")["input_features"].to(device_id) | |
| return sample | |
| def forward(self, audio_clips): | |
| # It will handle various segments, 1 audio will have various segments [batch X n_segments X time] | |
| if len(audio_clips.shape) == 2: | |
| audio_clips = audio_clips.unsqueeze(0) | |
| assert len(audio_clips.shape) == 3 | |
| audio_embeds = [] | |
| for audio_clip in audio_clips: | |
| audio = self.load_audio(audio_clip) | |
| audio_embed = self.whisper(audio).last_hidden_state #.reshape(-1, self.whisper_config["audio_embed_dim"]) | |
| audio_embeds.append(audio_embed) | |
| audio_embeds = torch.stack(audio_embeds, dim=0) | |
| # audio_embeds.requires_grad = self.finetune | |
| return audio_embeds | |
| class MERT(nn.Module): | |
| def __init__(self, mert_config): | |
| super(MERT, self).__init__() | |
| self.mert_config = mert_config | |
| self.method = mert_config["method"] | |
| device_id = f'cuda:{torch.cuda.current_device()}' | |
| if ('finetune' in mert_config) and mert_config['finetune']: | |
| self.finetune = True | |
| print('Finetuning MERT encoder as well!') | |
| else: | |
| self.finetune = False | |
| self.mert = AutoModel.from_pretrained(mert_config['path'], trust_remote_code=True) | |
| self.mert = self.mert.to(device_id) | |
| self.resampler = T.Resample(16000, mert_config['sampling_rate']).to(device_id) | |
| self.wav_processor = Wav2Vec2FeatureExtractor.from_pretrained(mert_config['path'],trust_remote_code=True) | |
| for param in self.mert.parameters(): | |
| param.requires_grad = self.finetune | |
| if self.finetune: | |
| self.mert.train() | |
| else: | |
| self.mert.eval() | |
| print('loaded MERT model: {}'.format(mert_config['path'])) | |
| def load_audio(self, clips): | |
| device_id = f'cuda:{torch.cuda.current_device()}' | |
| clips = self.resampler(clips.float()).float() | |
| sample = self.wav_processor(clips, sampling_rate=self.mert_config['sampling_rate'], return_tensors="pt")["input_values"] | |
| if len(sample.shape) == 1: | |
| sample = sample.unsqueeze(0) | |
| return sample.to(device_id) | |
| def forward(self, audio_clips): | |
| # It will handle various segments, 1 audio will have various segments [batch X n_segments X time] | |
| if len(audio_clips.shape) == 2: | |
| audio_clips = audio_clips.unsqueeze(0) | |
| assert len(audio_clips.shape) == 3 | |
| audio_embeds = [] | |
| for audio_clip in audio_clips: | |
| audio = self.load_audio(audio_clip).to(torch.bfloat16) # all processing happens in float | |
| if len(audio.shape) > 2: | |
| audio = audio.squeeze(0) | |
| audio_embed = self.mert(audio, output_hidden_states=True).last_hidden_state #.reshape(-1, self.mert_config["audio_embed_dim"]) | |
| audio_embeds.append(audio_embed) | |
| audio_embeds = torch.stack(audio_embeds, dim=0) | |
| audio_embeds.requires_grad = self.finetune | |
| return audio_embeds | |
| def create_model_and_transforms( | |
| clap_config: dict, | |
| lang_encoder_path: str, | |
| tokenizer_path: str, | |
| audio_transformer_kwargs: dict, | |
| cross_attn_every_n_layers: int = 1, | |
| use_local_files: bool = False, | |
| decoder_layers_attr_name: str = None, | |
| freeze_lm_embeddings: bool = False, | |
| unfreeze_full_lm: bool = False, | |
| cache_dir: Optional[str] = None, | |
| **flamingo_kwargs, | |
| ): | |
| clap = CLAP(clap_config) | |
| text_tokenizer = AutoTokenizer.from_pretrained( | |
| tokenizer_path, | |
| local_files_only=use_local_files, | |
| trust_remote_code=True, | |
| cache_dir=cache_dir, | |
| ) | |
| text_tokenizer.add_special_tokens( | |
| {"additional_special_tokens": ["<audio>", "<|endofchunk|>", "<|PAD_TOKEN|>"]} | |
| ) | |
| text_tokenizer.pad_token = None | |
| text_tokenizer.pad_token_id = None | |
| text_tokenizer.pad_token = "<|PAD_TOKEN|>" | |
| text_tokenizer.pad_token_id = text_tokenizer.encode("<|PAD_TOKEN|>")[-1] | |
| if text_tokenizer.sep_token is None: | |
| text_tokenizer.add_special_tokens({"sep_token": "<SEP>"}) | |
| lang_encoder = AutoModelForCausalLM.from_pretrained( | |
| lang_encoder_path, | |
| local_files_only=use_local_files, | |
| trust_remote_code=True, | |
| cache_dir=cache_dir, | |
| ) | |
| extend_instance(lang_encoder, FlamingoLMMixin) | |
| if decoder_layers_attr_name is None: | |
| decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder) | |
| lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name) | |
| lang_encoder.resize_token_embeddings(len(text_tokenizer)) | |
| if ('finetune' in clap_config) and clap_config['finetune']: | |
| unfreeze_clap = True | |
| else: | |
| unfreeze_clap = False | |
| model = Flamingo( | |
| clap, | |
| unfreeze_clap, | |
| lang_encoder, | |
| text_tokenizer.encode("<|endofchunk|>")[-1], | |
| text_tokenizer.encode("<audio>")[-1], | |
| text_tokenizer.sep_token_id, | |
| clap_embed_dim = clap_config["audio_embed_dim"], | |
| audio_transformer_kwargs=audio_transformer_kwargs, | |
| cross_attn_every_n_layers=cross_attn_every_n_layers, | |
| **flamingo_kwargs, | |
| ) | |
| model.requires_grad_(False) | |
| assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0 | |
| model.audio_transformer_clap.requires_grad_(True) | |
| model.lang_encoder.gated_cross_attn_layers_sound.requires_grad_(True) | |
| if not freeze_lm_embeddings: | |
| model.lang_encoder.get_input_embeddings().requires_grad_(True) | |
| if unfreeze_full_lm: | |
| model.lang_encoder.requires_grad_(True) | |
| if unfreeze_clap: | |
| model.clap.requires_grad_(True) | |
| print("Flamingo model initialized with {:,} trainable parameters (audio transformer has {:,}, LM has {:,})".format( | |
| sum(p.numel() for p in model.parameters() if p.requires_grad), | |
| sum(p.numel() for p in model.audio_transformer_clap.parameters() if p.requires_grad), | |
| sum(p.numel() for p in model.lang_encoder.parameters() if p.requires_grad), | |
| )) | |
| return model, text_tokenizer | |
| def _infer_decoder_layers_attr_name(model): | |
| for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES: | |
| if k.lower() in model.__class__.__name__.lower(): | |
| return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k] | |
| raise ValueError( | |
| f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually." | |
| ) | |
| __KNOWN_DECODER_LAYERS_ATTR_NAMES = { | |
| "opt": "model.decoder.layers", | |
| "gptj": "transformer.h", | |
| "gpt-j": "transformer.h", | |
| "pythia": "gpt_neox.layers", | |
| "llama": "model.layers", | |
| "gptneoxforcausallm": "gpt_neox.layers", | |
| "mpt": "transformer.blocks", | |
| "mosaicgpt": "transformer.blocks", | |
| "qwen": "model.layers", | |
| } | |
| if __name__ == '__main__': | |
| import torch | |
| torch.set_printoptions(profile="full") # only in debug mode | |
| import sys | |
| sys.path.append('../') | |
| import os | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| import yaml | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('-c', '--config', type=str, default='../configs/config.yaml', help='yaml config path') | |
| args = parser.parse_args() | |
| config = yaml.load(open(args.config), Loader=yaml.FullLoader) | |
| data_config = config['data_config'] | |
| model_config = config["model_config"] | |
| clap_config = config["clap_config"] | |
| model, tokenizer = create_model_and_transforms( | |
| **model_config, | |
| clap_config=clap_config, | |
| use_local_files=False, | |
| gradient_checkpointing=True, | |
| freeze_lm_embeddings=True | |
| ) | |
| model = model.cuda() | |
| from data.data import AudioTextData, DataCollator | |
| from torch.utils.data import DataLoader | |
| batch_size = 8 | |
| trainset = AudioTextData( | |
| **data_config, clap_config=clap_config, tokenizer=tokenizer, | |
| epoch=1, force_reblend=True | |
| ) | |
| data_collator = DataCollator(tokenizer) | |
| trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, collate_fn=data_collator, num_workers=4) | |
| for step, batch in enumerate(trainloader): | |
| audio_clips = batch["audio_clips"].cuda() | |
| audio_embed_mask = batch["audio_embed_mask"].cuda() | |
| input_ids = batch["input_ids"].cuda() | |
| attention_mask = batch["attention_mask"].cuda() | |
| print('batch {}:'.format(step+1), audio_clips.shape, audio_embed_mask.shape, input_ids.shape, attention_mask.shape) | |
| labels = input_ids.clone() | |
| labels[labels == tokenizer.pad_token_id] = -100 | |
| labels[:, :2] = -100 | |
| labels[labels == tokenizer.encode("<audio>")[-1]] = -100 | |
| sep_locations = labels == tokenizer.sep_token_id | |
| endofchunk_token_id = tokenizer.encode("<|endofchunk|>")[-1] | |
| eoc_locations = labels == endofchunk_token_id | |
| if not all(sep_locations.sum(dim=1) == eoc_locations.sum(dim=1)): | |
| print("Warning: sep loc {} but eoc loc {}".format(sep_locations.sum(dim=1), eoc_locations.sum(dim=1))) | |
| for input_id in labels: | |
| input_id[input_id==-100] = tokenizer.encode("-")[-1] | |
| print(input_id, '\n', tokenizer.decode(input_id)) | |
| for i in range(labels.shape[0]): | |
| shouldmask = True | |
| for j in range(labels.shape[1]): | |
| if shouldmask and (labels[i][j] != tokenizer.eos_token_id): | |
| masked_value = -100 | |
| else: | |
| masked_value = labels[i][j] | |
| if labels[i][j] == tokenizer.sep_token_id: | |
| shouldmask = False | |
| elif labels[i][j] == endofchunk_token_id: | |
| shouldmask = True | |
| labels[i][j] = masked_value | |
| if labels[i][-1] not in [-100, tokenizer.eos_token_id, tokenizer.pad_token_id, endofchunk_token_id]: | |
| debug_masked_labels_in_the_end = [] | |
| for j in range(labels.shape[1]-1, -1, -1): | |
| if labels[i][j] not in [-100, tokenizer.eos_token_id, endofchunk_token_id]: | |
| debug_masked_labels_in_the_end.insert(0, deepcopy(labels[i][j].item())) | |
| labels[i][j] = -100 | |
| else: | |
| break | |
| print('hit max_token and masking ids from the end:', \ | |
| tokenizer.decode(torch.LongTensor(debug_masked_labels_in_the_end).to(labels.device)) | |
| ) | |
| if step == 50: | |
| break | |