Spaces:
Runtime error
Runtime error
import io | |
import threading | |
import time | |
import os | |
import numpy as np | |
import torch | |
import torchaudio | |
import onnxruntime | |
import whisper | |
from funasr_detach import AutoModel | |
from utils import resample_audio, energy_norm_fn, trim_silence | |
class StepAudioTokenizer: | |
def __init__( | |
self, | |
encoder_path, | |
): | |
funasr_model_path = os.path.join( | |
encoder_path, | |
"dengcunqin/speech_paraformer-large_asr_nat-zh-cantonese-en-16k-vocab8501-online", | |
) | |
kms_path = os.path.join(encoder_path, "km_iter_3000.npy") | |
cosy_tokenizer_path = os.path.join(encoder_path, "speech_tokenizer_v1.onnx") | |
self.funasr_model = AutoModel(model=funasr_model_path, model_revision="master") | |
self.kms = torch.tensor(np.load(kms_path)) | |
providers = ["CUDAExecutionProvider"] | |
session_option = onnxruntime.SessionOptions() | |
session_option.graph_optimization_level = ( | |
onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL | |
) | |
session_option.intra_op_num_threads = 1 | |
self.ort_session = onnxruntime.InferenceSession( | |
cosy_tokenizer_path, sess_options=session_option, providers=providers | |
) | |
self.chunk_size = [0, 4, 5] | |
self.encoder_chunk_look_back = 4 | |
self.decoder_chunk_look_back = 1 | |
self.vq02_sessions = {} | |
self.vq02_lock = threading.Lock() | |
self.vq06_lock = threading.Lock() | |
def __call__(self, audio, sr): | |
_, vq02, vq06 = self.wav2token(audio, sr, False) | |
text = self.merge_vq0206_to_token_str(vq02, vq06) | |
return text | |
def preprocess_wav(self, audio, sample_rate, enable_trim=True, energy_norm=True): | |
audio = resample_audio(audio, sample_rate, 16000) | |
if energy_norm: | |
audio = energy_norm_fn(audio) | |
if enable_trim: | |
audio = audio.cpu().numpy().squeeze(0) | |
audio = trim_silence(audio, 16000) | |
audio = torch.from_numpy(audio) | |
audio = audio.unsqueeze(0) | |
return audio | |
def wav2token(self, audio, sample_rate, enable_trim=True, energy_norm=True): | |
audio = self.preprocess_wav( | |
audio, sample_rate, enable_trim=enable_trim, energy_norm=energy_norm | |
) | |
vq02_ori = self.get_vq02_code(audio) | |
vq02 = [int(x) + 65536 for x in vq02_ori] | |
vq06_ori = self.get_vq06_code(audio) | |
vq06 = [int(x) + 65536 + 1024 for x in vq06_ori] | |
chunk = 1 | |
chunk_nums = min(len(vq06) // (3 * chunk), len(vq02) // (2 * chunk)) | |
speech_tokens = [] | |
for idx in range(chunk_nums): | |
speech_tokens += vq02[idx * chunk * 2 : (idx + 1) * chunk * 2] | |
speech_tokens += vq06[idx * chunk * 3 : (idx + 1) * chunk * 3] | |
return speech_tokens, vq02_ori, vq06_ori | |
def get_vq02_code(self, audio, session_id=None, is_final=True): | |
_tmp_wav = io.BytesIO() | |
torchaudio.save(_tmp_wav, audio, 16000, format="wav") | |
_tmp_wav.seek(0) | |
with self.vq02_lock: | |
cache = {} | |
if session_id in self.vq02_sessions: | |
cache = self.vq02_sessions[session_id].get("cache", {}) | |
res, new_cache = self.funasr_model.infer_encoder( | |
input=[_tmp_wav], | |
chunk_size=self.chunk_size, | |
encoder_chunk_look_back=self.encoder_chunk_look_back, | |
decoder_chunk_look_back=self.decoder_chunk_look_back, | |
device=0, | |
is_final=is_final, | |
cache=cache, | |
) | |
c_list = [] | |
for j, res_ in enumerate(res): | |
feat = res_["enc_out"] | |
if len(feat) > 0: | |
c_list = self.dump_label([feat], self.kms)[0] | |
if is_final: | |
if session_id in self.vq02_sessions: | |
self.vq02_sessions.pop(session_id) | |
else: | |
if isinstance(session_id, str) and len(session_id) > 0: | |
self.vq02_sessions[session_id] = { | |
"cache": new_cache, | |
"update_time": time.time(), | |
} | |
return c_list | |
def get_vq06_code(self, audio): | |
def split_audio(audio, chunk_duration=480000): | |
start = 0 | |
chunks = [] | |
while start < len(audio): | |
end = min(start + chunk_duration, len(audio)) | |
chunk = audio[start:end] | |
if len(chunk) < 480: | |
pass | |
else: | |
chunks.append(chunk) | |
start = end | |
return chunks | |
with self.vq06_lock: | |
audio = audio.squeeze(0) | |
chunk_audios = split_audio(audio, chunk_duration=30 * 16000) # 最大支持30s | |
speech_tokens = [] | |
for chunk in chunk_audios: | |
duration = round(chunk.shape[0] / 16000, 2) | |
feat = whisper.log_mel_spectrogram(chunk, n_mels=128) | |
feat = feat.unsqueeze(0) | |
feat_len = np.array([feat.shape[2]], dtype=np.int32) | |
chunk_token = ( | |
self.ort_session.run( | |
None, | |
{ | |
self.ort_session.get_inputs()[0] | |
.name: feat.detach() | |
.cpu() | |
.numpy(), | |
self.ort_session.get_inputs()[1].name: feat_len, | |
}, | |
)[0] | |
.flatten() | |
.tolist() | |
) | |
assert abs(len(chunk_token) - duration * 25) <= 2 | |
speech_tokens += chunk_token | |
return speech_tokens | |
def kmean_cluster(self, samples, means): | |
dists = torch.cdist(samples, means) | |
indices = dists.argmin(dim=1).cpu().numpy() | |
return indices.tolist() | |
def dump_label(self, samples, mean): | |
dims = samples[0].shape[-1] | |
x_lens = [x.shape[1] for x in samples] | |
total_len = sum(x_lens) | |
x_sel = torch.FloatTensor(1, total_len, dims) | |
start_len = 0 | |
for sample in samples: | |
sample_len = sample.shape[1] | |
end_len = start_len + sample_len | |
x_sel[:, start_len:end_len] = sample | |
start_len = end_len | |
dense_x = x_sel.squeeze(0) | |
indices = self.kmean_cluster(dense_x, mean) | |
indices_list = [] | |
start_len = 0 | |
for x_len in x_lens: | |
end_len = start_len + end_len | |
indices_list.append(indices[start_len:end_len]) | |
return indices_list | |
def merge_vq0206_to_token_str(self, vq02, vq06): | |
_vq06 = [1024 + x for x in vq06] | |
result = [] | |
i = 0 | |
j = 0 | |
while i < len(vq02) - 1 and j < len(_vq06) - 2: | |
sublist = vq02[i : i + 2] + _vq06[j : j + 3] | |
result.extend(sublist) | |
i += 2 | |
j += 3 | |
return "".join([f"<audio_{x}>" for x in result]) | |