Spaces:
Runtime error
Runtime error
import os | |
import re | |
import json | |
import torchaudio | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from transformers.generation.logits_process import LogitsProcessor | |
from transformers.generation.utils import LogitsProcessorList | |
from cosyvoice.cli.cosyvoice import CosyVoice | |
class RepetitionAwareLogitsProcessor(LogitsProcessor): | |
def __call__( | |
self, input_ids: torch.LongTensor, scores: torch.FloatTensor | |
) -> torch.FloatTensor: | |
window_size = 10 | |
threshold = 0.1 | |
window = input_ids[:, -window_size:] | |
if window.shape[1] < window_size: | |
return scores | |
last_tokens = window[:, -1].unsqueeze(-1) | |
repeat_counts = (window == last_tokens).sum(dim=1) | |
repeat_ratios = repeat_counts.float() / window_size | |
mask = repeat_ratios > threshold | |
scores[mask, last_tokens[mask].squeeze(-1)] = float("-inf") | |
return scores | |
class StepAudioTTS: | |
def __init__( | |
self, | |
model_path, | |
encoder, | |
): | |
self.llm = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
torch_dtype=torch.bfloat16, | |
device_map="cuda", | |
trust_remote_code=True, | |
) | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
model_path, trust_remote_code=True | |
) | |
self.common_cosy_model = CosyVoice( | |
os.path.join(model_path, "CosyVoice-300M-25Hz") | |
) | |
self.music_cosy_model = CosyVoice( | |
os.path.join(model_path, "CosyVoice-300M-25Hz-Music") | |
) | |
self.encoder = encoder | |
self.sys_prompt_dict = { | |
"sys_prompt_for_rap": "请参考对话历史里的音色,用RAP方式将文本内容大声说唱出来。", | |
"sys_prompt_for_vocal": "请参考对话历史里的音色,用哼唱的方式将文本内容大声唱出来。", | |
"sys_prompt_wo_spk": '作为一名卓越的声优演员,你的任务是根据文本中()或()括号内标注的情感、语种或方言、音乐哼唱、语音调整等标签,以丰富细腻的情感和自然顺畅的语调来朗读文本。\n# 情感标签涵盖了多种情绪状态,包括但不限于:\n- "高兴1"\n- "高兴2"\n- "生气1"\n- "生气2"\n- "悲伤1"\n- "撒娇1"\n\n# 语种或方言标签包含多种语言或方言,包括但不限于:\n- "中文"\n- "英文"\n- "韩语"\n- "日语"\n- "四川话"\n- "粤语"\n- "广东话"\n\n# 音乐哼唱标签包含多种类型歌曲哼唱,包括但不限于:\n- "RAP"\n- "哼唱"\n\n# 语音调整标签,包括但不限于:\n- "慢速1"\n- "慢速2"\n- "快速1"\n- "快速2"\n\n请在朗读时,根据这些情感标签的指示,调整你的情感、语气、语调和哼唱节奏,以确保文本的情感和意义得到准确而生动的传达,如果没有()或()括号,则根据文本语义内容自由演绎。', | |
"sys_prompt_with_spk": '作为一名卓越的声优演员,你的任务是根据文本中()或()括号内标注的情感、语种或方言、音乐哼唱、语音调整等标签,以丰富细腻的情感和自然顺畅的语调来朗读文本。\n# 情感标签涵盖了多种情绪状态,包括但不限于:\n- "高兴1"\n- "高兴2"\n- "生气1"\n- "生气2"\n- "悲伤1"\n- "撒娇1"\n\n# 语种或方言标签包含多种语言或方言,包括但不限于:\n- "中文"\n- "英文"\n- "韩语"\n- "日语"\n- "四川话"\n- "粤语"\n- "广东话"\n\n# 音乐哼唱标签包含多种类型歌曲哼唱,包括但不限于:\n- "RAP"\n- "哼唱"\n\n# 语音调整标签,包括但不限于:\n- "慢速1"\n- "慢速2"\n- "快速1"\n- "快速2"\n\n请在朗读时,使用[{}]的声音,根据这些情感标签的指示,调整你的情感、语气、语调和哼唱节奏,以确保文本的情感和意义得到准确而生动的传达,如果没有()或()括号,则根据文本语义内容自由演绎。', | |
} | |
self.register_speakers() | |
def __call__(self, text: str, prompt_speaker: str, clone_dict: dict | None = None): | |
if clone_dict: | |
clone_prompt_code, clone_prompt_token, clone_prompt_token_len, clone_speech_feat, clone_speech_feat_len, clone_speech_embedding = ( | |
self.preprocess_prompt_wav(clone_dict['wav_path']) | |
) | |
prompt_speaker = clone_dict['speaker'] | |
self.speakers_info[prompt_speaker] = { | |
"prompt_text": clone_dict['prompt_text'], | |
"prompt_code": clone_prompt_code, | |
"cosy_speech_feat": clone_speech_feat.to(torch.bfloat16), | |
"cosy_speech_feat_len": clone_speech_feat_len, | |
"cosy_speech_embedding": clone_speech_embedding.to(torch.bfloat16), | |
"cosy_prompt_token": clone_prompt_token, | |
"cosy_prompt_token_len": clone_prompt_token_len, | |
} | |
instruction_name = self.detect_instruction_name(text) | |
if instruction_name in ("RAP", "哼唱"): | |
prompt_speaker_info = self.speakers_info[ | |
f"{prompt_speaker}{instruction_name}" | |
] | |
cosy_model = self.music_cosy_model | |
else: | |
prompt_speaker_info = self.speakers_info[prompt_speaker] | |
cosy_model = self.common_cosy_model | |
if clone_dict: | |
prompt_speaker = '' | |
token_ids = self.tokenize( | |
text, | |
prompt_speaker_info["prompt_text"], | |
prompt_speaker, | |
prompt_speaker_info["prompt_code"], | |
) | |
output_ids = self.llm.generate( | |
torch.tensor([token_ids]).to(torch.long).to("cuda"), | |
max_length=8192, | |
temperature=0.7, | |
do_sample=True, | |
logits_processor=LogitsProcessorList([RepetitionAwareLogitsProcessor()]), | |
) | |
output_ids = output_ids[:, len(token_ids) : -1] # skip eos token | |
return ( | |
cosy_model.token_to_wav_offline( | |
output_ids - 65536, | |
prompt_speaker_info["cosy_speech_feat"].to(torch.bfloat16), | |
prompt_speaker_info["cosy_speech_feat_len"], | |
prompt_speaker_info["cosy_prompt_token"], | |
prompt_speaker_info["cosy_prompt_token_len"], | |
prompt_speaker_info["cosy_speech_embedding"].to(torch.bfloat16), | |
), | |
22050, | |
) | |
def register_speakers(self): | |
self.speakers_info = {} | |
with open("speakers/speakers_info.json", "r") as f: | |
speakers_info = json.load(f) | |
for speaker_id, prompt_text in speakers_info.items(): | |
prompt_wav_path = f"speakers/{speaker_id}_prompt.wav" | |
prompt_code, prompt_token, prompt_token_len, speech_feat, speech_feat_len, speech_embedding = ( | |
self.preprocess_prompt_wav(prompt_wav_path) | |
) | |
self.speakers_info[speaker_id] = { | |
"prompt_text": prompt_text, | |
"prompt_code": prompt_code, | |
"cosy_speech_feat": speech_feat.to(torch.bfloat16), | |
"cosy_speech_feat_len": speech_feat_len, | |
"cosy_speech_embedding": speech_embedding.to(torch.bfloat16), | |
"cosy_prompt_token": prompt_token, | |
"cosy_prompt_token_len": prompt_token_len, | |
} | |
print(f"Registered speaker: {speaker_id}") | |
def detect_instruction_name(self, text): | |
instruction_name = "" | |
match_group = re.match(r"^([(\(][^\(\)()]*[)\)]).*$", text, re.DOTALL) | |
if match_group is not None: | |
instruction = match_group.group(1) | |
instruction_name = instruction.strip("()()") | |
return instruction_name | |
def tokenize( | |
self, text: str, prompt_text: str, prompt_speaker: str, prompt_code: list | |
): | |
rap_or_vocal = self.detect_instruction_name(text) in ("RAP", "哼唱") | |
if rap_or_vocal: | |
if "哼唱" in text: | |
prompt = self.sys_prompt_dict["sys_prompt_for_vocal"] | |
else: | |
prompt = self.sys_prompt_dict["sys_prompt_for_rap"] | |
elif prompt_speaker: | |
prompt = self.sys_prompt_dict["sys_prompt_with_spk"].format(prompt_speaker) | |
else: | |
prompt = self.sys_prompt_dict["sys_prompt_wo_spk"] | |
sys_tokens = self.tokenizer.encode(f"system\n{prompt}") | |
history = [1] | |
history.extend([4] + sys_tokens + [3]) | |
_prefix_tokens = self.tokenizer.encode("\n") | |
prompt_token_encode = self.tokenizer.encode("\n" + prompt_text) | |
prompt_tokens = prompt_token_encode[len(_prefix_tokens) :] | |
target_token_encode = self.tokenizer.encode("\n" + text) | |
target_tokens = target_token_encode[len(_prefix_tokens) :] | |
qrole_toks = self.tokenizer.encode("human\n") | |
arole_toks = self.tokenizer.encode("assistant\n") | |
history.extend( | |
[4] | |
+ qrole_toks | |
+ prompt_tokens | |
+ [3] | |
+ [4] | |
+ arole_toks | |
+ prompt_code | |
+ [3] | |
+ [4] | |
+ qrole_toks | |
+ target_tokens | |
+ [3] | |
+ [4] | |
+ arole_toks | |
) | |
return history | |
def preprocess_prompt_wav(self, prompt_wav_path : str): | |
prompt_wav, prompt_wav_sr = torchaudio.load(prompt_wav_path) | |
prompt_wav_16k = torchaudio.transforms.Resample( | |
orig_freq=prompt_wav_sr, new_freq=16000 | |
)(prompt_wav) | |
prompt_wav_22k = torchaudio.transforms.Resample( | |
orig_freq=prompt_wav_sr, new_freq=22050 | |
)(prompt_wav) | |
speech_feat, speech_feat_len = ( | |
self.common_cosy_model.frontend._extract_speech_feat(prompt_wav_22k) | |
) | |
speech_embedding = self.common_cosy_model.frontend._extract_spk_embedding( | |
prompt_wav_16k | |
) | |
prompt_code, _, _ = self.encoder.wav2token(prompt_wav, prompt_wav_sr) | |
prompt_token = torch.tensor([prompt_code], dtype=torch.long) - 65536 | |
prompt_token_len = torch.tensor([prompt_token.shape[1]], dtype=torch.long) | |
return ( | |
prompt_code, | |
prompt_token, | |
prompt_token_len, | |
speech_feat, | |
speech_feat_len, | |
speech_embedding, | |
) |