Spaces:
Runtime error
Runtime error
Upload 340 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- app.py +44 -0
- audiocaps_test_16000_struct.tsv +0 -0
- configs/audiolcm.yaml +130 -0
- configs/autoencoder1d.yaml +74 -0
- configs/teacher.yaml +121 -0
- infer.sh +4 -0
- infer_api.sh +4 -0
- ldm/__pycache__/lr_scheduler.cpython-37.pyc +0 -0
- ldm/__pycache__/lr_scheduler.cpython-38.pyc +0 -0
- ldm/__pycache__/util.cpython-310.pyc +0 -0
- ldm/__pycache__/util.cpython-37.pyc +0 -0
- ldm/__pycache__/util.cpython-38.pyc +0 -0
- ldm/data/__pycache__/joinaudiodataset_624.cpython-38.pyc +0 -0
- ldm/data/__pycache__/joinaudiodataset_anylen.cpython-37.pyc +0 -0
- ldm/data/__pycache__/joinaudiodataset_anylen.cpython-38.pyc +0 -0
- ldm/data/__pycache__/joinaudiodataset_struct.cpython-38.pyc +0 -0
- ldm/data/__pycache__/joinaudiodataset_struct_anylen.cpython-38.pyc +0 -0
- ldm/data/__pycache__/joinaudiodataset_struct_sample_anylen.cpython-37.pyc +0 -0
- ldm/data/__pycache__/joinaudiodataset_struct_sample_anylen.cpython-38.pyc +0 -0
- ldm/data/__pycache__/tsvdataset.cpython-38.pyc +0 -0
- ldm/data/joinaudiodataset_624.py +93 -0
- ldm/data/joinaudiodataset_anylen.py +331 -0
- ldm/data/joinaudiodataset_struct.py +95 -0
- ldm/data/joinaudiodataset_struct_anylen.py +336 -0
- ldm/data/joinaudiodataset_struct_sample.py +103 -0
- ldm/data/joinaudiodataset_struct_sample_anylen.py +230 -0
- ldm/data/preprocess/NAT_mel.py +131 -0
- ldm/data/preprocess/__pycache__/NAT_mel.cpython-38.pyc +0 -0
- ldm/data/preprocess/__pycache__/NAT_mel.cpython-39.pyc +0 -0
- ldm/data/preprocess/add_duration.py +45 -0
- ldm/data/preprocess/mel_spec.py +201 -0
- ldm/data/test.py +224 -0
- ldm/data/tsv_dirs/full_data/V1_new/audiocaps_train_16000.tsv +0 -0
- ldm/data/tsv_dirs/full_data/V2/MACS.tsv +0 -0
- ldm/data/tsv_dirs/full_data/V2/WavText5K.tsv +0 -0
- ldm/data/tsv_dirs/full_data/V2/adobe.tsv +0 -0
- ldm/data/tsv_dirs/full_data/V2/audiostock.tsv +0 -0
- ldm/data/tsv_dirs/full_data/V2/epidemic_sound.tsv +3 -0
- ldm/data/tsv_dirs/full_data/caps_struct/audiocaps_train_16000_struct2.tsv +0 -0
- ldm/data/tsv_dirs/full_data/clotho.tsv +0 -0
- ldm/data/tsvdataset.py +67 -0
- ldm/lr_scheduler.py +98 -0
- ldm/models/__pycache__/autoencoder.cpython-37.pyc +0 -0
- ldm/models/__pycache__/autoencoder.cpython-38.pyc +0 -0
- ldm/models/__pycache__/autoencoder.cpython-39.pyc +0 -0
- ldm/models/__pycache__/autoencoder1d.cpython-37.pyc +0 -0
- ldm/models/__pycache__/autoencoder1d.cpython-38.pyc +0 -0
- ldm/models/__pycache__/autoencoder_multi.cpython-38.pyc +0 -0
- ldm/models/autoencoder.py +504 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
ldm/data/tsv_dirs/full_data/V2/epidemic_sound.tsv filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
vocoder/BigVGAN/LibriTTS/train-full.txt filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio
|
| 2 |
+
|
| 3 |
+
def infer(prompt):
|
| 4 |
+
config = OmegaConf.load("configs/audiolcm.yaml")
|
| 5 |
+
|
| 6 |
+
# print("-------quick debug no load ckpt---------")
|
| 7 |
+
# model = instantiate_from_config(config['model'])# for quick debug
|
| 8 |
+
model = load_model_from_config(config,
|
| 9 |
+
"../logs/2024-04-21T14-50-11_text2music-audioset-nonoverlap/epoch=000184.ckpt")
|
| 10 |
+
|
| 11 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 12 |
+
model = model.to(device)
|
| 13 |
+
|
| 14 |
+
sampler = LCMSampler(model)
|
| 15 |
+
|
| 16 |
+
os.makedirs("results/test", exist_ok=True)
|
| 17 |
+
|
| 18 |
+
vocoder = VocoderBigVGAN("../vocoder/logs/bigvnat16k93.5w", device)
|
| 19 |
+
|
| 20 |
+
generator = GenSamples(sampler, model, "results/test", vocoder, save_mel=False, save_wav=True,
|
| 21 |
+
original_inference_steps=config.model.params.num_ddim_timesteps)
|
| 22 |
+
csv_dicts = []
|
| 23 |
+
|
| 24 |
+
with torch.no_grad():
|
| 25 |
+
with model.ema_scope():
|
| 26 |
+
wav_name = f'{prompt.strip().replace(" ", "-")}'
|
| 27 |
+
generator.gen_test_sample(prompt, wav_name=wav_name)
|
| 28 |
+
|
| 29 |
+
print(f"Your samples are ready and waiting four you here: \nresults/test \nEnjoy.")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def my_inference_function(prompt_oir):
|
| 33 |
+
prompt = {'ori_caption':prompt_oir,'struct_caption':prompt_oir}
|
| 34 |
+
file_path = infer(prompt)
|
| 35 |
+
return "test.wav"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
gradio_interface = gradio.Interface(
|
| 40 |
+
fn = my_inference_function,
|
| 41 |
+
inputs = "text",
|
| 42 |
+
outputs = "audio"
|
| 43 |
+
)
|
| 44 |
+
gradio_interface.launch()
|
audiocaps_test_16000_struct.tsv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
configs/audiolcm.yaml
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
base_learning_rate: 3.0e-06
|
| 3 |
+
target: ldm.models.diffusion.lcm_audio.LCM_audio
|
| 4 |
+
params:
|
| 5 |
+
linear_start: 0.00085
|
| 6 |
+
linear_end: 0.012
|
| 7 |
+
num_timesteps_cond: 1
|
| 8 |
+
log_every_t: 200
|
| 9 |
+
timesteps: 1000
|
| 10 |
+
first_stage_key: image
|
| 11 |
+
cond_stage_key: caption
|
| 12 |
+
mel_dim: 20
|
| 13 |
+
mel_length: 312
|
| 14 |
+
channels: 0
|
| 15 |
+
cond_stage_trainable: False
|
| 16 |
+
conditioning_key: crossattn
|
| 17 |
+
monitor: val/loss_simple_ema
|
| 18 |
+
scale_by_std: true
|
| 19 |
+
use_lcm: True
|
| 20 |
+
num_ddim_timesteps: 50
|
| 21 |
+
w_min: 4
|
| 22 |
+
w_max: 12
|
| 23 |
+
ckpt_path: ../ckpt/maa2.ckpt
|
| 24 |
+
|
| 25 |
+
use_ema: false
|
| 26 |
+
scheduler_config:
|
| 27 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
| 28 |
+
params:
|
| 29 |
+
warm_up_steps:
|
| 30 |
+
- 10000
|
| 31 |
+
cycle_lengths:
|
| 32 |
+
- 10000000000000
|
| 33 |
+
f_start:
|
| 34 |
+
- 1.0e-06
|
| 35 |
+
f_max:
|
| 36 |
+
- 1.0
|
| 37 |
+
f_min:
|
| 38 |
+
- 1.0
|
| 39 |
+
unet_config:
|
| 40 |
+
target: ldm.modules.diffusionmodules.concatDiT.ConcatDiT2MLP
|
| 41 |
+
params:
|
| 42 |
+
in_channels: 20
|
| 43 |
+
context_dim: 1024
|
| 44 |
+
hidden_size: 576
|
| 45 |
+
num_heads: 8
|
| 46 |
+
depth: 4
|
| 47 |
+
max_len: 1000
|
| 48 |
+
first_stage_config:
|
| 49 |
+
target: ldm.models.autoencoder1d.AutoencoderKL
|
| 50 |
+
params:
|
| 51 |
+
embed_dim: 20
|
| 52 |
+
monitor: val/rec_loss
|
| 53 |
+
ckpt_path: ./model/AutoencoderKL/epoch=000032.ckpt
|
| 54 |
+
ddconfig:
|
| 55 |
+
double_z: true
|
| 56 |
+
in_channels: 80
|
| 57 |
+
out_ch: 80
|
| 58 |
+
z_channels: 20
|
| 59 |
+
kernel_size: 5
|
| 60 |
+
ch: 384
|
| 61 |
+
ch_mult:
|
| 62 |
+
- 1
|
| 63 |
+
- 2
|
| 64 |
+
- 4
|
| 65 |
+
num_res_blocks: 2
|
| 66 |
+
attn_layers:
|
| 67 |
+
- 3
|
| 68 |
+
down_layers:
|
| 69 |
+
- 0
|
| 70 |
+
dropout: 0.0
|
| 71 |
+
lossconfig:
|
| 72 |
+
target: torch.nn.Identity
|
| 73 |
+
cond_stage_config:
|
| 74 |
+
target: ldm.modules.encoders.modules.FrozenCLAPFLANEmbedder
|
| 75 |
+
params:
|
| 76 |
+
weights_path: ./model/FrozenCLAPFLANEmbedder/CLAP_weights_2022.pth
|
| 77 |
+
|
| 78 |
+
lightning:
|
| 79 |
+
callbacks:
|
| 80 |
+
image_logger:
|
| 81 |
+
target: main.AudioLogger
|
| 82 |
+
params:
|
| 83 |
+
sample_rate: 16000
|
| 84 |
+
for_specs: true
|
| 85 |
+
increase_log_steps: false
|
| 86 |
+
batch_frequency: 5000
|
| 87 |
+
max_images: 8
|
| 88 |
+
melvmin: -5
|
| 89 |
+
melvmax: 1.5
|
| 90 |
+
vocoder_cfg:
|
| 91 |
+
target: vocoder.bigvgan.models.VocoderBigVGAN
|
| 92 |
+
params:
|
| 93 |
+
ckpt_vocoder: ./vocoder/logs/bigvnat16k93.5w
|
| 94 |
+
trainer:
|
| 95 |
+
benchmark: True
|
| 96 |
+
gradient_clip_val: 1.0
|
| 97 |
+
replace_sampler_ddp: false
|
| 98 |
+
max_epochs: 100
|
| 99 |
+
modelcheckpoint:
|
| 100 |
+
params:
|
| 101 |
+
monitor: epoch
|
| 102 |
+
mode: max
|
| 103 |
+
# every_n_train_steps: 2000
|
| 104 |
+
save_top_k: 100
|
| 105 |
+
every_n_epochs: 3
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
data:
|
| 109 |
+
target: main.SpectrogramDataModuleFromConfig
|
| 110 |
+
params:
|
| 111 |
+
batch_size: 8
|
| 112 |
+
num_workers: 32
|
| 113 |
+
spec_dir_path: 'ldm/data/tsv_dirs/full_data/caps_struct'
|
| 114 |
+
mel_num: 80
|
| 115 |
+
train:
|
| 116 |
+
target: ldm.data.joinaudiodataset_struct_anylen.JoinSpecsTrain
|
| 117 |
+
params:
|
| 118 |
+
specs_dataset_cfg:
|
| 119 |
+
validation:
|
| 120 |
+
target: ldm.data.joinaudiodataset_struct_anylen.JoinSpecsValidation
|
| 121 |
+
params:
|
| 122 |
+
specs_dataset_cfg:
|
| 123 |
+
|
| 124 |
+
test_dataset:
|
| 125 |
+
target: ldm.data.tsvdataset.TSVDatasetStruct
|
| 126 |
+
params:
|
| 127 |
+
tsv_path: audiocaps_test_16000_struct.tsv
|
| 128 |
+
spec_crop_len: 624
|
| 129 |
+
|
| 130 |
+
|
configs/autoencoder1d.yaml
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
base_learning_rate: 4.5e-06
|
| 3 |
+
target: ldm.models.autoencoder1d.AutoencoderKL
|
| 4 |
+
params:
|
| 5 |
+
embed_dim: 20
|
| 6 |
+
monitor: val/rec_loss
|
| 7 |
+
ddconfig:
|
| 8 |
+
double_z: true
|
| 9 |
+
in_channels: 80
|
| 10 |
+
out_ch: 80
|
| 11 |
+
z_channels: 20
|
| 12 |
+
kernel_size: 5
|
| 13 |
+
ch: 384
|
| 14 |
+
ch_mult:
|
| 15 |
+
- 1
|
| 16 |
+
- 2
|
| 17 |
+
- 4
|
| 18 |
+
num_res_blocks: 2
|
| 19 |
+
attn_layers:
|
| 20 |
+
- 3
|
| 21 |
+
down_layers:
|
| 22 |
+
- 0
|
| 23 |
+
dropout: 0.0
|
| 24 |
+
lossconfig:
|
| 25 |
+
target: ldm.modules.losses_audio.contperceptual.LPAPSWithDiscriminator
|
| 26 |
+
params:
|
| 27 |
+
disc_start: 80001
|
| 28 |
+
perceptual_weight: 0.0
|
| 29 |
+
kl_weight: 1.0e-06
|
| 30 |
+
disc_weight: 0.5
|
| 31 |
+
disc_in_channels: 1
|
| 32 |
+
disc_loss: mse
|
| 33 |
+
disc_factor: 2
|
| 34 |
+
disc_conditional: false
|
| 35 |
+
r1_reg_weight: 3
|
| 36 |
+
|
| 37 |
+
lightning:
|
| 38 |
+
callbacks:
|
| 39 |
+
image_logger:
|
| 40 |
+
target: main.AudioLogger
|
| 41 |
+
params:
|
| 42 |
+
for_specs: true
|
| 43 |
+
increase_log_steps: false
|
| 44 |
+
batch_frequency: 5000
|
| 45 |
+
max_images: 8
|
| 46 |
+
rescale: false
|
| 47 |
+
melvmin: -5
|
| 48 |
+
melvmax: 1.5
|
| 49 |
+
vocoder_cfg:
|
| 50 |
+
target: vocoder.bigvgan.models.VocoderBigVGAN
|
| 51 |
+
params:
|
| 52 |
+
ckpt_vocoder: vocoder/logs/bigvnat16k93.5w
|
| 53 |
+
trainer:
|
| 54 |
+
sync_batchnorm: false # not working with r1_regularization
|
| 55 |
+
strategy: ddp
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
data:
|
| 59 |
+
target: main.SpectrogramDataModuleFromConfig
|
| 60 |
+
params:
|
| 61 |
+
batch_size: 4
|
| 62 |
+
num_workers: 16
|
| 63 |
+
spec_dir_path: ldm/data/tsv_dirs/full_data/V1_new
|
| 64 |
+
mel_num: 80
|
| 65 |
+
spec_len: 624
|
| 66 |
+
spec_crop_len: 624
|
| 67 |
+
train:
|
| 68 |
+
target: ldm.data.joinaudiodataset_624.JoinSpecsTrain
|
| 69 |
+
params:
|
| 70 |
+
specs_dataset_cfg: null
|
| 71 |
+
validation:
|
| 72 |
+
target: ldm.data.joinaudiodataset_624.JoinSpecsValidation
|
| 73 |
+
params:
|
| 74 |
+
specs_dataset_cfg: null
|
configs/teacher.yaml
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
base_learning_rate: 3.0e-06
|
| 3 |
+
target: ldm.models.diffusion.ddpm_audio.LatentDiffusion_audio
|
| 4 |
+
params:
|
| 5 |
+
linear_start: 0.00085
|
| 6 |
+
linear_end: 0.012
|
| 7 |
+
num_timesteps_cond: 1
|
| 8 |
+
log_every_t: 200
|
| 9 |
+
timesteps: 1000
|
| 10 |
+
first_stage_key: image
|
| 11 |
+
cond_stage_key: caption
|
| 12 |
+
mel_dim: 20
|
| 13 |
+
mel_length: 312
|
| 14 |
+
channels: 0
|
| 15 |
+
cond_stage_trainable: True
|
| 16 |
+
conditioning_key: crossattn
|
| 17 |
+
monitor: val/loss_simple_ema
|
| 18 |
+
scale_by_std: true
|
| 19 |
+
use_ema: false
|
| 20 |
+
scheduler_config:
|
| 21 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
| 22 |
+
params:
|
| 23 |
+
warm_up_steps:
|
| 24 |
+
- 10000
|
| 25 |
+
cycle_lengths:
|
| 26 |
+
- 10000000000000
|
| 27 |
+
f_start:
|
| 28 |
+
- 1.0e-06
|
| 29 |
+
f_max:
|
| 30 |
+
- 1.0
|
| 31 |
+
f_min:
|
| 32 |
+
- 1.0
|
| 33 |
+
unet_config:
|
| 34 |
+
target: ldm.modules.diffusionmodules.concatDiT.ConcatDiT2MLP
|
| 35 |
+
params:
|
| 36 |
+
in_channels: 20
|
| 37 |
+
context_dim: 1024
|
| 38 |
+
hidden_size: 576
|
| 39 |
+
num_heads: 8
|
| 40 |
+
depth: 4
|
| 41 |
+
max_len: 1000
|
| 42 |
+
first_stage_config:
|
| 43 |
+
target: ldm.models.autoencoder1d.AutoencoderKL
|
| 44 |
+
params:
|
| 45 |
+
embed_dim: 20
|
| 46 |
+
monitor: val/rec_loss
|
| 47 |
+
ckpt_path: logs/trainae/ckpt/epoch=000032.ckpt
|
| 48 |
+
ddconfig:
|
| 49 |
+
double_z: true
|
| 50 |
+
in_channels: 80
|
| 51 |
+
out_ch: 80
|
| 52 |
+
z_channels: 20
|
| 53 |
+
kernel_size: 5
|
| 54 |
+
ch: 384
|
| 55 |
+
ch_mult:
|
| 56 |
+
- 1
|
| 57 |
+
- 2
|
| 58 |
+
- 4
|
| 59 |
+
num_res_blocks: 2
|
| 60 |
+
attn_layers:
|
| 61 |
+
- 3
|
| 62 |
+
down_layers:
|
| 63 |
+
- 0
|
| 64 |
+
dropout: 0.0
|
| 65 |
+
lossconfig:
|
| 66 |
+
target: torch.nn.Identity
|
| 67 |
+
cond_stage_config:
|
| 68 |
+
target: ldm.modules.encoders.modules.FrozenCLAPFLANEmbedder
|
| 69 |
+
params:
|
| 70 |
+
weights_path: useful_ckpts/CLAP/CLAP_weights_2022.pth
|
| 71 |
+
|
| 72 |
+
lightning:
|
| 73 |
+
callbacks:
|
| 74 |
+
image_logger:
|
| 75 |
+
target: main.AudioLogger
|
| 76 |
+
params:
|
| 77 |
+
sample_rate: 16000
|
| 78 |
+
for_specs: true
|
| 79 |
+
increase_log_steps: false
|
| 80 |
+
batch_frequency: 5000
|
| 81 |
+
max_images: 8
|
| 82 |
+
melvmin: -5
|
| 83 |
+
melvmax: 1.5
|
| 84 |
+
vocoder_cfg:
|
| 85 |
+
target: vocoder.bigvgan.models.VocoderBigVGAN
|
| 86 |
+
params:
|
| 87 |
+
ckpt_vocoder: vocoder/logs/bigvnat16k93.5w
|
| 88 |
+
trainer:
|
| 89 |
+
benchmark: True
|
| 90 |
+
gradient_clip_val: 1.0
|
| 91 |
+
replace_sampler_ddp: false
|
| 92 |
+
modelcheckpoint:
|
| 93 |
+
params:
|
| 94 |
+
monitor: epoch
|
| 95 |
+
mode: max
|
| 96 |
+
save_top_k: 10
|
| 97 |
+
every_n_epochs: 5
|
| 98 |
+
|
| 99 |
+
data:
|
| 100 |
+
target: main.SpectrogramDataModuleFromConfig
|
| 101 |
+
params:
|
| 102 |
+
batch_size: 4
|
| 103 |
+
num_workers: 32
|
| 104 |
+
main_spec_dir_path: 'ldm/data/tsv_dirs/full_data/caps_struct'
|
| 105 |
+
other_spec_dir_path: 'ldm/data/tsv_dirs/full_data/V2'
|
| 106 |
+
mel_num: 80
|
| 107 |
+
train:
|
| 108 |
+
target: ldm.data.joinaudiodataset_struct_sample_anylen.JoinSpecsTrain
|
| 109 |
+
params:
|
| 110 |
+
specs_dataset_cfg:
|
| 111 |
+
validation:
|
| 112 |
+
target: ldm.data.joinaudiodataset_struct_sample_anylen.JoinSpecsValidation
|
| 113 |
+
params:
|
| 114 |
+
specs_dataset_cfg:
|
| 115 |
+
|
| 116 |
+
test_dataset:
|
| 117 |
+
target: ldm.data.tsvdataset.TSVDatasetStruct
|
| 118 |
+
params:
|
| 119 |
+
tsv_path: musiccap.tsv
|
| 120 |
+
spec_crop_len: 624
|
| 121 |
+
|
infer.sh
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
CUDA_VISIBLE_DEVICES='1' python scripts/txt2audio_for_lcm.py --n_samples 1 --n_iter 1 --scale 5 --H 20 --W 312 \
|
| 2 |
+
--ddim_steps 2 -b configs/audiolcm.yaml \
|
| 3 |
+
--sample_rate 16000 --vocoder-ckpt ../vocoder/logs/bigvnat16k93.5w \
|
| 4 |
+
--outdir results/test --test-dataset audiocaps -r ../logs/2024-04-21T14-50-11_text2music-audioset-nonoverlap/epoch=000184.ckpt
|
infer_api.sh
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
CUDA_VISIBLE_DEVICES='1' python scripts/txt2audio_for_lcm.py --n_samples 1 --n_iter 1 --scale 5 --H 20 --W 312 \
|
| 2 |
+
--ddim_steps 2 -b configs/audiolcm.yaml \
|
| 3 |
+
--sample_rate 16000 --vocoder-ckpt ../vocoder/logs/bigvnat16k93.5w \
|
| 4 |
+
--outdir results/test -r ../logs/2024-04-21T14-50-11_text2music-audioset-nonoverlap/epoch=000184.ckpt --prompt_txt ./prompt.txt
|
ldm/__pycache__/lr_scheduler.cpython-37.pyc
ADDED
|
Binary file (3.66 kB). View file
|
|
|
ldm/__pycache__/lr_scheduler.cpython-38.pyc
ADDED
|
Binary file (3.61 kB). View file
|
|
|
ldm/__pycache__/util.cpython-310.pyc
ADDED
|
Binary file (8.36 kB). View file
|
|
|
ldm/__pycache__/util.cpython-37.pyc
ADDED
|
Binary file (5.1 kB). View file
|
|
|
ldm/__pycache__/util.cpython-38.pyc
ADDED
|
Binary file (8.3 kB). View file
|
|
|
ldm/data/__pycache__/joinaudiodataset_624.cpython-38.pyc
ADDED
|
Binary file (3.62 kB). View file
|
|
|
ldm/data/__pycache__/joinaudiodataset_anylen.cpython-37.pyc
ADDED
|
Binary file (12.4 kB). View file
|
|
|
ldm/data/__pycache__/joinaudiodataset_anylen.cpython-38.pyc
ADDED
|
Binary file (12.1 kB). View file
|
|
|
ldm/data/__pycache__/joinaudiodataset_struct.cpython-38.pyc
ADDED
|
Binary file (3.69 kB). View file
|
|
|
ldm/data/__pycache__/joinaudiodataset_struct_anylen.cpython-38.pyc
ADDED
|
Binary file (12.5 kB). View file
|
|
|
ldm/data/__pycache__/joinaudiodataset_struct_sample_anylen.cpython-37.pyc
ADDED
|
Binary file (8.29 kB). View file
|
|
|
ldm/data/__pycache__/joinaudiodataset_struct_sample_anylen.cpython-38.pyc
ADDED
|
Binary file (8.09 kB). View file
|
|
|
ldm/data/__pycache__/tsvdataset.cpython-38.pyc
ADDED
|
Binary file (2.66 kB). View file
|
|
|
ldm/data/joinaudiodataset_624.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import logging
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import glob
|
| 7 |
+
logger = logging.getLogger(f'main.{__name__}')
|
| 8 |
+
|
| 9 |
+
sys.path.insert(0, '.') # nopep8
|
| 10 |
+
|
| 11 |
+
class JoinManifestSpecs(torch.utils.data.Dataset):
|
| 12 |
+
def __init__(self, split, spec_dir_path, mel_num=None, spec_crop_len=None,drop=0,**kwargs):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.split = split
|
| 15 |
+
self.batch_max_length = spec_crop_len
|
| 16 |
+
self.batch_min_length = 50
|
| 17 |
+
self.mel_num = mel_num
|
| 18 |
+
self.drop = drop
|
| 19 |
+
manifest_files = []
|
| 20 |
+
for dir_path in spec_dir_path.split(','):
|
| 21 |
+
manifest_files += glob.glob(f'{dir_path}/**/*.tsv',recursive=True)
|
| 22 |
+
df_list = [pd.read_csv(manifest,sep='\t') for manifest in manifest_files]
|
| 23 |
+
df = pd.concat(df_list,ignore_index=True)
|
| 24 |
+
|
| 25 |
+
if split == 'train':
|
| 26 |
+
self.dataset = df.iloc[100:]
|
| 27 |
+
elif split == 'valid' or split == 'val':
|
| 28 |
+
self.dataset = df.iloc[:100]
|
| 29 |
+
elif split == 'test':
|
| 30 |
+
df = self.add_name_num(df)
|
| 31 |
+
self.dataset = df
|
| 32 |
+
else:
|
| 33 |
+
raise ValueError(f'Unknown split {split}')
|
| 34 |
+
self.dataset.reset_index(inplace=True)
|
| 35 |
+
print('dataset len:', len(self.dataset))
|
| 36 |
+
|
| 37 |
+
def add_name_num(self,df):
|
| 38 |
+
"""each file may have different caption, we add num to filename to identify each audio-caption pair"""
|
| 39 |
+
name_count_dict = {}
|
| 40 |
+
change = []
|
| 41 |
+
for t in df.itertuples():
|
| 42 |
+
name = getattr(t,'name')
|
| 43 |
+
if name in name_count_dict:
|
| 44 |
+
name_count_dict[name] += 1
|
| 45 |
+
else:
|
| 46 |
+
name_count_dict[name] = 0
|
| 47 |
+
change.append((t[0],name_count_dict[name]))
|
| 48 |
+
for t in change:
|
| 49 |
+
df.loc[t[0],'name'] = df.loc[t[0],'name'] + f'_{t[1]}'
|
| 50 |
+
return df
|
| 51 |
+
|
| 52 |
+
def __getitem__(self, idx):
|
| 53 |
+
data = self.dataset.iloc[idx]
|
| 54 |
+
item = {}
|
| 55 |
+
try:
|
| 56 |
+
spec = np.load(data['mel_path']) # mel spec [80, 624]
|
| 57 |
+
except:
|
| 58 |
+
mel_path = data['mel_path']
|
| 59 |
+
print(f'corrupted:{mel_path}')
|
| 60 |
+
spec = np.zeros((self.mel_num,self.batch_max_length)).astype(np.float32)
|
| 61 |
+
|
| 62 |
+
if spec.shape[1] < self.batch_max_length:
|
| 63 |
+
# spec = np.pad(spec, ((0, 0), (0, self.batch_max_length - spec.shape[1]))) # [80, 624]
|
| 64 |
+
spec = np.tile(spec,reps=(self.batch_max_length//spec.shape[1])+1)
|
| 65 |
+
|
| 66 |
+
item['image'] = spec[:,:self.batch_max_length]
|
| 67 |
+
p = np.random.uniform(0,1)
|
| 68 |
+
if p > self.drop:
|
| 69 |
+
item["caption"] = data['caption']
|
| 70 |
+
else:
|
| 71 |
+
item["caption"] = ""
|
| 72 |
+
if self.split == 'test':
|
| 73 |
+
item['f_name'] = data['name']
|
| 74 |
+
return item
|
| 75 |
+
|
| 76 |
+
def __len__(self):
|
| 77 |
+
return len(self.dataset)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class JoinSpecsTrain(JoinManifestSpecs):
|
| 81 |
+
def __init__(self, specs_dataset_cfg):
|
| 82 |
+
super().__init__('train', **specs_dataset_cfg)
|
| 83 |
+
|
| 84 |
+
class JoinSpecsValidation(JoinManifestSpecs):
|
| 85 |
+
def __init__(self, specs_dataset_cfg):
|
| 86 |
+
super().__init__('valid', **specs_dataset_cfg)
|
| 87 |
+
|
| 88 |
+
class JoinSpecsTest(JoinManifestSpecs):
|
| 89 |
+
def __init__(self, specs_dataset_cfg):
|
| 90 |
+
super().__init__('test', **specs_dataset_cfg)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
|
ldm/data/joinaudiodataset_anylen.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import math
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from torch.utils.data.sampler import Sampler
|
| 7 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 8 |
+
import torch.distributed
|
| 9 |
+
from typing import TypeVar, Optional, Iterator,List
|
| 10 |
+
import logging
|
| 11 |
+
import pandas as pd
|
| 12 |
+
import glob
|
| 13 |
+
import torch.distributed as dist
|
| 14 |
+
logger = logging.getLogger(f'main.{__name__}')
|
| 15 |
+
|
| 16 |
+
sys.path.insert(0, '.') # nopep8
|
| 17 |
+
|
| 18 |
+
class JoinManifestSpecs(torch.utils.data.Dataset):
|
| 19 |
+
def __init__(self, split, spec_dir_path, mel_num=80,spec_crop_len=1248,mode='pad',pad_value=-5,drop=0,**kwargs):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.split = split
|
| 22 |
+
self.max_batch_len = spec_crop_len
|
| 23 |
+
self.min_batch_len = 64
|
| 24 |
+
self.mel_num = mel_num
|
| 25 |
+
self.min_factor = 4
|
| 26 |
+
self.drop = drop
|
| 27 |
+
self.pad_value = pad_value
|
| 28 |
+
assert mode in ['pad','tile']
|
| 29 |
+
self.collate_mode = mode
|
| 30 |
+
# print(f"################# self.collate_mode {self.collate_mode} ##################")
|
| 31 |
+
|
| 32 |
+
manifest_files = []
|
| 33 |
+
for dir_path in spec_dir_path.split(','):
|
| 34 |
+
manifest_files += glob.glob(f'{dir_path}/*.tsv')
|
| 35 |
+
df_list = [pd.read_csv(manifest,sep='\t') for manifest in manifest_files]
|
| 36 |
+
df = pd.concat(df_list,ignore_index=True)
|
| 37 |
+
|
| 38 |
+
if split == 'train':
|
| 39 |
+
self.dataset = df.iloc[100:]
|
| 40 |
+
elif split == 'valid' or split == 'val':
|
| 41 |
+
self.dataset = df.iloc[:100]
|
| 42 |
+
elif split == 'test':
|
| 43 |
+
df = self.add_name_num(df)
|
| 44 |
+
self.dataset = df
|
| 45 |
+
else:
|
| 46 |
+
raise ValueError(f'Unknown split {split}')
|
| 47 |
+
self.dataset.reset_index(inplace=True)
|
| 48 |
+
print('dataset len:', len(self.dataset))
|
| 49 |
+
|
| 50 |
+
def add_name_num(self,df):
|
| 51 |
+
"""each file may have different caption, we add num to filename to identify each audio-caption pair"""
|
| 52 |
+
name_count_dict = {}
|
| 53 |
+
change = []
|
| 54 |
+
for t in df.itertuples():
|
| 55 |
+
name = getattr(t,'name')
|
| 56 |
+
if name in name_count_dict:
|
| 57 |
+
name_count_dict[name] += 1
|
| 58 |
+
else:
|
| 59 |
+
name_count_dict[name] = 0
|
| 60 |
+
change.append((t[0],name_count_dict[name]))
|
| 61 |
+
for t in change:
|
| 62 |
+
df.loc[t[0],'name'] = df.loc[t[0],'name'] + f'_{t[1]}'
|
| 63 |
+
return df
|
| 64 |
+
|
| 65 |
+
def ordered_indices(self):
|
| 66 |
+
index2dur = self.dataset[['duration']]
|
| 67 |
+
index2dur = index2dur.sort_values(by='duration')
|
| 68 |
+
return list(index2dur.index)
|
| 69 |
+
|
| 70 |
+
def __getitem__(self, idx):
|
| 71 |
+
item = {}
|
| 72 |
+
data = self.dataset.iloc[idx]
|
| 73 |
+
try:
|
| 74 |
+
spec = np.load(data['mel_path']) # mel spec [80, 624]
|
| 75 |
+
except:
|
| 76 |
+
mel_path = data['mel_path']
|
| 77 |
+
print(f'corrupted:{mel_path}')
|
| 78 |
+
spec = np.ones((self.mel_num,self.min_batch_len)).astype(np.float32)*self.pad_value
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
item['image'] = spec
|
| 82 |
+
p = np.random.uniform(0,1)
|
| 83 |
+
if p > self.drop:
|
| 84 |
+
item["caption"] = data['caption']
|
| 85 |
+
else:
|
| 86 |
+
item["caption"] = ""
|
| 87 |
+
if self.split == 'test':
|
| 88 |
+
item['f_name'] = data['name']
|
| 89 |
+
# item['f_name'] = data['mel_path']
|
| 90 |
+
return item
|
| 91 |
+
|
| 92 |
+
def collater(self,inputs):
|
| 93 |
+
to_dict = {}
|
| 94 |
+
for l in inputs:
|
| 95 |
+
for k,v in l.items():
|
| 96 |
+
if k in to_dict:
|
| 97 |
+
to_dict[k].append(v)
|
| 98 |
+
else:
|
| 99 |
+
to_dict[k] = [v]
|
| 100 |
+
if self.collate_mode == 'pad':
|
| 101 |
+
to_dict['image'] = collate_1d_or_2d(to_dict['image'],pad_idx=self.pad_value,min_len = self.min_batch_len,max_len=self.max_batch_len,min_factor=self.min_factor)
|
| 102 |
+
elif self.collate_mode == 'tile':
|
| 103 |
+
to_dict['image'] = collate_1d_or_2d_tile(to_dict['image'],min_len = self.min_batch_len,max_len=self.max_batch_len,min_factor=self.min_factor)
|
| 104 |
+
else:
|
| 105 |
+
raise NotImplementedError
|
| 106 |
+
|
| 107 |
+
return to_dict
|
| 108 |
+
|
| 109 |
+
def __len__(self):
|
| 110 |
+
return len(self.dataset)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class JoinSpecsTrain(JoinManifestSpecs):
|
| 114 |
+
def __init__(self, specs_dataset_cfg):
|
| 115 |
+
super().__init__('train', **specs_dataset_cfg)
|
| 116 |
+
|
| 117 |
+
class JoinSpecsValidation(JoinManifestSpecs):
|
| 118 |
+
def __init__(self, specs_dataset_cfg):
|
| 119 |
+
super().__init__('valid', **specs_dataset_cfg)
|
| 120 |
+
|
| 121 |
+
class JoinSpecsTest(JoinManifestSpecs):
|
| 122 |
+
def __init__(self, specs_dataset_cfg):
|
| 123 |
+
super().__init__('test', **specs_dataset_cfg)
|
| 124 |
+
|
| 125 |
+
class JoinSpecsDebug(JoinManifestSpecs):
|
| 126 |
+
def __init__(self, specs_dataset_cfg):
|
| 127 |
+
super().__init__('valid', **specs_dataset_cfg)
|
| 128 |
+
self.dataset = self.dataset.iloc[:37]
|
| 129 |
+
|
| 130 |
+
class DDPIndexBatchSampler(Sampler):# 让长度相似的音频的indices合到一个batch中以避免过长的pad
|
| 131 |
+
def __init__(self, indices ,batch_size, num_replicas: Optional[int] = None,
|
| 132 |
+
rank: Optional[int] = None, shuffle: bool = True,
|
| 133 |
+
seed: int = 0, drop_last: bool = False) -> None:
|
| 134 |
+
if num_replicas is None:
|
| 135 |
+
if not dist.is_initialized():
|
| 136 |
+
# raise RuntimeError("Requires distributed package to be available")
|
| 137 |
+
print("Not in distributed mode")
|
| 138 |
+
num_replicas = 1
|
| 139 |
+
else:
|
| 140 |
+
num_replicas = dist.get_world_size()
|
| 141 |
+
if rank is None:
|
| 142 |
+
if not dist.is_initialized():
|
| 143 |
+
# raise RuntimeError("Requires distributed package to be available")
|
| 144 |
+
rank = 0
|
| 145 |
+
else:
|
| 146 |
+
rank = dist.get_rank()
|
| 147 |
+
if rank >= num_replicas or rank < 0:
|
| 148 |
+
raise ValueError(
|
| 149 |
+
"Invalid rank {}, rank should be in the interval"
|
| 150 |
+
" [0, {}]".format(rank, num_replicas - 1))
|
| 151 |
+
self.indices = indices
|
| 152 |
+
self.num_replicas = num_replicas
|
| 153 |
+
self.rank = rank
|
| 154 |
+
self.epoch = 0
|
| 155 |
+
self.drop_last = drop_last
|
| 156 |
+
self.batch_size = batch_size
|
| 157 |
+
|
| 158 |
+
self.batches = self.build_batches()
|
| 159 |
+
print(f"rank: {self.rank}, batches_num {len(self.batches)}")
|
| 160 |
+
# If the dataset length is evenly divisible by replicas, then there
|
| 161 |
+
# is no need to drop any data, since the dataset will be split equally.
|
| 162 |
+
if self.drop_last and len(self.batches) % self.num_replicas != 0:
|
| 163 |
+
self.batches = self.batches[:len(self.batches)//self.num_replicas*self.num_replicas]
|
| 164 |
+
if len(self.batches) > self.num_replicas:
|
| 165 |
+
self.batches = self.batches[self.rank::self.num_replicas]
|
| 166 |
+
else: # may happen in sanity checking
|
| 167 |
+
self.batches = [self.batches[0]]
|
| 168 |
+
print(f"after split batches_num {len(self.batches)}")
|
| 169 |
+
self.shuffle = shuffle
|
| 170 |
+
if self.shuffle:
|
| 171 |
+
self.batches = np.random.permutation(self.batches)
|
| 172 |
+
self.seed = seed
|
| 173 |
+
|
| 174 |
+
def set_epoch(self,epoch):
|
| 175 |
+
self.epoch = epoch
|
| 176 |
+
if self.shuffle:
|
| 177 |
+
np.random.seed(self.seed+self.epoch)
|
| 178 |
+
self.batches = np.random.permutation(self.batches)
|
| 179 |
+
|
| 180 |
+
def build_batches(self):
|
| 181 |
+
batches,batch = [],[]
|
| 182 |
+
for index in self.indices:
|
| 183 |
+
batch.append(index)
|
| 184 |
+
if len(batch) == self.batch_size:
|
| 185 |
+
batches.append(batch)
|
| 186 |
+
batch = []
|
| 187 |
+
if not self.drop_last and len(batch) > 0:
|
| 188 |
+
batches.append(batch)
|
| 189 |
+
return batches
|
| 190 |
+
|
| 191 |
+
def __iter__(self) -> Iterator[List[int]]:
|
| 192 |
+
for batch in self.batches:
|
| 193 |
+
yield batch
|
| 194 |
+
|
| 195 |
+
def __len__(self) -> int:
|
| 196 |
+
return len(self.batches)
|
| 197 |
+
|
| 198 |
+
def set_epoch(self, epoch: int) -> None:
|
| 199 |
+
r"""
|
| 200 |
+
Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
|
| 201 |
+
use a different random ordering for each epoch. Otherwise, the next iteration of this
|
| 202 |
+
sampler will yield the same ordering.
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
epoch (int): Epoch number.
|
| 206 |
+
"""
|
| 207 |
+
self.epoch = epoch
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def collate_1d_or_2d(values, pad_idx=0, left_pad=False, shift_right=False,min_len = None, max_len=None,min_factor=None, shift_id=1):
|
| 211 |
+
if len(values[0].shape) == 1:
|
| 212 |
+
return collate_1d(values, pad_idx, left_pad, shift_right,min_len, max_len,min_factor, shift_id)
|
| 213 |
+
else:
|
| 214 |
+
return collate_2d(values, pad_idx, left_pad, shift_right,min_len,max_len,min_factor)
|
| 215 |
+
|
| 216 |
+
def collate_1d(values, pad_idx=0, left_pad=False, shift_right=False,min_len=None, max_len=None,min_factor=None, shift_id=1):
|
| 217 |
+
"""Convert a list of 1d tensors into a padded 2d tensor."""
|
| 218 |
+
size = max(v.size(0) for v in values)
|
| 219 |
+
if max_len:
|
| 220 |
+
size = min(size,max_len)
|
| 221 |
+
if min_len:
|
| 222 |
+
size = max(size,min_len)
|
| 223 |
+
if min_factor and (size % min_factor!=0):# size must be the multiple of min_factor
|
| 224 |
+
size += (min_factor - size % min_factor)
|
| 225 |
+
res = values[0].new(len(values), size).fill_(pad_idx)
|
| 226 |
+
|
| 227 |
+
def copy_tensor(src, dst):
|
| 228 |
+
assert dst.numel() == src.numel(), f"dst shape:{dst.shape} src shape:{src.shape}"
|
| 229 |
+
if shift_right:
|
| 230 |
+
dst[1:] = src[:-1]
|
| 231 |
+
dst[0] = shift_id
|
| 232 |
+
else:
|
| 233 |
+
dst.copy_(src)
|
| 234 |
+
|
| 235 |
+
for i, v in enumerate(values):
|
| 236 |
+
copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
|
| 237 |
+
return res
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def collate_2d(values, pad_idx=0, left_pad=False, shift_right=False, min_len=None,max_len=None,min_factor=None):
|
| 241 |
+
"""Collate 2d for melspec,Convert a list of 2d tensors into a padded 3d tensor,pad in mel_length dimension.
|
| 242 |
+
values[0] shape: (melbins,mel_length)
|
| 243 |
+
"""
|
| 244 |
+
size = max(v.shape[1] for v in values) # if max_len is None else max_len
|
| 245 |
+
if max_len:
|
| 246 |
+
size = min(size,max_len)
|
| 247 |
+
if min_len:
|
| 248 |
+
size = max(size,min_len)
|
| 249 |
+
if min_factor and (size % min_factor!=0):# size must be the multiple of min_factor
|
| 250 |
+
size += (min_factor - size % min_factor)
|
| 251 |
+
|
| 252 |
+
if isinstance(values,np.ndarray):
|
| 253 |
+
values = torch.FloatTensor(values)
|
| 254 |
+
if isinstance(values,list):
|
| 255 |
+
values = [torch.FloatTensor(v) for v in values]
|
| 256 |
+
res = torch.ones(len(values), values[0].shape[0],size).to(dtype=torch.float32)*pad_idx
|
| 257 |
+
|
| 258 |
+
def copy_tensor(src, dst):
|
| 259 |
+
assert dst.numel() == src.numel(), f"dst shape:{dst.shape} src shape:{src.shape}"
|
| 260 |
+
if shift_right:
|
| 261 |
+
dst[1:] = src[:-1]
|
| 262 |
+
else:
|
| 263 |
+
dst.copy_(src)
|
| 264 |
+
|
| 265 |
+
for i, v in enumerate(values):
|
| 266 |
+
copy_tensor(v[:,:size], res[i][:,size - v.shape[1]:] if left_pad else res[i][:,:v.shape[1]])
|
| 267 |
+
return res
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def collate_1d_or_2d_tile(values, shift_right=False,min_len = None, max_len=None,min_factor=None, shift_id=1):
|
| 271 |
+
if len(values[0].shape) == 1:
|
| 272 |
+
return collate_1d_tile(values, shift_right,min_len, max_len,min_factor, shift_id)
|
| 273 |
+
else:
|
| 274 |
+
return collate_2d_tile(values, shift_right,min_len,max_len,min_factor)
|
| 275 |
+
|
| 276 |
+
def collate_1d_tile(values, shift_right=False,min_len=None, max_len=None,min_factor=None,shift_id=1):
|
| 277 |
+
"""Convert a list of 1d tensors into a padded 2d tensor."""
|
| 278 |
+
size = max(v.size(0) for v in values)
|
| 279 |
+
if max_len:
|
| 280 |
+
size = min(size,max_len)
|
| 281 |
+
if min_len:
|
| 282 |
+
size = max(size,min_len)
|
| 283 |
+
if min_factor and (size%min_factor!=0):# size must be the multiple of min_factor
|
| 284 |
+
size += (min_factor - size % min_factor)
|
| 285 |
+
res = values[0].new(len(values), size)
|
| 286 |
+
|
| 287 |
+
def copy_tensor(src, dst):
|
| 288 |
+
assert dst.numel() == src.numel(), f"dst shape:{dst.shape} src shape:{src.shape}"
|
| 289 |
+
if shift_right:
|
| 290 |
+
dst[1:] = src[:-1]
|
| 291 |
+
dst[0] = shift_id
|
| 292 |
+
else:
|
| 293 |
+
dst.copy_(src)
|
| 294 |
+
|
| 295 |
+
for i, v in enumerate(values):
|
| 296 |
+
n_repeat = math.ceil((size + 1) / v.shape[0])
|
| 297 |
+
v = torch.tile(v,dims=(1,n_repeat))[:size]
|
| 298 |
+
copy_tensor(v, res[i])
|
| 299 |
+
|
| 300 |
+
return res
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def collate_2d_tile(values, shift_right=False, min_len=None,max_len=None,min_factor=None):
|
| 304 |
+
"""Collate 2d for melspec,Convert a list of 2d tensors into a padded 3d tensor,pad in mel_length dimension. """
|
| 305 |
+
size = max(v.shape[1] for v in values) # if max_len is None else max_len
|
| 306 |
+
if max_len:
|
| 307 |
+
size = min(size,max_len)
|
| 308 |
+
if min_len:
|
| 309 |
+
size = max(size,min_len)
|
| 310 |
+
if min_factor and (size % min_factor!=0):# size must be the multiple of min_factor
|
| 311 |
+
size += (min_factor - size % min_factor)
|
| 312 |
+
|
| 313 |
+
if isinstance(values,np.ndarray):
|
| 314 |
+
values = torch.FloatTensor(values)
|
| 315 |
+
if isinstance(values,list):
|
| 316 |
+
values = [torch.FloatTensor(v) for v in values]
|
| 317 |
+
res = torch.zeros(len(values), values[0].shape[0],size).to(dtype=torch.float32)
|
| 318 |
+
|
| 319 |
+
def copy_tensor(src, dst):
|
| 320 |
+
assert dst.numel() == src.numel()
|
| 321 |
+
if shift_right:
|
| 322 |
+
dst[1:] = src[:-1]
|
| 323 |
+
else:
|
| 324 |
+
dst.copy_(src)
|
| 325 |
+
|
| 326 |
+
for i, v in enumerate(values):
|
| 327 |
+
n_repeat = math.ceil((size + 1) / v.shape[1])
|
| 328 |
+
v = torch.tile(v,dims=(1,n_repeat))[:,:size]
|
| 329 |
+
copy_tensor(v, res[i])
|
| 330 |
+
|
| 331 |
+
return res
|
ldm/data/joinaudiodataset_struct.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import logging
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import glob
|
| 7 |
+
logger = logging.getLogger(f'main.{__name__}')
|
| 8 |
+
|
| 9 |
+
sys.path.insert(0, '.') # nopep8
|
| 10 |
+
|
| 11 |
+
class JoinManifestSpecs(torch.utils.data.Dataset):
|
| 12 |
+
def __init__(self, split, spec_dir_path, mel_num=None, spec_crop_len=None,drop=0,**kwargs):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.split = split
|
| 15 |
+
self.batch_max_length = spec_crop_len
|
| 16 |
+
self.batch_min_length = 50
|
| 17 |
+
self.drop = drop
|
| 18 |
+
self.mel_num = mel_num
|
| 19 |
+
|
| 20 |
+
manifest_files = []
|
| 21 |
+
for dir_path in spec_dir_path.split(','):
|
| 22 |
+
manifest_files += glob.glob(f'{dir_path}/*.tsv')
|
| 23 |
+
df_list = [pd.read_csv(manifest,sep='\t') for manifest in manifest_files]
|
| 24 |
+
df = pd.concat(df_list,ignore_index=True)
|
| 25 |
+
|
| 26 |
+
if split == 'train':
|
| 27 |
+
self.dataset = df.iloc[100:]
|
| 28 |
+
elif split == 'valid' or split == 'val':
|
| 29 |
+
self.dataset = df.iloc[:100]
|
| 30 |
+
elif split == 'test':
|
| 31 |
+
df = self.add_name_num(df)
|
| 32 |
+
self.dataset = df
|
| 33 |
+
else:
|
| 34 |
+
raise ValueError(f'Unknown split {split}')
|
| 35 |
+
self.dataset.reset_index(inplace=True)
|
| 36 |
+
print('dataset len:', len(self.dataset))
|
| 37 |
+
|
| 38 |
+
def add_name_num(self,df):
|
| 39 |
+
"""each file may have different caption, we add num to filename to identify each audio-caption pair"""
|
| 40 |
+
name_count_dict = {}
|
| 41 |
+
change = []
|
| 42 |
+
for t in df.itertuples():
|
| 43 |
+
name = getattr(t,'name')
|
| 44 |
+
if name in name_count_dict:
|
| 45 |
+
name_count_dict[name] += 1
|
| 46 |
+
else:
|
| 47 |
+
name_count_dict[name] = 0
|
| 48 |
+
change.append((t[0],name_count_dict[name]))
|
| 49 |
+
for t in change:
|
| 50 |
+
df.loc[t[0],'name'] = df.loc[t[0],'name'] + f'_{t[1]}'
|
| 51 |
+
return df
|
| 52 |
+
|
| 53 |
+
def __getitem__(self, idx):
|
| 54 |
+
data = self.dataset.iloc[idx]
|
| 55 |
+
item = {}
|
| 56 |
+
try:
|
| 57 |
+
spec = np.load(data['mel_path']) # mel spec [80, 624]
|
| 58 |
+
except:
|
| 59 |
+
mel_path = data['mel_path']
|
| 60 |
+
print(f'corrupted:{mel_path}')
|
| 61 |
+
spec = np.zeros((self.mel_num,self.batch_max_length)).astype(np.float32)
|
| 62 |
+
|
| 63 |
+
if spec.shape[1] <= self.batch_max_length:
|
| 64 |
+
spec = np.pad(spec, ((0, 0), (0, self.batch_max_length - spec.shape[1]))) # [80, 624]
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
item['image'] = spec[:self.mel_num,:self.batch_max_length]
|
| 68 |
+
p = np.random.uniform(0,1)
|
| 69 |
+
if p > self.drop:
|
| 70 |
+
item["caption"] = {"ori_caption":data['ori_cap'],"struct_caption":data['caption']}
|
| 71 |
+
else:
|
| 72 |
+
item["caption"] = {"ori_caption":"","struct_caption":""}
|
| 73 |
+
|
| 74 |
+
if self.split == 'test':
|
| 75 |
+
item['f_name'] = data['name']
|
| 76 |
+
return item
|
| 77 |
+
|
| 78 |
+
def __len__(self):
|
| 79 |
+
return len(self.dataset)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class JoinSpecsTrain(JoinManifestSpecs):
|
| 83 |
+
def __init__(self, specs_dataset_cfg):
|
| 84 |
+
super().__init__('train', **specs_dataset_cfg)
|
| 85 |
+
|
| 86 |
+
class JoinSpecsValidation(JoinManifestSpecs):
|
| 87 |
+
def __init__(self, specs_dataset_cfg):
|
| 88 |
+
super().__init__('valid', **specs_dataset_cfg)
|
| 89 |
+
|
| 90 |
+
class JoinSpecsTest(JoinManifestSpecs):
|
| 91 |
+
def __init__(self, specs_dataset_cfg):
|
| 92 |
+
super().__init__('test', **specs_dataset_cfg)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
|
ldm/data/joinaudiodataset_struct_anylen.py
ADDED
|
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import math
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from torch.utils.data.sampler import Sampler
|
| 7 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 8 |
+
import torch.distributed
|
| 9 |
+
from typing import TypeVar, Optional, Iterator,List
|
| 10 |
+
import logging
|
| 11 |
+
import pandas as pd
|
| 12 |
+
import glob
|
| 13 |
+
import torch.distributed as dist
|
| 14 |
+
logger = logging.getLogger(f'main.{__name__}')
|
| 15 |
+
|
| 16 |
+
sys.path.insert(0, '.') # nopep8
|
| 17 |
+
|
| 18 |
+
class JoinManifestSpecs(torch.utils.data.Dataset):
|
| 19 |
+
def __init__(self, split, spec_dir_path, mel_num=80,spec_crop_len=1248,mode='pad',pad_value=-5,drop=0,**kwargs):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.split = split
|
| 22 |
+
self.max_batch_len = spec_crop_len
|
| 23 |
+
self.min_batch_len = 64
|
| 24 |
+
self.mel_num = mel_num
|
| 25 |
+
self.min_factor = 4
|
| 26 |
+
self.drop = drop
|
| 27 |
+
self.pad_value = pad_value
|
| 28 |
+
assert mode in ['pad','tile']
|
| 29 |
+
self.collate_mode = mode
|
| 30 |
+
# print(f"################# self.collate_mode {self.collate_mode} ##################")
|
| 31 |
+
|
| 32 |
+
manifest_files = []
|
| 33 |
+
for dir_path in spec_dir_path.split(','):
|
| 34 |
+
manifest_files += glob.glob(f'{dir_path}/*.tsv')
|
| 35 |
+
df_list = [pd.read_csv(manifest,sep='\t') for manifest in manifest_files]
|
| 36 |
+
df = pd.concat(df_list,ignore_index=True)
|
| 37 |
+
|
| 38 |
+
if split == 'train':
|
| 39 |
+
self.dataset = df.iloc[100:]
|
| 40 |
+
elif split == 'valid' or split == 'val':
|
| 41 |
+
self.dataset = df.iloc[:100]
|
| 42 |
+
elif split == 'test':
|
| 43 |
+
df = self.add_name_num(df)
|
| 44 |
+
self.dataset = df
|
| 45 |
+
else:
|
| 46 |
+
raise ValueError(f'Unknown split {split}')
|
| 47 |
+
self.dataset.reset_index(inplace=True)
|
| 48 |
+
print('dataset len:', len(self.dataset))
|
| 49 |
+
|
| 50 |
+
def add_name_num(self,df):
|
| 51 |
+
"""each file may have different caption, we add num to filename to identify each audio-caption pair"""
|
| 52 |
+
name_count_dict = {}
|
| 53 |
+
change = []
|
| 54 |
+
for t in df.itertuples():
|
| 55 |
+
name = getattr(t,'name')
|
| 56 |
+
if name in name_count_dict:
|
| 57 |
+
name_count_dict[name] += 1
|
| 58 |
+
else:
|
| 59 |
+
name_count_dict[name] = 0
|
| 60 |
+
change.append((t[0],name_count_dict[name]))
|
| 61 |
+
for t in change:
|
| 62 |
+
df.loc[t[0],'name'] = df.loc[t[0],'name'] + f'_{t[1]}'
|
| 63 |
+
return df
|
| 64 |
+
|
| 65 |
+
def ordered_indices(self):
|
| 66 |
+
index2dur = self.dataset[['duration']]
|
| 67 |
+
index2dur = index2dur.sort_values(by='duration')
|
| 68 |
+
return list(index2dur.index)
|
| 69 |
+
|
| 70 |
+
def __getitem__(self, idx):
|
| 71 |
+
item = {}
|
| 72 |
+
data = self.dataset.iloc[idx]
|
| 73 |
+
try:
|
| 74 |
+
spec = np.load(data['mel_path']) # mel spec [80, 624]
|
| 75 |
+
except:
|
| 76 |
+
mel_path = data['mel_path']
|
| 77 |
+
print(f'corrupted:{mel_path}')
|
| 78 |
+
spec = np.ones((self.mel_num,self.min_batch_len)).astype(np.float32)*self.pad_value
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
item['image'] = spec
|
| 82 |
+
p = np.random.uniform(0,1)
|
| 83 |
+
if p > self.drop:
|
| 84 |
+
ori_caption = data['caption']
|
| 85 |
+
struct_caption = f'<{ori_caption}& all>'
|
| 86 |
+
else:
|
| 87 |
+
ori_caption = ""
|
| 88 |
+
struct_caption = ""
|
| 89 |
+
item["caption"] = {"ori_caption":ori_caption,"struct_caption":struct_caption}
|
| 90 |
+
if self.split == 'test':
|
| 91 |
+
item['f_name'] = data['name']
|
| 92 |
+
# item['f_name'] = data['mel_path']
|
| 93 |
+
return item
|
| 94 |
+
|
| 95 |
+
def collater(self,inputs):
|
| 96 |
+
to_dict = {}
|
| 97 |
+
for l in inputs:
|
| 98 |
+
for k,v in l.items():
|
| 99 |
+
if k in to_dict:
|
| 100 |
+
to_dict[k].append(v)
|
| 101 |
+
else:
|
| 102 |
+
to_dict[k] = [v]
|
| 103 |
+
if self.collate_mode == 'pad':
|
| 104 |
+
to_dict['image'] = collate_1d_or_2d(to_dict['image'],pad_idx=self.pad_value,min_len = self.min_batch_len,max_len=self.max_batch_len,min_factor=self.min_factor)
|
| 105 |
+
elif self.collate_mode == 'tile':
|
| 106 |
+
to_dict['image'] = collate_1d_or_2d_tile(to_dict['image'],min_len = self.min_batch_len,max_len=self.max_batch_len,min_factor=self.min_factor)
|
| 107 |
+
else:
|
| 108 |
+
raise NotImplementedError
|
| 109 |
+
to_dict['caption'] = {'ori_caption':[c['ori_caption'] for c in to_dict['caption']],
|
| 110 |
+
'struct_caption':[c['struct_caption'] for c in to_dict['caption']]}
|
| 111 |
+
|
| 112 |
+
return to_dict
|
| 113 |
+
|
| 114 |
+
def __len__(self):
|
| 115 |
+
return len(self.dataset)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class JoinSpecsTrain(JoinManifestSpecs):
|
| 119 |
+
def __init__(self, specs_dataset_cfg):
|
| 120 |
+
super().__init__('train', **specs_dataset_cfg)
|
| 121 |
+
|
| 122 |
+
class JoinSpecsValidation(JoinManifestSpecs):
|
| 123 |
+
def __init__(self, specs_dataset_cfg):
|
| 124 |
+
super().__init__('valid', **specs_dataset_cfg)
|
| 125 |
+
|
| 126 |
+
class JoinSpecsTest(JoinManifestSpecs):
|
| 127 |
+
def __init__(self, specs_dataset_cfg):
|
| 128 |
+
super().__init__('test', **specs_dataset_cfg)
|
| 129 |
+
|
| 130 |
+
class JoinSpecsDebug(JoinManifestSpecs):
|
| 131 |
+
def __init__(self, specs_dataset_cfg):
|
| 132 |
+
super().__init__('valid', **specs_dataset_cfg)
|
| 133 |
+
self.dataset = self.dataset.iloc[:37]
|
| 134 |
+
|
| 135 |
+
class DDPIndexBatchSampler(Sampler):# 让长度相似的音频的indices合到��个batch中以避免过长的pad
|
| 136 |
+
def __init__(self, indices ,batch_size, num_replicas: Optional[int] = None,
|
| 137 |
+
rank: Optional[int] = None, shuffle: bool = True,
|
| 138 |
+
seed: int = 0, drop_last: bool = False) -> None:
|
| 139 |
+
if num_replicas is None:
|
| 140 |
+
if not dist.is_initialized():
|
| 141 |
+
# raise RuntimeError("Requires distributed package to be available")
|
| 142 |
+
print("Not in distributed mode")
|
| 143 |
+
num_replicas = 1
|
| 144 |
+
else:
|
| 145 |
+
num_replicas = dist.get_world_size()
|
| 146 |
+
if rank is None:
|
| 147 |
+
if not dist.is_initialized():
|
| 148 |
+
# raise RuntimeError("Requires distributed package to be available")
|
| 149 |
+
rank = 0
|
| 150 |
+
else:
|
| 151 |
+
rank = dist.get_rank()
|
| 152 |
+
if rank >= num_replicas or rank < 0:
|
| 153 |
+
raise ValueError(
|
| 154 |
+
"Invalid rank {}, rank should be in the interval"
|
| 155 |
+
" [0, {}]".format(rank, num_replicas - 1))
|
| 156 |
+
self.indices = indices
|
| 157 |
+
self.num_replicas = num_replicas
|
| 158 |
+
self.rank = rank
|
| 159 |
+
self.epoch = 0
|
| 160 |
+
self.drop_last = drop_last
|
| 161 |
+
self.batch_size = batch_size
|
| 162 |
+
|
| 163 |
+
self.batches = self.build_batches()
|
| 164 |
+
print(f"rank: {self.rank}, batches_num {len(self.batches)}")
|
| 165 |
+
# If the dataset length is evenly divisible by replicas, then there
|
| 166 |
+
# is no need to drop any data, since the dataset will be split equally.
|
| 167 |
+
if self.drop_last and len(self.batches) % self.num_replicas != 0:
|
| 168 |
+
self.batches = self.batches[:len(self.batches)//self.num_replicas*self.num_replicas]
|
| 169 |
+
if len(self.batches) > self.num_replicas:
|
| 170 |
+
self.batches = self.batches[self.rank::self.num_replicas]
|
| 171 |
+
else: # may happen in sanity checking
|
| 172 |
+
self.batches = [self.batches[0]]
|
| 173 |
+
print(f"after split batches_num {len(self.batches)}")
|
| 174 |
+
self.shuffle = shuffle
|
| 175 |
+
if self.shuffle:
|
| 176 |
+
self.batches = np.random.permutation(self.batches)
|
| 177 |
+
self.seed = seed
|
| 178 |
+
|
| 179 |
+
def set_epoch(self,epoch):
|
| 180 |
+
self.epoch = epoch
|
| 181 |
+
if self.shuffle:
|
| 182 |
+
np.random.seed(self.seed+self.epoch)
|
| 183 |
+
self.batches = np.random.permutation(self.batches)
|
| 184 |
+
|
| 185 |
+
def build_batches(self):
|
| 186 |
+
batches,batch = [],[]
|
| 187 |
+
for index in self.indices:
|
| 188 |
+
batch.append(index)
|
| 189 |
+
if len(batch) == self.batch_size:
|
| 190 |
+
batches.append(batch)
|
| 191 |
+
batch = []
|
| 192 |
+
if not self.drop_last and len(batch) > 0:
|
| 193 |
+
batches.append(batch)
|
| 194 |
+
return batches
|
| 195 |
+
|
| 196 |
+
def __iter__(self) -> Iterator[List[int]]:
|
| 197 |
+
for batch in self.batches:
|
| 198 |
+
yield batch
|
| 199 |
+
|
| 200 |
+
def __len__(self) -> int:
|
| 201 |
+
return len(self.batches)
|
| 202 |
+
|
| 203 |
+
def set_epoch(self, epoch: int) -> None:
|
| 204 |
+
r"""
|
| 205 |
+
Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
|
| 206 |
+
use a different random ordering for each epoch. Otherwise, the next iteration of this
|
| 207 |
+
sampler will yield the same ordering.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
epoch (int): Epoch number.
|
| 211 |
+
"""
|
| 212 |
+
self.epoch = epoch
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def collate_1d_or_2d(values, pad_idx=0, left_pad=False, shift_right=False,min_len = None, max_len=None,min_factor=None, shift_id=1):
|
| 216 |
+
if len(values[0].shape) == 1:
|
| 217 |
+
return collate_1d(values, pad_idx, left_pad, shift_right,min_len, max_len,min_factor, shift_id)
|
| 218 |
+
else:
|
| 219 |
+
return collate_2d(values, pad_idx, left_pad, shift_right,min_len,max_len,min_factor)
|
| 220 |
+
|
| 221 |
+
def collate_1d(values, pad_idx=0, left_pad=False, shift_right=False,min_len=None, max_len=None,min_factor=None, shift_id=1):
|
| 222 |
+
"""Convert a list of 1d tensors into a padded 2d tensor."""
|
| 223 |
+
size = max(v.size(0) for v in values)
|
| 224 |
+
if max_len:
|
| 225 |
+
size = min(size,max_len)
|
| 226 |
+
if min_len:
|
| 227 |
+
size = max(size,min_len)
|
| 228 |
+
if min_factor and (size % min_factor!=0):# size must be the multiple of min_factor
|
| 229 |
+
size += (min_factor - size % min_factor)
|
| 230 |
+
res = values[0].new(len(values), size).fill_(pad_idx)
|
| 231 |
+
|
| 232 |
+
def copy_tensor(src, dst):
|
| 233 |
+
assert dst.numel() == src.numel(), f"dst shape:{dst.shape} src shape:{src.shape}"
|
| 234 |
+
if shift_right:
|
| 235 |
+
dst[1:] = src[:-1]
|
| 236 |
+
dst[0] = shift_id
|
| 237 |
+
else:
|
| 238 |
+
dst.copy_(src)
|
| 239 |
+
|
| 240 |
+
for i, v in enumerate(values):
|
| 241 |
+
copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
|
| 242 |
+
return res
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def collate_2d(values, pad_idx=0, left_pad=False, shift_right=False, min_len=None,max_len=None,min_factor=None):
|
| 246 |
+
"""Collate 2d for melspec,Convert a list of 2d tensors into a padded 3d tensor,pad in mel_length dimension.
|
| 247 |
+
values[0] shape: (melbins,mel_length)
|
| 248 |
+
"""
|
| 249 |
+
size = max(v.shape[1] for v in values) # if max_len is None else max_len
|
| 250 |
+
if max_len:
|
| 251 |
+
size = min(size,max_len)
|
| 252 |
+
if min_len:
|
| 253 |
+
size = max(size,min_len)
|
| 254 |
+
if min_factor and (size % min_factor!=0):# size must be the multiple of min_factor
|
| 255 |
+
size += (min_factor - size % min_factor)
|
| 256 |
+
|
| 257 |
+
if isinstance(values,np.ndarray):
|
| 258 |
+
values = torch.FloatTensor(values)
|
| 259 |
+
if isinstance(values,list):
|
| 260 |
+
values = [torch.FloatTensor(v) for v in values]
|
| 261 |
+
res = torch.ones(len(values), values[0].shape[0],size).to(dtype=torch.float32)*pad_idx
|
| 262 |
+
|
| 263 |
+
def copy_tensor(src, dst):
|
| 264 |
+
assert dst.numel() == src.numel(), f"dst shape:{dst.shape} src shape:{src.shape}"
|
| 265 |
+
if shift_right:
|
| 266 |
+
dst[1:] = src[:-1]
|
| 267 |
+
else:
|
| 268 |
+
dst.copy_(src)
|
| 269 |
+
|
| 270 |
+
for i, v in enumerate(values):
|
| 271 |
+
copy_tensor(v[:,:size], res[i][:,size - v.shape[1]:] if left_pad else res[i][:,:v.shape[1]])
|
| 272 |
+
return res
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def collate_1d_or_2d_tile(values, shift_right=False,min_len = None, max_len=None,min_factor=None, shift_id=1):
|
| 276 |
+
if len(values[0].shape) == 1:
|
| 277 |
+
return collate_1d_tile(values, shift_right,min_len, max_len,min_factor, shift_id)
|
| 278 |
+
else:
|
| 279 |
+
return collate_2d_tile(values, shift_right,min_len,max_len,min_factor)
|
| 280 |
+
|
| 281 |
+
def collate_1d_tile(values, shift_right=False,min_len=None, max_len=None,min_factor=None,shift_id=1):
|
| 282 |
+
"""Convert a list of 1d tensors into a padded 2d tensor."""
|
| 283 |
+
size = max(v.size(0) for v in values)
|
| 284 |
+
if max_len:
|
| 285 |
+
size = min(size,max_len)
|
| 286 |
+
if min_len:
|
| 287 |
+
size = max(size,min_len)
|
| 288 |
+
if min_factor and (size%min_factor!=0):# size must be the multiple of min_factor
|
| 289 |
+
size += (min_factor - size % min_factor)
|
| 290 |
+
res = values[0].new(len(values), size)
|
| 291 |
+
|
| 292 |
+
def copy_tensor(src, dst):
|
| 293 |
+
assert dst.numel() == src.numel(), f"dst shape:{dst.shape} src shape:{src.shape}"
|
| 294 |
+
if shift_right:
|
| 295 |
+
dst[1:] = src[:-1]
|
| 296 |
+
dst[0] = shift_id
|
| 297 |
+
else:
|
| 298 |
+
dst.copy_(src)
|
| 299 |
+
|
| 300 |
+
for i, v in enumerate(values):
|
| 301 |
+
n_repeat = math.ceil((size + 1) / v.shape[0])
|
| 302 |
+
v = torch.tile(v,dims=(1,n_repeat))[:size]
|
| 303 |
+
copy_tensor(v, res[i])
|
| 304 |
+
|
| 305 |
+
return res
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def collate_2d_tile(values, shift_right=False, min_len=None,max_len=None,min_factor=None):
|
| 309 |
+
"""Collate 2d for melspec,Convert a list of 2d tensors into a padded 3d tensor,pad in mel_length dimension. """
|
| 310 |
+
size = max(v.shape[1] for v in values) # if max_len is None else max_len
|
| 311 |
+
if max_len:
|
| 312 |
+
size = min(size,max_len)
|
| 313 |
+
if min_len:
|
| 314 |
+
size = max(size,min_len)
|
| 315 |
+
if min_factor and (size % min_factor!=0):# size must be the multiple of min_factor
|
| 316 |
+
size += (min_factor - size % min_factor)
|
| 317 |
+
|
| 318 |
+
if isinstance(values,np.ndarray):
|
| 319 |
+
values = torch.FloatTensor(values)
|
| 320 |
+
if isinstance(values,list):
|
| 321 |
+
values = [torch.FloatTensor(v) for v in values]
|
| 322 |
+
res = torch.zeros(len(values), values[0].shape[0],size).to(dtype=torch.float32)
|
| 323 |
+
|
| 324 |
+
def copy_tensor(src, dst):
|
| 325 |
+
assert dst.numel() == src.numel()
|
| 326 |
+
if shift_right:
|
| 327 |
+
dst[1:] = src[:-1]
|
| 328 |
+
else:
|
| 329 |
+
dst.copy_(src)
|
| 330 |
+
|
| 331 |
+
for i, v in enumerate(values):
|
| 332 |
+
n_repeat = math.ceil((size + 1) / v.shape[1])
|
| 333 |
+
v = torch.tile(v,dims=(1,n_repeat))[:,:size]
|
| 334 |
+
copy_tensor(v, res[i])
|
| 335 |
+
|
| 336 |
+
return res
|
ldm/data/joinaudiodataset_struct_sample.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import logging
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import glob
|
| 7 |
+
logger = logging.getLogger(f'main.{__name__}')
|
| 8 |
+
|
| 9 |
+
sys.path.insert(0, '.') # nopep8
|
| 10 |
+
|
| 11 |
+
class JoinManifestSpecs(torch.utils.data.Dataset):
|
| 12 |
+
def __init__(self, split, main_spec_dir_path,other_spec_dir_path, mel_num=None, spec_crop_len=None,pad_value=-5,**kwargs):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.main_prob = 0.5
|
| 15 |
+
self.split = split
|
| 16 |
+
self.batch_max_length = spec_crop_len
|
| 17 |
+
self.batch_min_length = 50
|
| 18 |
+
self.mel_num = mel_num
|
| 19 |
+
self.pad_value = pad_value
|
| 20 |
+
manifest_files = []
|
| 21 |
+
for dir_path in main_spec_dir_path.split(','):
|
| 22 |
+
manifest_files += glob.glob(f'{dir_path}/*.tsv')
|
| 23 |
+
df_list = [pd.read_csv(manifest,sep='\t') for manifest in manifest_files]
|
| 24 |
+
self.df_main = pd.concat(df_list,ignore_index=True)
|
| 25 |
+
|
| 26 |
+
manifest_files = []
|
| 27 |
+
for dir_path in other_spec_dir_path.split(','):
|
| 28 |
+
manifest_files += glob.glob(f'{dir_path}/*.tsv')
|
| 29 |
+
df_list = [pd.read_csv(manifest,sep='\t') for manifest in manifest_files]
|
| 30 |
+
self.df_other = pd.concat(df_list,ignore_index=True)
|
| 31 |
+
|
| 32 |
+
if split == 'train':
|
| 33 |
+
self.dataset = self.df_main.iloc[100:]
|
| 34 |
+
elif split == 'valid' or split == 'val':
|
| 35 |
+
self.dataset = self.df_main.iloc[:100]
|
| 36 |
+
elif split == 'test':
|
| 37 |
+
self.df_main = self.add_name_num(self.df_main)
|
| 38 |
+
self.dataset = self.df_main
|
| 39 |
+
else:
|
| 40 |
+
raise ValueError(f'Unknown split {split}')
|
| 41 |
+
self.dataset.reset_index(inplace=True)
|
| 42 |
+
print('dataset len:', len(self.dataset))
|
| 43 |
+
|
| 44 |
+
def add_name_num(self,df):
|
| 45 |
+
"""each file may have different caption, we add num to filename to identify each audio-caption pair"""
|
| 46 |
+
name_count_dict = {}
|
| 47 |
+
change = []
|
| 48 |
+
for t in df.itertuples():
|
| 49 |
+
name = getattr(t,'name')
|
| 50 |
+
if name in name_count_dict:
|
| 51 |
+
name_count_dict[name] += 1
|
| 52 |
+
else:
|
| 53 |
+
name_count_dict[name] = 0
|
| 54 |
+
change.append((t[0],name_count_dict[name]))
|
| 55 |
+
for t in change:
|
| 56 |
+
df.loc[t[0],'name'] = df.loc[t[0],'name'] + f'_{t[1]}'
|
| 57 |
+
return df
|
| 58 |
+
|
| 59 |
+
def __getitem__(self, idx):
|
| 60 |
+
if np.random.uniform(0,1) < self.main_prob:
|
| 61 |
+
data = self.dataset.iloc[idx]
|
| 62 |
+
ori_caption = data['ori_cap']
|
| 63 |
+
struct_caption = data['caption']
|
| 64 |
+
else:
|
| 65 |
+
randidx = np.random.randint(0,len(self.df_other))
|
| 66 |
+
data = self.df_other.iloc[randidx]
|
| 67 |
+
ori_caption = data['caption']
|
| 68 |
+
struct_caption = f'<{ori_caption}, all>'
|
| 69 |
+
item = {}
|
| 70 |
+
try:
|
| 71 |
+
spec = np.load(data['mel_path']) # mel spec [80, 624]
|
| 72 |
+
except:
|
| 73 |
+
mel_path = data['mel_path']
|
| 74 |
+
print(f'corrupted:{mel_path}')
|
| 75 |
+
spec = np.ones((self.mel_num,self.batch_max_length)).astype(np.float32)*self.pad_value
|
| 76 |
+
|
| 77 |
+
if spec.shape[1] <= self.batch_max_length:
|
| 78 |
+
spec = np.pad(spec, ((0, 0), (0, self.batch_max_length - spec.shape[1])),mode='constant',constant_values = (self.pad_value,self.pad_value)) # [80, 624]
|
| 79 |
+
|
| 80 |
+
item['image'] = spec[:self.mel_num,:self.batch_max_length]
|
| 81 |
+
item["caption"] = {"ori_caption":ori_caption,"struct_caption":struct_caption}
|
| 82 |
+
if self.split == 'test':
|
| 83 |
+
item['f_name'] = data['name']
|
| 84 |
+
return item
|
| 85 |
+
|
| 86 |
+
def __len__(self):
|
| 87 |
+
return len(self.dataset)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class JoinSpecsTrain(JoinManifestSpecs):
|
| 91 |
+
def __init__(self, specs_dataset_cfg):
|
| 92 |
+
super().__init__('train', **specs_dataset_cfg)
|
| 93 |
+
|
| 94 |
+
class JoinSpecsValidation(JoinManifestSpecs):
|
| 95 |
+
def __init__(self, specs_dataset_cfg):
|
| 96 |
+
super().__init__('valid', **specs_dataset_cfg)
|
| 97 |
+
|
| 98 |
+
class JoinSpecsTest(JoinManifestSpecs):
|
| 99 |
+
def __init__(self, specs_dataset_cfg):
|
| 100 |
+
super().__init__('test', **specs_dataset_cfg)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
|
ldm/data/joinaudiodataset_struct_sample_anylen.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from typing import TypeVar, Optional, Iterator
|
| 5 |
+
import logging
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from ldm.data.joinaudiodataset_anylen import *
|
| 8 |
+
import glob
|
| 9 |
+
logger = logging.getLogger(f'main.{__name__}')
|
| 10 |
+
|
| 11 |
+
sys.path.insert(0, '.') # nopep8
|
| 12 |
+
|
| 13 |
+
class JoinManifestSpecs(torch.utils.data.Dataset):
|
| 14 |
+
def __init__(self, split, main_spec_dir_path,other_spec_dir_path, mel_num=80,mode='pad', spec_crop_len=1248,pad_value=-5,drop=0,**kwargs):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.split = split
|
| 17 |
+
self.max_batch_len = spec_crop_len
|
| 18 |
+
self.min_batch_len = 64
|
| 19 |
+
self.min_factor = 4
|
| 20 |
+
self.mel_num = mel_num
|
| 21 |
+
self.drop = drop
|
| 22 |
+
self.pad_value = pad_value
|
| 23 |
+
assert mode in ['pad','tile']
|
| 24 |
+
self.collate_mode = mode
|
| 25 |
+
manifest_files = []
|
| 26 |
+
|
| 27 |
+
for dir_path in main_spec_dir_path.split(','):
|
| 28 |
+
manifest_files += glob.glob(f'{dir_path}/*.tsv')
|
| 29 |
+
df_list = [pd.read_csv(manifest,sep='\t') for manifest in manifest_files]
|
| 30 |
+
self.df_main = pd.concat(df_list,ignore_index=True)
|
| 31 |
+
|
| 32 |
+
manifest_files = []
|
| 33 |
+
for dir_path in other_spec_dir_path.split(','):
|
| 34 |
+
manifest_files += glob.glob(f'{dir_path}/*.tsv')
|
| 35 |
+
df_list = [pd.read_csv(manifest,sep='\t') for manifest in manifest_files]
|
| 36 |
+
# import ipdb
|
| 37 |
+
# ipdb.set_trace()
|
| 38 |
+
self.df_other = pd.concat(df_list,ignore_index=True)
|
| 39 |
+
self.df_other.reset_index(inplace=True)
|
| 40 |
+
|
| 41 |
+
if split == 'train':
|
| 42 |
+
self.dataset = self.df_main.iloc[100:]
|
| 43 |
+
elif split == 'valid' or split == 'val':
|
| 44 |
+
self.dataset = self.df_main.iloc[:100]
|
| 45 |
+
elif split == 'test':
|
| 46 |
+
self.df_main = self.add_name_num(self.df_main)
|
| 47 |
+
self.dataset = self.df_main
|
| 48 |
+
else:
|
| 49 |
+
raise ValueError(f'Unknown split {split}')
|
| 50 |
+
self.dataset.reset_index(inplace=True)
|
| 51 |
+
print('dataset len:', len(self.dataset),"drop_rate",self.drop)
|
| 52 |
+
|
| 53 |
+
def add_name_num(self,df):
|
| 54 |
+
"""each file may have different caption, we add num to filename to identify each audio-caption pair"""
|
| 55 |
+
name_count_dict = {}
|
| 56 |
+
change = []
|
| 57 |
+
for t in df.itertuples():
|
| 58 |
+
name = getattr(t,'name')
|
| 59 |
+
if name in name_count_dict:
|
| 60 |
+
name_count_dict[name] += 1
|
| 61 |
+
else:
|
| 62 |
+
name_count_dict[name] = 0
|
| 63 |
+
change.append((t[0],name_count_dict[name]))
|
| 64 |
+
for t in change:
|
| 65 |
+
df.loc[t[0],'name'] = str(df.loc[t[0],'name']) + f'_{t[1]}'
|
| 66 |
+
return df
|
| 67 |
+
|
| 68 |
+
def ordered_indices(self):
|
| 69 |
+
index2dur = self.dataset[['duration']].sort_values(by='duration')
|
| 70 |
+
index2dur_other = self.df_other[['duration']].sort_values(by='duration')
|
| 71 |
+
other_indices = list(index2dur_other.index)
|
| 72 |
+
offset = len(self.dataset)
|
| 73 |
+
other_indices = [x + offset for x in other_indices]
|
| 74 |
+
return list(index2dur.index),other_indices
|
| 75 |
+
# return list(index2dur.index)
|
| 76 |
+
|
| 77 |
+
def collater(self,inputs):
|
| 78 |
+
to_dict = {}
|
| 79 |
+
for l in inputs:
|
| 80 |
+
for k,v in l.items():
|
| 81 |
+
if k in to_dict:
|
| 82 |
+
to_dict[k].append(v)
|
| 83 |
+
else:
|
| 84 |
+
to_dict[k] = [v]
|
| 85 |
+
|
| 86 |
+
if self.collate_mode == 'pad':
|
| 87 |
+
to_dict['image'] = collate_1d_or_2d(to_dict['image'],pad_idx=self.pad_value,min_len = self.min_batch_len,max_len=self.max_batch_len,min_factor=self.min_factor)
|
| 88 |
+
elif self.collate_mode == 'tile':
|
| 89 |
+
to_dict['image'] = collate_1d_or_2d_tile(to_dict['image'],min_len = self.min_batch_len,max_len=self.max_batch_len,min_factor=self.min_factor)
|
| 90 |
+
else:
|
| 91 |
+
raise NotImplementedError
|
| 92 |
+
to_dict['caption'] = {'ori_caption':[c['ori_caption'] for c in to_dict['caption']],
|
| 93 |
+
'struct_caption':[c['struct_caption'] for c in to_dict['caption']]}
|
| 94 |
+
|
| 95 |
+
return to_dict
|
| 96 |
+
|
| 97 |
+
def __getitem__(self, idx):
|
| 98 |
+
if idx < len(self.dataset):
|
| 99 |
+
data = self.dataset.iloc[idx]
|
| 100 |
+
# p = np.random.uniform(0,1)
|
| 101 |
+
# if p > self.drop:
|
| 102 |
+
ori_caption = data['ori_cap']
|
| 103 |
+
struct_caption = data['caption']
|
| 104 |
+
# else:
|
| 105 |
+
# ori_caption = ""
|
| 106 |
+
# struct_caption = ""
|
| 107 |
+
else:
|
| 108 |
+
data = self.df_other.iloc[idx-len(self.dataset)]
|
| 109 |
+
# p = np.random.uniform(0,1)
|
| 110 |
+
# if p > self.drop:
|
| 111 |
+
ori_caption = data['caption']
|
| 112 |
+
struct_caption = f'<{ori_caption}& all>'
|
| 113 |
+
# else:
|
| 114 |
+
# ori_caption = ""
|
| 115 |
+
# struct_caption = ""
|
| 116 |
+
item = {}
|
| 117 |
+
try:
|
| 118 |
+
spec = np.load(data['mel_path']) # mel spec [80, T]
|
| 119 |
+
if spec.shape[1] > self.max_batch_len:
|
| 120 |
+
spec = spec[:,:self.max_batch_len]
|
| 121 |
+
except:
|
| 122 |
+
mel_path = data['mel_path']
|
| 123 |
+
print(f'corrupted:{mel_path}')
|
| 124 |
+
spec = np.ones((self.mel_num,self.min_batch_len)).astype(np.float32)*self.pad_value
|
| 125 |
+
|
| 126 |
+
item['image'] = spec
|
| 127 |
+
item["caption"] = {"ori_caption":ori_caption,"struct_caption":struct_caption}
|
| 128 |
+
if self.split == 'test':
|
| 129 |
+
item['f_name'] = data['name']
|
| 130 |
+
return item
|
| 131 |
+
|
| 132 |
+
def __len__(self):
|
| 133 |
+
return len(self.dataset) + len(self.df_other)
|
| 134 |
+
# return len(self.dataset)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class JoinSpecsTrain(JoinManifestSpecs):
|
| 138 |
+
def __init__(self, specs_dataset_cfg):
|
| 139 |
+
super().__init__('train', **specs_dataset_cfg)
|
| 140 |
+
|
| 141 |
+
class JoinSpecsValidation(JoinManifestSpecs):
|
| 142 |
+
def __init__(self, specs_dataset_cfg):
|
| 143 |
+
super().__init__('valid', **specs_dataset_cfg)
|
| 144 |
+
|
| 145 |
+
class JoinSpecsTest(JoinManifestSpecs):
|
| 146 |
+
def __init__(self, specs_dataset_cfg):
|
| 147 |
+
super().__init__('test', **specs_dataset_cfg)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class DDPIndexBatchSampler(Sampler):# 让长度相似的音频的indices合到一个batch中以避免过长的pad
|
| 152 |
+
def __init__(self, main_indices,other_indices,batch_size, num_replicas: Optional[int] = None,
|
| 153 |
+
# def __init__(self, main_indices,batch_size, num_replicas: Optional[int] = None,
|
| 154 |
+
rank: Optional[int] = None, shuffle: bool = True,
|
| 155 |
+
seed: int = 0, drop_last: bool = False) -> None:
|
| 156 |
+
if num_replicas is None:
|
| 157 |
+
if not dist.is_initialized():
|
| 158 |
+
# raise RuntimeError("Requires distributed package to be available")
|
| 159 |
+
print("Not in distributed mode")
|
| 160 |
+
num_replicas = 1
|
| 161 |
+
else:
|
| 162 |
+
num_replicas = dist.get_world_size()
|
| 163 |
+
if rank is None:
|
| 164 |
+
if not dist.is_initialized():
|
| 165 |
+
# raise RuntimeError("Requires distributed package to be available")
|
| 166 |
+
rank = 0
|
| 167 |
+
else:
|
| 168 |
+
rank = dist.get_rank()
|
| 169 |
+
if rank >= num_replicas or rank < 0:
|
| 170 |
+
raise ValueError(
|
| 171 |
+
"Invalid rank {}, rank should be in the interval"
|
| 172 |
+
" [0, {}]".format(rank, num_replicas - 1))
|
| 173 |
+
self.main_indices = main_indices
|
| 174 |
+
self.other_indices = other_indices
|
| 175 |
+
self.max_index = max(self.other_indices)
|
| 176 |
+
self.num_replicas = num_replicas
|
| 177 |
+
self.rank = rank
|
| 178 |
+
self.epoch = 0
|
| 179 |
+
self.drop_last = drop_last
|
| 180 |
+
self.batch_size = batch_size
|
| 181 |
+
self.shuffle = shuffle
|
| 182 |
+
self.batches = self.build_batches()
|
| 183 |
+
self.seed = seed
|
| 184 |
+
|
| 185 |
+
def set_epoch(self,epoch):
|
| 186 |
+
# print("!!!!!!!!!!!set epoch is called!!!!!!!!!!!!!!")
|
| 187 |
+
self.epoch = epoch
|
| 188 |
+
if self.shuffle:
|
| 189 |
+
np.random.seed(self.seed+self.epoch)
|
| 190 |
+
self.batches = self.build_batches()
|
| 191 |
+
|
| 192 |
+
def build_batches(self):
|
| 193 |
+
batches,batch = [],[]
|
| 194 |
+
for index in self.main_indices:
|
| 195 |
+
batch.append(index)
|
| 196 |
+
if len(batch) == self.batch_size:
|
| 197 |
+
batches.append(batch)
|
| 198 |
+
batch = []
|
| 199 |
+
if not self.drop_last and len(batch) > 0:
|
| 200 |
+
batches.append(batch)
|
| 201 |
+
selected_others = np.random.choice(len(self.other_indices),len(batches),replace=False)
|
| 202 |
+
for index in selected_others:
|
| 203 |
+
if index + self.batch_size > len(self.other_indices):
|
| 204 |
+
index = len(self.other_indices) - self.batch_size
|
| 205 |
+
batch = [self.other_indices[index + i] for i in range(self.batch_size)]
|
| 206 |
+
batches.append(batch)
|
| 207 |
+
self.batches = batches
|
| 208 |
+
if self.shuffle:
|
| 209 |
+
self.batches = np.random.permutation(self.batches)
|
| 210 |
+
if self.rank == 0:
|
| 211 |
+
print(f"rank: {self.rank}, batches_num {len(self.batches)}")
|
| 212 |
+
|
| 213 |
+
if self.drop_last and len(self.batches) % self.num_replicas != 0:
|
| 214 |
+
self.batches = self.batches[:len(self.batches)//self.num_replicas*self.num_replicas]
|
| 215 |
+
if len(self.batches) >= self.num_replicas:
|
| 216 |
+
self.batches = self.batches[self.rank::self.num_replicas]
|
| 217 |
+
else: # may happen in sanity checking
|
| 218 |
+
self.batches = [self.batches[0]]
|
| 219 |
+
if self.rank == 0:
|
| 220 |
+
print(f"after split batches_num {len(self.batches)}")
|
| 221 |
+
|
| 222 |
+
return self.batches
|
| 223 |
+
|
| 224 |
+
def __iter__(self) -> Iterator[List[int]]:
|
| 225 |
+
print(f"len(self.batches):{len(self.batches)}")
|
| 226 |
+
for batch in self.batches:
|
| 227 |
+
yield batch
|
| 228 |
+
|
| 229 |
+
def __len__(self) -> int:
|
| 230 |
+
return len(self.batches)
|
ldm/data/preprocess/NAT_mel.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.utils.data
|
| 4 |
+
from librosa.filters import mel as librosa_mel_fn
|
| 5 |
+
from scipy.io.wavfile import read
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
MAX_WAV_VALUE = 32768.0
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def load_wav(full_path):
|
| 13 |
+
sampling_rate, data = read(full_path)
|
| 14 |
+
return data, sampling_rate
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
| 18 |
+
return np.log10(np.clip(x, a_min=clip_val, a_max=None) * C)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def dynamic_range_decompression(x, C=1):
|
| 22 |
+
return np.exp(x) / C
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
| 26 |
+
return torch.log10(torch.clamp(x, min=clip_val) * C)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def dynamic_range_decompression_torch(x, C=1):
|
| 30 |
+
return torch.exp(x) / C
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def spectral_normalize_torch(magnitudes):
|
| 34 |
+
output = dynamic_range_compression_torch(magnitudes)
|
| 35 |
+
return output
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def spectral_de_normalize_torch(magnitudes):
|
| 39 |
+
output = dynamic_range_decompression_torch(magnitudes)
|
| 40 |
+
return output
|
| 41 |
+
|
| 42 |
+
class MelNet(nn.Module):
|
| 43 |
+
def __init__(self,hparams,device='cpu') -> None:
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.n_fft = hparams['fft_size']
|
| 46 |
+
self.num_mels = hparams['audio_num_mel_bins']
|
| 47 |
+
self.sampling_rate = hparams['audio_sample_rate']
|
| 48 |
+
self.hop_size = hparams['hop_size']
|
| 49 |
+
self.win_size = hparams['win_size']
|
| 50 |
+
self.fmin = hparams['fmin']
|
| 51 |
+
self.fmax = hparams['fmax']
|
| 52 |
+
self.device = device
|
| 53 |
+
|
| 54 |
+
mel = librosa_mel_fn(self.sampling_rate, self.n_fft, self.num_mels, self.fmin, self.fmax)
|
| 55 |
+
self.mel_basis = torch.from_numpy(mel).float().to(self.device)
|
| 56 |
+
self.hann_window = torch.hann_window(self.win_size).to(self.device)
|
| 57 |
+
|
| 58 |
+
def to(self,device,**kwagrs):
|
| 59 |
+
super().to(device=device,**kwagrs)
|
| 60 |
+
self.mel_basis = self.mel_basis.to(device)
|
| 61 |
+
self.hann_window = self.hann_window.to(device)
|
| 62 |
+
self.device = device
|
| 63 |
+
|
| 64 |
+
def forward(self,y,center=False, complex=False):
|
| 65 |
+
if isinstance(y,np.ndarray):
|
| 66 |
+
y = torch.FloatTensor(y)
|
| 67 |
+
if len(y.shape) == 1:
|
| 68 |
+
y = y.unsqueeze(0)
|
| 69 |
+
y = y.clamp(min=-1., max=1.).to(self.device)
|
| 70 |
+
|
| 71 |
+
y = torch.nn.functional.pad(y.unsqueeze(1), [int((self.n_fft - self.hop_size) / 2), int((self.n_fft - self.hop_size) / 2)],
|
| 72 |
+
mode='reflect')
|
| 73 |
+
y = y.squeeze(1)
|
| 74 |
+
|
| 75 |
+
spec = torch.stft(y, self.n_fft, hop_length=self.hop_size, win_length=self.win_size, window=self.hann_window,
|
| 76 |
+
center=center, pad_mode='reflect', normalized=False, onesided=True,return_complex=complex)
|
| 77 |
+
|
| 78 |
+
if not complex:
|
| 79 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
| 80 |
+
spec = torch.matmul(self.mel_basis, spec)
|
| 81 |
+
spec = spectral_normalize_torch(spec)
|
| 82 |
+
else:
|
| 83 |
+
B, C, T, _ = spec.shape
|
| 84 |
+
spec = spec.transpose(1, 2) # [B, T, n_fft, 2]
|
| 85 |
+
return spec
|
| 86 |
+
|
| 87 |
+
## below can be used in one gpu, but not ddp
|
| 88 |
+
mel_basis = {}
|
| 89 |
+
hann_window = {}
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def mel_spectrogram(y, hparams, center=False, complex=False): # y should be a tensor with shape (b,wav_len)
|
| 93 |
+
# hop_size: 512 # For 22050Hz, 275 ~= 12.5 ms (0.0125 * sample_rate)
|
| 94 |
+
# win_size: 2048 # For 22050Hz, 1100 ~= 50 ms (If None, win_size: fft_size) (0.05 * sample_rate)
|
| 95 |
+
# fmin: 55 # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
|
| 96 |
+
# fmax: 10000 # To be increased/reduced depending on data.
|
| 97 |
+
# fft_size: 2048 # Extra window size is filled with 0 paddings to match this parameter
|
| 98 |
+
# n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax,
|
| 99 |
+
n_fft = hparams['fft_size']
|
| 100 |
+
num_mels = hparams['audio_num_mel_bins']
|
| 101 |
+
sampling_rate = hparams['audio_sample_rate']
|
| 102 |
+
hop_size = hparams['hop_size']
|
| 103 |
+
win_size = hparams['win_size']
|
| 104 |
+
fmin = hparams['fmin']
|
| 105 |
+
fmax = hparams['fmax']
|
| 106 |
+
if isinstance(y,np.ndarray):
|
| 107 |
+
y = torch.FloatTensor(y)
|
| 108 |
+
if len(y.shape) == 1:
|
| 109 |
+
y = y.unsqueeze(0)
|
| 110 |
+
y = y.clamp(min=-1., max=1.)
|
| 111 |
+
global mel_basis, hann_window
|
| 112 |
+
if fmax not in mel_basis:
|
| 113 |
+
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
|
| 114 |
+
mel_basis[str(fmax) + '_' + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
| 115 |
+
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
| 116 |
+
|
| 117 |
+
y = torch.nn.functional.pad(y.unsqueeze(1), [int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)],
|
| 118 |
+
mode='reflect')
|
| 119 |
+
y = y.squeeze(1)
|
| 120 |
+
|
| 121 |
+
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
|
| 122 |
+
center=center, pad_mode='reflect', normalized=False, onesided=True,return_complex=complex)
|
| 123 |
+
|
| 124 |
+
if not complex:
|
| 125 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
| 126 |
+
spec = torch.matmul(mel_basis[str(fmax) + '_' + str(y.device)], spec)
|
| 127 |
+
spec = spectral_normalize_torch(spec)
|
| 128 |
+
else:
|
| 129 |
+
B, C, T, _ = spec.shape
|
| 130 |
+
spec = spec.transpose(1, 2) # [B, T, n_fft, 2]
|
| 131 |
+
return spec
|
ldm/data/preprocess/__pycache__/NAT_mel.cpython-38.pyc
ADDED
|
Binary file (4.25 kB). View file
|
|
|
ldm/data/preprocess/__pycache__/NAT_mel.cpython-39.pyc
ADDED
|
Binary file (4.23 kB). View file
|
|
|
ldm/data/preprocess/add_duration.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import audioread
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
from tqdm.contrib.concurrent import process_map
|
| 5 |
+
|
| 6 |
+
def map_duration(tsv_withdur,tsv_toadd):# tsv_withdur 和 tsv_toadd 'name'列相同且tsv_withdur有duration信息,目标是给tsv_toadd的相同行加上duration信息。
|
| 7 |
+
df1 = pd.read_csv(tsv_withdur,sep='\t')
|
| 8 |
+
df2 = pd.read_csv(tsv_toadd,sep='\t')
|
| 9 |
+
|
| 10 |
+
df = df2.merge(df1,on=['name'],suffixes=['','_y'])
|
| 11 |
+
dropset = list(set(df.columns) - set(df1.columns))
|
| 12 |
+
df = df.drop(dropset,axis=1)
|
| 13 |
+
df.to_csv(tsv_toadd,sep='\t',index=False)
|
| 14 |
+
return df
|
| 15 |
+
|
| 16 |
+
def add_duration(args):
|
| 17 |
+
index,audiopath = args
|
| 18 |
+
try:
|
| 19 |
+
with audioread.audio_open(audiopath) as f:
|
| 20 |
+
totalsec = f.duration
|
| 21 |
+
except:
|
| 22 |
+
totalsec = -1
|
| 23 |
+
return (index,totalsec)
|
| 24 |
+
|
| 25 |
+
def add_dur2tsv(tsv_path,save_path):
|
| 26 |
+
df = pd.read_csv(tsv_path,sep='\t')
|
| 27 |
+
item_list = []
|
| 28 |
+
for item in tqdm(df.itertuples()):
|
| 29 |
+
item_list.append((item[0],getattr(item,'audio_path')))
|
| 30 |
+
|
| 31 |
+
r = process_map(add_duration,item_list,max_workers=16,chunksize=32)
|
| 32 |
+
index2dur = {}
|
| 33 |
+
for index,dur in r:
|
| 34 |
+
if dur == -1:
|
| 35 |
+
bad_wav = df.loc[index,'audio_path']
|
| 36 |
+
print(f'bad wav:{bad_wav}')
|
| 37 |
+
index2dur[index] = dur
|
| 38 |
+
|
| 39 |
+
df['duration'] = df.index.map(index2dur)
|
| 40 |
+
df.to_csv(save_path,sep='\t',index=False)
|
| 41 |
+
|
| 42 |
+
if __name__ == '__main__':
|
| 43 |
+
add_dur2tsv('/root/autodl-tmp/liuhuadai/AudioLCM/now.tsv','/root/autodl-tmp/liuhuadai/AudioLCM/now_duration.tsv')
|
| 44 |
+
#map_duration(tsv_withdur='tsv_maker/filter_audioset.tsv',
|
| 45 |
+
# tsv_toadd='MAA1 Dataset tsvs/V3/refilter_audioset.tsv')
|
ldm/data/preprocess/mel_spec.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ldm.data.preprocess.NAT_mel import MelNet
|
| 2 |
+
import os
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
from glob import glob
|
| 5 |
+
import math
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import logging
|
| 8 |
+
import math
|
| 9 |
+
import audioread
|
| 10 |
+
from tqdm.contrib.concurrent import process_map
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torchaudio
|
| 14 |
+
import numpy as np
|
| 15 |
+
from torch.distributed import init_process_group
|
| 16 |
+
from torch.utils.data import Dataset,DataLoader,DistributedSampler
|
| 17 |
+
import torch.multiprocessing as mp
|
| 18 |
+
from argparse import Namespace
|
| 19 |
+
from multiprocessing import Pool
|
| 20 |
+
import json
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class tsv_dataset(Dataset):
|
| 24 |
+
def __init__(self,tsv_path,sr,mode='none',hop_size = None,target_mel_length = None) -> None:
|
| 25 |
+
super().__init__()
|
| 26 |
+
if os.path.isdir(tsv_path):
|
| 27 |
+
files = glob(os.path.join(tsv_path,'*.tsv'))
|
| 28 |
+
df = pd.concat([pd.read_csv(file,sep='\t') for file in files])
|
| 29 |
+
else:
|
| 30 |
+
df = pd.read_csv(tsv_path,sep='\t')
|
| 31 |
+
self.audio_paths = []
|
| 32 |
+
self.sr = sr
|
| 33 |
+
self.mode = mode
|
| 34 |
+
self.target_mel_length = target_mel_length
|
| 35 |
+
self.hop_size = hop_size
|
| 36 |
+
for t in tqdm(df.itertuples()):
|
| 37 |
+
self.audio_paths.append(getattr(t,'audio_path'))
|
| 38 |
+
|
| 39 |
+
def __len__(self):
|
| 40 |
+
return len(self.audio_paths)
|
| 41 |
+
|
| 42 |
+
def pad_wav(self,wav):
|
| 43 |
+
# wav should be in shape(1,wav_len)
|
| 44 |
+
wav_length = wav.shape[-1]
|
| 45 |
+
assert wav_length > 100, "wav is too short, %s" % wav_length
|
| 46 |
+
segment_length = (self.target_mel_length + 1) * self.hop_size # final mel will crop the last mel, mel = mel[:,:-1]
|
| 47 |
+
if segment_length is None or wav_length == segment_length:
|
| 48 |
+
return wav
|
| 49 |
+
elif wav_length > segment_length:
|
| 50 |
+
return wav[:,:segment_length]
|
| 51 |
+
elif wav_length < segment_length:
|
| 52 |
+
temp_wav = torch.zeros((1, segment_length),dtype=torch.float32)
|
| 53 |
+
temp_wav[:, :wav_length] = wav
|
| 54 |
+
return temp_wav
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def __getitem__(self, index):
|
| 58 |
+
audio_path = self.audio_paths[index]
|
| 59 |
+
wav, orisr = torchaudio.load(audio_path)
|
| 60 |
+
if wav.shape[0] != 1: # stereo to mono (2,wav_len) -> (1,wav_len)
|
| 61 |
+
wav = wav.mean(0,keepdim=True)
|
| 62 |
+
wav = torchaudio.functional.resample(wav, orig_freq=orisr, new_freq=self.sr)
|
| 63 |
+
if self.mode == 'pad':
|
| 64 |
+
assert self.target_mel_length is not None
|
| 65 |
+
wav = self.pad_wav(wav)
|
| 66 |
+
return audio_path,wav
|
| 67 |
+
|
| 68 |
+
def process_audio_by_tsv(rank,args):
|
| 69 |
+
if args.num_gpus > 1:
|
| 70 |
+
init_process_group(backend=args.dist_config['dist_backend'], init_method=args.dist_config['dist_url'],
|
| 71 |
+
world_size=args.dist_config['world_size'] * args.num_gpus, rank=rank)
|
| 72 |
+
|
| 73 |
+
sr = args.audio_sample_rate
|
| 74 |
+
dataset = tsv_dataset(args.tsv_path,sr = sr,mode=args.mode,hop_size=args.hop_size,target_mel_length=args.batch_max_length)
|
| 75 |
+
sampler = DistributedSampler(dataset,shuffle=False) if args.num_gpus > 1 else None
|
| 76 |
+
# batch_size must == 1,since wav_len is not equal
|
| 77 |
+
loader = DataLoader(dataset, sampler=sampler,batch_size=1, num_workers=16,drop_last=False)
|
| 78 |
+
|
| 79 |
+
device = torch.device('cuda:{:d}'.format(rank))
|
| 80 |
+
|
| 81 |
+
mel_net = MelNet(args.__dict__)
|
| 82 |
+
mel_net.to(device)
|
| 83 |
+
# if args.num_gpus > 1: # RuntimeError: DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient.
|
| 84 |
+
# mel_net = DistributedDataParallel(mel_net, device_ids=[rank]).to(device)
|
| 85 |
+
|
| 86 |
+
loader = tqdm(loader) if rank == 0 else loader
|
| 87 |
+
for batch in loader:
|
| 88 |
+
audio_paths,wavs = batch
|
| 89 |
+
wavs = wavs.to(device)
|
| 90 |
+
if args.save_resample:
|
| 91 |
+
for audio_path,wav in zip(audio_paths,wavs):
|
| 92 |
+
psplits = audio_path.split('/')
|
| 93 |
+
root,wav_name = psplits[0],psplits[-1]
|
| 94 |
+
# save resample
|
| 95 |
+
resample_root,resample_name = root+f'_{sr}',wav_name[:-4]+'_audio.npy'
|
| 96 |
+
resample_dir_name = os.path.join(resample_root,*psplits[1:-1])
|
| 97 |
+
resample_path = os.path.join(resample_dir_name,resample_name)
|
| 98 |
+
os.makedirs(resample_dir_name,exist_ok=True)
|
| 99 |
+
np.save(resample_path,wav.cpu().numpy().squeeze(0))
|
| 100 |
+
|
| 101 |
+
if args.save_mel:
|
| 102 |
+
mode = args.mode
|
| 103 |
+
batch_max_length = args.batch_max_length
|
| 104 |
+
|
| 105 |
+
for audio_path,wav in zip(audio_paths,wavs):
|
| 106 |
+
psplits = audio_path.split('/')
|
| 107 |
+
root,wav_name = psplits[0],psplits[-1]
|
| 108 |
+
mel_root,mel_name = root+f'_mel{mode}{sr}nfft{args.fft_size}',wav_name[:-4]+'_mel.npy'
|
| 109 |
+
mel_dir_name = os.path.join(mel_root,*psplits[1:-1])
|
| 110 |
+
mel_path = os.path.join(mel_dir_name,mel_name)
|
| 111 |
+
if not os.path.exists(mel_path):
|
| 112 |
+
mel_spec = mel_net(wav).cpu().numpy().squeeze(0) # (mel_bins,mel_len)
|
| 113 |
+
if mel_spec.shape[1] <= batch_max_length:
|
| 114 |
+
if mode == 'tile': # pad is done in dataset as pad wav
|
| 115 |
+
n_repeat = math.ceil((batch_max_length + 1) / mel_spec.shape[1])
|
| 116 |
+
mel_spec = np.tile(mel_spec,reps=(1,n_repeat))
|
| 117 |
+
elif mode == 'none' or mode == 'pad':
|
| 118 |
+
pass
|
| 119 |
+
else:
|
| 120 |
+
raise ValueError(f'mode:{mode} is not supported')
|
| 121 |
+
mel_spec = mel_spec[:,:batch_max_length]
|
| 122 |
+
os.makedirs(mel_dir_name,exist_ok=True)
|
| 123 |
+
np.save(mel_path,mel_spec)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def split_list(i_list,num):
|
| 127 |
+
each_num = math.ceil(i_list / num)
|
| 128 |
+
result = []
|
| 129 |
+
for i in range(num):
|
| 130 |
+
s = each_num * i
|
| 131 |
+
e = (each_num * (i+1))
|
| 132 |
+
result.append(i_list[s:e])
|
| 133 |
+
return result
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def drop_bad_wav(item):
|
| 137 |
+
index,path = item
|
| 138 |
+
try:
|
| 139 |
+
with audioread.audio_open(path) as f:
|
| 140 |
+
totalsec = f.duration
|
| 141 |
+
if totalsec < 0.1:
|
| 142 |
+
return index # index
|
| 143 |
+
except:
|
| 144 |
+
print(f"corrupted wav:{path}")
|
| 145 |
+
return index
|
| 146 |
+
return False
|
| 147 |
+
|
| 148 |
+
def drop_bad_wavs(tsv_path):# 'audioset.csv'
|
| 149 |
+
df = pd.read_csv(tsv_path,sep='\t')
|
| 150 |
+
item_list = []
|
| 151 |
+
for item in tqdm(df.itertuples()):
|
| 152 |
+
item_list.append((item[0],getattr(item,'audio_path')))
|
| 153 |
+
|
| 154 |
+
r = process_map(drop_bad_wav,item_list,max_workers=16,chunksize=16)
|
| 155 |
+
bad_indices = list(filter(lambda x:x!= False,r))
|
| 156 |
+
|
| 157 |
+
print(bad_indices)
|
| 158 |
+
with open('bad_wavs.json','w') as f:
|
| 159 |
+
x = [item_list[i] for i in bad_indices]
|
| 160 |
+
json.dump(x,f)
|
| 161 |
+
df = df.drop(bad_indices,axis=0)
|
| 162 |
+
df.to_csv(tsv_path,sep='\t',index=False)
|
| 163 |
+
|
| 164 |
+
if __name__ == '__main__':
|
| 165 |
+
logging.basicConfig(filename='example.log', level=logging.INFO,
|
| 166 |
+
format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p')
|
| 167 |
+
tsv_path = './musiccap.tsv'
|
| 168 |
+
if os.path.isdir(tsv_path):
|
| 169 |
+
files = glob(os.path.join(tsv_path,'*.tsv'))
|
| 170 |
+
for file in files:
|
| 171 |
+
drop_bad_wavs(file)
|
| 172 |
+
else:
|
| 173 |
+
drop_bad_wavs(tsv_path)
|
| 174 |
+
num_gpus = 1
|
| 175 |
+
args = {
|
| 176 |
+
'audio_sample_rate': 16000,
|
| 177 |
+
'audio_num_mel_bins':80,
|
| 178 |
+
'fft_size': 1024,# 4000:512 ,16000:1024,
|
| 179 |
+
'win_size': 1024,
|
| 180 |
+
'hop_size': 256,
|
| 181 |
+
'fmin': 0,
|
| 182 |
+
'fmax': 8000,
|
| 183 |
+
'batch_max_length': 1560, # 4000:312 (nfft = 512,hoplen=128,mellen = 313), 16000:624 , 22050:848 #
|
| 184 |
+
'tsv_path': tsv_path,
|
| 185 |
+
'num_gpus': num_gpus,
|
| 186 |
+
'mode': 'none',
|
| 187 |
+
'save_resample':False,
|
| 188 |
+
'save_mel' :True
|
| 189 |
+
}
|
| 190 |
+
args = Namespace(**args)
|
| 191 |
+
args.dist_config = {
|
| 192 |
+
"dist_backend": "nccl",
|
| 193 |
+
"dist_url": "tcp://localhost:54189",
|
| 194 |
+
"world_size": 1
|
| 195 |
+
}
|
| 196 |
+
if args.num_gpus>1:
|
| 197 |
+
mp.spawn(process_audio_by_tsv,nprocs=args.num_gpus,args=(args,))
|
| 198 |
+
else:
|
| 199 |
+
process_audio_by_tsv(0,args=args)
|
| 200 |
+
print("done")
|
| 201 |
+
|
ldm/data/test.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from typing import TypeVar, Optional, Iterator
|
| 5 |
+
import logging
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from ldm.data.joinaudiodataset_anylen import *
|
| 8 |
+
import glob
|
| 9 |
+
logger = logging.getLogger(f'main.{__name__}')
|
| 10 |
+
|
| 11 |
+
sys.path.insert(0, '.') # nopep8
|
| 12 |
+
|
| 13 |
+
class JoinManifestSpecs(torch.utils.data.Dataset):
|
| 14 |
+
def __init__(self, split, main_spec_dir_path,other_spec_dir_path, mel_num=80,mode='pad', spec_crop_len=1248,pad_value=-5,drop=0,**kwargs):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.split = split
|
| 17 |
+
self.max_batch_len = spec_crop_len
|
| 18 |
+
self.min_batch_len = 64
|
| 19 |
+
self.min_factor = 4
|
| 20 |
+
self.mel_num = mel_num
|
| 21 |
+
self.drop = drop
|
| 22 |
+
self.pad_value = pad_value
|
| 23 |
+
assert mode in ['pad','tile']
|
| 24 |
+
self.collate_mode = mode
|
| 25 |
+
manifest_files = []
|
| 26 |
+
for dir_path in main_spec_dir_path.split(','):
|
| 27 |
+
manifest_files += glob.glob(f'{dir_path}/*.tsv')
|
| 28 |
+
df_list = [pd.read_csv(manifest,sep='\t') for manifest in manifest_files]
|
| 29 |
+
self.df_main = pd.concat(df_list,ignore_index=True)
|
| 30 |
+
|
| 31 |
+
manifest_files = []
|
| 32 |
+
for dir_path in other_spec_dir_path.split(','):
|
| 33 |
+
manifest_files += glob.glob(f'{dir_path}/*.tsv')
|
| 34 |
+
df_list = [pd.read_csv(manifest,sep='\t') for manifest in manifest_files]
|
| 35 |
+
self.df_other = pd.concat(df_list,ignore_index=True)
|
| 36 |
+
self.df_other.reset_index(inplace=True)
|
| 37 |
+
|
| 38 |
+
if split == 'train':
|
| 39 |
+
self.dataset = self.df_main.iloc[100:]
|
| 40 |
+
elif split == 'valid' or split == 'val':
|
| 41 |
+
self.dataset = self.df_main.iloc[:100]
|
| 42 |
+
elif split == 'test':
|
| 43 |
+
self.df_main = self.add_name_num(self.df_main)
|
| 44 |
+
self.dataset = self.df_main
|
| 45 |
+
else:
|
| 46 |
+
raise ValueError(f'Unknown split {split}')
|
| 47 |
+
self.dataset.reset_index(inplace=True)
|
| 48 |
+
print('dataset len:', len(self.dataset),"drop_rate",self.drop)
|
| 49 |
+
|
| 50 |
+
def add_name_num(self,df):
|
| 51 |
+
"""each file may have different caption, we add num to filename to identify each audio-caption pair"""
|
| 52 |
+
name_count_dict = {}
|
| 53 |
+
change = []
|
| 54 |
+
for t in df.itertuples():
|
| 55 |
+
name = getattr(t,'name')
|
| 56 |
+
if name in name_count_dict:
|
| 57 |
+
name_count_dict[name] += 1
|
| 58 |
+
else:
|
| 59 |
+
name_count_dict[name] = 0
|
| 60 |
+
change.append((t[0],name_count_dict[name]))
|
| 61 |
+
for t in change:
|
| 62 |
+
df.loc[t[0],'name'] = str(df.loc[t[0],'name']) + f'_{t[1]}'
|
| 63 |
+
return df
|
| 64 |
+
|
| 65 |
+
def ordered_indices(self):
|
| 66 |
+
index2dur = self.dataset[['duration']].sort_values(by='duration')
|
| 67 |
+
index2dur_other = self.df_other[['duration']].sort_values(by='duration')
|
| 68 |
+
other_indices = list(index2dur_other.index)
|
| 69 |
+
offset = len(self.dataset)
|
| 70 |
+
other_indices = [x + offset for x in other_indices]
|
| 71 |
+
return list(index2dur.index),other_indices
|
| 72 |
+
|
| 73 |
+
def collater(self,inputs):
|
| 74 |
+
to_dict = {}
|
| 75 |
+
for l in inputs:
|
| 76 |
+
for k,v in l.items():
|
| 77 |
+
if k in to_dict:
|
| 78 |
+
to_dict[k].append(v)
|
| 79 |
+
else:
|
| 80 |
+
to_dict[k] = [v]
|
| 81 |
+
|
| 82 |
+
if self.collate_mode == 'pad':
|
| 83 |
+
to_dict['image'] = collate_1d_or_2d(to_dict['image'],pad_idx=self.pad_value,min_len = self.min_batch_len,max_len=self.max_batch_len,min_factor=self.min_factor)
|
| 84 |
+
elif self.collate_mode == 'tile':
|
| 85 |
+
to_dict['image'] = collate_1d_or_2d_tile(to_dict['image'],min_len = self.min_batch_len,max_len=self.max_batch_len,min_factor=self.min_factor)
|
| 86 |
+
else:
|
| 87 |
+
raise NotImplementedError
|
| 88 |
+
to_dict['caption'] = {'ori_caption':[c['ori_caption'] for c in to_dict['caption']],
|
| 89 |
+
'struct_caption':[c['struct_caption'] for c in to_dict['caption']]}
|
| 90 |
+
|
| 91 |
+
return to_dict
|
| 92 |
+
|
| 93 |
+
def __getitem__(self, idx):
|
| 94 |
+
if idx < len(self.dataset):
|
| 95 |
+
data = self.dataset.iloc[idx]
|
| 96 |
+
p = np.random.uniform(0,1)
|
| 97 |
+
if p > self.drop:
|
| 98 |
+
ori_caption = data['ori_cap']
|
| 99 |
+
struct_caption = data['caption']
|
| 100 |
+
else:
|
| 101 |
+
ori_caption = ""
|
| 102 |
+
struct_caption = ""
|
| 103 |
+
else:
|
| 104 |
+
data = self.df_other.iloc[idx-len(self.dataset)]
|
| 105 |
+
p = np.random.uniform(0,1)
|
| 106 |
+
if p > self.drop:
|
| 107 |
+
ori_caption = data['caption']
|
| 108 |
+
struct_caption = f'<{ori_caption}& all>'
|
| 109 |
+
else:
|
| 110 |
+
ori_caption = ""
|
| 111 |
+
struct_caption = ""
|
| 112 |
+
item = {}
|
| 113 |
+
try:
|
| 114 |
+
spec = np.load(data['mel_path']) # mel spec [80, T]
|
| 115 |
+
if spec.shape[1] > self.max_batch_len:
|
| 116 |
+
spec = spec[:,:self.max_batch_len]
|
| 117 |
+
except:
|
| 118 |
+
mel_path = data['mel_path']
|
| 119 |
+
print(f'corrupted:{mel_path}')
|
| 120 |
+
spec = np.ones((self.mel_num,self.min_batch_len)).astype(np.float32)*self.pad_value
|
| 121 |
+
|
| 122 |
+
item['image'] = spec
|
| 123 |
+
item["caption"] = {"ori_caption":ori_caption,"struct_caption":struct_caption}
|
| 124 |
+
if self.split == 'test':
|
| 125 |
+
item['f_name'] = data['name']
|
| 126 |
+
return item
|
| 127 |
+
|
| 128 |
+
def __len__(self):
|
| 129 |
+
return len(self.dataset) + len(self.df_other)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class JoinSpecsTrain(JoinManifestSpecs):
|
| 133 |
+
def __init__(self, specs_dataset_cfg):
|
| 134 |
+
super().__init__('train', **specs_dataset_cfg)
|
| 135 |
+
|
| 136 |
+
class JoinSpecsValidation(JoinManifestSpecs):
|
| 137 |
+
def __init__(self, specs_dataset_cfg):
|
| 138 |
+
super().__init__('valid', **specs_dataset_cfg)
|
| 139 |
+
|
| 140 |
+
class JoinSpecsTest(JoinManifestSpecs):
|
| 141 |
+
def __init__(self, specs_dataset_cfg):
|
| 142 |
+
super().__init__('test', **specs_dataset_cfg)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class DDPIndexBatchSampler(Sampler):# 让长度相似的音频的indices合到一个batch中以避免过长的pad
|
| 147 |
+
def __init__(self, main_indices,other_indices,batch_size, num_replicas: Optional[int] = None,
|
| 148 |
+
rank: Optional[int] = None, shuffle: bool = True,
|
| 149 |
+
seed: int = 0, drop_last: bool = False) -> None:
|
| 150 |
+
if num_replicas is None:
|
| 151 |
+
if not dist.is_initialized():
|
| 152 |
+
# raise RuntimeError("Requires distributed package to be available")
|
| 153 |
+
print("Not in distributed mode")
|
| 154 |
+
num_replicas = 1
|
| 155 |
+
else:
|
| 156 |
+
num_replicas = dist.get_world_size()
|
| 157 |
+
if rank is None:
|
| 158 |
+
if not dist.is_initialized():
|
| 159 |
+
# raise RuntimeError("Requires distributed package to be available")
|
| 160 |
+
rank = 0
|
| 161 |
+
else:
|
| 162 |
+
rank = dist.get_rank()
|
| 163 |
+
if rank >= num_replicas or rank < 0:
|
| 164 |
+
raise ValueError(
|
| 165 |
+
"Invalid rank {}, rank should be in the interval"
|
| 166 |
+
" [0, {}]".format(rank, num_replicas - 1))
|
| 167 |
+
self.main_indices = main_indices
|
| 168 |
+
self.other_indices = other_indices
|
| 169 |
+
self.max_index = max(self.other_indices)
|
| 170 |
+
self.num_replicas = num_replicas
|
| 171 |
+
self.rank = rank
|
| 172 |
+
self.epoch = 0
|
| 173 |
+
self.drop_last = drop_last
|
| 174 |
+
self.batch_size = batch_size
|
| 175 |
+
self.shuffle = shuffle
|
| 176 |
+
self.batches = self.build_batches()
|
| 177 |
+
self.seed = seed
|
| 178 |
+
|
| 179 |
+
def set_epoch(self,epoch):
|
| 180 |
+
# print("!!!!!!!!!!!set epoch is called!!!!!!!!!!!!!!")
|
| 181 |
+
self.epoch = epoch
|
| 182 |
+
if self.shuffle:
|
| 183 |
+
np.random.seed(self.seed+self.epoch)
|
| 184 |
+
self.batches = self.build_batches()
|
| 185 |
+
|
| 186 |
+
def build_batches(self):
|
| 187 |
+
batches,batch = [],[]
|
| 188 |
+
for index in self.main_indices:
|
| 189 |
+
batch.append(index)
|
| 190 |
+
if len(batch) == self.batch_size:
|
| 191 |
+
batches.append(batch)
|
| 192 |
+
batch = []
|
| 193 |
+
if not self.drop_last and len(batch) > 0:
|
| 194 |
+
batches.append(batch)
|
| 195 |
+
selected_others = np.random.choice(len(self.other_indices),len(batches),replace=False)
|
| 196 |
+
for index in selected_others:
|
| 197 |
+
if index + self.batch_size > len(self.other_indices):
|
| 198 |
+
index = len(self.other_indices) - self.batch_size
|
| 199 |
+
batch = [self.other_indices[index + i] for i in range(self.batch_size)]
|
| 200 |
+
batches.append(batch)
|
| 201 |
+
self.batches = batches
|
| 202 |
+
if self.shuffle:
|
| 203 |
+
self.batches = np.random.permutation(self.batches)
|
| 204 |
+
if self.rank == 0:
|
| 205 |
+
print(f"rank: {self.rank}, batches_num {len(self.batches)}")
|
| 206 |
+
|
| 207 |
+
if self.drop_last and len(self.batches) % self.num_replicas != 0:
|
| 208 |
+
self.batches = self.batches[:len(self.batches)//self.num_replicas*self.num_replicas]
|
| 209 |
+
if len(self.batches) >= self.num_replicas:
|
| 210 |
+
self.batches = self.batches[self.rank::self.num_replicas]
|
| 211 |
+
else: # may happen in sanity checking
|
| 212 |
+
self.batches = [self.batches[0]]
|
| 213 |
+
if self.rank == 0:
|
| 214 |
+
print(f"after split batches_num {len(self.batches)}")
|
| 215 |
+
|
| 216 |
+
return self.batches
|
| 217 |
+
|
| 218 |
+
def __iter__(self) -> Iterator[List[int]]:
|
| 219 |
+
print(f"len(self.batches):{len(self.batches)}")
|
| 220 |
+
for batch in self.batches:
|
| 221 |
+
yield batch
|
| 222 |
+
|
| 223 |
+
def __len__(self) -> int:
|
| 224 |
+
return len(self.batches)
|
ldm/data/tsv_dirs/full_data/V1_new/audiocaps_train_16000.tsv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ldm/data/tsv_dirs/full_data/V2/MACS.tsv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ldm/data/tsv_dirs/full_data/V2/WavText5K.tsv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ldm/data/tsv_dirs/full_data/V2/adobe.tsv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ldm/data/tsv_dirs/full_data/V2/audiostock.tsv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ldm/data/tsv_dirs/full_data/V2/epidemic_sound.tsv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dc67e42c9defa98edfc2c6b23c731fafa4a22307fddfd1fb95ccfc00d0168951
|
| 3 |
+
size 15062608
|
ldm/data/tsv_dirs/full_data/caps_struct/audiocaps_train_16000_struct2.tsv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ldm/data/tsv_dirs/full_data/clotho.tsv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ldm/data/tsvdataset.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from glob import glob
|
| 2 |
+
from torch.utils.data import Dataset
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
|
| 6 |
+
class TSVDataset(Dataset):
|
| 7 |
+
def __init__(self, tsv_path, spec_crop_len=None):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.batch_max_length = spec_crop_len
|
| 10 |
+
self.batch_min_length = 50
|
| 11 |
+
df = pd.read_csv(tsv_path,sep='\t')
|
| 12 |
+
df = self.add_name_num(df)
|
| 13 |
+
self.dataset = df
|
| 14 |
+
print('dataset len:', len(self.dataset))
|
| 15 |
+
|
| 16 |
+
def add_name_num(self,df):
|
| 17 |
+
"""each file may have different caption, we add num to filename to identify each audio-caption pair"""
|
| 18 |
+
name_count_dict = {}
|
| 19 |
+
change = []
|
| 20 |
+
for t in df.itertuples():
|
| 21 |
+
name = getattr(t,'name')
|
| 22 |
+
if name in name_count_dict:
|
| 23 |
+
name_count_dict[name] += 1
|
| 24 |
+
else:
|
| 25 |
+
name_count_dict[name] = 0
|
| 26 |
+
change.append((t[0],name_count_dict[name]))
|
| 27 |
+
for t in change:
|
| 28 |
+
df.loc[t[0],'name'] = df.loc[t[0],'name'] + f'_{t[1]}'
|
| 29 |
+
return df
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def __getitem__(self, idx):
|
| 33 |
+
data = self.dataset.iloc[idx]
|
| 34 |
+
item = {}
|
| 35 |
+
spec = np.load(data['mel_path']) # mel spec [80, 624]
|
| 36 |
+
if spec.shape[1] <= self.batch_max_length:
|
| 37 |
+
spec = np.pad(spec, ((0, 0), (0, self.batch_max_length - spec.shape[1]))) # [80, 624]
|
| 38 |
+
|
| 39 |
+
item['image'] = spec
|
| 40 |
+
item["caption"] = data['caption']
|
| 41 |
+
item["f_name"] = data['name']
|
| 42 |
+
return item
|
| 43 |
+
|
| 44 |
+
def __len__(self):
|
| 45 |
+
return len(self.dataset)
|
| 46 |
+
|
| 47 |
+
class TSVDatasetStruct(TSVDataset):
|
| 48 |
+
def __getitem__(self, idx):
|
| 49 |
+
data = self.dataset.iloc[idx]
|
| 50 |
+
item = {}
|
| 51 |
+
spec = np.load(data['mel_path']) # mel spec [80, 624]
|
| 52 |
+
if spec.shape[1] <= self.batch_max_length:
|
| 53 |
+
spec = np.pad(spec, ((0, 0), (0, self.batch_max_length - spec.shape[1]))) # [80, 624]
|
| 54 |
+
|
| 55 |
+
item['image'] = spec[:,:self.batch_max_length]
|
| 56 |
+
item["caption"] = {'ori_caption':data['ori_cap'],'struct_caption':data['caption']}
|
| 57 |
+
item["f_name"] = data['name']
|
| 58 |
+
return item
|
| 59 |
+
|
| 60 |
+
class TSVDatasetTestFake(TSVDataset):
|
| 61 |
+
def __init__(self, specs_dataset_cfg):
|
| 62 |
+
super().__init__(phase='test', **specs_dataset_cfg)
|
| 63 |
+
self.dataset = [self.dataset[0]]
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
|
ldm/lr_scheduler.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class LambdaWarmUpCosineScheduler:
|
| 5 |
+
"""
|
| 6 |
+
note: use with a base_lr of 1.0
|
| 7 |
+
"""
|
| 8 |
+
def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
|
| 9 |
+
self.lr_warm_up_steps = warm_up_steps
|
| 10 |
+
self.lr_start = lr_start
|
| 11 |
+
self.lr_min = lr_min
|
| 12 |
+
self.lr_max = lr_max
|
| 13 |
+
self.lr_max_decay_steps = max_decay_steps
|
| 14 |
+
self.last_lr = 0.
|
| 15 |
+
self.verbosity_interval = verbosity_interval
|
| 16 |
+
|
| 17 |
+
def schedule(self, n, **kwargs):
|
| 18 |
+
if self.verbosity_interval > 0:
|
| 19 |
+
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
|
| 20 |
+
if n < self.lr_warm_up_steps:
|
| 21 |
+
lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
|
| 22 |
+
self.last_lr = lr
|
| 23 |
+
return lr
|
| 24 |
+
else:
|
| 25 |
+
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
|
| 26 |
+
t = min(t, 1.0)
|
| 27 |
+
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
|
| 28 |
+
1 + np.cos(t * np.pi))
|
| 29 |
+
self.last_lr = lr
|
| 30 |
+
return lr
|
| 31 |
+
|
| 32 |
+
def __call__(self, n, **kwargs):
|
| 33 |
+
return self.schedule(n,**kwargs)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class LambdaWarmUpCosineScheduler2:
|
| 37 |
+
"""
|
| 38 |
+
supports repeated iterations, configurable via lists
|
| 39 |
+
note: use with a base_lr of 1.0.
|
| 40 |
+
"""
|
| 41 |
+
def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
|
| 42 |
+
assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
|
| 43 |
+
self.lr_warm_up_steps = warm_up_steps
|
| 44 |
+
self.f_start = f_start
|
| 45 |
+
self.f_min = f_min
|
| 46 |
+
self.f_max = f_max
|
| 47 |
+
self.cycle_lengths = cycle_lengths
|
| 48 |
+
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
|
| 49 |
+
self.last_f = 0.
|
| 50 |
+
self.verbosity_interval = verbosity_interval
|
| 51 |
+
|
| 52 |
+
def find_in_interval(self, n):
|
| 53 |
+
interval = 0
|
| 54 |
+
for cl in self.cum_cycles[1:]:
|
| 55 |
+
if n <= cl:
|
| 56 |
+
return interval
|
| 57 |
+
interval += 1
|
| 58 |
+
|
| 59 |
+
def schedule(self, n, **kwargs):
|
| 60 |
+
cycle = self.find_in_interval(n)
|
| 61 |
+
n = n - self.cum_cycles[cycle]
|
| 62 |
+
if self.verbosity_interval > 0:
|
| 63 |
+
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
| 64 |
+
f"current cycle {cycle}")
|
| 65 |
+
if n < self.lr_warm_up_steps[cycle]:
|
| 66 |
+
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
| 67 |
+
self.last_f = f
|
| 68 |
+
return f
|
| 69 |
+
else:
|
| 70 |
+
t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
|
| 71 |
+
t = min(t, 1.0)
|
| 72 |
+
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
|
| 73 |
+
1 + np.cos(t * np.pi))
|
| 74 |
+
self.last_f = f
|
| 75 |
+
return f
|
| 76 |
+
|
| 77 |
+
def __call__(self, n, **kwargs):
|
| 78 |
+
return self.schedule(n, **kwargs)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
|
| 82 |
+
|
| 83 |
+
def schedule(self, n, **kwargs):
|
| 84 |
+
cycle = self.find_in_interval(n)
|
| 85 |
+
n = n - self.cum_cycles[cycle]
|
| 86 |
+
if self.verbosity_interval > 0:
|
| 87 |
+
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
| 88 |
+
f"current cycle {cycle}")
|
| 89 |
+
|
| 90 |
+
if n < self.lr_warm_up_steps[cycle]:
|
| 91 |
+
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
| 92 |
+
self.last_f = f
|
| 93 |
+
return f
|
| 94 |
+
else:
|
| 95 |
+
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
|
| 96 |
+
self.last_f = f
|
| 97 |
+
return f
|
| 98 |
+
|
ldm/models/__pycache__/autoencoder.cpython-37.pyc
ADDED
|
Binary file (15.6 kB). View file
|
|
|
ldm/models/__pycache__/autoencoder.cpython-38.pyc
ADDED
|
Binary file (15.5 kB). View file
|
|
|
ldm/models/__pycache__/autoencoder.cpython-39.pyc
ADDED
|
Binary file (14.9 kB). View file
|
|
|
ldm/models/__pycache__/autoencoder1d.cpython-37.pyc
ADDED
|
Binary file (13.5 kB). View file
|
|
|
ldm/models/__pycache__/autoencoder1d.cpython-38.pyc
ADDED
|
Binary file (13.4 kB). View file
|
|
|
ldm/models/__pycache__/autoencoder_multi.cpython-38.pyc
ADDED
|
Binary file (14.8 kB). View file
|
|
|
ldm/models/autoencoder.py
ADDED
|
@@ -0,0 +1,504 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import pytorch_lightning as pl
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from contextlib import contextmanager
|
| 6 |
+
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
| 7 |
+
from packaging import version
|
| 8 |
+
import numpy as np
|
| 9 |
+
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
| 10 |
+
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
| 11 |
+
from torch.optim.lr_scheduler import LambdaLR
|
| 12 |
+
from ldm.util import instantiate_from_config
|
| 13 |
+
from icecream import ic
|
| 14 |
+
|
| 15 |
+
class VQModel(pl.LightningModule):
|
| 16 |
+
def __init__(self,
|
| 17 |
+
ddconfig,
|
| 18 |
+
lossconfig,
|
| 19 |
+
n_embed,
|
| 20 |
+
embed_dim,
|
| 21 |
+
ckpt_path=None,
|
| 22 |
+
ignore_keys=[],
|
| 23 |
+
image_key="image",
|
| 24 |
+
colorize_nlabels=None,
|
| 25 |
+
monitor=None,
|
| 26 |
+
batch_resize_range=None,
|
| 27 |
+
scheduler_config=None,
|
| 28 |
+
lr_g_factor=1.0,
|
| 29 |
+
remap=None,
|
| 30 |
+
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
| 31 |
+
use_ema=False
|
| 32 |
+
):
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.embed_dim = embed_dim
|
| 35 |
+
self.n_embed = n_embed
|
| 36 |
+
self.image_key = image_key
|
| 37 |
+
self.encoder = Encoder(**ddconfig)
|
| 38 |
+
self.decoder = Decoder(**ddconfig)
|
| 39 |
+
self.loss = instantiate_from_config(lossconfig)
|
| 40 |
+
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
|
| 41 |
+
remap=remap,
|
| 42 |
+
sane_index_shape=sane_index_shape)
|
| 43 |
+
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
| 44 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
| 45 |
+
if colorize_nlabels is not None:
|
| 46 |
+
assert type(colorize_nlabels)==int
|
| 47 |
+
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
| 48 |
+
if monitor is not None:
|
| 49 |
+
self.monitor = monitor
|
| 50 |
+
self.batch_resize_range = batch_resize_range
|
| 51 |
+
if self.batch_resize_range is not None:
|
| 52 |
+
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
|
| 53 |
+
|
| 54 |
+
self.use_ema = use_ema
|
| 55 |
+
if self.use_ema:
|
| 56 |
+
self.model_ema = LitEma(self)
|
| 57 |
+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
| 58 |
+
|
| 59 |
+
if ckpt_path is not None:
|
| 60 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
| 61 |
+
self.scheduler_config = scheduler_config
|
| 62 |
+
self.lr_g_factor = lr_g_factor
|
| 63 |
+
|
| 64 |
+
@contextmanager
|
| 65 |
+
def ema_scope(self, context=None):
|
| 66 |
+
if self.use_ema:
|
| 67 |
+
self.model_ema.store(self.parameters())
|
| 68 |
+
self.model_ema.copy_to(self)
|
| 69 |
+
if context is not None:
|
| 70 |
+
print(f"{context}: Switched to EMA weights")
|
| 71 |
+
try:
|
| 72 |
+
yield None
|
| 73 |
+
finally:
|
| 74 |
+
if self.use_ema:
|
| 75 |
+
self.model_ema.restore(self.parameters())
|
| 76 |
+
if context is not None:
|
| 77 |
+
print(f"{context}: Restored training weights")
|
| 78 |
+
|
| 79 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
| 80 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
| 81 |
+
keys = list(sd.keys())
|
| 82 |
+
for k in keys:
|
| 83 |
+
for ik in ignore_keys:
|
| 84 |
+
if k.startswith(ik):
|
| 85 |
+
print("Deleting key {} from state_dict.".format(k))
|
| 86 |
+
del sd[k]
|
| 87 |
+
missing, unexpected = self.load_state_dict(sd, strict=False)
|
| 88 |
+
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
| 89 |
+
if len(missing) > 0:
|
| 90 |
+
print(f"Missing Keys: {missing}")
|
| 91 |
+
print(f"Unexpected Keys: {unexpected}")
|
| 92 |
+
|
| 93 |
+
def on_train_batch_end(self, *args, **kwargs):
|
| 94 |
+
if self.use_ema:
|
| 95 |
+
self.model_ema(self)
|
| 96 |
+
|
| 97 |
+
def encode(self, x):
|
| 98 |
+
h = self.encoder(x)
|
| 99 |
+
h = self.quant_conv(h)
|
| 100 |
+
quant, emb_loss, info = self.quantize(h)
|
| 101 |
+
return quant, emb_loss, info
|
| 102 |
+
|
| 103 |
+
def encode_to_prequant(self, x):
|
| 104 |
+
h = self.encoder(x)
|
| 105 |
+
h = self.quant_conv(h)
|
| 106 |
+
return h
|
| 107 |
+
|
| 108 |
+
def decode(self, quant):
|
| 109 |
+
quant = self.post_quant_conv(quant)
|
| 110 |
+
dec = self.decoder(quant)
|
| 111 |
+
return dec
|
| 112 |
+
|
| 113 |
+
def decode_code(self, code_b):
|
| 114 |
+
quant_b = self.quantize.embed_code(code_b)
|
| 115 |
+
dec = self.decode(quant_b)
|
| 116 |
+
return dec
|
| 117 |
+
|
| 118 |
+
def forward(self, input, return_pred_indices=False):
|
| 119 |
+
quant, diff, (_,_,ind) = self.encode(input)
|
| 120 |
+
dec = self.decode(quant)
|
| 121 |
+
if return_pred_indices:
|
| 122 |
+
return dec, diff, ind
|
| 123 |
+
return dec, diff
|
| 124 |
+
|
| 125 |
+
def get_input(self, batch, k):
|
| 126 |
+
x = batch[k]
|
| 127 |
+
if len(x.shape) == 3:
|
| 128 |
+
x = x[..., None]
|
| 129 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
| 130 |
+
if self.batch_resize_range is not None:
|
| 131 |
+
lower_size = self.batch_resize_range[0]
|
| 132 |
+
upper_size = self.batch_resize_range[1]
|
| 133 |
+
if self.global_step <= 4:
|
| 134 |
+
# do the first few batches with max size to avoid later oom
|
| 135 |
+
new_resize = upper_size
|
| 136 |
+
else:
|
| 137 |
+
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
|
| 138 |
+
if new_resize != x.shape[2]:
|
| 139 |
+
x = F.interpolate(x, size=new_resize, mode="bicubic")
|
| 140 |
+
x = x.detach()
|
| 141 |
+
return x
|
| 142 |
+
|
| 143 |
+
def training_step(self, batch, batch_idx, optimizer_idx):
|
| 144 |
+
# https://github.com/pytorch/pytorch/issues/37142
|
| 145 |
+
# try not to fool the heuristics
|
| 146 |
+
x = self.get_input(batch, self.image_key)
|
| 147 |
+
xrec, qloss, ind = self(x, return_pred_indices=True)
|
| 148 |
+
|
| 149 |
+
if optimizer_idx == 0:
|
| 150 |
+
# autoencode
|
| 151 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
| 152 |
+
last_layer=self.get_last_layer(), split="train",
|
| 153 |
+
predicted_indices=ind)
|
| 154 |
+
|
| 155 |
+
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
| 156 |
+
return aeloss
|
| 157 |
+
|
| 158 |
+
if optimizer_idx == 1:
|
| 159 |
+
# discriminator
|
| 160 |
+
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
| 161 |
+
last_layer=self.get_last_layer(), split="train")
|
| 162 |
+
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
| 163 |
+
return discloss
|
| 164 |
+
|
| 165 |
+
def validation_step(self, batch, batch_idx):
|
| 166 |
+
log_dict = self._validation_step(batch, batch_idx)
|
| 167 |
+
with self.ema_scope():
|
| 168 |
+
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
|
| 169 |
+
return log_dict
|
| 170 |
+
|
| 171 |
+
def _validation_step(self, batch, batch_idx, suffix=""):
|
| 172 |
+
x = self.get_input(batch, self.image_key)
|
| 173 |
+
xrec, qloss, ind = self(x, return_pred_indices=True)
|
| 174 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
|
| 175 |
+
self.global_step,
|
| 176 |
+
last_layer=self.get_last_layer(),
|
| 177 |
+
split="val"+suffix,
|
| 178 |
+
predicted_indices=ind
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
|
| 182 |
+
self.global_step,
|
| 183 |
+
last_layer=self.get_last_layer(),
|
| 184 |
+
split="val"+suffix,
|
| 185 |
+
predicted_indices=ind
|
| 186 |
+
)
|
| 187 |
+
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
|
| 188 |
+
self.log(f"val{suffix}/rec_loss", rec_loss,
|
| 189 |
+
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
| 190 |
+
self.log(f"val{suffix}/aeloss", aeloss,
|
| 191 |
+
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
| 192 |
+
if version.parse(pl.__version__) >= version.parse('1.4.0'):
|
| 193 |
+
del log_dict_ae[f"val{suffix}/rec_loss"]
|
| 194 |
+
self.log_dict(log_dict_ae)
|
| 195 |
+
self.log_dict(log_dict_disc)
|
| 196 |
+
return self.log_dict
|
| 197 |
+
|
| 198 |
+
def test_step(self, batch, batch_idx):
|
| 199 |
+
x = self.get_input(batch, self.image_key)
|
| 200 |
+
xrec, qloss, ind = self(x, return_pred_indices=True)
|
| 201 |
+
reconstructions = (xrec + 1)/2 # to mel scale
|
| 202 |
+
test_ckpt_path = os.path.basename(self.trainer.tested_ckpt_path)
|
| 203 |
+
savedir = os.path.join(self.trainer.log_dir,f'output_imgs_{test_ckpt_path}','fake_class')
|
| 204 |
+
if not os.path.exists(savedir):
|
| 205 |
+
os.makedirs(savedir)
|
| 206 |
+
|
| 207 |
+
file_names = batch['f_name']
|
| 208 |
+
# print(f"reconstructions.shape:{reconstructions.shape}",file_names)
|
| 209 |
+
reconstructions = reconstructions.cpu().numpy().squeeze(1) # squuze channel dim
|
| 210 |
+
for b in range(reconstructions.shape[0]):
|
| 211 |
+
vname_num_split_index = file_names[b].rfind('_')# file_names[b]:video_name+'_'+num
|
| 212 |
+
v_n,num = file_names[b][:vname_num_split_index],file_names[b][vname_num_split_index+1:]
|
| 213 |
+
save_img_path = os.path.join(savedir,f'{v_n}_sample_{num}.npy')
|
| 214 |
+
np.save(save_img_path,reconstructions[b])
|
| 215 |
+
|
| 216 |
+
return None
|
| 217 |
+
|
| 218 |
+
def configure_optimizers(self):
|
| 219 |
+
lr_d = self.learning_rate
|
| 220 |
+
lr_g = self.lr_g_factor*self.learning_rate
|
| 221 |
+
print("lr_d", lr_d)
|
| 222 |
+
print("lr_g", lr_g)
|
| 223 |
+
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
| 224 |
+
list(self.decoder.parameters())+
|
| 225 |
+
list(self.quantize.parameters())+
|
| 226 |
+
list(self.quant_conv.parameters())+
|
| 227 |
+
list(self.post_quant_conv.parameters()),
|
| 228 |
+
lr=lr_g, betas=(0.5, 0.9))
|
| 229 |
+
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
| 230 |
+
lr=lr_d, betas=(0.5, 0.9))
|
| 231 |
+
|
| 232 |
+
if self.scheduler_config is not None:
|
| 233 |
+
scheduler = instantiate_from_config(self.scheduler_config)
|
| 234 |
+
|
| 235 |
+
print("Setting up LambdaLR scheduler...")
|
| 236 |
+
scheduler = [
|
| 237 |
+
{
|
| 238 |
+
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
|
| 239 |
+
'interval': 'step',
|
| 240 |
+
'frequency': 1
|
| 241 |
+
},
|
| 242 |
+
{
|
| 243 |
+
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
|
| 244 |
+
'interval': 'step',
|
| 245 |
+
'frequency': 1
|
| 246 |
+
},
|
| 247 |
+
]
|
| 248 |
+
return [opt_ae, opt_disc], scheduler
|
| 249 |
+
return [opt_ae, opt_disc], []
|
| 250 |
+
|
| 251 |
+
def get_last_layer(self):
|
| 252 |
+
return self.decoder.conv_out.weight
|
| 253 |
+
|
| 254 |
+
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
|
| 255 |
+
log = dict()
|
| 256 |
+
x = self.get_input(batch, self.image_key)
|
| 257 |
+
x = x.to(self.device)
|
| 258 |
+
if only_inputs:
|
| 259 |
+
log["inputs"] = x
|
| 260 |
+
return log
|
| 261 |
+
xrec, _ = self(x)
|
| 262 |
+
if x.shape[1] > 3:
|
| 263 |
+
# colorize with random projection
|
| 264 |
+
assert xrec.shape[1] > 3
|
| 265 |
+
x = self.to_rgb(x)
|
| 266 |
+
xrec = self.to_rgb(xrec)
|
| 267 |
+
log["inputs"] = x
|
| 268 |
+
log["reconstructions"] = xrec
|
| 269 |
+
if plot_ema:
|
| 270 |
+
with self.ema_scope():
|
| 271 |
+
xrec_ema, _ = self(x)
|
| 272 |
+
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
|
| 273 |
+
log["reconstructions_ema"] = xrec_ema
|
| 274 |
+
return log
|
| 275 |
+
|
| 276 |
+
def to_rgb(self, x):
|
| 277 |
+
assert self.image_key == "segmentation"
|
| 278 |
+
if not hasattr(self, "colorize"):
|
| 279 |
+
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
| 280 |
+
x = F.conv2d(x, weight=self.colorize)
|
| 281 |
+
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
| 282 |
+
return x
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
class VQModelInterface(VQModel):
|
| 286 |
+
def __init__(self, embed_dim, *args, **kwargs):
|
| 287 |
+
super().__init__(embed_dim=embed_dim, *args, **kwargs)
|
| 288 |
+
self.embed_dim = embed_dim
|
| 289 |
+
|
| 290 |
+
def encode(self, x):# VQModel的quantize写在encoder里,VQModelInterface则将其写在decoder里
|
| 291 |
+
h = self.encoder(x)
|
| 292 |
+
h = self.quant_conv(h)
|
| 293 |
+
return h
|
| 294 |
+
|
| 295 |
+
def decode(self, h, force_not_quantize=False):
|
| 296 |
+
# also go through quantization layer
|
| 297 |
+
if not force_not_quantize:
|
| 298 |
+
quant, emb_loss, info = self.quantize(h)
|
| 299 |
+
else:
|
| 300 |
+
quant = h
|
| 301 |
+
quant = self.post_quant_conv(quant)
|
| 302 |
+
dec = self.decoder(quant)
|
| 303 |
+
return dec
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
class AutoencoderKL(pl.LightningModule):
|
| 307 |
+
def __init__(self,
|
| 308 |
+
ddconfig,
|
| 309 |
+
lossconfig,
|
| 310 |
+
embed_dim,
|
| 311 |
+
ckpt_path=None,
|
| 312 |
+
ignore_keys=[],
|
| 313 |
+
image_key="image",
|
| 314 |
+
colorize_nlabels=None,
|
| 315 |
+
monitor=None,
|
| 316 |
+
):
|
| 317 |
+
super().__init__()
|
| 318 |
+
self.to_1d = False
|
| 319 |
+
print(f"to_1d is {self.to_1d} in AUTOENCODER")
|
| 320 |
+
self.image_key = image_key
|
| 321 |
+
self.encoder = Encoder(**ddconfig)
|
| 322 |
+
self.decoder = Decoder(**ddconfig)
|
| 323 |
+
self.loss = instantiate_from_config(lossconfig)
|
| 324 |
+
assert ddconfig["double_z"]
|
| 325 |
+
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
|
| 326 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
| 327 |
+
self.embed_dim = embed_dim
|
| 328 |
+
if colorize_nlabels is not None:
|
| 329 |
+
assert type(colorize_nlabels)==int
|
| 330 |
+
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
| 331 |
+
if monitor is not None:
|
| 332 |
+
self.monitor = monitor
|
| 333 |
+
if ckpt_path is not None:
|
| 334 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
| 335 |
+
# self.automatic_optimization = False # hjw for debug
|
| 336 |
+
|
| 337 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
| 338 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
| 339 |
+
keys = list(sd.keys())
|
| 340 |
+
for k in keys:
|
| 341 |
+
for ik in ignore_keys:
|
| 342 |
+
if k.startswith(ik):
|
| 343 |
+
print("Deleting key {} from state_dict.".format(k))
|
| 344 |
+
del sd[k]
|
| 345 |
+
self.load_state_dict(sd, strict=False)
|
| 346 |
+
print(f"Restored from {path}")
|
| 347 |
+
|
| 348 |
+
def encode(self, x):
|
| 349 |
+
if self.to_1d and len(x.shape)==3:
|
| 350 |
+
x = x.unsqueeze(1)
|
| 351 |
+
h = self.encoder(x)
|
| 352 |
+
moments = self.quant_conv(h)
|
| 353 |
+
if self.to_1d:
|
| 354 |
+
b,c,h,w = moments.shape
|
| 355 |
+
moments = moments.reshape(b,c*h,w)
|
| 356 |
+
posterior = DiagonalGaussianDistribution(moments)
|
| 357 |
+
return posterior
|
| 358 |
+
|
| 359 |
+
def decode(self, z):
|
| 360 |
+
if self.to_1d:
|
| 361 |
+
b,c_h,w = z.shape
|
| 362 |
+
c = self.post_quant_conv.in_channels
|
| 363 |
+
z = z.reshape(b,c,-1,w)
|
| 364 |
+
z = self.post_quant_conv(z)
|
| 365 |
+
dec = self.decoder(z)
|
| 366 |
+
return dec
|
| 367 |
+
|
| 368 |
+
def forward(self, input, sample_posterior=True):
|
| 369 |
+
posterior = self.encode(input)
|
| 370 |
+
if sample_posterior:
|
| 371 |
+
z = posterior.sample()
|
| 372 |
+
else:
|
| 373 |
+
z = posterior.mode()
|
| 374 |
+
dec = self.decode(z)
|
| 375 |
+
return dec, posterior
|
| 376 |
+
|
| 377 |
+
def get_input(self, batch, k):
|
| 378 |
+
x = batch[k]
|
| 379 |
+
if len(x.shape) == 3:
|
| 380 |
+
x = x[..., None]
|
| 381 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
| 382 |
+
return x
|
| 383 |
+
|
| 384 |
+
def training_step(self, batch, batch_idx, optimizer_idx):
|
| 385 |
+
inputs = self.get_input(batch, self.image_key)
|
| 386 |
+
reconstructions, posterior = self(inputs)
|
| 387 |
+
|
| 388 |
+
if optimizer_idx == 0:
|
| 389 |
+
# train encoder+decoder+logvar
|
| 390 |
+
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
| 391 |
+
last_layer=self.get_last_layer(), split="train")
|
| 392 |
+
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
| 393 |
+
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
| 394 |
+
# print(optimizer_idx,log_dict_ae)
|
| 395 |
+
return aeloss
|
| 396 |
+
|
| 397 |
+
if optimizer_idx == 1:
|
| 398 |
+
# train the discriminator
|
| 399 |
+
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
| 400 |
+
last_layer=self.get_last_layer(), split="train")
|
| 401 |
+
|
| 402 |
+
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
| 403 |
+
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
| 404 |
+
# print(optimizer_idx,log_dict_disc)
|
| 405 |
+
return discloss
|
| 406 |
+
|
| 407 |
+
def validation_step(self, batch, batch_idx):
|
| 408 |
+
inputs = self.get_input(batch, self.image_key)
|
| 409 |
+
reconstructions, posterior = self(inputs)
|
| 410 |
+
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
|
| 411 |
+
last_layer=self.get_last_layer(), split="val")
|
| 412 |
+
|
| 413 |
+
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
|
| 414 |
+
last_layer=self.get_last_layer(), split="val")
|
| 415 |
+
|
| 416 |
+
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
|
| 417 |
+
self.log_dict(log_dict_ae)
|
| 418 |
+
self.log_dict(log_dict_disc)
|
| 419 |
+
return self.log_dict
|
| 420 |
+
|
| 421 |
+
def test_step(self, batch, batch_idx):
|
| 422 |
+
inputs = self.get_input(batch, self.image_key)# inputs shape:(b,mel_len,T)
|
| 423 |
+
reconstructions, posterior = self(inputs)# reconstructions:(b,mel_len,T)
|
| 424 |
+
mse_loss = torch.nn.functional.mse_loss(reconstructions,inputs)
|
| 425 |
+
self.log('test/mse_loss',mse_loss)
|
| 426 |
+
|
| 427 |
+
test_ckpt_path = os.path.basename(self.trainer.tested_ckpt_path)
|
| 428 |
+
savedir = os.path.join(self.trainer.log_dir,f'output_imgs_{test_ckpt_path}','fake_class')
|
| 429 |
+
if batch_idx == 0:
|
| 430 |
+
print(f"save_path is: {savedir}")
|
| 431 |
+
if not os.path.exists(savedir):
|
| 432 |
+
os.makedirs(savedir)
|
| 433 |
+
print(f"save_path is: {savedir}")
|
| 434 |
+
|
| 435 |
+
file_names = batch['f_name']
|
| 436 |
+
# print(f"reconstructions.shape:{reconstructions.shape}",file_names)
|
| 437 |
+
# reconstructions = (reconstructions + 1)/2 # to mel scale
|
| 438 |
+
reconstructions = reconstructions.cpu().numpy().squeeze(1) # squeeze channel dim
|
| 439 |
+
for b in range(reconstructions.shape[0]):
|
| 440 |
+
vname_num_split_index = file_names[b].rfind('_')# file_names[b]:video_name+'_'+num
|
| 441 |
+
v_n,num = file_names[b][:vname_num_split_index],file_names[b][vname_num_split_index+1:]
|
| 442 |
+
save_img_path = os.path.join(savedir, f'{v_n}.npy') # f'{v_n}_sample_{num}.npy' f'{v_n}.npy'
|
| 443 |
+
np.save(save_img_path,reconstructions[b])
|
| 444 |
+
|
| 445 |
+
return None
|
| 446 |
+
|
| 447 |
+
def configure_optimizers(self):
|
| 448 |
+
lr = self.learning_rate
|
| 449 |
+
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
| 450 |
+
list(self.decoder.parameters())+
|
| 451 |
+
list(self.quant_conv.parameters())+
|
| 452 |
+
list(self.post_quant_conv.parameters()),
|
| 453 |
+
lr=lr, betas=(0.5, 0.9))
|
| 454 |
+
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
| 455 |
+
lr=lr, betas=(0.5, 0.9))
|
| 456 |
+
return [opt_ae, opt_disc], []
|
| 457 |
+
|
| 458 |
+
def get_last_layer(self):
|
| 459 |
+
return self.decoder.conv_out.weight
|
| 460 |
+
|
| 461 |
+
@torch.no_grad()
|
| 462 |
+
def log_images(self, batch, only_inputs=False,save_dir = 'mel_result_ae13_26_debug/fake_class', **kwargs): # 在main.py的on_validation_batch_end中调用
|
| 463 |
+
log = dict()
|
| 464 |
+
x = self.get_input(batch, self.image_key)
|
| 465 |
+
x = x.to(self.device)
|
| 466 |
+
if not only_inputs:
|
| 467 |
+
xrec, posterior = self(x)
|
| 468 |
+
if x.shape[1] > 3:
|
| 469 |
+
# colorize with random projection
|
| 470 |
+
assert xrec.shape[1] > 3
|
| 471 |
+
x = self.to_rgb(x)
|
| 472 |
+
xrec = self.to_rgb(xrec)
|
| 473 |
+
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
|
| 474 |
+
log["reconstructions"] = xrec
|
| 475 |
+
log["inputs"] = x
|
| 476 |
+
return log
|
| 477 |
+
|
| 478 |
+
def to_rgb(self, x):
|
| 479 |
+
assert self.image_key == "segmentation"
|
| 480 |
+
if not hasattr(self, "colorize"):
|
| 481 |
+
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
| 482 |
+
x = F.conv2d(x, weight=self.colorize)
|
| 483 |
+
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
| 484 |
+
return x
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
class IdentityFirstStage(torch.nn.Module):
|
| 488 |
+
def __init__(self, *args, vq_interface=False, **kwargs):
|
| 489 |
+
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
|
| 490 |
+
super().__init__()
|
| 491 |
+
|
| 492 |
+
def encode(self, x, *args, **kwargs):
|
| 493 |
+
return x
|
| 494 |
+
|
| 495 |
+
def decode(self, x, *args, **kwargs):
|
| 496 |
+
return x
|
| 497 |
+
|
| 498 |
+
def quantize(self, x, *args, **kwargs):
|
| 499 |
+
if self.vq_interface:
|
| 500 |
+
return x, None, [None, None, None]
|
| 501 |
+
return x
|
| 502 |
+
|
| 503 |
+
def forward(self, x, *args, **kwargs):
|
| 504 |
+
return x
|