Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
5d74f79
1
Parent(s):
d709cdc
优化asr
Browse files- examples/combined_podcast_transcription.py +1 -9
- examples/combined_transcription.py +0 -8
- examples/simple_asr.py +6 -21
- examples/simple_diarization.py +2 -8
- requirements.txt +0 -1
- src/podcast_transcribe/asr/asr_base.py +1 -186
- src/podcast_transcribe/asr/asr_distil_whisper.py +319 -0
- src/podcast_transcribe/asr/asr_distil_whisper_mlx.py +0 -111
- src/podcast_transcribe/asr/asr_distil_whisper_transformers.py +0 -133
- src/podcast_transcribe/asr/asr_parakeet_mlx.py +0 -126
- src/podcast_transcribe/asr/asr_router.py +8 -65
- src/podcast_transcribe/diarization/diarization_pyannote_mlx.py +0 -3
- src/podcast_transcribe/diarization/diarization_pyannote_transformers.py +1 -4
- src/podcast_transcribe/diarization/diarizer_base.py +0 -1
- src/podcast_transcribe/diarization/diarizer_router.py +1 -4
- src/podcast_transcribe/transcriber.py +0 -11
examples/combined_podcast_transcription.py
CHANGED
@@ -20,10 +20,9 @@ def main():
|
|
20 |
# audio_file = Path("/Users/konie/Desktop/voices/lex_ai_john_carmack_30.wav")
|
21 |
|
22 |
# 模型配置
|
23 |
-
asr_model_name = "mlx-community/
|
24 |
diarization_model_name = "pyannote/speaker-diarization-3.1" # 说话人分离模型名称
|
25 |
llm_model_path = "mlx-community/gemma-3-12b-it-4bit-DWQ"
|
26 |
-
hf_token = "" # Hugging Face API 令牌
|
27 |
device = "mps" # 设备类型
|
28 |
segmentation_batch_size = 64
|
29 |
parallel = True
|
@@ -33,12 +32,6 @@ def main():
|
|
33 |
print(f"错误:文件 '{audio_file}' 不存在")
|
34 |
return 1
|
35 |
|
36 |
-
# 检查HF令牌
|
37 |
-
if not hf_token:
|
38 |
-
print("警告:未设置HF_TOKEN环境变量,必须设置此环境变量才能使用pyannote说话人分离模型")
|
39 |
-
print("请执行:export HF_TOKEN='你的HuggingFace令牌'")
|
40 |
-
return 1
|
41 |
-
|
42 |
try:
|
43 |
print(f"正在加载音频文件: {audio_file}")
|
44 |
# 加载音频文件
|
@@ -67,7 +60,6 @@ def main():
|
|
67 |
result = transcribe_podcast_audio(audio,
|
68 |
podcast_info=mock_podcast_info,
|
69 |
episode_info=mock_episode_info,
|
70 |
-
hf_token=hf_token,
|
71 |
asr_model_name=asr_model_name,
|
72 |
diarization_model_name=diarization_model_name,
|
73 |
llm_model_name=llm_model_path,
|
|
|
20 |
# audio_file = Path("/Users/konie/Desktop/voices/lex_ai_john_carmack_30.wav")
|
21 |
|
22 |
# 模型配置
|
23 |
+
asr_model_name = "mlx-community/" # ASR模型名称
|
24 |
diarization_model_name = "pyannote/speaker-diarization-3.1" # 说话人分离模型名称
|
25 |
llm_model_path = "mlx-community/gemma-3-12b-it-4bit-DWQ"
|
|
|
26 |
device = "mps" # 设备类型
|
27 |
segmentation_batch_size = 64
|
28 |
parallel = True
|
|
|
32 |
print(f"错误:文件 '{audio_file}' 不存在")
|
33 |
return 1
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
try:
|
36 |
print(f"正在加载音频文件: {audio_file}")
|
37 |
# 加载音频文件
|
|
|
60 |
result = transcribe_podcast_audio(audio,
|
61 |
podcast_info=mock_podcast_info,
|
62 |
episode_info=mock_episode_info,
|
|
|
63 |
asr_model_name=asr_model_name,
|
64 |
diarization_model_name=diarization_model_name,
|
65 |
llm_model_name=llm_model_path,
|
examples/combined_transcription.py
CHANGED
@@ -27,7 +27,6 @@ def main():
|
|
27 |
# 模型配置
|
28 |
asr_model_name = "distil-whisper/distil-large-v3.5" # ASR模型名称
|
29 |
diarization_model_name = "pyannote/speaker-diarization-3.1" # 说话人分离模型名称
|
30 |
-
hf_token = "" # Hugging Face API 令牌
|
31 |
device = "mps" # 设备类型
|
32 |
segmentation_batch_size = 64
|
33 |
parallel = True
|
@@ -37,12 +36,6 @@ def main():
|
|
37 |
print(f"错误:文件 '{audio_file}' 不存在")
|
38 |
return 1
|
39 |
|
40 |
-
# 检查HF令牌
|
41 |
-
if not hf_token:
|
42 |
-
print("警告:未设置HF_TOKEN环境变量,必须设置此环境变量才能使用pyannote说话人分离模型")
|
43 |
-
print("请执行:export HF_TOKEN='你的HuggingFace令牌'")
|
44 |
-
return 1
|
45 |
-
|
46 |
try:
|
47 |
print(f"正在加载音频文件: {audio_file}")
|
48 |
# 加载音频文件
|
@@ -54,7 +47,6 @@ def main():
|
|
54 |
audio,
|
55 |
asr_model_name=asr_model_name,
|
56 |
diarization_model_name=diarization_model_name,
|
57 |
-
hf_token=hf_token,
|
58 |
device=device,
|
59 |
segmentation_batch_size=segmentation_batch_size,
|
60 |
parallel=parallel,
|
|
|
27 |
# 模型配置
|
28 |
asr_model_name = "distil-whisper/distil-large-v3.5" # ASR模型名称
|
29 |
diarization_model_name = "pyannote/speaker-diarization-3.1" # 说话人分离模型名称
|
|
|
30 |
device = "mps" # 设备类型
|
31 |
segmentation_batch_size = 64
|
32 |
parallel = True
|
|
|
36 |
print(f"错误:文件 '{audio_file}' 不存在")
|
37 |
return 1
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
try:
|
40 |
print(f"正在加载音频文件: {audio_file}")
|
41 |
# 加载音频文件
|
|
|
47 |
audio,
|
48 |
asr_model_name=asr_model_name,
|
49 |
diarization_model_name=diarization_model_name,
|
|
|
50 |
device=device,
|
51 |
segmentation_batch_size=segmentation_batch_size,
|
52 |
parallel=parallel,
|
examples/simple_asr.py
CHANGED
@@ -14,41 +14,26 @@ from pathlib import Path
|
|
14 |
sys.path.insert(0, str(Path(__file__).parent.parent))
|
15 |
|
16 |
from src.podcast_transcribe.audio import load_audio
|
|
|
17 |
|
18 |
logger = logging.getLogger("asr_example")
|
19 |
|
20 |
|
21 |
def main():
|
22 |
"""主函数"""
|
23 |
-
|
24 |
# audio_file = "/Users/konie/Desktop/voices/lex_ai_john_carmack_30.wav" # 播客音频文件路径
|
25 |
-
audio_file = "/Users/konie/Desktop/voices/podcast1_1.wav"
|
26 |
# model = "distil-whisper"
|
27 |
-
model = "distil-whisper-
|
28 |
|
29 |
-
device = "
|
30 |
|
31 |
# 检查文件是否存在
|
32 |
if not os.path.exists(audio_file):
|
33 |
print(f"错误:文件 '{audio_file}' 不存在")
|
34 |
return 1
|
35 |
|
36 |
-
if model == "parakeet":
|
37 |
-
from src.podcast_transcribe.asr.asr_parakeet_mlx import transcribe_audio
|
38 |
-
model_name = "mlx-community/parakeet-tdt-0.6b-v2"
|
39 |
-
logger.info(f"使用Parakeet模型: {model_name}")
|
40 |
-
elif model == "distil-whisper": # distil-whisper
|
41 |
-
from src.podcast_transcribe.asr.asr_distil_whisper_mlx import transcribe_audio
|
42 |
-
model_name = "mlx-community/distil-whisper-large-v3"
|
43 |
-
logger.info(f"使用Distil Whisper模型: {model_name}")
|
44 |
-
elif model == "distil-whisper-transformers": # distil-whisper
|
45 |
-
from src.podcast_transcribe.asr.asr_distil_whisper_transformers import transcribe_audio
|
46 |
-
model_name = "distil-whisper/distil-large-v3.5"
|
47 |
-
logger.info(f"使用Distil Whisper模型: {model_name}")
|
48 |
-
else:
|
49 |
-
logger.error(f"错误:未指定模型类型")
|
50 |
-
return 1
|
51 |
-
|
52 |
try:
|
53 |
print(f"正在加载音频文件: {audio_file}")
|
54 |
# 加载音频文件
|
@@ -58,7 +43,7 @@ def main():
|
|
58 |
|
59 |
# 进行转录
|
60 |
print("开始转录...")
|
61 |
-
result = transcribe_audio(audio, model_name=
|
62 |
|
63 |
# 输出结果
|
64 |
print("\n转录结果:")
|
|
|
14 |
sys.path.insert(0, str(Path(__file__).parent.parent))
|
15 |
|
16 |
from src.podcast_transcribe.audio import load_audio
|
17 |
+
from src.podcast_transcribe.asr.asr_distil_whisper import transcribe_audio
|
18 |
|
19 |
logger = logging.getLogger("asr_example")
|
20 |
|
21 |
|
22 |
def main():
|
23 |
"""主函数"""
|
24 |
+
audio_file = Path.joinpath(Path(__file__).parent, "input", "lex_ai_john_carmack_1.wav") # 播客音频文件路径
|
25 |
# audio_file = "/Users/konie/Desktop/voices/lex_ai_john_carmack_30.wav" # 播客音频文件路径
|
26 |
+
# audio_file = "/Users/konie/Desktop/voices/podcast1_1.wav"
|
27 |
# model = "distil-whisper"
|
28 |
+
model = "distil-whisper/distil-large-v3.5"
|
29 |
|
30 |
+
device = "mps"
|
31 |
|
32 |
# 检查文件是否存在
|
33 |
if not os.path.exists(audio_file):
|
34 |
print(f"错误:文件 '{audio_file}' 不存在")
|
35 |
return 1
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
try:
|
38 |
print(f"正在加载音频文件: {audio_file}")
|
39 |
# 加载音频文件
|
|
|
43 |
|
44 |
# 进行转录
|
45 |
print("开始转录...")
|
46 |
+
result = transcribe_audio(audio, model_name=model, device=device)
|
47 |
|
48 |
# 输出结果
|
49 |
print("\n转录结果:")
|
examples/simple_diarization.py
CHANGED
@@ -22,7 +22,6 @@ def main():
|
|
22 |
audio_file = Path.joinpath(Path(__file__).parent, "input", "lex_ai_john_carmack_1.wav") # 播客音频文件路径
|
23 |
# audio_file = "/Users/konie/Desktop/voices/history_in_the_baking.mp3" # 播客音频文件路径
|
24 |
model_name = "pyannote/speaker-diarization-3.1" # 说话人分离模型名称
|
25 |
-
hf_token = "" # Hugging Face API 令牌
|
26 |
device = "mps" # 设备类型
|
27 |
|
28 |
# 检查文件是否存在
|
@@ -30,11 +29,6 @@ def main():
|
|
30 |
print(f"错误:文件 '{audio_file}' 不存在")
|
31 |
return 1
|
32 |
|
33 |
-
# 检查令牌是否设置
|
34 |
-
if not hf_token:
|
35 |
-
print("错误:未设置HF_TOKEN环境变量,请设置后再运行")
|
36 |
-
return 1
|
37 |
-
|
38 |
try:
|
39 |
print(f"正在加载音频文件: {audio_file}")
|
40 |
# 加载音频文件
|
@@ -46,12 +40,12 @@ def main():
|
|
46 |
if "pyannote/speaker-diarization" in model_name:
|
47 |
# 使用transformers版本进行说话人分离
|
48 |
print(f"使用transformers版本处理模型: {model_name}")
|
49 |
-
result = diarize_audio_transformers(audio, model_name=model_name,
|
50 |
version_name = "Transformers"
|
51 |
else:
|
52 |
# 使用MLX版本进行说话人分离
|
53 |
print(f"使用MLX版本处理模型: {model_name}")
|
54 |
-
result = diarize_audio_mlx(audio, model_name=model_name,
|
55 |
version_name = "MLX"
|
56 |
|
57 |
# 输出结果
|
|
|
22 |
audio_file = Path.joinpath(Path(__file__).parent, "input", "lex_ai_john_carmack_1.wav") # 播客音频文件路径
|
23 |
# audio_file = "/Users/konie/Desktop/voices/history_in_the_baking.mp3" # 播客音频文件路径
|
24 |
model_name = "pyannote/speaker-diarization-3.1" # 说话人分离模型名称
|
|
|
25 |
device = "mps" # 设备类型
|
26 |
|
27 |
# 检查文件是否存在
|
|
|
29 |
print(f"错误:文件 '{audio_file}' 不存在")
|
30 |
return 1
|
31 |
|
|
|
|
|
|
|
|
|
|
|
32 |
try:
|
33 |
print(f"正在加载音频文件: {audio_file}")
|
34 |
# 加载音频文件
|
|
|
40 |
if "pyannote/speaker-diarization" in model_name:
|
41 |
# 使用transformers版本进行说话人分离
|
42 |
print(f"使用transformers版本处理模型: {model_name}")
|
43 |
+
result = diarize_audio_transformers(audio, model_name=model_name, device=device, segmentation_batch_size=128)
|
44 |
version_name = "Transformers"
|
45 |
else:
|
46 |
# 使用MLX版本进行说话人分离
|
47 |
print(f"使用MLX版本处理模型: {model_name}")
|
48 |
+
result = diarize_audio_mlx(audio, model_name=model_name, device=device, segmentation_batch_size=128)
|
49 |
version_name = "MLX"
|
50 |
|
51 |
# 输出结果
|
requirements.txt
CHANGED
@@ -18,5 +18,4 @@ accelerate>=1.6.0
|
|
18 |
# MLX特定依赖 - 仅适用于Apple Silicon Mac
|
19 |
# mlx>=0.25.2
|
20 |
# mlx-lm>=0.24.0
|
21 |
-
# parakeet-mlx>=0.2.6
|
22 |
# mlx-whisper>=0.4.2
|
|
|
18 |
# MLX特定依赖 - 仅适用于Apple Silicon Mac
|
19 |
# mlx>=0.25.2
|
20 |
# mlx-lm>=0.24.0
|
|
|
21 |
# mlx-whisper>=0.4.2
|
src/podcast_transcribe/asr/asr_base.py
CHANGED
@@ -89,189 +89,4 @@ class BaseTranscriber:
|
|
89 |
|
90 |
if chinese_chars > len(text) * 0.3:
|
91 |
return "zh"
|
92 |
-
return "en"
|
93 |
-
|
94 |
-
def _convert_segments(self, model_result) -> List[Dict[str, Union[float, str]]]:
|
95 |
-
"""
|
96 |
-
将模型的分段结果转换为所需格式(需要在子类中实现)
|
97 |
-
|
98 |
-
参数:
|
99 |
-
model_result: 模型返回的结果
|
100 |
-
|
101 |
-
返回:
|
102 |
-
转换后的分段列表
|
103 |
-
"""
|
104 |
-
raise NotImplementedError("子类必须实现_convert_segments方法")
|
105 |
-
|
106 |
-
def transcribe(self, audio: AudioSegment, chunk_duration_s: int = 30, overlap_s: int = 5) -> TranscriptionResult:
|
107 |
-
"""
|
108 |
-
转录音频,支持长音频分块处理。
|
109 |
-
|
110 |
-
参数:
|
111 |
-
audio: 要转录的AudioSegment对象
|
112 |
-
chunk_duration_s: 分块处理的块时长(秒)。如果音频短于此,则不分块。
|
113 |
-
overlap_s: 分块间的重叠时长(秒)。
|
114 |
-
|
115 |
-
返回:
|
116 |
-
TranscriptionResult对象,包含转录结果
|
117 |
-
"""
|
118 |
-
logger.info(f"开始转录 {len(audio)/1000:.2f} 秒的音频。分块设置: 块时长={chunk_duration_s}s, 重叠={overlap_s}s")
|
119 |
-
|
120 |
-
if overlap_s >= chunk_duration_s and len(audio)/1000.0 > chunk_duration_s :
|
121 |
-
logger.error("重叠时长必须小于块时长。")
|
122 |
-
raise ValueError("overlap_s 必须小于 chunk_duration_s。")
|
123 |
-
|
124 |
-
total_duration_ms = len(audio)
|
125 |
-
chunk_duration_ms = chunk_duration_s * 1000
|
126 |
-
overlap_ms = overlap_s * 1000
|
127 |
-
|
128 |
-
if total_duration_ms <= chunk_duration_ms:
|
129 |
-
logger.debug("音频时长不大于设定块时长,直接进行完整转录。")
|
130 |
-
processed_audio = self._prepare_audio(audio)
|
131 |
-
samples = np.array(processed_audio.get_array_of_samples(), dtype=np.float32) / 32768.0
|
132 |
-
|
133 |
-
try:
|
134 |
-
model_result = self._perform_transcription(samples)
|
135 |
-
text = self._get_text_from_result(model_result)
|
136 |
-
segments = self._convert_segments(model_result)
|
137 |
-
language = self._detect_language(text)
|
138 |
-
|
139 |
-
logger.info(f"单块转录完成,语言: {language},文本长度: {len(text)},分段数: {len(segments)}")
|
140 |
-
return TranscriptionResult(text=text, segments=segments, language=language)
|
141 |
-
except Exception as e:
|
142 |
-
logger.error(f"单块转录失败: {str(e)}", exc_info=True)
|
143 |
-
raise RuntimeError(f"单块转录失败: {str(e)}")
|
144 |
-
|
145 |
-
# 长音频分块处理
|
146 |
-
final_segments = []
|
147 |
-
# current_pos_ms 指的是当前块要处理的"新内容"的起始点在原始音频中的位置
|
148 |
-
current_pos_ms = 0
|
149 |
-
|
150 |
-
while current_pos_ms < total_duration_ms:
|
151 |
-
# 计算当前块实际送入模型处理的音频的起始和结束时间点
|
152 |
-
# 对于第一个块,start_process_ms = 0
|
153 |
-
# 对于后续块,start_process_ms 会向左回退 overlap_ms 以包含重叠区域
|
154 |
-
start_process_ms = max(0, current_pos_ms - overlap_ms)
|
155 |
-
end_process_ms = min(start_process_ms + chunk_duration_ms, total_duration_ms)
|
156 |
-
|
157 |
-
# 如果计算出的块起始点已经等于或超过总时长,说明处理完毕
|
158 |
-
if start_process_ms >= total_duration_ms:
|
159 |
-
break
|
160 |
-
|
161 |
-
chunk_audio = audio[start_process_ms:end_process_ms]
|
162 |
-
|
163 |
-
logger.info(f"处理音频块: {start_process_ms/1000.0:.2f}s - {end_process_ms/1000.0:.2f}s (新内容起始于: {current_pos_ms/1000.0:.2f}s)")
|
164 |
-
|
165 |
-
if len(chunk_audio) == 0:
|
166 |
-
logger.warning(f"生成了一个空的音频块,跳过。起始: {start_process_ms/1000.0:.2f}s, 结束: {end_process_ms/1000.0:.2f}s")
|
167 |
-
# 必须推进 current_pos_ms 以避免死循环
|
168 |
-
advance_ms = chunk_duration_ms - overlap_ms
|
169 |
-
if advance_ms <= 0: # 应该在函数开始时已检查 overlap_s < chunk_duration_s
|
170 |
-
raise RuntimeError("块推进时长配置错误,可能导致死循环。")
|
171 |
-
current_pos_ms += advance_ms
|
172 |
-
continue
|
173 |
-
|
174 |
-
processed_chunk_audio = self._prepare_audio(chunk_audio)
|
175 |
-
samples = np.array(processed_chunk_audio.get_array_of_samples(), dtype=np.float32) / 32768.0
|
176 |
-
|
177 |
-
try:
|
178 |
-
model_result = self._perform_transcription(samples)
|
179 |
-
segments_chunk = self._convert_segments(model_result)
|
180 |
-
|
181 |
-
for seg in segments_chunk:
|
182 |
-
# seg["start"] 和 seg["end"] 是相对于当前块 (chunk_audio) 的起始点(即0)
|
183 |
-
# 计算 segment 在原始完整音频中的绝对起止时间
|
184 |
-
global_seg_start_s = start_process_ms / 1000.0 + seg["start"]
|
185 |
-
global_seg_end_s = start_process_ms / 1000.0 + seg["end"]
|
186 |
-
|
187 |
-
# ��心去重逻辑:
|
188 |
-
# 我们只接受那些真实开始于 current_pos_ms / 1000.0 之后的 segment。
|
189 |
-
# current_pos_ms 是当前块应该贡献的"新"内容的开始时间。
|
190 |
-
# 对于第一个块 (current_pos_ms == 0),所有 segment 都被接受(只要它们的 start >= 0)。
|
191 |
-
# 对于后续块,只有当 segment 的全局开始时间 >= 当前块新内容的开始时间时,才添加。
|
192 |
-
if global_seg_start_s >= current_pos_ms / 1000.0:
|
193 |
-
final_segments.append({
|
194 |
-
"start": global_seg_start_s,
|
195 |
-
"end": global_seg_end_s,
|
196 |
-
"text": seg["text"]
|
197 |
-
})
|
198 |
-
# 特殊处理第一个块,因为 current_pos_ms 为 0,上面的条件 global_seg_start_s >= 0 总是满足。
|
199 |
-
# 但为了更清晰,如果不是第一个块,但 segment 跨越了 current_pos_ms,
|
200 |
-
# 它的起始部分在重叠区,结束部分在非重叠区。
|
201 |
-
# 当前逻辑是,如果它的 global_seg_start_s < current_pos_ms / 1000.0,它就被丢弃。
|
202 |
-
# 这是为了确保不重复记录重叠区域的开头部分。
|
203 |
-
# 如果一个 segment 完全在重叠区内且在前一个块已被记录,此逻辑可避免重复。
|
204 |
-
|
205 |
-
except Exception as e:
|
206 |
-
logger.error(f"处理音频块 {start_process_ms/1000.0:.2f}s - {end_process_ms/1000.0:.2f}s 失败: {str(e)}", exc_info=True)
|
207 |
-
|
208 |
-
# 更新下一个"新内容"块的起始位置
|
209 |
-
advance_ms = chunk_duration_ms - overlap_ms
|
210 |
-
current_pos_ms += advance_ms
|
211 |
-
|
212 |
-
# 对收集到的所有 segments 按开始时间排序
|
213 |
-
final_segments.sort(key=lambda s: s["start"])
|
214 |
-
|
215 |
-
# 可选:进一步清理 segments,例如合并非常接近且文本连续的,或移除完全重复的
|
216 |
-
cleaned_segments = []
|
217 |
-
if final_segments:
|
218 |
-
cleaned_segments.append(final_segments[0])
|
219 |
-
for i in range(1, len(final_segments)):
|
220 |
-
prev_s = cleaned_segments[-1]
|
221 |
-
curr_s = final_segments[i]
|
222 |
-
# 简单的去重:如果时间戳和文本都几乎一样,则认为是重复
|
223 |
-
if abs(curr_s["start"] - prev_s["start"]) < 0.01 and \
|
224 |
-
abs(curr_s["end"] - prev_s["end"]) < 0.01 and \
|
225 |
-
curr_s["text"] == prev_s["text"]:
|
226 |
-
continue
|
227 |
-
|
228 |
-
# 如果当前 segment 的开始时间在前一个 segment 的结束时间之前,
|
229 |
-
# 并且文本有明显重叠,可能需要更智能的合并。
|
230 |
-
# 目前的逻辑通过 global_seg_start_s >= current_pos_ms / 1000.0 过滤,
|
231 |
-
# 已经大大减少了直接的 segment 重复。
|
232 |
-
# 此处的清理更多是处理模型在边界可能产生的一些微小偏差。
|
233 |
-
# 如果上一个segment的结束时间比当前segment的开始时间还要晚,说明有重叠,
|
234 |
-
# 且上一个segment包含了当前segment的开始部分。
|
235 |
-
# 这种情况下,可以考虑调整上一个的结束,或当前segment的开始和文本。
|
236 |
-
# 为简单起见,暂时直接添加,相信之前的过滤已处理主要重叠。
|
237 |
-
if curr_s["start"] < prev_s["end"] and prev_s["text"].endswith(curr_s["text"][:len(prev_s["text"]) - int((prev_s["end"] - curr_s["start"])*10) ]): # 粗略检查
|
238 |
-
# 如果curr_s的开始部分被prev_s覆盖,并且文本也对应,则调整curr_s
|
239 |
-
# pass # 暂时不处理这种细微重叠,依赖模型切分
|
240 |
-
cleaned_segments.append(curr_s) # 仍添加,依赖后续文本拼接
|
241 |
-
else:
|
242 |
-
cleaned_segments.append(curr_s)
|
243 |
-
|
244 |
-
final_text = " ".join([s["text"] for s in cleaned_segments]).strip()
|
245 |
-
language = self._detect_language(final_text)
|
246 |
-
|
247 |
-
logger.info(f"分块转录完成。最终文本长度: {len(final_text)}, 分段数: {len(cleaned_segments)}")
|
248 |
-
|
249 |
-
return TranscriptionResult(
|
250 |
-
text=final_text,
|
251 |
-
segments=cleaned_segments,
|
252 |
-
language=language
|
253 |
-
)
|
254 |
-
|
255 |
-
def _perform_transcription(self, audio_data):
|
256 |
-
"""
|
257 |
-
执行转录(需要在子类中实现)
|
258 |
-
|
259 |
-
参数:
|
260 |
-
audio_data: 音频数据(numpy数组)
|
261 |
-
|
262 |
-
返回:
|
263 |
-
模型的转录结果
|
264 |
-
"""
|
265 |
-
raise NotImplementedError("子类必须实现_perform_transcription方法")
|
266 |
-
|
267 |
-
def _get_text_from_result(self, result):
|
268 |
-
"""
|
269 |
-
从结果中获取文本(需要在子类中实现)
|
270 |
-
|
271 |
-
参数:
|
272 |
-
result: 模型的转录结果
|
273 |
-
|
274 |
-
返回:
|
275 |
-
���录的文本
|
276 |
-
"""
|
277 |
-
raise NotImplementedError("子类必须实现_get_text_from_result方法")
|
|
|
89 |
|
90 |
if chinese_chars > len(text) * 0.3:
|
91 |
return "zh"
|
92 |
+
return "en"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/podcast_transcribe/asr/asr_distil_whisper.py
ADDED
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
基于MLX或Transformers实现的语音识别模块,使用distil-whisper模型
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
from pydub import AudioSegment
|
7 |
+
from typing import Dict, List, Union, Literal
|
8 |
+
import logging
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
# 导入基类
|
12 |
+
from .asr_base import BaseTranscriber, TranscriptionResult
|
13 |
+
|
14 |
+
# 配置日志
|
15 |
+
logger = logging.getLogger("asr")
|
16 |
+
|
17 |
+
|
18 |
+
class DistilWhisperTranscriber(BaseTranscriber):
|
19 |
+
"""抽象基类:Distil Whisper转录器的共享实现"""
|
20 |
+
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
model_name: str,
|
24 |
+
**kwargs
|
25 |
+
):
|
26 |
+
"""
|
27 |
+
初始化转录器
|
28 |
+
|
29 |
+
参数:
|
30 |
+
model_name: 模型名称
|
31 |
+
**kwargs: 其他参数
|
32 |
+
"""
|
33 |
+
super().__init__(model_name=model_name, **kwargs)
|
34 |
+
|
35 |
+
def transcribe(self, audio: AudioSegment) -> TranscriptionResult:
|
36 |
+
"""
|
37 |
+
转录音频,针对distil-whisper模型取消分块处理,直接处理整个音频。
|
38 |
+
|
39 |
+
参数:
|
40 |
+
audio: 要转录的AudioSegment对象
|
41 |
+
chunk_duration_s: 分块处理的块时长(秒)- 此参数被忽略
|
42 |
+
overlap_s: 分块间的重叠时长(秒)- 此参数被忽略
|
43 |
+
|
44 |
+
返回:
|
45 |
+
TranscriptionResult对象,包含转录结果
|
46 |
+
"""
|
47 |
+
logger.info(f"开始转录 {len(audio)/1000:.2f} 秒的音频(distil-whisper模型)")
|
48 |
+
|
49 |
+
# 直接处理整个音频,不进行分块
|
50 |
+
processed_audio = self._prepare_audio(audio)
|
51 |
+
samples = np.array(processed_audio.get_array_of_samples(), dtype=np.float32) / 32768.0
|
52 |
+
|
53 |
+
try:
|
54 |
+
model_result = self._perform_transcription(samples)
|
55 |
+
text = self._get_text_from_result(model_result)
|
56 |
+
segments = self._convert_segments(model_result)
|
57 |
+
language = self._detect_language(text)
|
58 |
+
|
59 |
+
logger.info(f"转录完成,语言: {language},文本长度: {len(text)},分段数: {len(segments)}")
|
60 |
+
return TranscriptionResult(text=text, segments=segments, language=language)
|
61 |
+
except Exception as e:
|
62 |
+
logger.error(f"转录失败: {str(e)}", exc_info=True)
|
63 |
+
raise RuntimeError(f"转录失败: {str(e)}")
|
64 |
+
|
65 |
+
def _get_text_from_result(self, result):
|
66 |
+
"""
|
67 |
+
从结果中获取文本
|
68 |
+
|
69 |
+
参数:
|
70 |
+
result: 模型的转录结果
|
71 |
+
|
72 |
+
返回:
|
73 |
+
转录的文本
|
74 |
+
"""
|
75 |
+
return result.get("text", "")
|
76 |
+
|
77 |
+
def _load_model(self):
|
78 |
+
"""加载模型的抽象方法,由子类实现"""
|
79 |
+
raise NotImplementedError("子类必须实现_load_model方法")
|
80 |
+
|
81 |
+
def _perform_transcription(self, audio_data):
|
82 |
+
"""执行转录的抽象方法,由子类实现"""
|
83 |
+
raise NotImplementedError("子类必须实现_perform_transcription方法")
|
84 |
+
|
85 |
+
def _convert_segments(self, result) -> List[Dict[str, Union[float, str]]]:
|
86 |
+
"""将模型结果转换为分段的抽象方法,由子类实现"""
|
87 |
+
raise NotImplementedError("子类必须实现_convert_segments方法")
|
88 |
+
|
89 |
+
|
90 |
+
class MLXDistilWhisperTranscriber(DistilWhisperTranscriber):
|
91 |
+
"""使用MLX加载和运行distil-whisper模型的转录器"""
|
92 |
+
|
93 |
+
def __init__(
|
94 |
+
self,
|
95 |
+
model_name: str = "mlx-community/distil-whisper-large-v3",
|
96 |
+
):
|
97 |
+
"""
|
98 |
+
初始化转录器
|
99 |
+
|
100 |
+
参数:
|
101 |
+
model_name: 模型名称
|
102 |
+
"""
|
103 |
+
super().__init__(model_name=model_name)
|
104 |
+
|
105 |
+
def _load_model(self):
|
106 |
+
"""加载Distil Whisper MLX模型"""
|
107 |
+
try:
|
108 |
+
# 懒加载mlx-whisper
|
109 |
+
try:
|
110 |
+
import mlx_whisper
|
111 |
+
except ImportError:
|
112 |
+
raise ImportError("请先安装mlx-whisper库: pip install mlx-whisper")
|
113 |
+
|
114 |
+
logger.info(f"开始加载模型 {self.model_name}")
|
115 |
+
self.model = mlx_whisper.load_models.load_model(self.model_name)
|
116 |
+
logger.info(f"模型加载成功")
|
117 |
+
except Exception as e:
|
118 |
+
logger.error(f"加载模型失败: {str(e)}", exc_info=True)
|
119 |
+
raise RuntimeError(f"加载模型失败: {str(e)}")
|
120 |
+
|
121 |
+
def _convert_segments(self, result) -> List[Dict[str, Union[float, str]]]:
|
122 |
+
"""
|
123 |
+
将模型的分段结果转换为所需格式
|
124 |
+
|
125 |
+
参数:
|
126 |
+
result: 模型返回的结果
|
127 |
+
|
128 |
+
返回:
|
129 |
+
转换后的分段列表
|
130 |
+
"""
|
131 |
+
segments = []
|
132 |
+
|
133 |
+
for segment in result.get("segments", []):
|
134 |
+
segments.append({
|
135 |
+
"start": segment.get("start", 0.0),
|
136 |
+
"end": segment.get("end", 0.0),
|
137 |
+
"text": segment.get("text", "").strip()
|
138 |
+
})
|
139 |
+
|
140 |
+
return segments
|
141 |
+
|
142 |
+
def _perform_transcription(self, audio_data):
|
143 |
+
"""
|
144 |
+
执行转录
|
145 |
+
|
146 |
+
参数:
|
147 |
+
audio_data: 音频数据(numpy数组)
|
148 |
+
|
149 |
+
返回:
|
150 |
+
模型的转录结果
|
151 |
+
"""
|
152 |
+
from mlx_whisper import transcribe
|
153 |
+
return transcribe(audio_data, path_or_hf_repo=self.model_name)
|
154 |
+
|
155 |
+
|
156 |
+
class TransformersDistilWhisperTranscriber(DistilWhisperTranscriber):
|
157 |
+
"""使用Transformers加载和运行distil-whisper模型的转录器"""
|
158 |
+
|
159 |
+
def __init__(
|
160 |
+
self,
|
161 |
+
model_name: str = "distil-whisper/distil-large-v3.5",
|
162 |
+
device: str = "cpu",
|
163 |
+
):
|
164 |
+
"""
|
165 |
+
初始化转录器
|
166 |
+
|
167 |
+
参数:
|
168 |
+
model_name: 模型名称
|
169 |
+
device: 推理设备,'cpu'或'cuda'
|
170 |
+
"""
|
171 |
+
super().__init__(model_name=model_name, device=device)
|
172 |
+
|
173 |
+
def _load_model(self):
|
174 |
+
"""加载Distil Whisper Transformers模型"""
|
175 |
+
try:
|
176 |
+
# 懒加载transformers
|
177 |
+
try:
|
178 |
+
from transformers import pipeline
|
179 |
+
except ImportError:
|
180 |
+
raise ImportError("请先安装transformers库: pip install transformers")
|
181 |
+
|
182 |
+
logger.info(f"开始加载模型 {self.model_name} 设备: {self.device}")
|
183 |
+
|
184 |
+
pipeline_device_arg = None
|
185 |
+
if self.device == "cuda":
|
186 |
+
pipeline_device_arg = 0 # 使用第一个 CUDA 设备
|
187 |
+
elif self.device == "mps":
|
188 |
+
pipeline_device_arg = "mps" # 使用 MPS 设备
|
189 |
+
elif self.device == "cpu":
|
190 |
+
pipeline_device_arg = -1 # 使用 CPU
|
191 |
+
else:
|
192 |
+
# 对于其他未明确支持的 device 字符串,记录警告并默认使用 CPU
|
193 |
+
logger.warning(f"不支持的设备字符串 '{self.device}',将默认使用 CPU。")
|
194 |
+
pipeline_device_arg = -1
|
195 |
+
|
196 |
+
# 导入必要的模块来配置模型
|
197 |
+
import warnings
|
198 |
+
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
|
199 |
+
|
200 |
+
# 抑制特定的警告
|
201 |
+
warnings.filterwarnings("ignore", message="The input name `inputs` is deprecated")
|
202 |
+
warnings.filterwarnings("ignore", message="You have passed task=transcribe")
|
203 |
+
warnings.filterwarnings("ignore", message="The attention mask is not set")
|
204 |
+
|
205 |
+
self.pipeline = pipeline(
|
206 |
+
"automatic-speech-recognition",
|
207 |
+
model=self.model_name,
|
208 |
+
device=pipeline_device_arg,
|
209 |
+
return_timestamps=True,
|
210 |
+
chunk_length_s=30, # 使用30秒的块长度
|
211 |
+
stride_length_s=5, # 块之间5秒的重叠
|
212 |
+
batch_size=1, # 顺序处理
|
213 |
+
# 添加以下参数来减少警告
|
214 |
+
generate_kwargs={
|
215 |
+
"task": "transcribe",
|
216 |
+
"language": None, # 自动检测语言
|
217 |
+
"forced_decoder_ids": None, # 避免冲突
|
218 |
+
}
|
219 |
+
)
|
220 |
+
logger.info(f"模型加载成功")
|
221 |
+
except Exception as e:
|
222 |
+
logger.error(f"加载模型失败: {str(e)}", exc_info=True)
|
223 |
+
raise RuntimeError(f"加载模型失败: {str(e)}")
|
224 |
+
|
225 |
+
def _convert_segments(self, result) -> List[Dict[str, Union[float, str]]]:
|
226 |
+
"""
|
227 |
+
将模型的分段结果转换为所需格式
|
228 |
+
|
229 |
+
参数:
|
230 |
+
result: 模型返回的结果
|
231 |
+
|
232 |
+
返回:
|
233 |
+
转换后的分段列表
|
234 |
+
"""
|
235 |
+
segments = []
|
236 |
+
|
237 |
+
# transformers pipeline 的结果格式
|
238 |
+
if "chunks" in result:
|
239 |
+
for chunk in result["chunks"]:
|
240 |
+
segments.append({
|
241 |
+
"start": chunk["timestamp"][0] if chunk["timestamp"][0] is not None else 0.0,
|
242 |
+
"end": chunk["timestamp"][1] if chunk["timestamp"][1] is not None else 0.0,
|
243 |
+
"text": chunk["text"].strip()
|
244 |
+
})
|
245 |
+
else:
|
246 |
+
# 如果没有分段信息,创建一个单一分段
|
247 |
+
segments.append({
|
248 |
+
"start": 0.0,
|
249 |
+
"end": 0.0, # 无法确定结束时间
|
250 |
+
"text": result.get("text", "").strip()
|
251 |
+
})
|
252 |
+
|
253 |
+
return segments
|
254 |
+
|
255 |
+
def _perform_transcription(self, audio_data):
|
256 |
+
"""
|
257 |
+
执行转录
|
258 |
+
|
259 |
+
参数:
|
260 |
+
audio_data: 音频数据(numpy数组)
|
261 |
+
|
262 |
+
返回:
|
263 |
+
模型的转录结果
|
264 |
+
"""
|
265 |
+
# transformers pipeline 接受numpy数组作为输入
|
266 |
+
# 音频数据已经在_prepare_audio中确保是16kHz采样率
|
267 |
+
|
268 |
+
# 确保音频数据格式正确
|
269 |
+
if audio_data.dtype != np.float32:
|
270 |
+
audio_data = audio_data.astype(np.float32)
|
271 |
+
|
272 |
+
# 使用正确的参数名称调用pipeline
|
273 |
+
try:
|
274 |
+
result = self.pipeline(
|
275 |
+
audio_data,
|
276 |
+
generate_kwargs={
|
277 |
+
"task": "transcribe",
|
278 |
+
"language": None, # 自动检测语言
|
279 |
+
"forced_decoder_ids": None, # 避免冲突
|
280 |
+
}
|
281 |
+
)
|
282 |
+
return result
|
283 |
+
except Exception as e:
|
284 |
+
logger.warning(f"使用新参数格式失败,尝试使用默认参数: {str(e)}")
|
285 |
+
# 如果新格式失败,回退到简单调用
|
286 |
+
return self.pipeline(audio_data)
|
287 |
+
|
288 |
+
|
289 |
+
# 统一的接口函数
|
290 |
+
def transcribe_audio(
|
291 |
+
audio_segment: AudioSegment,
|
292 |
+
model_name: str = None,
|
293 |
+
backend: Literal["mlx", "transformers"] = "transformers",
|
294 |
+
device: str = "cpu",
|
295 |
+
) -> TranscriptionResult:
|
296 |
+
"""
|
297 |
+
使用Distil Whisper模型转录音频
|
298 |
+
|
299 |
+
参数:
|
300 |
+
audio_segment: 输入的AudioSegment对象
|
301 |
+
model_name: 使用的模型名称,如果不指定则使用默认模型
|
302 |
+
backend: 后端类型,'mlx'或'transformers'
|
303 |
+
device: 推理设备,仅对transformers后端有效
|
304 |
+
|
305 |
+
返回:
|
306 |
+
TranscriptionResult对象,包含转录的文本、分段和语言
|
307 |
+
"""
|
308 |
+
logger.info(f"调用transcribe_audio函数,音频长度: {len(audio_segment)/1000:.2f}秒,后端: {backend}")
|
309 |
+
|
310 |
+
if backend == "mlx":
|
311 |
+
default_model = "mlx-community/distil-whisper-large-v3"
|
312 |
+
model = model_name or default_model
|
313 |
+
transcriber = MLXDistilWhisperTranscriber(model_name=model)
|
314 |
+
else: # transformers
|
315 |
+
default_model = "distil-whisper/distil-large-v3.5"
|
316 |
+
model = model_name or default_model
|
317 |
+
transcriber = TransformersDistilWhisperTranscriber(model_name=model, device=device)
|
318 |
+
|
319 |
+
return transcriber.transcribe(audio_segment)
|
src/podcast_transcribe/asr/asr_distil_whisper_mlx.py
DELETED
@@ -1,111 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
基于MLX实现的语音识别模块,使用distil-whisper-large-v3模型
|
3 |
-
"""
|
4 |
-
|
5 |
-
import os
|
6 |
-
from pydub import AudioSegment
|
7 |
-
from typing import Dict, List, Union
|
8 |
-
import logging
|
9 |
-
|
10 |
-
# 导入基类
|
11 |
-
from .asr_base import BaseTranscriber, TranscriptionResult
|
12 |
-
|
13 |
-
# 配置日志
|
14 |
-
logger = logging.getLogger("asr")
|
15 |
-
|
16 |
-
|
17 |
-
class MLXDistilWhisperTranscriber(BaseTranscriber):
|
18 |
-
"""使用MLX加载和运行distil-whisper-large-v3模型的转录器"""
|
19 |
-
|
20 |
-
def __init__(
|
21 |
-
self,
|
22 |
-
model_name: str = "mlx-community/distil-whisper-large-v3",
|
23 |
-
):
|
24 |
-
"""
|
25 |
-
初始化转录器
|
26 |
-
|
27 |
-
参数:
|
28 |
-
model_name: 模型名称
|
29 |
-
"""
|
30 |
-
super().__init__(model_name=model_name)
|
31 |
-
|
32 |
-
def _load_model(self):
|
33 |
-
"""加载Distil Whisper模型"""
|
34 |
-
try:
|
35 |
-
# 懒加载mlx-whisper
|
36 |
-
try:
|
37 |
-
import mlx_whisper
|
38 |
-
except ImportError:
|
39 |
-
raise ImportError("请先安装mlx-whisper库: pip install mlx-whisper")
|
40 |
-
|
41 |
-
logger.info(f"开始加载模型 {self.model_name}")
|
42 |
-
self.model = mlx_whisper.load_models.load_model(self.model_name)
|
43 |
-
logger.info(f"模型加载成功")
|
44 |
-
except Exception as e:
|
45 |
-
logger.error(f"加载模型失败: {str(e)}", exc_info=True)
|
46 |
-
raise RuntimeError(f"加载模型失败: {str(e)}")
|
47 |
-
|
48 |
-
def _convert_segments(self, result) -> List[Dict[str, Union[float, str]]]:
|
49 |
-
"""
|
50 |
-
将模型的分段结果转换为所需格式
|
51 |
-
|
52 |
-
参数:
|
53 |
-
result: 模型返回的结果
|
54 |
-
|
55 |
-
返回:
|
56 |
-
转换后的分段列表
|
57 |
-
"""
|
58 |
-
segments = []
|
59 |
-
|
60 |
-
for segment in result.get("segments", []):
|
61 |
-
segments.append({
|
62 |
-
"start": segment.get("start", 0.0),
|
63 |
-
"end": segment.get("end", 0.0),
|
64 |
-
"text": segment.get("text", "").strip()
|
65 |
-
})
|
66 |
-
|
67 |
-
return segments
|
68 |
-
|
69 |
-
def _perform_transcription(self, audio_data):
|
70 |
-
"""
|
71 |
-
执行转录
|
72 |
-
|
73 |
-
参数:
|
74 |
-
audio_data: 音频数据(numpy数组)
|
75 |
-
|
76 |
-
返回:
|
77 |
-
模型的转录结果
|
78 |
-
"""
|
79 |
-
from mlx_whisper import transcribe
|
80 |
-
return transcribe(audio_data, path_or_hf_repo=self.model_name)
|
81 |
-
|
82 |
-
def _get_text_from_result(self, result):
|
83 |
-
"""
|
84 |
-
从结果中获取文本
|
85 |
-
|
86 |
-
参数:
|
87 |
-
result: 模型的转录结果
|
88 |
-
|
89 |
-
返回:
|
90 |
-
转录的文本
|
91 |
-
"""
|
92 |
-
return result.get("text", "")
|
93 |
-
|
94 |
-
|
95 |
-
def transcribe_audio(
|
96 |
-
audio_segment: AudioSegment,
|
97 |
-
model_name: str = "mlx-community/distil-whisper-large-v3",
|
98 |
-
) -> TranscriptionResult:
|
99 |
-
"""
|
100 |
-
使用MLX和distil-whisper-large-v3模型转录音频
|
101 |
-
|
102 |
-
参数:
|
103 |
-
audio_segment: 输入的AudioSegment对象
|
104 |
-
model_name: 使用的模型名称
|
105 |
-
|
106 |
-
返回:
|
107 |
-
TranscriptionResult对象,包含转录的文本、分段和语言
|
108 |
-
"""
|
109 |
-
logger.info(f"调用transcribe_audio函数,音频长度: {len(audio_segment)/1000:.2f}秒")
|
110 |
-
transcriber = MLXDistilWhisperTranscriber(model_name=model_name)
|
111 |
-
return transcriber.transcribe(audio_segment)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/podcast_transcribe/asr/asr_distil_whisper_transformers.py
DELETED
@@ -1,133 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
基于Transformers实现的语音识别模块,使用distil-whisper-large-v3.5模型
|
3 |
-
"""
|
4 |
-
|
5 |
-
import os
|
6 |
-
from pydub import AudioSegment
|
7 |
-
from typing import Dict, List, Union
|
8 |
-
import logging
|
9 |
-
import numpy as np
|
10 |
-
|
11 |
-
# 导入基类
|
12 |
-
from .asr_base import BaseTranscriber, TranscriptionResult
|
13 |
-
|
14 |
-
# 配置日志
|
15 |
-
logger = logging.getLogger("asr")
|
16 |
-
|
17 |
-
|
18 |
-
class TransformersDistilWhisperTranscriber(BaseTranscriber):
|
19 |
-
"""使用Transformers加载和运行distil-whisper-large-v3.5模型的转录器"""
|
20 |
-
|
21 |
-
def __init__(
|
22 |
-
self,
|
23 |
-
model_name: str = "distil-whisper/distil-large-v3.5",
|
24 |
-
device: str = "cpu",
|
25 |
-
):
|
26 |
-
"""
|
27 |
-
初始化转录器
|
28 |
-
|
29 |
-
参数:
|
30 |
-
model_name: 模型名称
|
31 |
-
device: 推理设备,'cpu'或'cuda'
|
32 |
-
"""
|
33 |
-
super().__init__(model_name=model_name, device=device)
|
34 |
-
|
35 |
-
def _load_model(self):
|
36 |
-
"""加载Distil Whisper模型"""
|
37 |
-
try:
|
38 |
-
# 懒加载transformers
|
39 |
-
try:
|
40 |
-
from transformers import pipeline
|
41 |
-
except ImportError:
|
42 |
-
raise ImportError("请先安装transformers库: pip install transformers")
|
43 |
-
|
44 |
-
logger.info(f"开始加载模型 {self.model_name}")
|
45 |
-
self.pipeline = pipeline(
|
46 |
-
"automatic-speech-recognition",
|
47 |
-
model=self.model_name,
|
48 |
-
device=0 if self.device == "cuda" else -1,
|
49 |
-
return_timestamps=True,
|
50 |
-
chunk_length_s=25,
|
51 |
-
batch_size=32,
|
52 |
-
)
|
53 |
-
logger.info(f"模型加载成功")
|
54 |
-
except Exception as e:
|
55 |
-
logger.error(f"加载模型失败: {str(e)}", exc_info=True)
|
56 |
-
raise RuntimeError(f"加载模型失败: {str(e)}")
|
57 |
-
|
58 |
-
def _convert_segments(self, result) -> List[Dict[str, Union[float, str]]]:
|
59 |
-
"""
|
60 |
-
将模型的分段结果转换为所需格式
|
61 |
-
|
62 |
-
参数:
|
63 |
-
result: 模型返回的结果
|
64 |
-
|
65 |
-
返回:
|
66 |
-
转换后的分段列表
|
67 |
-
"""
|
68 |
-
segments = []
|
69 |
-
|
70 |
-
# transformers pipeline 的结果格式
|
71 |
-
if "chunks" in result:
|
72 |
-
for chunk in result["chunks"]:
|
73 |
-
segments.append({
|
74 |
-
"start": chunk["timestamp"][0] if chunk["timestamp"][0] is not None else 0.0,
|
75 |
-
"end": chunk["timestamp"][1] if chunk["timestamp"][1] is not None else 0.0,
|
76 |
-
"text": chunk["text"].strip()
|
77 |
-
})
|
78 |
-
else:
|
79 |
-
# 如果没有分段信息,创建一个单一分段
|
80 |
-
segments.append({
|
81 |
-
"start": 0.0,
|
82 |
-
"end": 0.0, # 无法确定结束时间
|
83 |
-
"text": result.get("text", "").strip()
|
84 |
-
})
|
85 |
-
|
86 |
-
return segments
|
87 |
-
|
88 |
-
def _perform_transcription(self, audio_data):
|
89 |
-
"""
|
90 |
-
执行转录
|
91 |
-
|
92 |
-
参数:
|
93 |
-
audio_data: 音频数据(numpy数组)
|
94 |
-
|
95 |
-
返回:
|
96 |
-
模型的转录结果
|
97 |
-
"""
|
98 |
-
# transformers pipeline 接受numpy数组作为输入
|
99 |
-
# 音频数据已经在_prepare_audio中确保是16kHz采样率
|
100 |
-
return self.pipeline(audio_data)
|
101 |
-
|
102 |
-
def _get_text_from_result(self, result):
|
103 |
-
"""
|
104 |
-
从结果中获取文本
|
105 |
-
|
106 |
-
参数:
|
107 |
-
result: 模型的转录结果
|
108 |
-
|
109 |
-
返回:
|
110 |
-
转录的文本
|
111 |
-
"""
|
112 |
-
return result.get("text", "")
|
113 |
-
|
114 |
-
|
115 |
-
def transcribe_audio(
|
116 |
-
audio_segment: AudioSegment,
|
117 |
-
model_name: str = "distil-whisper/distil-large-v3.5",
|
118 |
-
device: str = "cpu",
|
119 |
-
) -> TranscriptionResult:
|
120 |
-
"""
|
121 |
-
使用Transformers和distil-whisper-large-v3.5模型转录音频
|
122 |
-
|
123 |
-
参数:
|
124 |
-
audio_segment: 输入的AudioSegment对象
|
125 |
-
model_name: 使用的模型名称
|
126 |
-
device: 推理设备,'cpu'或'cuda'
|
127 |
-
|
128 |
-
返回:
|
129 |
-
TranscriptionResult对象,包含转录的文本、分段和语言
|
130 |
-
"""
|
131 |
-
logger.info(f"调用transcribe_audio函数,音频长度: {len(audio_segment)/1000:.2f}秒")
|
132 |
-
transcriber = TransformersDistilWhisperTranscriber(model_name=model_name, device=device)
|
133 |
-
return transcriber.transcribe(audio_segment)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/podcast_transcribe/asr/asr_parakeet_mlx.py
DELETED
@@ -1,126 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
基于MLX实现的语音识别模块,使用parakeet-tdt模型
|
3 |
-
"""
|
4 |
-
|
5 |
-
import os
|
6 |
-
from pydub import AudioSegment
|
7 |
-
from typing import Dict, List, Union
|
8 |
-
import logging
|
9 |
-
import tempfile
|
10 |
-
import numpy as np
|
11 |
-
import soundfile as sf
|
12 |
-
|
13 |
-
# 导入基类
|
14 |
-
from .asr_base import BaseTranscriber, TranscriptionResult
|
15 |
-
|
16 |
-
# 配置日志
|
17 |
-
logger = logging.getLogger("asr")
|
18 |
-
|
19 |
-
|
20 |
-
class MLXParakeetTranscriber(BaseTranscriber):
|
21 |
-
"""使用MLX加载和运行parakeet-tdt-0.6b-v2模型的转录器"""
|
22 |
-
|
23 |
-
def __init__(
|
24 |
-
self,
|
25 |
-
model_name: str = "mlx-community/parakeet-tdt-0.6b-v2",
|
26 |
-
):
|
27 |
-
"""
|
28 |
-
初始化转录器
|
29 |
-
|
30 |
-
参数:
|
31 |
-
model_name: 模型名称
|
32 |
-
"""
|
33 |
-
super().__init__(model_name=model_name)
|
34 |
-
|
35 |
-
def _load_model(self):
|
36 |
-
"""加载Parakeet模型"""
|
37 |
-
try:
|
38 |
-
# 懒加载parakeet_mlx
|
39 |
-
try:
|
40 |
-
from parakeet_mlx import from_pretrained
|
41 |
-
except ImportError:
|
42 |
-
raise ImportError("请先安装parakeet-mlx库: pip install parakeet-mlx")
|
43 |
-
|
44 |
-
logger.info(f"开始加载模型 {self.model_name}")
|
45 |
-
self.model = from_pretrained(self.model_name)
|
46 |
-
logger.info(f"模型加载成功")
|
47 |
-
except Exception as e:
|
48 |
-
logger.error(f"加载模型失败: {str(e)}", exc_info=True)
|
49 |
-
raise RuntimeError(f"加载模型失败: {str(e)}")
|
50 |
-
|
51 |
-
def _convert_segments(self, aligned_result) -> List[Dict[str, Union[float, str]]]:
|
52 |
-
"""
|
53 |
-
将模型的分段结果转换为所需格式
|
54 |
-
|
55 |
-
参数:
|
56 |
-
aligned_result: 模型返回的分段结果
|
57 |
-
|
58 |
-
返回:
|
59 |
-
转换后的分段列表
|
60 |
-
"""
|
61 |
-
segments = []
|
62 |
-
|
63 |
-
for sentence in aligned_result.sentences:
|
64 |
-
segments.append({
|
65 |
-
"start": sentence.start,
|
66 |
-
"end": sentence.end,
|
67 |
-
"text": sentence.text
|
68 |
-
})
|
69 |
-
|
70 |
-
return segments
|
71 |
-
|
72 |
-
def _perform_transcription(self, audio_data):
|
73 |
-
"""
|
74 |
-
执行转录
|
75 |
-
|
76 |
-
参数:
|
77 |
-
audio_data: 音频数据(numpy数组)
|
78 |
-
|
79 |
-
返回:
|
80 |
-
模型的转录结果
|
81 |
-
"""
|
82 |
-
# 由于parakeet-mlx可能不直接支持numpy数组输入
|
83 |
-
# 创建临时文件并写入音频数据
|
84 |
-
with tempfile.NamedTemporaryFile(suffix='.wav', delete=True) as temp_file:
|
85 |
-
# 确保数据在[-1, 1]范围内
|
86 |
-
if audio_data.max() > 1.0 or audio_data.min() < -1.0:
|
87 |
-
audio_data = np.clip(audio_data, -1.0, 1.0)
|
88 |
-
|
89 |
-
# 写入临时文件
|
90 |
-
sf.write(temp_file.name, audio_data, 16000, 'PCM_16')
|
91 |
-
|
92 |
-
# 使用临时文件进行转录
|
93 |
-
result = self.model.transcribe(temp_file.name)
|
94 |
-
|
95 |
-
return result
|
96 |
-
|
97 |
-
def _get_text_from_result(self, result):
|
98 |
-
"""
|
99 |
-
从结果中获取文本
|
100 |
-
|
101 |
-
参数:
|
102 |
-
result: 模型的转录结果
|
103 |
-
|
104 |
-
返回:
|
105 |
-
转录的文本
|
106 |
-
"""
|
107 |
-
return result.text
|
108 |
-
|
109 |
-
|
110 |
-
def transcribe_audio(
|
111 |
-
audio_segment: AudioSegment,
|
112 |
-
model_name: str = "mlx-community/parakeet-tdt-0.6b-v2",
|
113 |
-
) -> TranscriptionResult:
|
114 |
-
"""
|
115 |
-
使用MLX和parakeet-tdt模型转录音频
|
116 |
-
|
117 |
-
参数:
|
118 |
-
audio_segment: 输入的AudioSegment对象
|
119 |
-
model_name: 使用的模型名称
|
120 |
-
|
121 |
-
返回:
|
122 |
-
TranscriptionResult对象,包含转录的文本、分段和语言
|
123 |
-
"""
|
124 |
-
logger.info(f"调用transcribe_audio函数,音频长度: {len(audio_segment)/1000:.2f}秒")
|
125 |
-
transcriber = MLXParakeetTranscriber(model_name=model_name)
|
126 |
-
return transcriber.transcribe(audio_segment)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/podcast_transcribe/asr/asr_router.py
CHANGED
@@ -4,13 +4,11 @@ ASR模型调用路由器
|
|
4 |
"""
|
5 |
|
6 |
import logging
|
7 |
-
from typing import Dict, Any, Optional, Callable
|
8 |
from pydub import AudioSegment
|
9 |
import spaces
|
10 |
from .asr_base import TranscriptionResult
|
11 |
-
from . import
|
12 |
-
from . import asr_distil_whisper_mlx
|
13 |
-
from . import asr_distil_whisper_transformers
|
14 |
|
15 |
# 配置日志
|
16 |
logger = logging.getLogger("asr")
|
@@ -26,22 +24,8 @@ class ASRRouter:
|
|
26 |
|
27 |
# 定义支持的provider配置
|
28 |
self._provider_configs = {
|
29 |
-
"parakeet_mlx": {
|
30 |
-
"module_path": "asr_parakeet_mlx",
|
31 |
-
"function_name": "transcribe_audio",
|
32 |
-
"default_model": "mlx-community/parakeet-tdt-0.6b-v2",
|
33 |
-
"supported_params": ["model_name"],
|
34 |
-
"description": "基于MLX的Parakeet模型"
|
35 |
-
},
|
36 |
-
"distil_whisper_mlx": {
|
37 |
-
"module_path": "asr_distil_whisper_mlx",
|
38 |
-
"function_name": "transcribe_audio",
|
39 |
-
"default_model": "mlx-community/distil-whisper-large-v3",
|
40 |
-
"supported_params": ["model_name"],
|
41 |
-
"description": "基于MLX的Distil Whisper模型"
|
42 |
-
},
|
43 |
"distil_whisper_transformers": {
|
44 |
-
"module_path": "
|
45 |
"function_name": "transcribe_audio",
|
46 |
"default_model": "distil-whisper/distil-large-v3.5",
|
47 |
"supported_params": ["model_name", "device"],
|
@@ -66,15 +50,8 @@ class ASRRouter:
|
|
66 |
module_path = self._provider_configs[provider]["module_path"]
|
67 |
logger.info(f"获取模块: {module_path}")
|
68 |
|
69 |
-
#
|
70 |
-
|
71 |
-
module = asr_parakeet_mlx
|
72 |
-
elif module_path == "asr_distil_whisper_mlx":
|
73 |
-
module = asr_distil_whisper_mlx
|
74 |
-
elif module_path == "asr_distil_whisper_transformers":
|
75 |
-
module = asr_distil_whisper_transformers
|
76 |
-
else:
|
77 |
-
raise ImportError(f"未找到模块: {module_path}")
|
78 |
|
79 |
self._loaded_modules[provider] = module
|
80 |
logger.info(f"模块 {module_path} 获取成功")
|
@@ -202,51 +179,17 @@ def transcribe_audio(
|
|
202 |
provider: str = "distil_whisper_transformers",
|
203 |
model_name: Optional[str] = None,
|
204 |
device: str = "cpu",
|
|
|
205 |
**kwargs
|
206 |
) -> TranscriptionResult:
|
207 |
-
"""
|
208 |
-
统一的音频转录接口函数
|
209 |
-
|
210 |
-
参数:
|
211 |
-
audio_segment: 输入的AudioSegment对象
|
212 |
-
provider: ASR提供者,可选值:
|
213 |
-
- "parakeet_mlx": 基于MLX的Parakeet模型
|
214 |
-
- "distil_whisper_mlx": 基于MLX的Distil Whisper模型
|
215 |
-
- "distil_whisper_transformers": 基于Transformers的Distil Whisper模型
|
216 |
-
model_name: 模型名称,如果不指定则使用默认模型
|
217 |
-
device: 推理设备,仅对transformers provider有效
|
218 |
-
**kwargs: 其他参数
|
219 |
-
|
220 |
-
返回:
|
221 |
-
TranscriptionResult对象,包含转录的文本、分段和语言
|
222 |
-
|
223 |
-
示例:
|
224 |
-
# 使用默认MLX Distil Whisper模型
|
225 |
-
result = transcribe_audio(audio_segment, provider="distil_whisper_mlx")
|
226 |
-
|
227 |
-
# 使用Parakeet模型
|
228 |
-
result = transcribe_audio(audio_segment, provider="parakeet_mlx")
|
229 |
-
|
230 |
-
# 使用Transformers模型并指定设备
|
231 |
-
result = transcribe_audio(
|
232 |
-
audio_segment,
|
233 |
-
provider="distil_whisper_transformers",
|
234 |
-
device="cuda"
|
235 |
-
)
|
236 |
-
|
237 |
-
# 使用自定义模型
|
238 |
-
result = transcribe_audio(
|
239 |
-
audio_segment,
|
240 |
-
provider="distil_whisper_mlx",
|
241 |
-
model_name="mlx-community/whisper-large-v3"
|
242 |
-
)
|
243 |
-
"""
|
244 |
# 准备参数
|
245 |
params = kwargs.copy()
|
246 |
if model_name is not None:
|
247 |
params["model_name"] = model_name
|
248 |
if device != "cpu":
|
249 |
params["device"] = device
|
|
|
|
|
250 |
|
251 |
return _router.transcribe(audio_segment, provider, **params)
|
252 |
|
|
|
4 |
"""
|
5 |
|
6 |
import logging
|
7 |
+
from typing import Dict, Any, Literal, Optional, Callable
|
8 |
from pydub import AudioSegment
|
9 |
import spaces
|
10 |
from .asr_base import TranscriptionResult
|
11 |
+
from . import asr_distil_whisper
|
|
|
|
|
12 |
|
13 |
# 配置日志
|
14 |
logger = logging.getLogger("asr")
|
|
|
24 |
|
25 |
# 定义支持的provider配置
|
26 |
self._provider_configs = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
"distil_whisper_transformers": {
|
28 |
+
"module_path": "asr_distil_whisper",
|
29 |
"function_name": "transcribe_audio",
|
30 |
"default_model": "distil-whisper/distil-large-v3.5",
|
31 |
"supported_params": ["model_name", "device"],
|
|
|
50 |
module_path = self._provider_configs[provider]["module_path"]
|
51 |
logger.info(f"获取模块: {module_path}")
|
52 |
|
53 |
+
# 所有provider现在都指向同一个模块
|
54 |
+
module = asr_distil_whisper
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
self._loaded_modules[provider] = module
|
57 |
logger.info(f"模块 {module_path} 获取成功")
|
|
|
179 |
provider: str = "distil_whisper_transformers",
|
180 |
model_name: Optional[str] = None,
|
181 |
device: str = "cpu",
|
182 |
+
backend: str = "transformers",
|
183 |
**kwargs
|
184 |
) -> TranscriptionResult:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
# 准备参数
|
186 |
params = kwargs.copy()
|
187 |
if model_name is not None:
|
188 |
params["model_name"] = model_name
|
189 |
if device != "cpu":
|
190 |
params["device"] = device
|
191 |
+
if backend is not None:
|
192 |
+
params["backend"] = backend
|
193 |
|
194 |
return _router.transcribe(audio_segment, provider, **params)
|
195 |
|
src/podcast_transcribe/diarization/diarization_pyannote_mlx.py
CHANGED
@@ -48,9 +48,6 @@ class PyannoteTranscriber(BaseDiarizer):
|
|
48 |
except ImportError:
|
49 |
raise ImportError("请先安装pyannote.audio库: pip install pyannote.audio")
|
50 |
|
51 |
-
if not self.token:
|
52 |
-
raise ValueError("需要提供Hugging Face令牌才能使用pyannote模型。请通过参数传入或设置HF_TOKEN环境变量。")
|
53 |
-
|
54 |
logger.info(f"开始加载模型 {self.model_name}")
|
55 |
self.pipeline = Pipeline.from_pretrained(
|
56 |
self.model_name,
|
|
|
48 |
except ImportError:
|
49 |
raise ImportError("请先安装pyannote.audio库: pip install pyannote.audio")
|
50 |
|
|
|
|
|
|
|
51 |
logger.info(f"开始加载模型 {self.model_name}")
|
52 |
self.pipeline = Pipeline.from_pretrained(
|
53 |
self.model_name,
|
src/podcast_transcribe/diarization/diarization_pyannote_transformers.py
CHANGED
@@ -48,10 +48,7 @@ class PyannoteTransformersTranscriber(BaseDiarizer):
|
|
48 |
from pyannote.audio import Pipeline
|
49 |
except ImportError:
|
50 |
raise ImportError("请先安装pyannote.audio库: pip install pyannote.audio")
|
51 |
-
|
52 |
-
if not self.token:
|
53 |
-
raise ValueError("需要提供Hugging Face令牌才能使用pyannote模型。请通过参数传入或设置HF_TOKEN环境变量。")
|
54 |
-
|
55 |
logger.info(f"开始使用pyannote.audio加载模型 {self.model_name}")
|
56 |
|
57 |
# 使用pyannote.audio Pipeline加载说话人分离模型
|
|
|
48 |
from pyannote.audio import Pipeline
|
49 |
except ImportError:
|
50 |
raise ImportError("请先安装pyannote.audio库: pip install pyannote.audio")
|
51 |
+
|
|
|
|
|
|
|
52 |
logger.info(f"开始使用pyannote.audio加载模型 {self.model_name}")
|
53 |
|
54 |
# 使用pyannote.audio Pipeline加载说话人分离模型
|
src/podcast_transcribe/diarization/diarizer_base.py
CHANGED
@@ -34,7 +34,6 @@ class BaseDiarizer(ABC):
|
|
34 |
segmentation_batch_size: 分割批处理大小,默认为32
|
35 |
"""
|
36 |
self.model_name = model_name
|
37 |
-
self.token = token or os.environ.get("HF_TOKEN")
|
38 |
self.device = device
|
39 |
self.segmentation_batch_size = segmentation_batch_size
|
40 |
|
|
|
34 |
segmentation_batch_size: 分割批处理大小,默认为32
|
35 |
"""
|
36 |
self.model_name = model_name
|
|
|
37 |
self.device = device
|
38 |
self.segmentation_batch_size = segmentation_batch_size
|
39 |
|
src/podcast_transcribe/diarization/diarizer_router.py
CHANGED
@@ -215,20 +215,18 @@ def diarize_audio(
|
|
215 |
|
216 |
示例:
|
217 |
# 使用默认pyannote MLX实现
|
218 |
-
result = diarize_audio(audio_segment, provider="pyannote_mlx"
|
219 |
|
220 |
# 使用transformers实现
|
221 |
result = diarize_audio(
|
222 |
audio_segment,
|
223 |
provider="pyannote_transformers",
|
224 |
-
token="your_hf_token"
|
225 |
)
|
226 |
|
227 |
# 使用GPU设备
|
228 |
result = diarize_audio(
|
229 |
audio_segment,
|
230 |
provider="pyannote_mlx",
|
231 |
-
token="your_hf_token",
|
232 |
device="cuda"
|
233 |
)
|
234 |
|
@@ -236,7 +234,6 @@ def diarize_audio(
|
|
236 |
result = diarize_audio(
|
237 |
audio_segment,
|
238 |
provider="pyannote_mlx",
|
239 |
-
token="your_hf_token",
|
240 |
segmentation_batch_size=64
|
241 |
)
|
242 |
"""
|
|
|
215 |
|
216 |
示例:
|
217 |
# 使用默认pyannote MLX实现
|
218 |
+
result = diarize_audio(audio_segment, provider="pyannote_mlx")
|
219 |
|
220 |
# 使用transformers实现
|
221 |
result = diarize_audio(
|
222 |
audio_segment,
|
223 |
provider="pyannote_transformers",
|
|
|
224 |
)
|
225 |
|
226 |
# 使用GPU设备
|
227 |
result = diarize_audio(
|
228 |
audio_segment,
|
229 |
provider="pyannote_mlx",
|
|
|
230 |
device="cuda"
|
231 |
)
|
232 |
|
|
|
234 |
result = diarize_audio(
|
235 |
audio_segment,
|
236 |
provider="pyannote_mlx",
|
|
|
237 |
segmentation_batch_size=64
|
238 |
)
|
239 |
"""
|
src/podcast_transcribe/transcriber.py
CHANGED
@@ -31,7 +31,6 @@ class CombinedTranscriber:
|
|
31 |
diarization_model_name: str,
|
32 |
llm_model_name: Optional[str] = None,
|
33 |
llm_provider: Optional[str] = None,
|
34 |
-
hf_token: Optional[str] = None,
|
35 |
device: Optional[str] = None,
|
36 |
segmentation_batch_size: int = 64,
|
37 |
parallel: bool = False,
|
@@ -44,7 +43,6 @@ class CombinedTranscriber:
|
|
44 |
asr_provider: ASR提供者名称
|
45 |
diarization_provider: 说话人分离提供者名称
|
46 |
diarization_model_name: 说话人分离模型名称
|
47 |
-
hf_token: Hugging Face令牌
|
48 |
device: 推理设备,'cpu'或'cuda'
|
49 |
segmentation_batch_size: 分割批处理大小,默认为64
|
50 |
parallel: 是否并行执行ASR和说话人分离,默认为False
|
@@ -75,7 +73,6 @@ class CombinedTranscriber:
|
|
75 |
self.asr_provider = asr_provider
|
76 |
self.diarization_provider = diarization_provider
|
77 |
self.diarization_model_name = diarization_model_name
|
78 |
-
self.hf_token = hf_token or os.environ.get("HF_TOKEN")
|
79 |
self.device = device
|
80 |
self.segmentation_batch_size = segmentation_batch_size
|
81 |
self.parallel = parallel
|
@@ -148,7 +145,6 @@ class CombinedTranscriber:
|
|
148 |
audio,
|
149 |
provider=self.diarization_provider,
|
150 |
model_name=self.diarization_model_name,
|
151 |
-
token=self.hf_token,
|
152 |
device=self.device,
|
153 |
segmentation_batch_size=self.segmentation_batch_size
|
154 |
)
|
@@ -195,7 +191,6 @@ class CombinedTranscriber:
|
|
195 |
audio,
|
196 |
provider=self.diarization_provider,
|
197 |
model_name=self.diarization_model_name,
|
198 |
-
token=self.hf_token,
|
199 |
device=self.device,
|
200 |
segmentation_batch_size=self.segmentation_batch_size
|
201 |
)
|
@@ -491,7 +486,6 @@ def transcribe_audio(
|
|
491 |
asr_provider: str = "distil_whisper_transformers",
|
492 |
diarization_model_name: str = "pyannote/speaker-diarization-3.1",
|
493 |
diarization_provider: str = "pyannote_transformers",
|
494 |
-
hf_token: Optional[str] = None,
|
495 |
device: Optional[str] = None,
|
496 |
segmentation_batch_size: int = 64,
|
497 |
parallel: bool = False,
|
@@ -505,7 +499,6 @@ def transcribe_audio(
|
|
505 |
asr_provider: ASR提供者名称
|
506 |
diarization_model_name: 说话人分离模型名称
|
507 |
diarization_provider: 说话人分离提供者名称
|
508 |
-
hf_token: Hugging Face令牌
|
509 |
device: 推理设备,'cpu'或'cuda'
|
510 |
segmentation_batch_size: 分割批处理大小,默认为64
|
511 |
parallel: 是否并行执行ASR和说话人分离,默认为False
|
@@ -520,7 +513,6 @@ def transcribe_audio(
|
|
520 |
asr_provider=asr_provider,
|
521 |
diarization_model_name=diarization_model_name,
|
522 |
diarization_provider=diarization_provider,
|
523 |
-
hf_token=hf_token,
|
524 |
device=device,
|
525 |
segmentation_batch_size=segmentation_batch_size,
|
526 |
parallel=parallel
|
@@ -539,7 +531,6 @@ def transcribe_podcast_audio(
|
|
539 |
diarization_provider: str = "pyannote_transformers",
|
540 |
llm_model_name: Optional[str] = None,
|
541 |
llm_provider: Optional[str] = None,
|
542 |
-
hf_token: Optional[str] = None,
|
543 |
device: Optional[str] = None,
|
544 |
segmentation_batch_size: int = 64,
|
545 |
parallel: bool = False,
|
@@ -557,7 +548,6 @@ def transcribe_podcast_audio(
|
|
557 |
diarization_model_name: 说话人分离模型名称
|
558 |
llm_model_name: LLM模型名称,如果为None则无法识别说话人名称
|
559 |
llm_provider: LLM提供者名称,如果为None则无法识别说话人名称
|
560 |
-
hf_token: Hugging Face令牌
|
561 |
device: 推理设备,'cpu'或'cuda'
|
562 |
segmentation_batch_size: 分割批处理大小,默认为64
|
563 |
parallel: 是否并行执行ASR和说话人分离,默认为False
|
@@ -574,7 +564,6 @@ def transcribe_podcast_audio(
|
|
574 |
diarization_model_name=diarization_model_name,
|
575 |
llm_model_name=llm_model_name,
|
576 |
llm_provider=llm_provider,
|
577 |
-
hf_token=hf_token,
|
578 |
device=device,
|
579 |
segmentation_batch_size=segmentation_batch_size,
|
580 |
parallel=parallel
|
|
|
31 |
diarization_model_name: str,
|
32 |
llm_model_name: Optional[str] = None,
|
33 |
llm_provider: Optional[str] = None,
|
|
|
34 |
device: Optional[str] = None,
|
35 |
segmentation_batch_size: int = 64,
|
36 |
parallel: bool = False,
|
|
|
43 |
asr_provider: ASR提供者名称
|
44 |
diarization_provider: 说话人分离提供者名称
|
45 |
diarization_model_name: 说话人分离模型名称
|
|
|
46 |
device: 推理设备,'cpu'或'cuda'
|
47 |
segmentation_batch_size: 分割批处理大小,默认为64
|
48 |
parallel: 是否并行执行ASR和说话人分离,默认为False
|
|
|
73 |
self.asr_provider = asr_provider
|
74 |
self.diarization_provider = diarization_provider
|
75 |
self.diarization_model_name = diarization_model_name
|
|
|
76 |
self.device = device
|
77 |
self.segmentation_batch_size = segmentation_batch_size
|
78 |
self.parallel = parallel
|
|
|
145 |
audio,
|
146 |
provider=self.diarization_provider,
|
147 |
model_name=self.diarization_model_name,
|
|
|
148 |
device=self.device,
|
149 |
segmentation_batch_size=self.segmentation_batch_size
|
150 |
)
|
|
|
191 |
audio,
|
192 |
provider=self.diarization_provider,
|
193 |
model_name=self.diarization_model_name,
|
|
|
194 |
device=self.device,
|
195 |
segmentation_batch_size=self.segmentation_batch_size
|
196 |
)
|
|
|
486 |
asr_provider: str = "distil_whisper_transformers",
|
487 |
diarization_model_name: str = "pyannote/speaker-diarization-3.1",
|
488 |
diarization_provider: str = "pyannote_transformers",
|
|
|
489 |
device: Optional[str] = None,
|
490 |
segmentation_batch_size: int = 64,
|
491 |
parallel: bool = False,
|
|
|
499 |
asr_provider: ASR提供者名称
|
500 |
diarization_model_name: 说话人分离模型名称
|
501 |
diarization_provider: 说话人分离提供者名称
|
|
|
502 |
device: 推理设备,'cpu'或'cuda'
|
503 |
segmentation_batch_size: 分割批处理大小,默认为64
|
504 |
parallel: 是否并行执行ASR和说话人分离,默认为False
|
|
|
513 |
asr_provider=asr_provider,
|
514 |
diarization_model_name=diarization_model_name,
|
515 |
diarization_provider=diarization_provider,
|
|
|
516 |
device=device,
|
517 |
segmentation_batch_size=segmentation_batch_size,
|
518 |
parallel=parallel
|
|
|
531 |
diarization_provider: str = "pyannote_transformers",
|
532 |
llm_model_name: Optional[str] = None,
|
533 |
llm_provider: Optional[str] = None,
|
|
|
534 |
device: Optional[str] = None,
|
535 |
segmentation_batch_size: int = 64,
|
536 |
parallel: bool = False,
|
|
|
548 |
diarization_model_name: 说话人分离模型名称
|
549 |
llm_model_name: LLM模型名称,如果为None则无法识别说话人名称
|
550 |
llm_provider: LLM提供者名称,如果为None则无法识别说话人名称
|
|
|
551 |
device: 推理设备,'cpu'或'cuda'
|
552 |
segmentation_batch_size: 分割批处理大小,默认为64
|
553 |
parallel: 是否并行执行ASR和说话人分离,默认为False
|
|
|
564 |
diarization_model_name=diarization_model_name,
|
565 |
llm_model_name=llm_model_name,
|
566 |
llm_provider=llm_provider,
|
|
|
567 |
device=device,
|
568 |
segmentation_batch_size=segmentation_batch_size,
|
569 |
parallel=parallel
|