Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import torchaudio | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from tokenizer import StepAudioTokenizer | |
from tts import StepAudioTTS | |
from utils import load_audio, speech_adjust, volumn_adjust | |
class StepAudio: | |
def __init__(self, tokenizer_path: str, tts_path: str, llm_path: str): | |
# load optimus_ths for flash attention, make sure LD_LIBRARY_PATH has `nvidia/cuda_nvrtc/lib` | |
# if not, please manually set LD_LIBRARY_PATH=xxx/python3.10/site-packages/nvidia/cuda_nvrtc/lib | |
try: | |
if torch.__version__ >= "2.5": | |
torch.ops.load_library(os.path.join(llm_path, 'lib/liboptimus_ths-torch2.5-cu124.cpython-310-x86_64-linux-gnu.so')) | |
elif torch.__version__ >= "2.3": | |
torch.ops.load_library(os.path.join(llm_path, 'lib/liboptimus_ths-torch2.3-cu121.cpython-310-x86_64-linux-gnu.so')) | |
elif torch.__version__ >= "2.2": | |
torch.ops.load_library(os.path.join(llm_path, 'lib/liboptimus_ths-torch2.2-cu121.cpython-310-x86_64-linux-gnu.so')) | |
print("Load optimus_ths successfully and flash attn would be enabled") | |
except Exception as err: | |
print(f"Fail to load optimus_ths and flash attn is disabled: {err}") | |
self.llm_tokenizer = AutoTokenizer.from_pretrained( | |
llm_path, trust_remote_code=True | |
) | |
self.encoder = StepAudioTokenizer(tokenizer_path) | |
self.decoder = StepAudioTTS(tts_path, self.encoder) | |
self.llm = AutoModelForCausalLM.from_pretrained( | |
llm_path, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
trust_remote_code=True, | |
) | |
def __call__( | |
self, | |
messages: list, | |
speaker_id: str, | |
speed_ratio: float = 1.0, | |
volumn_ratio: float = 1.0, | |
): | |
text_with_audio = self.apply_chat_template(messages) | |
token_ids = self.llm_tokenizer.encode(text_with_audio, return_tensors="pt") | |
outputs = self.llm.generate( | |
token_ids, max_new_tokens=2048, temperature=0.7, top_p=0.9, do_sample=True | |
) | |
output_token_ids = outputs[:, token_ids.shape[-1] : -1].tolist()[0] | |
output_text = self.llm_tokenizer.decode(output_token_ids) | |
output_audio, sr = self.decoder(output_text, speaker_id) | |
if speed_ratio != 1.0: | |
output_audio = speech_adjust(output_audio, sr, speed_ratio) | |
if volumn_ratio != 1.0: | |
output_audio = volumn_adjust(output_audio, volumn_ratio) | |
return output_text, output_audio, sr | |
def encode_audio(self, audio_path): | |
audio_wav, sr = load_audio(audio_path) | |
audio_tokens = self.encoder(audio_wav, sr) | |
return audio_tokens | |
def apply_chat_template(self, messages: list): | |
text_with_audio = "" | |
for msg in messages: | |
role = msg["role"] | |
content = msg["content"] | |
if role == "user": | |
role = "human" | |
if isinstance(content, str): | |
text_with_audio += f"<|BOT|>{role}\n{content}<|EOT|>" | |
elif isinstance(content, dict): | |
if content["type"] == "text": | |
text_with_audio += f"<|BOT|>{role}\n{content['text']}<|EOT|>" | |
elif content["type"] == "audio": | |
audio_tokens = self.encode_audio(content["audio"]) | |
text_with_audio += f"<|BOT|>{role}\n{audio_tokens}<|EOT|>" | |
elif content is None: | |
text_with_audio += f"<|BOT|>{role}\n" | |
else: | |
raise ValueError(f"Unsupported content type: {type(content)}") | |
if not text_with_audio.endswith("<|BOT|>assistant\n"): | |
text_with_audio += "<|BOT|>assistant\n" | |
return text_with_audio | |
if __name__ == "__main__": | |
model = StepAudio( | |
encoder_path="/mnt/ys-shai-jfs/open-step1o-audio/step1o-audio-encoder", | |
decoder_path="/mnt/ys-shai-jfs/open-step1o-audio/step1o-audio-decoder", | |
llm_path="/mnt/ys-shai-jfs/open-step1o-audio/step1o-audio-v18", | |
) | |
text, audio, sr = model( | |
[{"role": "user", "content": "你好,我是你的朋友,我叫小明,你叫什么名字?"}], | |
"闫雨婷", | |
) | |
torchaudio.save("output/output_e2e_tqta.wav", audio, sr) | |
text, audio, sr = model( | |
[ | |
{ | |
"role": "user", | |
"content": {"type": "audio", "audio": "output/output_e2e_tqta.wav"}, | |
} | |
], | |
"闫雨婷", | |
) | |
torchaudio.save("output/output_e2e_aqta.wav", audio, sr) | |