Spaces:
Build error
Build error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
from pathlib import Path | |
import typing as tp | |
import torch | |
import torchmetrics | |
from transformers import RobertaTokenizer # type: ignore | |
from ..data.audio_utils import convert_audio | |
from ..environment import AudioCraftEnvironment | |
from ..utils.utils import load_clap_state_dict | |
try: | |
import laion_clap # type: ignore | |
except ImportError: | |
laion_clap = None | |
class TextConsistencyMetric(torchmetrics.Metric): | |
"""Text consistency metric measuring consistency between audio and text pairs.""" | |
def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None: | |
raise NotImplementedError("implement how to update the metric from the audio and text pairs.") | |
def compute(self): | |
raise NotImplementedError("implement how to compute the final metric score.") | |
class CLAPTextConsistencyMetric(TextConsistencyMetric): | |
"""Text consistency metric relying on Contrastive Language-Audio Pretraining (CLAP). | |
This metric is similar to the MuLan Cycle Consistency from MusicLM (https://arxiv.org/pdf/2301.11325.pdf) | |
or the CLAP score used in Make-An-Audio (https://arxiv.org/pdf/2301.12661v1.pdf). | |
As a joint audio-text embedding model, a pretrained CLAP model can be used to quantify the | |
similarity between audio-text pairs. We compute the CLAP embeddings from the text descriptions as | |
well as the generated audio based on them, and define the MCC metric as the average cosine similarity | |
between these embeddings. | |
Model implementation & pre-trained checkpoints: https://github.com/LAION-AI/CLAP | |
""" | |
def __init__(self, model_path: tp.Union[str, Path], model_arch: str = 'HTSAT-tiny', enable_fusion: bool = False): | |
super().__init__() | |
if laion_clap is None: | |
raise ImportError("Please install CLAP to compute text consistency: 'pip install laion_clap'") | |
self.add_state("cosine_sum", default=torch.tensor(0.), dist_reduce_fx="sum") | |
self.add_state("weight", default=torch.tensor(0.), dist_reduce_fx="sum") | |
self._initialize_model(model_path, model_arch, enable_fusion) | |
def _initialize_model(self, model_path: tp.Union[str, Path], model_arch: str, enable_fusion: bool): | |
model_path = AudioCraftEnvironment.resolve_reference_path(model_path) | |
self.tokenize = RobertaTokenizer.from_pretrained('roberta-base') | |
self.model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch) | |
self.model_sample_rate = 48_000 | |
load_clap_state_dict(self.model, model_path) | |
self.model.eval() | |
def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict: | |
# we use the default params from CLAP module here as well | |
return self.tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt") | |
def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None: | |
"""Compute cosine similarity between audio and text pairs and accumulate scores over the dataset.""" | |
assert audio.size(0) == len(text), "Number of audio and text samples should match" | |
assert torch.all(sample_rates == sample_rates[0].item()), "All items in batch should have the same sample rate" | |
sample_rate = int(sample_rates[0].item()) | |
# convert audio batch to 48kHz monophonic audio with no channel dimension: [B, C, T] -> [B, T] | |
audio = convert_audio(audio, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1).mean(dim=1) | |
audio_embeddings = self.model.get_audio_embedding_from_data(audio, use_tensor=True) | |
text_embeddings = self.model.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True) | |
# cosine similarity between the text and the audio embedding | |
cosine_sim = torch.nn.functional.cosine_similarity(audio_embeddings, text_embeddings, dim=1, eps=1e-8) | |
self.cosine_sum += cosine_sim.sum(dim=0) | |
self.weight += torch.tensor(cosine_sim.size(0)) | |
def compute(self): | |
"""Computes the average cosine similarty across all audio/text pairs.""" | |
assert self.weight.item() > 0, "Unable to compute with total number of comparisons <= 0" # type: ignore | |
return (self.cosine_sum / self.weight).item() # type: ignore | |