from __future__ import annotations from typing import List import numpy as np import torch from transformers import Qwen2Tokenizer, Qwen2TokenizerFast, Wav2Vec2FeatureExtractor from transformers.feature_extraction_utils import BatchFeature from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack class MiAudioLLMProcessorKwargs(ProcessingKwargs): _defaults = { "text_kwargs": { "padding": True, "padding_side": "left", }, "audio_kwargs": {}, } def calculate_mel_frames_dasheng( audio_length_samples: int, n_fft: int = 512, hop_size: int = 160, dasheng_subsampling: int = 4, center=True, model_subsampling: int = 5, ) -> int: """Calculate the number of Mel-spectrogram frames.""" if center: audio_length_samples = audio_length_samples + n_fft return ( int(1 + ((audio_length_samples - n_fft) / hop_size)) // dasheng_subsampling // model_subsampling ) class MiAudioLLMProcessor(ProcessorMixin): attributes = ["feature_extractor", "tokenizer"] valid_kwargs = [ "chat_template", "audio_token", "audio_bos_token", "audio_eos_token", ] feature_extractor_class = "Wav2Vec2FeatureExtractor" tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") def __init__( self, feature_extractor: Wav2Vec2FeatureExtractor | None = None, tokenizer: Qwen2Tokenizer | Qwen2TokenizerFast | None = None, model_subsampling: int = 5, chat_template: str | None = None, # TODO 是否可以移除? audio_token: str = "<|AUDIO|>", audio_bos_token: str = "<|audio_bos|>", audio_eos_token: str = "<|audio_eos|>", ): if chat_template is None: chat_template = self.default_chat_template assert tokenizer is not None, "Tokenizer Needs to be passed" self.audio_token = ( tokenizer.audio_token if hasattr(tokenizer, "audio_token") else audio_token ) self.audio_token_id = tokenizer.convert_tokens_to_ids(self.audio_token) self.audio_bos_token = ( tokenizer.audio_bos_token if hasattr(tokenizer, "audio_bos_token") else audio_bos_token ) self.audio_eos_token = ( tokenizer.audio_eos_token if hasattr(tokenizer, "audio_eos_token") else audio_eos_token ) self.model_subsampling = model_subsampling # Fix Normalization if feature_extractor is not None and feature_extractor.do_normalize is True: feature_extractor.do_normalize = False super().__init__(feature_extractor, tokenizer, chat_template=chat_template) def __call__( self, text: List[str] | None = None, audio: List[np.ndarray] | List[torch.Tensor] | None = None, **kwargs: Unpack[MiAudioLLMProcessorKwargs], ) -> BatchFeature: if text is None: raise ValueError("You need to specify `text` input to process.") elif isinstance(text, str): text = [text] elif not isinstance(text, list) and not isinstance(text[0], str): raise ValueError( "Invalid input text. Please provide a string, or a list of strings" ) output_kwargs = self._merge_kwargs( MiAudioLLMProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs, ) if audio is not None: if isinstance(audio[0], torch.Tensor): audio = [sample_.numpy() for sample_ in audio] if isinstance(audio[0], torch.Tensor): audio = [sample_.squeeze(0) for sample_ in audio] if not all(x_.ndim == 1 for x_ in audio): raise ValueError("All samples in a list must be 1D.") if isinstance(audio[0], np.ndarray): if not all(x_.ndim == 1 for x_ in audio): raise ValueError("All samples in a list must be 1D.") # ensure we have as much audios as audio tokens num_audio_tokens = sum(sample.count(self.audio_token) for sample in text) num_audios = 1 if type(audio) is np.ndarray else len(audio) if num_audio_tokens != num_audios: raise ValueError( f"Found {num_audio_tokens} {self.audio_token} token{'s' if num_audio_tokens > 1 else ''} in provided text but received {num_audios} audio{'s' if num_audios > 1 else ''}" ) # Some kwargs should not be changed so we can expand text with audio tokens below output_kwargs["audio_kwargs"]["return_attention_mask"] = True output_kwargs["audio_kwargs"]["padding"] = True output_kwargs["audio_kwargs"]["return_tensors"] = "pt" # + Padding audio_inputs = self.feature_extractor( audio, **output_kwargs["audio_kwargs"] ) # remove attention mask, dasheng uses lengths audio_feature_mask = audio_inputs.pop("attention_mask") expanded_text = [] audio_lengths = audio_feature_mask.sum(-1).tolist() audio_inputs["audio_length"] = torch.tensor(audio_lengths).long() audio_inputs["audio_token_id"] = ( self.audio_token_id ) # Pass to the model such that i knows what is the placeholder id for sample in text: replace_str = [] while self.audio_token in sample: audio_length = audio_lengths.pop(0) num_audio_tokens = calculate_mel_frames_dasheng( audio_length, model_subsampling=self.model_subsampling ) expanded_audio_token = self.audio_token * num_audio_tokens audio_token_start_idx = sample.find(self.audio_token) audio_token_end_idx = audio_token_start_idx + len(self.audio_token) has_bos = ( sample[ audio_token_start_idx - len(self.audio_bos_token) : audio_token_start_idx ] == self.audio_bos_token ) has_eos = ( sample[ audio_token_end_idx : audio_token_end_idx + len(self.audio_eos_token) ] == self.audio_eos_token ) # Check if this audio token is surrounded by bos/eos tokens if not has_bos and not has_eos: expanded_audio_token = ( self.audio_bos_token + expanded_audio_token + self.audio_eos_token ) replace_str.append(expanded_audio_token) sample = sample.replace(self.audio_token, "", 1) while "" in sample: sample = sample.replace("", replace_str.pop(0), 1) expanded_text.append(sample) text = expanded_text return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", "pt") inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) if hasattr(self, "_check_special_mm_tokens"): self._check_special_mm_tokens(text, inputs, modalities=["audio"]) if audio is not None: inputs.update(audio_inputs) return BatchFeature(data={**inputs}, tensor_type=return_tensors) def batch_decode(self, *args, **kwargs): """ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please refer to the docstring of this method for more information. """ return self.tokenizer.batch_decode(*args, **kwargs) def decode(self, *args, **kwargs): """ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to the docstring of this method for more information. """ return self.tokenizer.decode(*args, **kwargs) @property def model_input_names(self): tokenizer_input_names = self.tokenizer.model_input_names feature_extractor_input_names = self.feature_extractor.model_input_names return list( dict.fromkeys( tokenizer_input_names + feature_extractor_input_names + ["audio_length"] ) ) @property # NOTE: we don't have default templates anymore, and the below is kept only because the hub config is not yet updated! def default_chat_template(self): """ This default vicuna template formats inputs in the form of a chat history. For each message in the chat history: * the template will output the role of the speaker followed by the content of the message. * content is a list of strings and audios. * If the content element is an audio, the template will output a sequence of <|AUDIO|> tokens Example: ```python messages = [ {'role': 'system', 'content': 'You are a helpful assistant.'}, {"role": "user", "content": [ {"type": "audio", "audio_url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3"}, {"type": "text", "text": "What's that sound?"}, ]}, {"role": "assistant", "content": "It is the sound of glass shattering."}, {"role": "user", "content": [ {"type": "audio", "audio_url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/f2641_0_throatclearing.wav"}, {"type": "text", "text": "How about this one?"}, ]}, ] result = template.render(messages=messages, add_generation_prompt=True) ``` """ # fmt: off return ( "{% set audio_count = namespace(value=0) %}" "{% for message in messages %}" "{% if loop.first and message['role'] != 'system' %}" "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" "{% endif %}" "<|im_start|>{{ message['role'] }}\n" "{% if message['content'] is string %}" "{{ message['content'] }}<|im_end|>\n" "{% else %}" "{% for content in message['content'] %}" "{% if 'audio' in content or 'audio_url' in content or message['type'] == 'audio' %}" "{% set audio_count.value = audio_count.value + 1 %}" "Audio {{ audio_count.value }}: <|audio_bos|><|AUDIO|><|audio_eos|>\n" "{% elif 'text' in content %}" "{{ content['text'] }}" "{% endif %}" "{% endfor %}" "<|im_end|>\n" "{% endif %}" "{% endfor %}" "{% if add_generation_prompt %}" "<|im_start|>assistant\n" "{% endif %}" )