import ast | |
import logging | |
import os | |
import os.path as op | |
import sys | |
from argparse import Namespace | |
import numpy as np | |
import torch | |
from fairseq import checkpoint_utils, options, tasks, utils | |
from fairseq.dataclass.utils import convert_namespace_to_omegaconf | |
from fairseq.logging import progress_bar | |
from omegaconf import DictConfig | |
# define function for plot prob and att_ws | |
def _plot_and_save(array, figname, figsize=(6, 4), dpi=150): | |
import matplotlib.pyplot as plt | |
shape = array.shape | |
if len(shape) == 1: | |
# for eos probability | |
plt.figure(figsize=figsize, dpi=dpi) | |
plt.plot(array) | |
plt.xlabel("Frame") | |
plt.ylabel("Probability") | |
plt.ylim([0, 1]) | |
elif len(shape) == 2: | |
# for tacotron 2 attention weights, whose shape is (out_length, in_length) | |
plt.figure(figsize=figsize, dpi=dpi) | |
plt.imshow(array, aspect="auto") | |
elif len(shape) == 4: | |
# for transformer attention weights, | |
# whose shape is (#leyers, #heads, out_length, in_length) | |
plt.figure(figsize=(figsize[0] * shape[0], figsize[1] * shape[1]), dpi=dpi) | |
for idx1, xs in enumerate(array): | |
for idx2, x in enumerate(xs, 1): | |
plt.subplot(shape[0], shape[1], idx1 * shape[1] + idx2) | |
plt.imshow(x, aspect="auto") | |
plt.xlabel("Input") | |
plt.ylabel("Output") | |
else: | |
raise NotImplementedError("Support only from 1D to 4D array.") | |
plt.tight_layout() | |
if not op.exists(op.dirname(figname)): | |
# NOTE: exist_ok = True is needed for parallel process decoding | |
os.makedirs(op.dirname(figname), exist_ok=True) | |
plt.savefig(figname) | |
plt.close() | |
# define function to calculate focus rate | |
# (see section 3.3 in | |
def _calculate_focus_rete(att_ws): | |
if att_ws is None: | |
# fastspeech case -> None | |
return 1.0 | |
elif len(att_ws.shape) == 2: | |
# tacotron 2 case -> (L, T) | |
return float(att_ws.max(dim=-1)[0].mean()) | |
elif len(att_ws.shape) == 4: | |
# transformer case -> (#layers, #heads, L, T) | |
return float(att_ws.max(dim=-1)[0].mean(dim=-1).max()) | |
else: | |
raise ValueError("att_ws should be 2 or 4 dimensional tensor.") | |
def main(cfg: DictConfig): | |
if isinstance(cfg, Namespace): | |
cfg = convert_namespace_to_omegaconf(cfg) | |
assert cfg.common_eval.path is not None, "--path required for generation!" | |
assert ( | |
cfg.generation.replace_unk is None or cfg.dataset.dataset_impl == "raw" | |
), "--replace-unk requires a raw text dataset (--dataset-impl=raw)" | |
if cfg.common_eval.results_path is not None: | |
os.makedirs(cfg.common_eval.results_path, exist_ok=True) | |
return _main(cfg, sys.stdout) | |
def _main(cfg: DictConfig, output_file): | |
logging.basicConfig( | |
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", | |
datefmt="%Y-%m-%d %H:%M:%S", | |
level=os.environ.get("LOGLEVEL", "INFO").upper(), | |
stream=output_file, | |
) | |
logger = logging.getLogger("speecht5.generate_speech") | |
utils.import_user_module(cfg.common) | |
assert cfg.dataset.batch_size == 1, "only support batch size 1" | | | |
# Fix seed for stochastic decoding | |
if cfg.common.seed is not None and not cfg.generation.no_seed_provided: | |
np.random.seed(cfg.common.seed) | |
utils.set_torch_seed(cfg.common.seed) | |
use_cuda = torch.cuda.is_available() and not cfg.common.cpu | |
if not use_cuda: | |"generate speech on cpu") | |
# build task | |
task = tasks.setup_task(cfg.task) | |
# Load ensemble | |"loading model(s) from {}".format(cfg.common_eval.path)) | |
overrides = ast.literal_eval(cfg.common_eval.model_overrides) | |
models, saved_cfg = checkpoint_utils.load_model_ensemble( | |
utils.split_paths(cfg.common_eval.path), | |
arg_overrides=overrides, | |
task=task, | |
suffix=cfg.checkpoint.checkpoint_suffix, | |
strict=(cfg.checkpoint.checkpoint_shard_count == 1), | |
num_shards=cfg.checkpoint.checkpoint_shard_count, | |
) | | | |
# loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config | |
task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task) | |
# optimize ensemble for generation | |
for model in models: | |
if model is None: | |
continue | |
if cfg.common.fp16: | |
model.half() | |
if use_cuda and not cfg.distributed_training.pipeline_model_parallel: | |
model.cuda() | |
model.prepare_for_inference_(cfg) | |
# load dataset (possibly sharded) | |
itr = task.get_batch_iterator( | |
dataset=task.dataset(cfg.dataset.gen_subset), | |
max_tokens=cfg.dataset.max_tokens, | |
max_sentences=cfg.dataset.batch_size, | |
max_positions=None, | |
ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, | |
required_batch_size_multiple=cfg.dataset.required_batch_size_multiple, | |
seed=cfg.common.seed, | |
num_shards=cfg.distributed_training.distributed_world_size, | |
shard_id=cfg.distributed_training.distributed_rank, | |
num_workers=cfg.dataset.num_workers, | |
data_buffer_size=cfg.dataset.data_buffer_size, | |
).next_epoch_itr(shuffle=False) | |
progress = progress_bar.progress_bar( | |
itr, | |
log_format=cfg.common.log_format, | |
log_interval=cfg.common.log_interval, | |
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), | |
) | |
for i, sample in enumerate(progress): | |
if "net_input" not in sample: | |
continue | |
sample = utils.move_to_cuda(sample) if use_cuda else sample | |
outs, _, attn = task.generate_speech( | |
models, | |
sample["net_input"], | |
) | |
focus_rate = _calculate_focus_rete(attn) | |
outs = outs.cpu().numpy() | |
audio_name = op.basename(sample['name'][0]) | |, audio_name.replace(".wav", "-feats.npy")), outs) | | | |
"{} (size: {}->{} ({}), focus rate: {:.3f})".format( | |
sample['name'][0], | |
sample['src_lengths'][0].item(), | |
outs.shape[0], | |
sample['dec_target_lengths'][0].item(), | |
focus_rate | |
) | |
) | |
if i < 6 and attn is not None: | |
import shutil | |
demo_dir = op.join(op.dirname(cfg.common_eval.results_path), "demo") | |
audio_dir = op.join(demo_dir, "audio") | |
os.makedirs(audio_dir, exist_ok=True) | |
shutil.copy(op.join(task.dataset(cfg.dataset.gen_subset).audio_root, sample['tgt_name'][0] if "tgt_name" in sample else sample['name'][0]), op.join(audio_dir, audio_name)) | |
att_dir = op.join(demo_dir, "att_ws") | |
_plot_and_save(attn.cpu().numpy(), op.join(att_dir, f"{audio_name}_att_ws.png")) | |
spec_dir = op.join(demo_dir, "spec") | |
_plot_and_save(outs.T, op.join(spec_dir, f"{audio_name}_gen.png")) | |
_plot_and_save(sample["target"][0].cpu().numpy().T, op.join(spec_dir, f"{audio_name}_ori.png")) | |
def cli_main(): | |
parser = options.get_generation_parser() | |
args = options.parse_args_and_arch(parser) | |
main(args) | |
if __name__ == "__main__": | |
cli_main() | |