Upload 6 files
Browse files- audiocraft/metrics/chroma_cosinesim.py +72 -0
- audiocraft/metrics/clap_consistency.py +84 -0
- audiocraft/metrics/fad.py +329 -0
- audiocraft/metrics/kld.py +220 -0
- audiocraft/metrics/rvm.py +110 -0
- audiocraft/metrics/visqol.py +216 -0
audiocraft/metrics/chroma_cosinesim.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torchmetrics
|
| 9 |
+
|
| 10 |
+
from ..data.audio_utils import convert_audio
|
| 11 |
+
from ..modules.chroma import ChromaExtractor
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ChromaCosineSimilarityMetric(torchmetrics.Metric):
|
| 15 |
+
"""Chroma cosine similarity metric.
|
| 16 |
+
|
| 17 |
+
This metric extracts a chromagram for a reference waveform and
|
| 18 |
+
a generated waveform and compares each frame using the cosine similarity
|
| 19 |
+
function. The output is the mean cosine similarity.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
sample_rate (int): Sample rate used by the chroma extractor.
|
| 23 |
+
n_chroma (int): Number of chroma used by the chroma extractor.
|
| 24 |
+
radix2_exp (int): Exponent for the chroma extractor.
|
| 25 |
+
argmax (bool): Whether the chroma extractor uses argmax.
|
| 26 |
+
eps (float): Epsilon for cosine similarity computation.
|
| 27 |
+
"""
|
| 28 |
+
def __init__(self, sample_rate: int, n_chroma: int, radix2_exp: int, argmax: bool, eps: float = 1e-8):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.chroma_sample_rate = sample_rate
|
| 31 |
+
self.n_chroma = n_chroma
|
| 32 |
+
self.eps = eps
|
| 33 |
+
self.chroma_extractor = ChromaExtractor(sample_rate=self.chroma_sample_rate, n_chroma=self.n_chroma,
|
| 34 |
+
radix2_exp=radix2_exp, argmax=argmax)
|
| 35 |
+
self.add_state("cosine_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
|
| 36 |
+
self.add_state("weight", default=torch.tensor(0.), dist_reduce_fx="sum")
|
| 37 |
+
|
| 38 |
+
def update(self, preds: torch.Tensor, targets: torch.Tensor,
|
| 39 |
+
sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
|
| 40 |
+
"""Compute cosine similarity between chromagrams and accumulate scores over the dataset."""
|
| 41 |
+
if preds.size(0) == 0:
|
| 42 |
+
return
|
| 43 |
+
|
| 44 |
+
assert preds.shape == targets.shape, (
|
| 45 |
+
f"Preds and target shapes mismatch: preds={preds.shape}, targets={targets.shape}")
|
| 46 |
+
assert preds.size(0) == sizes.size(0), (
|
| 47 |
+
f"Number of items in preds ({preds.shape}) mismatch ",
|
| 48 |
+
f"with sizes ({sizes.shape})")
|
| 49 |
+
assert preds.size(0) == sample_rates.size(0), (
|
| 50 |
+
f"Number of items in preds ({preds.shape}) mismatch ",
|
| 51 |
+
f"with sample_rates ({sample_rates.shape})")
|
| 52 |
+
assert torch.all(sample_rates == sample_rates[0].item()), "All sample rates are not the same in the batch"
|
| 53 |
+
|
| 54 |
+
device = self.weight.device
|
| 55 |
+
preds, targets = preds.to(device), targets.to(device) # type: ignore
|
| 56 |
+
sample_rate = sample_rates[0].item()
|
| 57 |
+
preds = convert_audio(preds, from_rate=sample_rate, to_rate=self.chroma_sample_rate, to_channels=1)
|
| 58 |
+
targets = convert_audio(targets, from_rate=sample_rate, to_rate=self.chroma_sample_rate, to_channels=1)
|
| 59 |
+
gt_chroma = self.chroma_extractor(targets)
|
| 60 |
+
gen_chroma = self.chroma_extractor(preds)
|
| 61 |
+
chroma_lens = (sizes / self.chroma_extractor.winhop).ceil().int()
|
| 62 |
+
for i in range(len(gt_chroma)):
|
| 63 |
+
t = int(chroma_lens[i].item())
|
| 64 |
+
cosine_sim = torch.nn.functional.cosine_similarity(
|
| 65 |
+
gt_chroma[i, :t], gen_chroma[i, :t], dim=1, eps=self.eps)
|
| 66 |
+
self.cosine_sum += cosine_sim.sum(dim=0) # type: ignore
|
| 67 |
+
self.weight += torch.tensor(t) # type: ignore
|
| 68 |
+
|
| 69 |
+
def compute(self) -> float:
|
| 70 |
+
"""Computes the average cosine similarty across all generated/target chromagrams pairs."""
|
| 71 |
+
assert self.weight.item() > 0, "Unable to compute with total number of comparisons <= 0" # type: ignore
|
| 72 |
+
return (self.cosine_sum / self.weight).item() # type: ignore
|
audiocraft/metrics/clap_consistency.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import typing as tp
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torchmetrics
|
| 12 |
+
from transformers import RobertaTokenizer # type: ignore
|
| 13 |
+
|
| 14 |
+
from ..data.audio_utils import convert_audio
|
| 15 |
+
from ..environment import AudioCraftEnvironment
|
| 16 |
+
from ..utils.utils import load_clap_state_dict
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
import laion_clap # type: ignore
|
| 20 |
+
except ImportError:
|
| 21 |
+
laion_clap = None
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class TextConsistencyMetric(torchmetrics.Metric):
|
| 25 |
+
"""Text consistency metric measuring consistency between audio and text pairs."""
|
| 26 |
+
|
| 27 |
+
def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
|
| 28 |
+
raise NotImplementedError("implement how to update the metric from the audio and text pairs.")
|
| 29 |
+
|
| 30 |
+
def compute(self):
|
| 31 |
+
raise NotImplementedError("implement how to compute the final metric score.")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class CLAPTextConsistencyMetric(TextConsistencyMetric):
|
| 35 |
+
"""Text consistency metric relying on Contrastive Language-Audio Pretraining (CLAP).
|
| 36 |
+
|
| 37 |
+
This metric is similar to the MuLan Cycle Consistency from MusicLM (https://arxiv.org/pdf/2301.11325.pdf)
|
| 38 |
+
or the CLAP score used in Make-An-Audio (https://arxiv.org/pdf/2301.12661v1.pdf).
|
| 39 |
+
|
| 40 |
+
As a joint audio-text embedding model, a pretrained CLAP model can be used to quantify the
|
| 41 |
+
similarity between audio-text pairs. We compute the CLAP embeddings from the text descriptions as
|
| 42 |
+
well as the generated audio based on them, and define the MCC metric as the average cosine similarity
|
| 43 |
+
between these embeddings.
|
| 44 |
+
|
| 45 |
+
Model implementation & pre-trained checkpoints: https://github.com/LAION-AI/CLAP
|
| 46 |
+
"""
|
| 47 |
+
def __init__(self, model_path: tp.Union[str, Path], model_arch: str = 'HTSAT-tiny', enable_fusion: bool = False):
|
| 48 |
+
super().__init__()
|
| 49 |
+
if laion_clap is None:
|
| 50 |
+
raise ImportError("Please install CLAP to compute text consistency: 'pip install laion_clap'")
|
| 51 |
+
self.add_state("cosine_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
|
| 52 |
+
self.add_state("weight", default=torch.tensor(0.), dist_reduce_fx="sum")
|
| 53 |
+
self._initialize_model(model_path, model_arch, enable_fusion)
|
| 54 |
+
|
| 55 |
+
def _initialize_model(self, model_path: tp.Union[str, Path], model_arch: str, enable_fusion: bool):
|
| 56 |
+
model_path = AudioCraftEnvironment.resolve_reference_path(model_path)
|
| 57 |
+
self.tokenize = RobertaTokenizer.from_pretrained('roberta-base')
|
| 58 |
+
self.model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch)
|
| 59 |
+
self.model_sample_rate = 48_000
|
| 60 |
+
load_clap_state_dict(self.model, model_path)
|
| 61 |
+
self.model.eval()
|
| 62 |
+
|
| 63 |
+
def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict:
|
| 64 |
+
# we use the default params from CLAP module here as well
|
| 65 |
+
return self.tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt")
|
| 66 |
+
|
| 67 |
+
def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
|
| 68 |
+
"""Compute cosine similarity between audio and text pairs and accumulate scores over the dataset."""
|
| 69 |
+
assert audio.size(0) == len(text), "Number of audio and text samples should match"
|
| 70 |
+
assert torch.all(sample_rates == sample_rates[0].item()), "All items in batch should have the same sample rate"
|
| 71 |
+
sample_rate = int(sample_rates[0].item())
|
| 72 |
+
# convert audio batch to 48kHz monophonic audio with no channel dimension: [B, C, T] -> [B, T]
|
| 73 |
+
audio = convert_audio(audio, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1).mean(dim=1)
|
| 74 |
+
audio_embeddings = self.model.get_audio_embedding_from_data(audio, use_tensor=True)
|
| 75 |
+
text_embeddings = self.model.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True)
|
| 76 |
+
# cosine similarity between the text and the audio embedding
|
| 77 |
+
cosine_sim = torch.nn.functional.cosine_similarity(audio_embeddings, text_embeddings, dim=1, eps=1e-8)
|
| 78 |
+
self.cosine_sum += cosine_sim.sum(dim=0)
|
| 79 |
+
self.weight += torch.tensor(cosine_sim.size(0))
|
| 80 |
+
|
| 81 |
+
def compute(self):
|
| 82 |
+
"""Computes the average cosine similarty across all audio/text pairs."""
|
| 83 |
+
assert self.weight.item() > 0, "Unable to compute with total number of comparisons <= 0" # type: ignore
|
| 84 |
+
return (self.cosine_sum / self.weight).item() # type: ignore
|
audiocraft/metrics/fad.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
import os
|
| 10 |
+
import subprocess
|
| 11 |
+
import tempfile
|
| 12 |
+
import typing as tp
|
| 13 |
+
|
| 14 |
+
from audiocraft.data.audio import audio_write
|
| 15 |
+
from audiocraft.data.audio_utils import convert_audio
|
| 16 |
+
import flashy
|
| 17 |
+
import torch
|
| 18 |
+
import torchmetrics
|
| 19 |
+
|
| 20 |
+
from ..environment import AudioCraftEnvironment
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
VGGISH_SAMPLE_RATE = 16_000
|
| 26 |
+
VGGISH_CHANNELS = 1
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class FrechetAudioDistanceMetric(torchmetrics.Metric):
|
| 30 |
+
"""Fréchet Audio Distance computation based on official TensorFlow implementation from Google Research.
|
| 31 |
+
|
| 32 |
+
From: D.C. Dowson & B.V. Landau The Fréchet distance between
|
| 33 |
+
multivariate normal distributions
|
| 34 |
+
https://doi.org/10.1016/0047-259X(82)90077-X
|
| 35 |
+
The Fréchet distance between two multivariate gaussians,
|
| 36 |
+
`X ~ N(mu_x, sigma_x)` and `Y ~ N(mu_y, sigma_y)`, is `d^2`.
|
| 37 |
+
d^2 = (mu_x - mu_y)^2 + Tr(sigma_x + sigma_y - 2 * sqrt(sigma_x*sigma_y))
|
| 38 |
+
= (mu_x - mu_y)^2 + Tr(sigma_x) + Tr(sigma_y)
|
| 39 |
+
- 2 * Tr(sqrt(sigma_x*sigma_y)))
|
| 40 |
+
|
| 41 |
+
To use this FAD computation metric, you need to have the proper Frechet Audio Distance tool setup
|
| 42 |
+
from: https://github.com/google-research/google-research/tree/master/frechet_audio_distance
|
| 43 |
+
We provide the below instructions as reference but we do not guarantee for further support
|
| 44 |
+
in frechet_audio_distance installation. This was tested with python 3.10, cuda 11.8, tensorflow 2.12.0.
|
| 45 |
+
|
| 46 |
+
We recommend installing the frechet_audio_distance library in a dedicated env (e.g. conda).
|
| 47 |
+
|
| 48 |
+
1. Get the code and models following the repository instructions. We used the steps below:
|
| 49 |
+
git clone [email protected]:google-research/google-research.git
|
| 50 |
+
git clone [email protected]:tensorflow/models.git
|
| 51 |
+
mkdir google-research/tensorflow_models
|
| 52 |
+
touch google-research/tensorflow_models/__init__.py
|
| 53 |
+
cp -r models/research/audioset google-research/tensorflow_models/
|
| 54 |
+
touch google-research/tensorflow_models/audioset/__init__.py
|
| 55 |
+
echo "from .vggish import mel_features, vggish_params, vggish_slim" > \
|
| 56 |
+
google-research/tensorflow_models/audioset/__init__.py
|
| 57 |
+
# we can now remove the tensorflow models repository
|
| 58 |
+
# rm -r models
|
| 59 |
+
cd google-research
|
| 60 |
+
Follow the instructions to download the vggish checkpoint. AudioCraft base configuration
|
| 61 |
+
assumes it is placed in the AudioCraft reference dir.
|
| 62 |
+
|
| 63 |
+
Note that we operate the following changes for the code to work with TensorFlow 2.X and python 3:
|
| 64 |
+
- Update xrange for range in:
|
| 65 |
+
https://github.com/google-research/google-research/blob/master/frechet_audio_distance/audioset_model.py
|
| 66 |
+
- Update `tf_record = tf.python_io.tf_record_iterator(filename).next()` to
|
| 67 |
+
`tf_record = tf.python_io.tf_record_iterator(filename).__next__()` in
|
| 68 |
+
https://github.com/google-research/google-research/blob/master/frechet_audio_distance/fad_utils.py
|
| 69 |
+
- Update `import vggish_params as params` to `from . import vggish_params as params` in:
|
| 70 |
+
https://github.com/tensorflow/models/blob/master/research/audioset/vggish/vggish_slim.py
|
| 71 |
+
- Add flag to provide a given batch size for running the AudioSet model in:
|
| 72 |
+
https://github.com/google-research/google-research/blob/master/frechet_audio_distance/create_embeddings_main.py
|
| 73 |
+
```
|
| 74 |
+
flags.DEFINE_integer('batch_size', 64,
|
| 75 |
+
'Number of samples in the batch for AudioSet model.')
|
| 76 |
+
```
|
| 77 |
+
Ensure you pass the flag to the create_embeddings_beam.create_pipeline function, adding:
|
| 78 |
+
`batch_size=FLAGS.batch_size` to the provided parameters.
|
| 79 |
+
|
| 80 |
+
2. Follow instructions for the library installation and a valid TensorFlow installation
|
| 81 |
+
```
|
| 82 |
+
# e.g. instructions from: https://www.tensorflow.org/install/pip
|
| 83 |
+
conda install -c conda-forge cudatoolkit=11.8.0
|
| 84 |
+
python3 -m pip install nvidia-cudnn-cu11==8.6.0.163 tensorflow==2.12.*
|
| 85 |
+
mkdir -p $CONDA_PREFIX/etc/conda/activate.d
|
| 86 |
+
echo 'CUDNN_PATH=$(dirname $(python -c "import nvidia.cudnn;print(nvidia.cudnn.__file__)"))' \
|
| 87 |
+
>> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
|
| 88 |
+
echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib/:$CUDNN_PATH/lib' \
|
| 89 |
+
>> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
|
| 90 |
+
source $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
|
| 91 |
+
# Verify install: on a machine with GPU device
|
| 92 |
+
python3 -c "import tensorflow as tf; print(tf.config.list_physical_devices('GPU'))"
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
Now install frechet_audio_distance required dependencies:
|
| 96 |
+
```
|
| 97 |
+
# We assume we already have TensorFlow installed from the above steps
|
| 98 |
+
pip install apache-beam numpy scipy tf_slim
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
Finally, follow remaining library instructions to ensure you have a working frechet_audio_distance setup
|
| 102 |
+
(you may want to specify --model_ckpt flag pointing to the model's path).
|
| 103 |
+
|
| 104 |
+
3. AudioCraft's FrechetAudioDistanceMetric requires 2 environment variables pointing to the python executable
|
| 105 |
+
and Tensorflow library path from the above installation steps:
|
| 106 |
+
export TF_PYTHON_EXE="<PATH_TO_THE_ENV_PYTHON_BINARY>"
|
| 107 |
+
export TF_LIBRARY_PATH="<PATH_TO_THE_ENV_CUDNN_LIBRARY>"
|
| 108 |
+
|
| 109 |
+
e.g. assuming we have installed everything in a dedicated conda env
|
| 110 |
+
with python 3.10 that is currently active:
|
| 111 |
+
export TF_PYTHON_EXE="$CONDA_PREFIX/bin/python"
|
| 112 |
+
export TF_LIBRARY_PATH="$CONDA_PREFIX/lib/python3.10/site-packages/nvidia/cudnn/lib"
|
| 113 |
+
|
| 114 |
+
Finally you may want to export the following variable:
|
| 115 |
+
export TF_FORCE_GPU_ALLOW_GROWTH=true
|
| 116 |
+
See: https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth
|
| 117 |
+
|
| 118 |
+
You can save those environment variables in your training conda env, when currently active:
|
| 119 |
+
`$CONDA_PREFIX/etc/conda/activate.d/env_vars.sh`
|
| 120 |
+
e.g. assuming the env with TensorFlow and frechet_audio_distance install is named ac_eval,
|
| 121 |
+
and the training conda env is named audiocraft:
|
| 122 |
+
```
|
| 123 |
+
# activate training env
|
| 124 |
+
conda activate audiocraft
|
| 125 |
+
# get path to all envs
|
| 126 |
+
CONDA_ENV_DIR=$(dirname $CONDA_PREFIX)
|
| 127 |
+
# export pointers to evaluation env for using TensorFlow in FrechetAudioDistanceMetric
|
| 128 |
+
touch $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
|
| 129 |
+
echo 'export TF_PYTHON_EXE="$CONDA_ENV_DIR/ac_eval/bin/python"' >> \
|
| 130 |
+
$CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
|
| 131 |
+
echo 'export TF_LIBRARY_PATH="$CONDA_ENV_DIR/ac_eval/lib/python3.10/site-packages/nvidia/cudnn/lib"' >> \
|
| 132 |
+
$CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
|
| 133 |
+
# optionally:
|
| 134 |
+
echo 'export TF_FORCE_GPU_ALLOW_GROWTH=true' >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
|
| 135 |
+
# you may need to reactivate the audiocraft env for this to take effect
|
| 136 |
+
```
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
bin (Path or str): Path to installed frechet audio distance code.
|
| 140 |
+
model_path (Path or str): Path to Tensorflow checkpoint for the model
|
| 141 |
+
used to compute statistics over the embedding beams.
|
| 142 |
+
format (str): Audio format used to save files.
|
| 143 |
+
log_folder (Path or str, optional): Path where to write process logs.
|
| 144 |
+
"""
|
| 145 |
+
def __init__(self, bin: tp.Union[Path, str], model_path: tp.Union[Path, str],
|
| 146 |
+
format: str = "wav", batch_size: tp.Optional[int] = None,
|
| 147 |
+
log_folder: tp.Optional[tp.Union[Path, str]] = None):
|
| 148 |
+
super().__init__()
|
| 149 |
+
self.model_sample_rate = VGGISH_SAMPLE_RATE
|
| 150 |
+
self.model_channels = VGGISH_CHANNELS
|
| 151 |
+
self.model_path = AudioCraftEnvironment.resolve_reference_path(model_path)
|
| 152 |
+
assert Path(self.model_path).exists(), f"Could not find provided model checkpoint path at: {self.model_path}"
|
| 153 |
+
self.format = format
|
| 154 |
+
self.batch_size = batch_size
|
| 155 |
+
self.bin = bin
|
| 156 |
+
self.tf_env = {"PYTHONPATH": str(self.bin)}
|
| 157 |
+
self.python_path = os.environ.get('TF_PYTHON_EXE') or 'python'
|
| 158 |
+
logger.info("Python exe for TF is %s", self.python_path)
|
| 159 |
+
if 'TF_LIBRARY_PATH' in os.environ:
|
| 160 |
+
self.tf_env['LD_LIBRARY_PATH'] = os.environ['TF_LIBRARY_PATH']
|
| 161 |
+
if 'TF_FORCE_GPU_ALLOW_GROWTH' in os.environ:
|
| 162 |
+
self.tf_env['TF_FORCE_GPU_ALLOW_GROWTH'] = os.environ['TF_FORCE_GPU_ALLOW_GROWTH']
|
| 163 |
+
logger.info("Env for TF is %r", self.tf_env)
|
| 164 |
+
self.reset(log_folder)
|
| 165 |
+
self.add_state("total_files", default=torch.tensor(0.), dist_reduce_fx="sum")
|
| 166 |
+
|
| 167 |
+
def reset(self, log_folder: tp.Optional[tp.Union[Path, str]] = None):
|
| 168 |
+
"""Reset torchmetrics.Metrics state."""
|
| 169 |
+
log_folder = Path(log_folder or tempfile.mkdtemp())
|
| 170 |
+
self.tmp_dir = log_folder / 'fad'
|
| 171 |
+
self.tmp_dir.mkdir(exist_ok=True)
|
| 172 |
+
self.samples_tests_dir = self.tmp_dir / 'tests'
|
| 173 |
+
self.samples_tests_dir.mkdir(exist_ok=True)
|
| 174 |
+
self.samples_background_dir = self.tmp_dir / 'background'
|
| 175 |
+
self.samples_background_dir.mkdir(exist_ok=True)
|
| 176 |
+
self.manifest_tests = self.tmp_dir / 'files_tests.cvs'
|
| 177 |
+
self.manifest_background = self.tmp_dir / 'files_background.cvs'
|
| 178 |
+
self.stats_tests_dir = self.tmp_dir / 'stats_tests'
|
| 179 |
+
self.stats_background_dir = self.tmp_dir / 'stats_background'
|
| 180 |
+
self.counter = 0
|
| 181 |
+
|
| 182 |
+
def update(self, preds: torch.Tensor, targets: torch.Tensor,
|
| 183 |
+
sizes: torch.Tensor, sample_rates: torch.Tensor,
|
| 184 |
+
stems: tp.Optional[tp.List[str]] = None):
|
| 185 |
+
"""Update torchmetrics.Metrics by saving the audio and updating the manifest file."""
|
| 186 |
+
assert preds.shape == targets.shape, f"preds={preds.shape} != targets={targets.shape}"
|
| 187 |
+
num_samples = preds.shape[0]
|
| 188 |
+
assert num_samples == sizes.size(0) and num_samples == sample_rates.size(0)
|
| 189 |
+
assert stems is None or num_samples == len(set(stems))
|
| 190 |
+
for i in range(num_samples):
|
| 191 |
+
self.total_files += 1 # type: ignore
|
| 192 |
+
self.counter += 1
|
| 193 |
+
wav_len = int(sizes[i].item())
|
| 194 |
+
sample_rate = int(sample_rates[i].item())
|
| 195 |
+
pred_wav = preds[i]
|
| 196 |
+
target_wav = targets[i]
|
| 197 |
+
pred_wav = pred_wav[..., :wav_len]
|
| 198 |
+
target_wav = target_wav[..., :wav_len]
|
| 199 |
+
stem_name = stems[i] if stems is not None else f'sample_{self.counter}_{flashy.distrib.rank()}'
|
| 200 |
+
# dump audio files
|
| 201 |
+
try:
|
| 202 |
+
pred_wav = convert_audio(
|
| 203 |
+
pred_wav.unsqueeze(0), from_rate=sample_rate,
|
| 204 |
+
to_rate=self.model_sample_rate, to_channels=1).squeeze(0)
|
| 205 |
+
audio_write(
|
| 206 |
+
self.samples_tests_dir / stem_name, pred_wav, sample_rate=self.model_sample_rate,
|
| 207 |
+
format=self.format, strategy="peak")
|
| 208 |
+
except Exception as e:
|
| 209 |
+
logger.error(f"Exception occured when saving tests files for FAD computation: {repr(e)} - {e}")
|
| 210 |
+
try:
|
| 211 |
+
# for the ground truth audio, we enforce the 'peak' strategy to avoid modifying
|
| 212 |
+
# the original audio when writing it
|
| 213 |
+
target_wav = convert_audio(
|
| 214 |
+
target_wav.unsqueeze(0), from_rate=sample_rate,
|
| 215 |
+
to_rate=self.model_sample_rate, to_channels=1).squeeze(0)
|
| 216 |
+
audio_write(
|
| 217 |
+
self.samples_background_dir / stem_name, target_wav, sample_rate=self.model_sample_rate,
|
| 218 |
+
format=self.format, strategy="peak")
|
| 219 |
+
except Exception as e:
|
| 220 |
+
logger.error(f"Exception occured when saving background files for FAD computation: {repr(e)} - {e}")
|
| 221 |
+
|
| 222 |
+
def _get_samples_name(self, is_background: bool):
|
| 223 |
+
return 'background' if is_background else 'tests'
|
| 224 |
+
|
| 225 |
+
def _create_embedding_beams(self, is_background: bool, gpu_index: tp.Optional[int] = None):
|
| 226 |
+
if is_background:
|
| 227 |
+
input_samples_dir = self.samples_background_dir
|
| 228 |
+
input_filename = self.manifest_background
|
| 229 |
+
stats_name = self.stats_background_dir
|
| 230 |
+
else:
|
| 231 |
+
input_samples_dir = self.samples_tests_dir
|
| 232 |
+
input_filename = self.manifest_tests
|
| 233 |
+
stats_name = self.stats_tests_dir
|
| 234 |
+
beams_name = self._get_samples_name(is_background)
|
| 235 |
+
log_file = self.tmp_dir / f'fad_logs_create_beams_{beams_name}.log'
|
| 236 |
+
|
| 237 |
+
logger.info(f"Scanning samples folder to fetch list of files: {input_samples_dir}")
|
| 238 |
+
with open(input_filename, "w") as fout:
|
| 239 |
+
for path in Path(input_samples_dir).glob(f"*.{self.format}"):
|
| 240 |
+
fout.write(f"{str(path)}\n")
|
| 241 |
+
|
| 242 |
+
cmd = [
|
| 243 |
+
self.python_path, "-m",
|
| 244 |
+
"frechet_audio_distance.create_embeddings_main",
|
| 245 |
+
"--model_ckpt", f"{self.model_path}",
|
| 246 |
+
"--input_files", f"{str(input_filename)}",
|
| 247 |
+
"--stats", f"{str(stats_name)}",
|
| 248 |
+
]
|
| 249 |
+
if self.batch_size is not None:
|
| 250 |
+
cmd += ["--batch_size", str(self.batch_size)]
|
| 251 |
+
logger.info(f"Launching frechet_audio_distance embeddings main method: {' '.join(cmd)} on {beams_name}")
|
| 252 |
+
env = os.environ
|
| 253 |
+
if gpu_index is not None:
|
| 254 |
+
env["CUDA_VISIBLE_DEVICES"] = str(gpu_index)
|
| 255 |
+
process = subprocess.Popen(
|
| 256 |
+
cmd, stdout=open(log_file, "w"), env={**env, **self.tf_env}, stderr=subprocess.STDOUT)
|
| 257 |
+
return process, log_file
|
| 258 |
+
|
| 259 |
+
def _compute_fad_score(self, gpu_index: tp.Optional[int] = None):
|
| 260 |
+
cmd = [
|
| 261 |
+
self.python_path, "-m", "frechet_audio_distance.compute_fad",
|
| 262 |
+
"--test_stats", f"{str(self.stats_tests_dir)}",
|
| 263 |
+
"--background_stats", f"{str(self.stats_background_dir)}",
|
| 264 |
+
]
|
| 265 |
+
logger.info(f"Launching frechet_audio_distance compute fad method: {' '.join(cmd)}")
|
| 266 |
+
env = os.environ
|
| 267 |
+
if gpu_index is not None:
|
| 268 |
+
env["CUDA_VISIBLE_DEVICES"] = str(gpu_index)
|
| 269 |
+
result = subprocess.run(cmd, env={**env, **self.tf_env}, capture_output=True)
|
| 270 |
+
if result.returncode:
|
| 271 |
+
logger.error(
|
| 272 |
+
"Error with FAD computation from stats: \n %s \n %s",
|
| 273 |
+
result.stdout.decode(), result.stderr.decode()
|
| 274 |
+
)
|
| 275 |
+
raise RuntimeError("Error while executing FAD computation from stats")
|
| 276 |
+
try:
|
| 277 |
+
# result is "FAD: (d+).(d+)" hence we remove the prefix with (d+) being one digit or more
|
| 278 |
+
fad_score = float(result.stdout[4:])
|
| 279 |
+
return fad_score
|
| 280 |
+
except Exception as e:
|
| 281 |
+
raise RuntimeError(f"Error parsing FAD score from command stdout: {e}")
|
| 282 |
+
|
| 283 |
+
def _log_process_result(self, returncode: int, log_file: tp.Union[Path, str], is_background: bool) -> None:
|
| 284 |
+
beams_name = self._get_samples_name(is_background)
|
| 285 |
+
if returncode:
|
| 286 |
+
with open(log_file, "r") as f:
|
| 287 |
+
error_log = f.read()
|
| 288 |
+
logger.error(error_log)
|
| 289 |
+
os._exit(1)
|
| 290 |
+
else:
|
| 291 |
+
logger.info(f"Successfully computed embedding beams on {beams_name} samples.")
|
| 292 |
+
|
| 293 |
+
def _parallel_create_embedding_beams(self, num_of_gpus: int):
|
| 294 |
+
assert num_of_gpus > 0
|
| 295 |
+
logger.info("Creating embeddings beams in a parallel manner on different GPUs")
|
| 296 |
+
tests_beams_process, tests_beams_log_file = self._create_embedding_beams(is_background=False, gpu_index=0)
|
| 297 |
+
bg_beams_process, bg_beams_log_file = self._create_embedding_beams(is_background=True, gpu_index=1)
|
| 298 |
+
tests_beams_code = tests_beams_process.wait()
|
| 299 |
+
bg_beams_code = bg_beams_process.wait()
|
| 300 |
+
self._log_process_result(tests_beams_code, tests_beams_log_file, is_background=False)
|
| 301 |
+
self._log_process_result(bg_beams_code, bg_beams_log_file, is_background=True)
|
| 302 |
+
|
| 303 |
+
def _sequential_create_embedding_beams(self):
|
| 304 |
+
logger.info("Creating embeddings beams in a sequential manner")
|
| 305 |
+
tests_beams_process, tests_beams_log_file = self._create_embedding_beams(is_background=False)
|
| 306 |
+
tests_beams_code = tests_beams_process.wait()
|
| 307 |
+
self._log_process_result(tests_beams_code, tests_beams_log_file, is_background=False)
|
| 308 |
+
bg_beams_process, bg_beams_log_file = self._create_embedding_beams(is_background=True)
|
| 309 |
+
bg_beams_code = bg_beams_process.wait()
|
| 310 |
+
self._log_process_result(bg_beams_code, bg_beams_log_file, is_background=True)
|
| 311 |
+
|
| 312 |
+
@flashy.distrib.rank_zero_only
|
| 313 |
+
def _local_compute_frechet_audio_distance(self):
|
| 314 |
+
"""Compute Frechet Audio Distance score calling TensorFlow API."""
|
| 315 |
+
num_of_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
|
| 316 |
+
if num_of_gpus > 1:
|
| 317 |
+
self._parallel_create_embedding_beams(num_of_gpus)
|
| 318 |
+
else:
|
| 319 |
+
self._sequential_create_embedding_beams()
|
| 320 |
+
fad_score = self._compute_fad_score(gpu_index=0)
|
| 321 |
+
return fad_score
|
| 322 |
+
|
| 323 |
+
def compute(self) -> float:
|
| 324 |
+
"""Compute metrics."""
|
| 325 |
+
assert self.total_files.item() > 0, "No files dumped for FAD computation!" # type: ignore
|
| 326 |
+
fad_score = self._local_compute_frechet_audio_distance()
|
| 327 |
+
logger.warning(f"FAD score = {fad_score}")
|
| 328 |
+
fad_score = flashy.distrib.broadcast_object(fad_score, src=0)
|
| 329 |
+
return fad_score
|
audiocraft/metrics/kld.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import contextlib
|
| 8 |
+
from functools import partial
|
| 9 |
+
import logging
|
| 10 |
+
import os
|
| 11 |
+
import typing as tp
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torchmetrics
|
| 15 |
+
|
| 16 |
+
from ..data.audio_utils import convert_audio
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class _patch_passt_stft:
|
| 23 |
+
"""Decorator to patch torch.stft in PaSST."""
|
| 24 |
+
def __init__(self):
|
| 25 |
+
self.old_stft = torch.stft
|
| 26 |
+
|
| 27 |
+
def __enter__(self):
|
| 28 |
+
# return_complex is a mandatory parameter in latest torch versions
|
| 29 |
+
# torch is throwing RuntimeErrors when not set
|
| 30 |
+
torch.stft = partial(torch.stft, return_complex=False)
|
| 31 |
+
|
| 32 |
+
def __exit__(self, *exc):
|
| 33 |
+
torch.stft = self.old_stft
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def kl_divergence(pred_probs: torch.Tensor, target_probs: torch.Tensor, epsilon: float = 1e-6) -> torch.Tensor:
|
| 37 |
+
"""Computes the elementwise KL-Divergence loss between probability distributions
|
| 38 |
+
from generated samples and target samples.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
pred_probs (torch.Tensor): Probabilities for each label obtained
|
| 42 |
+
from a classifier on generated audio. Expected shape is [B, num_classes].
|
| 43 |
+
target_probs (torch.Tensor): Probabilities for each label obtained
|
| 44 |
+
from a classifier on target audio. Expected shape is [B, num_classes].
|
| 45 |
+
epsilon (float): Epsilon value.
|
| 46 |
+
Returns:
|
| 47 |
+
kld (torch.Tensor): KLD loss between each generated sample and target pair.
|
| 48 |
+
"""
|
| 49 |
+
kl_div = torch.nn.functional.kl_div((pred_probs + epsilon).log(), target_probs, reduction="none")
|
| 50 |
+
return kl_div.sum(-1)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class KLDivergenceMetric(torchmetrics.Metric):
|
| 54 |
+
"""Base implementation for KL Divergence metric.
|
| 55 |
+
|
| 56 |
+
The KL divergence is measured between probability distributions
|
| 57 |
+
of class predictions returned by a pre-trained audio classification model.
|
| 58 |
+
When the KL-divergence is low, the generated audio is expected to
|
| 59 |
+
have similar acoustic characteristics as the reference audio,
|
| 60 |
+
according to the classifier.
|
| 61 |
+
"""
|
| 62 |
+
def __init__(self):
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.add_state("kld_pq_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
|
| 65 |
+
self.add_state("kld_qp_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
|
| 66 |
+
self.add_state("kld_all_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
|
| 67 |
+
self.add_state("weight", default=torch.tensor(0), dist_reduce_fx="sum")
|
| 68 |
+
|
| 69 |
+
def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor,
|
| 70 |
+
sample_rates: torch.Tensor) -> tp.Optional[torch.Tensor]:
|
| 71 |
+
"""Get model output given provided input tensor.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
x (torch.Tensor): Input audio tensor of shape [B, C, T].
|
| 75 |
+
sizes (torch.Tensor): Actual audio sample length, of shape [B].
|
| 76 |
+
sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
|
| 77 |
+
Returns:
|
| 78 |
+
probs (torch.Tensor): Probabilities over labels, of shape [B, num_classes].
|
| 79 |
+
"""
|
| 80 |
+
raise NotImplementedError("implement method to extract label distributions from the model.")
|
| 81 |
+
|
| 82 |
+
def update(self, preds: torch.Tensor, targets: torch.Tensor,
|
| 83 |
+
sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
|
| 84 |
+
"""Calculates running KL-Divergence loss between batches of audio
|
| 85 |
+
preds (generated) and target (ground-truth)
|
| 86 |
+
Args:
|
| 87 |
+
preds (torch.Tensor): Audio samples to evaluate, of shape [B, C, T].
|
| 88 |
+
targets (torch.Tensor): Target samples to compare against, of shape [B, C, T].
|
| 89 |
+
sizes (torch.Tensor): Actual audio sample length, of shape [B].
|
| 90 |
+
sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
|
| 91 |
+
"""
|
| 92 |
+
assert preds.shape == targets.shape
|
| 93 |
+
assert preds.size(0) > 0, "Cannot update the loss with empty tensors"
|
| 94 |
+
preds_probs = self._get_label_distribution(preds, sizes, sample_rates)
|
| 95 |
+
targets_probs = self._get_label_distribution(targets, sizes, sample_rates)
|
| 96 |
+
if preds_probs is not None and targets_probs is not None:
|
| 97 |
+
assert preds_probs.shape == targets_probs.shape
|
| 98 |
+
kld_scores = kl_divergence(preds_probs, targets_probs)
|
| 99 |
+
assert not torch.isnan(kld_scores).any(), "kld_scores contains NaN value(s)!"
|
| 100 |
+
self.kld_pq_sum += torch.sum(kld_scores)
|
| 101 |
+
kld_qp_scores = kl_divergence(targets_probs, preds_probs)
|
| 102 |
+
self.kld_qp_sum += torch.sum(kld_qp_scores)
|
| 103 |
+
self.weight += torch.tensor(kld_scores.size(0))
|
| 104 |
+
|
| 105 |
+
def compute(self) -> dict:
|
| 106 |
+
"""Computes KL-Divergence across all evaluated pred/target pairs."""
|
| 107 |
+
weight: float = float(self.weight.item()) # type: ignore
|
| 108 |
+
assert weight > 0, "Unable to compute with total number of comparisons <= 0"
|
| 109 |
+
logger.info(f"Computing KL divergence on a total of {weight} samples")
|
| 110 |
+
kld_pq = self.kld_pq_sum.item() / weight # type: ignore
|
| 111 |
+
kld_qp = self.kld_qp_sum.item() / weight # type: ignore
|
| 112 |
+
kld_both = kld_pq + kld_qp
|
| 113 |
+
return {'kld': kld_pq, 'kld_pq': kld_pq, 'kld_qp': kld_qp, 'kld_both': kld_both}
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class PasstKLDivergenceMetric(KLDivergenceMetric):
|
| 117 |
+
"""KL-Divergence metric based on pre-trained PASST classifier on AudioSet.
|
| 118 |
+
|
| 119 |
+
From: PaSST: Efficient Training of Audio Transformers with Patchout
|
| 120 |
+
Paper: https://arxiv.org/abs/2110.05069
|
| 121 |
+
Implementation: https://github.com/kkoutini/PaSST
|
| 122 |
+
|
| 123 |
+
Follow instructions from the github repo:
|
| 124 |
+
```
|
| 125 |
+
pip install 'git+https://github.com/kkoutini/[email protected]#egg=hear21passt'
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
pretrained_length (float, optional): Audio duration used for the pretrained model.
|
| 130 |
+
"""
|
| 131 |
+
def __init__(self, pretrained_length: tp.Optional[float] = None):
|
| 132 |
+
super().__init__()
|
| 133 |
+
self._initialize_model(pretrained_length)
|
| 134 |
+
|
| 135 |
+
def _initialize_model(self, pretrained_length: tp.Optional[float] = None):
|
| 136 |
+
"""Initialize underlying PaSST audio classifier."""
|
| 137 |
+
model, sr, max_frames, min_frames = self._load_base_model(pretrained_length)
|
| 138 |
+
self.min_input_frames = min_frames
|
| 139 |
+
self.max_input_frames = max_frames
|
| 140 |
+
self.model_sample_rate = sr
|
| 141 |
+
self.model = model
|
| 142 |
+
self.model.eval()
|
| 143 |
+
self.model.to(self.device)
|
| 144 |
+
|
| 145 |
+
def _load_base_model(self, pretrained_length: tp.Optional[float]):
|
| 146 |
+
"""Load pretrained model from PaSST."""
|
| 147 |
+
try:
|
| 148 |
+
if pretrained_length == 30:
|
| 149 |
+
from hear21passt.base30sec import get_basic_model # type: ignore
|
| 150 |
+
max_duration = 30
|
| 151 |
+
elif pretrained_length == 20:
|
| 152 |
+
from hear21passt.base20sec import get_basic_model # type: ignore
|
| 153 |
+
max_duration = 20
|
| 154 |
+
else:
|
| 155 |
+
from hear21passt.base import get_basic_model # type: ignore
|
| 156 |
+
# Original PASST was trained on AudioSet with 10s-long audio samples
|
| 157 |
+
max_duration = 10
|
| 158 |
+
min_duration = 0.15
|
| 159 |
+
min_duration = 0.15
|
| 160 |
+
except ModuleNotFoundError:
|
| 161 |
+
raise ModuleNotFoundError(
|
| 162 |
+
"Please install hear21passt to compute KL divergence: ",
|
| 163 |
+
"pip install 'git+https://github.com/kkoutini/[email protected]#egg=hear21passt'"
|
| 164 |
+
)
|
| 165 |
+
model_sample_rate = 32_000
|
| 166 |
+
max_input_frames = int(max_duration * model_sample_rate)
|
| 167 |
+
min_input_frames = int(min_duration * model_sample_rate)
|
| 168 |
+
with open(os.devnull, 'w') as f, contextlib.redirect_stdout(f):
|
| 169 |
+
model = get_basic_model(mode='logits')
|
| 170 |
+
return model, model_sample_rate, max_input_frames, min_input_frames
|
| 171 |
+
|
| 172 |
+
def _process_audio(self, wav: torch.Tensor, sample_rate: int, wav_len: int) -> tp.List[torch.Tensor]:
|
| 173 |
+
"""Process audio to feed to the pretrained model."""
|
| 174 |
+
wav = wav.unsqueeze(0)
|
| 175 |
+
wav = wav[..., :wav_len]
|
| 176 |
+
wav = convert_audio(wav, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1)
|
| 177 |
+
wav = wav.squeeze(0)
|
| 178 |
+
# we don't pad but return a list of audio segments as this otherwise affects the KLD computation
|
| 179 |
+
segments = torch.split(wav, self.max_input_frames, dim=-1)
|
| 180 |
+
valid_segments = []
|
| 181 |
+
for s in segments:
|
| 182 |
+
# ignoring too small segments that are breaking the model inference
|
| 183 |
+
if s.size(-1) > self.min_input_frames:
|
| 184 |
+
valid_segments.append(s)
|
| 185 |
+
return [s[None] for s in valid_segments]
|
| 186 |
+
|
| 187 |
+
def _get_model_preds(self, wav: torch.Tensor) -> torch.Tensor:
|
| 188 |
+
"""Run the pretrained model and get the predictions."""
|
| 189 |
+
assert wav.dim() == 3, f"Unexpected number of dims for preprocessed wav: {wav.shape}"
|
| 190 |
+
wav = wav.mean(dim=1)
|
| 191 |
+
# PaSST is printing a lot of garbage that we are not interested in
|
| 192 |
+
with open(os.devnull, "w") as f, contextlib.redirect_stdout(f):
|
| 193 |
+
with torch.no_grad(), _patch_passt_stft():
|
| 194 |
+
logits = self.model(wav.to(self.device))
|
| 195 |
+
probs = torch.softmax(logits, dim=-1)
|
| 196 |
+
return probs
|
| 197 |
+
|
| 198 |
+
def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor,
|
| 199 |
+
sample_rates: torch.Tensor) -> tp.Optional[torch.Tensor]:
|
| 200 |
+
"""Get model output given provided input tensor.
|
| 201 |
+
|
| 202 |
+
Args:
|
| 203 |
+
x (torch.Tensor): Input audio tensor of shape [B, C, T].
|
| 204 |
+
sizes (torch.Tensor): Actual audio sample length, of shape [B].
|
| 205 |
+
sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
|
| 206 |
+
Returns:
|
| 207 |
+
probs (torch.Tensor, optional): Probabilities over labels, of shape [B, num_classes].
|
| 208 |
+
"""
|
| 209 |
+
all_probs: tp.List[torch.Tensor] = []
|
| 210 |
+
for i, wav in enumerate(x):
|
| 211 |
+
sample_rate = int(sample_rates[i].item())
|
| 212 |
+
wav_len = int(sizes[i].item())
|
| 213 |
+
wav_segments = self._process_audio(wav, sample_rate, wav_len)
|
| 214 |
+
for segment in wav_segments:
|
| 215 |
+
probs = self._get_model_preds(segment).mean(dim=0)
|
| 216 |
+
all_probs.append(probs)
|
| 217 |
+
if len(all_probs) > 0:
|
| 218 |
+
return torch.stack(all_probs, dim=0)
|
| 219 |
+
else:
|
| 220 |
+
return None
|
audiocraft/metrics/rvm.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import typing as tp
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn
|
| 10 |
+
import torchaudio
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def db_to_scale(volume: tp.Union[float, torch.Tensor]):
|
| 14 |
+
return 10 ** (volume / 20)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def scale_to_db(scale: torch.Tensor, min_volume: float = -120):
|
| 18 |
+
min_scale = db_to_scale(min_volume)
|
| 19 |
+
return 20 * torch.log10(scale.clamp(min=min_scale))
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class RelativeVolumeMel(nn.Module):
|
| 23 |
+
"""Relative volume melspectrogram measure.
|
| 24 |
+
|
| 25 |
+
Computes a measure of distance over two mel spectrogram that is interpretable in terms
|
| 26 |
+
of decibels. Given `x_ref` and `x_est` two waveforms of shape `[*, T]`, it will
|
| 27 |
+
first renormalize both by the ground truth of `x_ref`.
|
| 28 |
+
|
| 29 |
+
..Warning:: This class returns the volume of the distortion at the spectrogram level,
|
| 30 |
+
e.g. low negative values reflects lower distortion levels. For a SNR (like reported
|
| 31 |
+
in the MultiBandDiffusion paper), just take `-rvm`.
|
| 32 |
+
|
| 33 |
+
Then it computes the mel spectrogram `z_ref` and `z_est` and compute volume of the difference
|
| 34 |
+
relative to the volume of `z_ref` for each time-frequency bin. It further adds some limits, e.g.
|
| 35 |
+
clamping the values between -25 and 25 dB (controlled by `min_relative_volume` and `max_relative_volume`)
|
| 36 |
+
with the goal of avoiding the loss being dominated by parts where the reference is almost silent.
|
| 37 |
+
Indeed, volumes in dB can take unbounded values both towards -oo and +oo, which can make the final
|
| 38 |
+
average metric harder to interpret. Besides, anything below -30 dB of attenuation would sound extremely
|
| 39 |
+
good (for a neural network output, although sound engineers typically aim for much lower attenuations).
|
| 40 |
+
Similarly, anything above +30 dB would just be completely missing the target, and there is no point
|
| 41 |
+
in measuring by exactly how much it missed it. -25, 25 is a more conservative range, but also more
|
| 42 |
+
in line with what neural nets currently can achieve.
|
| 43 |
+
|
| 44 |
+
For instance, a Relative Volume Mel (RVM) score of -10 dB means that on average, the delta between
|
| 45 |
+
the target and reference mel-spec is 10 dB lower than the reference mel-spec value.
|
| 46 |
+
|
| 47 |
+
The metric can be aggregated over a given frequency band in order have different insights for
|
| 48 |
+
different region of the spectrum. `num_aggregated_bands` controls the number of bands.
|
| 49 |
+
|
| 50 |
+
..Warning:: While this function is optimized for interpretability, nothing was done to ensure it
|
| 51 |
+
is numerically stable when computing its gradient. We thus advise against using it as a training loss.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
sample_rate (int): Sample rate of the input audio.
|
| 55 |
+
n_mels (int): Number of mel bands to use.
|
| 56 |
+
n_fft (int): Number of frequency bins for the STFT.
|
| 57 |
+
hop_length (int): Hop length of the STFT and the mel-spectrogram.
|
| 58 |
+
min_relative_volume (float): The error `z_ref - z_est` volume is given relative to
|
| 59 |
+
the volume of `z_ref`. If error is smaller than -25 dB of `z_ref`, then it is clamped.
|
| 60 |
+
max_relative_volume (float): Same as `min_relative_volume` but clamping if the error is larger than that.
|
| 61 |
+
max_initial_gain (float): When rescaling the audio at the very beginning, we will limit the gain
|
| 62 |
+
to that amount, to avoid rescaling near silence. Given in dB.
|
| 63 |
+
min_activity_volume (float): When computing the reference level from `z_ref`, will clamp low volume
|
| 64 |
+
bins to that amount. This is effectively our "zero" level for the reference mel-spectrogram,
|
| 65 |
+
and anything below that will be considered equally.
|
| 66 |
+
num_aggregated_bands (int): Number of bands to keep when computing the average RVM value.
|
| 67 |
+
For instance, a value of 3 would give 3 scores, roughly for low, mid and high freqs.
|
| 68 |
+
"""
|
| 69 |
+
def __init__(self, sample_rate: int = 24000, n_mels: int = 80, n_fft: int = 512,
|
| 70 |
+
hop_length: int = 128, min_relative_volume: float = -25,
|
| 71 |
+
max_relative_volume: float = 25, max_initial_gain: float = 25,
|
| 72 |
+
min_activity_volume: float = -25,
|
| 73 |
+
num_aggregated_bands: int = 4) -> None:
|
| 74 |
+
super().__init__()
|
| 75 |
+
self.melspec = torchaudio.transforms.MelSpectrogram(
|
| 76 |
+
n_mels=n_mels, n_fft=n_fft, hop_length=hop_length,
|
| 77 |
+
normalized=True, sample_rate=sample_rate, power=2)
|
| 78 |
+
self.min_relative_volume = min_relative_volume
|
| 79 |
+
self.max_relative_volume = max_relative_volume
|
| 80 |
+
self.max_initial_gain = max_initial_gain
|
| 81 |
+
self.min_activity_volume = min_activity_volume
|
| 82 |
+
self.num_aggregated_bands = num_aggregated_bands
|
| 83 |
+
|
| 84 |
+
def forward(self, estimate: torch.Tensor, ground_truth: torch.Tensor) -> tp.Dict[str, torch.Tensor]:
|
| 85 |
+
"""Compute RVM metric between estimate and reference samples.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
estimate (torch.Tensor): Estimate sample.
|
| 89 |
+
ground_truth (torch.Tensor): Reference sample.
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
dict[str, torch.Tensor]: Metrics with keys `rvm` for the overall average, and `rvm_{k}`
|
| 93 |
+
for the RVM over the k-th band (k=0..num_aggregated_bands - 1).
|
| 94 |
+
"""
|
| 95 |
+
min_scale = db_to_scale(-self.max_initial_gain)
|
| 96 |
+
std = ground_truth.pow(2).mean().sqrt().clamp(min=min_scale)
|
| 97 |
+
z_gt = self.melspec(ground_truth / std).sqrt()
|
| 98 |
+
z_est = self.melspec(estimate / std).sqrt()
|
| 99 |
+
|
| 100 |
+
delta = z_gt - z_est
|
| 101 |
+
ref_db = scale_to_db(z_gt, self.min_activity_volume)
|
| 102 |
+
delta_db = scale_to_db(delta.abs(), min_volume=-120)
|
| 103 |
+
relative_db = (delta_db - ref_db).clamp(self.min_relative_volume, self.max_relative_volume)
|
| 104 |
+
dims = list(range(relative_db.dim()))
|
| 105 |
+
dims.remove(dims[-2])
|
| 106 |
+
losses_per_band = relative_db.mean(dim=dims)
|
| 107 |
+
aggregated = [chunk.mean() for chunk in losses_per_band.chunk(self.num_aggregated_bands, dim=0)]
|
| 108 |
+
metrics = {f'rvm_{index}': value for index, value in enumerate(aggregated)}
|
| 109 |
+
metrics['rvm'] = losses_per_band.mean()
|
| 110 |
+
return metrics
|
audiocraft/metrics/visqol.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import csv
|
| 8 |
+
import json
|
| 9 |
+
import logging
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
import tempfile
|
| 12 |
+
import typing as tp
|
| 13 |
+
import subprocess
|
| 14 |
+
import shutil
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torchaudio
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ViSQOL:
|
| 23 |
+
"""ViSQOL wrapper to run ViSQOL from Python using a pre-installed binary.
|
| 24 |
+
|
| 25 |
+
To learn more about ViSQOL and how to build ViSQOL binary using bazel, please refer to the
|
| 26 |
+
instructions available in the open source repository: https://github.com/google/visqol
|
| 27 |
+
|
| 28 |
+
ViSQOL is capable of running in two modes:
|
| 29 |
+
|
| 30 |
+
Audio Mode:
|
| 31 |
+
When running in audio mode, input signals must have a 48kHz sample rate. Input should be resampled to 48kHz.
|
| 32 |
+
Input signals can be multi-channel, but they will be down-mixed to mono for performing the comparison.
|
| 33 |
+
Audio mode uses support vector regression, with the maximum range at ~4.75.
|
| 34 |
+
|
| 35 |
+
Speech Mode:
|
| 36 |
+
When running in speech mode, ViSQOL uses a wideband model. It therefore expects input sample rates of 16kHz.
|
| 37 |
+
Input should be resampled to 16kHz.
|
| 38 |
+
As part of the speech mode processing, a root mean square implementation for voice activity detection
|
| 39 |
+
is performed on the reference signal to determine what parts of the signal have voice activity and
|
| 40 |
+
should therefore be included in the comparison. The signal is normalized before performing the voice
|
| 41 |
+
activity detection.
|
| 42 |
+
Input signals can be multi-channel, but they will be down-mixed to mono for performing the comparison.
|
| 43 |
+
Speech mode is scaled to have a maximum MOS of 5.0 to match previous version behavior.
|
| 44 |
+
|
| 45 |
+
For more details, check the guidelines: https://github.com/google/visqol#general-guidelines-for-input
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
visqol_bin (str): Path to the ViSQOL binary.
|
| 49 |
+
mode (str): ViSQOL computation mode, expecting "audio" or "speech".
|
| 50 |
+
model (str): Name of the model to use for similarity to quality model.
|
| 51 |
+
debug (bool): Whether to also get debug metrics from ViSQOL or not.
|
| 52 |
+
"""
|
| 53 |
+
SAMPLE_RATES_MODES = {"audio": 48_000, "speech": 16_000}
|
| 54 |
+
ALLOWED_SAMPLE_RATES = frozenset(SAMPLE_RATES_MODES.values())
|
| 55 |
+
|
| 56 |
+
def __init__(self, bin: tp.Union[Path, str], mode: str = "audio",
|
| 57 |
+
model: str = "libsvm_nu_svr_model.txt", debug: bool = False):
|
| 58 |
+
assert bin is not None and Path(bin).exists(), f"Could not find ViSQOL binary in specified path: {bin}"
|
| 59 |
+
self.visqol_bin = str(bin)
|
| 60 |
+
self.visqol_mode = mode
|
| 61 |
+
self.target_sr = self._get_target_sr(self.visqol_mode)
|
| 62 |
+
self.model = model
|
| 63 |
+
self.debug = debug
|
| 64 |
+
assert Path(self.visqol_model).exists(), \
|
| 65 |
+
f"Could not find the specified model in ViSQOL install: {self.visqol_model}"
|
| 66 |
+
|
| 67 |
+
def _get_target_sr(self, mode: str) -> int:
|
| 68 |
+
# returns target sampling rate for the corresponding ViSQOL mode.
|
| 69 |
+
if mode not in ViSQOL.SAMPLE_RATES_MODES:
|
| 70 |
+
raise ValueError(
|
| 71 |
+
f"Unsupported mode! Allowed are: {', '.join(ViSQOL.SAMPLE_RATES_MODES.keys())}"
|
| 72 |
+
)
|
| 73 |
+
return ViSQOL.SAMPLE_RATES_MODES[mode]
|
| 74 |
+
|
| 75 |
+
def _prepare_files(
|
| 76 |
+
self, ref_sig: torch.Tensor, deg_sig: torch.Tensor, sr: int, target_sr: int, pad_with_silence: bool = False
|
| 77 |
+
):
|
| 78 |
+
# prepare files for ViSQOL evaluation.
|
| 79 |
+
assert target_sr in ViSQOL.ALLOWED_SAMPLE_RATES
|
| 80 |
+
assert len(ref_sig) == len(deg_sig), (
|
| 81 |
+
"Expects same number of ref and degraded inputs",
|
| 82 |
+
f" but ref len {len(ref_sig)} != deg len {len(deg_sig)}"
|
| 83 |
+
)
|
| 84 |
+
# resample audio if needed
|
| 85 |
+
if sr != target_sr:
|
| 86 |
+
transform = torchaudio.transforms.Resample(sr, target_sr)
|
| 87 |
+
pad = int(0.5 * target_sr)
|
| 88 |
+
rs_ref = []
|
| 89 |
+
rs_deg = []
|
| 90 |
+
for i in range(len(ref_sig)):
|
| 91 |
+
rs_ref_i = transform(ref_sig[i])
|
| 92 |
+
rs_deg_i = transform(deg_sig[i])
|
| 93 |
+
if pad_with_silence:
|
| 94 |
+
rs_ref_i = torch.nn.functional.pad(rs_ref_i, (pad, pad), mode='constant', value=0)
|
| 95 |
+
rs_deg_i = torch.nn.functional.pad(rs_deg_i, (pad, pad), mode='constant', value=0)
|
| 96 |
+
rs_ref.append(rs_ref_i)
|
| 97 |
+
rs_deg.append(rs_deg_i)
|
| 98 |
+
ref_sig = torch.stack(rs_ref)
|
| 99 |
+
deg_sig = torch.stack(rs_deg)
|
| 100 |
+
# save audio chunks to tmp dir and create csv
|
| 101 |
+
tmp_dir = Path(tempfile.mkdtemp())
|
| 102 |
+
try:
|
| 103 |
+
tmp_input_csv_path = tmp_dir / "input.csv"
|
| 104 |
+
tmp_results_csv_path = tmp_dir / "results.csv"
|
| 105 |
+
tmp_debug_json_path = tmp_dir / "debug.json"
|
| 106 |
+
with open(tmp_input_csv_path, "w") as csv_file:
|
| 107 |
+
csv_writer = csv.writer(csv_file)
|
| 108 |
+
csv_writer.writerow(["reference", "degraded"])
|
| 109 |
+
for i in range(len(ref_sig)):
|
| 110 |
+
tmp_ref_filename = tmp_dir / f"ref_{i}.wav"
|
| 111 |
+
tmp_deg_filename = tmp_dir / f"deg_{i}.wav"
|
| 112 |
+
torchaudio.save(
|
| 113 |
+
tmp_ref_filename,
|
| 114 |
+
torch.clamp(ref_sig[i], min=-0.99, max=0.99),
|
| 115 |
+
sample_rate=target_sr,
|
| 116 |
+
bits_per_sample=16,
|
| 117 |
+
encoding="PCM_S"
|
| 118 |
+
)
|
| 119 |
+
torchaudio.save(
|
| 120 |
+
tmp_deg_filename,
|
| 121 |
+
torch.clamp(deg_sig[i], min=-0.99, max=0.99),
|
| 122 |
+
sample_rate=target_sr,
|
| 123 |
+
bits_per_sample=16,
|
| 124 |
+
encoding="PCM_S"
|
| 125 |
+
)
|
| 126 |
+
csv_writer.writerow([str(tmp_ref_filename), str(tmp_deg_filename)])
|
| 127 |
+
return tmp_dir, tmp_input_csv_path, tmp_results_csv_path, tmp_debug_json_path
|
| 128 |
+
except Exception as e:
|
| 129 |
+
logger.error("Exception occurred when preparing files for ViSQOL: %s", e)
|
| 130 |
+
return tmp_dir, None, None, None
|
| 131 |
+
|
| 132 |
+
def _flush_files(self, tmp_dir: tp.Union[Path, str]):
|
| 133 |
+
# flush tmp files used to compute ViSQOL.
|
| 134 |
+
shutil.rmtree(str(tmp_dir))
|
| 135 |
+
|
| 136 |
+
def _collect_moslqo_score(self, results_csv_path: tp.Union[Path, str]) -> float:
|
| 137 |
+
# collect results for each evaluated pair and return averaged moslqo score.
|
| 138 |
+
with open(results_csv_path, "r") as csv_file:
|
| 139 |
+
reader = csv.DictReader(csv_file)
|
| 140 |
+
moslqo_scores = [float(row["moslqo"]) for row in reader]
|
| 141 |
+
if len(moslqo_scores) > 0:
|
| 142 |
+
return sum(moslqo_scores) / len(moslqo_scores)
|
| 143 |
+
else:
|
| 144 |
+
return 0.0
|
| 145 |
+
|
| 146 |
+
def _collect_debug_data(self, debug_json_path: tp.Union[Path, str]) -> dict:
|
| 147 |
+
# collect debug data for the visqol inference.
|
| 148 |
+
with open(debug_json_path, "r") as f:
|
| 149 |
+
data = json.load(f)
|
| 150 |
+
return data
|
| 151 |
+
|
| 152 |
+
@property
|
| 153 |
+
def visqol_model(self):
|
| 154 |
+
return f'{self.visqol_bin}/model/{self.model}'
|
| 155 |
+
|
| 156 |
+
def _run_visqol(
|
| 157 |
+
self,
|
| 158 |
+
input_csv_path: tp.Union[Path, str],
|
| 159 |
+
results_csv_path: tp.Union[Path, str],
|
| 160 |
+
debug_csv_path: tp.Optional[tp.Union[Path, str]],
|
| 161 |
+
):
|
| 162 |
+
input_csv_path = str(input_csv_path)
|
| 163 |
+
results_csv_path = str(results_csv_path)
|
| 164 |
+
debug_csv_path = str(debug_csv_path)
|
| 165 |
+
cmd = [
|
| 166 |
+
f'{self.visqol_bin}/bazel-bin/visqol',
|
| 167 |
+
'--batch_input_csv', f'{input_csv_path}',
|
| 168 |
+
'--results_csv', f'{results_csv_path}'
|
| 169 |
+
]
|
| 170 |
+
if debug_csv_path is not None:
|
| 171 |
+
cmd += ['--output_debug', f'{debug_csv_path}']
|
| 172 |
+
if self.visqol_mode == "speech":
|
| 173 |
+
cmd += ['--use_speech_mode']
|
| 174 |
+
cmd += ['--similarity_to_quality_model', f'{self.visqol_model}']
|
| 175 |
+
result = subprocess.run(cmd, capture_output=True)
|
| 176 |
+
if result.returncode:
|
| 177 |
+
logger.error("Error with visqol: \n %s \n %s", result.stdout.decode(), result.stderr.decode())
|
| 178 |
+
raise RuntimeError("Error while executing visqol")
|
| 179 |
+
result.check_returncode()
|
| 180 |
+
|
| 181 |
+
def __call__(
|
| 182 |
+
self,
|
| 183 |
+
ref_sig: torch.Tensor,
|
| 184 |
+
deg_sig: torch.Tensor,
|
| 185 |
+
sr: int,
|
| 186 |
+
pad_with_silence: bool = False,
|
| 187 |
+
):
|
| 188 |
+
"""Calculate the ViSQOL metric for a pair of audio signals at a given sample rate.
|
| 189 |
+
Args:
|
| 190 |
+
ref_sig (torch.Tensor): Reference signals as [B, C, T].
|
| 191 |
+
deg_sig (torch.Tensor): Degraded signals as [B, C, T].
|
| 192 |
+
sr (int): Sample rate of the two audio signals.
|
| 193 |
+
pad_with_silence (bool): Whether to pad the file with silences as recommended
|
| 194 |
+
in visqol guidelines (see: https://github.com/google/visqol#general-guidelines-for-input).
|
| 195 |
+
Returns:
|
| 196 |
+
float: The ViSQOL score or mean score for the batch.
|
| 197 |
+
"""
|
| 198 |
+
logger.debug(f"Calculating visqol with mode={self.visqol_mode} on {len(ref_sig)} samples")
|
| 199 |
+
tmp_dir, input_csv, results_csv, debug_json = self._prepare_files(
|
| 200 |
+
ref_sig, deg_sig, sr, self.target_sr, pad_with_silence
|
| 201 |
+
)
|
| 202 |
+
try:
|
| 203 |
+
if input_csv and results_csv:
|
| 204 |
+
self._run_visqol(
|
| 205 |
+
input_csv,
|
| 206 |
+
results_csv,
|
| 207 |
+
debug_json if self.debug else None,
|
| 208 |
+
)
|
| 209 |
+
mosqol = self._collect_moslqo_score(results_csv)
|
| 210 |
+
return mosqol
|
| 211 |
+
else:
|
| 212 |
+
raise RuntimeError("Something unexpected happened when running VISQOL!")
|
| 213 |
+
except Exception as e:
|
| 214 |
+
logger.error("Exception occurred when running ViSQOL: %s", e)
|
| 215 |
+
finally:
|
| 216 |
+
self._flush_files(tmp_dir)
|