mrfakename's picture
Super-squash branch 'main' using huggingface_hub
0102e16 verified
import os
import re
import json
import torchaudio
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList
from cosyvoice.cli.cosyvoice import CosyVoice
class RepetitionAwareLogitsProcessor(LogitsProcessor):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
window_size = 10
threshold = 0.1
window = input_ids[:, -window_size:]
if window.shape[1] < window_size:
return scores
last_tokens = window[:, -1].unsqueeze(-1)
repeat_counts = (window == last_tokens).sum(dim=1)
repeat_ratios = repeat_counts.float() / window_size
mask = repeat_ratios > threshold
scores[mask, last_tokens[mask].squeeze(-1)] = float("-inf")
return scores
class StepAudioTTS:
def __init__(
self,
model_path,
encoder,
):
self.llm = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="cuda",
trust_remote_code=True,
)
self.tokenizer = AutoTokenizer.from_pretrained(
model_path, trust_remote_code=True
)
self.common_cosy_model = CosyVoice(
os.path.join(model_path, "CosyVoice-300M-25Hz")
)
self.music_cosy_model = CosyVoice(
os.path.join(model_path, "CosyVoice-300M-25Hz-Music")
)
self.encoder = encoder
self.sys_prompt_dict = {
"sys_prompt_for_rap": "请参考对话历史里的音色,用RAP方式将文本内容大声说唱出来。",
"sys_prompt_for_vocal": "请参考对话历史里的音色,用哼唱的方式将文本内容大声唱出来。",
"sys_prompt_wo_spk": '作为一名卓越的声优演员,你的任务是根据文本中()或()括号内标注的情感、语种或方言、音乐哼唱、语音调整等标签,以丰富细腻的情感和自然顺畅的语调来朗读文本。\n# 情感标签涵盖了多种情绪状态,包括但不限于:\n- "高兴1"\n- "高兴2"\n- "生气1"\n- "生气2"\n- "悲伤1"\n- "撒娇1"\n\n# 语种或方言标签包含多种语言或方言,包括但不限于:\n- "中文"\n- "英文"\n- "韩语"\n- "日语"\n- "四川话"\n- "粤语"\n- "广东话"\n\n# 音乐哼唱标签包含多种类型歌曲哼唱,包括但不限于:\n- "RAP"\n- "哼唱"\n\n# 语音调整标签,包括但不限于:\n- "慢速1"\n- "慢速2"\n- "快速1"\n- "快速2"\n\n请在朗读时,根据这些情感标签的指示,调整你的情感、语气、语调和哼唱节奏,以确保文本的情感和意义得到准确而生动的传达,如果没有()或()括号,则根据文本语义内容自由演绎。',
"sys_prompt_with_spk": '作为一名卓越的声优演员,你的任务是根据文本中()或()括号内标注的情感、语种或方言、音乐哼唱、语音调整等标签,以丰富细腻的情感和自然顺畅的语调来朗读文本。\n# 情感标签涵盖了多种情绪状态,包括但不限于:\n- "高兴1"\n- "高兴2"\n- "生气1"\n- "生气2"\n- "悲伤1"\n- "撒娇1"\n\n# 语种或方言标签包含多种语言或方言,包括但不限于:\n- "中文"\n- "英文"\n- "韩语"\n- "日语"\n- "四川话"\n- "粤语"\n- "广东话"\n\n# 音乐哼唱标签包含多种类型歌曲哼唱,包括但不限于:\n- "RAP"\n- "哼唱"\n\n# 语音调整标签,包括但不限于:\n- "慢速1"\n- "慢速2"\n- "快速1"\n- "快速2"\n\n请在朗读时,使用[{}]的声音,根据这些情感标签的指示,调整你的情感、语气、语调和哼唱节奏,以确保文本的情感和意义得到准确而生动的传达,如果没有()或()括号,则根据文本语义内容自由演绎。',
}
self.register_speakers()
def __call__(self, text: str, prompt_speaker: str, clone_dict: dict | None = None):
if clone_dict:
clone_prompt_code, clone_prompt_token, clone_prompt_token_len, clone_speech_feat, clone_speech_feat_len, clone_speech_embedding = (
self.preprocess_prompt_wav(clone_dict['wav_path'])
)
prompt_speaker = clone_dict['speaker']
self.speakers_info[prompt_speaker] = {
"prompt_text": clone_dict['prompt_text'],
"prompt_code": clone_prompt_code,
"cosy_speech_feat": clone_speech_feat.to(torch.bfloat16),
"cosy_speech_feat_len": clone_speech_feat_len,
"cosy_speech_embedding": clone_speech_embedding.to(torch.bfloat16),
"cosy_prompt_token": clone_prompt_token,
"cosy_prompt_token_len": clone_prompt_token_len,
}
instruction_name = self.detect_instruction_name(text)
if instruction_name in ("RAP", "哼唱"):
prompt_speaker_info = self.speakers_info[
f"{prompt_speaker}{instruction_name}"
]
cosy_model = self.music_cosy_model
else:
prompt_speaker_info = self.speakers_info[prompt_speaker]
cosy_model = self.common_cosy_model
if clone_dict:
prompt_speaker = ''
token_ids = self.tokenize(
text,
prompt_speaker_info["prompt_text"],
prompt_speaker,
prompt_speaker_info["prompt_code"],
)
output_ids = self.llm.generate(
torch.tensor([token_ids]).to(torch.long).to("cuda"),
max_length=8192,
temperature=0.7,
do_sample=True,
logits_processor=LogitsProcessorList([RepetitionAwareLogitsProcessor()]),
)
output_ids = output_ids[:, len(token_ids) : -1] # skip eos token
return (
cosy_model.token_to_wav_offline(
output_ids - 65536,
prompt_speaker_info["cosy_speech_feat"].to(torch.bfloat16),
prompt_speaker_info["cosy_speech_feat_len"],
prompt_speaker_info["cosy_prompt_token"],
prompt_speaker_info["cosy_prompt_token_len"],
prompt_speaker_info["cosy_speech_embedding"].to(torch.bfloat16),
),
22050,
)
def register_speakers(self):
self.speakers_info = {}
with open("speakers/speakers_info.json", "r") as f:
speakers_info = json.load(f)
for speaker_id, prompt_text in speakers_info.items():
prompt_wav_path = f"speakers/{speaker_id}_prompt.wav"
prompt_code, prompt_token, prompt_token_len, speech_feat, speech_feat_len, speech_embedding = (
self.preprocess_prompt_wav(prompt_wav_path)
)
self.speakers_info[speaker_id] = {
"prompt_text": prompt_text,
"prompt_code": prompt_code,
"cosy_speech_feat": speech_feat.to(torch.bfloat16),
"cosy_speech_feat_len": speech_feat_len,
"cosy_speech_embedding": speech_embedding.to(torch.bfloat16),
"cosy_prompt_token": prompt_token,
"cosy_prompt_token_len": prompt_token_len,
}
print(f"Registered speaker: {speaker_id}")
def detect_instruction_name(self, text):
instruction_name = ""
match_group = re.match(r"^([(\(][^\(\)()]*[)\)]).*$", text, re.DOTALL)
if match_group is not None:
instruction = match_group.group(1)
instruction_name = instruction.strip("()()")
return instruction_name
def tokenize(
self, text: str, prompt_text: str, prompt_speaker: str, prompt_code: list
):
rap_or_vocal = self.detect_instruction_name(text) in ("RAP", "哼唱")
if rap_or_vocal:
if "哼唱" in text:
prompt = self.sys_prompt_dict["sys_prompt_for_vocal"]
else:
prompt = self.sys_prompt_dict["sys_prompt_for_rap"]
elif prompt_speaker:
prompt = self.sys_prompt_dict["sys_prompt_with_spk"].format(prompt_speaker)
else:
prompt = self.sys_prompt_dict["sys_prompt_wo_spk"]
sys_tokens = self.tokenizer.encode(f"system\n{prompt}")
history = [1]
history.extend([4] + sys_tokens + [3])
_prefix_tokens = self.tokenizer.encode("\n")
prompt_token_encode = self.tokenizer.encode("\n" + prompt_text)
prompt_tokens = prompt_token_encode[len(_prefix_tokens) :]
target_token_encode = self.tokenizer.encode("\n" + text)
target_tokens = target_token_encode[len(_prefix_tokens) :]
qrole_toks = self.tokenizer.encode("human\n")
arole_toks = self.tokenizer.encode("assistant\n")
history.extend(
[4]
+ qrole_toks
+ prompt_tokens
+ [3]
+ [4]
+ arole_toks
+ prompt_code
+ [3]
+ [4]
+ qrole_toks
+ target_tokens
+ [3]
+ [4]
+ arole_toks
)
return history
def preprocess_prompt_wav(self, prompt_wav_path : str):
prompt_wav, prompt_wav_sr = torchaudio.load(prompt_wav_path)
prompt_wav_16k = torchaudio.transforms.Resample(
orig_freq=prompt_wav_sr, new_freq=16000
)(prompt_wav)
prompt_wav_22k = torchaudio.transforms.Resample(
orig_freq=prompt_wav_sr, new_freq=22050
)(prompt_wav)
speech_feat, speech_feat_len = (
self.common_cosy_model.frontend._extract_speech_feat(prompt_wav_22k)
)
speech_embedding = self.common_cosy_model.frontend._extract_spk_embedding(
prompt_wav_16k
)
prompt_code, _, _ = self.encoder.wav2token(prompt_wav, prompt_wav_sr)
prompt_token = torch.tensor([prompt_code], dtype=torch.long) - 65536
prompt_token_len = torch.tensor([prompt_token.shape[1]], dtype=torch.long)
return (
prompt_code,
prompt_token,
prompt_token_len,
speech_feat,
speech_feat_len,
speech_embedding,
)