konieshadow commited on
Commit
5d74f79
·
1 Parent(s): d709cdc
examples/combined_podcast_transcription.py CHANGED
@@ -20,10 +20,9 @@ def main():
20
  # audio_file = Path("/Users/konie/Desktop/voices/lex_ai_john_carmack_30.wav")
21
 
22
  # 模型配置
23
- asr_model_name = "mlx-community/parakeet-tdt-0.6b-v2" # ASR模型名称
24
  diarization_model_name = "pyannote/speaker-diarization-3.1" # 说话人分离模型名称
25
  llm_model_path = "mlx-community/gemma-3-12b-it-4bit-DWQ"
26
- hf_token = "" # Hugging Face API 令牌
27
  device = "mps" # 设备类型
28
  segmentation_batch_size = 64
29
  parallel = True
@@ -33,12 +32,6 @@ def main():
33
  print(f"错误:文件 '{audio_file}' 不存在")
34
  return 1
35
 
36
- # 检查HF令牌
37
- if not hf_token:
38
- print("警告:未设置HF_TOKEN环境变量,必须设置此环境变量才能使用pyannote说话人分离模型")
39
- print("请执行:export HF_TOKEN='你的HuggingFace令牌'")
40
- return 1
41
-
42
  try:
43
  print(f"正在加载音频文件: {audio_file}")
44
  # 加载音频文件
@@ -67,7 +60,6 @@ def main():
67
  result = transcribe_podcast_audio(audio,
68
  podcast_info=mock_podcast_info,
69
  episode_info=mock_episode_info,
70
- hf_token=hf_token,
71
  asr_model_name=asr_model_name,
72
  diarization_model_name=diarization_model_name,
73
  llm_model_name=llm_model_path,
 
20
  # audio_file = Path("/Users/konie/Desktop/voices/lex_ai_john_carmack_30.wav")
21
 
22
  # 模型配置
23
+ asr_model_name = "mlx-community/" # ASR模型名称
24
  diarization_model_name = "pyannote/speaker-diarization-3.1" # 说话人分离模型名称
25
  llm_model_path = "mlx-community/gemma-3-12b-it-4bit-DWQ"
 
26
  device = "mps" # 设备类型
27
  segmentation_batch_size = 64
28
  parallel = True
 
32
  print(f"错误:文件 '{audio_file}' 不存在")
33
  return 1
34
 
 
 
 
 
 
 
35
  try:
36
  print(f"正在加载音频文件: {audio_file}")
37
  # 加载音频文件
 
60
  result = transcribe_podcast_audio(audio,
61
  podcast_info=mock_podcast_info,
62
  episode_info=mock_episode_info,
 
63
  asr_model_name=asr_model_name,
64
  diarization_model_name=diarization_model_name,
65
  llm_model_name=llm_model_path,
examples/combined_transcription.py CHANGED
@@ -27,7 +27,6 @@ def main():
27
  # 模型配置
28
  asr_model_name = "distil-whisper/distil-large-v3.5" # ASR模型名称
29
  diarization_model_name = "pyannote/speaker-diarization-3.1" # 说话人分离模型名称
30
- hf_token = "" # Hugging Face API 令牌
31
  device = "mps" # 设备类型
32
  segmentation_batch_size = 64
33
  parallel = True
@@ -37,12 +36,6 @@ def main():
37
  print(f"错误:文件 '{audio_file}' 不存在")
38
  return 1
39
 
40
- # 检查HF令牌
41
- if not hf_token:
42
- print("警告:未设置HF_TOKEN环境变量,必须设置此环境变量才能使用pyannote说话人分离模型")
43
- print("请执行:export HF_TOKEN='你的HuggingFace令牌'")
44
- return 1
45
-
46
  try:
47
  print(f"正在加载音频文件: {audio_file}")
48
  # 加载音频文件
@@ -54,7 +47,6 @@ def main():
54
  audio,
55
  asr_model_name=asr_model_name,
56
  diarization_model_name=diarization_model_name,
57
- hf_token=hf_token,
58
  device=device,
59
  segmentation_batch_size=segmentation_batch_size,
60
  parallel=parallel,
 
27
  # 模型配置
28
  asr_model_name = "distil-whisper/distil-large-v3.5" # ASR模型名称
29
  diarization_model_name = "pyannote/speaker-diarization-3.1" # 说话人分离模型名称
 
30
  device = "mps" # 设备类型
31
  segmentation_batch_size = 64
32
  parallel = True
 
36
  print(f"错误:文件 '{audio_file}' 不存在")
37
  return 1
38
 
 
 
 
 
 
 
39
  try:
40
  print(f"正在加载音频文件: {audio_file}")
41
  # 加载音频文件
 
47
  audio,
48
  asr_model_name=asr_model_name,
49
  diarization_model_name=diarization_model_name,
 
50
  device=device,
51
  segmentation_batch_size=segmentation_batch_size,
52
  parallel=parallel,
examples/simple_asr.py CHANGED
@@ -14,41 +14,26 @@ from pathlib import Path
14
  sys.path.insert(0, str(Path(__file__).parent.parent))
15
 
16
  from src.podcast_transcribe.audio import load_audio
 
17
 
18
  logger = logging.getLogger("asr_example")
19
 
20
 
21
  def main():
22
  """主函数"""
23
- # audio_file = Path.joinpath(Path(__file__).parent, "input", "lex_ai_john_carmack_1.wav") # 播客音频文件路径
24
  # audio_file = "/Users/konie/Desktop/voices/lex_ai_john_carmack_30.wav" # 播客音频文件路径
25
- audio_file = "/Users/konie/Desktop/voices/podcast1_1.wav"
26
  # model = "distil-whisper"
27
- model = "distil-whisper-transformers"
28
 
29
- device = "mlx"
30
 
31
  # 检查文件是否存在
32
  if not os.path.exists(audio_file):
33
  print(f"错误:文件 '{audio_file}' 不存在")
34
  return 1
35
 
36
- if model == "parakeet":
37
- from src.podcast_transcribe.asr.asr_parakeet_mlx import transcribe_audio
38
- model_name = "mlx-community/parakeet-tdt-0.6b-v2"
39
- logger.info(f"使用Parakeet模型: {model_name}")
40
- elif model == "distil-whisper": # distil-whisper
41
- from src.podcast_transcribe.asr.asr_distil_whisper_mlx import transcribe_audio
42
- model_name = "mlx-community/distil-whisper-large-v3"
43
- logger.info(f"使用Distil Whisper模型: {model_name}")
44
- elif model == "distil-whisper-transformers": # distil-whisper
45
- from src.podcast_transcribe.asr.asr_distil_whisper_transformers import transcribe_audio
46
- model_name = "distil-whisper/distil-large-v3.5"
47
- logger.info(f"使用Distil Whisper模型: {model_name}")
48
- else:
49
- logger.error(f"错误:未指定模型类型")
50
- return 1
51
-
52
  try:
53
  print(f"正在加载音频文件: {audio_file}")
54
  # 加载音频文件
@@ -58,7 +43,7 @@ def main():
58
 
59
  # 进行转录
60
  print("开始转录...")
61
- result = transcribe_audio(audio, model_name=model_name, device=device)
62
 
63
  # 输出结果
64
  print("\n转录结果:")
 
14
  sys.path.insert(0, str(Path(__file__).parent.parent))
15
 
16
  from src.podcast_transcribe.audio import load_audio
17
+ from src.podcast_transcribe.asr.asr_distil_whisper import transcribe_audio
18
 
19
  logger = logging.getLogger("asr_example")
20
 
21
 
22
  def main():
23
  """主函数"""
24
+ audio_file = Path.joinpath(Path(__file__).parent, "input", "lex_ai_john_carmack_1.wav") # 播客音频文件路径
25
  # audio_file = "/Users/konie/Desktop/voices/lex_ai_john_carmack_30.wav" # 播客音频文件路径
26
+ # audio_file = "/Users/konie/Desktop/voices/podcast1_1.wav"
27
  # model = "distil-whisper"
28
+ model = "distil-whisper/distil-large-v3.5"
29
 
30
+ device = "mps"
31
 
32
  # 检查文件是否存在
33
  if not os.path.exists(audio_file):
34
  print(f"错误:文件 '{audio_file}' 不存在")
35
  return 1
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  try:
38
  print(f"正在加载音频文件: {audio_file}")
39
  # 加载音频文件
 
43
 
44
  # 进行转录
45
  print("开始转录...")
46
+ result = transcribe_audio(audio, model_name=model, device=device)
47
 
48
  # 输出结果
49
  print("\n转录结果:")
examples/simple_diarization.py CHANGED
@@ -22,7 +22,6 @@ def main():
22
  audio_file = Path.joinpath(Path(__file__).parent, "input", "lex_ai_john_carmack_1.wav") # 播客音频文件路径
23
  # audio_file = "/Users/konie/Desktop/voices/history_in_the_baking.mp3" # 播客音频文件路径
24
  model_name = "pyannote/speaker-diarization-3.1" # 说话人分离模型名称
25
- hf_token = "" # Hugging Face API 令牌
26
  device = "mps" # 设备类型
27
 
28
  # 检查文件是否存在
@@ -30,11 +29,6 @@ def main():
30
  print(f"错误:文件 '{audio_file}' 不存在")
31
  return 1
32
 
33
- # 检查令牌是否设置
34
- if not hf_token:
35
- print("错误:未设置HF_TOKEN环境变量,请设置后再运行")
36
- return 1
37
-
38
  try:
39
  print(f"正在加载音频文件: {audio_file}")
40
  # 加载音频文件
@@ -46,12 +40,12 @@ def main():
46
  if "pyannote/speaker-diarization" in model_name:
47
  # 使用transformers版本进行说话人分离
48
  print(f"使用transformers版本处理模型: {model_name}")
49
- result = diarize_audio_transformers(audio, model_name=model_name, token=hf_token, device=device, segmentation_batch_size=128)
50
  version_name = "Transformers"
51
  else:
52
  # 使用MLX版本进行说话人分离
53
  print(f"使用MLX版本处理模型: {model_name}")
54
- result = diarize_audio_mlx(audio, model_name=model_name, token=hf_token, device=device, segmentation_batch_size=128)
55
  version_name = "MLX"
56
 
57
  # 输出结果
 
22
  audio_file = Path.joinpath(Path(__file__).parent, "input", "lex_ai_john_carmack_1.wav") # 播客音频文件路径
23
  # audio_file = "/Users/konie/Desktop/voices/history_in_the_baking.mp3" # 播客音频文件路径
24
  model_name = "pyannote/speaker-diarization-3.1" # 说话人分离模型名称
 
25
  device = "mps" # 设备类型
26
 
27
  # 检查文件是否存在
 
29
  print(f"错误:文件 '{audio_file}' 不存在")
30
  return 1
31
 
 
 
 
 
 
32
  try:
33
  print(f"正在加载音频文件: {audio_file}")
34
  # 加载音频文件
 
40
  if "pyannote/speaker-diarization" in model_name:
41
  # 使用transformers版本进行说话人分离
42
  print(f"使用transformers版本处理模型: {model_name}")
43
+ result = diarize_audio_transformers(audio, model_name=model_name, device=device, segmentation_batch_size=128)
44
  version_name = "Transformers"
45
  else:
46
  # 使用MLX版本进行说话人分离
47
  print(f"使用MLX版本处理模型: {model_name}")
48
+ result = diarize_audio_mlx(audio, model_name=model_name, device=device, segmentation_batch_size=128)
49
  version_name = "MLX"
50
 
51
  # 输出结果
requirements.txt CHANGED
@@ -18,5 +18,4 @@ accelerate>=1.6.0
18
  # MLX特定依赖 - 仅适用于Apple Silicon Mac
19
  # mlx>=0.25.2
20
  # mlx-lm>=0.24.0
21
- # parakeet-mlx>=0.2.6
22
  # mlx-whisper>=0.4.2
 
18
  # MLX特定依赖 - 仅适用于Apple Silicon Mac
19
  # mlx>=0.25.2
20
  # mlx-lm>=0.24.0
 
21
  # mlx-whisper>=0.4.2
src/podcast_transcribe/asr/asr_base.py CHANGED
@@ -89,189 +89,4 @@ class BaseTranscriber:
89
 
90
  if chinese_chars > len(text) * 0.3:
91
  return "zh"
92
- return "en"
93
-
94
- def _convert_segments(self, model_result) -> List[Dict[str, Union[float, str]]]:
95
- """
96
- 将模型的分段结果转换为所需格式(需要在子类中实现)
97
-
98
- 参数:
99
- model_result: 模型返回的结果
100
-
101
- 返回:
102
- 转换后的分段列表
103
- """
104
- raise NotImplementedError("子类必须实现_convert_segments方法")
105
-
106
- def transcribe(self, audio: AudioSegment, chunk_duration_s: int = 30, overlap_s: int = 5) -> TranscriptionResult:
107
- """
108
- 转录音频,支持长音频分块处理。
109
-
110
- 参数:
111
- audio: 要转录的AudioSegment对象
112
- chunk_duration_s: 分块处理的块时长(秒)。如果音频短于此,则不分块。
113
- overlap_s: 分块间的重叠时长(秒)。
114
-
115
- 返回:
116
- TranscriptionResult对象,包含转录结果
117
- """
118
- logger.info(f"开始转录 {len(audio)/1000:.2f} 秒的音频。分块设置: 块时长={chunk_duration_s}s, 重叠={overlap_s}s")
119
-
120
- if overlap_s >= chunk_duration_s and len(audio)/1000.0 > chunk_duration_s :
121
- logger.error("重叠时长必须小于块时长。")
122
- raise ValueError("overlap_s 必须小于 chunk_duration_s。")
123
-
124
- total_duration_ms = len(audio)
125
- chunk_duration_ms = chunk_duration_s * 1000
126
- overlap_ms = overlap_s * 1000
127
-
128
- if total_duration_ms <= chunk_duration_ms:
129
- logger.debug("音频时长不大于设定块时长,直接进行完整转录。")
130
- processed_audio = self._prepare_audio(audio)
131
- samples = np.array(processed_audio.get_array_of_samples(), dtype=np.float32) / 32768.0
132
-
133
- try:
134
- model_result = self._perform_transcription(samples)
135
- text = self._get_text_from_result(model_result)
136
- segments = self._convert_segments(model_result)
137
- language = self._detect_language(text)
138
-
139
- logger.info(f"单块转录完成,语言: {language},文本长度: {len(text)},分段数: {len(segments)}")
140
- return TranscriptionResult(text=text, segments=segments, language=language)
141
- except Exception as e:
142
- logger.error(f"单块转录失败: {str(e)}", exc_info=True)
143
- raise RuntimeError(f"单块转录失败: {str(e)}")
144
-
145
- # 长音频分块处理
146
- final_segments = []
147
- # current_pos_ms 指的是当前块要处理的"新内容"的起始点在原始音频中的位置
148
- current_pos_ms = 0
149
-
150
- while current_pos_ms < total_duration_ms:
151
- # 计算当前块实际送入模型处理的音频的起始和结束时间点
152
- # 对于第一个块,start_process_ms = 0
153
- # 对于后续块,start_process_ms 会向左回退 overlap_ms 以包含重叠区域
154
- start_process_ms = max(0, current_pos_ms - overlap_ms)
155
- end_process_ms = min(start_process_ms + chunk_duration_ms, total_duration_ms)
156
-
157
- # 如果计算出的块起始点已经等于或超过总时长,说明处理完毕
158
- if start_process_ms >= total_duration_ms:
159
- break
160
-
161
- chunk_audio = audio[start_process_ms:end_process_ms]
162
-
163
- logger.info(f"处理音频块: {start_process_ms/1000.0:.2f}s - {end_process_ms/1000.0:.2f}s (新内容起始于: {current_pos_ms/1000.0:.2f}s)")
164
-
165
- if len(chunk_audio) == 0:
166
- logger.warning(f"生成了一个空的音频块,跳过。起始: {start_process_ms/1000.0:.2f}s, 结束: {end_process_ms/1000.0:.2f}s")
167
- # 必须推进 current_pos_ms 以避免死循环
168
- advance_ms = chunk_duration_ms - overlap_ms
169
- if advance_ms <= 0: # 应该在函数开始时已检查 overlap_s < chunk_duration_s
170
- raise RuntimeError("块推进时长配置错误,可能导致死循环。")
171
- current_pos_ms += advance_ms
172
- continue
173
-
174
- processed_chunk_audio = self._prepare_audio(chunk_audio)
175
- samples = np.array(processed_chunk_audio.get_array_of_samples(), dtype=np.float32) / 32768.0
176
-
177
- try:
178
- model_result = self._perform_transcription(samples)
179
- segments_chunk = self._convert_segments(model_result)
180
-
181
- for seg in segments_chunk:
182
- # seg["start"] 和 seg["end"] 是相对于当前块 (chunk_audio) 的起始点(即0)
183
- # 计算 segment 在原始完整音频中的绝对起止时间
184
- global_seg_start_s = start_process_ms / 1000.0 + seg["start"]
185
- global_seg_end_s = start_process_ms / 1000.0 + seg["end"]
186
-
187
- # ��心去重逻辑:
188
- # 我们只接受那些真实开始于 current_pos_ms / 1000.0 之后的 segment。
189
- # current_pos_ms 是当前块应该贡献的"新"内容的开始时间。
190
- # 对于第一个块 (current_pos_ms == 0),所有 segment 都被接受(只要它们的 start >= 0)。
191
- # 对于后续块,只有当 segment 的全局开始时间 >= 当前块新内容的开始时间时,才添加。
192
- if global_seg_start_s >= current_pos_ms / 1000.0:
193
- final_segments.append({
194
- "start": global_seg_start_s,
195
- "end": global_seg_end_s,
196
- "text": seg["text"]
197
- })
198
- # 特殊处理第一个块,因为 current_pos_ms 为 0,上面的条件 global_seg_start_s >= 0 总是满足。
199
- # 但为了更清晰,如果不是第一个块,但 segment 跨越了 current_pos_ms,
200
- # 它的起始部分在重叠区,结束部分在非重叠区。
201
- # 当前逻辑是,如果它的 global_seg_start_s < current_pos_ms / 1000.0,它就被丢弃。
202
- # 这是为了确保不重复记录重叠区域的开头部分。
203
- # 如果一个 segment 完全在重叠区内且在前一个块已被记录,此逻辑可避免重复。
204
-
205
- except Exception as e:
206
- logger.error(f"处理音频块 {start_process_ms/1000.0:.2f}s - {end_process_ms/1000.0:.2f}s 失败: {str(e)}", exc_info=True)
207
-
208
- # 更新下一个"新内容"块的起始位置
209
- advance_ms = chunk_duration_ms - overlap_ms
210
- current_pos_ms += advance_ms
211
-
212
- # 对收集到的所有 segments 按开始时间排序
213
- final_segments.sort(key=lambda s: s["start"])
214
-
215
- # 可选:进一步清理 segments,例如合并非常接近且文本连续的,或移除完全重复的
216
- cleaned_segments = []
217
- if final_segments:
218
- cleaned_segments.append(final_segments[0])
219
- for i in range(1, len(final_segments)):
220
- prev_s = cleaned_segments[-1]
221
- curr_s = final_segments[i]
222
- # 简单的去重:如果时间戳和文本都几乎一样,则认为是重复
223
- if abs(curr_s["start"] - prev_s["start"]) < 0.01 and \
224
- abs(curr_s["end"] - prev_s["end"]) < 0.01 and \
225
- curr_s["text"] == prev_s["text"]:
226
- continue
227
-
228
- # 如果当前 segment 的开始时间在前一个 segment 的结束时间之前,
229
- # 并且文本有明显重叠,可能需要更智能的合并。
230
- # 目前的逻辑通过 global_seg_start_s >= current_pos_ms / 1000.0 过滤,
231
- # 已经大大减少了直接的 segment 重复。
232
- # 此处的清理更多是处理模型在边界可能产生的一些微小偏差。
233
- # 如果上一个segment的结束时间比当前segment的开始时间还要晚,说明有重叠,
234
- # 且上一个segment包含了当前segment的开始部分。
235
- # 这种情况下,可以考虑调整上一个的结束,或当前segment的开始和文本。
236
- # 为简单起见,暂时直接添加,相信之前的过滤已处理主要重叠。
237
- if curr_s["start"] < prev_s["end"] and prev_s["text"].endswith(curr_s["text"][:len(prev_s["text"]) - int((prev_s["end"] - curr_s["start"])*10) ]): # 粗略检查
238
- # 如果curr_s的开始部分被prev_s覆盖,并且文本也对应,则调整curr_s
239
- # pass # 暂时不处理这种细微重叠,依赖模型切分
240
- cleaned_segments.append(curr_s) # 仍添加,依赖后续文本拼接
241
- else:
242
- cleaned_segments.append(curr_s)
243
-
244
- final_text = " ".join([s["text"] for s in cleaned_segments]).strip()
245
- language = self._detect_language(final_text)
246
-
247
- logger.info(f"分块转录完成。最终文本长度: {len(final_text)}, 分段数: {len(cleaned_segments)}")
248
-
249
- return TranscriptionResult(
250
- text=final_text,
251
- segments=cleaned_segments,
252
- language=language
253
- )
254
-
255
- def _perform_transcription(self, audio_data):
256
- """
257
- 执行转录(需要在子类中实现)
258
-
259
- 参数:
260
- audio_data: 音频数据(numpy数组)
261
-
262
- 返回:
263
- 模型的转录结果
264
- """
265
- raise NotImplementedError("子类必须实现_perform_transcription方法")
266
-
267
- def _get_text_from_result(self, result):
268
- """
269
- 从结果中获取文本(需要在子类中实现)
270
-
271
- 参数:
272
- result: 模型的转录结果
273
-
274
- 返回:
275
- ���录的文本
276
- """
277
- raise NotImplementedError("子类必须实现_get_text_from_result方法")
 
89
 
90
  if chinese_chars > len(text) * 0.3:
91
  return "zh"
92
+ return "en"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/podcast_transcribe/asr/asr_distil_whisper.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 基于MLX或Transformers实现的语音识别模块,使用distil-whisper模型
3
+ """
4
+
5
+ import os
6
+ from pydub import AudioSegment
7
+ from typing import Dict, List, Union, Literal
8
+ import logging
9
+ import numpy as np
10
+
11
+ # 导入基类
12
+ from .asr_base import BaseTranscriber, TranscriptionResult
13
+
14
+ # 配置日志
15
+ logger = logging.getLogger("asr")
16
+
17
+
18
+ class DistilWhisperTranscriber(BaseTranscriber):
19
+ """抽象基类:Distil Whisper转录器的共享实现"""
20
+
21
+ def __init__(
22
+ self,
23
+ model_name: str,
24
+ **kwargs
25
+ ):
26
+ """
27
+ 初始化转录器
28
+
29
+ 参数:
30
+ model_name: 模型名称
31
+ **kwargs: 其他参数
32
+ """
33
+ super().__init__(model_name=model_name, **kwargs)
34
+
35
+ def transcribe(self, audio: AudioSegment) -> TranscriptionResult:
36
+ """
37
+ 转录音频,针对distil-whisper模型取消分块处理,直接处理整个音频。
38
+
39
+ 参数:
40
+ audio: 要转录的AudioSegment对象
41
+ chunk_duration_s: 分块处理的块时长(秒)- 此参数被忽略
42
+ overlap_s: 分块间的重叠时长(秒)- 此参数被忽略
43
+
44
+ 返回:
45
+ TranscriptionResult对象,包含转录结果
46
+ """
47
+ logger.info(f"开始转录 {len(audio)/1000:.2f} 秒的音频(distil-whisper模型)")
48
+
49
+ # 直接处理整个音频,不进行分块
50
+ processed_audio = self._prepare_audio(audio)
51
+ samples = np.array(processed_audio.get_array_of_samples(), dtype=np.float32) / 32768.0
52
+
53
+ try:
54
+ model_result = self._perform_transcription(samples)
55
+ text = self._get_text_from_result(model_result)
56
+ segments = self._convert_segments(model_result)
57
+ language = self._detect_language(text)
58
+
59
+ logger.info(f"转录完成,语言: {language},文本长度: {len(text)},分段数: {len(segments)}")
60
+ return TranscriptionResult(text=text, segments=segments, language=language)
61
+ except Exception as e:
62
+ logger.error(f"转录失败: {str(e)}", exc_info=True)
63
+ raise RuntimeError(f"转录失败: {str(e)}")
64
+
65
+ def _get_text_from_result(self, result):
66
+ """
67
+ 从结果中获取文本
68
+
69
+ 参数:
70
+ result: 模型的转录结果
71
+
72
+ 返回:
73
+ 转录的文本
74
+ """
75
+ return result.get("text", "")
76
+
77
+ def _load_model(self):
78
+ """加载模型的抽象方法,由子类实现"""
79
+ raise NotImplementedError("子类必须实现_load_model方法")
80
+
81
+ def _perform_transcription(self, audio_data):
82
+ """执行转录的抽象方法,由子类实现"""
83
+ raise NotImplementedError("子类必须实现_perform_transcription方法")
84
+
85
+ def _convert_segments(self, result) -> List[Dict[str, Union[float, str]]]:
86
+ """将模型结果转换为分段的抽象方法,由子类实现"""
87
+ raise NotImplementedError("子类必须实现_convert_segments方法")
88
+
89
+
90
+ class MLXDistilWhisperTranscriber(DistilWhisperTranscriber):
91
+ """使用MLX加载和运行distil-whisper模型的转录器"""
92
+
93
+ def __init__(
94
+ self,
95
+ model_name: str = "mlx-community/distil-whisper-large-v3",
96
+ ):
97
+ """
98
+ 初始化转录器
99
+
100
+ 参数:
101
+ model_name: 模型名称
102
+ """
103
+ super().__init__(model_name=model_name)
104
+
105
+ def _load_model(self):
106
+ """加载Distil Whisper MLX模型"""
107
+ try:
108
+ # 懒加载mlx-whisper
109
+ try:
110
+ import mlx_whisper
111
+ except ImportError:
112
+ raise ImportError("请先安装mlx-whisper库: pip install mlx-whisper")
113
+
114
+ logger.info(f"开始加载模型 {self.model_name}")
115
+ self.model = mlx_whisper.load_models.load_model(self.model_name)
116
+ logger.info(f"模型加载成功")
117
+ except Exception as e:
118
+ logger.error(f"加载模型失败: {str(e)}", exc_info=True)
119
+ raise RuntimeError(f"加载模型失败: {str(e)}")
120
+
121
+ def _convert_segments(self, result) -> List[Dict[str, Union[float, str]]]:
122
+ """
123
+ 将模型的分段结果转换为所需格式
124
+
125
+ 参数:
126
+ result: 模型返回的结果
127
+
128
+ 返回:
129
+ 转换后的分段列表
130
+ """
131
+ segments = []
132
+
133
+ for segment in result.get("segments", []):
134
+ segments.append({
135
+ "start": segment.get("start", 0.0),
136
+ "end": segment.get("end", 0.0),
137
+ "text": segment.get("text", "").strip()
138
+ })
139
+
140
+ return segments
141
+
142
+ def _perform_transcription(self, audio_data):
143
+ """
144
+ 执行转录
145
+
146
+ 参数:
147
+ audio_data: 音频数据(numpy数组)
148
+
149
+ 返回:
150
+ 模型的转录结果
151
+ """
152
+ from mlx_whisper import transcribe
153
+ return transcribe(audio_data, path_or_hf_repo=self.model_name)
154
+
155
+
156
+ class TransformersDistilWhisperTranscriber(DistilWhisperTranscriber):
157
+ """使用Transformers加载和运行distil-whisper模型的转录器"""
158
+
159
+ def __init__(
160
+ self,
161
+ model_name: str = "distil-whisper/distil-large-v3.5",
162
+ device: str = "cpu",
163
+ ):
164
+ """
165
+ 初始化转录器
166
+
167
+ 参数:
168
+ model_name: 模型名称
169
+ device: 推理设备,'cpu'或'cuda'
170
+ """
171
+ super().__init__(model_name=model_name, device=device)
172
+
173
+ def _load_model(self):
174
+ """加载Distil Whisper Transformers模型"""
175
+ try:
176
+ # 懒加载transformers
177
+ try:
178
+ from transformers import pipeline
179
+ except ImportError:
180
+ raise ImportError("请先安装transformers库: pip install transformers")
181
+
182
+ logger.info(f"开始加载模型 {self.model_name} 设备: {self.device}")
183
+
184
+ pipeline_device_arg = None
185
+ if self.device == "cuda":
186
+ pipeline_device_arg = 0 # 使用第一个 CUDA 设备
187
+ elif self.device == "mps":
188
+ pipeline_device_arg = "mps" # 使用 MPS 设备
189
+ elif self.device == "cpu":
190
+ pipeline_device_arg = -1 # 使用 CPU
191
+ else:
192
+ # 对于其他未明确支持的 device 字符串,记录警告并默认使用 CPU
193
+ logger.warning(f"不支持的设备字符串 '{self.device}',将默认使用 CPU。")
194
+ pipeline_device_arg = -1
195
+
196
+ # 导入必要的模块来配置模型
197
+ import warnings
198
+ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
199
+
200
+ # 抑制特定的警告
201
+ warnings.filterwarnings("ignore", message="The input name `inputs` is deprecated")
202
+ warnings.filterwarnings("ignore", message="You have passed task=transcribe")
203
+ warnings.filterwarnings("ignore", message="The attention mask is not set")
204
+
205
+ self.pipeline = pipeline(
206
+ "automatic-speech-recognition",
207
+ model=self.model_name,
208
+ device=pipeline_device_arg,
209
+ return_timestamps=True,
210
+ chunk_length_s=30, # 使用30秒的块长度
211
+ stride_length_s=5, # 块之间5秒的重叠
212
+ batch_size=1, # 顺序处理
213
+ # 添加以下参数来减少警告
214
+ generate_kwargs={
215
+ "task": "transcribe",
216
+ "language": None, # 自动检测语言
217
+ "forced_decoder_ids": None, # 避免冲突
218
+ }
219
+ )
220
+ logger.info(f"模型加载成功")
221
+ except Exception as e:
222
+ logger.error(f"加载模型失败: {str(e)}", exc_info=True)
223
+ raise RuntimeError(f"加载模型失败: {str(e)}")
224
+
225
+ def _convert_segments(self, result) -> List[Dict[str, Union[float, str]]]:
226
+ """
227
+ 将模型的分段结果转换为所需格式
228
+
229
+ 参数:
230
+ result: 模型返回的结果
231
+
232
+ 返回:
233
+ 转换后的分段列表
234
+ """
235
+ segments = []
236
+
237
+ # transformers pipeline 的结果格式
238
+ if "chunks" in result:
239
+ for chunk in result["chunks"]:
240
+ segments.append({
241
+ "start": chunk["timestamp"][0] if chunk["timestamp"][0] is not None else 0.0,
242
+ "end": chunk["timestamp"][1] if chunk["timestamp"][1] is not None else 0.0,
243
+ "text": chunk["text"].strip()
244
+ })
245
+ else:
246
+ # 如果没有分段信息,创建一个单一分段
247
+ segments.append({
248
+ "start": 0.0,
249
+ "end": 0.0, # 无法确定结束时间
250
+ "text": result.get("text", "").strip()
251
+ })
252
+
253
+ return segments
254
+
255
+ def _perform_transcription(self, audio_data):
256
+ """
257
+ 执行转录
258
+
259
+ 参数:
260
+ audio_data: 音频数据(numpy数组)
261
+
262
+ 返回:
263
+ 模型的转录结果
264
+ """
265
+ # transformers pipeline 接受numpy数组作为输入
266
+ # 音频数据已经在_prepare_audio中确保是16kHz采样率
267
+
268
+ # 确保音频数据格式正确
269
+ if audio_data.dtype != np.float32:
270
+ audio_data = audio_data.astype(np.float32)
271
+
272
+ # 使用正确的参数名称调用pipeline
273
+ try:
274
+ result = self.pipeline(
275
+ audio_data,
276
+ generate_kwargs={
277
+ "task": "transcribe",
278
+ "language": None, # 自动检测语言
279
+ "forced_decoder_ids": None, # 避免冲突
280
+ }
281
+ )
282
+ return result
283
+ except Exception as e:
284
+ logger.warning(f"使用新参数格式失败,尝试使用默认参数: {str(e)}")
285
+ # 如果新格式失败,回退到简单调用
286
+ return self.pipeline(audio_data)
287
+
288
+
289
+ # 统一的接口函数
290
+ def transcribe_audio(
291
+ audio_segment: AudioSegment,
292
+ model_name: str = None,
293
+ backend: Literal["mlx", "transformers"] = "transformers",
294
+ device: str = "cpu",
295
+ ) -> TranscriptionResult:
296
+ """
297
+ 使用Distil Whisper模型转录音频
298
+
299
+ 参数:
300
+ audio_segment: 输入的AudioSegment对象
301
+ model_name: 使用的模型名称,如果不指定则使用默认模型
302
+ backend: 后端类型,'mlx'或'transformers'
303
+ device: 推理设备,仅对transformers后端有效
304
+
305
+ 返回:
306
+ TranscriptionResult对象,包含转录的文本、分段和语言
307
+ """
308
+ logger.info(f"调用transcribe_audio函数,音频长度: {len(audio_segment)/1000:.2f}秒,后端: {backend}")
309
+
310
+ if backend == "mlx":
311
+ default_model = "mlx-community/distil-whisper-large-v3"
312
+ model = model_name or default_model
313
+ transcriber = MLXDistilWhisperTranscriber(model_name=model)
314
+ else: # transformers
315
+ default_model = "distil-whisper/distil-large-v3.5"
316
+ model = model_name or default_model
317
+ transcriber = TransformersDistilWhisperTranscriber(model_name=model, device=device)
318
+
319
+ return transcriber.transcribe(audio_segment)
src/podcast_transcribe/asr/asr_distil_whisper_mlx.py DELETED
@@ -1,111 +0,0 @@
1
- """
2
- 基于MLX实现的语音识别模块,使用distil-whisper-large-v3模型
3
- """
4
-
5
- import os
6
- from pydub import AudioSegment
7
- from typing import Dict, List, Union
8
- import logging
9
-
10
- # 导入基类
11
- from .asr_base import BaseTranscriber, TranscriptionResult
12
-
13
- # 配置日志
14
- logger = logging.getLogger("asr")
15
-
16
-
17
- class MLXDistilWhisperTranscriber(BaseTranscriber):
18
- """使用MLX加载和运行distil-whisper-large-v3模型的转录器"""
19
-
20
- def __init__(
21
- self,
22
- model_name: str = "mlx-community/distil-whisper-large-v3",
23
- ):
24
- """
25
- 初始化转录器
26
-
27
- 参数:
28
- model_name: 模型名称
29
- """
30
- super().__init__(model_name=model_name)
31
-
32
- def _load_model(self):
33
- """加载Distil Whisper模型"""
34
- try:
35
- # 懒加载mlx-whisper
36
- try:
37
- import mlx_whisper
38
- except ImportError:
39
- raise ImportError("请先安装mlx-whisper库: pip install mlx-whisper")
40
-
41
- logger.info(f"开始加载模型 {self.model_name}")
42
- self.model = mlx_whisper.load_models.load_model(self.model_name)
43
- logger.info(f"模型加载成功")
44
- except Exception as e:
45
- logger.error(f"加载模型失败: {str(e)}", exc_info=True)
46
- raise RuntimeError(f"加载模型失败: {str(e)}")
47
-
48
- def _convert_segments(self, result) -> List[Dict[str, Union[float, str]]]:
49
- """
50
- 将模型的分段结果转换为所需格式
51
-
52
- 参数:
53
- result: 模型返回的结果
54
-
55
- 返回:
56
- 转换后的分段列表
57
- """
58
- segments = []
59
-
60
- for segment in result.get("segments", []):
61
- segments.append({
62
- "start": segment.get("start", 0.0),
63
- "end": segment.get("end", 0.0),
64
- "text": segment.get("text", "").strip()
65
- })
66
-
67
- return segments
68
-
69
- def _perform_transcription(self, audio_data):
70
- """
71
- 执行转录
72
-
73
- 参数:
74
- audio_data: 音频数据(numpy数组)
75
-
76
- 返回:
77
- 模型的转录结果
78
- """
79
- from mlx_whisper import transcribe
80
- return transcribe(audio_data, path_or_hf_repo=self.model_name)
81
-
82
- def _get_text_from_result(self, result):
83
- """
84
- 从结果中获取文本
85
-
86
- 参数:
87
- result: 模型的转录结果
88
-
89
- 返回:
90
- 转录的文本
91
- """
92
- return result.get("text", "")
93
-
94
-
95
- def transcribe_audio(
96
- audio_segment: AudioSegment,
97
- model_name: str = "mlx-community/distil-whisper-large-v3",
98
- ) -> TranscriptionResult:
99
- """
100
- 使用MLX和distil-whisper-large-v3模型转录音频
101
-
102
- 参数:
103
- audio_segment: 输入的AudioSegment对象
104
- model_name: 使用的模型名称
105
-
106
- 返回:
107
- TranscriptionResult对象,包含转录的文本、分段和语言
108
- """
109
- logger.info(f"调用transcribe_audio函数,音频长度: {len(audio_segment)/1000:.2f}秒")
110
- transcriber = MLXDistilWhisperTranscriber(model_name=model_name)
111
- return transcriber.transcribe(audio_segment)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/podcast_transcribe/asr/asr_distil_whisper_transformers.py DELETED
@@ -1,133 +0,0 @@
1
- """
2
- 基于Transformers实现的语音识别模块,使用distil-whisper-large-v3.5模型
3
- """
4
-
5
- import os
6
- from pydub import AudioSegment
7
- from typing import Dict, List, Union
8
- import logging
9
- import numpy as np
10
-
11
- # 导入基类
12
- from .asr_base import BaseTranscriber, TranscriptionResult
13
-
14
- # 配置日志
15
- logger = logging.getLogger("asr")
16
-
17
-
18
- class TransformersDistilWhisperTranscriber(BaseTranscriber):
19
- """使用Transformers加载和运行distil-whisper-large-v3.5模型的转录器"""
20
-
21
- def __init__(
22
- self,
23
- model_name: str = "distil-whisper/distil-large-v3.5",
24
- device: str = "cpu",
25
- ):
26
- """
27
- 初始化转录器
28
-
29
- 参数:
30
- model_name: 模型名称
31
- device: 推理设备,'cpu'或'cuda'
32
- """
33
- super().__init__(model_name=model_name, device=device)
34
-
35
- def _load_model(self):
36
- """加载Distil Whisper模型"""
37
- try:
38
- # 懒加载transformers
39
- try:
40
- from transformers import pipeline
41
- except ImportError:
42
- raise ImportError("请先安装transformers库: pip install transformers")
43
-
44
- logger.info(f"开始加载模型 {self.model_name}")
45
- self.pipeline = pipeline(
46
- "automatic-speech-recognition",
47
- model=self.model_name,
48
- device=0 if self.device == "cuda" else -1,
49
- return_timestamps=True,
50
- chunk_length_s=25,
51
- batch_size=32,
52
- )
53
- logger.info(f"模型加载成功")
54
- except Exception as e:
55
- logger.error(f"加载模型失败: {str(e)}", exc_info=True)
56
- raise RuntimeError(f"加载模型失败: {str(e)}")
57
-
58
- def _convert_segments(self, result) -> List[Dict[str, Union[float, str]]]:
59
- """
60
- 将模型的分段结果转换为所需格式
61
-
62
- 参数:
63
- result: 模型返回的结果
64
-
65
- 返回:
66
- 转换后的分段列表
67
- """
68
- segments = []
69
-
70
- # transformers pipeline 的结果格式
71
- if "chunks" in result:
72
- for chunk in result["chunks"]:
73
- segments.append({
74
- "start": chunk["timestamp"][0] if chunk["timestamp"][0] is not None else 0.0,
75
- "end": chunk["timestamp"][1] if chunk["timestamp"][1] is not None else 0.0,
76
- "text": chunk["text"].strip()
77
- })
78
- else:
79
- # 如果没有分段信息,创建一个单一分段
80
- segments.append({
81
- "start": 0.0,
82
- "end": 0.0, # 无法确定结束时间
83
- "text": result.get("text", "").strip()
84
- })
85
-
86
- return segments
87
-
88
- def _perform_transcription(self, audio_data):
89
- """
90
- 执行转录
91
-
92
- 参数:
93
- audio_data: 音频数据(numpy数组)
94
-
95
- 返回:
96
- 模型的转录结果
97
- """
98
- # transformers pipeline 接受numpy数组作为输入
99
- # 音频数据已经在_prepare_audio中确保是16kHz采样率
100
- return self.pipeline(audio_data)
101
-
102
- def _get_text_from_result(self, result):
103
- """
104
- 从结果中获取文本
105
-
106
- 参数:
107
- result: 模型的转录结果
108
-
109
- 返回:
110
- 转录的文本
111
- """
112
- return result.get("text", "")
113
-
114
-
115
- def transcribe_audio(
116
- audio_segment: AudioSegment,
117
- model_name: str = "distil-whisper/distil-large-v3.5",
118
- device: str = "cpu",
119
- ) -> TranscriptionResult:
120
- """
121
- 使用Transformers和distil-whisper-large-v3.5模型转录音频
122
-
123
- 参数:
124
- audio_segment: 输入的AudioSegment对象
125
- model_name: 使用的模型名称
126
- device: 推理设备,'cpu'或'cuda'
127
-
128
- 返回:
129
- TranscriptionResult对象,包含转录的文本、分段和语言
130
- """
131
- logger.info(f"调用transcribe_audio函数,音频长度: {len(audio_segment)/1000:.2f}秒")
132
- transcriber = TransformersDistilWhisperTranscriber(model_name=model_name, device=device)
133
- return transcriber.transcribe(audio_segment)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/podcast_transcribe/asr/asr_parakeet_mlx.py DELETED
@@ -1,126 +0,0 @@
1
- """
2
- 基于MLX实现的语音识别模块,使用parakeet-tdt模型
3
- """
4
-
5
- import os
6
- from pydub import AudioSegment
7
- from typing import Dict, List, Union
8
- import logging
9
- import tempfile
10
- import numpy as np
11
- import soundfile as sf
12
-
13
- # 导入基类
14
- from .asr_base import BaseTranscriber, TranscriptionResult
15
-
16
- # 配置日志
17
- logger = logging.getLogger("asr")
18
-
19
-
20
- class MLXParakeetTranscriber(BaseTranscriber):
21
- """使用MLX加载和运行parakeet-tdt-0.6b-v2模型的转录器"""
22
-
23
- def __init__(
24
- self,
25
- model_name: str = "mlx-community/parakeet-tdt-0.6b-v2",
26
- ):
27
- """
28
- 初始化转录器
29
-
30
- 参数:
31
- model_name: 模型名称
32
- """
33
- super().__init__(model_name=model_name)
34
-
35
- def _load_model(self):
36
- """加载Parakeet模型"""
37
- try:
38
- # 懒加载parakeet_mlx
39
- try:
40
- from parakeet_mlx import from_pretrained
41
- except ImportError:
42
- raise ImportError("请先安装parakeet-mlx库: pip install parakeet-mlx")
43
-
44
- logger.info(f"开始加载模型 {self.model_name}")
45
- self.model = from_pretrained(self.model_name)
46
- logger.info(f"模型加载成功")
47
- except Exception as e:
48
- logger.error(f"加载模型失败: {str(e)}", exc_info=True)
49
- raise RuntimeError(f"加载模型失败: {str(e)}")
50
-
51
- def _convert_segments(self, aligned_result) -> List[Dict[str, Union[float, str]]]:
52
- """
53
- 将模型的分段结果转换为所需格式
54
-
55
- 参数:
56
- aligned_result: 模型返回的分段结果
57
-
58
- 返回:
59
- 转换后的分段列表
60
- """
61
- segments = []
62
-
63
- for sentence in aligned_result.sentences:
64
- segments.append({
65
- "start": sentence.start,
66
- "end": sentence.end,
67
- "text": sentence.text
68
- })
69
-
70
- return segments
71
-
72
- def _perform_transcription(self, audio_data):
73
- """
74
- 执行转录
75
-
76
- 参数:
77
- audio_data: 音频数据(numpy数组)
78
-
79
- 返回:
80
- 模型的转录结果
81
- """
82
- # 由于parakeet-mlx可能不直接支持numpy数组输入
83
- # 创建临时文件并写入音频数据
84
- with tempfile.NamedTemporaryFile(suffix='.wav', delete=True) as temp_file:
85
- # 确保数据在[-1, 1]范围内
86
- if audio_data.max() > 1.0 or audio_data.min() < -1.0:
87
- audio_data = np.clip(audio_data, -1.0, 1.0)
88
-
89
- # 写入临时文件
90
- sf.write(temp_file.name, audio_data, 16000, 'PCM_16')
91
-
92
- # 使用临时文件进行转录
93
- result = self.model.transcribe(temp_file.name)
94
-
95
- return result
96
-
97
- def _get_text_from_result(self, result):
98
- """
99
- 从结果中获取文本
100
-
101
- 参数:
102
- result: 模型的转录结果
103
-
104
- 返回:
105
- 转录的文本
106
- """
107
- return result.text
108
-
109
-
110
- def transcribe_audio(
111
- audio_segment: AudioSegment,
112
- model_name: str = "mlx-community/parakeet-tdt-0.6b-v2",
113
- ) -> TranscriptionResult:
114
- """
115
- 使用MLX和parakeet-tdt模型转录音频
116
-
117
- 参数:
118
- audio_segment: 输入的AudioSegment对象
119
- model_name: 使用的模型名称
120
-
121
- 返回:
122
- TranscriptionResult对象,包含转录的文本、分段和语言
123
- """
124
- logger.info(f"调用transcribe_audio函数,音频长度: {len(audio_segment)/1000:.2f}秒")
125
- transcriber = MLXParakeetTranscriber(model_name=model_name)
126
- return transcriber.transcribe(audio_segment)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/podcast_transcribe/asr/asr_router.py CHANGED
@@ -4,13 +4,11 @@ ASR模型调用路由器
4
  """
5
 
6
  import logging
7
- from typing import Dict, Any, Optional, Callable
8
  from pydub import AudioSegment
9
  import spaces
10
  from .asr_base import TranscriptionResult
11
- from . import asr_parakeet_mlx
12
- from . import asr_distil_whisper_mlx
13
- from . import asr_distil_whisper_transformers
14
 
15
  # 配置日志
16
  logger = logging.getLogger("asr")
@@ -26,22 +24,8 @@ class ASRRouter:
26
 
27
  # 定义支持的provider配置
28
  self._provider_configs = {
29
- "parakeet_mlx": {
30
- "module_path": "asr_parakeet_mlx",
31
- "function_name": "transcribe_audio",
32
- "default_model": "mlx-community/parakeet-tdt-0.6b-v2",
33
- "supported_params": ["model_name"],
34
- "description": "基于MLX的Parakeet模型"
35
- },
36
- "distil_whisper_mlx": {
37
- "module_path": "asr_distil_whisper_mlx",
38
- "function_name": "transcribe_audio",
39
- "default_model": "mlx-community/distil-whisper-large-v3",
40
- "supported_params": ["model_name"],
41
- "description": "基于MLX的Distil Whisper模型"
42
- },
43
  "distil_whisper_transformers": {
44
- "module_path": "asr_distil_whisper_transformers",
45
  "function_name": "transcribe_audio",
46
  "default_model": "distil-whisper/distil-large-v3.5",
47
  "supported_params": ["model_name", "device"],
@@ -66,15 +50,8 @@ class ASRRouter:
66
  module_path = self._provider_configs[provider]["module_path"]
67
  logger.info(f"获取模块: {module_path}")
68
 
69
- # 根据module_path返回对应的模块
70
- if module_path == "asr_parakeet_mlx":
71
- module = asr_parakeet_mlx
72
- elif module_path == "asr_distil_whisper_mlx":
73
- module = asr_distil_whisper_mlx
74
- elif module_path == "asr_distil_whisper_transformers":
75
- module = asr_distil_whisper_transformers
76
- else:
77
- raise ImportError(f"未找到模块: {module_path}")
78
 
79
  self._loaded_modules[provider] = module
80
  logger.info(f"模块 {module_path} 获取成功")
@@ -202,51 +179,17 @@ def transcribe_audio(
202
  provider: str = "distil_whisper_transformers",
203
  model_name: Optional[str] = None,
204
  device: str = "cpu",
 
205
  **kwargs
206
  ) -> TranscriptionResult:
207
- """
208
- 统一的音频转录接口函数
209
-
210
- 参数:
211
- audio_segment: 输入的AudioSegment对象
212
- provider: ASR提供者,可选值:
213
- - "parakeet_mlx": 基于MLX的Parakeet模型
214
- - "distil_whisper_mlx": 基于MLX的Distil Whisper模型
215
- - "distil_whisper_transformers": 基于Transformers的Distil Whisper模型
216
- model_name: 模型名称,如果不指定则使用默认模型
217
- device: 推理设备,仅对transformers provider有效
218
- **kwargs: 其他参数
219
-
220
- 返回:
221
- TranscriptionResult对象,包含转录的文本、分段和语言
222
-
223
- 示例:
224
- # 使用默认MLX Distil Whisper模型
225
- result = transcribe_audio(audio_segment, provider="distil_whisper_mlx")
226
-
227
- # 使用Parakeet模型
228
- result = transcribe_audio(audio_segment, provider="parakeet_mlx")
229
-
230
- # 使用Transformers模型并指定设备
231
- result = transcribe_audio(
232
- audio_segment,
233
- provider="distil_whisper_transformers",
234
- device="cuda"
235
- )
236
-
237
- # 使用自定义模型
238
- result = transcribe_audio(
239
- audio_segment,
240
- provider="distil_whisper_mlx",
241
- model_name="mlx-community/whisper-large-v3"
242
- )
243
- """
244
  # 准备参数
245
  params = kwargs.copy()
246
  if model_name is not None:
247
  params["model_name"] = model_name
248
  if device != "cpu":
249
  params["device"] = device
 
 
250
 
251
  return _router.transcribe(audio_segment, provider, **params)
252
 
 
4
  """
5
 
6
  import logging
7
+ from typing import Dict, Any, Literal, Optional, Callable
8
  from pydub import AudioSegment
9
  import spaces
10
  from .asr_base import TranscriptionResult
11
+ from . import asr_distil_whisper
 
 
12
 
13
  # 配置日志
14
  logger = logging.getLogger("asr")
 
24
 
25
  # 定义支持的provider配置
26
  self._provider_configs = {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  "distil_whisper_transformers": {
28
+ "module_path": "asr_distil_whisper",
29
  "function_name": "transcribe_audio",
30
  "default_model": "distil-whisper/distil-large-v3.5",
31
  "supported_params": ["model_name", "device"],
 
50
  module_path = self._provider_configs[provider]["module_path"]
51
  logger.info(f"获取模块: {module_path}")
52
 
53
+ # 所有provider现在都指向同一个模块
54
+ module = asr_distil_whisper
 
 
 
 
 
 
 
55
 
56
  self._loaded_modules[provider] = module
57
  logger.info(f"模块 {module_path} 获取成功")
 
179
  provider: str = "distil_whisper_transformers",
180
  model_name: Optional[str] = None,
181
  device: str = "cpu",
182
+ backend: str = "transformers",
183
  **kwargs
184
  ) -> TranscriptionResult:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  # 准备参数
186
  params = kwargs.copy()
187
  if model_name is not None:
188
  params["model_name"] = model_name
189
  if device != "cpu":
190
  params["device"] = device
191
+ if backend is not None:
192
+ params["backend"] = backend
193
 
194
  return _router.transcribe(audio_segment, provider, **params)
195
 
src/podcast_transcribe/diarization/diarization_pyannote_mlx.py CHANGED
@@ -48,9 +48,6 @@ class PyannoteTranscriber(BaseDiarizer):
48
  except ImportError:
49
  raise ImportError("请先安装pyannote.audio库: pip install pyannote.audio")
50
 
51
- if not self.token:
52
- raise ValueError("需要提供Hugging Face令牌才能使用pyannote模型。请通过参数传入或设置HF_TOKEN环境变量。")
53
-
54
  logger.info(f"开始加载模型 {self.model_name}")
55
  self.pipeline = Pipeline.from_pretrained(
56
  self.model_name,
 
48
  except ImportError:
49
  raise ImportError("请先安装pyannote.audio库: pip install pyannote.audio")
50
 
 
 
 
51
  logger.info(f"开始加载模型 {self.model_name}")
52
  self.pipeline = Pipeline.from_pretrained(
53
  self.model_name,
src/podcast_transcribe/diarization/diarization_pyannote_transformers.py CHANGED
@@ -48,10 +48,7 @@ class PyannoteTransformersTranscriber(BaseDiarizer):
48
  from pyannote.audio import Pipeline
49
  except ImportError:
50
  raise ImportError("请先安装pyannote.audio库: pip install pyannote.audio")
51
-
52
- if not self.token:
53
- raise ValueError("需要提供Hugging Face令牌才能使用pyannote模型。请通过参数传入或设置HF_TOKEN环境变量。")
54
-
55
  logger.info(f"开始使用pyannote.audio加载模型 {self.model_name}")
56
 
57
  # 使用pyannote.audio Pipeline加载说话人分离模型
 
48
  from pyannote.audio import Pipeline
49
  except ImportError:
50
  raise ImportError("请先安装pyannote.audio库: pip install pyannote.audio")
51
+
 
 
 
52
  logger.info(f"开始使用pyannote.audio加载模型 {self.model_name}")
53
 
54
  # 使用pyannote.audio Pipeline加载说话人分离模型
src/podcast_transcribe/diarization/diarizer_base.py CHANGED
@@ -34,7 +34,6 @@ class BaseDiarizer(ABC):
34
  segmentation_batch_size: 分割批处理大小,默认为32
35
  """
36
  self.model_name = model_name
37
- self.token = token or os.environ.get("HF_TOKEN")
38
  self.device = device
39
  self.segmentation_batch_size = segmentation_batch_size
40
 
 
34
  segmentation_batch_size: 分割批处理大小,默认为32
35
  """
36
  self.model_name = model_name
 
37
  self.device = device
38
  self.segmentation_batch_size = segmentation_batch_size
39
 
src/podcast_transcribe/diarization/diarizer_router.py CHANGED
@@ -215,20 +215,18 @@ def diarize_audio(
215
 
216
  示例:
217
  # 使用默认pyannote MLX实现
218
- result = diarize_audio(audio_segment, provider="pyannote_mlx", token="your_hf_token")
219
 
220
  # 使用transformers实现
221
  result = diarize_audio(
222
  audio_segment,
223
  provider="pyannote_transformers",
224
- token="your_hf_token"
225
  )
226
 
227
  # 使用GPU设备
228
  result = diarize_audio(
229
  audio_segment,
230
  provider="pyannote_mlx",
231
- token="your_hf_token",
232
  device="cuda"
233
  )
234
 
@@ -236,7 +234,6 @@ def diarize_audio(
236
  result = diarize_audio(
237
  audio_segment,
238
  provider="pyannote_mlx",
239
- token="your_hf_token",
240
  segmentation_batch_size=64
241
  )
242
  """
 
215
 
216
  示例:
217
  # 使用默认pyannote MLX实现
218
+ result = diarize_audio(audio_segment, provider="pyannote_mlx")
219
 
220
  # 使用transformers实现
221
  result = diarize_audio(
222
  audio_segment,
223
  provider="pyannote_transformers",
 
224
  )
225
 
226
  # 使用GPU设备
227
  result = diarize_audio(
228
  audio_segment,
229
  provider="pyannote_mlx",
 
230
  device="cuda"
231
  )
232
 
 
234
  result = diarize_audio(
235
  audio_segment,
236
  provider="pyannote_mlx",
 
237
  segmentation_batch_size=64
238
  )
239
  """
src/podcast_transcribe/transcriber.py CHANGED
@@ -31,7 +31,6 @@ class CombinedTranscriber:
31
  diarization_model_name: str,
32
  llm_model_name: Optional[str] = None,
33
  llm_provider: Optional[str] = None,
34
- hf_token: Optional[str] = None,
35
  device: Optional[str] = None,
36
  segmentation_batch_size: int = 64,
37
  parallel: bool = False,
@@ -44,7 +43,6 @@ class CombinedTranscriber:
44
  asr_provider: ASR提供者名称
45
  diarization_provider: 说话人分离提供者名称
46
  diarization_model_name: 说话人分离模型名称
47
- hf_token: Hugging Face令牌
48
  device: 推理设备,'cpu'或'cuda'
49
  segmentation_batch_size: 分割批处理大小,默认为64
50
  parallel: 是否并行执行ASR和说话人分离,默认为False
@@ -75,7 +73,6 @@ class CombinedTranscriber:
75
  self.asr_provider = asr_provider
76
  self.diarization_provider = diarization_provider
77
  self.diarization_model_name = diarization_model_name
78
- self.hf_token = hf_token or os.environ.get("HF_TOKEN")
79
  self.device = device
80
  self.segmentation_batch_size = segmentation_batch_size
81
  self.parallel = parallel
@@ -148,7 +145,6 @@ class CombinedTranscriber:
148
  audio,
149
  provider=self.diarization_provider,
150
  model_name=self.diarization_model_name,
151
- token=self.hf_token,
152
  device=self.device,
153
  segmentation_batch_size=self.segmentation_batch_size
154
  )
@@ -195,7 +191,6 @@ class CombinedTranscriber:
195
  audio,
196
  provider=self.diarization_provider,
197
  model_name=self.diarization_model_name,
198
- token=self.hf_token,
199
  device=self.device,
200
  segmentation_batch_size=self.segmentation_batch_size
201
  )
@@ -491,7 +486,6 @@ def transcribe_audio(
491
  asr_provider: str = "distil_whisper_transformers",
492
  diarization_model_name: str = "pyannote/speaker-diarization-3.1",
493
  diarization_provider: str = "pyannote_transformers",
494
- hf_token: Optional[str] = None,
495
  device: Optional[str] = None,
496
  segmentation_batch_size: int = 64,
497
  parallel: bool = False,
@@ -505,7 +499,6 @@ def transcribe_audio(
505
  asr_provider: ASR提供者名称
506
  diarization_model_name: 说话人分离模型名称
507
  diarization_provider: 说话人分离提供者名称
508
- hf_token: Hugging Face令牌
509
  device: 推理设备,'cpu'或'cuda'
510
  segmentation_batch_size: 分割批处理大小,默认为64
511
  parallel: 是否并行执行ASR和说话人分离,默认为False
@@ -520,7 +513,6 @@ def transcribe_audio(
520
  asr_provider=asr_provider,
521
  diarization_model_name=diarization_model_name,
522
  diarization_provider=diarization_provider,
523
- hf_token=hf_token,
524
  device=device,
525
  segmentation_batch_size=segmentation_batch_size,
526
  parallel=parallel
@@ -539,7 +531,6 @@ def transcribe_podcast_audio(
539
  diarization_provider: str = "pyannote_transformers",
540
  llm_model_name: Optional[str] = None,
541
  llm_provider: Optional[str] = None,
542
- hf_token: Optional[str] = None,
543
  device: Optional[str] = None,
544
  segmentation_batch_size: int = 64,
545
  parallel: bool = False,
@@ -557,7 +548,6 @@ def transcribe_podcast_audio(
557
  diarization_model_name: 说话人分离模型名称
558
  llm_model_name: LLM模型名称,如果为None则无法识别说话人名称
559
  llm_provider: LLM提供者名称,如果为None则无法识别说话人名称
560
- hf_token: Hugging Face令牌
561
  device: 推理设备,'cpu'或'cuda'
562
  segmentation_batch_size: 分割批处理大小,默认为64
563
  parallel: 是否并行执行ASR和说话人分离,默认为False
@@ -574,7 +564,6 @@ def transcribe_podcast_audio(
574
  diarization_model_name=diarization_model_name,
575
  llm_model_name=llm_model_name,
576
  llm_provider=llm_provider,
577
- hf_token=hf_token,
578
  device=device,
579
  segmentation_batch_size=segmentation_batch_size,
580
  parallel=parallel
 
31
  diarization_model_name: str,
32
  llm_model_name: Optional[str] = None,
33
  llm_provider: Optional[str] = None,
 
34
  device: Optional[str] = None,
35
  segmentation_batch_size: int = 64,
36
  parallel: bool = False,
 
43
  asr_provider: ASR提供者名称
44
  diarization_provider: 说话人分离提供者名称
45
  diarization_model_name: 说话人分离模型名称
 
46
  device: 推理设备,'cpu'或'cuda'
47
  segmentation_batch_size: 分割批处理大小,默认为64
48
  parallel: 是否并行执行ASR和说话人分离,默认为False
 
73
  self.asr_provider = asr_provider
74
  self.diarization_provider = diarization_provider
75
  self.diarization_model_name = diarization_model_name
 
76
  self.device = device
77
  self.segmentation_batch_size = segmentation_batch_size
78
  self.parallel = parallel
 
145
  audio,
146
  provider=self.diarization_provider,
147
  model_name=self.diarization_model_name,
 
148
  device=self.device,
149
  segmentation_batch_size=self.segmentation_batch_size
150
  )
 
191
  audio,
192
  provider=self.diarization_provider,
193
  model_name=self.diarization_model_name,
 
194
  device=self.device,
195
  segmentation_batch_size=self.segmentation_batch_size
196
  )
 
486
  asr_provider: str = "distil_whisper_transformers",
487
  diarization_model_name: str = "pyannote/speaker-diarization-3.1",
488
  diarization_provider: str = "pyannote_transformers",
 
489
  device: Optional[str] = None,
490
  segmentation_batch_size: int = 64,
491
  parallel: bool = False,
 
499
  asr_provider: ASR提供者名称
500
  diarization_model_name: 说话人分离模型名称
501
  diarization_provider: 说话人分离提供者名称
 
502
  device: 推理设备,'cpu'或'cuda'
503
  segmentation_batch_size: 分割批处理大小,默认为64
504
  parallel: 是否并行执行ASR和说话人分离,默认为False
 
513
  asr_provider=asr_provider,
514
  diarization_model_name=diarization_model_name,
515
  diarization_provider=diarization_provider,
 
516
  device=device,
517
  segmentation_batch_size=segmentation_batch_size,
518
  parallel=parallel
 
531
  diarization_provider: str = "pyannote_transformers",
532
  llm_model_name: Optional[str] = None,
533
  llm_provider: Optional[str] = None,
 
534
  device: Optional[str] = None,
535
  segmentation_batch_size: int = 64,
536
  parallel: bool = False,
 
548
  diarization_model_name: 说话人分离模型名称
549
  llm_model_name: LLM模型名称,如果为None则无法识别说话人名称
550
  llm_provider: LLM提供者名称,如果为None则无法识别说话人名称
 
551
  device: 推理设备,'cpu'或'cuda'
552
  segmentation_batch_size: 分割批处理大小,默认为64
553
  parallel: 是否并行执行ASR和说话人分离,默认为False
 
564
  diarization_model_name=diarization_model_name,
565
  llm_model_name=llm_model_name,
566
  llm_provider=llm_provider,
 
567
  device=device,
568
  segmentation_batch_size=segmentation_batch_size,
569
  parallel=parallel