Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
48811fe
1
Parent(s):
406e6ac
更新ASR模块,重构转录逻辑,新增基于Transformers的distil-whisper模型支持,并优化音频处理流程。
Browse files
examples/simple_asr.py
CHANGED
@@ -14,7 +14,7 @@ from pathlib import Path
|
|
14 |
sys.path.insert(0, str(Path(__file__).parent.parent))
|
15 |
|
16 |
from src.podcast_transcribe.audio import load_audio
|
17 |
-
from src.podcast_transcribe.asr.
|
18 |
|
19 |
logger = logging.getLogger("asr_example")
|
20 |
|
@@ -43,7 +43,7 @@ def main():
|
|
43 |
|
44 |
# 进行转录
|
45 |
print("开始转录...")
|
46 |
-
result = transcribe_audio(audio, model_name=model, device=device)
|
47 |
|
48 |
# 输出结果
|
49 |
print("\n转录结果:")
|
|
|
14 |
sys.path.insert(0, str(Path(__file__).parent.parent))
|
15 |
|
16 |
from src.podcast_transcribe.audio import load_audio
|
17 |
+
from src.podcast_transcribe.asr.asr_router import transcribe_audio
|
18 |
|
19 |
logger = logging.getLogger("asr_example")
|
20 |
|
|
|
43 |
|
44 |
# 进行转录
|
45 |
print("开始转录...")
|
46 |
+
result = transcribe_audio(audio, provider="distil_whisper_transformers", model_name=model, device=device)
|
47 |
|
48 |
# 输出结果
|
49 |
print("\n转录结果:")
|
src/podcast_transcribe/asr/asr_base.py
CHANGED
@@ -46,6 +46,54 @@ class BaseTranscriber:
|
|
46 |
"""
|
47 |
raise NotImplementedError("子类必须实现_load_model方法")
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
def _prepare_audio(self, audio: AudioSegment) -> AudioSegment:
|
50 |
"""
|
51 |
准备音频数据
|
|
|
46 |
"""
|
47 |
raise NotImplementedError("子类必须实现_load_model方法")
|
48 |
|
49 |
+
def transcribe(self, audio: AudioSegment) -> TranscriptionResult:
|
50 |
+
"""
|
51 |
+
转录音频,针对distil-whisper模型取消分块处理,直接处理整个音频。
|
52 |
+
|
53 |
+
参数:
|
54 |
+
audio: 要转录的AudioSegment对象
|
55 |
+
|
56 |
+
返回:
|
57 |
+
TranscriptionResult对象,包含转录结果
|
58 |
+
"""
|
59 |
+
logger.info(f"开始转录 {len(audio)/1000:.2f} 秒的音频") # 移除了模型名称,因为基类不知道具体模型
|
60 |
+
|
61 |
+
# 直接处理整个音频,不进行分块
|
62 |
+
processed_audio = self._prepare_audio(audio)
|
63 |
+
samples = np.array(processed_audio.get_array_of_samples(), dtype=np.float32) / 32768.0
|
64 |
+
|
65 |
+
try:
|
66 |
+
model_result = self._perform_transcription(samples)
|
67 |
+
text = self._get_text_from_result(model_result)
|
68 |
+
segments = self._convert_segments(model_result)
|
69 |
+
language = self._detect_language(text)
|
70 |
+
|
71 |
+
logger.info(f"转录完成,语言: {language},文本长度: {len(text)},分段数: {len(segments)}")
|
72 |
+
return TranscriptionResult(text=text, segments=segments, language=language)
|
73 |
+
except Exception as e:
|
74 |
+
logger.error(f"转录失败: {str(e)}", exc_info=True)
|
75 |
+
raise RuntimeError(f"转录失败: {str(e)}")
|
76 |
+
|
77 |
+
def _get_text_from_result(self, result):
|
78 |
+
"""
|
79 |
+
从结果中获取文本
|
80 |
+
|
81 |
+
参数:
|
82 |
+
result: 模型的转录结果
|
83 |
+
|
84 |
+
返回:
|
85 |
+
转录的文本
|
86 |
+
"""
|
87 |
+
return result.get("text", "")
|
88 |
+
|
89 |
+
def _perform_transcription(self, audio_data):
|
90 |
+
"""执行转录的抽象方法,由子类实现"""
|
91 |
+
raise NotImplementedError("子类必须实现_perform_transcription方法")
|
92 |
+
|
93 |
+
def _convert_segments(self, result) -> List[Dict[str, Union[float, str]]]:
|
94 |
+
"""将模型结果转换为分段的抽象方法,由子类实现"""
|
95 |
+
raise NotImplementedError("子类必须实现_convert_segments方法")
|
96 |
+
|
97 |
def _prepare_audio(self, audio: AudioSegment) -> AudioSegment:
|
98 |
"""
|
99 |
准备音频数据
|
src/podcast_transcribe/asr/{asr_distil_whisper.py → asr_distil_whisper_transformers.py}
RENAMED
@@ -1,5 +1,5 @@
|
|
1 |
"""
|
2 |
-
基于
|
3 |
"""
|
4 |
|
5 |
import os
|
@@ -15,145 +15,7 @@ from .asr_base import BaseTranscriber, TranscriptionResult
|
|
15 |
logger = logging.getLogger("asr")
|
16 |
|
17 |
|
18 |
-
class
|
19 |
-
"""抽象基类:Distil Whisper转录器的共享实现"""
|
20 |
-
|
21 |
-
def __init__(
|
22 |
-
self,
|
23 |
-
model_name: str,
|
24 |
-
**kwargs
|
25 |
-
):
|
26 |
-
"""
|
27 |
-
初始化转录器
|
28 |
-
|
29 |
-
参数:
|
30 |
-
model_name: 模型名称
|
31 |
-
**kwargs: 其他参数
|
32 |
-
"""
|
33 |
-
super().__init__(model_name=model_name, **kwargs)
|
34 |
-
|
35 |
-
def transcribe(self, audio: AudioSegment) -> TranscriptionResult:
|
36 |
-
"""
|
37 |
-
转录音频,针对distil-whisper模型取消分块处理,直接处理整个音频。
|
38 |
-
|
39 |
-
参数:
|
40 |
-
audio: 要转录的AudioSegment对象
|
41 |
-
chunk_duration_s: 分块处理的块时长(秒)- 此参数被忽略
|
42 |
-
overlap_s: 分块间的重叠时长(秒)- 此参数被忽略
|
43 |
-
|
44 |
-
返回:
|
45 |
-
TranscriptionResult对象,包含转录结果
|
46 |
-
"""
|
47 |
-
logger.info(f"开始转录 {len(audio)/1000:.2f} 秒的音频(distil-whisper模型)")
|
48 |
-
|
49 |
-
# 直接处理整个音频,不进行分块
|
50 |
-
processed_audio = self._prepare_audio(audio)
|
51 |
-
samples = np.array(processed_audio.get_array_of_samples(), dtype=np.float32) / 32768.0
|
52 |
-
|
53 |
-
try:
|
54 |
-
model_result = self._perform_transcription(samples)
|
55 |
-
text = self._get_text_from_result(model_result)
|
56 |
-
segments = self._convert_segments(model_result)
|
57 |
-
language = self._detect_language(text)
|
58 |
-
|
59 |
-
logger.info(f"转录完成,语言: {language},文本长度: {len(text)},分段数: {len(segments)}")
|
60 |
-
return TranscriptionResult(text=text, segments=segments, language=language)
|
61 |
-
except Exception as e:
|
62 |
-
logger.error(f"转录失败: {str(e)}", exc_info=True)
|
63 |
-
raise RuntimeError(f"转录失败: {str(e)}")
|
64 |
-
|
65 |
-
def _get_text_from_result(self, result):
|
66 |
-
"""
|
67 |
-
从结果中获取文本
|
68 |
-
|
69 |
-
参数:
|
70 |
-
result: 模型的转录结果
|
71 |
-
|
72 |
-
返回:
|
73 |
-
转录的文本
|
74 |
-
"""
|
75 |
-
return result.get("text", "")
|
76 |
-
|
77 |
-
def _load_model(self):
|
78 |
-
"""加载模型的抽象方法,由子类实现"""
|
79 |
-
raise NotImplementedError("子类必须实现_load_model方法")
|
80 |
-
|
81 |
-
def _perform_transcription(self, audio_data):
|
82 |
-
"""执行转录的抽象方法,由子类实现"""
|
83 |
-
raise NotImplementedError("子类必须实现_perform_transcription方法")
|
84 |
-
|
85 |
-
def _convert_segments(self, result) -> List[Dict[str, Union[float, str]]]:
|
86 |
-
"""将模型结果转换为分段的抽象方法,由子类实现"""
|
87 |
-
raise NotImplementedError("子类必须实现_convert_segments方法")
|
88 |
-
|
89 |
-
|
90 |
-
class MLXDistilWhisperTranscriber(DistilWhisperTranscriber):
|
91 |
-
"""使用MLX加载和运行distil-whisper模型的转录器"""
|
92 |
-
|
93 |
-
def __init__(
|
94 |
-
self,
|
95 |
-
model_name: str = "mlx-community/distil-whisper-large-v3",
|
96 |
-
):
|
97 |
-
"""
|
98 |
-
初始化转录器
|
99 |
-
|
100 |
-
参数:
|
101 |
-
model_name: 模型名称
|
102 |
-
"""
|
103 |
-
super().__init__(model_name=model_name)
|
104 |
-
|
105 |
-
def _load_model(self):
|
106 |
-
"""加载Distil Whisper MLX模型"""
|
107 |
-
try:
|
108 |
-
# 懒加载mlx-whisper
|
109 |
-
try:
|
110 |
-
import mlx_whisper
|
111 |
-
except ImportError:
|
112 |
-
raise ImportError("请先安装mlx-whisper库: pip install mlx-whisper")
|
113 |
-
|
114 |
-
logger.info(f"开始加载模型 {self.model_name}")
|
115 |
-
self.model = mlx_whisper.load_models.load_model(self.model_name)
|
116 |
-
logger.info(f"模型加载成功")
|
117 |
-
except Exception as e:
|
118 |
-
logger.error(f"加载模型失败: {str(e)}", exc_info=True)
|
119 |
-
raise RuntimeError(f"加载模型失败: {str(e)}")
|
120 |
-
|
121 |
-
def _convert_segments(self, result) -> List[Dict[str, Union[float, str]]]:
|
122 |
-
"""
|
123 |
-
将模型的分段结果转换为所需格式
|
124 |
-
|
125 |
-
参数:
|
126 |
-
result: 模型返回的结果
|
127 |
-
|
128 |
-
返回:
|
129 |
-
转换后的分段列表
|
130 |
-
"""
|
131 |
-
segments = []
|
132 |
-
|
133 |
-
for segment in result.get("segments", []):
|
134 |
-
segments.append({
|
135 |
-
"start": segment.get("start", 0.0),
|
136 |
-
"end": segment.get("end", 0.0),
|
137 |
-
"text": segment.get("text", "").strip()
|
138 |
-
})
|
139 |
-
|
140 |
-
return segments
|
141 |
-
|
142 |
-
def _perform_transcription(self, audio_data):
|
143 |
-
"""
|
144 |
-
执行转录
|
145 |
-
|
146 |
-
参数:
|
147 |
-
audio_data: 音频数据(numpy数组)
|
148 |
-
|
149 |
-
返回:
|
150 |
-
模型的转录结果
|
151 |
-
"""
|
152 |
-
from mlx_whisper import transcribe
|
153 |
-
return transcribe(audio_data, path_or_hf_repo=self.model_name)
|
154 |
-
|
155 |
-
|
156 |
-
class TransformersDistilWhisperTranscriber(DistilWhisperTranscriber):
|
157 |
"""使用Transformers加载和运行distil-whisper模型的转录器"""
|
158 |
|
159 |
def __init__(
|
@@ -285,35 +147,27 @@ class TransformersDistilWhisperTranscriber(DistilWhisperTranscriber):
|
|
285 |
# 如果新格式失败,回退到简单调用
|
286 |
return self.pipeline(audio_data)
|
287 |
|
288 |
-
|
289 |
# 统一的接口函数
|
290 |
def transcribe_audio(
|
291 |
audio_segment: AudioSegment,
|
292 |
model_name: str = None,
|
293 |
-
backend: Literal["mlx", "transformers"] = "transformers",
|
294 |
device: str = "cpu",
|
295 |
) -> TranscriptionResult:
|
296 |
"""
|
297 |
-
使用Distil Whisper模型转录音频
|
298 |
|
299 |
参数:
|
300 |
audio_segment: 输入的AudioSegment对象
|
301 |
model_name: 使用的模型名称,如果不指定则使用默认模型
|
302 |
-
|
303 |
-
device: 推理设备,仅对transformers后端有效
|
304 |
|
305 |
返回:
|
306 |
TranscriptionResult对象,包含转录的文本、分段和语言
|
307 |
"""
|
308 |
-
logger.info(f"调用transcribe_audio
|
309 |
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
transcriber = MLXDistilWhisperTranscriber(model_name=model)
|
314 |
-
else: # transformers
|
315 |
-
default_model = "distil-whisper/distil-large-v3.5"
|
316 |
-
model = model_name or default_model
|
317 |
-
transcriber = TransformersDistilWhisperTranscriber(model_name=model, device=device)
|
318 |
|
319 |
-
return transcriber.transcribe(audio_segment)
|
|
|
1 |
"""
|
2 |
+
基于Transformers实现的语音识别模块,使用distil-whisper模型
|
3 |
"""
|
4 |
|
5 |
import os
|
|
|
15 |
logger = logging.getLogger("asr")
|
16 |
|
17 |
|
18 |
+
class TransformersDistilWhisperTranscriber(BaseTranscriber):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
"""使用Transformers加载和运行distil-whisper模型的转录器"""
|
20 |
|
21 |
def __init__(
|
|
|
147 |
# 如果新格式失败,回退到简单调用
|
148 |
return self.pipeline(audio_data)
|
149 |
|
|
|
150 |
# 统一的接口函数
|
151 |
def transcribe_audio(
|
152 |
audio_segment: AudioSegment,
|
153 |
model_name: str = None,
|
|
|
154 |
device: str = "cpu",
|
155 |
) -> TranscriptionResult:
|
156 |
"""
|
157 |
+
使用Distil Whisper模型转录音频 (Transformers后端)
|
158 |
|
159 |
参数:
|
160 |
audio_segment: 输入的AudioSegment对象
|
161 |
model_name: 使用的模型名称,如果不指定则使用默认模型
|
162 |
+
device: 推理设备,'cpu'或'cuda'
|
|
|
163 |
|
164 |
返回:
|
165 |
TranscriptionResult对象,包含转录的文本、分段和语言
|
166 |
"""
|
167 |
+
logger.info(f"调用 transcribe_audio 函数 (Transformers后端),音频长度: {len(audio_segment)/1000:.2f}秒,设备: {device}")
|
168 |
|
169 |
+
default_model = "distil-whisper/distil-large-v3.5"
|
170 |
+
model = model_name or default_model
|
171 |
+
transcriber = TransformersDistilWhisperTranscriber(model_name=model, device=device)
|
|
|
|
|
|
|
|
|
|
|
172 |
|
173 |
+
return transcriber.transcribe(audio_segment)
|
src/podcast_transcribe/asr/asr_router.py
CHANGED
@@ -4,11 +4,10 @@ ASR模型调用路由器
|
|
4 |
"""
|
5 |
|
6 |
import logging
|
7 |
-
from typing import Dict, Any,
|
8 |
from pydub import AudioSegment
|
9 |
import spaces
|
10 |
from .asr_base import TranscriptionResult
|
11 |
-
from . import asr_distil_whisper
|
12 |
|
13 |
# 配置日志
|
14 |
logger = logging.getLogger("asr")
|
@@ -25,8 +24,8 @@ class ASRRouter:
|
|
25 |
# 定义支持的provider配置
|
26 |
self._provider_configs = {
|
27 |
"distil_whisper_transformers": {
|
28 |
-
"module_path": "
|
29 |
-
"function_name": "transcribe_audio",
|
30 |
"default_model": "distil-whisper/distil-large-v3.5",
|
31 |
"supported_params": ["model_name", "device"],
|
32 |
"description": "基于Transformers的Distil Whisper模型"
|
@@ -50,8 +49,9 @@ class ASRRouter:
|
|
50 |
module_path = self._provider_configs[provider]["module_path"]
|
51 |
logger.info(f"获取模块: {module_path}")
|
52 |
|
53 |
-
#
|
54 |
-
|
|
|
55 |
|
56 |
self._loaded_modules[provider] = module
|
57 |
logger.info(f"模块 {module_path} 获取成功")
|
@@ -98,6 +98,10 @@ class ASRRouter:
|
|
98 |
if "model_name" not in filtered_params and "model_name" in supported_params:
|
99 |
filtered_params["model_name"] = self._provider_configs[provider]["default_model"]
|
100 |
|
|
|
|
|
|
|
|
|
101 |
return filtered_params
|
102 |
|
103 |
def transcribe(
|
@@ -179,17 +183,17 @@ def transcribe_audio(
|
|
179 |
provider: str = "distil_whisper_transformers",
|
180 |
model_name: Optional[str] = None,
|
181 |
device: str = "cpu",
|
182 |
-
backend: str = "transformers",
|
183 |
**kwargs
|
184 |
) -> TranscriptionResult:
|
|
|
|
|
|
|
185 |
# 准备参数
|
186 |
params = kwargs.copy()
|
187 |
if model_name is not None:
|
188 |
params["model_name"] = model_name
|
189 |
-
if device != "cpu":
|
190 |
params["device"] = device
|
191 |
-
if backend is not None:
|
192 |
-
params["backend"] = backend
|
193 |
|
194 |
return _router.transcribe(audio_segment, provider, **params)
|
195 |
|
|
|
4 |
"""
|
5 |
|
6 |
import logging
|
7 |
+
from typing import Dict, Any, Optional, Callable
|
8 |
from pydub import AudioSegment
|
9 |
import spaces
|
10 |
from .asr_base import TranscriptionResult
|
|
|
11 |
|
12 |
# 配置日志
|
13 |
logger = logging.getLogger("asr")
|
|
|
24 |
# 定义支持的provider配置
|
25 |
self._provider_configs = {
|
26 |
"distil_whisper_transformers": {
|
27 |
+
"module_path": ".asr_distil_whisper_transformers",
|
28 |
+
"function_name": "transcribe_audio",
|
29 |
"default_model": "distil-whisper/distil-large-v3.5",
|
30 |
"supported_params": ["model_name", "device"],
|
31 |
"description": "基于Transformers的Distil Whisper模型"
|
|
|
49 |
module_path = self._provider_configs[provider]["module_path"]
|
50 |
logger.info(f"获取模块: {module_path}")
|
51 |
|
52 |
+
# 使用 importlib 动态导入模块
|
53 |
+
import importlib
|
54 |
+
module = importlib.import_module(module_path, package=__package__)
|
55 |
|
56 |
self._loaded_modules[provider] = module
|
57 |
logger.info(f"模块 {module_path} 获取成功")
|
|
|
98 |
if "model_name" not in filtered_params and "model_name" in supported_params:
|
99 |
filtered_params["model_name"] = self._provider_configs[provider]["default_model"]
|
100 |
|
101 |
+
# 对于 Transformers backend,如果 device 未指定,则默认为 cpu
|
102 |
+
if provider == "distil_whisper_transformers" and "device" in supported_params and "device" not in filtered_params:
|
103 |
+
filtered_params["device"] = "cpu"
|
104 |
+
|
105 |
return filtered_params
|
106 |
|
107 |
def transcribe(
|
|
|
183 |
provider: str = "distil_whisper_transformers",
|
184 |
model_name: Optional[str] = None,
|
185 |
device: str = "cpu",
|
|
|
186 |
**kwargs
|
187 |
) -> TranscriptionResult:
|
188 |
+
"""
|
189 |
+
统一的音频转录接口,通过路由器选择后端
|
190 |
+
"""
|
191 |
# 准备参数
|
192 |
params = kwargs.copy()
|
193 |
if model_name is not None:
|
194 |
params["model_name"] = model_name
|
195 |
+
if device != "cpu": # 只有当 device 不是默认值才传递,或者根据需要传递所有支持的参数
|
196 |
params["device"] = device
|
|
|
|
|
197 |
|
198 |
return _router.transcribe(audio_segment, provider, **params)
|
199 |
|