konieshadow commited on
Commit
8289369
·
1 Parent(s): e0e9f98
Files changed (36) hide show
  1. .DS_Store +0 -0
  2. .gitignore +131 -0
  3. .vscode/launch.json +19 -0
  4. app.py +44 -5
  5. examples/.DS_Store +0 -0
  6. examples/combined_podcast_transcription.py +102 -0
  7. examples/combined_transcription.py +104 -0
  8. examples/simple_asr.py +80 -0
  9. examples/simple_diarization.py +79 -0
  10. examples/simple_llm.py +74 -0
  11. examples/simple_rss_parser.py +40 -0
  12. examples/simple_speaker_identify.py +68 -0
  13. requirements.txt +21 -0
  14. src/.DS_Store +0 -0
  15. src/podcast_transcribe/.DS_Store +0 -0
  16. src/podcast_transcribe/__init__.py +8 -0
  17. src/podcast_transcribe/asr/asr_base.py +277 -0
  18. src/podcast_transcribe/asr/asr_distil_whisper_mlx.py +112 -0
  19. src/podcast_transcribe/asr/asr_distil_whisper_transformers.py +133 -0
  20. src/podcast_transcribe/asr/asr_parakeet_mlx.py +126 -0
  21. src/podcast_transcribe/asr/asr_router.py +273 -0
  22. src/podcast_transcribe/audio.py +62 -0
  23. src/podcast_transcribe/diarization/diarization_pyannote_mlx.py +154 -0
  24. src/podcast_transcribe/diarization/diarization_pyannote_transformers.py +170 -0
  25. src/podcast_transcribe/diarization/diarizer_base.py +118 -0
  26. src/podcast_transcribe/diarization/diarizer_router.py +276 -0
  27. src/podcast_transcribe/llm/llm_base.py +391 -0
  28. src/podcast_transcribe/llm/llm_gemma_mlx.py +62 -0
  29. src/podcast_transcribe/llm/llm_gemma_transfomers.py +61 -0
  30. src/podcast_transcribe/llm/llm_phi4_transfomers.py +369 -0
  31. src/podcast_transcribe/llm/llm_router.py +578 -0
  32. src/podcast_transcribe/rss/podcast_rss_parser.py +162 -0
  33. src/podcast_transcribe/schemas.py +63 -0
  34. src/podcast_transcribe/summary/speaker_identify.py +350 -0
  35. src/podcast_transcribe/transcriber.py +588 -0
  36. src/podcast_transcribe/webui/app.py +585 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.gitignore ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ *.egg-info/
24
+ .installed.cfg
25
+ *.egg
26
+ MANIFEST
27
+
28
+ # PyInstaller
29
+ *.manifest
30
+ *.spec
31
+
32
+ # Installer logs
33
+ pip-log.txt
34
+ pip-delete-this-directory.txt
35
+
36
+ # Unit test / coverage reports
37
+ htmlcov/
38
+ .tox/
39
+ .nox/
40
+ .coverage
41
+ .coverage.*
42
+ .cache
43
+ nosetests.xml
44
+ coverage.xml
45
+ *.cover
46
+ .hypothesis/
47
+ .pytest_cache/
48
+
49
+ # Translations
50
+ *.mo
51
+ *.pot
52
+
53
+ # Django stuff:
54
+ *.log
55
+ local_settings.py
56
+ db.sqlite3
57
+
58
+ # Flask stuff:
59
+ instance/
60
+ .webassets-cache
61
+
62
+ # Scrapy stuff:
63
+ .scrapy
64
+
65
+ # Sphinx documentation
66
+ docs/_build/
67
+
68
+ # PyBuilder
69
+ target/
70
+
71
+ # Jupyter Notebook
72
+ .ipynb_checkpoints
73
+
74
+ # IPython
75
+ profile_default/
76
+ ipython_config.py
77
+
78
+ # pyenv
79
+ .python-version
80
+
81
+ # celery beat schedule file
82
+ celerybeat-schedule
83
+
84
+ # SageMath parsed files
85
+ *.sage.py
86
+
87
+ # Environments
88
+ .env
89
+ .venv
90
+ env/
91
+ venv/
92
+ ENV/
93
+ env.bak/
94
+ venv.bak/
95
+
96
+ # Spyder project settings
97
+ .spyderproject
98
+ .spyproject
99
+
100
+ # Rope project settings
101
+ .ropeproject
102
+
103
+ # mkdocs documentation
104
+ /site
105
+
106
+ # mypy
107
+ .mypy_cache/
108
+ .dmypy.json
109
+ dmypy.json
110
+
111
+ # Pyre type checker
112
+ .pyre/
113
+
114
+ # uv
115
+ .uv/
116
+
117
+ # Output files
118
+ output/
119
+ logs/
120
+
121
+ # Large models
122
+ *.bin
123
+ *.pth
124
+ *.pt
125
+ *.onnx
126
+
127
+ examples/input/
128
+ examples/output/
129
+
130
+ # temp files
131
+ _temp_*
.vscode/launch.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ // Use IntelliSense to learn about possible attributes.
3
+ // Hover to view descriptions of existing attributes.
4
+ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5
+ "version": "0.2.0",
6
+ "configurations": [
7
+ {
8
+ "name": "Python Debugger: Current File",
9
+ "type": "debugpy",
10
+ "request": "launch",
11
+ "program": "${file}",
12
+ "console": "integratedTerminal",
13
+ "env": {
14
+ // "HTTPS_PROXY": "http://127.0.0.1:12334",
15
+ // "HF_ENDPOINT": "https://hf-mirror.com"
16
+ }
17
+ }
18
+ ]
19
+ }
app.py CHANGED
@@ -1,7 +1,46 @@
1
- import gradio as gr
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 播客转录工具 - 主启动文件
4
+ 这个文件用于启动 Gradio WebUI 应用
5
+ """
6
 
7
+ import sys
8
+ import os
9
 
10
+ # src 目录添加到 Python 路径中
11
+ current_dir = os.path.dirname(os.path.abspath(__file__))
12
+ src_path = os.path.join(current_dir, "src")
13
+ if src_path not in sys.path:
14
+ sys.path.insert(0, src_path)
15
+
16
+ def main():
17
+ """主函数:启动 WebUI 应用"""
18
+ try:
19
+ # 导入并启动 webui 应用
20
+ from podcast_transcribe.webui.app import demo
21
+
22
+ print("🎙️ 启动播客转录工具...")
23
+ print("📍 WebUI 将在浏览器中打开")
24
+ print("🔗 默认地址: http://localhost:7860")
25
+ print("⏹️ 按 Ctrl+C 停止服务")
26
+
27
+ # 启动 Gradio 应用
28
+ demo.launch(
29
+ debug=True,
30
+ server_name="0.0.0.0", # 允许外部访问
31
+ server_port=7860, # 指定端口
32
+ share=False, # 不生成公开链接
33
+ inbrowser=True # 自动在浏览器中打开
34
+ )
35
+
36
+ except ImportError as e:
37
+ print(f"❌ 导入错误: {e}")
38
+ print("请确保已安装所有依赖包:")
39
+ print("pip install -r requirements.txt")
40
+ sys.exit(1)
41
+ except Exception as e:
42
+ print(f"❌ 启动失败: {e}")
43
+ sys.exit(1)
44
+
45
+ if __name__ == "__main__":
46
+ main()
examples/.DS_Store ADDED
Binary file (6.15 kB). View file
 
examples/combined_podcast_transcription.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 添加项目根目录到Python路径
2
+ import json
3
+ import sys
4
+ from pathlib import Path
5
+ import os
6
+
7
+ sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
8
+
9
+ from src.podcast_transcribe.transcriber import transcribe_podcast_audio
10
+ from src.podcast_transcribe.audio import load_audio
11
+ from src.podcast_transcribe.rss.podcast_rss_parser import parse_rss_xml_content
12
+ from podcast_transcribe.llm.llm_gemma_mlx import GemmaMLXChatCompletion
13
+ from src.podcast_transcribe.schemas import EnhancedSegment, CombinedTranscriptionResult
14
+ from src.podcast_transcribe.summary.speaker_identify import recognize_speaker_names
15
+
16
+ def main():
17
+ """主函数"""
18
+ podcast_rss_xml_file = Path.joinpath(Path(__file__).parent, "input", "lexfridman.com.rss.xml")
19
+ audio_file = Path.joinpath(Path(__file__).parent, "input", "lex_ai_john_carmack_1.wav") # 播客音频文件路径
20
+ # audio_file = Path("/Users/konie/Desktop/voices/lex_ai_john_carmack_30.wav")
21
+
22
+ # 模型配置
23
+ asr_model_name = "mlx-community/parakeet-tdt-0.6b-v2" # ASR模型名称
24
+ diarization_model_name = "pyannote/speaker-diarization-3.1" # 说话人分离模型名称
25
+ llm_model_path = "mlx-community/gemma-3-12b-it-4bit-DWQ"
26
+ hf_token = "" # Hugging Face API 令牌
27
+ device = "mps" # 设备类型
28
+ segmentation_batch_size = 64
29
+ parallel = True
30
+
31
+ # 检查文件是否存在
32
+ if not os.path.exists(audio_file):
33
+ print(f"错误:文件 '{audio_file}' 不存在")
34
+ return 1
35
+
36
+ # 检查HF令牌
37
+ if not hf_token:
38
+ print("警告:未设置HF_TOKEN环境变量,必须设置此环境变量才能使用pyannote说话人分离模型")
39
+ print("请执行:export HF_TOKEN='你的HuggingFace令牌'")
40
+ return 1
41
+
42
+ try:
43
+ print(f"正在加载音频文件: {audio_file}")
44
+ # 加载音频文件
45
+ audio, _ = load_audio(audio_file)
46
+
47
+ print(f"音频信息: 时长={audio.duration_seconds:.2f}秒, 通道数={audio.channels}, 采样率={audio.frame_rate}Hz")
48
+
49
+
50
+ except Exception as e:
51
+ print(f"错误: {str(e)}")
52
+ import traceback
53
+ traceback.print_exc()
54
+ return 1
55
+
56
+ # Load the podcast RSS XML file
57
+ with open(podcast_rss_xml_file, "r") as f:
58
+ podcast_rss_xml = f.read()
59
+ mock_podcast_info = parse_rss_xml_content(podcast_rss_xml)
60
+
61
+
62
+ # 查找标题已 "#309" 开头的剧集
63
+ mock_episode_info = next((episode for episode in mock_podcast_info.episodes if episode.title.startswith("#309")), None)
64
+ if not mock_episode_info:
65
+ raise ValueError("Could not find episode with title starting with '#309'")
66
+
67
+ result = transcribe_podcast_audio(audio,
68
+ podcast_info=mock_podcast_info,
69
+ episode_info=mock_episode_info,
70
+ hf_token=hf_token,
71
+ asr_model_name=asr_model_name,
72
+ diarization_model_name=diarization_model_name,
73
+ llm_model_name=llm_model_path,
74
+ device=device,
75
+ segmentation_batch_size=segmentation_batch_size,
76
+ parallel=parallel,
77
+ llm_model_name=llm_model_path)
78
+
79
+ # 输出结果
80
+ print("\n转录结果:")
81
+ print("-" * 50)
82
+ print(f"检测到的语言: {result.language}")
83
+ print(f"检测到的说话人数量: {result.num_speakers}")
84
+ print(f"总文本长度: {len(result.text)} 字符")
85
+
86
+ # 输出每个说话人的部分
87
+ speakers = set(segment.speaker for segment in result.segments)
88
+ for speaker in sorted(speakers):
89
+ speaker_segments = [seg for seg in result.segments if seg.speaker == speaker]
90
+ total_duration = sum(seg.end - seg.start for seg in speaker_segments)
91
+ print(f"\n说话人 {speaker}: 共 {len(speaker_segments)} 个片段, 总时长 {total_duration:.2f} 秒")
92
+
93
+ # 输出详细分段信息
94
+ print("\n详细分段信息:")
95
+ for i, segment in enumerate(result.segments, 1):
96
+ if i <= 20 or i > len(result.segments) - 20: # 仅显示前20个和后20个分段
97
+ print(f"段落 {i}/{len(result.segments)}: [{segment.start:.2f}s - {segment.end:.2f}s] 说话人: {segment.speaker_name if segment.speaker_name else segment.speaker} 文本: {segment.text}")
98
+ elif i == 21:
99
+ print("... 省略中间部分 ...")
100
+
101
+ if __name__ == '__main__':
102
+ main()
examples/combined_transcription.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ 整合ASR和说话人分离的示例程序
5
+ 从本地文件读取音频,同时进行转录和说话人分离
6
+ """
7
+
8
+ import json
9
+ import os
10
+ import sys
11
+ from pathlib import Path
12
+ from dataclasses import asdict
13
+
14
+ # 添加项目根目录到Python路径
15
+ sys.path.insert(0, str(Path(__file__).parent.parent))
16
+
17
+ # 导入必要的模块,使用正确的导入路径
18
+ from src.podcast_transcribe.audio import load_audio
19
+ from src.podcast_transcribe.transcriber import transcribe_audio
20
+
21
+
22
+ def main():
23
+ """主函数"""
24
+ audio_file = Path.joinpath(Path(__file__).parent, "input", "lex_ai_john_carmack_1.wav") # 播客音频文件路径
25
+ # audio_file = Path("/Users/konie/Desktop/voices/lex_ai_john_carmack_30.wav")
26
+
27
+ # 模型配置
28
+ asr_model_name = "distil-whisper/distil-large-v3.5" # ASR模型名称
29
+ diarization_model_name = "pyannote/speaker-diarization-3.1" # 说话人分离模型名称
30
+ hf_token = "" # Hugging Face API 令牌
31
+ device = "mps" # 设备类型
32
+ segmentation_batch_size = 64
33
+ parallel = True
34
+
35
+ # 检查文件是否存在
36
+ if not os.path.exists(audio_file):
37
+ print(f"错误:文件 '{audio_file}' 不存在")
38
+ return 1
39
+
40
+ # 检查HF令牌
41
+ if not hf_token:
42
+ print("警告:未设置HF_TOKEN环境变量,必须设置此环境变量才能使用pyannote说话人分离模型")
43
+ print("请执行:export HF_TOKEN='你的HuggingFace令牌'")
44
+ return 1
45
+
46
+ try:
47
+ print(f"正在加载音频文件: {audio_file}")
48
+ # 加载音频文件
49
+ audio, _ = load_audio(audio_file)
50
+
51
+ print(f"音频信息: 时长={audio.duration_seconds:.2f}秒, 通道数={audio.channels}, 采样率={audio.frame_rate}Hz")
52
+
53
+ result = transcribe_audio(
54
+ audio,
55
+ asr_model_name=asr_model_name,
56
+ diarization_model_name=diarization_model_name,
57
+ hf_token=hf_token,
58
+ device=device,
59
+ segmentation_batch_size=segmentation_batch_size,
60
+ parallel=parallel,
61
+ )
62
+
63
+ # 输出结果
64
+ print("\n转录结果:")
65
+ print("-" * 50)
66
+ print(f"检测到的语言: {result.language}")
67
+ print(f"检测到的说话人数量: {result.num_speakers}")
68
+ print(f"总文本长度: {len(result.text)} 字符")
69
+
70
+ # 输出每个说话人的部分
71
+ speakers = set(segment.speaker for segment in result.segments)
72
+ for speaker in sorted(speakers):
73
+ speaker_segments = [seg for seg in result.segments if seg.speaker == speaker]
74
+ total_duration = sum(seg.end - seg.start for seg in speaker_segments)
75
+ print(f"\n说话人 {speaker}: 共 {len(speaker_segments)} 个片段, 总时长 {total_duration:.2f} 秒")
76
+
77
+ # 输出详细分段信息
78
+ print("\n详细分段信息:")
79
+ for i, segment in enumerate(result.segments, 1):
80
+ if i <= 20 or i > len(result.segments) - 20: # 仅显示前20个和后20个分段
81
+ print(f"段落 {i}/{len(result.segments)}: [{segment.start:.2f}s - {segment.end:.2f}s] 说话人: {segment.speaker} 文本: {segment.text}")
82
+ elif i == 21:
83
+ print("... 省略中间部分 ...")
84
+
85
+ # 将转录结果保存为json文件,文件名取自音频文件名
86
+ output_file = Path.joinpath(Path(__file__).parent, "output", f"{audio_file.stem}.transcription.json")
87
+ # 创建上层文件夹
88
+ output_dir = Path.joinpath(Path(__file__).parent, "output")
89
+ output_dir.mkdir(parents=True, exist_ok=True)
90
+ with open(output_file, "w") as f:
91
+ json.dump(asdict(result), f)
92
+ print(f"转录结果已保存到 {output_file}")
93
+
94
+ return 0
95
+
96
+ except Exception as e:
97
+ print(f"错误: {str(e)}")
98
+ import traceback
99
+ traceback.print_exc()
100
+ return 1
101
+
102
+
103
+ if __name__ == "__main__":
104
+ sys.exit(main())
examples/simple_asr.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ 简单的语音识别示例程序
5
+ 从本地文件读取音频并进行转录
6
+ """
7
+
8
+ import logging
9
+ import os
10
+ import sys
11
+ from pathlib import Path
12
+
13
+ # 添加项目根目录到Python路径
14
+ sys.path.insert(0, str(Path(__file__).parent.parent))
15
+
16
+ from src.podcast_transcribe.audio import load_audio
17
+
18
+ logger = logging.getLogger("asr_example")
19
+
20
+
21
+ def main():
22
+ """主函数"""
23
+ # audio_file = Path.joinpath(Path(__file__).parent, "input", "lex_ai_john_carmack_1.wav") # 播客音频文件路径
24
+ # audio_file = "/Users/konie/Desktop/voices/lex_ai_john_carmack_30.wav" # 播客音频文件路径
25
+ audio_file = "/Users/konie/Desktop/voices/podcast1_1.wav"
26
+ # model = "distil-whisper"
27
+ model = "distil-whisper-transformers"
28
+
29
+ device = "mlx"
30
+
31
+ # 检查文件是否存在
32
+ if not os.path.exists(audio_file):
33
+ print(f"错误:文件 '{audio_file}' 不存在")
34
+ return 1
35
+
36
+ if model == "parakeet":
37
+ from src.podcast_transcribe.asr.asr_parakeet_mlx import transcribe_audio
38
+ model_name = "mlx-community/parakeet-tdt-0.6b-v2"
39
+ logger.info(f"使用Parakeet模型: {model_name}")
40
+ elif model == "distil-whisper": # distil-whisper
41
+ from src.podcast_transcribe.asr.asr_distil_whisper_mlx import transcribe_audio
42
+ model_name = "mlx-community/distil-whisper-large-v3"
43
+ logger.info(f"使用Distil Whisper模型: {model_name}")
44
+ elif model == "distil-whisper-transformers": # distil-whisper
45
+ from src.podcast_transcribe.asr.asr_distil_whisper_transformers import transcribe_audio
46
+ model_name = "distil-whisper/distil-large-v3.5"
47
+ logger.info(f"使用Distil Whisper模型: {model_name}")
48
+ else:
49
+ logger.error(f"错误:未指定模型类型")
50
+ return 1
51
+
52
+ try:
53
+ print(f"正在加载音频文件: {audio_file}")
54
+ # 加载音频文件
55
+ audio, _ = load_audio(audio_file)
56
+
57
+ print(f"音频信息: 时长={audio.duration_seconds:.2f}秒, 通道数={audio.channels}, 采样率={audio.frame_rate}Hz")
58
+
59
+ # 进行转录
60
+ print("开始转录...")
61
+ result = transcribe_audio(audio, model_name=model_name, device=device)
62
+
63
+ # 输出结果
64
+ print("\n转录结果:")
65
+ print("-" * 50)
66
+ print(f"检测到的语言: {result.language}")
67
+ print(f"完整文本: {result.text}")
68
+ print("\n分段信息:")
69
+ for i, segment in enumerate(result.segments, 1):
70
+ print(f"分段 {i}: [{segment['start']:.2f}s - {segment['end']:.2f}s] {segment['text']}")
71
+
72
+ return 0
73
+
74
+ except Exception as e:
75
+ print(f"错误: {str(e)}")
76
+ return 1
77
+
78
+
79
+ if __name__ == "__main__":
80
+ main()
examples/simple_diarization.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ 简单的说话人标注示例程序
5
+ 从本地文件读取音频并进行说话人分离
6
+ """
7
+
8
+ import os
9
+ import sys
10
+ from pathlib import Path
11
+
12
+ # 添加项目根目录到Python路径
13
+ sys.path.insert(0, str(Path(__file__).parent.parent))
14
+
15
+ from src.podcast_transcribe.audio import load_audio
16
+ from src.podcast_transcribe.diarization.diarization_pyannote_mlx import diarize_audio as diarize_audio_mlx
17
+ from src.podcast_transcribe.diarization.diarization_pyannote_transformers import diarize_audio as diarize_audio_transformers
18
+
19
+
20
+ def main():
21
+ """主函数"""
22
+ audio_file = Path.joinpath(Path(__file__).parent, "input", "lex_ai_john_carmack_1.wav") # 播客音频文件路径
23
+ # audio_file = "/Users/konie/Desktop/voices/history_in_the_baking.mp3" # 播客音频文件路径
24
+ model_name = "pyannote/speaker-diarization-3.1" # 说话人分离模型名称
25
+ hf_token = "" # Hugging Face API 令牌
26
+ device = "mps" # 设备类型
27
+
28
+ # 检查文件是否存在
29
+ if not os.path.exists(audio_file):
30
+ print(f"错误:文件 '{audio_file}' 不存在")
31
+ return 1
32
+
33
+ # 检查令牌是否设置
34
+ if not hf_token:
35
+ print("错误:未设置HF_TOKEN环境变量,请设置后再运行")
36
+ return 1
37
+
38
+ try:
39
+ print(f"正在加载音频文件: {audio_file}")
40
+ # 加载音频文件
41
+ audio, _ = load_audio(audio_file)
42
+
43
+ print(f"音频信息: 时长={audio.duration_seconds:.2f}秒, 通道数={audio.channels}, 采样率={audio.frame_rate}Hz")
44
+
45
+ # 根据model_name选择合适的实现
46
+ if "pyannote/speaker-diarization" in model_name:
47
+ # 使用transformers版本进行说话人分离
48
+ print(f"使用transformers版本处理模型: {model_name}")
49
+ result = diarize_audio_transformers(audio, model_name=model_name, token=hf_token, device=device, segmentation_batch_size=128)
50
+ version_name = "Transformers"
51
+ else:
52
+ # 使用MLX版本进行说话人分离
53
+ print(f"使用MLX版本处理模型: {model_name}")
54
+ result = diarize_audio_mlx(audio, model_name=model_name, token=hf_token, device=device, segmentation_batch_size=128)
55
+ version_name = "MLX"
56
+
57
+ # 输出结果
58
+ print(f"\n{version_name}版本说话人分离结果:")
59
+ print("-" * 50)
60
+ print(f"检测到的说话人数量: {result.num_speakers}")
61
+ print(f"分段总数: {len(result.segments)}")
62
+
63
+ print("\n分段详情:")
64
+ for i, segment in enumerate(result.segments, 1):
65
+ start = segment["start"]
66
+ end = segment["end"]
67
+ speaker = segment["speaker"]
68
+ duration = end - start
69
+ print(f"分段 {i}: [{start:.2f}s - {end:.2f}s] (时长: {duration:.2f}s) 说话人: {speaker}")
70
+
71
+ return 0
72
+
73
+ except Exception as e:
74
+ print(f"错误: {str(e)}")
75
+ return 1
76
+
77
+
78
+ if __name__ == "__main__":
79
+ sys.exit(main())
examples/simple_llm.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # 添加项目根目录到Python路径
3
+ import sys
4
+ from pathlib import Path
5
+
6
+ sys.path.insert(0, str(Path(__file__).parent.parent))
7
+
8
+ from src.podcast_transcribe.llm.llm_phi4_transfomers import Phi4TransformersChatCompletion
9
+ from src.podcast_transcribe.llm.llm_gemma_mlx import GemmaMLXChatCompletion
10
+ from src.podcast_transcribe.llm.llm_gemma_transfomers import GemmaTransformersChatCompletion
11
+
12
+
13
+ if __name__ == "__main__":
14
+ # 示例用法:
15
+ print("正在初始化 LLM 聊天补全...")
16
+ try:
17
+ model_name = "google/gemma-3-4b-it"
18
+ use_4bit_quantization = False
19
+
20
+ # gemma_chat = GemmaMLXChatCompletion(model_name="mlx-community/gemma-3-12b-it-4bit-DWQ")
21
+ # 或者,如果您有更小、更快的模型,可以尝试使用,例如:"mlx-community/gemma-2b-it-8bit"
22
+ if model_name.startswith("mlx-community"):
23
+ gemma_chat = GemmaMLXChatCompletion(model_name=model_name)
24
+ elif model_name.startswith("microsoft"):
25
+ gemma_chat = Phi4TransformersChatCompletion(model_name=model_name, use_4bit_quantization=use_4bit_quantization)
26
+ else:
27
+ gemma_chat = GemmaTransformersChatCompletion(model_name=model_name, use_4bit_quantization=use_4bit_quantization)
28
+
29
+ print("\n--- 示例 1: 简单用户查询 ---")
30
+ messages_example1 = [
31
+ {"role": "user", "content": "你好,你是谁?"}
32
+ ]
33
+ response1 = gemma_chat.create(messages=messages_example1, max_tokens=50)
34
+ print("响应 1:")
35
+ print(f" 助手: {response1['choices'][0]['message']['content']}")
36
+ print(f" 用量: {response1['usage']}")
37
+
38
+ print("\n--- 示例 2: 带历史记录的对话 ---")
39
+ messages_example2 = [
40
+ {"role": "user", "content": "法国的首都是哪里?"},
41
+ {"role": "assistant", "content": "法国的首都是巴黎。"},
42
+ {"role": "user", "content": "你能告诉我一个关于它的有趣事实吗?"}
43
+ ]
44
+ response2 = gemma_chat.create(messages=messages_example2, max_tokens=100, temperature=0.8)
45
+ print("响应 2:")
46
+ print(f" 助手: {response2['choices'][0]['message']['content']}")
47
+ print(f" 用量: {response2['usage']}")
48
+
49
+ print("\n--- 示例 3: 系统提示 (实验性,效果取决于模型微调) ---")
50
+ messages_example3 = [
51
+ {"role": "system", "content": "你是一位富有诗意的助手,擅长用富有创意的方式解释复杂的编程概念。"},
52
+ {"role": "user", "content": "解释一下编程中递归的概念。"}
53
+ ]
54
+ response3 = gemma_chat.create(messages=messages_example3, max_tokens=150)
55
+ print("响应 3:")
56
+ print(f" 助手: {response3['choices'][0]['message']['content']}")
57
+ print(f" 用量: {response3['usage']}")
58
+
59
+ print("\n--- 示例 4: 使用 max_tokens 强制缩短响应 ---")
60
+ messages_example4 = [
61
+ {"role": "user", "content": "给我讲一个关于勇敢骑士的很长的故事。"}
62
+ ]
63
+ response4 = gemma_chat.create(messages=messages_example4, max_tokens=20) # 非常短
64
+ print("响应 4:")
65
+ print(f" 助手: {response4['choices'][0]['message']['content']}")
66
+ print(f" 用量: {response4['usage']}")
67
+ if response4['usage']['completion_tokens'] >= 20:
68
+ print(" 注意:由于 max_tokens,补全可能已被截断。")
69
+
70
+
71
+ except Exception as e:
72
+ print(f"示例用法期间发生错误: {e}")
73
+ import traceback
74
+ traceback.print_exc()
examples/simple_rss_parser.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # 添加项目根目录到Python路径
3
+ import sys
4
+ from pathlib import Path
5
+
6
+
7
+ sys.path.insert(0, str(Path(__file__).parent.parent))
8
+
9
+ from src.podcast_transcribe.rss.podcast_rss_parser import parse_podcast_rss
10
+
11
+ if __name__ == '__main__':
12
+ # 使用示例:
13
+ lex_fridman_rss = "https://feeds.buzzsprout.com/2460059.rss"
14
+ print(f"正在解析 Lex Fridman Podcast RSS: {lex_fridman_rss}")
15
+ podcast_data = parse_podcast_rss(lex_fridman_rss)
16
+
17
+ if podcast_data:
18
+ print(f"Podcast Title: {podcast_data.title}")
19
+ print(f"Podcast Link: {podcast_data.link}")
20
+ print(f"Podcast Description: {podcast_data.description[:200] if podcast_data.description else 'N/A'}...")
21
+ print(f"Podcast Author: {podcast_data.author}")
22
+ print(f"Podcast Image URL: {podcast_data.image_url}")
23
+ print(f"Total episodes found: {len(podcast_data.episodes)}")
24
+
25
+ if podcast_data.episodes:
26
+ print("\n--- Sample Episode ---")
27
+ sample_episode = podcast_data.episodes[0]
28
+ print(f" 标题: {sample_episode.title}")
29
+ print(f" 发布日期: {sample_episode.published_date}")
30
+ print(f" 链接: {sample_episode.link}")
31
+ print(f" 音频 URL: {sample_episode.audio_url}")
32
+ print(f" GUID: {sample_episode.guid}")
33
+ print(f" 时长: {sample_episode.duration}")
34
+ print(f" 季: {sample_episode.season}")
35
+ print(f" 集数: {sample_episode.episode_number}")
36
+ print(f" 剧集类型: {sample_episode.episode_type}")
37
+ print(f" 摘要: {sample_episode.summary[:200] if sample_episode.summary else 'N/A'}...")
38
+ print(f" Shownotes (前 300 字符): {sample_episode.shownotes[:300] if sample_episode.shownotes else 'N/A'}...")
39
+ else:
40
+ print("解析播客 RSS feed 失败。")
examples/simple_speaker_identify.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 添加项目根目录到Python路径
2
+ import json
3
+ import sys
4
+ from pathlib import Path
5
+ import os
6
+
7
+ sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
8
+
9
+ from src.podcast_transcribe.rss.podcast_rss_parser import parse_rss_xml_content
10
+ from src.podcast_transcribe.schemas import EnhancedSegment, CombinedTranscriptionResult
11
+ from src.podcast_transcribe.summary.speaker_identify import SpeakerIdentifier
12
+
13
+ if __name__ == '__main__':
14
+ transcribe_result_dump_file = Path.joinpath(Path(__file__).parent, "output", "lex_ai_john_carmack_1.transcription.json")
15
+ podcast_rss_xml_file = Path.joinpath(Path(__file__).parent, "input", "lexfridman.com.rss.xml")
16
+
17
+ # Load the transcription result
18
+ if not os.path.exists(transcribe_result_dump_file):
19
+ print(f"错误:转录结果文件 '{transcribe_result_dump_file}' 不存在。请先运行 combined_transcription.py 生成结果。")
20
+ sys.exit(1)
21
+
22
+ with open(transcribe_result_dump_file, "r", encoding="utf-8") as f:
23
+ # transcription_result = json.load(f) # 旧代码
24
+ data = json.load(f)
25
+ segments_data = data.get("segments", [])
26
+ # 确保 segments_data 中的每个元素都是字典,以避免在 EnhancedSegment(**seg) 时出错
27
+ # 假设 EnhancedSegment 的字段与 JSON 中 segment 字典的键完全对应
28
+ enhanced_segments = []
29
+ for seg_dict in segments_data:
30
+ if isinstance(seg_dict, dict):
31
+ enhanced_segments.append(EnhancedSegment(**seg_dict))
32
+ else:
33
+ # 处理非字典类型 segment 的情况,例如记录日志或抛出错误
34
+ print(f"警告: 在JSON中发现非字典类型的segment: {seg_dict}")
35
+
36
+ transcription_result = CombinedTranscriptionResult(
37
+ segments=enhanced_segments,
38
+ text=data.get("text", ""),
39
+ language=data.get("language", ""),
40
+ num_speakers=data.get("num_speakers", 0)
41
+ )
42
+
43
+ # 打印加载的 CombinedTranscriptionResult 对象的一些信息以供验证
44
+ print(f"\\n成功从JSON加载 CombinedTranscriptionResult 对象:")
45
+ print(f"类型: {type(transcription_result)}")
46
+
47
+ # Load the podcast RSS XML file
48
+ with open(podcast_rss_xml_file, "r") as f:
49
+ podcast_rss_xml = f.read()
50
+ mock_podcast_info = parse_rss_xml_content(podcast_rss_xml)
51
+
52
+
53
+ # 查找标题已 "#309" 开头的剧集
54
+ mock_episode_info = next((episode for episode in mock_podcast_info.episodes if episode.title.startswith("#309")), None)
55
+ if not mock_episode_info:
56
+ raise ValueError("Could not find episode with title starting with '#309'")
57
+
58
+
59
+ speaker_identifier = SpeakerIdentifier(
60
+ llm_model_name="mlx-community/gemma-3-12b-it-4bit-DWQ",
61
+ llm_provider="gemma-mlx"
62
+ )
63
+
64
+ # 3. Call the function
65
+ print("\\n--- Test Case 1: Normal execution ---")
66
+ speaker_names = speaker_identifier.recognize_speaker_names(transcription_result.segments, mock_podcast_info, mock_episode_info)
67
+ print("\\nRecognized Speaker Names (Test Case 1):")
68
+ print(json.dumps(speaker_names, ensure_ascii=False, indent=2))
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pydub>=0.25.1
2
+ numpy>=2.2.5
3
+ pyannote.audio>=3.3.2
4
+ transformers>=4.51.3
5
+ torch>=2.7.0
6
+ torchaudio>=2.7.0
7
+ soundfile>=0.13.1
8
+ feedparser>=6.0.11
9
+ requests>=2.32.3
10
+ gradio>=5.30.0
11
+
12
+ # 可选依赖 - whisper.cpp 绑定
13
+ pywhispercpp>=1.3.0
14
+ bitsandbytes>=0.42.0
15
+ accelerate>=1.6.0
16
+
17
+ # MLX特定依赖 - 仅适用于Apple Silicon Mac
18
+ # mlx>=0.25.2
19
+ # mlx-lm>=0.24.0
20
+ # parakeet-mlx=0.2.6
21
+ # mlx-whisper=0.4.2
src/.DS_Store ADDED
Binary file (6.15 kB). View file
 
src/podcast_transcribe/.DS_Store ADDED
Binary file (6.15 kB). View file
 
src/podcast_transcribe/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ # 设置根日志级别为INFO,这样第三方包默认使用INFO级别
4
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
5
+
6
+ # 单独设置podcast_transcribe包的日志级别
7
+ package_logger = logging.getLogger("podcast_transcribe")
8
+ package_logger.setLevel(logging.INFO)
src/podcast_transcribe/asr/asr_base.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 语音识别模块基类
3
+ """
4
+
5
+ import os
6
+ import numpy as np
7
+ from pydub import AudioSegment
8
+ from typing import Dict, List, Union, Optional, Tuple
9
+ # from dataclasses import dataclass # dataclass is now imported from schemas if needed or already there
10
+ import logging
11
+
12
+ from ..schemas import TranscriptionResult # Added import
13
+
14
+ # 配置日志
15
+ logger = logging.getLogger("asr")
16
+
17
+
18
+ class BaseTranscriber:
19
+ """统一的语音识别基类,支持MLX和Transformers等多种框架"""
20
+
21
+ def __init__(
22
+ self,
23
+ model_name: str,
24
+ device: str = None,
25
+ ):
26
+ """
27
+ 初始化转录器
28
+
29
+ 参数:
30
+ model_name: 模型名称
31
+ device: 推理设备,'cpu'或'cuda',对于MLX框架此参数可忽略
32
+ """
33
+ self.model_name = model_name
34
+ self.device = device
35
+ self.pipeline = None # 用于Transformers
36
+ self.model = None # 用于MLX等其他框架
37
+
38
+ logger.info(f"初始化转录器,模型: {model_name}" + (f",设备: {device}" if device else ""))
39
+
40
+ # 子类需要实现_load_model方法
41
+ self._load_model()
42
+
43
+ def _load_model(self):
44
+ """
45
+ 加载模型(需要在子类中实现)
46
+ """
47
+ raise NotImplementedError("子类必须实现_load_model方法")
48
+
49
+ def _prepare_audio(self, audio: AudioSegment) -> AudioSegment:
50
+ """
51
+ 准备音频数据
52
+
53
+ 参数:
54
+ audio: 输入的AudioSegment对象
55
+
56
+ 返回:
57
+ 处理后的AudioSegment对象
58
+ """
59
+ logger.debug(f"准备音频数据: 时长={len(audio)/1000:.2f}秒, 采样率={audio.frame_rate}Hz, 声道数={audio.channels}")
60
+
61
+ # 确保采样率为16kHz
62
+ if audio.frame_rate != 16000:
63
+ logger.debug(f"重采样音频从 {audio.frame_rate}Hz 到 16000Hz")
64
+ audio = audio.set_frame_rate(16000)
65
+
66
+ # 确保是单声道
67
+ if audio.channels > 1:
68
+ logger.debug(f"将{audio.channels}声道音频转换为单声道")
69
+ audio = audio.set_channels(1)
70
+
71
+ logger.debug(f"音频处理完成")
72
+
73
+ return audio
74
+
75
+ def _detect_language(self, text: str) -> str:
76
+ """
77
+ 简单的语言检测(基于经验规则)
78
+
79
+ 参数:
80
+ text: 识别出的文本
81
+
82
+ 返回:
83
+ 检测到的语言代码
84
+ """
85
+ # 简单的规则检测,实际应用中应使用更准确的语言检测
86
+ chinese_chars = len([c for c in text if '\u4e00' <= c <= '\u9fff'])
87
+ chinese_ratio = chinese_chars / len(text) if text else 0
88
+ logger.debug(f"语言检测: 中文字符比例 = {chinese_ratio:.2f}")
89
+
90
+ if chinese_chars > len(text) * 0.3:
91
+ return "zh"
92
+ return "en"
93
+
94
+ def _convert_segments(self, model_result) -> List[Dict[str, Union[float, str]]]:
95
+ """
96
+ 将模型的分段结果转换为所需格式(需要在子类中实现)
97
+
98
+ 参数:
99
+ model_result: 模型返回的结果
100
+
101
+ 返回:
102
+ 转换后的分段列表
103
+ """
104
+ raise NotImplementedError("子类必须实现_convert_segments方法")
105
+
106
+ def transcribe(self, audio: AudioSegment, chunk_duration_s: int = 30, overlap_s: int = 5) -> TranscriptionResult:
107
+ """
108
+ 转录音频,支持长音频分块处理。
109
+
110
+ 参数:
111
+ audio: 要转录的AudioSegment对象
112
+ chunk_duration_s: 分块处理的块时长(秒)。如果音频短于此,则不分块。
113
+ overlap_s: 分块间的重叠时长(秒)。
114
+
115
+ 返回:
116
+ TranscriptionResult对象,包含转录结果
117
+ """
118
+ logger.info(f"开始转录 {len(audio)/1000:.2f} 秒的音频。分块设置: 块时长={chunk_duration_s}s, 重叠={overlap_s}s")
119
+
120
+ if overlap_s >= chunk_duration_s and len(audio)/1000.0 > chunk_duration_s :
121
+ logger.error("重叠时长必须小于块时长。")
122
+ raise ValueError("overlap_s 必须小于 chunk_duration_s。")
123
+
124
+ total_duration_ms = len(audio)
125
+ chunk_duration_ms = chunk_duration_s * 1000
126
+ overlap_ms = overlap_s * 1000
127
+
128
+ if total_duration_ms <= chunk_duration_ms:
129
+ logger.debug("音频时长不大于设定块时长,直接进行完整转录。")
130
+ processed_audio = self._prepare_audio(audio)
131
+ samples = np.array(processed_audio.get_array_of_samples(), dtype=np.float32) / 32768.0
132
+
133
+ try:
134
+ model_result = self._perform_transcription(samples)
135
+ text = self._get_text_from_result(model_result)
136
+ segments = self._convert_segments(model_result)
137
+ language = self._detect_language(text)
138
+
139
+ logger.info(f"单块转录完成,语言: {language},文本长度: {len(text)},分段数: {len(segments)}")
140
+ return TranscriptionResult(text=text, segments=segments, language=language)
141
+ except Exception as e:
142
+ logger.error(f"单块转录失败: {str(e)}", exc_info=True)
143
+ raise RuntimeError(f"单块转录失败: {str(e)}")
144
+
145
+ # 长音频分块处理
146
+ final_segments = []
147
+ # current_pos_ms 指的是当前块要处理的"新内容"的起始点在原始音频中的位置
148
+ current_pos_ms = 0
149
+
150
+ while current_pos_ms < total_duration_ms:
151
+ # 计算当前块实际送入模型处理的音频的起始和结束时间点
152
+ # 对于第一个块,start_process_ms = 0
153
+ # 对于后续块,start_process_ms 会向左回退 overlap_ms 以包含重叠区域
154
+ start_process_ms = max(0, current_pos_ms - overlap_ms)
155
+ end_process_ms = min(start_process_ms + chunk_duration_ms, total_duration_ms)
156
+
157
+ # 如果计算出的块起始点已经等于或超过总时长,说明处理完毕
158
+ if start_process_ms >= total_duration_ms:
159
+ break
160
+
161
+ chunk_audio = audio[start_process_ms:end_process_ms]
162
+
163
+ logger.info(f"处理音频块: {start_process_ms/1000.0:.2f}s - {end_process_ms/1000.0:.2f}s (新内容起始于: {current_pos_ms/1000.0:.2f}s)")
164
+
165
+ if len(chunk_audio) == 0:
166
+ logger.warning(f"生成了一个空的音频块,跳过。起始: {start_process_ms/1000.0:.2f}s, 结束: {end_process_ms/1000.0:.2f}s")
167
+ # 必须推进 current_pos_ms 以避免死循环
168
+ advance_ms = chunk_duration_ms - overlap_ms
169
+ if advance_ms <= 0: # 应该在函数开始时已检查 overlap_s < chunk_duration_s
170
+ raise RuntimeError("块推进时长配置错误,可能导致死循环。")
171
+ current_pos_ms += advance_ms
172
+ continue
173
+
174
+ processed_chunk_audio = self._prepare_audio(chunk_audio)
175
+ samples = np.array(processed_chunk_audio.get_array_of_samples(), dtype=np.float32) / 32768.0
176
+
177
+ try:
178
+ model_result = self._perform_transcription(samples)
179
+ segments_chunk = self._convert_segments(model_result)
180
+
181
+ for seg in segments_chunk:
182
+ # seg["start"] 和 seg["end"] 是相对于当前块 (chunk_audio) 的起始点(即0)
183
+ # 计算 segment 在原始完整音频中的绝对起止时间
184
+ global_seg_start_s = start_process_ms / 1000.0 + seg["start"]
185
+ global_seg_end_s = start_process_ms / 1000.0 + seg["end"]
186
+
187
+ # 核心去重逻辑:
188
+ # 我们只接受那些真实开始于 current_pos_ms / 1000.0 之后的 segment。
189
+ # current_pos_ms 是当前块应该贡献的"新"内容的开始时间。
190
+ # 对于第一个块 (current_pos_ms == 0),所有 segment 都被接受(只要它们的 start >= 0)。
191
+ # 对于后续块,只有当 segment 的全局开始时间 >= 当前块新内容的开始时间时,才添加。
192
+ if global_seg_start_s >= current_pos_ms / 1000.0:
193
+ final_segments.append({
194
+ "start": global_seg_start_s,
195
+ "end": global_seg_end_s,
196
+ "text": seg["text"]
197
+ })
198
+ # 特殊处理第一个块,因为 current_pos_ms 为 0,上面的条件 global_seg_start_s >= 0 总是满足。
199
+ # 但为了更清晰,如果不是第一个块,但 segment 跨越了 current_pos_ms,
200
+ # 它的起始部分在重叠区,结束部分在非重叠区。
201
+ # 当前逻辑是,如果它的 global_seg_start_s < current_pos_ms / 1000.0,它就被丢弃。
202
+ # 这是为了确保不重复记录重叠区域的开头部分。
203
+ # 如果一个 segment 完全在重叠区内且在前一个块已被记录,此逻辑可避免重复。
204
+
205
+ except Exception as e:
206
+ logger.error(f"处理音频块 {start_process_ms/1000.0:.2f}s - {end_process_ms/1000.0:.2f}s 失败: {str(e)}", exc_info=True)
207
+
208
+ # 更新下一个"新内容"块的起始位置
209
+ advance_ms = chunk_duration_ms - overlap_ms
210
+ current_pos_ms += advance_ms
211
+
212
+ # 对收集到的所有 segments 按开始时间排序
213
+ final_segments.sort(key=lambda s: s["start"])
214
+
215
+ # 可选:进一步清理 segments,例如合并非常接近且文本连续的,或移除完全重复的
216
+ cleaned_segments = []
217
+ if final_segments:
218
+ cleaned_segments.append(final_segments[0])
219
+ for i in range(1, len(final_segments)):
220
+ prev_s = cleaned_segments[-1]
221
+ curr_s = final_segments[i]
222
+ # 简单的去重:如果时间戳和文本都几乎一样,则认为是重复
223
+ if abs(curr_s["start"] - prev_s["start"]) < 0.01 and \
224
+ abs(curr_s["end"] - prev_s["end"]) < 0.01 and \
225
+ curr_s["text"] == prev_s["text"]:
226
+ continue
227
+
228
+ # 如果当前 segment 的开始时间在前一个 segment 的结束时间之前,
229
+ # 并且文本有明显重叠,可能需要更智能的合并。
230
+ # 目前的逻辑通过 global_seg_start_s >= current_pos_ms / 1000.0 过滤,
231
+ # 已经大大减少了直接的 segment 重复。
232
+ # 此处的清理更多是处理模型在边界可能产生的一些微小偏差。
233
+ # 如果上一个segment的结束时间比当前segment的开始时间还要晚,说明有重叠,
234
+ # 且上一个segment包含了当前segment的开始部分。
235
+ # 这种情况下,可以考虑调整上一个的结束,或当前segment的开始和文本。
236
+ # 为简单起见,暂时直接添加,相信之前的过滤已处理主要重叠。
237
+ if curr_s["start"] < prev_s["end"] and prev_s["text"].endswith(curr_s["text"][:len(prev_s["text"]) - int((prev_s["end"] - curr_s["start"])*10) ]): # 粗略检查
238
+ # 如果curr_s的开始部分被prev_s覆盖,并且文本也对应,则调整curr_s
239
+ # pass # 暂时不处理这种细微重叠,依赖模型切分
240
+ cleaned_segments.append(curr_s) # 仍添加,依赖后续文本拼接
241
+ else:
242
+ cleaned_segments.append(curr_s)
243
+
244
+ final_text = " ".join([s["text"] for s in cleaned_segments]).strip()
245
+ language = self._detect_language(final_text)
246
+
247
+ logger.info(f"分块转录完成。最终文本长度: {len(final_text)}, 分段数: {len(cleaned_segments)}")
248
+
249
+ return TranscriptionResult(
250
+ text=final_text,
251
+ segments=cleaned_segments,
252
+ language=language
253
+ )
254
+
255
+ def _perform_transcription(self, audio_data):
256
+ """
257
+ 执行转录(需要在子类中实现)
258
+
259
+ 参数:
260
+ audio_data: 音频数据(numpy数组)
261
+
262
+ 返回:
263
+ 模型的转录结果
264
+ """
265
+ raise NotImplementedError("子类必须实现_perform_transcription方法")
266
+
267
+ def _get_text_from_result(self, result):
268
+ """
269
+ 从结果中获取文本(需要在子类中实现)
270
+
271
+ 参数:
272
+ result: 模型的转录结果
273
+
274
+ 返回:
275
+ 转录的文本
276
+ """
277
+ raise NotImplementedError("子类必须实现_get_text_from_result方法")
src/podcast_transcribe/asr/asr_distil_whisper_mlx.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 基于MLX实现的语音识别模块,使用distil-whisper-large-v3模型
3
+ """
4
+
5
+ import os
6
+ from pydub import AudioSegment
7
+ from typing import Dict, List, Union
8
+ import logging
9
+ import numpy as np
10
+ import mlx_whisper
11
+
12
+ # 导入基类
13
+ from .asr_base import BaseTranscriber, TranscriptionResult
14
+
15
+ # 配置日志
16
+ logger = logging.getLogger("asr")
17
+
18
+
19
+ class MLXDistilWhisperTranscriber(BaseTranscriber):
20
+ """使用MLX加载和运行distil-whisper-large-v3模型的转录器"""
21
+
22
+ def __init__(
23
+ self,
24
+ model_name: str = "mlx-community/distil-whisper-large-v3",
25
+ ):
26
+ """
27
+ 初始化转录器
28
+
29
+ 参数:
30
+ model_name: 模型名称
31
+ """
32
+ super().__init__(model_name=model_name)
33
+
34
+ def _load_model(self):
35
+ """加载Distil Whisper模型"""
36
+ try:
37
+ # 懒加载mlx-whisper
38
+ try:
39
+ import mlx_whisper
40
+ except ImportError:
41
+ raise ImportError("请先安装mlx-whisper库: pip install mlx-whisper")
42
+
43
+ logger.info(f"开始加载模型 {self.model_name}")
44
+ self.model = mlx_whisper.load_models.load_model(self.model_name)
45
+ logger.info(f"模型加载成功")
46
+ except Exception as e:
47
+ logger.error(f"加载模型失败: {str(e)}", exc_info=True)
48
+ raise RuntimeError(f"加载模型失败: {str(e)}")
49
+
50
+ def _convert_segments(self, result) -> List[Dict[str, Union[float, str]]]:
51
+ """
52
+ 将模型的分段结果转换为所需格式
53
+
54
+ 参数:
55
+ result: 模型返回的结果
56
+
57
+ 返回:
58
+ 转换后的分段列表
59
+ """
60
+ segments = []
61
+
62
+ for segment in result.get("segments", []):
63
+ segments.append({
64
+ "start": segment.get("start", 0.0),
65
+ "end": segment.get("end", 0.0),
66
+ "text": segment.get("text", "").strip()
67
+ })
68
+
69
+ return segments
70
+
71
+ def _perform_transcription(self, audio_data):
72
+ """
73
+ 执行转录
74
+
75
+ 参数:
76
+ audio_data: 音频数据(numpy数组)
77
+
78
+ 返回:
79
+ 模型的转录结果
80
+ """
81
+ return mlx_whisper.transcribe(audio_data, path_or_hf_repo=self.model_name)
82
+
83
+ def _get_text_from_result(self, result):
84
+ """
85
+ 从结果中获取文本
86
+
87
+ 参数:
88
+ result: 模型的转录结果
89
+
90
+ 返回:
91
+ 转录的文本
92
+ """
93
+ return result.get("text", "")
94
+
95
+
96
+ def transcribe_audio(
97
+ audio_segment: AudioSegment,
98
+ model_name: str = "mlx-community/distil-whisper-large-v3",
99
+ ) -> TranscriptionResult:
100
+ """
101
+ 使用MLX和distil-whisper-large-v3模型转录音频
102
+
103
+ 参数:
104
+ audio_segment: 输入的AudioSegment对象
105
+ model_name: 使用的模型名称
106
+
107
+ 返回:
108
+ TranscriptionResult对象,包含转录的文本、分段和语言
109
+ """
110
+ logger.info(f"调用transcribe_audio函数,音频长度: {len(audio_segment)/1000:.2f}秒")
111
+ transcriber = MLXDistilWhisperTranscriber(model_name=model_name)
112
+ return transcriber.transcribe(audio_segment)
src/podcast_transcribe/asr/asr_distil_whisper_transformers.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 基于Transformers实现的语音识别模块,使用distil-whisper-large-v3.5模型
3
+ """
4
+
5
+ import os
6
+ from pydub import AudioSegment
7
+ from typing import Dict, List, Union
8
+ import logging
9
+ import numpy as np
10
+
11
+ # 导入基类
12
+ from .asr_base import BaseTranscriber, TranscriptionResult
13
+
14
+ # 配置日志
15
+ logger = logging.getLogger("asr")
16
+
17
+
18
+ class TransformersDistilWhisperTranscriber(BaseTranscriber):
19
+ """使用Transformers加载和运行distil-whisper-large-v3.5模型的转录器"""
20
+
21
+ def __init__(
22
+ self,
23
+ model_name: str = "distil-whisper/distil-large-v3.5",
24
+ device: str = "cpu",
25
+ ):
26
+ """
27
+ 初始化转录器
28
+
29
+ 参数:
30
+ model_name: 模型名称
31
+ device: 推理设备,'cpu'或'cuda'
32
+ """
33
+ super().__init__(model_name=model_name, device=device)
34
+
35
+ def _load_model(self):
36
+ """加载Distil Whisper模型"""
37
+ try:
38
+ # 懒加载transformers
39
+ try:
40
+ from transformers import pipeline
41
+ except ImportError:
42
+ raise ImportError("请先安装transformers库: pip install transformers")
43
+
44
+ logger.info(f"开始加载模型 {self.model_name}")
45
+ self.pipeline = pipeline(
46
+ "automatic-speech-recognition",
47
+ model=self.model_name,
48
+ device=0 if self.device == "cuda" else -1,
49
+ return_timestamps=True,
50
+ chunk_length_s=30,
51
+ batch_size=16,
52
+ )
53
+ logger.info(f"模型加载成功")
54
+ except Exception as e:
55
+ logger.error(f"加载模型失败: {str(e)}", exc_info=True)
56
+ raise RuntimeError(f"加载模型失败: {str(e)}")
57
+
58
+ def _convert_segments(self, result) -> List[Dict[str, Union[float, str]]]:
59
+ """
60
+ 将模型的分段结果转换为所需格式
61
+
62
+ 参数:
63
+ result: 模型返回的结果
64
+
65
+ 返回:
66
+ 转换后的分段列表
67
+ """
68
+ segments = []
69
+
70
+ # transformers pipeline 的结果格式
71
+ if "chunks" in result:
72
+ for chunk in result["chunks"]:
73
+ segments.append({
74
+ "start": chunk["timestamp"][0] if chunk["timestamp"][0] is not None else 0.0,
75
+ "end": chunk["timestamp"][1] if chunk["timestamp"][1] is not None else 0.0,
76
+ "text": chunk["text"].strip()
77
+ })
78
+ else:
79
+ # 如果没有分段信息,创建一个单一分段
80
+ segments.append({
81
+ "start": 0.0,
82
+ "end": 0.0, # 无法确定结束时间
83
+ "text": result.get("text", "").strip()
84
+ })
85
+
86
+ return segments
87
+
88
+ def _perform_transcription(self, audio_data):
89
+ """
90
+ 执行转录
91
+
92
+ 参数:
93
+ audio_data: 音频数据(numpy数组)
94
+
95
+ 返回:
96
+ 模型的转录结果
97
+ """
98
+ # transformers pipeline 接受numpy数组作为输入
99
+ # 音频数据已经在_prepare_audio中确保是16kHz采样率
100
+ return self.pipeline(audio_data)
101
+
102
+ def _get_text_from_result(self, result):
103
+ """
104
+ 从结果中获取文本
105
+
106
+ 参数:
107
+ result: 模型的转录结果
108
+
109
+ 返回:
110
+ 转录的文本
111
+ """
112
+ return result.get("text", "")
113
+
114
+
115
+ def transcribe_audio(
116
+ audio_segment: AudioSegment,
117
+ model_name: str = "distil-whisper/distil-large-v3.5",
118
+ device: str = "cpu",
119
+ ) -> TranscriptionResult:
120
+ """
121
+ 使用Transformers和distil-whisper-large-v3.5模型转录音频
122
+
123
+ 参数:
124
+ audio_segment: 输入的AudioSegment对象
125
+ model_name: 使用的模型名称
126
+ device: 推理设备,'cpu'或'cuda'
127
+
128
+ 返回:
129
+ TranscriptionResult对象,包含转录的文本、分段和语言
130
+ """
131
+ logger.info(f"调用transcribe_audio函数,音频长度: {len(audio_segment)/1000:.2f}秒")
132
+ transcriber = TransformersDistilWhisperTranscriber(model_name=model_name, device=device)
133
+ return transcriber.transcribe(audio_segment)
src/podcast_transcribe/asr/asr_parakeet_mlx.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 基于MLX实现的语音识别模块,使用parakeet-tdt模型
3
+ """
4
+
5
+ import os
6
+ from pydub import AudioSegment
7
+ from typing import Dict, List, Union
8
+ import logging
9
+ import tempfile
10
+ import numpy as np
11
+ import soundfile as sf
12
+
13
+ # 导入基类
14
+ from .asr_base import BaseTranscriber, TranscriptionResult
15
+
16
+ # 配置日志
17
+ logger = logging.getLogger("asr")
18
+
19
+
20
+ class MLXParakeetTranscriber(BaseTranscriber):
21
+ """使用MLX加载和运行parakeet-tdt-0.6b-v2模型的转录器"""
22
+
23
+ def __init__(
24
+ self,
25
+ model_name: str = "mlx-community/parakeet-tdt-0.6b-v2",
26
+ ):
27
+ """
28
+ 初始化转录器
29
+
30
+ 参数:
31
+ model_name: 模型名称
32
+ """
33
+ super().__init__(model_name=model_name)
34
+
35
+ def _load_model(self):
36
+ """加载Parakeet模型"""
37
+ try:
38
+ # 懒加载parakeet_mlx
39
+ try:
40
+ from parakeet_mlx import from_pretrained
41
+ except ImportError:
42
+ raise ImportError("请先安装parakeet-mlx库: pip install parakeet-mlx")
43
+
44
+ logger.info(f"开始加载模型 {self.model_name}")
45
+ self.model = from_pretrained(self.model_name)
46
+ logger.info(f"模型加载成功")
47
+ except Exception as e:
48
+ logger.error(f"加载模型失败: {str(e)}", exc_info=True)
49
+ raise RuntimeError(f"加载模型失败: {str(e)}")
50
+
51
+ def _convert_segments(self, aligned_result) -> List[Dict[str, Union[float, str]]]:
52
+ """
53
+ 将模型的分段结果转换为所需格式
54
+
55
+ 参数:
56
+ aligned_result: 模型返回的分段结果
57
+
58
+ 返回:
59
+ 转换后的分段列表
60
+ """
61
+ segments = []
62
+
63
+ for sentence in aligned_result.sentences:
64
+ segments.append({
65
+ "start": sentence.start,
66
+ "end": sentence.end,
67
+ "text": sentence.text
68
+ })
69
+
70
+ return segments
71
+
72
+ def _perform_transcription(self, audio_data):
73
+ """
74
+ 执行转录
75
+
76
+ 参数:
77
+ audio_data: 音频数据(numpy数组)
78
+
79
+ 返回:
80
+ 模型的转录结果
81
+ """
82
+ # 由于parakeet-mlx可能不直接支持numpy数组输入
83
+ # 创建临时文件并写入音频数据
84
+ with tempfile.NamedTemporaryFile(suffix='.wav', delete=True) as temp_file:
85
+ # 确保数据在[-1, 1]范围内
86
+ if audio_data.max() > 1.0 or audio_data.min() < -1.0:
87
+ audio_data = np.clip(audio_data, -1.0, 1.0)
88
+
89
+ # 写入临时文件
90
+ sf.write(temp_file.name, audio_data, 16000, 'PCM_16')
91
+
92
+ # 使用临时文件进行转录
93
+ result = self.model.transcribe(temp_file.name)
94
+
95
+ return result
96
+
97
+ def _get_text_from_result(self, result):
98
+ """
99
+ 从结果中获取文本
100
+
101
+ 参数:
102
+ result: 模型的转录结果
103
+
104
+ 返回:
105
+ 转录的文本
106
+ """
107
+ return result.text
108
+
109
+
110
+ def transcribe_audio(
111
+ audio_segment: AudioSegment,
112
+ model_name: str = "mlx-community/parakeet-tdt-0.6b-v2",
113
+ ) -> TranscriptionResult:
114
+ """
115
+ 使用MLX和parakeet-tdt模型转录音频
116
+
117
+ 参数:
118
+ audio_segment: 输入的AudioSegment对象
119
+ model_name: 使用的模型名称
120
+
121
+ 返回:
122
+ TranscriptionResult对象,包含转录的文本、分段和语言
123
+ """
124
+ logger.info(f"调用transcribe_audio函数,音频长度: {len(audio_segment)/1000:.2f}秒")
125
+ transcriber = MLXParakeetTranscriber(model_name=model_name)
126
+ return transcriber.transcribe(audio_segment)
src/podcast_transcribe/asr/asr_router.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ASR模型调用路由器
3
+ 根据传递的provider参数调用不同的ASR实现,支持延迟加载
4
+ """
5
+
6
+ import logging
7
+ from typing import Dict, Any, Optional, Callable
8
+ from pydub import AudioSegment
9
+ from .asr_base import TranscriptionResult
10
+ from . import asr_parakeet_mlx
11
+ from . import asr_distil_whisper_mlx
12
+ from . import asr_distil_whisper_transformers
13
+
14
+ # 配置日志
15
+ logger = logging.getLogger("asr")
16
+
17
+
18
+ class ASRRouter:
19
+ """ASR模型调用路由器,支持多种ASR实现的统一调用"""
20
+
21
+ def __init__(self):
22
+ """初始化路由器"""
23
+ self._loaded_modules = {} # 用于缓存已加载的模块
24
+ self._transcribers = {} # 用于缓存已实例化的转录器
25
+
26
+ # 定义支持的provider配置
27
+ self._provider_configs = {
28
+ "parakeet_mlx": {
29
+ "module_path": "asr_parakeet_mlx",
30
+ "function_name": "transcribe_audio",
31
+ "default_model": "mlx-community/parakeet-tdt-0.6b-v2",
32
+ "supported_params": ["model_name"],
33
+ "description": "基于MLX的Parakeet模型"
34
+ },
35
+ "distil_whisper_mlx": {
36
+ "module_path": "asr_distil_whisper_mlx",
37
+ "function_name": "transcribe_audio",
38
+ "default_model": "mlx-community/distil-whisper-large-v3",
39
+ "supported_params": ["model_name"],
40
+ "description": "基于MLX的Distil Whisper模型"
41
+ },
42
+ "distil_whisper_transformers": {
43
+ "module_path": "asr_distil_whisper_transformers",
44
+ "function_name": "transcribe_audio",
45
+ "default_model": "distil-whisper/distil-large-v3.5",
46
+ "supported_params": ["model_name", "device"],
47
+ "description": "基于Transformers的Distil Whisper模型"
48
+ }
49
+ }
50
+
51
+ def _lazy_load_module(self, provider: str):
52
+ """
53
+ 获取指定provider的模块
54
+
55
+ 参数:
56
+ provider: provider名称
57
+
58
+ 返回:
59
+ 对应的模块
60
+ """
61
+ if provider not in self._provider_configs:
62
+ raise ValueError(f"不支持的provider: {provider}")
63
+
64
+ if provider not in self._loaded_modules:
65
+ module_path = self._provider_configs[provider]["module_path"]
66
+ logger.info(f"获取模块: {module_path}")
67
+
68
+ # 根据module_path返回对应的模块
69
+ if module_path == "asr_parakeet_mlx":
70
+ module = asr_parakeet_mlx
71
+ elif module_path == "asr_distil_whisper_mlx":
72
+ module = asr_distil_whisper_mlx
73
+ elif module_path == "asr_distil_whisper_transformers":
74
+ module = asr_distil_whisper_transformers
75
+ else:
76
+ raise ImportError(f"未找到模块: {module_path}")
77
+
78
+ self._loaded_modules[provider] = module
79
+ logger.info(f"模块 {module_path} 获取成功")
80
+
81
+ return self._loaded_modules[provider]
82
+
83
+ def _get_transcribe_function(self, provider: str) -> Callable:
84
+ """
85
+ 获取指定provider的转录函数
86
+
87
+ 参数:
88
+ provider: provider名称
89
+
90
+ 返回:
91
+ 转录函数
92
+ """
93
+ module = self._lazy_load_module(provider)
94
+ function_name = self._provider_configs[provider]["function_name"]
95
+
96
+ if not hasattr(module, function_name):
97
+ raise AttributeError(f"模块中未找到函数: {function_name}")
98
+
99
+ return getattr(module, function_name)
100
+
101
+ def _filter_params(self, provider: str, params: Dict[str, Any]) -> Dict[str, Any]:
102
+ """
103
+ 过滤参数,只保留指定provider支持的参数
104
+
105
+ 参数:
106
+ provider: provider名称
107
+ params: 原始参数字典
108
+
109
+ 返回:
110
+ 过滤后的参数字典
111
+ """
112
+ supported_params = self._provider_configs[provider]["supported_params"]
113
+ filtered_params = {}
114
+
115
+ for param in supported_params:
116
+ if param in params:
117
+ filtered_params[param] = params[param]
118
+
119
+ # 如果没有指定model_name,使用默认模型
120
+ if "model_name" not in filtered_params and "model_name" in supported_params:
121
+ filtered_params["model_name"] = self._provider_configs[provider]["default_model"]
122
+
123
+ return filtered_params
124
+
125
+ def transcribe(
126
+ self,
127
+ audio_segment: AudioSegment,
128
+ provider: str,
129
+ **kwargs
130
+ ) -> TranscriptionResult:
131
+ """
132
+ 统一的音频转录接口
133
+
134
+ 参数:
135
+ audio_segment: 输入的AudioSegment对象
136
+ provider: ASR提供者名称
137
+ **kwargs: 其他参数,如model_name, device等
138
+
139
+ 返回:
140
+ TranscriptionResult��象
141
+ """
142
+ logger.info(f"使用provider '{provider}' 进行音频转录,音频长度: {len(audio_segment)/1000:.2f}秒")
143
+
144
+ if provider not in self._provider_configs:
145
+ available_providers = list(self._provider_configs.keys())
146
+ raise ValueError(f"不支持的provider: {provider}。支持的provider: {available_providers}")
147
+
148
+ try:
149
+ # 获取转录函数
150
+ transcribe_func = self._get_transcribe_function(provider)
151
+
152
+ # 过滤并准备参数
153
+ filtered_kwargs = self._filter_params(provider, kwargs)
154
+
155
+ logger.debug(f"调用 {provider} 转录函数,参数: {filtered_kwargs}")
156
+
157
+ # 执行转录
158
+ result = transcribe_func(audio_segment, **filtered_kwargs)
159
+
160
+ logger.info(f"转录完成,文本长度: {len(result.text)}字符")
161
+ return result
162
+
163
+ except Exception as e:
164
+ logger.error(f"使用provider '{provider}' 转录音频失败: {str(e)}", exc_info=True)
165
+ raise RuntimeError(f"转录失败: {str(e)}")
166
+
167
+ def get_available_providers(self) -> Dict[str, str]:
168
+ """
169
+ 获取所有可用的provider及其描述
170
+
171
+ 返回:
172
+ provider名称到描述的映射
173
+ """
174
+ return {
175
+ provider: config["description"]
176
+ for provider, config in self._provider_configs.items()
177
+ }
178
+
179
+ def get_provider_info(self, provider: str) -> Dict[str, Any]:
180
+ """
181
+ 获取指定provider的详细信息
182
+
183
+ 参数:
184
+ provider: provider名称
185
+
186
+ 返回:
187
+ provider的配置信息
188
+ """
189
+ if provider not in self._provider_configs:
190
+ raise ValueError(f"不支持的provider: {provider}")
191
+
192
+ return self._provider_configs[provider].copy()
193
+
194
+
195
+ # 创建全局路由器实例
196
+ _router = ASRRouter()
197
+
198
+
199
+ def transcribe_audio(
200
+ audio_segment: AudioSegment,
201
+ provider: str = "distil_whisper_transformers",
202
+ model_name: Optional[str] = None,
203
+ device: str = "cpu",
204
+ **kwargs
205
+ ) -> TranscriptionResult:
206
+ """
207
+ 统一的音频转录接口函数
208
+
209
+ 参数:
210
+ audio_segment: 输入的AudioSegment对象
211
+ provider: ASR提供者,可选值:
212
+ - "parakeet_mlx": 基于MLX的Parakeet模型
213
+ - "distil_whisper_mlx": 基于MLX的Distil Whisper模型
214
+ - "distil_whisper_transformers": 基于Transformers的Distil Whisper模型
215
+ model_name: 模型名称,如果不指定则使用默认模型
216
+ device: 推理设备,仅对transformers provider有效
217
+ **kwargs: 其他参数
218
+
219
+ 返回:
220
+ TranscriptionResult对象,包含转录的文本、分段和语言
221
+
222
+ 示例:
223
+ # 使用默认MLX Distil Whisper模型
224
+ result = transcribe_audio(audio_segment, provider="distil_whisper_mlx")
225
+
226
+ # 使用Parakeet模型
227
+ result = transcribe_audio(audio_segment, provider="parakeet_mlx")
228
+
229
+ # 使用Transformers模型并指定设备
230
+ result = transcribe_audio(
231
+ audio_segment,
232
+ provider="distil_whisper_transformers",
233
+ device="cuda"
234
+ )
235
+
236
+ # 使用自定义模型
237
+ result = transcribe_audio(
238
+ audio_segment,
239
+ provider="distil_whisper_mlx",
240
+ model_name="mlx-community/whisper-large-v3"
241
+ )
242
+ """
243
+ # 准备参数
244
+ params = kwargs.copy()
245
+ if model_name is not None:
246
+ params["model_name"] = model_name
247
+ if device != "cpu":
248
+ params["device"] = device
249
+
250
+ return _router.transcribe(audio_segment, provider, **params)
251
+
252
+
253
+ def get_available_providers() -> Dict[str, str]:
254
+ """
255
+ 获取所有可用的ASR提供者
256
+
257
+ 返回:
258
+ provider名称到描述的映射
259
+ """
260
+ return _router.get_available_providers()
261
+
262
+
263
+ def get_provider_info(provider: str) -> Dict[str, Any]:
264
+ """
265
+ 获取指定provider的详细信息
266
+
267
+ 参数:
268
+ provider: provider名称
269
+
270
+ 返回:
271
+ provider的配置信息
272
+ """
273
+ return _router.get_provider_info(provider)
src/podcast_transcribe/audio.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 音频处理工具模块
3
+ """
4
+
5
+ import numpy as np
6
+ from io import BytesIO
7
+ from pydub import AudioSegment
8
+ from typing import Tuple, Dict, Any
9
+
10
+
11
+ def load_audio(audio_file: str, target_sample_rate: int = 16000, mono: bool = True) -> Tuple[AudioSegment, np.ndarray]:
12
+ """
13
+ 加载音频文件并转换为目标采样率和通道数
14
+
15
+ 参数:
16
+ audio_file: 音频文件路径
17
+ target_sample_rate: 目标采样率,默认16kHz
18
+ mono: 是否转换为单声道,默认True
19
+
20
+ 返回:
21
+ AudioSegment对象和对应的numpy数组
22
+ """
23
+ try:
24
+ audio = AudioSegment.from_file(audio_file)
25
+
26
+ # 转换为单声道(如果需要)
27
+ if mono and audio.channels > 1:
28
+ audio = audio.set_channels(1)
29
+
30
+ # 转换采样率
31
+ if audio.frame_rate != target_sample_rate:
32
+ audio = audio.set_frame_rate(target_sample_rate)
33
+
34
+ # 获取音频波形(用于pyannote)
35
+ waveform = np.array(audio.get_array_of_samples()).astype(np.float32) / 32768.0
36
+
37
+ return audio, waveform
38
+
39
+ except Exception as e:
40
+ raise RuntimeError(f"无法加载音频文件: {str(e)}")
41
+
42
+
43
+ def extract_audio_segment(audio: AudioSegment, start_ms: int, end_ms: int) -> BytesIO:
44
+ """
45
+ 从音频中提取指定时间段
46
+
47
+ 参数:
48
+ audio: AudioSegment对象
49
+ start_ms: 开始时间(毫秒)
50
+ end_ms: 结束时间(毫秒)
51
+
52
+ 返回:
53
+ 包含音频段的BytesIO对象
54
+ """
55
+ try:
56
+ sub_audio = audio[start_ms:end_ms]
57
+ fp = BytesIO()
58
+ sub_audio.export(fp, format="wav")
59
+ fp.seek(0)
60
+ return fp
61
+ except Exception as e:
62
+ raise RuntimeError(f"无法提取音频段: {str(e)}")
src/podcast_transcribe/diarization/diarization_pyannote_mlx.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 基于pyannote/speaker-diarization-3.1模型实现的说话人分离模块
3
+ """
4
+
5
+ import os
6
+ import numpy as np
7
+ from pydub import AudioSegment
8
+ from typing import Any, Dict, List, Mapping, Text, Union, Optional, Tuple
9
+ import logging
10
+ import torch
11
+
12
+ from .diarizer_base import BaseDiarizer
13
+ from ..schemas import DiarizationResult
14
+
15
+ # 配置日志
16
+ logger = logging.getLogger("diarization")
17
+
18
+ class PyannoteTranscriber(BaseDiarizer):
19
+ """使用pyannote/speaker-diarization-3.1模型进行说话人分离"""
20
+
21
+ def __init__(
22
+ self,
23
+ model_name: str = "pyannote/speaker-diarization-3.1",
24
+ token: Optional[str] = None,
25
+ device: str = "cpu",
26
+ segmentation_batch_size: int = 32,
27
+ ):
28
+ """
29
+ 初始化说话人分离器
30
+
31
+ 参数:
32
+ model_name: 模型名称
33
+ token: Hugging Face令牌,用于访问模型
34
+ device: 推理设备,'cpu'或'cuda'
35
+ segmentation_batch_size: 分割批处理大小,默认为32
36
+ """
37
+ super().__init__(model_name, token, device, segmentation_batch_size)
38
+
39
+ # 加载模型
40
+ self._load_model()
41
+
42
+ def _load_model(self):
43
+ """加载pyannote模型"""
44
+ try:
45
+ # 懒加载pyannote.audio
46
+ try:
47
+ from pyannote.audio import Pipeline
48
+ except ImportError:
49
+ raise ImportError("请先安装pyannote.audio库: pip install pyannote.audio")
50
+
51
+ if not self.token:
52
+ raise ValueError("需要提供Hugging Face令牌才能使用pyannote模型。请通过参数传入或设置HF_TOKEN环境变量。")
53
+
54
+ logger.info(f"开始加载模型 {self.model_name}")
55
+ self.pipeline = Pipeline.from_pretrained(
56
+ self.model_name,
57
+ use_auth_token=self.token
58
+ )
59
+
60
+ # 设置设备
61
+ self.pipeline.to(torch.device(self.device))
62
+
63
+ # 设置分割批处理大小
64
+ if hasattr(self.pipeline, "segmentation_batch_size"):
65
+ logger.info(f"设置分割批处理大小: {self.segmentation_batch_size}")
66
+ self.pipeline.segmentation_batch_size = self.segmentation_batch_size
67
+
68
+ logger.info(f"模型加载成功")
69
+ except Exception as e:
70
+ logger.error(f"加载模型失败: {str(e)}", exc_info=True)
71
+ raise RuntimeError(f"加载模型失败: {str(e)}")
72
+
73
+
74
+
75
+ def diarize(self, audio: AudioSegment) -> DiarizationResult:
76
+ """
77
+ 对音频进行说话人分离
78
+
79
+ 参数:
80
+ audio: 要处理的AudioSegment对象
81
+
82
+ 返回:
83
+ DiarizationResult对象,包含分段结果和说话人数量
84
+ """
85
+ logger.info(f"开始处理 {len(audio)/1000:.2f} 秒的音频进行说话人分离")
86
+
87
+ # 准备音频输入
88
+ temp_audio_path = self._prepare_audio(audio)
89
+
90
+ try:
91
+ # 执行说话人分离
92
+ logger.debug("开始执行说话人分离")
93
+ from pyannote.audio.pipelines.utils.hook import ProgressHook
94
+
95
+ # 自定义 ProgressHook 类
96
+ class CustomProgressHook(ProgressHook):
97
+ def __call__(
98
+ self,
99
+ step_name: Text,
100
+ step_artifact: Any,
101
+ file: Optional[Mapping] = None,
102
+ total: Optional[int] = None,
103
+ completed: Optional[int] = None,
104
+ ):
105
+ if completed is not None:
106
+ logger.info(f"处理中 {step_name}: ({completed/total*100:.1f}%)")
107
+ else:
108
+ logger.info(f"已完成 {step_name}")
109
+
110
+ with CustomProgressHook() as hook:
111
+ diarization = self.pipeline(temp_audio_path, hook=hook)
112
+
113
+ # 转换分段结果
114
+ segments, num_speakers = self._convert_segments(diarization)
115
+
116
+ logger.info(f"说话人分离完成,检测到 {num_speakers} 个说话人,生成 {len(segments)} 个分段")
117
+
118
+ return DiarizationResult(
119
+ segments=segments,
120
+ num_speakers=num_speakers
121
+ )
122
+
123
+ except Exception as e:
124
+ logger.error(f"说话人分离失败: {str(e)}", exc_info=True)
125
+ raise RuntimeError(f"说话人分离失败: {str(e)}")
126
+ finally:
127
+ # 删除临时文件
128
+ if os.path.exists(temp_audio_path):
129
+ os.remove(temp_audio_path)
130
+
131
+
132
+ def diarize_audio(
133
+ audio_segment: AudioSegment,
134
+ model_name: str = "pyannote/speaker-diarization-3.1",
135
+ token: Optional[str] = None,
136
+ device: str = "cpu",
137
+ segmentation_batch_size: int = 32,
138
+ ) -> DiarizationResult:
139
+ """
140
+ 使用pyannote模型对音频进行说话人��离
141
+
142
+ 参数:
143
+ audio_segment: 输入的AudioSegment对象
144
+ model_name: 使用的模型名称
145
+ token: Hugging Face令牌
146
+ device: 推理设备,'cpu'、'cuda'、'mps'
147
+ segmentation_batch_size: 分割批处理大小,默认为32
148
+
149
+ 返回:
150
+ DiarizationResult对象,包含分段和说话人数量
151
+ """
152
+ logger.info(f"调用diarize_audio函数,音频长度: {len(audio_segment)/1000:.2f}秒")
153
+ transcriber = PyannoteTranscriber(model_name=model_name, token=token, device=device, segmentation_batch_size=segmentation_batch_size)
154
+ return transcriber.diarize(audio_segment)
src/podcast_transcribe/diarization/diarization_pyannote_transformers.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 基于pyannote.audio库调用pyannote/speaker-diarization-3.1模型实现的说话人分离模块
3
+ """
4
+
5
+ import os
6
+ import numpy as np
7
+ from pydub import AudioSegment
8
+ from typing import Any, Dict, List, Mapping, Text, Union, Optional, Tuple
9
+ import logging
10
+ import torch
11
+
12
+ from .diarizer_base import BaseDiarizer
13
+ from ..schemas import DiarizationResult
14
+
15
+ # 配置日志
16
+ logger = logging.getLogger("diarization")
17
+
18
+
19
+ class PyannoteTransformersTranscriber(BaseDiarizer):
20
+ """使用pyannote.audio库调用pyannote/speaker-diarization-3.1模型进行说话人分离"""
21
+
22
+ def __init__(
23
+ self,
24
+ model_name: str = "pyannote/speaker-diarization-3.1",
25
+ token: Optional[str] = None,
26
+ device: str = "cpu",
27
+ segmentation_batch_size: int = 32,
28
+ ):
29
+ """
30
+ 初始化说话人分离器
31
+
32
+ 参数:
33
+ model_name: 模型名称
34
+ token: Hugging Face令牌,用于访问模型
35
+ device: 推理设备,'cpu'或'cuda'
36
+ segmentation_batch_size: 分割批处理大小,默认为32
37
+ """
38
+ super().__init__(model_name, token, device, segmentation_batch_size)
39
+
40
+ # 加载模型
41
+ self._load_model()
42
+
43
+ def _load_model(self):
44
+ """使用pyannote.audio加载模型"""
45
+ try:
46
+ # 检查依赖库
47
+ try:
48
+ from pyannote.audio import Pipeline
49
+ except ImportError:
50
+ raise ImportError("请先安装pyannote.audio库: pip install pyannote.audio")
51
+
52
+ if not self.token:
53
+ raise ValueError("需要提供Hugging Face令牌才能使用pyannote模型。请通过参数传入或设置HF_TOKEN环境变量。")
54
+
55
+ logger.info(f"开始使用pyannote.audio加载模型 {self.model_name}")
56
+
57
+ # 使用pyannote.audio Pipeline加载说话人分离模型
58
+ self.pipeline = Pipeline.from_pretrained(
59
+ self.model_name,
60
+ use_auth_token=self.token
61
+ )
62
+
63
+ # 设置设备
64
+ logger.info(f"将模型移动到设备: {self.device}")
65
+ self.pipeline.to(torch.device(self.device))
66
+
67
+ # 设置分割批处理大小
68
+ if hasattr(self.pipeline, "segmentation_batch_size"):
69
+ logger.info(f"设置分割批处理大小: {self.segmentation_batch_size}")
70
+ self.pipeline.segmentation_batch_size = self.segmentation_batch_size
71
+
72
+ logger.info(f"pyannote.audio模型加载成功")
73
+
74
+ except Exception as e:
75
+ logger.error(f"加载模型失败: {str(e)}", exc_info=True)
76
+ raise RuntimeError(f"模型加载失败: {str(e)}")
77
+
78
+ def diarize(self, audio: AudioSegment) -> DiarizationResult:
79
+ """
80
+ 对音频进行说话人分离
81
+
82
+ 参数:
83
+ audio: 要处理的AudioSegment对象
84
+
85
+ 返回:
86
+ DiarizationResult对象,包含分段结果和说话人数量
87
+ """
88
+ logger.info(f"开始使用pyannote.audio处理 {len(audio)/1000:.2f} 秒的音频进行说话人分离")
89
+
90
+ # 准备音频输入
91
+ temp_audio_path = self._prepare_audio(audio)
92
+
93
+ try:
94
+ # 执行说话人分离
95
+ logger.debug("开始执行说话人分离")
96
+
97
+ # 使用自定义 ProgressHook 来显示进度
98
+ try:
99
+ from pyannote.audio.pipelines.utils.hook import ProgressHook
100
+
101
+ class CustomProgressHook(ProgressHook):
102
+ def __call__(
103
+ self,
104
+ step_name: Text,
105
+ step_artifact: Any,
106
+ file: Optional[Mapping] = None,
107
+ total: Optional[int] = None,
108
+ completed: Optional[int] = None,
109
+ ):
110
+ if completed is not None and total is not None:
111
+ percentage = completed / total * 100
112
+ logger.info(f"处理中 {step_name}: ({percentage:.1f}%)")
113
+ else:
114
+ logger.info(f"已完成 {step_name}")
115
+
116
+ with CustomProgressHook() as hook:
117
+ diarization = self.pipeline(temp_audio_path, hook=hook)
118
+
119
+ except ImportError:
120
+ # 如果ProgressHook不可用,直接执行
121
+ logger.info("ProgressHook不可用,直接执行说话人分离")
122
+ diarization = self.pipeline(temp_audio_path)
123
+
124
+ # 转换分段结果
125
+ segments, num_speakers = self._convert_segments(diarization)
126
+
127
+ logger.info(f"说话人分离完成,检测到 {num_speakers} 个说话人,生成 {len(segments)} 个分段")
128
+
129
+ return DiarizationResult(
130
+ segments=segments,
131
+ num_speakers=num_speakers
132
+ )
133
+
134
+ except Exception as e:
135
+ logger.error(f"说话人分离失败: {str(e)}", exc_info=True)
136
+ raise RuntimeError(f"说话人分离失败: {str(e)}")
137
+ finally:
138
+ # 删除临时文件
139
+ if os.path.exists(temp_audio_path):
140
+ os.remove(temp_audio_path)
141
+
142
+
143
+ def diarize_audio(
144
+ audio_segment: AudioSegment,
145
+ model_name: str = "pyannote/speaker-diarization-3.1",
146
+ token: Optional[str] = None,
147
+ device: str = "cpu",
148
+ segmentation_batch_size: int = 32,
149
+ ) -> DiarizationResult:
150
+ """
151
+ 使用pyannote.audio调用pyannote模型对音频进行说话人分离
152
+
153
+ 参数:
154
+ audio_segment: 输入的AudioSegment对象
155
+ model_name: 使用的模型名称
156
+ token: Hugging Face令牌
157
+ device: 推理设备,'cpu'、'cuda'、'mps'
158
+ segmentation_batch_size: 分割批处理大小,默认为32
159
+
160
+ 返回:
161
+ DiarizationResult对象,包含分段和说话人数量
162
+ """
163
+ logger.info(f"调用pyannote.audio版本diarize_audio函数,音频长度: {len(audio_segment)/1000:.2f}秒")
164
+ transcriber = PyannoteTransformersTranscriber(
165
+ model_name=model_name,
166
+ token=token,
167
+ device=device,
168
+ segmentation_batch_size=segmentation_batch_size
169
+ )
170
+ return transcriber.diarize(audio_segment)
src/podcast_transcribe/diarization/diarizer_base.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 说话人分离器基础类,包含可复用的方法
3
+ """
4
+
5
+ import os
6
+ import logging
7
+ from abc import ABC, abstractmethod
8
+ from pydub import AudioSegment
9
+ from typing import Any, Dict, List, Union, Optional, Tuple
10
+
11
+ from ..schemas import DiarizationResult
12
+
13
+ # 配置日志
14
+ logger = logging.getLogger("diarization")
15
+
16
+
17
+ class BaseDiarizer(ABC):
18
+ """说话人分离器基础类"""
19
+
20
+ def __init__(
21
+ self,
22
+ model_name: str,
23
+ token: Optional[str] = None,
24
+ device: str = "cpu",
25
+ segmentation_batch_size: int = 32,
26
+ ):
27
+ """
28
+ 初始化说话人分离器基础参数
29
+
30
+ 参数:
31
+ model_name: 模型名称
32
+ token: Hugging Face令牌,用于访问模型
33
+ device: 推理设备,'cpu'或'cuda'
34
+ segmentation_batch_size: 分割批处理大小,默认为32
35
+ """
36
+ self.model_name = model_name
37
+ self.token = token or os.environ.get("HF_TOKEN")
38
+ self.device = device
39
+ self.segmentation_batch_size = segmentation_batch_size
40
+
41
+ logger.info(f"初始化说话人分离器,模型: {model_name},设备: {device},分割批处理大小: {segmentation_batch_size}")
42
+
43
+ @abstractmethod
44
+ def _load_model(self):
45
+ """加载模型,子类需要实现"""
46
+ pass
47
+
48
+ def _prepare_audio(self, audio: AudioSegment) -> str:
49
+ """
50
+ 准备音频数据,保存为临时文件
51
+
52
+ 参数:
53
+ audio: 输入的AudioSegment对象
54
+
55
+ 返回:
56
+ 临时音频文件的路径
57
+ """
58
+ logger.debug(f"准备音频数据: 时长={len(audio)/1000:.2f}秒, 采样率={audio.frame_rate}Hz, 声道数={audio.channels}")
59
+
60
+ # 确保采样率为16kHz (pyannote模型要求)
61
+ if audio.frame_rate != 16000:
62
+ logger.debug(f"重采样音频从 {audio.frame_rate}Hz 到 16000Hz")
63
+ audio = audio.set_frame_rate(16000)
64
+
65
+ # 确保是单声道
66
+ if audio.channels > 1:
67
+ logger.debug(f"将{audio.channels}声道音频转换为单声道")
68
+ audio = audio.set_channels(1)
69
+
70
+ # 保存为临时文件
71
+ temp_audio_path = "_temp_audio_for_diarization.wav"
72
+ audio.export(temp_audio_path, format="wav")
73
+
74
+ logger.debug(f"音频处理完成,保存至: {temp_audio_path}")
75
+
76
+ return temp_audio_path
77
+
78
+ def _convert_segments(self, diarization) -> Tuple[List[Dict[str, Union[float, str, int]]], int]:
79
+ """
80
+ 将pyannote的分段结果转换为所需格式
81
+
82
+ 参数:
83
+ diarization: pyannote模型返回的分段结果
84
+
85
+ 返回:
86
+ 转换后的分段列表和说话人数量
87
+ """
88
+ segments = []
89
+ speakers = set()
90
+
91
+ # 遍历说话人分离结果
92
+ for turn, _, speaker in diarization.itertracks(yield_label=True):
93
+ segments.append({
94
+ "start": turn.start,
95
+ "end": turn.end,
96
+ "speaker": speaker
97
+ })
98
+ speakers.add(speaker)
99
+
100
+ # 按开始时间排序
101
+ segments.sort(key=lambda x: x["start"])
102
+
103
+ logger.debug(f"转换了 {len(segments)} 个分段,检测到 {len(speakers)} 个说话人")
104
+
105
+ return segments, len(speakers)
106
+
107
+ @abstractmethod
108
+ def diarize(self, audio: AudioSegment) -> DiarizationResult:
109
+ """
110
+ 对音频进行说话人分离,子类需要实现
111
+
112
+ 参数:
113
+ audio: 要处理的AudioSegment对象
114
+
115
+ 返回:
116
+ DiarizationResult对象,包含分段结果和说话人数量
117
+ """
118
+ pass
src/podcast_transcribe/diarization/diarizer_router.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 说话人分离模型调用路由器
3
+ 根据传递的provider参数调用不同的说话人分离实现,支持延迟加载
4
+ """
5
+
6
+ import logging
7
+ from typing import Dict, Any, Optional, Callable
8
+ from pydub import AudioSegment
9
+ from ..schemas import DiarizationResult
10
+ from . import diarization_pyannote_mlx
11
+ from . import diarization_pyannote_transformers
12
+
13
+ # 配置日志
14
+ logger = logging.getLogger("diarization")
15
+
16
+
17
+ class DiarizerRouter:
18
+ """说话人分离模型调用路由器,支持多种实现的统一调用"""
19
+
20
+ def __init__(self):
21
+ """初始化路由器"""
22
+ self._loaded_modules = {} # 用于缓存已加载的模块
23
+ self._diarizers = {} # 用于缓存已实例化的分离器
24
+
25
+ # 定义支持的provider配置
26
+ self._provider_configs = {
27
+ "pyannote_mlx": {
28
+ "module_path": "diarization_pyannote_mlx",
29
+ "function_name": "diarize_audio",
30
+ "default_model": "pyannote/speaker-diarization-3.1",
31
+ "supported_params": ["model_name", "token", "device", "segmentation_batch_size"],
32
+ "description": "基于pyannote.audio的原生MLX实现"
33
+ },
34
+ "pyannote_transformers": {
35
+ "module_path": "diarization_pyannote_transformers",
36
+ "function_name": "diarize_audio",
37
+ "default_model": "pyannote/speaker-diarization-3.1",
38
+ "supported_params": ["model_name", "token", "device", "segmentation_batch_size"],
39
+ "description": "基于transformers库调用pyannote模型"
40
+ }
41
+ }
42
+
43
+ def _lazy_load_module(self, provider: str):
44
+ """
45
+ 获取指定provider的模块
46
+
47
+ 参数:
48
+ provider: provider名称
49
+
50
+ 返回:
51
+ 对应的模块
52
+ """
53
+ if provider not in self._provider_configs:
54
+ raise ValueError(f"不支持的provider: {provider}")
55
+
56
+ if provider not in self._loaded_modules:
57
+ module_path = self._provider_configs[provider]["module_path"]
58
+ logger.info(f"获取模块: {module_path}")
59
+
60
+ # 根据module_path返回对应的模块
61
+ if module_path == "diarization_pyannote_mlx":
62
+ module = diarization_pyannote_mlx
63
+ elif module_path == "diarization_pyannote_transformers":
64
+ module = diarization_pyannote_transformers
65
+ else:
66
+ raise ImportError(f"未找到模块: {module_path}")
67
+
68
+ self._loaded_modules[provider] = module
69
+ logger.info(f"模块 {module_path} 获取成功")
70
+
71
+ return self._loaded_modules[provider]
72
+
73
+ def _get_diarize_function(self, provider: str) -> Callable:
74
+ """
75
+ 获取指定provider的说话人分离函数
76
+
77
+ 参数:
78
+ provider: provider名称
79
+
80
+ 返回:
81
+ 说话人分离函数
82
+ """
83
+ module = self._lazy_load_module(provider)
84
+ function_name = self._provider_configs[provider]["function_name"]
85
+
86
+ if not hasattr(module, function_name):
87
+ raise AttributeError(f"模块中未找到函数: {function_name}")
88
+
89
+ return getattr(module, function_name)
90
+
91
+ def _filter_params(self, provider: str, params: Dict[str, Any]) -> Dict[str, Any]:
92
+ """
93
+ 过滤参数,只保留指定provider支持的参数
94
+
95
+ 参数:
96
+ provider: provider名称
97
+ params: 原始参数字典
98
+
99
+ 返回:
100
+ 过滤后的参数字典
101
+ """
102
+ supported_params = self._provider_configs[provider]["supported_params"]
103
+ filtered_params = {}
104
+
105
+ for param in supported_params:
106
+ if param in params:
107
+ filtered_params[param] = params[param]
108
+
109
+ # 如果没有指定model_name,使用默认模型
110
+ if "model_name" not in filtered_params and "model_name" in supported_params:
111
+ filtered_params["model_name"] = self._provider_configs[provider]["default_model"]
112
+
113
+ return filtered_params
114
+
115
+ def diarize(
116
+ self,
117
+ audio_segment: AudioSegment,
118
+ provider: str,
119
+ **kwargs
120
+ ) -> DiarizationResult:
121
+ """
122
+ 统一的说话人分离接口
123
+
124
+ 参数:
125
+ audio_segment: 输入的AudioSegment对象
126
+ provider: 说话人分离提供者名称
127
+ **kwargs: 其他参数,如model_name, token, device, segmentation_batch_size等
128
+
129
+ 返回:
130
+ DiarizationResult对象
131
+ """
132
+ logger.info(f"使用provider '{provider}' 进行说话人分离,音频长度: {len(audio_segment)/1000:.2f}秒")
133
+
134
+ if provider not in self._provider_configs:
135
+ available_providers = list(self._provider_configs.keys())
136
+ raise ValueError(f"不支持的provider: {provider}。支持的provider: {available_providers}")
137
+
138
+ try:
139
+ # 获取说话人分离函数
140
+ diarize_func = self._get_diarize_function(provider)
141
+
142
+ # 过滤并准备参数
143
+ filtered_kwargs = self._filter_params(provider, kwargs)
144
+
145
+ logger.debug(f"调用 {provider} 说话人分离函数,参数: {filtered_kwargs}")
146
+
147
+ # 执行说话人分离
148
+ result = diarize_func(audio_segment, **filtered_kwargs)
149
+
150
+ logger.info(f"说话人分离完成,检测到 {result.num_speakers} 个说话人,生成 {len(result.segments)} 个分段")
151
+ return result
152
+
153
+ except Exception as e:
154
+ logger.error(f"使用provider '{provider}' 进行说话人分离失败: {str(e)}", exc_info=True)
155
+ raise RuntimeError(f"说话人分离失败: {str(e)}")
156
+
157
+ def get_available_providers(self) -> Dict[str, str]:
158
+ """
159
+ 获取所有可用的provider及其描述
160
+
161
+ 返回:
162
+ provider名称到描述的映射
163
+ """
164
+ return {
165
+ provider: config["description"]
166
+ for provider, config in self._provider_configs.items()
167
+ }
168
+
169
+ def get_provider_info(self, provider: str) -> Dict[str, Any]:
170
+ """
171
+ 获取指定provider的详细信息
172
+
173
+ 参数:
174
+ provider: provider名称
175
+
176
+ 返回:
177
+ provider的配置信息
178
+ """
179
+ if provider not in self._provider_configs:
180
+ raise ValueError(f"不支持的provider: {provider}")
181
+
182
+ return self._provider_configs[provider].copy()
183
+
184
+
185
+ # 创建全局路由器实例
186
+ _router = DiarizerRouter()
187
+
188
+
189
+ def diarize_audio(
190
+ audio_segment: AudioSegment,
191
+ provider: str = "pyannote_mlx",
192
+ model_name: Optional[str] = None,
193
+ token: Optional[str] = None,
194
+ device: str = "cpu",
195
+ segmentation_batch_size: int = 32,
196
+ **kwargs
197
+ ) -> DiarizationResult:
198
+ """
199
+ 统一的音频说话人分离接口函数
200
+
201
+ 参数:
202
+ audio_segment: 输入的AudioSegment对象
203
+ provider: 说话人分离提供者,可选值:
204
+ - "pyannote_mlx": 基于pyannote.audio的原生MLX实现
205
+ - "pyannote_transformers": 基于transformers库调用pyannote模型
206
+ model_name: 模型名称,如果不指定则使用默认模型
207
+ token: Hugging Face令牌,用于访问模型
208
+ device: 推理设备,'cpu'、'cuda'、'mps'
209
+ segmentation_batch_size: 分割批处理大小,默认为32
210
+ **kwargs: 其他参数
211
+
212
+ 返回:
213
+ DiarizationResult对象,包含分段结果和说话人数量
214
+
215
+ 示例:
216
+ # 使用默认pyannote MLX实现
217
+ result = diarize_audio(audio_segment, provider="pyannote_mlx", token="your_hf_token")
218
+
219
+ # 使用transformers实现
220
+ result = diarize_audio(
221
+ audio_segment,
222
+ provider="pyannote_transformers",
223
+ token="your_hf_token"
224
+ )
225
+
226
+ # 使用GPU设备
227
+ result = diarize_audio(
228
+ audio_segment,
229
+ provider="pyannote_mlx",
230
+ token="your_hf_token",
231
+ device="cuda"
232
+ )
233
+
234
+ # 自定义批处理大小
235
+ result = diarize_audio(
236
+ audio_segment,
237
+ provider="pyannote_mlx",
238
+ token="your_hf_token",
239
+ segmentation_batch_size=64
240
+ )
241
+ """
242
+ # 准备参数
243
+ params = kwargs.copy()
244
+ if model_name is not None:
245
+ params["model_name"] = model_name
246
+ if token is not None:
247
+ params["token"] = token
248
+ if device != "cpu":
249
+ params["device"] = device
250
+ if segmentation_batch_size != 32:
251
+ params["segmentation_batch_size"] = segmentation_batch_size
252
+
253
+ return _router.diarize(audio_segment, provider, **params)
254
+
255
+
256
+ def get_available_providers() -> Dict[str, str]:
257
+ """
258
+ 获取所有可用的说话人分离提供者
259
+
260
+ 返回:
261
+ provider名称到描述的映射
262
+ """
263
+ return _router.get_available_providers()
264
+
265
+
266
+ def get_provider_info(provider: str) -> Dict[str, Any]:
267
+ """
268
+ 获取指定provider的详细信息
269
+
270
+ 参数:
271
+ provider: provider名称
272
+
273
+ 返回:
274
+ provider的配置信息
275
+ """
276
+ return _router.get_provider_info(provider)
src/podcast_transcribe/llm/llm_base.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import uuid
3
+ import torch
4
+ from typing import List, Dict, Optional, Union, Literal
5
+ from abc import ABC, abstractmethod
6
+
7
+
8
+ class BaseChatCompletion(ABC):
9
+ """Gemma 聊天完成的基类,包含公共功能"""
10
+
11
+ def __init__(self, model_name: str):
12
+ self.model_name = model_name
13
+
14
+ @abstractmethod
15
+ def _load_model_and_tokenizer(self):
16
+ """加载模型和分词器的抽象方法,由子类实现"""
17
+ pass
18
+
19
+ @abstractmethod
20
+ def _generate_response(self, prompt_str: str, temperature: float, max_tokens: int, top_p: float, **kwargs) -> str:
21
+ """生成响应的抽象方法,由子类实现"""
22
+ pass
23
+
24
+ def _format_messages_for_gemma(self, messages: List[Dict[str, str]]) -> str:
25
+ """
26
+ 为Gemma格式化消息。
27
+ Gemma期望特定的格式,通常类似于:
28
+ <start_of_turn>user
29
+ {user_message}<end_of_turn>
30
+ <start_of_turn>model
31
+ {assistant_message}<end_of_turn>
32
+ ...
33
+ <start_of_turn>user
34
+ {current_user_message}<end_of_turn>
35
+ <start_of_turn>model
36
+ """
37
+ # 尝试使用分词器的聊天模板(如果可用)
38
+ try:
39
+ # Hugging Face分词器中的apply_chat_template方法
40
+ # 通常需要一个字典列表,每个字典包含'role'和'content'。
41
+ # 我们需要确保我们的`messages`格式兼容。
42
+ # add_generation_prompt=True 至关重要,以确保模型知道轮到它发言了。
43
+ return self.tokenizer.apply_chat_template(
44
+ messages, tokenize=False, add_generation_prompt=True
45
+ )
46
+ except Exception:
47
+ # 如果apply_chat_template失败或不可用,则回退到手动格式化
48
+ prompt_parts = []
49
+ for message in messages:
50
+ role = message.get("role")
51
+ content = message.get("content")
52
+ if role == "user":
53
+ prompt_parts.append(f"<start_of_turn>user\n{content}<end_of_turn>")
54
+ elif role == "assistant":
55
+ prompt_parts.append(f"<start_of_turn>model\n{content}<end_of_turn>")
56
+ elif role == "system": # Gemma可能不会以相同的方式显式使用'system',通常是前置的
57
+ # 对于Gemma,系统提示通常只是前置到第一个用户消息或隐式处理。
58
+ # 我们会在这里前置它,尽管其有效性取决于特定的Gemma微调。
59
+ # 一种常见的模式是在开头放置系统指令,不使用特殊标记。
60
+ # 然而,为了保持结构化,我们将尝试一种通用方法。
61
+ # 如果分词器在其模板中有特定的方式来处理系统提示,
62
+ # 那么`apply_chat_template`将是首选。
63
+ # 由于我们处于回退状态,这是一个最佳猜测。
64
+ # 一些模型期望系统提示在轮次结构之外,或者在最开始。
65
+ # 为了在回退中简化,我们只做前置处理。
66
+ # 如果`apply_chat_template`不可用,更健壮的解决方案是检查模型的特定聊天模板。
67
+ prompt_parts.insert(0, f"<start_of_turn>system\n{content}<end_of_turn>")
68
+
69
+ # 添加提示,让模型开始生成
70
+ prompt_parts.append("<start_of_turn>model")
71
+ return "\n".join(prompt_parts)
72
+
73
+ def _post_process_response(self, response_text: str, prompt_str: str) -> str:
74
+ """
75
+ 后处理生成的响应文本,清理提示和特殊标记
76
+ """
77
+ # 后处理:Gemma的输出可能包含输入提示或特殊标记。
78
+ # 我们需要清理这些,以仅返回助手的最新消息。
79
+ # 一种常见的模式是,模型输出将以我们给它的提示开始,
80
+ # 或者它可能包含 <start_of_turn>model 标记,然后是其响应。
81
+
82
+ # 如果模型输出包含提示,然后是新的响应:
83
+ if response_text.startswith(prompt_str):
84
+ assistant_message_content = response_text[len(prompt_str):].strip()
85
+ else:
86
+ # 如果模型不回显提示,则可能需要更复杂的清理。
87
+ # 对于Gemma,响应通常跟随提示的最后一部分 "<start_of_turn>model\n"。
88
+ # 让我们尝试找到最后一个 "<start_of_turn>model" 并获取其后的文本。
89
+ # 这是一种启发式方法,可能需要根据实际模型输出进行调整。
90
+ parts = response_text.split("<start_of_turn>model")
91
+ if len(parts) > 1:
92
+ assistant_message_content = parts[-1].strip()
93
+ # 进一步清理 <end_of_turn> 或其他特殊标记
94
+ assistant_message_content = assistant_message_content.split("<end_of_turn>")[0].strip()
95
+ else: # 如果上述方法不起作用,则回退
96
+ assistant_message_content = response_text.strip()
97
+
98
+ return assistant_message_content
99
+
100
+ def _calculate_tokens(self, prompt_str: str, assistant_message_content: str) -> Dict[str, int]:
101
+ """
102
+ 计算token数量(近似值,因为确切的OpenAI分词可能不同)
103
+ """
104
+ # 对于提示token,我们对输入到模型的字符串进行分词。
105
+ # 对于完成token,我们对生成的助手消息进行分词。
106
+ prompt_tokens = len(self.tokenizer.encode(prompt_str))
107
+ completion_tokens = len(self.tokenizer.encode(assistant_message_content))
108
+ total_tokens = prompt_tokens + completion_tokens
109
+
110
+ return {
111
+ "prompt_tokens": prompt_tokens,
112
+ "completion_tokens": completion_tokens,
113
+ "total_tokens": total_tokens
114
+ }
115
+
116
+ def _build_chat_completion_response(self, assistant_message_content: str, token_usage: Dict[str, int]) -> Dict:
117
+ """
118
+ 构建模仿OpenAI结构的响应对象
119
+ 基于: https://platform.openai.com/docs/api-reference/chat/object
120
+ """
121
+ # 获取完成的当前时间戳
122
+ created_timestamp = int(time.time())
123
+ completion_id = f"chatcmpl-{uuid.uuid4().hex}" # 创建一个唯一的ID
124
+
125
+ return {
126
+ "id": completion_id,
127
+ "object": "chat.completion",
128
+ "created": created_timestamp,
129
+ "model": self.model_name, # 报告我们使用的模型名称
130
+ "choices": [
131
+ {
132
+ "index": 0,
133
+ "message": {
134
+ "role": "assistant",
135
+ "content": assistant_message_content,
136
+ },
137
+ "finish_reason": "stop", # 假定为 "stop"
138
+ }
139
+ ],
140
+ "usage": token_usage,
141
+ }
142
+
143
+ def create(
144
+ self,
145
+ messages: List[Dict[str, str]],
146
+ temperature: float = 0.7,
147
+ max_tokens: int = 2048,
148
+ top_p: float = 1.0,
149
+ model: Optional[str] = None,
150
+ **kwargs,
151
+ ):
152
+ """
153
+ 创建聊天完成响应。
154
+ 模仿OpenAI的ChatCompletion.create方法。
155
+ """
156
+ if model and model != self.model_name:
157
+ # 这是一个简化的处理。在实际场景中,您可能希望加载新模型。
158
+ # 目前,我们将只打印一个警告并使用初始化的模型。
159
+ print(f"警告: 'model' 参数 ({model}) 与初始化的模型 ({self.model_name}) 不同。"
160
+ f"正在使用初始化的模型。要使用不同的模型,请重新初始化该类。")
161
+
162
+ # 为Gemma格式化消息
163
+ prompt_str = self._format_messages_for_gemma(messages)
164
+
165
+ # 生成响应(由子类实现)
166
+ response_text = self._generate_response(prompt_str, temperature, max_tokens, top_p, **kwargs)
167
+
168
+ # 后处理响应
169
+ assistant_message_content = self._post_process_response(response_text, prompt_str)
170
+
171
+ # 计算token使用量
172
+ token_usage = self._calculate_tokens(prompt_str, assistant_message_content)
173
+
174
+ # 构建响应对象
175
+ return self._build_chat_completion_response(assistant_message_content, token_usage)
176
+
177
+
178
+ class TransformersBaseChatCompletion(BaseChatCompletion):
179
+ """基于Transformers库的聊天完成基类,提供通用的设备管理和量化功能"""
180
+
181
+ def __init__(
182
+ self,
183
+ model_name: str,
184
+ use_4bit_quantization: bool = False,
185
+ device_map: Optional[str] = "auto",
186
+ device: Optional[str] = None,
187
+ trust_remote_code: bool = True,
188
+ torch_dtype: Optional[torch.dtype] = None
189
+ ):
190
+ super().__init__(model_name)
191
+ self.use_4bit_quantization = use_4bit_quantization
192
+ self.device_map = device_map
193
+ self.trust_remote_code = trust_remote_code
194
+ self.torch_dtype = torch_dtype or torch.float16
195
+ self.device = device
196
+
197
+ # 加载模型和分词器
198
+ self._load_model_and_tokenizer()
199
+
200
+ def _get_quantization_config(self):
201
+ """获取量化配置"""
202
+ if not self.use_4bit_quantization:
203
+ return None
204
+
205
+ if self.device and self.device.type == "mps":
206
+ print("警告: MPS 设备不支持 4bit 量化,将禁用量化")
207
+ self.use_4bit_quantization = False
208
+ return None
209
+
210
+ # 导入量化配置
211
+ try:
212
+ from transformers import BitsAndBytesConfig
213
+ except ImportError:
214
+ raise ImportError("请先安装 bitsandbytes 库: pip install bitsandbytes")
215
+
216
+ return BitsAndBytesConfig(
217
+ load_in_4bit=True,
218
+ bnb_4bit_compute_dtype=self.torch_dtype,
219
+ bnb_4bit_quant_type="nf4",
220
+ bnb_4bit_use_double_quant=True,
221
+ )
222
+
223
+ def _load_tokenizer(self):
224
+ """加载分词器"""
225
+ try:
226
+ from transformers import AutoTokenizer
227
+ except ImportError:
228
+ raise ImportError("请先安装 transformers 库: pip install transformers")
229
+
230
+ self.tokenizer = AutoTokenizer.from_pretrained(
231
+ self.model_name,
232
+ trust_remote_code=self.trust_remote_code
233
+ )
234
+
235
+ # 设置 pad_token 如果不存在
236
+ if self.tokenizer.pad_token is None:
237
+ self.tokenizer.pad_token = self.tokenizer.eos_token
238
+
239
+ def _load_model(self):
240
+ """加载模型"""
241
+ try:
242
+ from transformers import AutoModelForCausalLM
243
+ except ImportError:
244
+ raise ImportError("请先安装 transformers 库: pip install transformers")
245
+
246
+ print(f"正在加载模型: {self.model_name}")
247
+ print(f"4bit量化: {'启用' if self.use_4bit_quantization else '禁用'}")
248
+ print(f"目标设备: {self.device}")
249
+ print(f"设备映射: {self.device_map}")
250
+
251
+ # 配置模型加载参数
252
+ model_kwargs = {
253
+ "trust_remote_code": self.trust_remote_code,
254
+ "torch_dtype": self.torch_dtype,
255
+ }
256
+
257
+ # 处理量化配置
258
+ quantization_config = self._get_quantization_config()
259
+ if quantization_config:
260
+ model_kwargs["quantization_config"] = quantization_config
261
+ print(f"使用 4bit 量化配置")
262
+
263
+ # 处理设备映射
264
+ if self.device_map is not None:
265
+ if self.device and self.device.type == "mps":
266
+ print("警告: MPS 设备不支持 device_map,将手动管理设备")
267
+ else:
268
+ model_kwargs["device_map"] = self.device_map
269
+
270
+ # 加载模型
271
+ self.model = AutoModelForCausalLM.from_pretrained(
272
+ self.model_name,
273
+ **model_kwargs
274
+ )
275
+
276
+ # MPS 或手动设备管理
277
+ if self.device_map is None or (self.device and self.device.type == "mps"):
278
+ if not self.use_4bit_quantization:
279
+ print(f"手动移动模型到设备: {self.device}")
280
+ self.model = self.model.to(self.device)
281
+
282
+ print(f"模型 {self.model_name} 加载成功")
283
+
284
+ def _load_model_and_tokenizer(self):
285
+ """加载模型和分词器"""
286
+ try:
287
+ self._load_tokenizer()
288
+ self._load_model()
289
+ except Exception as e:
290
+ print(f"加载模型 {self.model_name} 时出错: {e}")
291
+ self._print_error_hints()
292
+ raise
293
+
294
+ def _print_error_hints(self):
295
+ """打印错误提示信息"""
296
+ print("请确保模型名称正确且可访问。")
297
+ if self.use_4bit_quantization:
298
+ print("如果使用量化,请确保已安装 bitsandbytes 库: pip install bitsandbytes")
299
+ if self.device and self.device.type == "mps":
300
+ print("MPS 设备注意事项:")
301
+ print("- 不支持 4bit 量化")
302
+ print("- 不支持 device_map")
303
+ print("- 确保 PyTorch 版本支持 MPS")
304
+
305
+ def _generate_response(
306
+ self,
307
+ prompt_str: str,
308
+ temperature: float,
309
+ max_tokens: int,
310
+ top_p: float,
311
+ **kwargs
312
+ ) -> str:
313
+ """使用 transformers 生成响应"""
314
+
315
+ # 对提示进行编码
316
+ inputs = self.tokenizer.encode(prompt_str, return_tensors="pt")
317
+
318
+ # 移动输入到正确的设备
319
+ if self.device_map is None or (self.device and self.device.type == "mps"):
320
+ inputs = inputs.to(self.device)
321
+
322
+ # 生成参数
323
+ generation_config = {
324
+ "max_new_tokens": max_tokens,
325
+ "temperature": temperature,
326
+ "top_p": top_p,
327
+ "do_sample": True if temperature > 0 else False,
328
+ "pad_token_id": self.tokenizer.pad_token_id,
329
+ "eos_token_id": self.tokenizer.eos_token_id,
330
+ "repetition_penalty": kwargs.get("repetition_penalty", 1.1),
331
+ "no_repeat_ngram_size": kwargs.get("no_repeat_ngram_size", 3),
332
+ }
333
+
334
+ # 如果温度为0,使用贪婪解码
335
+ if temperature == 0:
336
+ generation_config["do_sample"] = False
337
+ generation_config.pop("temperature", None)
338
+ generation_config.pop("top_p", None)
339
+
340
+ try:
341
+ # 生成响应
342
+ with torch.no_grad():
343
+ outputs = self.model.generate(
344
+ inputs,
345
+ **generation_config
346
+ )
347
+
348
+ # 解码生成的文本,跳过输入部分
349
+ generated_tokens = outputs[0][len(inputs[0]):]
350
+ generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
351
+
352
+ return generated_text
353
+
354
+ except Exception as e:
355
+ print(f"生成响应时出错: {e}")
356
+ raise
357
+
358
+ def get_model_info(self) -> Dict[str, Union[str, bool, int]]:
359
+ """获取模型信息"""
360
+ model_info = {
361
+ "model_name": self.model_name,
362
+ "use_4bit_quantization": self.use_4bit_quantization,
363
+ "device": str(self.device),
364
+ "device_type": self.device.type,
365
+ "device_map": self.device_map,
366
+ "model_type": "transformers",
367
+ "torch_dtype": str(self.torch_dtype),
368
+ "mps_available": torch.backends.mps.is_available() if hasattr(torch.backends, 'mps') else False,
369
+ "cuda_available": torch.cuda.is_available(),
370
+ }
371
+
372
+ # 添加模型配置信息(如果可用)
373
+ try:
374
+ if hasattr(self.model, "config"):
375
+ config = self.model.config
376
+ model_info.update({
377
+ "vocab_size": getattr(config, "vocab_size", "未知"),
378
+ "hidden_size": getattr(config, "hidden_size", "未知"),
379
+ "num_layers": getattr(config, "num_hidden_layers", "未知"),
380
+ "num_attention_heads": getattr(config, "num_attention_heads", "未知"),
381
+ })
382
+ except Exception:
383
+ pass
384
+
385
+ return model_info
386
+
387
+ def clear_cache(self):
388
+ """清理 GPU 缓存"""
389
+ if torch.cuda.is_available():
390
+ torch.cuda.empty_cache()
391
+ print("GPU 缓存已清理")
src/podcast_transcribe/llm/llm_gemma_mlx.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mlx_lm import load, generate
2
+ from mlx_lm.sample_utils import make_sampler
3
+ from typing import Dict, Union
4
+ from .llm_base import BaseChatCompletion
5
+
6
+
7
+ class GemmaMLXChatCompletion(BaseChatCompletion):
8
+ """基于 MLX 库的 Gemma 聊天完成实现"""
9
+
10
+ def __init__(self, model_name: str = "mlx-community/gemma-3-12b-it-4bit-DWQ"):
11
+ super().__init__(model_name)
12
+ self._load_model_and_tokenizer()
13
+
14
+ def _load_model_and_tokenizer(self):
15
+ """加载 MLX 模型和分词器"""
16
+ try:
17
+ print(f"正在加载 MLX 模型: {self.model_name}")
18
+ self.model, self.tokenizer = load(self.model_name)
19
+ print(f"MLX 模型 {self.model_name} 加载成功")
20
+ except Exception as e:
21
+ print(f"加载模型 {self.model_name} 时出错: {e}")
22
+ print("请确保模型名称正确且可访问。")
23
+ print("您可以尝试使用 'mlx_lm.utils.get_model_path(model_name)' 搜索可用的模型。")
24
+ raise
25
+
26
+ def _generate_response(
27
+ self,
28
+ prompt_str: str,
29
+ temperature: float,
30
+ max_tokens: int,
31
+ top_p: float,
32
+ **kwargs
33
+ ) -> str:
34
+ """使用 MLX 生成响应"""
35
+
36
+ # 为temperature和top_p创建一个采样器
37
+ sampler = make_sampler(temp=temperature, top_p=top_p)
38
+
39
+ # 生成响应
40
+ # mlx_lm中的`generate`函数接受模型、分词器、提示和其他生成参数。
41
+ # 我们需要将我们的参数映射到`generate`期望的参数。
42
+ # `mlx_lm.generate` 的 verbose 参数可用于调试。
43
+ # `temperature` 是 `mlx_lm.generate` 中温度的参数名称。
44
+ response_text = generate(
45
+ self.model,
46
+ self.tokenizer,
47
+ prompt=prompt_str,
48
+ max_tokens=max_tokens,
49
+ sampler=sampler,
50
+ # verbose=True # 取消注释以调试生成过程
51
+ )
52
+
53
+ return response_text
54
+
55
+ def get_model_info(self) -> Dict[str, Union[str, bool, int]]:
56
+ """获取模型信息"""
57
+ return {
58
+ "model_name": self.model_name,
59
+ "model_type": "mlx",
60
+ "library": "mlx_lm"
61
+ }
62
+
src/podcast_transcribe/llm/llm_gemma_transfomers.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
3
+ from typing import List, Dict, Optional, Union, Literal
4
+ from .llm_base import TransformersBaseChatCompletion
5
+
6
+
7
+ class GemmaTransformersChatCompletion(TransformersBaseChatCompletion):
8
+ """基于 Transformers 库的 Gemma 聊天完成实现"""
9
+
10
+ def __init__(
11
+ self,
12
+ model_name: str = "google/gemma-3-12b-it",
13
+ use_4bit_quantization: bool = False,
14
+ device_map: Optional[str] = "auto",
15
+ device: Optional[str] = None,
16
+ trust_remote_code: bool = True
17
+ ):
18
+ # Gemma 使用 float16 作为默认数据类型
19
+ super().__init__(
20
+ model_name=model_name,
21
+ use_4bit_quantization=use_4bit_quantization,
22
+ device_map=device_map,
23
+ device=device,
24
+ trust_remote_code=trust_remote_code,
25
+ torch_dtype=torch.float16
26
+ )
27
+
28
+ def _print_error_hints(self):
29
+ """打印Gemma特定的错误提示信息"""
30
+ super()._print_error_hints()
31
+ print("Gemma 特殊要求:")
32
+ print("- 建议使用 Transformers >= 4.21.0")
33
+ print("- 推荐使用 float16 数据类型")
34
+ print("- 确保有足够的GPU内存")
35
+
36
+
37
+ # 为了保持向后兼容性,也可以提供一个简化的工厂函数
38
+ def create_gemma_transformers_client(
39
+ model_name: str = "google/gemma-3-12b-it",
40
+ use_4bit_quantization: bool = False,
41
+ device: Optional[str] = None,
42
+ **kwargs
43
+ ) -> GemmaTransformersChatCompletion:
44
+ """
45
+ 创建 Gemma Transformers 客户端的工厂函数
46
+
47
+ Args:
48
+ model_name: 模型名称
49
+ use_4bit_quantization: 是否使用4bit量化
50
+ device: 指定设备 ("cpu", "cuda", "mps", 等)
51
+ **kwargs: 其他传递给构造函数的参数
52
+
53
+ Returns:
54
+ GemmaTransformersChatCompletion 实例
55
+ """
56
+ return GemmaTransformersChatCompletion(
57
+ model_name=model_name,
58
+ use_4bit_quantization=use_4bit_quantization,
59
+ device=device,
60
+ **kwargs
61
+ )
src/podcast_transcribe/llm/llm_phi4_transfomers.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
3
+ from typing import List, Dict, Optional, Union, Literal
4
+ from .llm_base import TransformersBaseChatCompletion
5
+
6
+
7
+ class Phi4TransformersChatCompletion(TransformersBaseChatCompletion):
8
+ """基于 Transformers 库的 Phi-4-mini-reasoning 聊天完成实现"""
9
+
10
+ def __init__(
11
+ self,
12
+ model_name: str = "microsoft/Phi-4-mini-reasoning",
13
+ use_4bit_quantization: bool = False,
14
+ device_map: Optional[str] = "auto",
15
+ device: Optional[str] = None,
16
+ trust_remote_code: bool = True
17
+ ):
18
+ # Phi-4 使用 bfloat16 作为推荐数据类型
19
+ super().__init__(
20
+ model_name=model_name,
21
+ use_4bit_quantization=use_4bit_quantization,
22
+ device_map=device_map,
23
+ device=device,
24
+ trust_remote_code=trust_remote_code,
25
+ torch_dtype=torch.bfloat16
26
+ )
27
+
28
+ def _print_error_hints(self):
29
+ """打印Phi-4特定的错误提示信息"""
30
+ super()._print_error_hints()
31
+ print("Phi-4 特殊要求:")
32
+ print("- 建议使用 Transformers >= 4.51.3")
33
+ print("- 推荐使用 bfloat16 数据类型")
34
+ print("- 模型支持 128K token 上下文长度")
35
+
36
+ def _format_phi4_messages(self, messages: List[Dict[str, str]]) -> str:
37
+ """
38
+ 格式化消息为 Phi-4 的聊天格式
39
+ Phi-4 使用特定的聊天模板格式
40
+ """
41
+ # 使用 tokenizer 的内置聊天模板
42
+ if hasattr(self.tokenizer, 'apply_chat_template'):
43
+ return self.tokenizer.apply_chat_template(
44
+ messages,
45
+ tokenize=False,
46
+ add_generation_prompt=True
47
+ )
48
+ else:
49
+ # 如果没有聊天模板,使用 Phi-4 的标准格式
50
+ formatted_prompt = ""
51
+ for message in messages:
52
+ role = message.get("role", "user")
53
+ content = message.get("content", "")
54
+
55
+ if role == "system":
56
+ formatted_prompt += f"<|system|>\n{content}<|end|>\n"
57
+ elif role == "user":
58
+ formatted_prompt += f"<|user|>\n{content}<|end|>\n"
59
+ elif role == "assistant":
60
+ formatted_prompt += f"<|assistant|>\n{content}<|end|>\n"
61
+
62
+ # 添加助手开始标记
63
+ formatted_prompt += "<|assistant|>\n"
64
+ return formatted_prompt
65
+
66
+ def _generate_response(
67
+ self,
68
+ prompt_str: str,
69
+ temperature: float,
70
+ max_tokens: int,
71
+ top_p: float,
72
+ enable_reasoning: bool = True,
73
+ **kwargs
74
+ ) -> str:
75
+ """使用 transformers 生成响应,针对 Phi-4 推理功能优化"""
76
+
77
+ # 对提示进行编码
78
+ inputs = self.tokenizer.encode(prompt_str, return_tensors="pt")
79
+
80
+ # 移动输入到正确的设备
81
+ if self.device_map is None or self.device.type == "mps":
82
+ inputs = inputs.to(self.device)
83
+
84
+ # Phi-4-mini-reasoning 优化的生成参数
85
+ generation_config = {
86
+ "max_new_tokens": min(max_tokens, 32768), # Phi-4-mini 支持最大 32K token
87
+ "temperature": temperature,
88
+ "top_p": top_p,
89
+ "do_sample": True if temperature > 0 else False,
90
+ "pad_token_id": self.tokenizer.pad_token_id,
91
+ "eos_token_id": self.tokenizer.eos_token_id,
92
+ "repetition_penalty": kwargs.get("repetition_penalty", 1.1),
93
+ "no_repeat_ngram_size": kwargs.get("no_repeat_ngram_size", 3),
94
+ }
95
+
96
+ # 推理模式配置
97
+ if enable_reasoning and "reasoning" in self.model_name.lower():
98
+ # 为推理任务优化的配置
99
+ generation_config.update({
100
+ "temperature": max(temperature, 0.1), # 推理模式下保持一定的温度
101
+ "top_p": min(top_p, 0.95), # 推理模式下限制 top_p
102
+ "do_sample": True, # 推理模式下总是启用采样
103
+ "early_stopping": False, # 允许完整的推理过程
104
+ })
105
+
106
+ # 如果温度为0,使用贪婪解码
107
+ if temperature == 0:
108
+ generation_config["do_sample"] = False
109
+ generation_config.pop("temperature", None)
110
+ generation_config.pop("top_p", None)
111
+
112
+ try:
113
+ # 生成响应
114
+ with torch.no_grad():
115
+ outputs = self.model.generate(
116
+ inputs,
117
+ **generation_config
118
+ )
119
+
120
+ # 解码生成的文本,跳过输入部分
121
+ generated_tokens = outputs[0][len(inputs[0]):]
122
+ generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
123
+
124
+ return generated_text
125
+
126
+ except Exception as e:
127
+ print(f"生成响应时出错: {e}")
128
+ raise
129
+
130
+ def create(
131
+ self,
132
+ messages: List[Dict[str, str]],
133
+ temperature: float = 0.7,
134
+ max_tokens: int = 2048,
135
+ top_p: float = 1.0,
136
+ model: Optional[str] = None,
137
+ enable_reasoning: bool = True,
138
+ **kwargs,
139
+ ):
140
+ """
141
+ 创建聊天完成响应,支持Phi-4特有的推理功能
142
+ """
143
+ if model and model != self.model_name:
144
+ print(f"警告: 'model' 参数 ({model}) 与初始化的模型 ({self.model_name}) 不同。"
145
+ f"正在使用初始化的模型。要使用不同的模型,请重新初始化该类。")
146
+
147
+ # 检查是否为推理任务
148
+ is_reasoning_task = self._is_reasoning_task(messages)
149
+
150
+ # 格式化消息为 Phi-4 聊天格式
151
+ if is_reasoning_task and enable_reasoning:
152
+ prompt_str = self._format_reasoning_prompt(messages)
153
+ else:
154
+ prompt_str = self._format_phi4_messages(messages)
155
+
156
+ # 生成响应
157
+ response_text = self._generate_response(
158
+ prompt_str,
159
+ temperature,
160
+ max_tokens,
161
+ top_p,
162
+ enable_reasoning=enable_reasoning and is_reasoning_task,
163
+ **kwargs
164
+ )
165
+
166
+ # 后处理响应(使用基类的方法,但针对Phi-4调整)
167
+ assistant_message_content = self._post_process_phi4_response(response_text, prompt_str)
168
+
169
+ # 计算token使用量
170
+ token_usage = self._calculate_tokens(prompt_str, assistant_message_content)
171
+
172
+ # 构建响应对象
173
+ response = self._build_chat_completion_response(assistant_message_content, token_usage)
174
+
175
+ # 添加Phi-4特有的信息
176
+ response["reasoning_enabled"] = enable_reasoning and is_reasoning_task
177
+
178
+ return response
179
+
180
+ def _post_process_phi4_response(self, response_text: str, prompt_str: str) -> str:
181
+ """
182
+ 后处理Phi-4生成的响应文本
183
+ """
184
+ # Phi-4的输出通常不包含输入提示,直接返回生成的内容
185
+ assistant_message_content = response_text.strip()
186
+
187
+ # 清理可能的特殊标记
188
+ if assistant_message_content.endswith("<|end|>"):
189
+ assistant_message_content = assistant_message_content[:-7].strip()
190
+
191
+ return assistant_message_content
192
+
193
+ def _is_reasoning_task(self, messages: List[Dict[str, str]]) -> bool:
194
+ """检测是否为推理任务"""
195
+ reasoning_keywords = [
196
+ "解题", "推理", "计算", "证明", "分析", "逻辑", "步骤",
197
+ "solve", "reasoning", "calculate", "prove", "analyze", "logic", "step"
198
+ ]
199
+
200
+ for message in messages:
201
+ content = message.get("content", "").lower()
202
+ if any(keyword in content for keyword in reasoning_keywords):
203
+ return True
204
+
205
+ return False
206
+
207
+ def _format_reasoning_prompt(self, messages: List[Dict[str, str]]) -> str:
208
+ """
209
+ 为推理任务格式化特殊的提示词
210
+ """
211
+ # 添加推理指导的系统消息
212
+ reasoning_system_msg = {
213
+ "role": "system",
214
+ "content": "你是一个专业的数学推理助手。请逐步分析问题,展示详细的推理过程,包括:\n1. 问题理解\n2. 解题思路\n3. 具体步骤\n4. 最终答案\n\n每个步骤都要清晰明了。"
215
+ }
216
+
217
+ # 将推理系统消息添加到消息列表的开头
218
+ enhanced_messages = [reasoning_system_msg] + messages
219
+
220
+ # 使用标准格式化方法
221
+ return self._format_phi4_messages(enhanced_messages)
222
+
223
+ def reasoning_completion(
224
+ self,
225
+ messages: List[Dict[str, str]],
226
+ temperature: float = 0.3, # 推理任务使用较低的温度
227
+ max_tokens: int = 2048, # 推理任务需要更多 tokens
228
+ top_p: float = 0.9,
229
+ extract_reasoning_steps: bool = True,
230
+ **kwargs
231
+ ) -> Dict[str, Union[str, Dict, List]]:
232
+ """
233
+ 专门用于推理任务的聊天完成接口
234
+
235
+ Args:
236
+ messages: 对话消息列表
237
+ temperature: 采样温度(推理任务建议使用较低值)
238
+ max_tokens: 最大生成token数量
239
+ top_p: top-p采样参数
240
+ extract_reasoning_steps: 是否提取推理步骤
241
+ **kwargs: 其他参数
242
+
243
+ Returns:
244
+ 包含推理步骤的响应字典
245
+ """
246
+ # 强制启用推理模式
247
+ response = self.create(
248
+ messages=messages,
249
+ temperature=temperature,
250
+ max_tokens=max_tokens,
251
+ top_p=top_p,
252
+ enable_reasoning=True,
253
+ **kwargs
254
+ )
255
+
256
+ if extract_reasoning_steps:
257
+ # 提取推理步骤
258
+ content = response["choices"][0]["message"]["content"]
259
+ reasoning_steps = self._extract_reasoning_steps(content)
260
+ response["reasoning_steps"] = reasoning_steps
261
+
262
+ return response
263
+
264
+ def _extract_reasoning_steps(self, content: str) -> List[Dict[str, str]]:
265
+ """
266
+ 从响应内容中提取推理步骤
267
+ """
268
+ steps = []
269
+ lines = content.split('\n')
270
+ current_step = {"title": "", "content": ""}
271
+
272
+ step_patterns = [
273
+ "1. 问题理解", "2. 解题思路", "3. 具体步骤", "4. 最终答案",
274
+ "步骤", "分析", "解答", "结论", "reasoning", "step", "analysis", "solution"
275
+ ]
276
+
277
+ for line in lines:
278
+ line = line.strip()
279
+ if not line:
280
+ continue
281
+
282
+ # 检查是否是新的步骤开始
283
+ is_new_step = any(pattern in line.lower() for pattern in step_patterns)
284
+ if is_new_step and current_step["content"]:
285
+ steps.append(current_step.copy())
286
+ current_step = {"title": line, "content": ""}
287
+ elif is_new_step:
288
+ current_step["title"] = line
289
+ else:
290
+ if current_step["title"]:
291
+ current_step["content"] += line + "\n"
292
+ else:
293
+ current_step["content"] = line + "\n"
294
+
295
+ # 添加最后一个步骤
296
+ if current_step["title"] or current_step["content"]:
297
+ steps.append(current_step)
298
+
299
+ return steps
300
+
301
+ def get_model_info(self) -> Dict[str, Union[str, bool, int]]:
302
+ """获取 Phi-4 模型信息"""
303
+ model_info = super().get_model_info()
304
+
305
+ # 添加Phi-4特有的信息
306
+ model_info.update({
307
+ "model_family": "Phi-4-mini-reasoning",
308
+ "parameters": "3.8B",
309
+ "context_length": "128K tokens",
310
+ "specialization": "数学推理优化",
311
+ })
312
+
313
+ return model_info
314
+
315
+
316
+ # 工厂函数
317
+ def create_phi4_transformers_client(
318
+ model_name: str = "microsoft/Phi-4-mini-reasoning",
319
+ use_4bit_quantization: bool = False,
320
+ device: Optional[str] = None,
321
+ **kwargs
322
+ ) -> Phi4TransformersChatCompletion:
323
+ """
324
+ 创建 Phi-4 Transformers 客户端的工厂函数
325
+
326
+ Args:
327
+ model_name: 模型名称,默认为 microsoft/Phi-4-mini-reasoning
328
+ use_4bit_quantization: 是否使用4bit量化
329
+ device: 指定设备 ("cpu", "cuda", "mps", 等)
330
+ **kwargs: 其他传递给构造函数的参数
331
+
332
+ Returns:
333
+ Phi4TransformersChatCompletion 实例
334
+ """
335
+ return Phi4TransformersChatCompletion(
336
+ model_name=model_name,
337
+ use_4bit_quantization=use_4bit_quantization,
338
+ device=device,
339
+ **kwargs
340
+ )
341
+
342
+ def create_reasoning_client(
343
+ model_name: str = "microsoft/Phi-4-mini-reasoning",
344
+ use_4bit_quantization: bool = False,
345
+ device: Optional[str] = None,
346
+ **kwargs
347
+ ) -> Phi4TransformersChatCompletion:
348
+ """
349
+ 创建专门用于推理任务的 Phi-4 客户端
350
+
351
+ Args:
352
+ model_name: 模型名称,推荐使用 microsoft/Phi-4-mini-reasoning
353
+ use_4bit_quantization: 是否使用4bit量化
354
+ device: 指定设备 ("cpu", "cuda", "mps", 等)
355
+ **kwargs: 其他传递给构造函数的参数
356
+
357
+ Returns:
358
+ 优化了推理功能的 Phi4TransformersChatCompletion 实例
359
+ """
360
+ # 确保使用推理模型
361
+ if "reasoning" not in model_name.lower():
362
+ print("警告: 建议使用包含 'reasoning' 的模型名称以获得最佳推理性能")
363
+
364
+ return Phi4TransformersChatCompletion(
365
+ model_name=model_name,
366
+ use_4bit_quantization=use_4bit_quantization,
367
+ device=device,
368
+ **kwargs
369
+ )
src/podcast_transcribe/llm/llm_router.py ADDED
@@ -0,0 +1,578 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM模型调用路由器
3
+ 根据传递的provider参数调用不同的LLM实现,支持延迟加载
4
+ """
5
+
6
+ import logging
7
+ from typing import Dict, Any, Optional, List, Union
8
+ from .llm_base import BaseChatCompletion
9
+ from . import llm_gemma_mlx
10
+ from . import llm_gemma_transfomers
11
+ from . import llm_phi4_transfomers
12
+
13
+ # 配置日志
14
+ logger = logging.getLogger("llm")
15
+
16
+
17
+ class LLMRouter:
18
+ """LLM模型调用路由器,支持多种实现的统一调用"""
19
+
20
+ def __init__(self):
21
+ """初始化路由器"""
22
+ self._loaded_modules = {} # 用于缓存已加载的模块
23
+ self._llm_instances = {} # 用于缓存已实例化的LLM实例
24
+
25
+ # 定义支持的provider配置
26
+ self._provider_configs = {
27
+ "gemma-mlx": {
28
+ "module_path": "llm_gemma_mlx",
29
+ "class_name": "GemmaMLXChatCompletion",
30
+ "default_model": "mlx-community/gemma-3-12b-it-4bit-DWQ",
31
+ "supported_params": ["model_name"],
32
+ "description": "基于MLX库的Gemma聊天完成实现"
33
+ },
34
+ "gemma-transformers": {
35
+ "module_path": "llm_gemma_transfomers",
36
+ "class_name": "GemmaTransformersChatCompletion",
37
+ "default_model": "google/gemma-3-12b-it",
38
+ "supported_params": [
39
+ "model_name", "use_4bit_quantization", "device_map",
40
+ "device", "trust_remote_code"
41
+ ],
42
+ "description": "基于Transformers库的Gemma聊天完成实现"
43
+ },
44
+ "phi4-transformers": {
45
+ "module_path": "llm_phi4_transfomers",
46
+ "class_name": "Phi4TransformersChatCompletion",
47
+ "default_model": "microsoft/Phi-4-reasoning",
48
+ "supported_params": [
49
+ "model_name", "use_4bit_quantization", "device_map",
50
+ "device", "trust_remote_code", "enable_reasoning"
51
+ ],
52
+ "description": "基于Transformers库的Phi-4推理聊天完成实现"
53
+ }
54
+ }
55
+
56
+ def _lazy_load_module(self, provider: str):
57
+ """
58
+ 获取指定provider的模块
59
+
60
+ 参数:
61
+ provider: provider名称
62
+
63
+ 返回:
64
+ 对应的模块
65
+ """
66
+ if provider not in self._provider_configs:
67
+ raise ValueError(f"不支持的provider: {provider}")
68
+
69
+ if provider not in self._loaded_modules:
70
+ module_path = self._provider_configs[provider]["module_path"]
71
+ logger.info(f"获取模块: {module_path}")
72
+
73
+ # 根据module_path返回对应的模块
74
+ if module_path == "llm_gemma_mlx":
75
+ module = llm_gemma_mlx
76
+ elif module_path == "llm_gemma_transfomers":
77
+ module = llm_gemma_transfomers
78
+ elif module_path == "llm_phi4_transfomers":
79
+ module = llm_phi4_transfomers
80
+ else:
81
+ raise ImportError(f"未找到模块: {module_path}")
82
+
83
+ self._loaded_modules[provider] = module
84
+ logger.info(f"模块 {module_path} 获取成功")
85
+
86
+ return self._loaded_modules[provider]
87
+
88
+ def _get_llm_class(self, provider: str):
89
+ """
90
+ 获取指定provider的LLM类
91
+
92
+ 参数:
93
+ provider: provider名称
94
+
95
+ 返回:
96
+ LLM类
97
+ """
98
+ module = self._lazy_load_module(provider)
99
+ class_name = self._provider_configs[provider]["class_name"]
100
+
101
+ if not hasattr(module, class_name):
102
+ raise AttributeError(f"模块中未找到类: {class_name}")
103
+
104
+ return getattr(module, class_name)
105
+
106
+ def _filter_params(self, provider: str, params: Dict[str, Any]) -> Dict[str, Any]:
107
+ """
108
+ 过滤参数,只保留指定provider支持的参数
109
+
110
+ 参数:
111
+ provider: provider名称
112
+ params: 原始参数字典
113
+
114
+ 返回:
115
+ 过滤后的参数字典
116
+ """
117
+ supported_params = self._provider_configs[provider]["supported_params"]
118
+ filtered_params = {}
119
+
120
+ for param in supported_params:
121
+ if param in params:
122
+ filtered_params[param] = params[param]
123
+
124
+ # 如果没有指定model_name,使用默认模型
125
+ if "model_name" not in filtered_params and "model_name" in supported_params:
126
+ filtered_params["model_name"] = self._provider_configs[provider]["default_model"]
127
+
128
+ return filtered_params
129
+
130
+ def _get_instance_key(self, provider: str, params: Dict[str, Any]) -> str:
131
+ """
132
+ 生成LLM实例的缓存键
133
+
134
+ 参数:
135
+ provider: provider名称
136
+ params: 参数字典
137
+
138
+ 返回:
139
+ 实例缓存键
140
+ """
141
+ # 将参数转换为可哈希的字符串
142
+ param_str = "_".join([f"{k}={v}" for k, v in sorted(params.items())])
143
+ return f"{provider}_{param_str}"
144
+
145
+ def _get_or_create_instance(self, provider: str, **kwargs) -> BaseChatCompletion:
146
+ """
147
+ 获取或创建LLM实例(支持缓存复用)
148
+
149
+ 参数:
150
+ provider: provider名称
151
+ **kwargs: 构造函数参数
152
+
153
+ 返回:
154
+ LLM实例
155
+ """
156
+ # 过滤并准备参数
157
+ filtered_kwargs = self._filter_params(provider, kwargs)
158
+
159
+ # 生成实例缓存键
160
+ instance_key = self._get_instance_key(provider, filtered_kwargs)
161
+
162
+ # 检查是否已有缓存实例
163
+ if instance_key not in self._llm_instances:
164
+ try:
165
+ # 获取LLM类
166
+ llm_class = self._get_llm_class(provider)
167
+
168
+ logger.debug(f"创建 {provider} LLM实例,参数: {filtered_kwargs}")
169
+
170
+ # 创建实例
171
+ instance = llm_class(**filtered_kwargs)
172
+
173
+ # 缓存实例
174
+ self._llm_instances[instance_key] = instance
175
+
176
+ logger.info(f"LLM实例创建成功: {provider} ({instance.model_name})")
177
+
178
+ except Exception as e:
179
+ logger.error(f"创建 {provider} LLM实例失败: {str(e)}", exc_info=True)
180
+ raise RuntimeError(f"创建LLM实例失败: {str(e)}")
181
+
182
+ return self._llm_instances[instance_key]
183
+
184
+ def chat_completion(
185
+ self,
186
+ messages: List[Dict[str, str]],
187
+ provider: str,
188
+ temperature: float = 0.7,
189
+ max_tokens: int = 2048,
190
+ top_p: float = 1.0,
191
+ model: Optional[str] = None,
192
+ **kwargs
193
+ ) -> Dict[str, Any]:
194
+ """
195
+ 统一的聊天完成接口
196
+
197
+ 参数:
198
+ messages: 消息列表,每个消息包含role和content
199
+ provider: LLM提供者名称
200
+ temperature: 温度参数,控制生成的随机性
201
+ max_tokens: 最大生成token数
202
+ top_p: nucleus采样参数
203
+ model: 可选的模型名称,如果提供则覆盖默认model_name
204
+ **kwargs: 其他参数,如device、use_4bit_quantization等
205
+
206
+ 返回:
207
+ 聊天完成响应字典
208
+ """
209
+ logger.info(f"使用provider '{provider}' 进行聊天完成,消息数量: {len(messages)}")
210
+
211
+ if provider not in self._provider_configs:
212
+ available_providers = list(self._provider_configs.keys())
213
+ raise ValueError(f"不支持的provider: {provider}。支持的provider: {available_providers}")
214
+
215
+ try:
216
+ # 如果提供了model参数,添加到kwargs中
217
+ if model is not None:
218
+ kwargs["model_name"] = model
219
+
220
+ # 获取或创建LLM实例
221
+ llm_instance = self._get_or_create_instance(provider, **kwargs)
222
+
223
+ # 调用聊天完成
224
+ result = llm_instance.create(
225
+ messages=messages,
226
+ temperature=temperature,
227
+ max_tokens=max_tokens,
228
+ top_p=top_p,
229
+ model=model,
230
+ **kwargs
231
+ )
232
+
233
+ logger.info(f"聊天完成成功,使用tokens: {result.get('usage', {}).get('total_tokens', 'unknown')}")
234
+ return result
235
+
236
+ except Exception as e:
237
+ logger.error(f"使用provider '{provider}' 进行聊天完成失败: {str(e)}", exc_info=True)
238
+ raise RuntimeError(f"聊天完成失败: {str(e)}")
239
+
240
+ def reasoning_completion(
241
+ self,
242
+ messages: List[Dict[str, str]],
243
+ provider: str = "phi4-transformers",
244
+ temperature: float = 0.3,
245
+ max_tokens: int = 2048,
246
+ top_p: float = 0.9,
247
+ model: Optional[str] = None,
248
+ extract_reasoning_steps: bool = True,
249
+ **kwargs
250
+ ) -> Dict[str, Any]:
251
+ """
252
+ 专门用于推理任务的聊天完成接口
253
+
254
+ 参数:
255
+ messages: 消息列表,每个消息包含role和content
256
+ provider: LLM提供者名称,默认使用phi4-transformers
257
+ temperature: 温度参数(推理任务建议使用较低值)
258
+ max_tokens: 最大生成token数
259
+ top_p: nucleus采样参数
260
+ model: 可选的模型名称
261
+ extract_reasoning_steps: 是否提取推理步骤
262
+ **kwargs: 其他参数
263
+
264
+ 返回:
265
+ 包含推理步骤的响应字典
266
+ """
267
+ logger.info(f"使用provider '{provider}' 进行推理完成,消息数量: {len(messages)}")
268
+
269
+ # 确保使用支持推理的provider
270
+ if provider not in ["phi4-transformers"]:
271
+ logger.warning(f"Provider '{provider}' 可能不支持推理功能,建议使用 'phi4-transformers'")
272
+
273
+ try:
274
+ # 如果提供了model参数,添加到kwargs中
275
+ if model is not None:
276
+ kwargs["model_name"] = model
277
+
278
+ # 获取或创建LLM实例
279
+ llm_instance = self._get_or_create_instance(provider, **kwargs)
280
+
281
+ # 检查实例是否支持推理完成
282
+ if hasattr(llm_instance, 'reasoning_completion'):
283
+ result = llm_instance.reasoning_completion(
284
+ messages=messages,
285
+ temperature=temperature,
286
+ max_tokens=max_tokens,
287
+ top_p=top_p,
288
+ extract_reasoning_steps=extract_reasoning_steps,
289
+ **kwargs
290
+ )
291
+ else:
292
+ # 回退到普通聊天完成
293
+ logger.warning(f"Provider '{provider}' 不支持推理完成,回退到普通聊天完成")
294
+ result = llm_instance.create(
295
+ messages=messages,
296
+ temperature=temperature,
297
+ max_tokens=max_tokens,
298
+ top_p=top_p,
299
+ model=model,
300
+ **kwargs
301
+ )
302
+
303
+ logger.info(f"推理完成成功,使用tokens: {result.get('usage', {}).get('total_tokens', 'unknown')}")
304
+ return result
305
+
306
+ except Exception as e:
307
+ logger.error(f"使用provider '{provider}' 进行推理完成失败: {str(e)}", exc_info=True)
308
+ raise RuntimeError(f"推理完成失败: {str(e)}")
309
+
310
+ def get_model_info(self, provider: str, **kwargs) -> Dict[str, Any]:
311
+ """
312
+ 获取模型信息
313
+
314
+ 参数:
315
+ provider: provider名称
316
+ **kwargs: 构造函数参数
317
+
318
+ 返回:
319
+ 模型信息字典
320
+ """
321
+ try:
322
+ llm_instance = self._get_or_create_instance(provider, **kwargs)
323
+ return llm_instance.get_model_info()
324
+ except Exception as e:
325
+ logger.error(f"获取模型信息失败: {str(e)}")
326
+ raise RuntimeError(f"获取模型信息失败: {str(e)}")
327
+
328
+ def get_available_providers(self) -> Dict[str, str]:
329
+ """
330
+ 获取所有可用的provider及其描述
331
+
332
+ 返回:
333
+ provider名称到描述的映射
334
+ """
335
+ return {
336
+ provider: config["description"]
337
+ for provider, config in self._provider_configs.items()
338
+ }
339
+
340
+ def get_provider_info(self, provider: str) -> Dict[str, Any]:
341
+ """
342
+ 获取指定provider的详细信息
343
+
344
+ 参数:
345
+ provider: provider名称
346
+
347
+ 返回:
348
+ provider的配置信息
349
+ """
350
+ if provider not in self._provider_configs:
351
+ raise ValueError(f"不支持的provider: {provider}")
352
+
353
+ return self._provider_configs[provider].copy()
354
+
355
+ def clear_cache(self):
356
+ """清理缓存的实例"""
357
+ # 清理每个实例的GPU缓存
358
+ for instance in self._llm_instances.values():
359
+ if hasattr(instance, 'clear_cache'):
360
+ instance.clear_cache()
361
+
362
+ # 清理实例缓存
363
+ self._llm_instances.clear()
364
+ logger.info("LLM实例缓存已清理")
365
+
366
+
367
+ # 创建全局路由器实例
368
+ _router = LLMRouter()
369
+
370
+
371
+ def chat_completion(
372
+ messages: List[Dict[str, str]],
373
+ provider: str = "gemma-mlx",
374
+ temperature: float = 0.7,
375
+ max_tokens: int = 2048,
376
+ top_p: float = 1.0,
377
+ model: Optional[str] = None,
378
+ device: Optional[str] = None,
379
+ use_4bit_quantization: bool = False,
380
+ device_map: Optional[str] = "auto",
381
+ trust_remote_code: bool = True,
382
+ **kwargs
383
+ ) -> Dict[str, Any]:
384
+ """
385
+ 统一的聊天完成接口函数
386
+
387
+ 参数:
388
+ messages: 消息列表,每个消息包含role和content字段
389
+ provider: LLM提供者,可选值:
390
+ - "gemma-mlx": 基于MLX库的Gemma聊天完成实现
391
+ - "gemma-transformers": 基于Transformers库的Gemma聊天完成实现
392
+ - "phi4-transformers": 基于Transformers库的Phi-4推理聊天完成实现
393
+ temperature: 温度参数,控制生成的随机性 (0.0-2.0)
394
+ max_tokens: 最大生成token数
395
+ top_p: nucleus采样参数 (0.0-1.0)
396
+ model: 模型名称,如果不指定则使用默认模型
397
+ device: 推理设备,'cpu'、'cuda'、'mps'(仅transformers provider支持)
398
+ use_4bit_quantization: 是否使用4bit量化(仅transformers provider支持)
399
+ device_map: 设备映射配置(仅transformers provider支持)
400
+ trust_remote_code: 是否信任远程代码(仅transformers provider支持)
401
+ **kwargs: 其他参数
402
+
403
+ 返回:
404
+ 聊天完成响应字典,包含生成的消息和使用统计
405
+
406
+ 示例:
407
+ # 使用默认MLX实现
408
+ response = chat_completion(
409
+ messages=[{"role": "user", "content": "你好"}],
410
+ provider="gemma-mlx"
411
+ )
412
+
413
+ # 使用Gemma transformers实现
414
+ response = chat_completion(
415
+ messages=[{"role": "user", "content": "你好"}],
416
+ provider="gemma-transformers",
417
+ model="google/gemma-3-12b-it",
418
+ device="cuda",
419
+ use_4bit_quantization=True
420
+ )
421
+
422
+ # 使用Phi-4推理实现
423
+ response = chat_completion(
424
+ messages=[{"role": "user", "content": "解这个数学题:2x + 5 = 15"}],
425
+ provider="phi4-transformers",
426
+ model="microsoft/Phi-4-mini-reasoning",
427
+ device="cuda"
428
+ )
429
+
430
+ # 自定义参数
431
+ response = chat_completion(
432
+ messages=[
433
+ {"role": "system", "content": "你是一个有用的助手"},
434
+ {"role": "user", "content": "请介绍自己"}
435
+ ],
436
+ provider="gemma-mlx",
437
+ temperature=0.8,
438
+ max_tokens=1024
439
+ )
440
+ """
441
+ # 准备参数
442
+ params = kwargs.copy()
443
+ if model is not None:
444
+ params["model_name"] = model
445
+ if device is not None:
446
+ params["device"] = device
447
+ if use_4bit_quantization:
448
+ params["use_4bit_quantization"] = use_4bit_quantization
449
+ if device_map != "auto":
450
+ params["device_map"] = device_map
451
+ if not trust_remote_code:
452
+ params["trust_remote_code"] = trust_remote_code
453
+
454
+ return _router.chat_completion(
455
+ messages=messages,
456
+ provider=provider,
457
+ temperature=temperature,
458
+ max_tokens=max_tokens,
459
+ top_p=top_p,
460
+ model=model,
461
+ **params
462
+ )
463
+
464
+
465
+ def reasoning_completion(
466
+ messages: List[Dict[str, str]],
467
+ provider: str = "phi4-transformers",
468
+ temperature: float = 0.3,
469
+ max_tokens: int = 2048,
470
+ top_p: float = 0.9,
471
+ model: Optional[str] = None,
472
+ device: Optional[str] = None,
473
+ use_4bit_quantization: bool = False,
474
+ device_map: Optional[str] = "auto",
475
+ trust_remote_code: bool = True,
476
+ extract_reasoning_steps: bool = True,
477
+ **kwargs
478
+ ) -> Dict[str, Any]:
479
+ """
480
+ 专门用于推理任务的聊天完成接口函数
481
+
482
+ 参数:
483
+ messages: 消息列表,每个消息包含role和content字段
484
+ provider: LLM提供者,默认使用phi4-transformers
485
+ temperature: 温度参数(推理任务建议使用较低值)
486
+ max_tokens: 最大生成token数
487
+ top_p: nucleus采样参数
488
+ model: 模型名称,如果不指定则使用默认模型
489
+ device: 推理设备
490
+ use_4bit_quantization: 是否使用4bit量化
491
+ device_map: 设备映射配置
492
+ trust_remote_code: 是否信任远程代码
493
+ extract_reasoning_steps: 是否提取推理步骤
494
+ **kwargs: 其他参数
495
+
496
+ 返回:
497
+ 包含推理步骤的响应字典
498
+
499
+ 示例:
500
+ # 数学推理任务
501
+ response = reasoning_completion(
502
+ messages=[{"role": "user", "content": "解这个方程:3x + 7 = 22"}],
503
+ provider="phi4-transformers",
504
+ extract_reasoning_steps=True
505
+ )
506
+
507
+ # 逻辑推理任务
508
+ response = reasoning_completion(
509
+ messages=[{"role": "user", "content": "如果所有的猫都是动物,而小花是一只猫,那么小花是什么?"}],
510
+ provider="phi4-transformers",
511
+ temperature=0.2
512
+ )
513
+ """
514
+ # 准备参数
515
+ params = kwargs.copy()
516
+ if model is not None:
517
+ params["model_name"] = model
518
+ if device is not None:
519
+ params["device"] = device
520
+ if use_4bit_quantization:
521
+ params["use_4bit_quantization"] = use_4bit_quantization
522
+ if device_map != "auto":
523
+ params["device_map"] = device_map
524
+ if not trust_remote_code:
525
+ params["trust_remote_code"] = trust_remote_code
526
+
527
+ return _router.reasoning_completion(
528
+ messages=messages,
529
+ provider=provider,
530
+ temperature=temperature,
531
+ max_tokens=max_tokens,
532
+ top_p=top_p,
533
+ model=model,
534
+ extract_reasoning_steps=extract_reasoning_steps,
535
+ **params
536
+ )
537
+
538
+
539
+ def get_model_info(provider: str = "gemma-mlx", **kwargs) -> Dict[str, Any]:
540
+ """
541
+ 获取模型信息
542
+
543
+ 参数:
544
+ provider: provider名称
545
+ **kwargs: 构造函数参数
546
+
547
+ 返回:
548
+ 模型信息字典
549
+ """
550
+ return _router.get_model_info(provider, **kwargs)
551
+
552
+
553
+ def get_available_providers() -> Dict[str, str]:
554
+ """
555
+ 获取所有可用的LLM提供者
556
+
557
+ 返回:
558
+ provider名称到描述的映射
559
+ """
560
+ return _router.get_available_providers()
561
+
562
+
563
+ def get_provider_info(provider: str) -> Dict[str, Any]:
564
+ """
565
+ 获取指定provider的详细信息
566
+
567
+ 参数:
568
+ provider: provider名称
569
+
570
+ 返回:
571
+ provider的配置信息
572
+ """
573
+ return _router.get_provider_info(provider)
574
+
575
+
576
+ def clear_cache():
577
+ """清理缓存的LLM实例"""
578
+ _router.clear_cache()
src/podcast_transcribe/rss/podcast_rss_parser.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import feedparser
3
+ # from dataclasses import dataclass, field # 已移除
4
+ from typing import Optional # , List, Dict # List 和 Dict 不再需要
5
+ from datetime import datetime
6
+ import time
7
+
8
+ from ..schemas import PodcastEpisode, PodcastChannel
9
+
10
+ def _parse_date(date_str: Optional[str]) -> Optional[datetime]:
11
+ if not date_str:
12
+ return None
13
+ try:
14
+ # feedparser 已经将日期解析为 time.struct_time 类型
15
+ # 我们将其转换为 datetime 类型
16
+ if isinstance(date_str, time.struct_time):
17
+ return datetime.fromtimestamp(time.mktime(date_str))
18
+ # 如果 feedparser 解析失败或返回字符串,则回退使用其他字符串格式解析
19
+ # 这是一种常见的 RSS 日期格式
20
+ return datetime.strptime(date_str, '%a, %d %b %Y %H:%M:%S %z')
21
+ except (ValueError, TypeError):
22
+ try:
23
+ return datetime.strptime(date_str, '%a, %d %b %Y %H:%M:%S %Z') # 处理 GMT, EST 等时区
24
+ except (ValueError, TypeError):
25
+ # 如果时区缺失或无法解析,则尝试不带时区解析
26
+ try:
27
+ return datetime.strptime(date_str[:-6], '%a, %d %b %Y %H:%M:%S')
28
+ except (ValueError, TypeError):
29
+ print(f"Warning: Could not parse date string: {date_str}")
30
+ return None
31
+
32
+ def fetch_rss_content(rss_url: str) -> Optional[bytes]:
33
+ """
34
+ 通过 HTTP 请求获取 RSS feed 的内容。
35
+
36
+ 参数:
37
+ rss_url: 播客 RSS feed 的 URL。
38
+
39
+ 返回:
40
+ bytes 类型的 RSS 内容,如果获取失败则返回 None。
41
+ """
42
+ headers = {
43
+ 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.114 Safari/537.36'
44
+ }
45
+ try:
46
+ response = requests.get(rss_url, headers=headers, timeout=30)
47
+ response.raise_for_status() # 针对 HTTP 错误抛出异常
48
+ return response.content
49
+ except requests.exceptions.RequestException as e:
50
+ print(f"获取 RSS feed 时出错: {e}")
51
+ return None
52
+
53
+ def parse_rss_xml_content(rss_content: bytes) -> Optional[PodcastChannel]:
54
+ """
55
+ 解析播客 RSS XML 内容,并返回其主要信息和剧集详情。
56
+
57
+ 参数:
58
+ rss_content: bytes 类型的 RSS XML 内容。
59
+
60
+ 返回:
61
+ 一个包含已解析信息的 PodcastChannel 对象,如果解析失败则返回 None。
62
+ """
63
+ feed = feedparser.parse(rss_content)
64
+
65
+ if feed.bozo:
66
+ # 如果 feed 格式不正确,bozo 为 True
67
+ # feed.bozo_exception 包含异常信息
68
+ print(f"警告: RSS feed 可能格式不正确。Bozo 异常: {feed.bozo_exception}")
69
+ # 即使格式不完全正确,feedparser 通常仍会尝试解析,所以我们不在此处直接返回 None
70
+ # 但如果关键的 feed 或 channel_info 缺失,后续会自然失败
71
+
72
+ channel_info = feed.get('feed', {})
73
+ if not channel_info: # 如果连基本的 feed 结构都没有,则认为解析失败
74
+ print("错误: RSS 内容无法解析为有效的 feed 结构。")
75
+ return None
76
+
77
+ podcast_channel = PodcastChannel(
78
+ title=channel_info.get('title'),
79
+ link=channel_info.get('link'),
80
+ description=channel_info.get('subtitle') or channel_info.get('description'),
81
+ language=channel_info.get('language'),
82
+ image_url=channel_info.get('image', {}).get('href') if channel_info.get('image') else None,
83
+ author=channel_info.get('author') or channel_info.get('itunes_author'),
84
+ last_build_date=_parse_date(channel_info.get('updated_parsed') or channel_info.get('published_parsed'))
85
+ )
86
+
87
+ for entry in feed.entries:
88
+ # 确定 shownotes:优先使用 content:encoded,然后是 itunes:summary,其次是 description/summary
89
+ shownotes = None
90
+ # 1. 优先尝试 <content:encoded>
91
+ # entry.content 是一个 FeedParserDict 对象列表
92
+ if 'content' in entry and entry.content:
93
+ for content_item in entry.content:
94
+ # 检查 content_item 是否有 value 属性并且该值非空
95
+ if hasattr(content_item, 'value') and content_item.value:
96
+ shownotes = content_item.value
97
+ break # 找到第一个有效的 content:encoded,停止查找
98
+
99
+ # 2. 如果没有从 content:encoded 获得,尝试 itunes:summary
100
+ if not shownotes and 'itunes_summary' in entry:
101
+ shownotes = entry.itunes_summary
102
+
103
+ # 3. 最后回退到 summary 或 description
104
+ if not shownotes: # 回退到 summary 或 description
105
+ shownotes = entry.get('summary') or entry.get('description')
106
+
107
+ # 从 enclosures 获取音频 URL
108
+ audio_url = None
109
+ if 'enclosures' in entry:
110
+ for enc in entry.enclosures:
111
+ if enc.get('type', '').startswith('audio/'):
112
+ audio_url = enc.get('href')
113
+ break
114
+
115
+ # 解析特定于剧集的 iTunes 标签
116
+ itunes_season = None
117
+ try:
118
+ itunes_season_str = entry.get('itunes_season')
119
+ if itunes_season_str:
120
+ itunes_season = int(itunes_season_str)
121
+ except (ValueError, TypeError):
122
+ pass # 如果不是有效整数则忽略
123
+
124
+ itunes_episode_number = None
125
+ try:
126
+ itunes_episode_number_str = entry.get('itunes_episode')
127
+ if itunes_episode_number_str:
128
+ itunes_episode_number = int(itunes_episode_number_str)
129
+ except (ValueError, TypeError):
130
+ pass # 如果不是有效整数则忽略
131
+
132
+ episode = PodcastEpisode(
133
+ title=entry.get('title'),
134
+ link=entry.get('link'),
135
+ published_date=_parse_date(entry.get('published_parsed')),
136
+ summary=entry.get('summary'), # 这通常是较短的版本
137
+ shownotes=shownotes, # 这是我们尝试获取的更详细版本
138
+ audio_url=audio_url,
139
+ guid=entry.get('id') or entry.get('guid'),
140
+ duration=entry.get('itunes_duration'),
141
+ episode_type=entry.get('itunes_episodetype'),
142
+ season=itunes_season,
143
+ episode_number=itunes_episode_number
144
+ )
145
+ podcast_channel.episodes.append(episode)
146
+
147
+ return podcast_channel
148
+
149
+ def parse_podcast_rss(rss_url: str) -> Optional[PodcastChannel]:
150
+ """
151
+ 从给定的 RSS URL 获取并解析播客数据。
152
+
153
+ 参数:
154
+ rss_url: 播客 RSS feed 的 URL。
155
+
156
+ 返回:
157
+ 一个包含已解析信息的 PodcastChannel 对象,如果获取或解析失败则返回 None。
158
+ """
159
+ rss_content = fetch_rss_content(rss_url)
160
+ if rss_content:
161
+ return parse_rss_xml_content(rss_content)
162
+ return None
src/podcast_transcribe/schemas.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import List, Optional, Dict, Union
3
+ from datetime import datetime
4
+
5
+ @dataclass
6
+ class EnhancedSegment:
7
+ """增强的转录分段,包含说话人信息"""
8
+ start: float # 开始时间(秒)
9
+ end: float # 结束时间(秒)
10
+ text: str # 转录的文本
11
+ speaker: str # 说话人ID
12
+ language: str # 检测到的语言
13
+ speaker_name: Optional[str] = None # 识别出的说话人名称
14
+
15
+
16
+ @dataclass
17
+ class CombinedTranscriptionResult:
18
+ """结合ASR和说话人分离的转录结果"""
19
+ segments: List[EnhancedSegment] # 包含说话人和文本的分段
20
+ text: str # 完整转录文本
21
+ language: str # 检测到的语言
22
+ num_speakers: int # 检测到的说话人数量
23
+
24
+
25
+ @dataclass
26
+ class PodcastEpisode:
27
+ title: Optional[str] = None
28
+ link: Optional[str] = None
29
+ published_date: Optional[datetime] = None
30
+ summary: Optional[str] = None # 简短摘要
31
+ shownotes: Optional[str] = None # 详细的shownotes,通常是HTML格式
32
+ audio_url: Optional[str] = None
33
+ guid: Optional[str] = None
34
+ duration: Optional[str] = None # 例如,来自 <itunes:duration>
35
+ episode_type: Optional[str] = None # 例如,来自 <itunes:episodetype>
36
+ season: Optional[int] = None # 例如,来自 <itunes:season>
37
+ episode_number: Optional[int] = None # 例如,来自 <itunes:episode>
38
+
39
+ @dataclass
40
+ class PodcastChannel:
41
+ title: Optional[str] = None
42
+ link: Optional[str] = None
43
+ description: Optional[str] = None
44
+ language: Optional[str] = None
45
+ image_url: Optional[str] = None
46
+ author: Optional[str] = None # 例如,来自 <itunes:author>
47
+ last_build_date: Optional[datetime] = None
48
+ episodes: List[PodcastEpisode] = field(default_factory=list)
49
+
50
+
51
+ @dataclass
52
+ class TranscriptionResult:
53
+ """转录结果数据类"""
54
+ text: str # 转录的文本
55
+ segments: List[Dict[str, Union[float, str]]] # 包含时间戳的分段
56
+ language: str # 检测到的语言
57
+
58
+
59
+ @dataclass
60
+ class DiarizationResult:
61
+ """说话人分离结果数据类"""
62
+ segments: List[Dict[str, Union[float, str, int]]] # 包含时间戳和说话人ID的分段
63
+ num_speakers: int # 检测到的说话人数量
src/podcast_transcribe/summary/speaker_identify.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Optional
2
+ import json
3
+ import re
4
+
5
+ from ..schemas import EnhancedSegment, PodcastChannel, PodcastEpisode
6
+ from ..llm import llm_router
7
+
8
+
9
+ class SpeakerIdentifier:
10
+ """
11
+ 说话人识别器类,用于根据转录分段和播客元数据识别说话人的真实姓名或昵称
12
+ """
13
+
14
+ def __init__(self, llm_model_name: str, llm_provider: str):
15
+ """
16
+ 初始化说话人识别器
17
+
18
+ 参数:
19
+ llm_model_name: LLM模型名称,如果为None则使用默认模型
20
+ llm_provider: LLM提供者,默认为"gemma-mlx"
21
+ """
22
+ self.llm_model_name = llm_model_name
23
+ self.llm_provider = llm_provider
24
+
25
+ def _clean_html(self, html_string: Optional[str]) -> str:
26
+ """
27
+ 简单地从字符串中移除HTML标签并清理多余空白。
28
+ """
29
+ if not html_string:
30
+ return ""
31
+ # 移除HTML标签
32
+ text = re.sub(r'<[^>]+>', ' ', html_string)
33
+ # 替换HTML实体(简单版本,只处理常见几个)
34
+ text = text.replace('&nbsp;', ' ').replace('&amp;', '&').replace('&lt;', '<').replace('&gt;', '>')
35
+ # 移除多余的空白符
36
+ text = re.sub(r'\\s+', ' ', text).strip()
37
+ return text
38
+
39
+ def _get_dialogue_samples(
40
+ self,
41
+ segments: List[EnhancedSegment],
42
+ max_samples_per_speaker: int = 3, # 增加样本数量
43
+ max_length_per_sample: int = 200 # 增加样本长度
44
+ ) -> Dict[str, List[str]]:
45
+ """
46
+ 为每个说话人提取对话样本。
47
+ """
48
+ speaker_dialogues: Dict[str, List[str]] = {}
49
+ for segment in segments:
50
+ speaker = segment.speaker
51
+ if speaker == "UNKNOWN" or not segment.text.strip(): # 跳过未知说话人或空文本
52
+ continue
53
+
54
+ if speaker not in speaker_dialogues:
55
+ speaker_dialogues[speaker] = []
56
+
57
+ if len(speaker_dialogues[speaker]) < max_samples_per_speaker:
58
+ text_sample = segment.text.strip()[:max_length_per_sample]
59
+ if len(segment.text.strip()) > max_length_per_sample:
60
+ text_sample += "..."
61
+ speaker_dialogues[speaker].append(text_sample)
62
+ return speaker_dialogues
63
+
64
+ def recognize_speaker_names(
65
+ self,
66
+ segments: List[EnhancedSegment],
67
+ podcast_info: Optional[PodcastChannel],
68
+ episode_info: Optional[PodcastEpisode],
69
+ max_shownotes_length: int = 1500,
70
+ max_desc_length: int = 500
71
+ ) -> Dict[str, str]:
72
+ """
73
+ 使用LLM根据转录分段和播客/剧集元数据识别说话人的真实姓名或昵称。
74
+
75
+ 参数:
76
+ segments: 转录后的 EnhancedSegment 列表。
77
+ podcast_info: 包含播客元数据的 PodcastChannel 对象。
78
+ episode_info: 包含单集播客元数据的 PodcastEpisode 对象。
79
+ max_shownotes_length: 用于Prompt的 Shownotes 最大字符数。
80
+ max_desc_length: 用于Prompt的播客描述最大字符数。
81
+
82
+ 返回:
83
+ 一个字典,键是原始的 "SPEAKER_XX",值是识别出的说话人名称。
84
+ """
85
+ unique_speaker_ids = sorted(list(set(seg.speaker for seg in segments if seg.speaker != "UNKNOWN" and seg.text.strip())))
86
+ if not unique_speaker_ids:
87
+ print("未能从 segments 中提取到有效的 speaker_ids。")
88
+ return {}
89
+
90
+ dialogue_samples = self._get_dialogue_samples(segments)
91
+
92
+ # 增加每个说话人的话语分析信息,包括话语频率和长度
93
+ speaker_stats = {}
94
+ for segment in segments:
95
+ speaker = segment.speaker
96
+ if speaker == "UNKNOWN" or not segment.text.strip():
97
+ continue
98
+
99
+ if speaker not in speaker_stats:
100
+ speaker_stats[speaker] = {
101
+ "total_segments": 0,
102
+ "total_chars": 0,
103
+ "avg_segment_length": 0,
104
+ "intro_likely": False # 是否有介绍性质的话语
105
+ }
106
+
107
+ speaker_stats[speaker]["total_segments"] += 1
108
+ speaker_stats[speaker]["total_chars"] += len(segment.text)
109
+
110
+ # 检测可能的自我介绍或他人介绍
111
+ lower_text = segment.text.lower()
112
+ intro_patterns = [
113
+ r'欢迎来到', r'欢迎收听', r'我是', r'我叫', r'大家好', r'今天的嘉宾是', r'我们请到了',
114
+ r'welcome to', r'i\'m your host', r'this is', r'today we have', r'joining us',
115
+ r'our guest', r'my name is'
116
+ ]
117
+ if any(re.search(pattern, lower_text) for pattern in intro_patterns):
118
+ speaker_stats[speaker]["intro_likely"] = True
119
+
120
+ # 计算平均话语长度
121
+ for speaker, stats in speaker_stats.items():
122
+ if stats["total_segments"] > 0:
123
+ stats["avg_segment_length"] = stats["total_chars"] / stats["total_segments"]
124
+
125
+ # 创建增强的说话人信息,包含统计数据
126
+ speaker_info_for_prompt = []
127
+ for speaker_id in unique_speaker_ids:
128
+ samples = dialogue_samples.get(speaker_id, ["(No dialogue samples available)"])
129
+ stats = speaker_stats.get(speaker_id, {"total_segments": 0, "avg_segment_length": 0, "intro_likely": False})
130
+
131
+ speaker_info_for_prompt.append({
132
+ "speaker_id": speaker_id,
133
+ "dialogue_samples": samples,
134
+ "speech_stats": {
135
+ "total_segments": stats["total_segments"],
136
+ "avg_segment_length": round(stats["avg_segment_length"], 2),
137
+ "has_intro_pattern": stats["intro_likely"]
138
+ }
139
+ })
140
+
141
+ # 安全地访问属性,提供默认值
142
+ podcast_title = podcast_info.title if podcast_info and podcast_info.title else "Unknown Podcast"
143
+ podcast_author = podcast_info.author if podcast_info and podcast_info.author else "Unknown"
144
+
145
+ raw_podcast_desc = podcast_info.description if podcast_info and podcast_info.description else ""
146
+ cleaned_podcast_desc = self._clean_html(raw_podcast_desc)
147
+ podcast_desc_for_prompt = cleaned_podcast_desc[:max_desc_length]
148
+ if len(cleaned_podcast_desc) > max_desc_length:
149
+ podcast_desc_for_prompt += "..."
150
+
151
+ episode_title = episode_info.title if episode_info and episode_info.title else "Unknown Episode"
152
+
153
+ raw_episode_summary = episode_info.summary if episode_info and episode_info.summary else ""
154
+ cleaned_episode_summary = self._clean_html(raw_episode_summary)
155
+ episode_summary_for_prompt = cleaned_episode_summary[:max_desc_length] # 使用与描述相同的长度限制
156
+ if len(cleaned_episode_summary) > max_desc_length:
157
+ episode_summary_for_prompt += "..."
158
+
159
+ raw_episode_shownotes = episode_info.shownotes if episode_info and episode_info.shownotes else ""
160
+ cleaned_episode_shownotes = self._clean_html(raw_episode_shownotes)
161
+ episode_shownotes_for_prompt = cleaned_episode_shownotes[:max_shownotes_length]
162
+ if len(cleaned_episode_shownotes) > max_shownotes_length:
163
+ episode_shownotes_for_prompt += "..."
164
+
165
+ system_prompt = """You are an experienced podcast content analyst. Your task is to accurately identify the real names, nicknames, or roles of different speakers (tagged in SPEAKER_XX format) in a podcast episode, based on the provided metadata, episode information, dialogue snippets, and speech patterns. Your analysis should NOT rely on the order of speakers or speaker IDs."""
166
+
167
+ user_prompt_template = f"""
168
+ Contextual Information:
169
+
170
+ 1. **Podcast Information**:
171
+ * Podcast Title: {podcast_title}
172
+ * Podcast Author/Producer: {podcast_author} (This information often points to the main host or production team)
173
+ * Podcast Description: {podcast_desc_for_prompt}
174
+
175
+ 2. **Current Episode Information**:
176
+ * Episode Title: {episode_title}
177
+ * Episode Summary: {episode_summary_for_prompt}
178
+ * Detailed Episode Notes (Shownotes):
179
+ ```text
180
+ {episode_shownotes_for_prompt}
181
+ ```
182
+ (Pay close attention to any host names, guest names, positions, or social media handles mentioned in the Shownotes.)
183
+
184
+ 3. **Speakers to Identify and Their Information**:
185
+ ```json
186
+ {json.dumps(speaker_info_for_prompt, ensure_ascii=False, indent=2)}
187
+ ```
188
+ (Analyze dialogue samples and speech statistics to understand speaker roles and identities. DO NOT use speaker IDs to determine roles - SPEAKER_00 is not necessarily the host.)
189
+
190
+ Task:
191
+ Based on all the information above, assign the most accurate name or role to each "speaker_id".
192
+
193
+ Analysis Guidance:
194
+ * A host typically has more frequent, shorter segments, often introduces the show or guests, and may mention the podcast name
195
+ * In panel discussion formats, there might be multiple hosts or co-hosts of similar speaking patterns
196
+ * In interview formats, the host typically asks questions while guests give longer answers
197
+ * Speakers who make introductory statements or welcome listeners are likely hosts
198
+ * Use dialogue content (not just speaking patterns) to identify names and roles
199
+
200
+ Output Requirements and Guidelines:
201
+ * Please return the result strictly in JSON format. The keys of the JSON object should be the original "speaker_id" (e.g., "SPEAKER_00"), and the values should be the identified person's name or role (string type).
202
+ * **Prioritize Specific Names/Nicknames**: If there is sufficient information (e.g., guests explicitly listed in Shownotes, or names mentioned in dialogue), please use the identified specific names, such as "John Doe", "AI Assistant", "Dr. Evelyn Reed". Do NOT append roles like "(Host)" or "(Guest)" if a specific name is found.
203
+ * **Host Identification**:
204
+ * Hosts may be identified by analyzing speech patterns - they often speak more frequently in shorter segments
205
+ * Look for introduction patterns in dialogue where speakers welcome listeners or introduce the show
206
+ * The podcast author (if provided and credible) is often a host but verify through dialogue
207
+ * There may be multiple hosts (co-hosts) in panel-style podcasts
208
+ * If a host's name is identified, use the identified name directly (e.g., "Lex Fridman"). Do not append "(Host)".
209
+ * If the host's name cannot be determined but the role is clearly a host, use "Podcast Host".
210
+ * **Guest Identification**:
211
+ * Guests often give longer responses and speak less frequently than hosts
212
+ * For other non-host speakers, if a specific name is identified, use the identified name directly (e.g., "John Carmack"). Do not append "(Guest)".
213
+ * If specific names cannot be identified for guests, label them sequentially as "Guest 1", "Guest 2", etc.
214
+ * **Handling Multiple Hosts/Guests**: If there are multiple hosts or guests and they can be distinguished by name, use their names. If you cannot distinguish specific identities but know there are multiple hosts, use "Host 1", "Host 2", etc. Similarly for guests without specific names, use "Guest 1", "Guest 2".
215
+ * **Ensure Completeness**: The returned JSON object must include all "speaker_id"s listed in the input as keys.
216
+
217
+ JSON Output Example:
218
+ ```json
219
+ {{
220
+ "SPEAKER_00": "Jane Smith",
221
+ "SPEAKER_01": "Podcast Host",
222
+ "SPEAKER_02": "Alex Green"
223
+ }}
224
+ ```
225
+ Note that in this example, SPEAKER_01 is identified as the host, not SPEAKER_00, based on content analysis, not ID order.
226
+
227
+ Please begin your analysis and provide the JSON result.
228
+ """
229
+
230
+ messages = [
231
+ {"role": "system", "content": system_prompt},
232
+ {"role": "user", "content": user_prompt_template}
233
+ ]
234
+
235
+ # 预设默认映射,使用更智能的启发式方法而不是简单依赖顺序
236
+ final_map = {}
237
+
238
+ # 尝试使用说话模式启发式方法来初步识别角色
239
+ # 1. 说话次数最多的可能是主持人
240
+ # 2. 有介绍性话语的可能是主持人
241
+ # 3. 其他角色先标记为嘉宾
242
+
243
+ host_candidates = []
244
+ for speaker_id, stats in speaker_stats.items():
245
+ if stats["intro_likely"]:
246
+ host_candidates.append((speaker_id, 2)) # 优先级2:有介绍性话语
247
+ else:
248
+ # 按说话次数排序
249
+ host_candidates.append((speaker_id, stats["total_segments"]))
250
+
251
+ # 按可能性排序(介绍性话语 > 说话次数)
252
+ host_candidates.sort(key=lambda x: (-1 if x[1] == 2 else 0, x[1]), reverse=True)
253
+
254
+ if host_candidates:
255
+ # 最可能的主持人
256
+ host_id = host_candidates[0][0]
257
+ final_map[host_id] = "Podcast Host"
258
+
259
+ # 其他人先标为嘉宾
260
+ guest_counter = 1
261
+ for speaker_id in unique_speaker_ids:
262
+ if speaker_id != host_id:
263
+ final_map[speaker_id] = f"Guest {guest_counter}"
264
+ guest_counter += 1
265
+ else:
266
+ # 如果没有明显线索,使用传统的顺序方法作为备选
267
+ is_host_assigned = False
268
+ guest_counter = 1
269
+ for speaker_id in unique_speaker_ids:
270
+ if not is_host_assigned:
271
+ final_map[speaker_id] = "Podcast Host"
272
+ is_host_assigned = True
273
+ else:
274
+ final_map[speaker_id] = f"Guest {guest_counter}"
275
+ guest_counter += 1
276
+
277
+ try:
278
+ response = llm_router.chat_completion(
279
+ messages=messages,
280
+ provider=self.llm_provider,
281
+ model=self.llm_model_name,
282
+ temperature=0.1,
283
+ max_tokens=1024
284
+ )
285
+ assistant_response_content = response["choices"][0]["message"]["content"]
286
+
287
+ parsed_llm_output = None
288
+ # 尝试从Markdown代码块中提取JSON
289
+ json_match = re.search(r'```json\s*(\{.*?\})\s*```', assistant_response_content, re.DOTALL)
290
+ if json_match:
291
+ json_str = json_match.group(1)
292
+ else:
293
+ # 如果没有markdown块,尝试找到第一个 '{' 到最后一个 '}'
294
+ first_brace = assistant_response_content.find('{')
295
+ last_brace = assistant_response_content.rfind('}')
296
+ if first_brace != -1 and last_brace != -1 and last_brace > first_brace:
297
+ json_str = assistant_response_content[first_brace : last_brace+1]
298
+ else: # 如果还是找不到,就认为整个回复都是JSON(可能需要更复杂的清理)
299
+ json_str = assistant_response_content.strip()
300
+
301
+ try:
302
+ parsed_llm_output = json.loads(json_str)
303
+ if not isinstance(parsed_llm_output, dict): # 确保解析出来是字典
304
+ print(f"LLM返回的JSON不是一个字典: {parsed_llm_output}")
305
+ parsed_llm_output = None # 重置,以便使用默认值
306
+ except json.JSONDecodeError as e:
307
+ print(f"LLM返回的JSON解析失败: {e}")
308
+ print(f"用于解析的字符串: '{json_str}'")
309
+ # parsed_llm_output 保持为 None,将使用默认值
310
+
311
+ if parsed_llm_output:
312
+ # 直接使用LLM的有效输出,不再依赖预设的角色分配逻辑
313
+ final_map = {}
314
+ unknown_counter = 1
315
+
316
+ # 先处理LLM识别出的角色
317
+ for spk_id in unique_speaker_ids:
318
+ if spk_id in parsed_llm_output and isinstance(parsed_llm_output[spk_id], str) and parsed_llm_output[spk_id].strip():
319
+ final_map[spk_id] = parsed_llm_output[spk_id].strip()
320
+ else:
321
+ # 如果LLM没有给出特定ID的结果,使用"Unknown Speaker"
322
+ final_map[spk_id] = f"Unknown Speaker {unknown_counter}"
323
+ unknown_counter += 1
324
+
325
+ # 检查是否有"Host"或"主持人"标识
326
+ has_host = any("主持人" in name or "Host" in name for name in final_map.values())
327
+
328
+ # 如果没有任何主持人标识,且存在"Unknown Speaker",可以考虑将最活跃的未知说话人设为主持人
329
+ if not has_host and any("Unknown Speaker" in name for name in final_map.values()):
330
+ # 找出最活跃的未知说话人
331
+ most_active_unknown = None
332
+ max_segments = 0
333
+
334
+ for spk_id, name in final_map.items():
335
+ if "Unknown Speaker" in name and spk_id in speaker_stats:
336
+ if speaker_stats[spk_id]["total_segments"] > max_segments:
337
+ max_segments = speaker_stats[spk_id]["total_segments"]
338
+ most_active_unknown = spk_id
339
+
340
+ if most_active_unknown:
341
+ final_map[most_active_unknown] = "Podcast Host"
342
+
343
+ return final_map
344
+
345
+ except Exception as e:
346
+ import traceback
347
+ print(f"调用LLM或处理响应时发生严重错误: {e}")
348
+ print(traceback.format_exc())
349
+ # 发生任何严重错误,返回初始的启发式映射
350
+ return final_map
src/podcast_transcribe/transcriber.py ADDED
@@ -0,0 +1,588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 整合ASR和说话人分离的转录器模块,支持流式处理长语音对话
3
+ """
4
+
5
+ import os
6
+ from pydub import AudioSegment
7
+ from typing import Dict, List, Union, Optional, Any
8
+ import logging
9
+ from concurrent.futures import ThreadPoolExecutor
10
+ import re
11
+
12
+ from .summary.speaker_identify import SpeakerIdentifier # 新增导入
13
+
14
+ # 导入ASR和说话人分离模块,使用相对导入
15
+ from .asr import asr_router
16
+ from .asr.asr_base import TranscriptionResult
17
+ from .diarization import diarizer_router
18
+ from .schemas import EnhancedSegment, CombinedTranscriptionResult, PodcastChannel, PodcastEpisode, DiarizationResult
19
+
20
+ # 配置日志
21
+ logger = logging.getLogger("podcast_transcribe")
22
+
23
+ class CombinedTranscriber:
24
+ """整合ASR和说话人分离的转录器"""
25
+
26
+ def __init__(
27
+ self,
28
+ asr_model_name: str,
29
+ asr_provider: str,
30
+ diarization_provider: str,
31
+ diarization_model_name: str,
32
+ llm_model_name: Optional[str] = None,
33
+ llm_provider: Optional[str] = None,
34
+ hf_token: Optional[str] = None,
35
+ device: Optional[str] = None,
36
+ segmentation_batch_size: int = 64,
37
+ parallel: bool = False,
38
+ ):
39
+ """
40
+ 初始化转录器
41
+
42
+ 参数:
43
+ asr_model_name: ASR模型名称
44
+ asr_provider: ASR提供者名称
45
+ diarization_provider: 说话人分离提供者名称
46
+ diarization_model_name: 说话人分离模型名称
47
+ hf_token: Hugging Face令牌
48
+ device: 推理设备,'cpu'或'cuda'
49
+ segmentation_batch_size: 分割批处理大小,默认为64
50
+ parallel: 是否并行执行ASR和说话人分离,默认为False
51
+ """
52
+ if not device:
53
+ import torch
54
+ if torch.backends.mps.is_available():
55
+ device = "mps"
56
+ if not llm_model_name:
57
+ llm_model_name = "mlx-community/gemma-3-12b-it-4bit-DWQ"
58
+ if not llm_provider:
59
+ llm_provider = "gemma-mlx"
60
+
61
+ elif torch.cuda.is_available():
62
+ device = "cuda"
63
+ if not llm_model_name:
64
+ llm_model_name = "google/gemma-3-12b-it"
65
+ if not llm_provider:
66
+ llm_provider = "gemma-transformers"
67
+ else:
68
+ device = "cpu"
69
+ if not llm_model_name:
70
+ llm_model_name = "google/gemma-3-12b-it"
71
+ if not llm_provider:
72
+ llm_provider = "gemma-transformers"
73
+
74
+ self.asr_model_name = asr_model_name
75
+ self.asr_provider = asr_provider
76
+ self.diarization_provider = diarization_provider
77
+ self.diarization_model_name = diarization_model_name
78
+ self.hf_token = hf_token or os.environ.get("HF_TOKEN")
79
+ self.device = device
80
+ self.segmentation_batch_size = segmentation_batch_size
81
+ self.parallel = parallel
82
+
83
+ self.speaker_identifier = SpeakerIdentifier(
84
+ llm_model_name=llm_model_name,
85
+ llm_provider=llm_provider
86
+ )
87
+
88
+ logger.info(f"初始化组合转录器,ASR提供者: {asr_provider},ASR模型: {asr_model_name},分离提供者: {diarization_provider},分离模型: {diarization_model_name},分割批处理大小: {segmentation_batch_size},并行执行: {parallel},推理设备: {device}")
89
+
90
+
91
+ def _merge_adjacent_text_segments(self, segments: List[EnhancedSegment]) -> List[EnhancedSegment]:
92
+ """
93
+ 合并相邻的、可能属于同一句子的 EnhancedSegment。
94
+ 合并条件:同一说话人,时间基本连续,文本内容可拼接。
95
+ """
96
+ if not segments:
97
+ return []
98
+
99
+ merged_segments: List[EnhancedSegment] = []
100
+ if not segments: # 重复检查,可移除
101
+ return merged_segments
102
+
103
+ current_merged_segment = segments[0]
104
+
105
+ for i in range(1, len(segments)):
106
+ next_segment = segments[i]
107
+
108
+ time_gap_seconds = next_segment.start - current_merged_segment.end
109
+
110
+ can_merge_text = False
111
+ if current_merged_segment.text and next_segment.text:
112
+ current_text_stripped = current_merged_segment.text.strip()
113
+ if current_text_stripped and not current_text_stripped[-1] in ".。?!?!":
114
+ can_merge_text = True
115
+
116
+ if (current_merged_segment.speaker == next_segment.speaker and
117
+ 0 <= time_gap_seconds < 0.75 and
118
+ can_merge_text):
119
+ current_merged_segment = EnhancedSegment(
120
+ start=current_merged_segment.start,
121
+ end=next_segment.end,
122
+ text=(current_merged_segment.text.strip() + " " + next_segment.text.strip()).strip(),
123
+ speaker=current_merged_segment.speaker,
124
+ language=current_merged_segment.language
125
+ )
126
+ else:
127
+ merged_segments.append(current_merged_segment)
128
+ current_merged_segment = next_segment
129
+
130
+ merged_segments.append(current_merged_segment)
131
+
132
+ return merged_segments
133
+
134
+ def _run_asr(self, audio: AudioSegment) -> TranscriptionResult:
135
+ """执行ASR处理"""
136
+ logger.debug("执行ASR...")
137
+ return asr_router.transcribe_audio(
138
+ audio,
139
+ provider=self.asr_provider,
140
+ model_name=self.asr_model_name,
141
+ device=self.device
142
+ )
143
+
144
+ def _run_diarization(self, audio: AudioSegment) -> DiarizationResult:
145
+ """执行说话人分离处理"""
146
+ logger.debug("执行说话人分离...")
147
+ return diarizer_router.diarize_audio(
148
+ audio,
149
+ provider=self.diarization_provider,
150
+ model_name=self.diarization_model_name,
151
+ token=self.hf_token,
152
+ device=self.device,
153
+ segmentation_batch_size=self.segmentation_batch_size
154
+ )
155
+
156
+ def transcribe(self, audio: AudioSegment) -> CombinedTranscriptionResult:
157
+ """
158
+ 转录整个音频 (新的非流式逻辑将在这里实现)
159
+
160
+ 参数:
161
+ audio: 要转录的AudioSegment对象
162
+
163
+ 返回:
164
+ 包含完整转录和说话人信息的结果
165
+ """
166
+ logger.info(f"开始转录 {len(audio)/1000:.2f} 秒的音频 (非流式)")
167
+
168
+ if self.parallel:
169
+ # 并行执行ASR和说话人分离
170
+ logger.info("并行执行ASR和说话人分离")
171
+ with ThreadPoolExecutor(max_workers=2) as executor:
172
+ asr_future = executor.submit(self._run_asr, audio)
173
+ diarization_future = executor.submit(self._run_diarization, audio)
174
+
175
+ asr_result: TranscriptionResult = asr_future.result()
176
+ diarization_result: DiarizationResult = diarization_future.result()
177
+
178
+ logger.debug(f"ASR完成,识别语言: {asr_result.language},得到 {len(asr_result.segments)} 个分段")
179
+ logger.debug(f"说话人分离完成,得到 {len(diarization_result.segments)} 个说话人分段,检测到 {diarization_result.num_speakers} 个说话人")
180
+ else:
181
+ # 顺序执行ASR和说话人分离
182
+ # 步骤1: 对整个音频执行ASR
183
+ logger.debug("执行ASR...")
184
+ asr_result: TranscriptionResult = asr_router.transcribe_audio(
185
+ audio,
186
+ provider=self.asr_provider,
187
+ model_name=self.asr_model_name,
188
+ device=self.device
189
+ )
190
+ logger.debug(f"ASR完成,识别语言: {asr_result.language},得到 {len(asr_result.segments)} 个分段")
191
+
192
+ # 步骤2: 对整个音频执行说话人分离
193
+ logger.debug("执行说话人分离...")
194
+ diarization_result: DiarizationResult = diarizer_router.diarize_audio(
195
+ audio,
196
+ provider=self.diarization_provider,
197
+ model_name=self.diarization_model_name,
198
+ token=self.hf_token,
199
+ device=self.device,
200
+ segmentation_batch_size=self.segmentation_batch_size
201
+ )
202
+ logger.debug(f"说话人分离完成,得到 {len(diarization_result.segments)} 个说话人分段,检测到 {diarization_result.num_speakers} 个说话人")
203
+
204
+ # 步骤3: 创建增强分段
205
+ all_enhanced_segments: List[EnhancedSegment] = self._create_enhanced_segments_with_splitting(
206
+ asr_result.segments,
207
+ diarization_result.segments,
208
+ asr_result.language
209
+ )
210
+
211
+ # 步骤4: (可选)合并相邻的文本分段
212
+ if all_enhanced_segments:
213
+ logger.debug(f"合并前有 {len(all_enhanced_segments)} 个增强分段,尝试合并相邻分段...")
214
+ final_segments = self._merge_adjacent_text_segments(all_enhanced_segments)
215
+ logger.debug(f"合并后有 {len(final_segments)} 个增强分段")
216
+ else:
217
+ final_segments = []
218
+ logger.debug("没有增强分段可供合并。")
219
+
220
+ # 整理合并的文本
221
+ full_text = " ".join([segment.text for segment in final_segments]).strip()
222
+
223
+ # 计算最终说话人数
224
+ num_speakers_set = set(s.speaker for s in final_segments if s.speaker != "UNKNOWN")
225
+
226
+ return CombinedTranscriptionResult(
227
+ segments=final_segments,
228
+ text=full_text,
229
+ language=asr_result.language or "unknown",
230
+ num_speakers=len(num_speakers_set) if num_speakers_set else diarization_result.num_speakers
231
+ )
232
+
233
+ # 新方法:根据标点分割ASR文本片段
234
+ def _split_asr_segment_by_punctuation(
235
+ self,
236
+ asr_seg_text: str,
237
+ asr_seg_start: float,
238
+ asr_seg_end: float
239
+ ) -> List[Dict[str, Any]]:
240
+ """
241
+ 根据标点符号分割ASR文本片段,并按字符比例估算子片段的时间戳。
242
+ 返回: 字典列表,每个字典包含 'text', 'start', 'end'。
243
+ """
244
+ sentence_terminators = ".。?!?!;;"
245
+ # 正则表达式:匹配句子内容以及紧随其后的标点(如果存在)
246
+ # 使用 re.split 保留分隔符,然后重组
247
+ parts = re.split(f'([{sentence_terminators}])', asr_seg_text)
248
+
249
+ sub_texts_final = []
250
+ current_s = ""
251
+ for s_part in parts:
252
+ if not s_part:
253
+ continue
254
+ current_s += s_part
255
+ if s_part in sentence_terminators:
256
+ if current_s.strip():
257
+ sub_texts_final.append(current_s.strip())
258
+ current_s = ""
259
+ if current_s.strip():
260
+ sub_texts_final.append(current_s.strip())
261
+
262
+ if not sub_texts_final or (len(sub_texts_final) == 1 and sub_texts_final[0] == asr_seg_text.strip()):
263
+ # 没有有效分割或分割后只有一个句子(等于原始文本)
264
+ return [{"text": asr_seg_text.strip(), "start": asr_seg_start, "end": asr_seg_end}]
265
+
266
+ output_sub_segments = []
267
+ total_text_len = len(asr_seg_text) # 使用原始文本长度进行比例计算
268
+ if total_text_len == 0:
269
+ return [{"text": "", "start": asr_seg_start, "end": asr_seg_end}]
270
+
271
+ current_time = asr_seg_start
272
+ original_duration = asr_seg_end - asr_seg_start
273
+
274
+ for i, sub_text in enumerate(sub_texts_final):
275
+ sub_len = len(sub_text)
276
+ sub_duration = (sub_len / total_text_len) * original_duration
277
+
278
+ sub_start_time = current_time
279
+ sub_end_time = current_time + sub_duration
280
+
281
+ # 对于最后一个分片,确保其结束时间与原始分段的结束时间一致,以避免累积误差
282
+ if i == len(sub_texts_final) - 1:
283
+ sub_end_time = asr_seg_end
284
+
285
+ # 确保结束时间不超过原始结束时间,并且开始时间不晚于结束时间
286
+ sub_end_time = min(sub_end_time, asr_seg_end)
287
+ if sub_start_time >= sub_end_time and sub_start_time == asr_seg_end : # 如果开始等于原始结束,允许微小片段
288
+ if sub_text: # 仅当有文本时
289
+ output_sub_segments.append({"text": sub_text, "start": sub_start_time, "end": sub_end_time})
290
+ elif sub_start_time < sub_end_time :
291
+ output_sub_segments.append({"text": sub_text, "start": sub_start_time, "end": sub_end_time})
292
+
293
+ current_time = sub_end_time
294
+ if current_time >= asr_seg_end and i < len(sub_texts_final) -1: # 如果时间已用完,但还有句子
295
+ # 将剩余句子附加到最后一个有效的时间段,或创建零长度的段
296
+ logger.warning(f"时间已在分割过程中用尽,但仍有文本未分配时间。原始段: [{asr_seg_start}-{asr_seg_end}], 当前子句: '{sub_text}'")
297
+ # 为后续未分配时间的文本创建零时长或极短时长的片段,附着在末尾
298
+ for k in range(i + 1, len(sub_texts_final)):
299
+ remaining_text = sub_texts_final[k]
300
+ if remaining_text:
301
+ output_sub_segments.append({"text": remaining_text, "start": asr_seg_end, "end": asr_seg_end})
302
+ break
303
+
304
+
305
+ # 如果处理后没有任何子分段(例如原始文本为空,或分割逻辑问题),返回原始信息作为一个分段
306
+ if not output_sub_segments and asr_seg_text.strip():
307
+ return [{"text": asr_seg_text.strip(), "start": asr_seg_start, "end": asr_seg_end}]
308
+ elif not output_sub_segments and not asr_seg_text.strip():
309
+ return [{"text": "", "start": asr_seg_start, "end": asr_seg_end}]
310
+
311
+
312
+ return output_sub_segments
313
+
314
+ # 新的核心方法:创建增强分段,包含说话人分配和按需分裂逻辑
315
+ def _create_enhanced_segments_with_splitting(
316
+ self,
317
+ asr_segments: List[Dict[str, Union[float, str]]],
318
+ diarization_segments: List[Dict[str, Union[float, str, int]]],
319
+ language: str
320
+ ) -> List[EnhancedSegment]:
321
+ """
322
+ 为ASR分段分配说话人,如果ASR分段跨越多个说话人,则尝试按标点分裂。
323
+ """
324
+ final_enhanced_segments: List[EnhancedSegment] = []
325
+
326
+ if not asr_segments:
327
+ return []
328
+
329
+ # 为了快速查找,可以预处理 diarization_segments,但对于数量不多的情况,直接遍历也可
330
+ # diarization_segments.sort(key=lambda x: x['start']) # 确保有序
331
+
332
+ for asr_seg in asr_segments:
333
+ asr_start = float(asr_seg["start"])
334
+ asr_end = float(asr_seg["end"])
335
+ asr_text = str(asr_seg["text"]).strip()
336
+
337
+ if not asr_text or asr_start >= asr_end: # 跳过无效的ASR分段
338
+ continue
339
+
340
+ # 找出与当前ASR分段在时间上重叠的所有说话人分段
341
+ overlapping_diar_segs = []
342
+ for diar_seg in diarization_segments:
343
+ diar_start = float(diar_seg["start"])
344
+ diar_end = float(diar_seg["end"])
345
+
346
+ overlap_start = max(asr_start, diar_start)
347
+ overlap_end = min(asr_end, diar_end)
348
+
349
+ if overlap_end > overlap_start: # 有重叠
350
+ overlapping_diar_segs.append({
351
+ "speaker": str(diar_seg["speaker"]),
352
+ "start": diar_start,
353
+ "end": diar_end,
354
+ "overlap_duration": overlap_end - overlap_start
355
+ })
356
+
357
+ distinct_speakers_in_overlap = set(d['speaker'] for d in overlapping_diar_segs)
358
+
359
+ segments_to_process_further: List[Dict[str, Any]] = []
360
+
361
+ if len(distinct_speakers_in_overlap) > 1:
362
+ logger.debug(f"ASR段 [{asr_start:.2f}-{asr_end:.2f}] \"{asr_text[:50]}...\" 跨越 {len(distinct_speakers_in_overlap)} 个说话人。尝试按标点分裂。")
363
+ # 跨多个说话人,尝试按标点分裂ASR segment
364
+ sub_asr_segments_data = self._split_asr_segment_by_punctuation(
365
+ asr_text,
366
+ asr_start,
367
+ asr_end
368
+ )
369
+ if len(sub_asr_segments_data) > 1:
370
+ logger.debug(f"成功将ASR段分裂成 {len(sub_asr_segments_data)} 个子句。")
371
+ segments_to_process_further.extend(sub_asr_segments_data)
372
+ else:
373
+ # 单一说话人或无说话人重叠(也视为单一处理单位)
374
+ segments_to_process_further.append({"text": asr_text, "start": asr_start, "end": asr_end})
375
+
376
+ # 为每个原始或分裂后的ASR(子)分段分配说话人
377
+ for current_proc_seg_data in segments_to_process_further:
378
+ proc_text = current_proc_seg_data["text"].strip()
379
+ proc_start = current_proc_seg_data["start"]
380
+ proc_end = current_proc_seg_data["end"]
381
+
382
+ if not proc_text or proc_start >= proc_end: # 跳过无效的子分段
383
+ continue
384
+
385
+ # 为当前处理的(可能是子)分段确定最佳说话人
386
+ speaker_overlaps_for_proc_seg = {}
387
+ for diar_seg_info in overlapping_diar_segs: # 使用之前计算的、与原始ASR段重叠的diar_segs
388
+ # 现在需要计算这个 diar_seg_info 与 proc_seg 的重叠
389
+ overlap_start = max(proc_start, diar_seg_info["start"])
390
+ overlap_end = min(proc_end, diar_seg_info["end"])
391
+
392
+ if overlap_end > overlap_start:
393
+ overlap_duration = overlap_end - overlap_start
394
+ speaker = diar_seg_info["speaker"]
395
+ speaker_overlaps_for_proc_seg[speaker] = \
396
+ speaker_overlaps_for_proc_seg.get(speaker, 0) + overlap_duration
397
+
398
+ best_speaker = "UNKNOWN"
399
+ if speaker_overlaps_for_proc_seg:
400
+ best_speaker = max(speaker_overlaps_for_proc_seg.items(), key=lambda x: x[1])[0]
401
+ elif overlapping_diar_segs: # 如果子分段本身没有重叠,但原始ASR段有
402
+ # 可以选择原始ASR段中占比最大的,或者最近的
403
+ # 为简化,如果子分段无直接重叠,也可能标记为UNKNOWN,或尝试找最近的
404
+ # 这里采用:如果子分段无直接重叠,但在原始ASR段中有说话人,则使用原始ASR段中重叠最长的
405
+ # (此逻辑分支效果待观察,更简单的是直接UNKNOWN)
406
+ # 此处简化:若子分段无重叠,则为UNKNOWN
407
+ pass # best_speaker 默认为 UNKNOWN
408
+
409
+ # 如果 best_speaker 仍为 UNKNOWN,但原始ASR段只有一个说话者,则使用该说话者
410
+ if best_speaker == "UNKNOWN" and len(distinct_speakers_in_overlap) == 1:
411
+ best_speaker = list(distinct_speakers_in_overlap)[0]
412
+ elif best_speaker == "UNKNOWN" and not overlapping_diar_segs:
413
+ # 如果整个ASR段都没有任何说话人信息,则确实是UNKNOWN
414
+ pass
415
+
416
+
417
+ final_enhanced_segments.append(
418
+ EnhancedSegment(
419
+ start=proc_start,
420
+ end=proc_end,
421
+ text=proc_text,
422
+ speaker=best_speaker,
423
+ language=language # 所有子分段继承原始ASR段的语言
424
+ )
425
+ )
426
+
427
+ # 对最终结果按开始时间排序
428
+ final_enhanced_segments.sort(key=lambda seg: seg.start)
429
+ return final_enhanced_segments
430
+
431
+ def transcribe_podcast(
432
+ self,
433
+ audio: AudioSegment,
434
+ podcast_info: PodcastChannel,
435
+ episode_info: PodcastEpisode,
436
+ ) -> CombinedTranscriptionResult:
437
+ """
438
+ 专门针对播客剧集的音频转录方法
439
+
440
+ 参数:
441
+ audio: 要转录的AudioSegment对象
442
+ podcast_info: 播客频道信息
443
+ episode_info: 播客剧集信息
444
+
445
+ 返回:
446
+ 包含完整转录和识别后说话人名称的结果
447
+ """
448
+ logger.info(f"开始转录播客剧集 {len(audio)/1000:.2f} 秒的音频")
449
+
450
+ # 1. 先执行基础转录流程
451
+ transcription_result = self.transcribe(audio)
452
+
453
+ # 3. 识别说话人名称
454
+ logger.info("识别说话人名称...")
455
+ speaker_name_map = self.speaker_identifier.recognize_speaker_names(
456
+ transcription_result.segments,
457
+ podcast_info,
458
+ episode_info
459
+ )
460
+
461
+ # 4. 将识别的说话人名称添加到转录结果中
462
+ enhanced_segments_with_names = []
463
+ for segment in transcription_result.segments:
464
+ # 复制原始段落并添加说话人名称
465
+ speaker_id = segment.speaker
466
+ speaker_name = speaker_name_map.get(speaker_id, None)
467
+
468
+ # 创建新的段落对象,包含说话人名称
469
+ new_segment = EnhancedSegment(
470
+ start=segment.start,
471
+ end=segment.end,
472
+ text=segment.text,
473
+ speaker=speaker_id,
474
+ language=segment.language,
475
+ speaker_name=speaker_name
476
+ )
477
+ enhanced_segments_with_names.append(new_segment)
478
+
479
+ # 5. 创建并返回新的转录结果
480
+ return CombinedTranscriptionResult(
481
+ segments=enhanced_segments_with_names,
482
+ text=transcription_result.text,
483
+ language=transcription_result.language,
484
+ num_speakers=transcription_result.num_speakers
485
+ )
486
+
487
+
488
+ def transcribe_audio(
489
+ audio_segment: AudioSegment,
490
+ asr_model_name: str = "distil-whisper/distil-large-v3.5",
491
+ asr_provider: str = "distil_whisper_transformers",
492
+ diarization_model_name: str = "pyannote/speaker-diarization-3.1",
493
+ diarization_provider: str = "pyannote_transformers",
494
+ hf_token: Optional[str] = None,
495
+ device: Optional[str] = None,
496
+ segmentation_batch_size: int = 64,
497
+ parallel: bool = False,
498
+ ) -> CombinedTranscriptionResult: # 返回类型固定为 CombinedTranscriptionResult
499
+ """
500
+ 整合ASR和说话人分离的音频转录函数 (仅支持非流式)
501
+
502
+ 参数:
503
+ audio_segment: 输入的AudioSegment对象
504
+ asr_model_name: ASR模型名称
505
+ asr_provider: ASR提供者名称
506
+ diarization_model_name: 说话人分离模型名称
507
+ diarization_provider: 说话人分离提供者名称
508
+ hf_token: Hugging Face令牌
509
+ device: 推理设备,'cpu'或'cuda'
510
+ segmentation_batch_size: 分割批处理大小,默认为64
511
+ parallel: 是否并行执行ASR和说话人分离,默认为False
512
+
513
+ 返回:
514
+ 完整转录结果
515
+ """
516
+ logger.info(f"调用transcribe_audio函数 (非流式),音频长度: {len(audio_segment)/1000:.2f}秒")
517
+
518
+ transcriber = CombinedTranscriber(
519
+ asr_model_name=asr_model_name,
520
+ asr_provider=asr_provider,
521
+ diarization_model_name=diarization_model_name,
522
+ diarization_provider=diarization_provider,
523
+ hf_token=hf_token,
524
+ device=device,
525
+ segmentation_batch_size=segmentation_batch_size,
526
+ parallel=parallel
527
+ )
528
+
529
+ # 直接调用 transcribe 方法
530
+ return transcriber.transcribe(audio_segment)
531
+
532
+ def transcribe_podcast_audio(
533
+ audio_segment: AudioSegment,
534
+ podcast_info: PodcastChannel,
535
+ episode_info: PodcastEpisode,
536
+ asr_model_name: str = "distil-whisper/distil-large-v3.5",
537
+ asr_provider: str = "distil_whisper_transformers",
538
+ diarization_model_name: str = "pyannote/speaker-diarization-3.1",
539
+ diarization_provider: str = "pyannote_transformers",
540
+ llm_model_name: Optional[str] = None,
541
+ llm_provider: Optional[str] = None,
542
+ hf_token: Optional[str] = None,
543
+ device: Optional[str] = None,
544
+ segmentation_batch_size: int = 64,
545
+ parallel: bool = False,
546
+ ) -> CombinedTranscriptionResult:
547
+ """
548
+ 针对播客剧集的音频转录函数,包含说话人名称识别
549
+
550
+ 参数:
551
+ audio_segment: 输入的AudioSegment对象
552
+ podcast_info: 播客频道信息
553
+ episode_info: 播客剧集信息
554
+ asr_model_name: ASR模型名称
555
+ asr_provider: ASR提供者名称
556
+ diarization_provider: 说话人分离提供者名称
557
+ diarization_model_name: 说话人分离模型名称
558
+ llm_model_name: LLM模型名称,如果为None则无法识别说话人名称
559
+ llm_provider: LLM提供者名称,如果为None则无法识别说话人名称
560
+ hf_token: Hugging Face令牌
561
+ device: 推理设备,'cpu'或'cuda'
562
+ segmentation_batch_size: 分割批处理大小,默认为64
563
+ parallel: 是否并行执行ASR和说话人分离,默认为False
564
+
565
+ 返回:
566
+ 包含说话人名称的完整转录结果
567
+ """
568
+ logger.info(f"调用transcribe_podcast_audio函数,音频长度: {len(audio_segment)/1000:.2f}秒")
569
+
570
+ transcriber = CombinedTranscriber(
571
+ asr_model_name=asr_model_name,
572
+ asr_provider=asr_provider,
573
+ diarization_provider=diarization_provider,
574
+ diarization_model_name=diarization_model_name,
575
+ llm_model_name=llm_model_name,
576
+ llm_provider=llm_provider,
577
+ hf_token=hf_token,
578
+ device=device,
579
+ segmentation_batch_size=segmentation_batch_size,
580
+ parallel=parallel
581
+ )
582
+
583
+ # 调用播客专用转录方法
584
+ return transcriber.transcribe_podcast(
585
+ audio=audio_segment,
586
+ podcast_info=podcast_info,
587
+ episode_info=episode_info,
588
+ )
src/podcast_transcribe/webui/app.py ADDED
@@ -0,0 +1,585 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import requests
3
+ import io
4
+ from pydub import AudioSegment
5
+ import traceback # 用于打印更详细的错误信息
6
+ import tempfile
7
+ import os
8
+ import uuid
9
+ import atexit
10
+ import shutil
11
+
12
+ # 尝试相对导入,这在通过 `python -m src.podcast_transcribe.webui.app` 运行时有效
13
+ try:
14
+ from podcast_transcribe.rss.podcast_rss_parser import parse_podcast_rss
15
+ from podcast_transcribe.schemas import PodcastChannel, PodcastEpisode, CombinedTranscriptionResult, EnhancedSegment
16
+ from podcast_transcribe.transcriber import transcribe_podcast_audio
17
+ except ImportError:
18
+ # 如果直接运行此脚本,并且项目根目录不在PYTHONPATH中,
19
+ # 则需要将项目根目录添加到 sys.path
20
+ import sys
21
+ import os
22
+ # 获取当前脚本文件所在的目录 (src/podcast_transcribe/webui)
23
+ current_dir = os.path.dirname(os.path.abspath(__file__))
24
+ # 获取项目根目录 (向上三级: webui -> podcast_transcribe -> src -> project_root)
25
+ # 修正:应该是 src 的父目录是项目根
26
+ project_root = os.path.dirname(os.path.dirname(os.path.dirname(current_dir)))
27
+ # 将 src 目录添加到 sys.path,因为模块是 podcast_transcribe.xxx
28
+ src_path = os.path.join(project_root, "src")
29
+ if src_path not in sys.path:
30
+ sys.path.insert(0, src_path)
31
+
32
+ from podcast_transcribe.rss.podcast_rss_parser import parse_podcast_rss
33
+ from podcast_transcribe.schemas import PodcastChannel, PodcastEpisode, CombinedTranscriptionResult, EnhancedSegment
34
+ from podcast_transcribe.transcriber import transcribe_podcast_audio
35
+
36
+ # 用于存储应用程序使用的所有临时文件路径
37
+ temp_files = []
38
+
39
+ def cleanup_temp_files():
40
+ """清理应用程序使用的临时文件"""
41
+ global temp_files
42
+ print(f"应用程序退出,清理 {len(temp_files)} 个临时文件...")
43
+
44
+ for filepath in temp_files:
45
+ try:
46
+ if os.path.exists(filepath):
47
+ os.remove(filepath)
48
+ print(f"已删除临时文件: {filepath}")
49
+ except Exception as e:
50
+ print(f"无法删除临时文件 {filepath}: {e}")
51
+
52
+ # 清空列表
53
+ temp_files = []
54
+
55
+ # 注册应用程序退出时的清理函数
56
+ atexit.register(cleanup_temp_files)
57
+
58
+ def parse_rss_feed(rss_url: str):
59
+ """回调函数:解析 RSS Feed"""
60
+ print(f"开始解析RSS: {rss_url}")
61
+
62
+ if not rss_url:
63
+ print("RSS地址为空")
64
+ return {
65
+ status_message_area: gr.update(value="错误:请输入 RSS 地址。"),
66
+ podcast_title_display: gr.update(value="", visible=False),
67
+ episode_dropdown: gr.update(choices=[], value=None, interactive=False),
68
+ podcast_data_state: None,
69
+ audio_player: gr.update(value=None),
70
+ current_audio_url_state: None,
71
+ episode_shownotes: gr.update(value="", visible=False),
72
+ transcription_output_df: gr.update(value=None, headers=["说话人", "文本", "时间"]),
73
+ transcribe_button: gr.update(interactive=False),
74
+ selected_episode_index_state: None
75
+ }
76
+
77
+ try:
78
+ print(f"正在解析RSS: {rss_url}")
79
+ # 先更新状态消息,但由于不再使用生成器,我们直接在解析后更新UI
80
+
81
+ podcast_data: PodcastChannel = parse_podcast_rss(rss_url)
82
+ print(f"RSS解析结果: 频道名称={podcast_data.title if podcast_data else 'None'}, 剧集数量={len(podcast_data.episodes) if podcast_data and podcast_data.episodes else 0}")
83
+
84
+ if podcast_data and podcast_data.episodes:
85
+ choices = []
86
+ for i, episode in enumerate(podcast_data.episodes):
87
+ # 使用 (标题 (时长), guid 或索引) 作为选项
88
+ # 如果 guid 不可靠或缺失,可以使用索引
89
+ label = f"{episode.title or '无标题'} (时长: {episode.duration or '未知'})"
90
+ # 将 episode 对象直接作为值传递,或仅传递一个唯一标识符
91
+ # 为了简单起见,我们使用索引作为唯一ID,因为我们需要从 podcast_data_state 中检索完整的 episode
92
+ choices.append((label, i))
93
+
94
+ # 显示播客标题
95
+ podcast_title = f"## 🎙️ {podcast_data.title or '未知播客'}"
96
+ if podcast_data.author:
97
+ podcast_title += f"\n**主播/制作人:** {podcast_data.author}"
98
+ if podcast_data.description:
99
+ # 限制描述长度,避免界面过长
100
+ description = podcast_data.description[:300]
101
+ if len(podcast_data.description) > 300:
102
+ description += "..."
103
+ podcast_title += f"\n\n**播客简介:** {description}"
104
+
105
+ return {
106
+ status_message_area: gr.update(value=f"成功解析到 {len(podcast_data.episodes)} 个剧集。请选择一个剧集。"),
107
+ podcast_title_display: gr.update(value=podcast_title, visible=True),
108
+ episode_dropdown: gr.update(choices=choices, value=None, interactive=True),
109
+ podcast_data_state: podcast_data,
110
+ audio_player: gr.update(value=None),
111
+ current_audio_url_state: None,
112
+ episode_shownotes: gr.update(value="", visible=False),
113
+ transcription_output_df: gr.update(value=None),
114
+ transcribe_button: gr.update(interactive=False),
115
+ selected_episode_index_state: None
116
+ }
117
+ elif podcast_data: # 有 channel 信息但没有 episodes
118
+ print("解析成功但未找到剧集")
119
+ podcast_title = f"## 🎙️ {podcast_data.title or '未知播客'}"
120
+ if podcast_data.author:
121
+ podcast_title += f"\n**主播/制作人:** {podcast_data.author}"
122
+
123
+ return {
124
+ status_message_area: gr.update(value="解析成功,但未找到任何剧集。"),
125
+ podcast_title_display: gr.update(value=podcast_title, visible=True),
126
+ episode_dropdown: gr.update(choices=[], value=None, interactive=False),
127
+ podcast_data_state: podcast_data, # 仍然存储,以防万一
128
+ audio_player: gr.update(value=None),
129
+ current_audio_url_state: None,
130
+ episode_shownotes: gr.update(value="", visible=False),
131
+ transcription_output_df: gr.update(value=None),
132
+ transcribe_button: gr.update(interactive=False),
133
+ selected_episode_index_state: None
134
+ }
135
+ else:
136
+ print(f"解析RSS失败: {rss_url}")
137
+ return {
138
+ status_message_area: gr.update(value=f"解析 RSS失败: {rss_url}。请检查URL或网络连接。"),
139
+ podcast_title_display: gr.update(value="", visible=False),
140
+ episode_dropdown: gr.update(choices=[], value=None, interactive=False),
141
+ podcast_data_state: None,
142
+ audio_player: gr.update(value=None),
143
+ current_audio_url_state: None,
144
+ episode_shownotes: gr.update(value="", visible=False),
145
+ transcription_output_df: gr.update(value=None),
146
+ transcribe_button: gr.update(interactive=False),
147
+ selected_episode_index_state: None
148
+ }
149
+ except Exception as e:
150
+ print(f"解析 RSS 时发生错误: {e}")
151
+ traceback.print_exc()
152
+ return {
153
+ status_message_area: gr.update(value=f"解析 RSS 时发生严重错误: {e}"),
154
+ podcast_title_display: gr.update(value="", visible=False),
155
+ episode_dropdown: gr.update(choices=[], value=None, interactive=False),
156
+ podcast_data_state: None,
157
+ audio_player: gr.update(value=None),
158
+ current_audio_url_state: None,
159
+ episode_shownotes: gr.update(value="", visible=False),
160
+ transcription_output_df: gr.update(value=None),
161
+ transcribe_button: gr.update(interactive=False),
162
+ selected_episode_index_state: None
163
+ }
164
+
165
+ def load_episode_audio(selected_episode_index: int, podcast_data: PodcastChannel):
166
+ """回调函数:当用户从下拉菜单选择一个剧集时加载音频"""
167
+ global temp_files
168
+ print(f"开始加载剧集音频,选择的索引: {selected_episode_index}")
169
+
170
+ if selected_episode_index is None or podcast_data is None or not podcast_data.episodes:
171
+ print("未选择剧集或无播客数据")
172
+ return {
173
+ audio_player: gr.update(value=None),
174
+ current_audio_url_state: None,
175
+ status_message_area: gr.update(value="请先解析 RSS 并选择一个剧集。"),
176
+ episode_shownotes: gr.update(value="", visible=False),
177
+ transcription_output_df: gr.update(value=None),
178
+ local_audio_file_path: None,
179
+ transcribe_button: gr.update(interactive=False),
180
+ selected_episode_index_state: None
181
+ }
182
+
183
+ try:
184
+ episode = podcast_data.episodes[selected_episode_index]
185
+ audio_url = episode.audio_url
186
+ print(f"获取到剧集信息,标题: {episode.title}, 音频URL: {audio_url}")
187
+
188
+ # 准备剧集信息显示
189
+ episode_shownotes_content = ""
190
+
191
+ # 准备shownotes内容
192
+ if episode.shownotes:
193
+ # 清理HTML标签并格式化shownotes
194
+ import re
195
+ # 简单的HTML标签清理
196
+ clean_shownotes = re.sub(r'<[^>]+>', '', episode.shownotes)
197
+ # 替换HTML实体
198
+ clean_shownotes = clean_shownotes.replace('&nbsp;', ' ').replace('&amp;', '&').replace('&lt;', '<').replace('&gt;', '>')
199
+ # 清理多余空白
200
+ clean_shownotes = re.sub(r'\s+', ' ', clean_shownotes).strip()
201
+
202
+ episode_shownotes_content = f"### 📝 剧集详情\n\n**标题:** {episode.title or '无标题'}\n\n"
203
+ if episode.published_date:
204
+ episode_shownotes_content += f"**发布日期:** {episode.published_date.strftime('%Y年%m月%d日')}\n\n"
205
+ if episode.duration:
206
+ episode_shownotes_content += f"**时长:** {episode.duration}\n\n"
207
+
208
+ episode_shownotes_content += f"**节目介绍:**\n\n{clean_shownotes}"
209
+ elif episode.summary:
210
+ # 如果没有shownotes,使用summary
211
+ episode_shownotes_content = f"### 📝 剧集详情\n\n**标题:** {episode.title or '无标题'}\n\n"
212
+ if episode.published_date:
213
+ episode_shownotes_content += f"**发布日期:** {episode.published_date.strftime('%Y年%m月%d日')}\n\n"
214
+ if episode.duration:
215
+ episode_shownotes_content += f"**时长:** {episode.duration}\n\n"
216
+
217
+ episode_shownotes_content += f"**节目简介:**\n\n{episode.summary}"
218
+ else:
219
+ # 最基本的信息
220
+ episode_shownotes_content = f"### 📝 剧集详情\n\n**标题:** {episode.title or '无标题'}\n\n"
221
+ if episode.published_date:
222
+ episode_shownotes_content += f"**发布日期:** {episode.published_date.strftime('%Y年%m月%d日')}\n\n"
223
+ if episode.duration:
224
+ episode_shownotes_content += f"**时长:** {episode.duration}\n\n"
225
+
226
+ if audio_url:
227
+ # 更新状态消息
228
+ print(f"正在下载音频: {audio_url}")
229
+
230
+ # 下载音频文件
231
+ try:
232
+ headers = {
233
+ 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.114 Safari/537.36'
234
+ }
235
+
236
+ # 创建临时文件
237
+ temp_dir = tempfile.gettempdir()
238
+ # 使用UUID生成唯一文件名,避免冲突
239
+ unique_filename = f"podcast_audio_{uuid.uuid4().hex}"
240
+
241
+ # 先发送一个HEAD请求获取内容类型
242
+ head_response = requests.head(audio_url, timeout=30, headers=headers)
243
+
244
+ # 根据内容类型确定文件扩展名
245
+ content_type = head_response.headers.get('Content-Type', '').lower()
246
+ if 'mp3' in content_type:
247
+ file_ext = '.mp3'
248
+ elif 'mpeg' in content_type:
249
+ file_ext = '.mp3'
250
+ elif 'mp4' in content_type or 'm4a' in content_type:
251
+ file_ext = '.mp4'
252
+ elif 'wav' in content_type:
253
+ file_ext = '.wav'
254
+ elif 'ogg' in content_type:
255
+ file_ext = '.ogg'
256
+ else:
257
+ # 默认扩展名
258
+ file_ext = '.mp3'
259
+
260
+ temp_filepath = os.path.join(temp_dir, unique_filename + file_ext)
261
+
262
+ # 将文件路径添加到全局临时文件列表
263
+ temp_files.append(temp_filepath)
264
+
265
+ # 保存到临时文件
266
+ # 使用流式下载,避免一次性加载整个文件到内存
267
+ with open(temp_filepath, 'wb') as f:
268
+ # 使用流式响应并设置较大的块大小提高效率
269
+ response = requests.get(audio_url, timeout=60, headers=headers, stream=True)
270
+ response.raise_for_status()
271
+
272
+ # 从响应中获取文件大小(如果服务器提供)
273
+ total_size = int(response.headers.get('content-length', 0))
274
+ downloaded = 0
275
+ chunk_size = 8192 # 8KB 的块大小
276
+
277
+ # 分块下载并写入文件
278
+ for chunk in response.iter_content(chunk_size=chunk_size):
279
+ if chunk: # 过滤掉保持连接活跃的空块
280
+ f.write(chunk)
281
+ downloaded += len(chunk)
282
+ # 可以在这里添加下载进度更新
283
+ if total_size > 0:
284
+ download_percentage = downloaded / total_size
285
+ print(f"下载进度: {download_percentage:.1%}")
286
+
287
+ print(f"音频已下载到临时文件: {temp_filepath}")
288
+
289
+ return {
290
+ audio_player: gr.update(value=temp_filepath, label=f"当前播放: {episode.title or '无标题'}"),
291
+ current_audio_url_state: audio_url,
292
+ status_message_area: gr.update(value=f"已加载剧集: {episode.title or '无标题'}。"),
293
+ episode_shownotes: gr.update(value=episode_shownotes_content, visible=True),
294
+ transcription_output_df: gr.update(value=None),
295
+ local_audio_file_path: temp_filepath,
296
+ transcribe_button: gr.update(interactive=True),
297
+ selected_episode_index_state: selected_episode_index
298
+ }
299
+ except requests.exceptions.RequestException as e:
300
+ print(f"下载音频失败: {e}")
301
+ traceback.print_exc()
302
+ return {
303
+ audio_player: gr.update(value=None),
304
+ current_audio_url_state: None,
305
+ status_message_area: gr.update(value=f"错误:下载音频失败: {e}"),
306
+ episode_shownotes: gr.update(value=episode_shownotes_content, visible=True),
307
+ transcription_output_df: gr.update(value=None),
308
+ local_audio_file_path: None,
309
+ transcribe_button: gr.update(interactive=False),
310
+ selected_episode_index_state: None
311
+ }
312
+ else:
313
+ print(f"剧集 '{episode.title}' 缺少有效的音频URL")
314
+ return {
315
+ audio_player: gr.update(value=None),
316
+ current_audio_url_state: None,
317
+ status_message_area: gr.update(value=f"错误:选中的剧集 '{episode.title}' 没有提供有效的音频URL。"),
318
+ episode_shownotes: gr.update(value=episode_shownotes_content, visible=True),
319
+ transcription_output_df: gr.update(value=None),
320
+ local_audio_file_path: None,
321
+ transcribe_button: gr.update(interactive=False),
322
+ selected_episode_index_state: None
323
+ }
324
+ except IndexError:
325
+ print(f"无效的剧集索引: {selected_episode_index}")
326
+ return {
327
+ audio_player: gr.update(value=None),
328
+ current_audio_url_state: None,
329
+ status_message_area: gr.update(value="错误:选择的剧集索引无效。"),
330
+ episode_shownotes: gr.update(value="", visible=False),
331
+ transcription_output_df: gr.update(value=None),
332
+ local_audio_file_path: None,
333
+ transcribe_button: gr.update(interactive=False),
334
+ selected_episode_index_state: None
335
+ }
336
+ except Exception as e:
337
+ print(f"加载音频时发生错误: {e}")
338
+ traceback.print_exc()
339
+ return {
340
+ audio_player: gr.update(value=None),
341
+ current_audio_url_state: None,
342
+ status_message_area: gr.update(value=f"加载音频时发生严重错误: {e}"),
343
+ episode_shownotes: gr.update(value="", visible=False),
344
+ transcription_output_df: gr.update(value=None),
345
+ local_audio_file_path: None,
346
+ transcribe_button: gr.update(interactive=False),
347
+ selected_episode_index_state: None
348
+ }
349
+
350
+ def disable_buttons_before_transcription(local_audio_file_path: str):
351
+ """在开始转录前禁用按钮"""
352
+ print("禁用界面按钮以防止转录期间的交互")
353
+ return {
354
+ parse_button: gr.update(interactive=False),
355
+ episode_dropdown: gr.update(interactive=False),
356
+ transcribe_button: gr.update(interactive=False),
357
+ status_message_area: gr.update(value="开始转录过程,请耐心等待...")
358
+ }
359
+
360
+ def start_transcription(local_audio_file_path: str, podcast_data: PodcastChannel, selected_episode_index: int, progress=gr.Progress(track_tqdm=True)):
361
+ """回调函数:开始转录当前加载的音频"""
362
+ print(f"开始转录本地音频文件: {local_audio_file_path}, 选中剧集索引: {selected_episode_index}")
363
+
364
+ if not local_audio_file_path or not os.path.exists(local_audio_file_path):
365
+ print("没有可用的本地音频文件")
366
+ return {
367
+ transcription_output_df: gr.update(value=None),
368
+ status_message_area: gr.update(value="错误:没有有效的音频文件用于转录。请先选择一个剧集。"),
369
+ parse_button: gr.update(interactive=True),
370
+ episode_dropdown: gr.update(interactive=True),
371
+ transcribe_button: gr.update(interactive=True)
372
+ }
373
+
374
+ try:
375
+ # 先更新状态消息并禁用按钮
376
+ progress(0, desc="初始化转录过程...")
377
+
378
+ # 使用progress回调来更新进度
379
+ progress(0.2, desc="加载音频文件...")
380
+
381
+ # 从文件加载音频
382
+ audio_segment = AudioSegment.from_file(local_audio_file_path)
383
+ print(f"音频加载完成,时长: {len(audio_segment)/1000}秒")
384
+
385
+ progress(0.4, desc="音频加载完成,开始转录 (此过程可能需要较长时间)...")
386
+
387
+ # 获取当前选中的剧集信息
388
+ episode_info = None
389
+ if podcast_data and podcast_data.episodes and selected_episode_index is not None:
390
+ if 0 <= selected_episode_index < len(podcast_data.episodes):
391
+ episode_info = podcast_data.episodes[selected_episode_index]
392
+ print(f"获取到当前选中剧集信息: {episode_info.title if episode_info else '无'}")
393
+
394
+ # 调用转录函数
395
+ print("开始转录音频...")
396
+ result: CombinedTranscriptionResult = transcribe_podcast_audio(audio_segment,
397
+ podcast_info=podcast_data,
398
+ episode_info=episode_info,
399
+ segmentation_batch_size=64,
400
+ parallel=True)
401
+ print(f"转录完成,结果: {result is not None}, 段落数: {len(result.segments) if result and result.segments else 0}")
402
+ progress(0.9, desc="转录完成,正在格式化结果...")
403
+
404
+ if result and result.segments:
405
+ formatted_segments = []
406
+ for seg in result.segments:
407
+ time_str = f"{seg.start:.2f}s - {seg.end:.2f}s"
408
+ formatted_segments.append([seg.speaker, seg.speaker_name, seg.text, time_str])
409
+
410
+ progress(1.0, desc="转录结果已生成!")
411
+ return {
412
+ transcription_output_df: gr.update(value=formatted_segments),
413
+ status_message_area: gr.update(value=f"转录完成!共 {len(result.segments)} 个片段。检测到 {result.num_speakers} 个说话人。"),
414
+ parse_button: gr.update(interactive=True),
415
+ episode_dropdown: gr.update(interactive=True),
416
+ transcribe_button: gr.update(interactive=True)
417
+ }
418
+ elif result: # 有 result 但没有 segments
419
+ progress(1.0, desc="转录完成,但无文本片段")
420
+ return {
421
+ transcription_output_df: gr.update(value=None),
422
+ status_message_area: gr.update(value="转录完成,但未生成任何文本片段。"),
423
+ parse_button: gr.update(interactive=True),
424
+ episode_dropdown: gr.update(interactive=True),
425
+ transcribe_button: gr.update(interactive=True)
426
+ }
427
+ else: # result 为 None
428
+ progress(1.0, desc="转录失败")
429
+ return {
430
+ transcription_output_df: gr.update(value=None),
431
+ status_message_area: gr.update(value="转录失败,未能获取结果。"),
432
+ parse_button: gr.update(interactive=True),
433
+ episode_dropdown: gr.update(interactive=True),
434
+ transcribe_button: gr.update(interactive=True)
435
+ }
436
+ except Exception as e:
437
+ print(f"转录过程中发生错误: {e}")
438
+ traceback.print_exc()
439
+ progress(1.0, desc="转录失败: 处理错误")
440
+ return {
441
+ transcription_output_df: gr.update(value=None),
442
+ status_message_area: gr.update(value=f"转录过程中发生严重错误: {e}"),
443
+ parse_button: gr.update(interactive=True),
444
+ episode_dropdown: gr.update(interactive=True),
445
+ transcribe_button: gr.update(interactive=True)
446
+ }
447
+
448
+ # --- Gradio 界面定义 ---
449
+ with gr.Blocks(title="播客转录工具 v2", css="""
450
+ .status-message-container {
451
+ min-height: 50px;
452
+ height: auto;
453
+ max-height: none;
454
+ overflow-y: visible;
455
+ white-space: normal;
456
+ word-wrap: break-word;
457
+ margin-top: 10px;
458
+ margin-bottom: 10px;
459
+ border-radius: 6px;
460
+ background-color: rgba(32, 36, 45, 0.03);
461
+ border: 1px solid rgba(32, 36, 45, 0.1);
462
+ color: #303030;
463
+ }
464
+ .episode-cover {
465
+ max-width: 300px;
466
+ max-height: 300px;
467
+ border-radius: 8px;
468
+ box-shadow: 0 4px 8px rgba(0,0,0,0.1);
469
+ }
470
+ """) as demo:
471
+ gr.Markdown("# 🎙️ 播客转录工具")
472
+
473
+ # 状态管理
474
+ podcast_data_state = gr.State(None) # 存储解析后的 PodcastChannel 对象
475
+ current_audio_url_state = gr.State(None) # 存储当前选中剧集的音频URL
476
+ local_audio_file_path = gr.State(None) # 存储下载到本地的音频文件路径
477
+ selected_episode_index_state = gr.State(None) # 存储当前选中的剧集索引
478
+
479
+ with gr.Row():
480
+ rss_url_input = gr.Textbox(
481
+ label="播客 RSS 地址",
482
+ placeholder="例如: https://your-podcast-feed.com/rss.xml",
483
+ elem_id="rss-url-input"
484
+ )
485
+ parse_button = gr.Button("🔗 解析 RSS", elem_id="parse-rss-button")
486
+
487
+ status_message_area = gr.Markdown(
488
+ "",
489
+ elem_id="status-message",
490
+ elem_classes="status-message-container", # 添加自定义CSS类
491
+ show_label=False
492
+ )
493
+
494
+ # 播客标题显示区域
495
+ podcast_title_display = gr.Markdown(
496
+ "",
497
+ visible=False,
498
+ elem_id="podcast-title-display"
499
+ )
500
+
501
+ episode_dropdown = gr.Dropdown(
502
+ label="选择剧集",
503
+ choices=[],
504
+ interactive=False, # 初始时不可交互,解析成功后设为 True
505
+ elem_id="episode-dropdown"
506
+ )
507
+
508
+ # 剧集信息显示区域
509
+ with gr.Row():
510
+ with gr.Column(scale=2):
511
+ episode_shownotes = gr.Markdown(
512
+ "",
513
+ visible=False,
514
+ elem_id="episode-shownotes"
515
+ )
516
+
517
+ audio_player = gr.Audio(
518
+ label="播客音频播放器",
519
+ interactive=False, # 音频源由程序控制,用户不能直接修改
520
+ elem_id="audio-player"
521
+ )
522
+
523
+ transcribe_button = gr.Button("🔊 开始转录", elem_id="transcribe-button", interactive=False)
524
+
525
+ gr.Markdown("## 📝 转录结果")
526
+ transcription_output_df = gr.DataFrame(
527
+ headers=["说话人ID", "说话人名称", "转录文本", "起止时间"],
528
+ interactive=False,
529
+ wrap=True, # 允许文本换行
530
+ row_count=(10, "dynamic"), # 显示10行,可滚动
531
+ col_count=(4, "fixed"),
532
+ elem_id="transcription-output"
533
+ )
534
+
535
+ # --- 事件处理 ---
536
+ parse_button.click(
537
+ fn=parse_rss_feed,
538
+ inputs=[rss_url_input],
539
+ outputs=[
540
+ status_message_area,
541
+ podcast_title_display,
542
+ episode_dropdown,
543
+ podcast_data_state,
544
+ audio_player,
545
+ current_audio_url_state,
546
+ episode_shownotes,
547
+ transcription_output_df,
548
+ transcribe_button,
549
+ selected_episode_index_state
550
+ ]
551
+ )
552
+
553
+ episode_dropdown.change(
554
+ fn=load_episode_audio,
555
+ inputs=[episode_dropdown, podcast_data_state],
556
+ outputs=[
557
+ audio_player,
558
+ current_audio_url_state,
559
+ status_message_area,
560
+ episode_shownotes,
561
+ transcription_output_df,
562
+ local_audio_file_path,
563
+ transcribe_button,
564
+ selected_episode_index_state
565
+ ]
566
+ )
567
+
568
+ # 首先禁用按钮,然后执行转录
569
+ transcribe_button.click(
570
+ fn=disable_buttons_before_transcription,
571
+ inputs=[local_audio_file_path],
572
+ outputs=[parse_button, episode_dropdown, transcribe_button, status_message_area]
573
+ ).then(
574
+ fn=start_transcription,
575
+ inputs=[local_audio_file_path, podcast_data_state, selected_episode_index_state],
576
+ outputs=[transcription_output_df, status_message_area, parse_button, episode_dropdown, transcribe_button]
577
+ )
578
+
579
+ if __name__ == "__main__":
580
+ try:
581
+ # demo.launch(debug=True, share=True) # share=True 会生成一个公开链接
582
+ demo.launch(debug=True)
583
+ finally:
584
+ # 确保在应用程序退出时清理临时文件
585
+ cleanup_temp_files()