Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import logging | |
| from pathlib import Path | |
| import os | |
| import subprocess | |
| import tempfile | |
| import typing as tp | |
| from audiocraft.data.audio import audio_write | |
| from audiocraft.data.audio_utils import convert_audio | |
| import flashy | |
| import torch | |
| import torchmetrics | |
| from ..environment import AudioCraftEnvironment | |
| logger = logging.getLogger(__name__) | |
| VGGISH_SAMPLE_RATE = 16_000 | |
| VGGISH_CHANNELS = 1 | |
| class FrechetAudioDistanceMetric(torchmetrics.Metric): | |
| """Fréchet Audio Distance computation based on official TensorFlow implementation from Google Research. | |
| From: D.C. Dowson & B.V. Landau The Fréchet distance between | |
| multivariate normal distributions | |
| https://doi.org/10.1016/0047-259X(82)90077-X | |
| The Fréchet distance between two multivariate gaussians, | |
| `X ~ N(mu_x, sigma_x)` and `Y ~ N(mu_y, sigma_y)`, is `d^2`. | |
| d^2 = (mu_x - mu_y)^2 + Tr(sigma_x + sigma_y - 2 * sqrt(sigma_x*sigma_y)) | |
| = (mu_x - mu_y)^2 + Tr(sigma_x) + Tr(sigma_y) | |
| - 2 * Tr(sqrt(sigma_x*sigma_y))) | |
| To use this FAD computation metric, you need to have the proper Frechet Audio Distance tool setup | |
| from: https://github.com/google-research/google-research/tree/master/frechet_audio_distance | |
| We provide the below instructions as reference but we do not guarantee for further support | |
| in frechet_audio_distance installation. This was tested with python 3.10, cuda 11.8, tensorflow 2.12.0. | |
| We recommend installing the frechet_audio_distance library in a dedicated env (e.g. conda). | |
| 1. Get the code and models following the repository instructions. We used the steps below: | |
| git clone [email protected]:google-research/google-research.git | |
| git clone [email protected]:tensorflow/models.git | |
| mkdir google-research/tensorflow_models | |
| touch google-research/tensorflow_models/__init__.py | |
| cp -r models/research/audioset google-research/tensorflow_models/ | |
| touch google-research/tensorflow_models/audioset/__init__.py | |
| echo "from .vggish import mel_features, vggish_params, vggish_slim" > \ | |
| google-research/tensorflow_models/audioset/__init__.py | |
| # we can now remove the tensorflow models repository | |
| # rm -r models | |
| cd google-research | |
| Follow the instructions to download the vggish checkpoint. AudioCraft base configuration | |
| assumes it is placed in the AudioCraft reference dir. | |
| Note that we operate the following changes for the code to work with TensorFlow 2.X and python 3: | |
| - Update xrange for range in: | |
| https://github.com/google-research/google-research/blob/master/frechet_audio_distance/audioset_model.py | |
| - Update `tf_record = tf.python_io.tf_record_iterator(filename).next()` to | |
| `tf_record = tf.python_io.tf_record_iterator(filename).__next__()` in | |
| https://github.com/google-research/google-research/blob/master/frechet_audio_distance/fad_utils.py | |
| - Update `import vggish_params as params` to `from . import vggish_params as params` in: | |
| https://github.com/tensorflow/models/blob/master/research/audioset/vggish/vggish_slim.py | |
| - Add flag to provide a given batch size for running the AudioSet model in: | |
| https://github.com/google-research/google-research/blob/master/frechet_audio_distance/create_embeddings_main.py | |
| ``` | |
| flags.DEFINE_integer('batch_size', 64, | |
| 'Number of samples in the batch for AudioSet model.') | |
| ``` | |
| Ensure you pass the flag to the create_embeddings_beam.create_pipeline function, adding: | |
| `batch_size=FLAGS.batch_size` to the provided parameters. | |
| 2. Follow instructions for the library installation and a valid TensorFlow installation | |
| ``` | |
| # e.g. instructions from: https://www.tensorflow.org/install/pip | |
| conda install -c conda-forge cudatoolkit=11.8.0 | |
| python3 -m pip install nvidia-cudnn-cu11==8.6.0.163 tensorflow==2.12.* | |
| mkdir -p $CONDA_PREFIX/etc/conda/activate.d | |
| echo 'CUDNN_PATH=$(dirname $(python -c "import nvidia.cudnn;print(nvidia.cudnn.__file__)"))' \ | |
| >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh | |
| echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib/:$CUDNN_PATH/lib' \ | |
| >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh | |
| source $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh | |
| # Verify install: on a machine with GPU device | |
| python3 -c "import tensorflow as tf; print(tf.config.list_physical_devices('GPU'))" | |
| ``` | |
| Now install frechet_audio_distance required dependencies: | |
| ``` | |
| # We assume we already have TensorFlow installed from the above steps | |
| pip install apache-beam numpy scipy tf_slim | |
| ``` | |
| Finally, follow remaining library instructions to ensure you have a working frechet_audio_distance setup | |
| (you may want to specify --model_ckpt flag pointing to the model's path). | |
| 3. AudioCraft's FrechetAudioDistanceMetric requires 2 environment variables pointing to the python executable | |
| and Tensorflow library path from the above installation steps: | |
| export TF_PYTHON_EXE="<PATH_TO_THE_ENV_PYTHON_BINARY>" | |
| export TF_LIBRARY_PATH="<PATH_TO_THE_ENV_CUDNN_LIBRARY>" | |
| e.g. assuming we have installed everything in a dedicated conda env | |
| with python 3.10 that is currently active: | |
| export TF_PYTHON_EXE="$CONDA_PREFIX/bin/python" | |
| export TF_LIBRARY_PATH="$CONDA_PREFIX/lib/python3.10/site-packages/nvidia/cudnn/lib" | |
| Finally you may want to export the following variable: | |
| export TF_FORCE_GPU_ALLOW_GROWTH=true | |
| See: https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth | |
| You can save those environment variables in your training conda env, when currently active: | |
| `$CONDA_PREFIX/etc/conda/activate.d/env_vars.sh` | |
| e.g. assuming the env with TensorFlow and frechet_audio_distance install is named ac_eval, | |
| and the training conda env is named audiocraft: | |
| ``` | |
| # activate training env | |
| conda activate audiocraft | |
| # get path to all envs | |
| CONDA_ENV_DIR=$(dirname $CONDA_PREFIX) | |
| # export pointers to evaluation env for using TensorFlow in FrechetAudioDistanceMetric | |
| touch $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh | |
| echo 'export TF_PYTHON_EXE="$CONDA_ENV_DIR/ac_eval/bin/python"' >> \ | |
| $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh | |
| echo 'export TF_LIBRARY_PATH="$CONDA_ENV_DIR/ac_eval/lib/python3.10/site-packages/nvidia/cudnn/lib"' >> \ | |
| $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh | |
| # optionally: | |
| echo 'export TF_FORCE_GPU_ALLOW_GROWTH=true' >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh | |
| # you may need to reactivate the audiocraft env for this to take effect | |
| ``` | |
| Args: | |
| bin (Path or str): Path to installed frechet audio distance code. | |
| model_path (Path or str): Path to Tensorflow checkpoint for the model | |
| used to compute statistics over the embedding beams. | |
| format (str): Audio format used to save files. | |
| log_folder (Path or str, optional): Path where to write process logs. | |
| """ | |
| def __init__(self, bin: tp.Union[Path, str], model_path: tp.Union[Path, str], | |
| format: str = "wav", batch_size: tp.Optional[int] = None, | |
| log_folder: tp.Optional[tp.Union[Path, str]] = None): | |
| super().__init__() | |
| self.model_sample_rate = VGGISH_SAMPLE_RATE | |
| self.model_channels = VGGISH_CHANNELS | |
| self.model_path = AudioCraftEnvironment.resolve_reference_path(model_path) | |
| assert Path(self.model_path).exists(), f"Could not find provided model checkpoint path at: {self.model_path}" | |
| self.format = format | |
| self.batch_size = batch_size | |
| self.bin = bin | |
| self.tf_env = {"PYTHONPATH": str(self.bin)} | |
| self.python_path = os.environ.get('TF_PYTHON_EXE') or 'python' | |
| logger.info("Python exe for TF is %s", self.python_path) | |
| if 'TF_LIBRARY_PATH' in os.environ: | |
| self.tf_env['LD_LIBRARY_PATH'] = os.environ['TF_LIBRARY_PATH'] | |
| if 'TF_FORCE_GPU_ALLOW_GROWTH' in os.environ: | |
| self.tf_env['TF_FORCE_GPU_ALLOW_GROWTH'] = os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] | |
| logger.info("Env for TF is %r", self.tf_env) | |
| self.reset(log_folder) | |
| self.add_state("total_files", default=torch.tensor(0.), dist_reduce_fx="sum") | |
| def reset(self, log_folder: tp.Optional[tp.Union[Path, str]] = None): | |
| """Reset torchmetrics.Metrics state.""" | |
| log_folder = Path(log_folder or tempfile.mkdtemp()) | |
| self.tmp_dir = log_folder / 'fad' | |
| self.tmp_dir.mkdir(exist_ok=True) | |
| self.samples_tests_dir = self.tmp_dir / 'tests' | |
| self.samples_tests_dir.mkdir(exist_ok=True) | |
| self.samples_background_dir = self.tmp_dir / 'background' | |
| self.samples_background_dir.mkdir(exist_ok=True) | |
| self.manifest_tests = self.tmp_dir / 'files_tests.cvs' | |
| self.manifest_background = self.tmp_dir / 'files_background.cvs' | |
| self.stats_tests_dir = self.tmp_dir / 'stats_tests' | |
| self.stats_background_dir = self.tmp_dir / 'stats_background' | |
| self.counter = 0 | |
| def update(self, preds: torch.Tensor, targets: torch.Tensor, | |
| sizes: torch.Tensor, sample_rates: torch.Tensor, | |
| stems: tp.Optional[tp.List[str]] = None): | |
| """Update torchmetrics.Metrics by saving the audio and updating the manifest file.""" | |
| assert preds.shape == targets.shape, f"preds={preds.shape} != targets={targets.shape}" | |
| num_samples = preds.shape[0] | |
| assert num_samples == sizes.size(0) and num_samples == sample_rates.size(0) | |
| assert stems is None or num_samples == len(set(stems)) | |
| for i in range(num_samples): | |
| self.total_files += 1 # type: ignore | |
| self.counter += 1 | |
| wav_len = int(sizes[i].item()) | |
| sample_rate = int(sample_rates[i].item()) | |
| pred_wav = preds[i] | |
| target_wav = targets[i] | |
| pred_wav = pred_wav[..., :wav_len] | |
| target_wav = target_wav[..., :wav_len] | |
| stem_name = stems[i] if stems is not None else f'sample_{self.counter}_{flashy.distrib.rank()}' | |
| # dump audio files | |
| try: | |
| pred_wav = convert_audio( | |
| pred_wav.unsqueeze(0), from_rate=sample_rate, | |
| to_rate=self.model_sample_rate, to_channels=1).squeeze(0) | |
| audio_write( | |
| self.samples_tests_dir / stem_name, pred_wav, sample_rate=self.model_sample_rate, | |
| format=self.format, strategy="peak") | |
| except Exception as e: | |
| logger.error(f"Exception occured when saving tests files for FAD computation: {repr(e)} - {e}") | |
| try: | |
| # for the ground truth audio, we enforce the 'peak' strategy to avoid modifying | |
| # the original audio when writing it | |
| target_wav = convert_audio( | |
| target_wav.unsqueeze(0), from_rate=sample_rate, | |
| to_rate=self.model_sample_rate, to_channels=1).squeeze(0) | |
| audio_write( | |
| self.samples_background_dir / stem_name, target_wav, sample_rate=self.model_sample_rate, | |
| format=self.format, strategy="peak") | |
| except Exception as e: | |
| logger.error(f"Exception occured when saving background files for FAD computation: {repr(e)} - {e}") | |
| def _get_samples_name(self, is_background: bool): | |
| return 'background' if is_background else 'tests' | |
| def _create_embedding_beams(self, is_background: bool, gpu_index: tp.Optional[int] = None): | |
| if is_background: | |
| input_samples_dir = self.samples_background_dir | |
| input_filename = self.manifest_background | |
| stats_name = self.stats_background_dir | |
| else: | |
| input_samples_dir = self.samples_tests_dir | |
| input_filename = self.manifest_tests | |
| stats_name = self.stats_tests_dir | |
| beams_name = self._get_samples_name(is_background) | |
| log_file = self.tmp_dir / f'fad_logs_create_beams_{beams_name}.log' | |
| logger.info(f"Scanning samples folder to fetch list of files: {input_samples_dir}") | |
| with open(input_filename, "w") as fout: | |
| for path in Path(input_samples_dir).glob(f"*.{self.format}"): | |
| fout.write(f"{str(path)}\n") | |
| cmd = [ | |
| self.python_path, "-m", | |
| "frechet_audio_distance.create_embeddings_main", | |
| "--model_ckpt", f"{self.model_path}", | |
| "--input_files", f"{str(input_filename)}", | |
| "--stats", f"{str(stats_name)}", | |
| ] | |
| if self.batch_size is not None: | |
| cmd += ["--batch_size", str(self.batch_size)] | |
| logger.info(f"Launching frechet_audio_distance embeddings main method: {' '.join(cmd)} on {beams_name}") | |
| env = os.environ | |
| if gpu_index is not None: | |
| env["CUDA_VISIBLE_DEVICES"] = str(gpu_index) | |
| process = subprocess.Popen( | |
| cmd, stdout=open(log_file, "w"), env={**env, **self.tf_env}, stderr=subprocess.STDOUT) | |
| return process, log_file | |
| def _compute_fad_score(self, gpu_index: tp.Optional[int] = None): | |
| cmd = [ | |
| self.python_path, "-m", "frechet_audio_distance.compute_fad", | |
| "--test_stats", f"{str(self.stats_tests_dir)}", | |
| "--background_stats", f"{str(self.stats_background_dir)}", | |
| ] | |
| logger.info(f"Launching frechet_audio_distance compute fad method: {' '.join(cmd)}") | |
| env = os.environ | |
| if gpu_index is not None: | |
| env["CUDA_VISIBLE_DEVICES"] = str(gpu_index) | |
| result = subprocess.run(cmd, env={**env, **self.tf_env}, capture_output=True) | |
| if result.returncode: | |
| logger.error( | |
| "Error with FAD computation from stats: \n %s \n %s", | |
| result.stdout.decode(), result.stderr.decode() | |
| ) | |
| raise RuntimeError("Error while executing FAD computation from stats") | |
| try: | |
| # result is "FAD: (d+).(d+)" hence we remove the prefix with (d+) being one digit or more | |
| fad_score = float(result.stdout[4:]) | |
| return fad_score | |
| except Exception as e: | |
| raise RuntimeError(f"Error parsing FAD score from command stdout: {e}") | |
| def _log_process_result(self, returncode: int, log_file: tp.Union[Path, str], is_background: bool) -> None: | |
| beams_name = self._get_samples_name(is_background) | |
| if returncode: | |
| with open(log_file, "r") as f: | |
| error_log = f.read() | |
| logger.error(error_log) | |
| os._exit(1) | |
| else: | |
| logger.info(f"Successfully computed embedding beams on {beams_name} samples.") | |
| def _parallel_create_embedding_beams(self, num_of_gpus: int): | |
| assert num_of_gpus > 0 | |
| logger.info("Creating embeddings beams in a parallel manner on different GPUs") | |
| tests_beams_process, tests_beams_log_file = self._create_embedding_beams(is_background=False, gpu_index=0) | |
| bg_beams_process, bg_beams_log_file = self._create_embedding_beams(is_background=True, gpu_index=1) | |
| tests_beams_code = tests_beams_process.wait() | |
| bg_beams_code = bg_beams_process.wait() | |
| self._log_process_result(tests_beams_code, tests_beams_log_file, is_background=False) | |
| self._log_process_result(bg_beams_code, bg_beams_log_file, is_background=True) | |
| def _sequential_create_embedding_beams(self): | |
| logger.info("Creating embeddings beams in a sequential manner") | |
| tests_beams_process, tests_beams_log_file = self._create_embedding_beams(is_background=False) | |
| tests_beams_code = tests_beams_process.wait() | |
| self._log_process_result(tests_beams_code, tests_beams_log_file, is_background=False) | |
| bg_beams_process, bg_beams_log_file = self._create_embedding_beams(is_background=True) | |
| bg_beams_code = bg_beams_process.wait() | |
| self._log_process_result(bg_beams_code, bg_beams_log_file, is_background=True) | |
| def _local_compute_frechet_audio_distance(self): | |
| """Compute Frechet Audio Distance score calling TensorFlow API.""" | |
| num_of_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0 | |
| if num_of_gpus > 1: | |
| self._parallel_create_embedding_beams(num_of_gpus) | |
| else: | |
| self._sequential_create_embedding_beams() | |
| fad_score = self._compute_fad_score(gpu_index=0) | |
| return fad_score | |
| def compute(self) -> float: | |
| """Compute metrics.""" | |
| assert self.total_files.item() > 0, "No files dumped for FAD computation!" # type: ignore | |
| fad_score = self._local_compute_frechet_audio_distance() | |
| logger.warning(f"FAD score = {fad_score}") | |
| fad_score = flashy.distrib.broadcast_object(fad_score, src=0) | |
| return fad_score | |