Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
8289369
1
Parent(s):
e0e9f98
init
Browse files- .DS_Store +0 -0
- .gitignore +131 -0
- .vscode/launch.json +19 -0
- app.py +44 -5
- examples/.DS_Store +0 -0
- examples/combined_podcast_transcription.py +102 -0
- examples/combined_transcription.py +104 -0
- examples/simple_asr.py +80 -0
- examples/simple_diarization.py +79 -0
- examples/simple_llm.py +74 -0
- examples/simple_rss_parser.py +40 -0
- examples/simple_speaker_identify.py +68 -0
- requirements.txt +21 -0
- src/.DS_Store +0 -0
- src/podcast_transcribe/.DS_Store +0 -0
- src/podcast_transcribe/__init__.py +8 -0
- src/podcast_transcribe/asr/asr_base.py +277 -0
- src/podcast_transcribe/asr/asr_distil_whisper_mlx.py +112 -0
- src/podcast_transcribe/asr/asr_distil_whisper_transformers.py +133 -0
- src/podcast_transcribe/asr/asr_parakeet_mlx.py +126 -0
- src/podcast_transcribe/asr/asr_router.py +273 -0
- src/podcast_transcribe/audio.py +62 -0
- src/podcast_transcribe/diarization/diarization_pyannote_mlx.py +154 -0
- src/podcast_transcribe/diarization/diarization_pyannote_transformers.py +170 -0
- src/podcast_transcribe/diarization/diarizer_base.py +118 -0
- src/podcast_transcribe/diarization/diarizer_router.py +276 -0
- src/podcast_transcribe/llm/llm_base.py +391 -0
- src/podcast_transcribe/llm/llm_gemma_mlx.py +62 -0
- src/podcast_transcribe/llm/llm_gemma_transfomers.py +61 -0
- src/podcast_transcribe/llm/llm_phi4_transfomers.py +369 -0
- src/podcast_transcribe/llm/llm_router.py +578 -0
- src/podcast_transcribe/rss/podcast_rss_parser.py +162 -0
- src/podcast_transcribe/schemas.py +63 -0
- src/podcast_transcribe/summary/speaker_identify.py +350 -0
- src/podcast_transcribe/transcriber.py +588 -0
- src/podcast_transcribe/webui/app.py +585 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
.gitignore
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
*.egg-info/
|
24 |
+
.installed.cfg
|
25 |
+
*.egg
|
26 |
+
MANIFEST
|
27 |
+
|
28 |
+
# PyInstaller
|
29 |
+
*.manifest
|
30 |
+
*.spec
|
31 |
+
|
32 |
+
# Installer logs
|
33 |
+
pip-log.txt
|
34 |
+
pip-delete-this-directory.txt
|
35 |
+
|
36 |
+
# Unit test / coverage reports
|
37 |
+
htmlcov/
|
38 |
+
.tox/
|
39 |
+
.nox/
|
40 |
+
.coverage
|
41 |
+
.coverage.*
|
42 |
+
.cache
|
43 |
+
nosetests.xml
|
44 |
+
coverage.xml
|
45 |
+
*.cover
|
46 |
+
.hypothesis/
|
47 |
+
.pytest_cache/
|
48 |
+
|
49 |
+
# Translations
|
50 |
+
*.mo
|
51 |
+
*.pot
|
52 |
+
|
53 |
+
# Django stuff:
|
54 |
+
*.log
|
55 |
+
local_settings.py
|
56 |
+
db.sqlite3
|
57 |
+
|
58 |
+
# Flask stuff:
|
59 |
+
instance/
|
60 |
+
.webassets-cache
|
61 |
+
|
62 |
+
# Scrapy stuff:
|
63 |
+
.scrapy
|
64 |
+
|
65 |
+
# Sphinx documentation
|
66 |
+
docs/_build/
|
67 |
+
|
68 |
+
# PyBuilder
|
69 |
+
target/
|
70 |
+
|
71 |
+
# Jupyter Notebook
|
72 |
+
.ipynb_checkpoints
|
73 |
+
|
74 |
+
# IPython
|
75 |
+
profile_default/
|
76 |
+
ipython_config.py
|
77 |
+
|
78 |
+
# pyenv
|
79 |
+
.python-version
|
80 |
+
|
81 |
+
# celery beat schedule file
|
82 |
+
celerybeat-schedule
|
83 |
+
|
84 |
+
# SageMath parsed files
|
85 |
+
*.sage.py
|
86 |
+
|
87 |
+
# Environments
|
88 |
+
.env
|
89 |
+
.venv
|
90 |
+
env/
|
91 |
+
venv/
|
92 |
+
ENV/
|
93 |
+
env.bak/
|
94 |
+
venv.bak/
|
95 |
+
|
96 |
+
# Spyder project settings
|
97 |
+
.spyderproject
|
98 |
+
.spyproject
|
99 |
+
|
100 |
+
# Rope project settings
|
101 |
+
.ropeproject
|
102 |
+
|
103 |
+
# mkdocs documentation
|
104 |
+
/site
|
105 |
+
|
106 |
+
# mypy
|
107 |
+
.mypy_cache/
|
108 |
+
.dmypy.json
|
109 |
+
dmypy.json
|
110 |
+
|
111 |
+
# Pyre type checker
|
112 |
+
.pyre/
|
113 |
+
|
114 |
+
# uv
|
115 |
+
.uv/
|
116 |
+
|
117 |
+
# Output files
|
118 |
+
output/
|
119 |
+
logs/
|
120 |
+
|
121 |
+
# Large models
|
122 |
+
*.bin
|
123 |
+
*.pth
|
124 |
+
*.pt
|
125 |
+
*.onnx
|
126 |
+
|
127 |
+
examples/input/
|
128 |
+
examples/output/
|
129 |
+
|
130 |
+
# temp files
|
131 |
+
_temp_*
|
.vscode/launch.json
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
// Use IntelliSense to learn about possible attributes.
|
3 |
+
// Hover to view descriptions of existing attributes.
|
4 |
+
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
5 |
+
"version": "0.2.0",
|
6 |
+
"configurations": [
|
7 |
+
{
|
8 |
+
"name": "Python Debugger: Current File",
|
9 |
+
"type": "debugpy",
|
10 |
+
"request": "launch",
|
11 |
+
"program": "${file}",
|
12 |
+
"console": "integratedTerminal",
|
13 |
+
"env": {
|
14 |
+
// "HTTPS_PROXY": "http://127.0.0.1:12334",
|
15 |
+
// "HF_ENDPOINT": "https://hf-mirror.com"
|
16 |
+
}
|
17 |
+
}
|
18 |
+
]
|
19 |
+
}
|
app.py
CHANGED
@@ -1,7 +1,46 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
播客转录工具 - 主启动文件
|
4 |
+
这个文件用于启动 Gradio WebUI 应用
|
5 |
+
"""
|
6 |
|
7 |
+
import sys
|
8 |
+
import os
|
9 |
|
10 |
+
# 将 src 目录添加到 Python 路径中
|
11 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
12 |
+
src_path = os.path.join(current_dir, "src")
|
13 |
+
if src_path not in sys.path:
|
14 |
+
sys.path.insert(0, src_path)
|
15 |
+
|
16 |
+
def main():
|
17 |
+
"""主函数:启动 WebUI 应用"""
|
18 |
+
try:
|
19 |
+
# 导入并启动 webui 应用
|
20 |
+
from podcast_transcribe.webui.app import demo
|
21 |
+
|
22 |
+
print("🎙️ 启动播客转录工具...")
|
23 |
+
print("📍 WebUI 将在浏览器中打开")
|
24 |
+
print("🔗 默认地址: http://localhost:7860")
|
25 |
+
print("⏹️ 按 Ctrl+C 停止服务")
|
26 |
+
|
27 |
+
# 启动 Gradio 应用
|
28 |
+
demo.launch(
|
29 |
+
debug=True,
|
30 |
+
server_name="0.0.0.0", # 允许外部访问
|
31 |
+
server_port=7860, # 指定端口
|
32 |
+
share=False, # 不生成公开链接
|
33 |
+
inbrowser=True # 自动在浏览器中打开
|
34 |
+
)
|
35 |
+
|
36 |
+
except ImportError as e:
|
37 |
+
print(f"❌ 导入错误: {e}")
|
38 |
+
print("请确保已安装所有依赖包:")
|
39 |
+
print("pip install -r requirements.txt")
|
40 |
+
sys.exit(1)
|
41 |
+
except Exception as e:
|
42 |
+
print(f"❌ 启动失败: {e}")
|
43 |
+
sys.exit(1)
|
44 |
+
|
45 |
+
if __name__ == "__main__":
|
46 |
+
main()
|
examples/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
examples/combined_podcast_transcription.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 添加项目根目录到Python路径
|
2 |
+
import json
|
3 |
+
import sys
|
4 |
+
from pathlib import Path
|
5 |
+
import os
|
6 |
+
|
7 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
8 |
+
|
9 |
+
from src.podcast_transcribe.transcriber import transcribe_podcast_audio
|
10 |
+
from src.podcast_transcribe.audio import load_audio
|
11 |
+
from src.podcast_transcribe.rss.podcast_rss_parser import parse_rss_xml_content
|
12 |
+
from podcast_transcribe.llm.llm_gemma_mlx import GemmaMLXChatCompletion
|
13 |
+
from src.podcast_transcribe.schemas import EnhancedSegment, CombinedTranscriptionResult
|
14 |
+
from src.podcast_transcribe.summary.speaker_identify import recognize_speaker_names
|
15 |
+
|
16 |
+
def main():
|
17 |
+
"""主函数"""
|
18 |
+
podcast_rss_xml_file = Path.joinpath(Path(__file__).parent, "input", "lexfridman.com.rss.xml")
|
19 |
+
audio_file = Path.joinpath(Path(__file__).parent, "input", "lex_ai_john_carmack_1.wav") # 播客音频文件路径
|
20 |
+
# audio_file = Path("/Users/konie/Desktop/voices/lex_ai_john_carmack_30.wav")
|
21 |
+
|
22 |
+
# 模型配置
|
23 |
+
asr_model_name = "mlx-community/parakeet-tdt-0.6b-v2" # ASR模型名称
|
24 |
+
diarization_model_name = "pyannote/speaker-diarization-3.1" # 说话人分离模型名称
|
25 |
+
llm_model_path = "mlx-community/gemma-3-12b-it-4bit-DWQ"
|
26 |
+
hf_token = "" # Hugging Face API 令牌
|
27 |
+
device = "mps" # 设备类型
|
28 |
+
segmentation_batch_size = 64
|
29 |
+
parallel = True
|
30 |
+
|
31 |
+
# 检查文件是否存在
|
32 |
+
if not os.path.exists(audio_file):
|
33 |
+
print(f"错误:文件 '{audio_file}' 不存在")
|
34 |
+
return 1
|
35 |
+
|
36 |
+
# 检查HF令牌
|
37 |
+
if not hf_token:
|
38 |
+
print("警告:未设置HF_TOKEN环境变量,必须设置此环境变量才能使用pyannote说话人分离模型")
|
39 |
+
print("请执行:export HF_TOKEN='你的HuggingFace令牌'")
|
40 |
+
return 1
|
41 |
+
|
42 |
+
try:
|
43 |
+
print(f"正在加载音频文件: {audio_file}")
|
44 |
+
# 加载音频文件
|
45 |
+
audio, _ = load_audio(audio_file)
|
46 |
+
|
47 |
+
print(f"音频信息: 时长={audio.duration_seconds:.2f}秒, 通道数={audio.channels}, 采样率={audio.frame_rate}Hz")
|
48 |
+
|
49 |
+
|
50 |
+
except Exception as e:
|
51 |
+
print(f"错误: {str(e)}")
|
52 |
+
import traceback
|
53 |
+
traceback.print_exc()
|
54 |
+
return 1
|
55 |
+
|
56 |
+
# Load the podcast RSS XML file
|
57 |
+
with open(podcast_rss_xml_file, "r") as f:
|
58 |
+
podcast_rss_xml = f.read()
|
59 |
+
mock_podcast_info = parse_rss_xml_content(podcast_rss_xml)
|
60 |
+
|
61 |
+
|
62 |
+
# 查找标题已 "#309" 开头的剧集
|
63 |
+
mock_episode_info = next((episode for episode in mock_podcast_info.episodes if episode.title.startswith("#309")), None)
|
64 |
+
if not mock_episode_info:
|
65 |
+
raise ValueError("Could not find episode with title starting with '#309'")
|
66 |
+
|
67 |
+
result = transcribe_podcast_audio(audio,
|
68 |
+
podcast_info=mock_podcast_info,
|
69 |
+
episode_info=mock_episode_info,
|
70 |
+
hf_token=hf_token,
|
71 |
+
asr_model_name=asr_model_name,
|
72 |
+
diarization_model_name=diarization_model_name,
|
73 |
+
llm_model_name=llm_model_path,
|
74 |
+
device=device,
|
75 |
+
segmentation_batch_size=segmentation_batch_size,
|
76 |
+
parallel=parallel,
|
77 |
+
llm_model_name=llm_model_path)
|
78 |
+
|
79 |
+
# 输出结果
|
80 |
+
print("\n转录结果:")
|
81 |
+
print("-" * 50)
|
82 |
+
print(f"检测到的语言: {result.language}")
|
83 |
+
print(f"检测到的说话人数量: {result.num_speakers}")
|
84 |
+
print(f"总文本长度: {len(result.text)} 字符")
|
85 |
+
|
86 |
+
# 输出每个说话人的部分
|
87 |
+
speakers = set(segment.speaker for segment in result.segments)
|
88 |
+
for speaker in sorted(speakers):
|
89 |
+
speaker_segments = [seg for seg in result.segments if seg.speaker == speaker]
|
90 |
+
total_duration = sum(seg.end - seg.start for seg in speaker_segments)
|
91 |
+
print(f"\n说话人 {speaker}: 共 {len(speaker_segments)} 个片段, 总时长 {total_duration:.2f} 秒")
|
92 |
+
|
93 |
+
# 输出详细分段信息
|
94 |
+
print("\n详细分段信息:")
|
95 |
+
for i, segment in enumerate(result.segments, 1):
|
96 |
+
if i <= 20 or i > len(result.segments) - 20: # 仅显示前20个和后20个分段
|
97 |
+
print(f"段落 {i}/{len(result.segments)}: [{segment.start:.2f}s - {segment.end:.2f}s] 说话人: {segment.speaker_name if segment.speaker_name else segment.speaker} 文本: {segment.text}")
|
98 |
+
elif i == 21:
|
99 |
+
print("... 省略中间部分 ...")
|
100 |
+
|
101 |
+
if __name__ == '__main__':
|
102 |
+
main()
|
examples/combined_transcription.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
整合ASR和说话人分离的示例程序
|
5 |
+
从本地文件读取音频,同时进行转录和说话人分离
|
6 |
+
"""
|
7 |
+
|
8 |
+
import json
|
9 |
+
import os
|
10 |
+
import sys
|
11 |
+
from pathlib import Path
|
12 |
+
from dataclasses import asdict
|
13 |
+
|
14 |
+
# 添加项目根目录到Python路径
|
15 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
16 |
+
|
17 |
+
# 导入必要的模块,使用正确的导入路径
|
18 |
+
from src.podcast_transcribe.audio import load_audio
|
19 |
+
from src.podcast_transcribe.transcriber import transcribe_audio
|
20 |
+
|
21 |
+
|
22 |
+
def main():
|
23 |
+
"""主函数"""
|
24 |
+
audio_file = Path.joinpath(Path(__file__).parent, "input", "lex_ai_john_carmack_1.wav") # 播客音频文件路径
|
25 |
+
# audio_file = Path("/Users/konie/Desktop/voices/lex_ai_john_carmack_30.wav")
|
26 |
+
|
27 |
+
# 模型配置
|
28 |
+
asr_model_name = "distil-whisper/distil-large-v3.5" # ASR模型名称
|
29 |
+
diarization_model_name = "pyannote/speaker-diarization-3.1" # 说话人分离模型名称
|
30 |
+
hf_token = "" # Hugging Face API 令牌
|
31 |
+
device = "mps" # 设备类型
|
32 |
+
segmentation_batch_size = 64
|
33 |
+
parallel = True
|
34 |
+
|
35 |
+
# 检查文件是否存在
|
36 |
+
if not os.path.exists(audio_file):
|
37 |
+
print(f"错误:文件 '{audio_file}' 不存在")
|
38 |
+
return 1
|
39 |
+
|
40 |
+
# 检查HF令牌
|
41 |
+
if not hf_token:
|
42 |
+
print("警告:未设置HF_TOKEN环境变量,必须设置此环境变量才能使用pyannote说话人分离模型")
|
43 |
+
print("请执行:export HF_TOKEN='你的HuggingFace令牌'")
|
44 |
+
return 1
|
45 |
+
|
46 |
+
try:
|
47 |
+
print(f"正在加载音频文件: {audio_file}")
|
48 |
+
# 加载音频文件
|
49 |
+
audio, _ = load_audio(audio_file)
|
50 |
+
|
51 |
+
print(f"音频信息: 时长={audio.duration_seconds:.2f}秒, 通道数={audio.channels}, 采样率={audio.frame_rate}Hz")
|
52 |
+
|
53 |
+
result = transcribe_audio(
|
54 |
+
audio,
|
55 |
+
asr_model_name=asr_model_name,
|
56 |
+
diarization_model_name=diarization_model_name,
|
57 |
+
hf_token=hf_token,
|
58 |
+
device=device,
|
59 |
+
segmentation_batch_size=segmentation_batch_size,
|
60 |
+
parallel=parallel,
|
61 |
+
)
|
62 |
+
|
63 |
+
# 输出结果
|
64 |
+
print("\n转录结果:")
|
65 |
+
print("-" * 50)
|
66 |
+
print(f"检测到的语言: {result.language}")
|
67 |
+
print(f"检测到的说话人数量: {result.num_speakers}")
|
68 |
+
print(f"总文本长度: {len(result.text)} 字符")
|
69 |
+
|
70 |
+
# 输出每个说话人的部分
|
71 |
+
speakers = set(segment.speaker for segment in result.segments)
|
72 |
+
for speaker in sorted(speakers):
|
73 |
+
speaker_segments = [seg for seg in result.segments if seg.speaker == speaker]
|
74 |
+
total_duration = sum(seg.end - seg.start for seg in speaker_segments)
|
75 |
+
print(f"\n说话人 {speaker}: 共 {len(speaker_segments)} 个片段, 总时长 {total_duration:.2f} 秒")
|
76 |
+
|
77 |
+
# 输出详细分段信息
|
78 |
+
print("\n详细分段信息:")
|
79 |
+
for i, segment in enumerate(result.segments, 1):
|
80 |
+
if i <= 20 or i > len(result.segments) - 20: # 仅显示前20个和后20个分段
|
81 |
+
print(f"段落 {i}/{len(result.segments)}: [{segment.start:.2f}s - {segment.end:.2f}s] 说话人: {segment.speaker} 文本: {segment.text}")
|
82 |
+
elif i == 21:
|
83 |
+
print("... 省略中间部分 ...")
|
84 |
+
|
85 |
+
# 将转录结果保存为json文件,文件名取自音频文件名
|
86 |
+
output_file = Path.joinpath(Path(__file__).parent, "output", f"{audio_file.stem}.transcription.json")
|
87 |
+
# 创建上层文件夹
|
88 |
+
output_dir = Path.joinpath(Path(__file__).parent, "output")
|
89 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
90 |
+
with open(output_file, "w") as f:
|
91 |
+
json.dump(asdict(result), f)
|
92 |
+
print(f"转录结果已保存到 {output_file}")
|
93 |
+
|
94 |
+
return 0
|
95 |
+
|
96 |
+
except Exception as e:
|
97 |
+
print(f"错误: {str(e)}")
|
98 |
+
import traceback
|
99 |
+
traceback.print_exc()
|
100 |
+
return 1
|
101 |
+
|
102 |
+
|
103 |
+
if __name__ == "__main__":
|
104 |
+
sys.exit(main())
|
examples/simple_asr.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
简单的语音识别示例程序
|
5 |
+
从本地文件读取音频并进行转录
|
6 |
+
"""
|
7 |
+
|
8 |
+
import logging
|
9 |
+
import os
|
10 |
+
import sys
|
11 |
+
from pathlib import Path
|
12 |
+
|
13 |
+
# 添加项目根目录到Python路径
|
14 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
15 |
+
|
16 |
+
from src.podcast_transcribe.audio import load_audio
|
17 |
+
|
18 |
+
logger = logging.getLogger("asr_example")
|
19 |
+
|
20 |
+
|
21 |
+
def main():
|
22 |
+
"""主函数"""
|
23 |
+
# audio_file = Path.joinpath(Path(__file__).parent, "input", "lex_ai_john_carmack_1.wav") # 播客音频文件路径
|
24 |
+
# audio_file = "/Users/konie/Desktop/voices/lex_ai_john_carmack_30.wav" # 播客音频文件路径
|
25 |
+
audio_file = "/Users/konie/Desktop/voices/podcast1_1.wav"
|
26 |
+
# model = "distil-whisper"
|
27 |
+
model = "distil-whisper-transformers"
|
28 |
+
|
29 |
+
device = "mlx"
|
30 |
+
|
31 |
+
# 检查文件是否存在
|
32 |
+
if not os.path.exists(audio_file):
|
33 |
+
print(f"错误:文件 '{audio_file}' 不存在")
|
34 |
+
return 1
|
35 |
+
|
36 |
+
if model == "parakeet":
|
37 |
+
from src.podcast_transcribe.asr.asr_parakeet_mlx import transcribe_audio
|
38 |
+
model_name = "mlx-community/parakeet-tdt-0.6b-v2"
|
39 |
+
logger.info(f"使用Parakeet模型: {model_name}")
|
40 |
+
elif model == "distil-whisper": # distil-whisper
|
41 |
+
from src.podcast_transcribe.asr.asr_distil_whisper_mlx import transcribe_audio
|
42 |
+
model_name = "mlx-community/distil-whisper-large-v3"
|
43 |
+
logger.info(f"使用Distil Whisper模型: {model_name}")
|
44 |
+
elif model == "distil-whisper-transformers": # distil-whisper
|
45 |
+
from src.podcast_transcribe.asr.asr_distil_whisper_transformers import transcribe_audio
|
46 |
+
model_name = "distil-whisper/distil-large-v3.5"
|
47 |
+
logger.info(f"使用Distil Whisper模型: {model_name}")
|
48 |
+
else:
|
49 |
+
logger.error(f"错误:未指定模型类型")
|
50 |
+
return 1
|
51 |
+
|
52 |
+
try:
|
53 |
+
print(f"正在加载音频文件: {audio_file}")
|
54 |
+
# 加载音频文件
|
55 |
+
audio, _ = load_audio(audio_file)
|
56 |
+
|
57 |
+
print(f"音频信息: 时长={audio.duration_seconds:.2f}秒, 通道数={audio.channels}, 采样率={audio.frame_rate}Hz")
|
58 |
+
|
59 |
+
# 进行转录
|
60 |
+
print("开始转录...")
|
61 |
+
result = transcribe_audio(audio, model_name=model_name, device=device)
|
62 |
+
|
63 |
+
# 输出结果
|
64 |
+
print("\n转录结果:")
|
65 |
+
print("-" * 50)
|
66 |
+
print(f"检测到的语言: {result.language}")
|
67 |
+
print(f"完整文本: {result.text}")
|
68 |
+
print("\n分段信息:")
|
69 |
+
for i, segment in enumerate(result.segments, 1):
|
70 |
+
print(f"分段 {i}: [{segment['start']:.2f}s - {segment['end']:.2f}s] {segment['text']}")
|
71 |
+
|
72 |
+
return 0
|
73 |
+
|
74 |
+
except Exception as e:
|
75 |
+
print(f"错误: {str(e)}")
|
76 |
+
return 1
|
77 |
+
|
78 |
+
|
79 |
+
if __name__ == "__main__":
|
80 |
+
main()
|
examples/simple_diarization.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
简单的说话人标注示例程序
|
5 |
+
从本地文件读取音频并进行说话人分离
|
6 |
+
"""
|
7 |
+
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
# 添加项目根目录到Python路径
|
13 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
14 |
+
|
15 |
+
from src.podcast_transcribe.audio import load_audio
|
16 |
+
from src.podcast_transcribe.diarization.diarization_pyannote_mlx import diarize_audio as diarize_audio_mlx
|
17 |
+
from src.podcast_transcribe.diarization.diarization_pyannote_transformers import diarize_audio as diarize_audio_transformers
|
18 |
+
|
19 |
+
|
20 |
+
def main():
|
21 |
+
"""主函数"""
|
22 |
+
audio_file = Path.joinpath(Path(__file__).parent, "input", "lex_ai_john_carmack_1.wav") # 播客音频文件路径
|
23 |
+
# audio_file = "/Users/konie/Desktop/voices/history_in_the_baking.mp3" # 播客音频文件路径
|
24 |
+
model_name = "pyannote/speaker-diarization-3.1" # 说话人分离模型名称
|
25 |
+
hf_token = "" # Hugging Face API 令牌
|
26 |
+
device = "mps" # 设备类型
|
27 |
+
|
28 |
+
# 检查文件是否存在
|
29 |
+
if not os.path.exists(audio_file):
|
30 |
+
print(f"错误:文件 '{audio_file}' 不存在")
|
31 |
+
return 1
|
32 |
+
|
33 |
+
# 检查令牌是否设置
|
34 |
+
if not hf_token:
|
35 |
+
print("错误:未设置HF_TOKEN环境变量,请设置后再运行")
|
36 |
+
return 1
|
37 |
+
|
38 |
+
try:
|
39 |
+
print(f"正在加载音频文件: {audio_file}")
|
40 |
+
# 加载音频文件
|
41 |
+
audio, _ = load_audio(audio_file)
|
42 |
+
|
43 |
+
print(f"音频信息: 时长={audio.duration_seconds:.2f}秒, 通道数={audio.channels}, 采样率={audio.frame_rate}Hz")
|
44 |
+
|
45 |
+
# 根据model_name选择合适的实现
|
46 |
+
if "pyannote/speaker-diarization" in model_name:
|
47 |
+
# 使用transformers版本进行说话人分离
|
48 |
+
print(f"使用transformers版本处理模型: {model_name}")
|
49 |
+
result = diarize_audio_transformers(audio, model_name=model_name, token=hf_token, device=device, segmentation_batch_size=128)
|
50 |
+
version_name = "Transformers"
|
51 |
+
else:
|
52 |
+
# 使用MLX版本进行说话人分离
|
53 |
+
print(f"使用MLX版本处理模型: {model_name}")
|
54 |
+
result = diarize_audio_mlx(audio, model_name=model_name, token=hf_token, device=device, segmentation_batch_size=128)
|
55 |
+
version_name = "MLX"
|
56 |
+
|
57 |
+
# 输出结果
|
58 |
+
print(f"\n{version_name}版本说话人分离结果:")
|
59 |
+
print("-" * 50)
|
60 |
+
print(f"检测到的说话人数量: {result.num_speakers}")
|
61 |
+
print(f"分段总数: {len(result.segments)}")
|
62 |
+
|
63 |
+
print("\n分段详情:")
|
64 |
+
for i, segment in enumerate(result.segments, 1):
|
65 |
+
start = segment["start"]
|
66 |
+
end = segment["end"]
|
67 |
+
speaker = segment["speaker"]
|
68 |
+
duration = end - start
|
69 |
+
print(f"分段 {i}: [{start:.2f}s - {end:.2f}s] (时长: {duration:.2f}s) 说话人: {speaker}")
|
70 |
+
|
71 |
+
return 0
|
72 |
+
|
73 |
+
except Exception as e:
|
74 |
+
print(f"错误: {str(e)}")
|
75 |
+
return 1
|
76 |
+
|
77 |
+
|
78 |
+
if __name__ == "__main__":
|
79 |
+
sys.exit(main())
|
examples/simple_llm.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# 添加项目根目录到Python路径
|
3 |
+
import sys
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
7 |
+
|
8 |
+
from src.podcast_transcribe.llm.llm_phi4_transfomers import Phi4TransformersChatCompletion
|
9 |
+
from src.podcast_transcribe.llm.llm_gemma_mlx import GemmaMLXChatCompletion
|
10 |
+
from src.podcast_transcribe.llm.llm_gemma_transfomers import GemmaTransformersChatCompletion
|
11 |
+
|
12 |
+
|
13 |
+
if __name__ == "__main__":
|
14 |
+
# 示例用法:
|
15 |
+
print("正在初始化 LLM 聊天补全...")
|
16 |
+
try:
|
17 |
+
model_name = "google/gemma-3-4b-it"
|
18 |
+
use_4bit_quantization = False
|
19 |
+
|
20 |
+
# gemma_chat = GemmaMLXChatCompletion(model_name="mlx-community/gemma-3-12b-it-4bit-DWQ")
|
21 |
+
# 或者,如果您有更小、更快的模型,可以尝试使用,例如:"mlx-community/gemma-2b-it-8bit"
|
22 |
+
if model_name.startswith("mlx-community"):
|
23 |
+
gemma_chat = GemmaMLXChatCompletion(model_name=model_name)
|
24 |
+
elif model_name.startswith("microsoft"):
|
25 |
+
gemma_chat = Phi4TransformersChatCompletion(model_name=model_name, use_4bit_quantization=use_4bit_quantization)
|
26 |
+
else:
|
27 |
+
gemma_chat = GemmaTransformersChatCompletion(model_name=model_name, use_4bit_quantization=use_4bit_quantization)
|
28 |
+
|
29 |
+
print("\n--- 示例 1: 简单用户查询 ---")
|
30 |
+
messages_example1 = [
|
31 |
+
{"role": "user", "content": "你好,你是谁?"}
|
32 |
+
]
|
33 |
+
response1 = gemma_chat.create(messages=messages_example1, max_tokens=50)
|
34 |
+
print("响应 1:")
|
35 |
+
print(f" 助手: {response1['choices'][0]['message']['content']}")
|
36 |
+
print(f" 用量: {response1['usage']}")
|
37 |
+
|
38 |
+
print("\n--- 示例 2: 带历史记录的对话 ---")
|
39 |
+
messages_example2 = [
|
40 |
+
{"role": "user", "content": "法国的首都是哪里?"},
|
41 |
+
{"role": "assistant", "content": "法国的首都是巴黎。"},
|
42 |
+
{"role": "user", "content": "你能告诉我一个关于它的有趣事实吗?"}
|
43 |
+
]
|
44 |
+
response2 = gemma_chat.create(messages=messages_example2, max_tokens=100, temperature=0.8)
|
45 |
+
print("响应 2:")
|
46 |
+
print(f" 助手: {response2['choices'][0]['message']['content']}")
|
47 |
+
print(f" 用量: {response2['usage']}")
|
48 |
+
|
49 |
+
print("\n--- 示例 3: 系统提示 (实验性,效果取决于模型微调) ---")
|
50 |
+
messages_example3 = [
|
51 |
+
{"role": "system", "content": "你是一位富有诗意的助手,擅长用富有创意的方式解释复杂的编程概念。"},
|
52 |
+
{"role": "user", "content": "解释一下编程中递归的概念。"}
|
53 |
+
]
|
54 |
+
response3 = gemma_chat.create(messages=messages_example3, max_tokens=150)
|
55 |
+
print("响应 3:")
|
56 |
+
print(f" 助手: {response3['choices'][0]['message']['content']}")
|
57 |
+
print(f" 用量: {response3['usage']}")
|
58 |
+
|
59 |
+
print("\n--- 示例 4: 使用 max_tokens 强制缩短响应 ---")
|
60 |
+
messages_example4 = [
|
61 |
+
{"role": "user", "content": "给我讲一个关于勇敢骑士的很长的故事。"}
|
62 |
+
]
|
63 |
+
response4 = gemma_chat.create(messages=messages_example4, max_tokens=20) # 非常短
|
64 |
+
print("响应 4:")
|
65 |
+
print(f" 助手: {response4['choices'][0]['message']['content']}")
|
66 |
+
print(f" 用量: {response4['usage']}")
|
67 |
+
if response4['usage']['completion_tokens'] >= 20:
|
68 |
+
print(" 注意:由于 max_tokens,补全可能已被截断。")
|
69 |
+
|
70 |
+
|
71 |
+
except Exception as e:
|
72 |
+
print(f"示例用法期间发生错误: {e}")
|
73 |
+
import traceback
|
74 |
+
traceback.print_exc()
|
examples/simple_rss_parser.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# 添加项目根目录到Python路径
|
3 |
+
import sys
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
|
7 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
8 |
+
|
9 |
+
from src.podcast_transcribe.rss.podcast_rss_parser import parse_podcast_rss
|
10 |
+
|
11 |
+
if __name__ == '__main__':
|
12 |
+
# 使用示例:
|
13 |
+
lex_fridman_rss = "https://feeds.buzzsprout.com/2460059.rss"
|
14 |
+
print(f"正在解析 Lex Fridman Podcast RSS: {lex_fridman_rss}")
|
15 |
+
podcast_data = parse_podcast_rss(lex_fridman_rss)
|
16 |
+
|
17 |
+
if podcast_data:
|
18 |
+
print(f"Podcast Title: {podcast_data.title}")
|
19 |
+
print(f"Podcast Link: {podcast_data.link}")
|
20 |
+
print(f"Podcast Description: {podcast_data.description[:200] if podcast_data.description else 'N/A'}...")
|
21 |
+
print(f"Podcast Author: {podcast_data.author}")
|
22 |
+
print(f"Podcast Image URL: {podcast_data.image_url}")
|
23 |
+
print(f"Total episodes found: {len(podcast_data.episodes)}")
|
24 |
+
|
25 |
+
if podcast_data.episodes:
|
26 |
+
print("\n--- Sample Episode ---")
|
27 |
+
sample_episode = podcast_data.episodes[0]
|
28 |
+
print(f" 标题: {sample_episode.title}")
|
29 |
+
print(f" 发布日期: {sample_episode.published_date}")
|
30 |
+
print(f" 链接: {sample_episode.link}")
|
31 |
+
print(f" 音频 URL: {sample_episode.audio_url}")
|
32 |
+
print(f" GUID: {sample_episode.guid}")
|
33 |
+
print(f" 时长: {sample_episode.duration}")
|
34 |
+
print(f" 季: {sample_episode.season}")
|
35 |
+
print(f" 集数: {sample_episode.episode_number}")
|
36 |
+
print(f" 剧集类型: {sample_episode.episode_type}")
|
37 |
+
print(f" 摘要: {sample_episode.summary[:200] if sample_episode.summary else 'N/A'}...")
|
38 |
+
print(f" Shownotes (前 300 字符): {sample_episode.shownotes[:300] if sample_episode.shownotes else 'N/A'}...")
|
39 |
+
else:
|
40 |
+
print("解析播客 RSS feed 失败。")
|
examples/simple_speaker_identify.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 添加项目根目录到Python路径
|
2 |
+
import json
|
3 |
+
import sys
|
4 |
+
from pathlib import Path
|
5 |
+
import os
|
6 |
+
|
7 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
8 |
+
|
9 |
+
from src.podcast_transcribe.rss.podcast_rss_parser import parse_rss_xml_content
|
10 |
+
from src.podcast_transcribe.schemas import EnhancedSegment, CombinedTranscriptionResult
|
11 |
+
from src.podcast_transcribe.summary.speaker_identify import SpeakerIdentifier
|
12 |
+
|
13 |
+
if __name__ == '__main__':
|
14 |
+
transcribe_result_dump_file = Path.joinpath(Path(__file__).parent, "output", "lex_ai_john_carmack_1.transcription.json")
|
15 |
+
podcast_rss_xml_file = Path.joinpath(Path(__file__).parent, "input", "lexfridman.com.rss.xml")
|
16 |
+
|
17 |
+
# Load the transcription result
|
18 |
+
if not os.path.exists(transcribe_result_dump_file):
|
19 |
+
print(f"错误:转录结果文件 '{transcribe_result_dump_file}' 不存在。请先运行 combined_transcription.py 生成结果。")
|
20 |
+
sys.exit(1)
|
21 |
+
|
22 |
+
with open(transcribe_result_dump_file, "r", encoding="utf-8") as f:
|
23 |
+
# transcription_result = json.load(f) # 旧代码
|
24 |
+
data = json.load(f)
|
25 |
+
segments_data = data.get("segments", [])
|
26 |
+
# 确保 segments_data 中的每个元素都是字典,以避免在 EnhancedSegment(**seg) 时出错
|
27 |
+
# 假设 EnhancedSegment 的字段与 JSON 中 segment 字典的键完全对应
|
28 |
+
enhanced_segments = []
|
29 |
+
for seg_dict in segments_data:
|
30 |
+
if isinstance(seg_dict, dict):
|
31 |
+
enhanced_segments.append(EnhancedSegment(**seg_dict))
|
32 |
+
else:
|
33 |
+
# 处理非字典类型 segment 的情况,例如记录日志或抛出错误
|
34 |
+
print(f"警告: 在JSON中发现非字典类型的segment: {seg_dict}")
|
35 |
+
|
36 |
+
transcription_result = CombinedTranscriptionResult(
|
37 |
+
segments=enhanced_segments,
|
38 |
+
text=data.get("text", ""),
|
39 |
+
language=data.get("language", ""),
|
40 |
+
num_speakers=data.get("num_speakers", 0)
|
41 |
+
)
|
42 |
+
|
43 |
+
# 打印加载的 CombinedTranscriptionResult 对象的一些信息以供验证
|
44 |
+
print(f"\\n成功从JSON加载 CombinedTranscriptionResult 对象:")
|
45 |
+
print(f"类型: {type(transcription_result)}")
|
46 |
+
|
47 |
+
# Load the podcast RSS XML file
|
48 |
+
with open(podcast_rss_xml_file, "r") as f:
|
49 |
+
podcast_rss_xml = f.read()
|
50 |
+
mock_podcast_info = parse_rss_xml_content(podcast_rss_xml)
|
51 |
+
|
52 |
+
|
53 |
+
# 查找标题已 "#309" 开头的剧集
|
54 |
+
mock_episode_info = next((episode for episode in mock_podcast_info.episodes if episode.title.startswith("#309")), None)
|
55 |
+
if not mock_episode_info:
|
56 |
+
raise ValueError("Could not find episode with title starting with '#309'")
|
57 |
+
|
58 |
+
|
59 |
+
speaker_identifier = SpeakerIdentifier(
|
60 |
+
llm_model_name="mlx-community/gemma-3-12b-it-4bit-DWQ",
|
61 |
+
llm_provider="gemma-mlx"
|
62 |
+
)
|
63 |
+
|
64 |
+
# 3. Call the function
|
65 |
+
print("\\n--- Test Case 1: Normal execution ---")
|
66 |
+
speaker_names = speaker_identifier.recognize_speaker_names(transcription_result.segments, mock_podcast_info, mock_episode_info)
|
67 |
+
print("\\nRecognized Speaker Names (Test Case 1):")
|
68 |
+
print(json.dumps(speaker_names, ensure_ascii=False, indent=2))
|
requirements.txt
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pydub>=0.25.1
|
2 |
+
numpy>=2.2.5
|
3 |
+
pyannote.audio>=3.3.2
|
4 |
+
transformers>=4.51.3
|
5 |
+
torch>=2.7.0
|
6 |
+
torchaudio>=2.7.0
|
7 |
+
soundfile>=0.13.1
|
8 |
+
feedparser>=6.0.11
|
9 |
+
requests>=2.32.3
|
10 |
+
gradio>=5.30.0
|
11 |
+
|
12 |
+
# 可选依赖 - whisper.cpp 绑定
|
13 |
+
pywhispercpp>=1.3.0
|
14 |
+
bitsandbytes>=0.42.0
|
15 |
+
accelerate>=1.6.0
|
16 |
+
|
17 |
+
# MLX特定依赖 - 仅适用于Apple Silicon Mac
|
18 |
+
# mlx>=0.25.2
|
19 |
+
# mlx-lm>=0.24.0
|
20 |
+
# parakeet-mlx=0.2.6
|
21 |
+
# mlx-whisper=0.4.2
|
src/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
src/podcast_transcribe/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
src/podcast_transcribe/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
# 设置根日志级别为INFO,这样第三方包默认使用INFO级别
|
4 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
5 |
+
|
6 |
+
# 单独设置podcast_transcribe包的日志级别
|
7 |
+
package_logger = logging.getLogger("podcast_transcribe")
|
8 |
+
package_logger.setLevel(logging.INFO)
|
src/podcast_transcribe/asr/asr_base.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
语音识别模块基类
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
import numpy as np
|
7 |
+
from pydub import AudioSegment
|
8 |
+
from typing import Dict, List, Union, Optional, Tuple
|
9 |
+
# from dataclasses import dataclass # dataclass is now imported from schemas if needed or already there
|
10 |
+
import logging
|
11 |
+
|
12 |
+
from ..schemas import TranscriptionResult # Added import
|
13 |
+
|
14 |
+
# 配置日志
|
15 |
+
logger = logging.getLogger("asr")
|
16 |
+
|
17 |
+
|
18 |
+
class BaseTranscriber:
|
19 |
+
"""统一的语音识别基类,支持MLX和Transformers等多种框架"""
|
20 |
+
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
model_name: str,
|
24 |
+
device: str = None,
|
25 |
+
):
|
26 |
+
"""
|
27 |
+
初始化转录器
|
28 |
+
|
29 |
+
参数:
|
30 |
+
model_name: 模型名称
|
31 |
+
device: 推理设备,'cpu'或'cuda',对于MLX框架此参数可忽略
|
32 |
+
"""
|
33 |
+
self.model_name = model_name
|
34 |
+
self.device = device
|
35 |
+
self.pipeline = None # 用于Transformers
|
36 |
+
self.model = None # 用于MLX等其他框架
|
37 |
+
|
38 |
+
logger.info(f"初始化转录器,模型: {model_name}" + (f",设备: {device}" if device else ""))
|
39 |
+
|
40 |
+
# 子类需要实现_load_model方法
|
41 |
+
self._load_model()
|
42 |
+
|
43 |
+
def _load_model(self):
|
44 |
+
"""
|
45 |
+
加载模型(需要在子类中实现)
|
46 |
+
"""
|
47 |
+
raise NotImplementedError("子类必须实现_load_model方法")
|
48 |
+
|
49 |
+
def _prepare_audio(self, audio: AudioSegment) -> AudioSegment:
|
50 |
+
"""
|
51 |
+
准备音频数据
|
52 |
+
|
53 |
+
参数:
|
54 |
+
audio: 输入的AudioSegment对象
|
55 |
+
|
56 |
+
返回:
|
57 |
+
处理后的AudioSegment对象
|
58 |
+
"""
|
59 |
+
logger.debug(f"准备音频数据: 时长={len(audio)/1000:.2f}秒, 采样率={audio.frame_rate}Hz, 声道数={audio.channels}")
|
60 |
+
|
61 |
+
# 确保采样率为16kHz
|
62 |
+
if audio.frame_rate != 16000:
|
63 |
+
logger.debug(f"重采样音频从 {audio.frame_rate}Hz 到 16000Hz")
|
64 |
+
audio = audio.set_frame_rate(16000)
|
65 |
+
|
66 |
+
# 确保是单声道
|
67 |
+
if audio.channels > 1:
|
68 |
+
logger.debug(f"将{audio.channels}声道音频转换为单声道")
|
69 |
+
audio = audio.set_channels(1)
|
70 |
+
|
71 |
+
logger.debug(f"音频处理完成")
|
72 |
+
|
73 |
+
return audio
|
74 |
+
|
75 |
+
def _detect_language(self, text: str) -> str:
|
76 |
+
"""
|
77 |
+
简单的语言检测(基于经验规则)
|
78 |
+
|
79 |
+
参数:
|
80 |
+
text: 识别出的文本
|
81 |
+
|
82 |
+
返回:
|
83 |
+
检测到的语言代码
|
84 |
+
"""
|
85 |
+
# 简单的规则检测,实际应用中应使用更准确的语言检测
|
86 |
+
chinese_chars = len([c for c in text if '\u4e00' <= c <= '\u9fff'])
|
87 |
+
chinese_ratio = chinese_chars / len(text) if text else 0
|
88 |
+
logger.debug(f"语言检测: 中文字符比例 = {chinese_ratio:.2f}")
|
89 |
+
|
90 |
+
if chinese_chars > len(text) * 0.3:
|
91 |
+
return "zh"
|
92 |
+
return "en"
|
93 |
+
|
94 |
+
def _convert_segments(self, model_result) -> List[Dict[str, Union[float, str]]]:
|
95 |
+
"""
|
96 |
+
将模型的分段结果转换为所需格式(需要在子类中实现)
|
97 |
+
|
98 |
+
参数:
|
99 |
+
model_result: 模型返回的结果
|
100 |
+
|
101 |
+
返回:
|
102 |
+
转换后的分段列表
|
103 |
+
"""
|
104 |
+
raise NotImplementedError("子类必须实现_convert_segments方法")
|
105 |
+
|
106 |
+
def transcribe(self, audio: AudioSegment, chunk_duration_s: int = 30, overlap_s: int = 5) -> TranscriptionResult:
|
107 |
+
"""
|
108 |
+
转录音频,支持长音频分块处理。
|
109 |
+
|
110 |
+
参数:
|
111 |
+
audio: 要转录的AudioSegment对象
|
112 |
+
chunk_duration_s: 分块处理的块时长(秒)。如果音频短于此,则不分块。
|
113 |
+
overlap_s: 分块间的重叠时长(秒)。
|
114 |
+
|
115 |
+
返回:
|
116 |
+
TranscriptionResult对象,包含转录结果
|
117 |
+
"""
|
118 |
+
logger.info(f"开始转录 {len(audio)/1000:.2f} 秒的音频。分块设置: 块时长={chunk_duration_s}s, 重叠={overlap_s}s")
|
119 |
+
|
120 |
+
if overlap_s >= chunk_duration_s and len(audio)/1000.0 > chunk_duration_s :
|
121 |
+
logger.error("重叠时长必须小于块时长。")
|
122 |
+
raise ValueError("overlap_s 必须小于 chunk_duration_s。")
|
123 |
+
|
124 |
+
total_duration_ms = len(audio)
|
125 |
+
chunk_duration_ms = chunk_duration_s * 1000
|
126 |
+
overlap_ms = overlap_s * 1000
|
127 |
+
|
128 |
+
if total_duration_ms <= chunk_duration_ms:
|
129 |
+
logger.debug("音频时长不大于设定块时长,直接进行完整转录。")
|
130 |
+
processed_audio = self._prepare_audio(audio)
|
131 |
+
samples = np.array(processed_audio.get_array_of_samples(), dtype=np.float32) / 32768.0
|
132 |
+
|
133 |
+
try:
|
134 |
+
model_result = self._perform_transcription(samples)
|
135 |
+
text = self._get_text_from_result(model_result)
|
136 |
+
segments = self._convert_segments(model_result)
|
137 |
+
language = self._detect_language(text)
|
138 |
+
|
139 |
+
logger.info(f"单块转录完成,语言: {language},文本长度: {len(text)},分段数: {len(segments)}")
|
140 |
+
return TranscriptionResult(text=text, segments=segments, language=language)
|
141 |
+
except Exception as e:
|
142 |
+
logger.error(f"单块转录失败: {str(e)}", exc_info=True)
|
143 |
+
raise RuntimeError(f"单块转录失败: {str(e)}")
|
144 |
+
|
145 |
+
# 长音频分块处理
|
146 |
+
final_segments = []
|
147 |
+
# current_pos_ms 指的是当前块要处理的"新内容"的起始点在原始音频中的位置
|
148 |
+
current_pos_ms = 0
|
149 |
+
|
150 |
+
while current_pos_ms < total_duration_ms:
|
151 |
+
# 计算当前块实际送入模型处理的音频的起始和结束时间点
|
152 |
+
# 对于第一个块,start_process_ms = 0
|
153 |
+
# 对于后续块,start_process_ms 会向左回退 overlap_ms 以包含重叠区域
|
154 |
+
start_process_ms = max(0, current_pos_ms - overlap_ms)
|
155 |
+
end_process_ms = min(start_process_ms + chunk_duration_ms, total_duration_ms)
|
156 |
+
|
157 |
+
# 如果计算出的块起始点已经等于或超过总时长,说明处理完毕
|
158 |
+
if start_process_ms >= total_duration_ms:
|
159 |
+
break
|
160 |
+
|
161 |
+
chunk_audio = audio[start_process_ms:end_process_ms]
|
162 |
+
|
163 |
+
logger.info(f"处理音频块: {start_process_ms/1000.0:.2f}s - {end_process_ms/1000.0:.2f}s (新内容起始于: {current_pos_ms/1000.0:.2f}s)")
|
164 |
+
|
165 |
+
if len(chunk_audio) == 0:
|
166 |
+
logger.warning(f"生成了一个空的音频块,跳过。起始: {start_process_ms/1000.0:.2f}s, 结束: {end_process_ms/1000.0:.2f}s")
|
167 |
+
# 必须推进 current_pos_ms 以避免死循环
|
168 |
+
advance_ms = chunk_duration_ms - overlap_ms
|
169 |
+
if advance_ms <= 0: # 应该在函数开始时已检查 overlap_s < chunk_duration_s
|
170 |
+
raise RuntimeError("块推进时长配置错误,可能导致死循环。")
|
171 |
+
current_pos_ms += advance_ms
|
172 |
+
continue
|
173 |
+
|
174 |
+
processed_chunk_audio = self._prepare_audio(chunk_audio)
|
175 |
+
samples = np.array(processed_chunk_audio.get_array_of_samples(), dtype=np.float32) / 32768.0
|
176 |
+
|
177 |
+
try:
|
178 |
+
model_result = self._perform_transcription(samples)
|
179 |
+
segments_chunk = self._convert_segments(model_result)
|
180 |
+
|
181 |
+
for seg in segments_chunk:
|
182 |
+
# seg["start"] 和 seg["end"] 是相对于当前块 (chunk_audio) 的起始点(即0)
|
183 |
+
# 计算 segment 在原始完整音频中的绝对起止时间
|
184 |
+
global_seg_start_s = start_process_ms / 1000.0 + seg["start"]
|
185 |
+
global_seg_end_s = start_process_ms / 1000.0 + seg["end"]
|
186 |
+
|
187 |
+
# 核心去重逻辑:
|
188 |
+
# 我们只接受那些真实开始于 current_pos_ms / 1000.0 之后的 segment。
|
189 |
+
# current_pos_ms 是当前块应该贡献的"新"内容的开始时间。
|
190 |
+
# 对于第一个块 (current_pos_ms == 0),所有 segment 都被接受(只要它们的 start >= 0)。
|
191 |
+
# 对于后续块,只有当 segment 的全局开始时间 >= 当前块新内容的开始时间时,才添加。
|
192 |
+
if global_seg_start_s >= current_pos_ms / 1000.0:
|
193 |
+
final_segments.append({
|
194 |
+
"start": global_seg_start_s,
|
195 |
+
"end": global_seg_end_s,
|
196 |
+
"text": seg["text"]
|
197 |
+
})
|
198 |
+
# 特殊处理第一个块,因为 current_pos_ms 为 0,上面的条件 global_seg_start_s >= 0 总是满足。
|
199 |
+
# 但为了更清晰,如果不是第一个块,但 segment 跨越了 current_pos_ms,
|
200 |
+
# 它的起始部分在重叠区,结束部分在非重叠区。
|
201 |
+
# 当前逻辑是,如果它的 global_seg_start_s < current_pos_ms / 1000.0,它就被丢弃。
|
202 |
+
# 这是为了确保不重复记录重叠区域的开头部分。
|
203 |
+
# 如果一个 segment 完全在重叠区内且在前一个块已被记录,此逻辑可避免重复。
|
204 |
+
|
205 |
+
except Exception as e:
|
206 |
+
logger.error(f"处理音频块 {start_process_ms/1000.0:.2f}s - {end_process_ms/1000.0:.2f}s 失败: {str(e)}", exc_info=True)
|
207 |
+
|
208 |
+
# 更新下一个"新内容"块的起始位置
|
209 |
+
advance_ms = chunk_duration_ms - overlap_ms
|
210 |
+
current_pos_ms += advance_ms
|
211 |
+
|
212 |
+
# 对收集到的所有 segments 按开始时间排序
|
213 |
+
final_segments.sort(key=lambda s: s["start"])
|
214 |
+
|
215 |
+
# 可选:进一步清理 segments,例如合并非常接近且文本连续的,或移除完全重复的
|
216 |
+
cleaned_segments = []
|
217 |
+
if final_segments:
|
218 |
+
cleaned_segments.append(final_segments[0])
|
219 |
+
for i in range(1, len(final_segments)):
|
220 |
+
prev_s = cleaned_segments[-1]
|
221 |
+
curr_s = final_segments[i]
|
222 |
+
# 简单的去重:如果时间戳和文本都几乎一样,则认为是重复
|
223 |
+
if abs(curr_s["start"] - prev_s["start"]) < 0.01 and \
|
224 |
+
abs(curr_s["end"] - prev_s["end"]) < 0.01 and \
|
225 |
+
curr_s["text"] == prev_s["text"]:
|
226 |
+
continue
|
227 |
+
|
228 |
+
# 如果当前 segment 的开始时间在前一个 segment 的结束时间之前,
|
229 |
+
# 并且文本有明显重叠,可能需要更智能的合并。
|
230 |
+
# 目前的逻辑通过 global_seg_start_s >= current_pos_ms / 1000.0 过滤,
|
231 |
+
# 已经大大减少了直接的 segment 重复。
|
232 |
+
# 此处的清理更多是处理模型在边界可能产生的一些微小偏差。
|
233 |
+
# 如果上一个segment的结束时间比当前segment的开始时间还要晚,说明有重叠,
|
234 |
+
# 且上一个segment包含了当前segment的开始部分。
|
235 |
+
# 这种情况下,可以考虑调整上一个的结束,或当前segment的开始和文本。
|
236 |
+
# 为简单起见,暂时直接添加,相信之前的过滤已处理主要重叠。
|
237 |
+
if curr_s["start"] < prev_s["end"] and prev_s["text"].endswith(curr_s["text"][:len(prev_s["text"]) - int((prev_s["end"] - curr_s["start"])*10) ]): # 粗略检查
|
238 |
+
# 如果curr_s的开始部分被prev_s覆盖,并且文本也对应,则调整curr_s
|
239 |
+
# pass # 暂时不处理这种细微重叠,依赖模型切分
|
240 |
+
cleaned_segments.append(curr_s) # 仍添加,依赖后续文本拼接
|
241 |
+
else:
|
242 |
+
cleaned_segments.append(curr_s)
|
243 |
+
|
244 |
+
final_text = " ".join([s["text"] for s in cleaned_segments]).strip()
|
245 |
+
language = self._detect_language(final_text)
|
246 |
+
|
247 |
+
logger.info(f"分块转录完成。最终文本长度: {len(final_text)}, 分段数: {len(cleaned_segments)}")
|
248 |
+
|
249 |
+
return TranscriptionResult(
|
250 |
+
text=final_text,
|
251 |
+
segments=cleaned_segments,
|
252 |
+
language=language
|
253 |
+
)
|
254 |
+
|
255 |
+
def _perform_transcription(self, audio_data):
|
256 |
+
"""
|
257 |
+
执行转录(需要在子类中实现)
|
258 |
+
|
259 |
+
参数:
|
260 |
+
audio_data: 音频数据(numpy数组)
|
261 |
+
|
262 |
+
返回:
|
263 |
+
模型的转录结果
|
264 |
+
"""
|
265 |
+
raise NotImplementedError("子类必须实现_perform_transcription方法")
|
266 |
+
|
267 |
+
def _get_text_from_result(self, result):
|
268 |
+
"""
|
269 |
+
从结果中获取文本(需要在子类中实现)
|
270 |
+
|
271 |
+
参数:
|
272 |
+
result: 模型的转录结果
|
273 |
+
|
274 |
+
返回:
|
275 |
+
转录的文本
|
276 |
+
"""
|
277 |
+
raise NotImplementedError("子类必须实现_get_text_from_result方法")
|
src/podcast_transcribe/asr/asr_distil_whisper_mlx.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
基于MLX实现的语音识别模块,使用distil-whisper-large-v3模型
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
from pydub import AudioSegment
|
7 |
+
from typing import Dict, List, Union
|
8 |
+
import logging
|
9 |
+
import numpy as np
|
10 |
+
import mlx_whisper
|
11 |
+
|
12 |
+
# 导入基类
|
13 |
+
from .asr_base import BaseTranscriber, TranscriptionResult
|
14 |
+
|
15 |
+
# 配置日志
|
16 |
+
logger = logging.getLogger("asr")
|
17 |
+
|
18 |
+
|
19 |
+
class MLXDistilWhisperTranscriber(BaseTranscriber):
|
20 |
+
"""使用MLX加载和运行distil-whisper-large-v3模型的转录器"""
|
21 |
+
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
model_name: str = "mlx-community/distil-whisper-large-v3",
|
25 |
+
):
|
26 |
+
"""
|
27 |
+
初始化转录器
|
28 |
+
|
29 |
+
参数:
|
30 |
+
model_name: 模型名称
|
31 |
+
"""
|
32 |
+
super().__init__(model_name=model_name)
|
33 |
+
|
34 |
+
def _load_model(self):
|
35 |
+
"""加载Distil Whisper模型"""
|
36 |
+
try:
|
37 |
+
# 懒加载mlx-whisper
|
38 |
+
try:
|
39 |
+
import mlx_whisper
|
40 |
+
except ImportError:
|
41 |
+
raise ImportError("请先安装mlx-whisper库: pip install mlx-whisper")
|
42 |
+
|
43 |
+
logger.info(f"开始加载模型 {self.model_name}")
|
44 |
+
self.model = mlx_whisper.load_models.load_model(self.model_name)
|
45 |
+
logger.info(f"模型加载成功")
|
46 |
+
except Exception as e:
|
47 |
+
logger.error(f"加载模型失败: {str(e)}", exc_info=True)
|
48 |
+
raise RuntimeError(f"加载模型失败: {str(e)}")
|
49 |
+
|
50 |
+
def _convert_segments(self, result) -> List[Dict[str, Union[float, str]]]:
|
51 |
+
"""
|
52 |
+
将模型的分段结果转换为所需格式
|
53 |
+
|
54 |
+
参数:
|
55 |
+
result: 模型返回的结果
|
56 |
+
|
57 |
+
返回:
|
58 |
+
转换后的分段列表
|
59 |
+
"""
|
60 |
+
segments = []
|
61 |
+
|
62 |
+
for segment in result.get("segments", []):
|
63 |
+
segments.append({
|
64 |
+
"start": segment.get("start", 0.0),
|
65 |
+
"end": segment.get("end", 0.0),
|
66 |
+
"text": segment.get("text", "").strip()
|
67 |
+
})
|
68 |
+
|
69 |
+
return segments
|
70 |
+
|
71 |
+
def _perform_transcription(self, audio_data):
|
72 |
+
"""
|
73 |
+
执行转录
|
74 |
+
|
75 |
+
参数:
|
76 |
+
audio_data: 音频数据(numpy数组)
|
77 |
+
|
78 |
+
返回:
|
79 |
+
模型的转录结果
|
80 |
+
"""
|
81 |
+
return mlx_whisper.transcribe(audio_data, path_or_hf_repo=self.model_name)
|
82 |
+
|
83 |
+
def _get_text_from_result(self, result):
|
84 |
+
"""
|
85 |
+
从结果中获取文本
|
86 |
+
|
87 |
+
参数:
|
88 |
+
result: 模型的转录结果
|
89 |
+
|
90 |
+
返回:
|
91 |
+
转录的文本
|
92 |
+
"""
|
93 |
+
return result.get("text", "")
|
94 |
+
|
95 |
+
|
96 |
+
def transcribe_audio(
|
97 |
+
audio_segment: AudioSegment,
|
98 |
+
model_name: str = "mlx-community/distil-whisper-large-v3",
|
99 |
+
) -> TranscriptionResult:
|
100 |
+
"""
|
101 |
+
使用MLX和distil-whisper-large-v3模型转录音频
|
102 |
+
|
103 |
+
参数:
|
104 |
+
audio_segment: 输入的AudioSegment对象
|
105 |
+
model_name: 使用的模型名称
|
106 |
+
|
107 |
+
返回:
|
108 |
+
TranscriptionResult对象,包含转录的文本、分段和语言
|
109 |
+
"""
|
110 |
+
logger.info(f"调用transcribe_audio函数,音频长度: {len(audio_segment)/1000:.2f}秒")
|
111 |
+
transcriber = MLXDistilWhisperTranscriber(model_name=model_name)
|
112 |
+
return transcriber.transcribe(audio_segment)
|
src/podcast_transcribe/asr/asr_distil_whisper_transformers.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
基于Transformers实现的语音识别模块,使用distil-whisper-large-v3.5模型
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
from pydub import AudioSegment
|
7 |
+
from typing import Dict, List, Union
|
8 |
+
import logging
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
# 导入基类
|
12 |
+
from .asr_base import BaseTranscriber, TranscriptionResult
|
13 |
+
|
14 |
+
# 配置日志
|
15 |
+
logger = logging.getLogger("asr")
|
16 |
+
|
17 |
+
|
18 |
+
class TransformersDistilWhisperTranscriber(BaseTranscriber):
|
19 |
+
"""使用Transformers加载和运行distil-whisper-large-v3.5模型的转录器"""
|
20 |
+
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
model_name: str = "distil-whisper/distil-large-v3.5",
|
24 |
+
device: str = "cpu",
|
25 |
+
):
|
26 |
+
"""
|
27 |
+
初始化转录器
|
28 |
+
|
29 |
+
参数:
|
30 |
+
model_name: 模型名称
|
31 |
+
device: 推理设备,'cpu'或'cuda'
|
32 |
+
"""
|
33 |
+
super().__init__(model_name=model_name, device=device)
|
34 |
+
|
35 |
+
def _load_model(self):
|
36 |
+
"""加载Distil Whisper模型"""
|
37 |
+
try:
|
38 |
+
# 懒加载transformers
|
39 |
+
try:
|
40 |
+
from transformers import pipeline
|
41 |
+
except ImportError:
|
42 |
+
raise ImportError("请先安装transformers库: pip install transformers")
|
43 |
+
|
44 |
+
logger.info(f"开始加载模型 {self.model_name}")
|
45 |
+
self.pipeline = pipeline(
|
46 |
+
"automatic-speech-recognition",
|
47 |
+
model=self.model_name,
|
48 |
+
device=0 if self.device == "cuda" else -1,
|
49 |
+
return_timestamps=True,
|
50 |
+
chunk_length_s=30,
|
51 |
+
batch_size=16,
|
52 |
+
)
|
53 |
+
logger.info(f"模型加载成功")
|
54 |
+
except Exception as e:
|
55 |
+
logger.error(f"加载模型失败: {str(e)}", exc_info=True)
|
56 |
+
raise RuntimeError(f"加载模型失败: {str(e)}")
|
57 |
+
|
58 |
+
def _convert_segments(self, result) -> List[Dict[str, Union[float, str]]]:
|
59 |
+
"""
|
60 |
+
将模型的分段结果转换为所需格式
|
61 |
+
|
62 |
+
参数:
|
63 |
+
result: 模型返回的结果
|
64 |
+
|
65 |
+
返回:
|
66 |
+
转换后的分段列表
|
67 |
+
"""
|
68 |
+
segments = []
|
69 |
+
|
70 |
+
# transformers pipeline 的结果格式
|
71 |
+
if "chunks" in result:
|
72 |
+
for chunk in result["chunks"]:
|
73 |
+
segments.append({
|
74 |
+
"start": chunk["timestamp"][0] if chunk["timestamp"][0] is not None else 0.0,
|
75 |
+
"end": chunk["timestamp"][1] if chunk["timestamp"][1] is not None else 0.0,
|
76 |
+
"text": chunk["text"].strip()
|
77 |
+
})
|
78 |
+
else:
|
79 |
+
# 如果没有分段信息,创建一个单一分段
|
80 |
+
segments.append({
|
81 |
+
"start": 0.0,
|
82 |
+
"end": 0.0, # 无法确定结束时间
|
83 |
+
"text": result.get("text", "").strip()
|
84 |
+
})
|
85 |
+
|
86 |
+
return segments
|
87 |
+
|
88 |
+
def _perform_transcription(self, audio_data):
|
89 |
+
"""
|
90 |
+
执行转录
|
91 |
+
|
92 |
+
参数:
|
93 |
+
audio_data: 音频数据(numpy数组)
|
94 |
+
|
95 |
+
返回:
|
96 |
+
模型的转录结果
|
97 |
+
"""
|
98 |
+
# transformers pipeline 接受numpy数组作为输入
|
99 |
+
# 音频数据已经在_prepare_audio中确保是16kHz采样率
|
100 |
+
return self.pipeline(audio_data)
|
101 |
+
|
102 |
+
def _get_text_from_result(self, result):
|
103 |
+
"""
|
104 |
+
从结果中获取文本
|
105 |
+
|
106 |
+
参数:
|
107 |
+
result: 模型的转录结果
|
108 |
+
|
109 |
+
返回:
|
110 |
+
转录的文本
|
111 |
+
"""
|
112 |
+
return result.get("text", "")
|
113 |
+
|
114 |
+
|
115 |
+
def transcribe_audio(
|
116 |
+
audio_segment: AudioSegment,
|
117 |
+
model_name: str = "distil-whisper/distil-large-v3.5",
|
118 |
+
device: str = "cpu",
|
119 |
+
) -> TranscriptionResult:
|
120 |
+
"""
|
121 |
+
使用Transformers和distil-whisper-large-v3.5模型转录音频
|
122 |
+
|
123 |
+
参数:
|
124 |
+
audio_segment: 输入的AudioSegment对象
|
125 |
+
model_name: 使用的模型名称
|
126 |
+
device: 推理设备,'cpu'或'cuda'
|
127 |
+
|
128 |
+
返回:
|
129 |
+
TranscriptionResult对象,包含转录的文本、分段和语言
|
130 |
+
"""
|
131 |
+
logger.info(f"调用transcribe_audio函数,音频长度: {len(audio_segment)/1000:.2f}秒")
|
132 |
+
transcriber = TransformersDistilWhisperTranscriber(model_name=model_name, device=device)
|
133 |
+
return transcriber.transcribe(audio_segment)
|
src/podcast_transcribe/asr/asr_parakeet_mlx.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
基于MLX实现的语音识别模块,使用parakeet-tdt模型
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
from pydub import AudioSegment
|
7 |
+
from typing import Dict, List, Union
|
8 |
+
import logging
|
9 |
+
import tempfile
|
10 |
+
import numpy as np
|
11 |
+
import soundfile as sf
|
12 |
+
|
13 |
+
# 导入基类
|
14 |
+
from .asr_base import BaseTranscriber, TranscriptionResult
|
15 |
+
|
16 |
+
# 配置日志
|
17 |
+
logger = logging.getLogger("asr")
|
18 |
+
|
19 |
+
|
20 |
+
class MLXParakeetTranscriber(BaseTranscriber):
|
21 |
+
"""使用MLX加载和运行parakeet-tdt-0.6b-v2模型的转录器"""
|
22 |
+
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
model_name: str = "mlx-community/parakeet-tdt-0.6b-v2",
|
26 |
+
):
|
27 |
+
"""
|
28 |
+
初始化转录器
|
29 |
+
|
30 |
+
参数:
|
31 |
+
model_name: 模型名称
|
32 |
+
"""
|
33 |
+
super().__init__(model_name=model_name)
|
34 |
+
|
35 |
+
def _load_model(self):
|
36 |
+
"""加载Parakeet模型"""
|
37 |
+
try:
|
38 |
+
# 懒加载parakeet_mlx
|
39 |
+
try:
|
40 |
+
from parakeet_mlx import from_pretrained
|
41 |
+
except ImportError:
|
42 |
+
raise ImportError("请先安装parakeet-mlx库: pip install parakeet-mlx")
|
43 |
+
|
44 |
+
logger.info(f"开始加载模型 {self.model_name}")
|
45 |
+
self.model = from_pretrained(self.model_name)
|
46 |
+
logger.info(f"模型加载成功")
|
47 |
+
except Exception as e:
|
48 |
+
logger.error(f"加载模型失败: {str(e)}", exc_info=True)
|
49 |
+
raise RuntimeError(f"加载模型失败: {str(e)}")
|
50 |
+
|
51 |
+
def _convert_segments(self, aligned_result) -> List[Dict[str, Union[float, str]]]:
|
52 |
+
"""
|
53 |
+
将模型的分段结果转换为所需格式
|
54 |
+
|
55 |
+
参数:
|
56 |
+
aligned_result: 模型返回的分段结果
|
57 |
+
|
58 |
+
返回:
|
59 |
+
转换后的分段列表
|
60 |
+
"""
|
61 |
+
segments = []
|
62 |
+
|
63 |
+
for sentence in aligned_result.sentences:
|
64 |
+
segments.append({
|
65 |
+
"start": sentence.start,
|
66 |
+
"end": sentence.end,
|
67 |
+
"text": sentence.text
|
68 |
+
})
|
69 |
+
|
70 |
+
return segments
|
71 |
+
|
72 |
+
def _perform_transcription(self, audio_data):
|
73 |
+
"""
|
74 |
+
执行转录
|
75 |
+
|
76 |
+
参数:
|
77 |
+
audio_data: 音频数据(numpy数组)
|
78 |
+
|
79 |
+
返回:
|
80 |
+
模型的转录结果
|
81 |
+
"""
|
82 |
+
# 由于parakeet-mlx可能不直接支持numpy数组输入
|
83 |
+
# 创建临时文件并写入音频数据
|
84 |
+
with tempfile.NamedTemporaryFile(suffix='.wav', delete=True) as temp_file:
|
85 |
+
# 确保数据在[-1, 1]范围内
|
86 |
+
if audio_data.max() > 1.0 or audio_data.min() < -1.0:
|
87 |
+
audio_data = np.clip(audio_data, -1.0, 1.0)
|
88 |
+
|
89 |
+
# 写入临时文件
|
90 |
+
sf.write(temp_file.name, audio_data, 16000, 'PCM_16')
|
91 |
+
|
92 |
+
# 使用临时文件进行转录
|
93 |
+
result = self.model.transcribe(temp_file.name)
|
94 |
+
|
95 |
+
return result
|
96 |
+
|
97 |
+
def _get_text_from_result(self, result):
|
98 |
+
"""
|
99 |
+
从结果中获取文本
|
100 |
+
|
101 |
+
参数:
|
102 |
+
result: 模型的转录结果
|
103 |
+
|
104 |
+
返回:
|
105 |
+
转录的文本
|
106 |
+
"""
|
107 |
+
return result.text
|
108 |
+
|
109 |
+
|
110 |
+
def transcribe_audio(
|
111 |
+
audio_segment: AudioSegment,
|
112 |
+
model_name: str = "mlx-community/parakeet-tdt-0.6b-v2",
|
113 |
+
) -> TranscriptionResult:
|
114 |
+
"""
|
115 |
+
使用MLX和parakeet-tdt模型转录音频
|
116 |
+
|
117 |
+
参数:
|
118 |
+
audio_segment: 输入的AudioSegment对象
|
119 |
+
model_name: 使用的模型名称
|
120 |
+
|
121 |
+
返回:
|
122 |
+
TranscriptionResult对象,包含转录的文本、分段和语言
|
123 |
+
"""
|
124 |
+
logger.info(f"调用transcribe_audio函数,音频长度: {len(audio_segment)/1000:.2f}秒")
|
125 |
+
transcriber = MLXParakeetTranscriber(model_name=model_name)
|
126 |
+
return transcriber.transcribe(audio_segment)
|
src/podcast_transcribe/asr/asr_router.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
ASR模型调用路由器
|
3 |
+
根据传递的provider参数调用不同的ASR实现,支持延迟加载
|
4 |
+
"""
|
5 |
+
|
6 |
+
import logging
|
7 |
+
from typing import Dict, Any, Optional, Callable
|
8 |
+
from pydub import AudioSegment
|
9 |
+
from .asr_base import TranscriptionResult
|
10 |
+
from . import asr_parakeet_mlx
|
11 |
+
from . import asr_distil_whisper_mlx
|
12 |
+
from . import asr_distil_whisper_transformers
|
13 |
+
|
14 |
+
# 配置日志
|
15 |
+
logger = logging.getLogger("asr")
|
16 |
+
|
17 |
+
|
18 |
+
class ASRRouter:
|
19 |
+
"""ASR模型调用路由器,支持多种ASR实现的统一调用"""
|
20 |
+
|
21 |
+
def __init__(self):
|
22 |
+
"""初始化路由器"""
|
23 |
+
self._loaded_modules = {} # 用于缓存已加载的模块
|
24 |
+
self._transcribers = {} # 用于缓存已实例化的转录器
|
25 |
+
|
26 |
+
# 定义支持的provider配置
|
27 |
+
self._provider_configs = {
|
28 |
+
"parakeet_mlx": {
|
29 |
+
"module_path": "asr_parakeet_mlx",
|
30 |
+
"function_name": "transcribe_audio",
|
31 |
+
"default_model": "mlx-community/parakeet-tdt-0.6b-v2",
|
32 |
+
"supported_params": ["model_name"],
|
33 |
+
"description": "基于MLX的Parakeet模型"
|
34 |
+
},
|
35 |
+
"distil_whisper_mlx": {
|
36 |
+
"module_path": "asr_distil_whisper_mlx",
|
37 |
+
"function_name": "transcribe_audio",
|
38 |
+
"default_model": "mlx-community/distil-whisper-large-v3",
|
39 |
+
"supported_params": ["model_name"],
|
40 |
+
"description": "基于MLX的Distil Whisper模型"
|
41 |
+
},
|
42 |
+
"distil_whisper_transformers": {
|
43 |
+
"module_path": "asr_distil_whisper_transformers",
|
44 |
+
"function_name": "transcribe_audio",
|
45 |
+
"default_model": "distil-whisper/distil-large-v3.5",
|
46 |
+
"supported_params": ["model_name", "device"],
|
47 |
+
"description": "基于Transformers的Distil Whisper模型"
|
48 |
+
}
|
49 |
+
}
|
50 |
+
|
51 |
+
def _lazy_load_module(self, provider: str):
|
52 |
+
"""
|
53 |
+
获取指定provider的模块
|
54 |
+
|
55 |
+
参数:
|
56 |
+
provider: provider名称
|
57 |
+
|
58 |
+
返回:
|
59 |
+
对应的模块
|
60 |
+
"""
|
61 |
+
if provider not in self._provider_configs:
|
62 |
+
raise ValueError(f"不支持的provider: {provider}")
|
63 |
+
|
64 |
+
if provider not in self._loaded_modules:
|
65 |
+
module_path = self._provider_configs[provider]["module_path"]
|
66 |
+
logger.info(f"获取模块: {module_path}")
|
67 |
+
|
68 |
+
# 根据module_path返回对应的模块
|
69 |
+
if module_path == "asr_parakeet_mlx":
|
70 |
+
module = asr_parakeet_mlx
|
71 |
+
elif module_path == "asr_distil_whisper_mlx":
|
72 |
+
module = asr_distil_whisper_mlx
|
73 |
+
elif module_path == "asr_distil_whisper_transformers":
|
74 |
+
module = asr_distil_whisper_transformers
|
75 |
+
else:
|
76 |
+
raise ImportError(f"未找到模块: {module_path}")
|
77 |
+
|
78 |
+
self._loaded_modules[provider] = module
|
79 |
+
logger.info(f"模块 {module_path} 获取成功")
|
80 |
+
|
81 |
+
return self._loaded_modules[provider]
|
82 |
+
|
83 |
+
def _get_transcribe_function(self, provider: str) -> Callable:
|
84 |
+
"""
|
85 |
+
获取指定provider的转录函数
|
86 |
+
|
87 |
+
参数:
|
88 |
+
provider: provider名称
|
89 |
+
|
90 |
+
返回:
|
91 |
+
转录函数
|
92 |
+
"""
|
93 |
+
module = self._lazy_load_module(provider)
|
94 |
+
function_name = self._provider_configs[provider]["function_name"]
|
95 |
+
|
96 |
+
if not hasattr(module, function_name):
|
97 |
+
raise AttributeError(f"模块中未找到函数: {function_name}")
|
98 |
+
|
99 |
+
return getattr(module, function_name)
|
100 |
+
|
101 |
+
def _filter_params(self, provider: str, params: Dict[str, Any]) -> Dict[str, Any]:
|
102 |
+
"""
|
103 |
+
过滤参数,只保留指定provider支持的参数
|
104 |
+
|
105 |
+
参数:
|
106 |
+
provider: provider名称
|
107 |
+
params: 原始参数字典
|
108 |
+
|
109 |
+
返回:
|
110 |
+
过滤后的参数字典
|
111 |
+
"""
|
112 |
+
supported_params = self._provider_configs[provider]["supported_params"]
|
113 |
+
filtered_params = {}
|
114 |
+
|
115 |
+
for param in supported_params:
|
116 |
+
if param in params:
|
117 |
+
filtered_params[param] = params[param]
|
118 |
+
|
119 |
+
# 如果没有指定model_name,使用默认模型
|
120 |
+
if "model_name" not in filtered_params and "model_name" in supported_params:
|
121 |
+
filtered_params["model_name"] = self._provider_configs[provider]["default_model"]
|
122 |
+
|
123 |
+
return filtered_params
|
124 |
+
|
125 |
+
def transcribe(
|
126 |
+
self,
|
127 |
+
audio_segment: AudioSegment,
|
128 |
+
provider: str,
|
129 |
+
**kwargs
|
130 |
+
) -> TranscriptionResult:
|
131 |
+
"""
|
132 |
+
统一的音频转录接口
|
133 |
+
|
134 |
+
参数:
|
135 |
+
audio_segment: 输入的AudioSegment对象
|
136 |
+
provider: ASR提供者名称
|
137 |
+
**kwargs: 其他参数,如model_name, device等
|
138 |
+
|
139 |
+
返回:
|
140 |
+
TranscriptionResult��象
|
141 |
+
"""
|
142 |
+
logger.info(f"使用provider '{provider}' 进行音频转录,音频长度: {len(audio_segment)/1000:.2f}秒")
|
143 |
+
|
144 |
+
if provider not in self._provider_configs:
|
145 |
+
available_providers = list(self._provider_configs.keys())
|
146 |
+
raise ValueError(f"不支持的provider: {provider}。支持的provider: {available_providers}")
|
147 |
+
|
148 |
+
try:
|
149 |
+
# 获取转录函数
|
150 |
+
transcribe_func = self._get_transcribe_function(provider)
|
151 |
+
|
152 |
+
# 过滤并准备参数
|
153 |
+
filtered_kwargs = self._filter_params(provider, kwargs)
|
154 |
+
|
155 |
+
logger.debug(f"调用 {provider} 转录函数,参数: {filtered_kwargs}")
|
156 |
+
|
157 |
+
# 执行转录
|
158 |
+
result = transcribe_func(audio_segment, **filtered_kwargs)
|
159 |
+
|
160 |
+
logger.info(f"转录完成,文本长度: {len(result.text)}字符")
|
161 |
+
return result
|
162 |
+
|
163 |
+
except Exception as e:
|
164 |
+
logger.error(f"使用provider '{provider}' 转录音频失败: {str(e)}", exc_info=True)
|
165 |
+
raise RuntimeError(f"转录失败: {str(e)}")
|
166 |
+
|
167 |
+
def get_available_providers(self) -> Dict[str, str]:
|
168 |
+
"""
|
169 |
+
获取所有可用的provider及其描述
|
170 |
+
|
171 |
+
返回:
|
172 |
+
provider名称到描述的映射
|
173 |
+
"""
|
174 |
+
return {
|
175 |
+
provider: config["description"]
|
176 |
+
for provider, config in self._provider_configs.items()
|
177 |
+
}
|
178 |
+
|
179 |
+
def get_provider_info(self, provider: str) -> Dict[str, Any]:
|
180 |
+
"""
|
181 |
+
获取指定provider的详细信息
|
182 |
+
|
183 |
+
参数:
|
184 |
+
provider: provider名称
|
185 |
+
|
186 |
+
返回:
|
187 |
+
provider的配置信息
|
188 |
+
"""
|
189 |
+
if provider not in self._provider_configs:
|
190 |
+
raise ValueError(f"不支持的provider: {provider}")
|
191 |
+
|
192 |
+
return self._provider_configs[provider].copy()
|
193 |
+
|
194 |
+
|
195 |
+
# 创建全局路由器实例
|
196 |
+
_router = ASRRouter()
|
197 |
+
|
198 |
+
|
199 |
+
def transcribe_audio(
|
200 |
+
audio_segment: AudioSegment,
|
201 |
+
provider: str = "distil_whisper_transformers",
|
202 |
+
model_name: Optional[str] = None,
|
203 |
+
device: str = "cpu",
|
204 |
+
**kwargs
|
205 |
+
) -> TranscriptionResult:
|
206 |
+
"""
|
207 |
+
统一的音频转录接口函数
|
208 |
+
|
209 |
+
参数:
|
210 |
+
audio_segment: 输入的AudioSegment对象
|
211 |
+
provider: ASR提供者,可选值:
|
212 |
+
- "parakeet_mlx": 基于MLX的Parakeet模型
|
213 |
+
- "distil_whisper_mlx": 基于MLX的Distil Whisper模型
|
214 |
+
- "distil_whisper_transformers": 基于Transformers的Distil Whisper模型
|
215 |
+
model_name: 模型名称,如果不指定则使用默认模型
|
216 |
+
device: 推理设备,仅对transformers provider有效
|
217 |
+
**kwargs: 其他参数
|
218 |
+
|
219 |
+
返回:
|
220 |
+
TranscriptionResult对象,包含转录的文本、分段和语言
|
221 |
+
|
222 |
+
示例:
|
223 |
+
# 使用默认MLX Distil Whisper模型
|
224 |
+
result = transcribe_audio(audio_segment, provider="distil_whisper_mlx")
|
225 |
+
|
226 |
+
# 使用Parakeet模型
|
227 |
+
result = transcribe_audio(audio_segment, provider="parakeet_mlx")
|
228 |
+
|
229 |
+
# 使用Transformers模型并指定设备
|
230 |
+
result = transcribe_audio(
|
231 |
+
audio_segment,
|
232 |
+
provider="distil_whisper_transformers",
|
233 |
+
device="cuda"
|
234 |
+
)
|
235 |
+
|
236 |
+
# 使用自定义模型
|
237 |
+
result = transcribe_audio(
|
238 |
+
audio_segment,
|
239 |
+
provider="distil_whisper_mlx",
|
240 |
+
model_name="mlx-community/whisper-large-v3"
|
241 |
+
)
|
242 |
+
"""
|
243 |
+
# 准备参数
|
244 |
+
params = kwargs.copy()
|
245 |
+
if model_name is not None:
|
246 |
+
params["model_name"] = model_name
|
247 |
+
if device != "cpu":
|
248 |
+
params["device"] = device
|
249 |
+
|
250 |
+
return _router.transcribe(audio_segment, provider, **params)
|
251 |
+
|
252 |
+
|
253 |
+
def get_available_providers() -> Dict[str, str]:
|
254 |
+
"""
|
255 |
+
获取所有可用的ASR提供者
|
256 |
+
|
257 |
+
返回:
|
258 |
+
provider名称到描述的映射
|
259 |
+
"""
|
260 |
+
return _router.get_available_providers()
|
261 |
+
|
262 |
+
|
263 |
+
def get_provider_info(provider: str) -> Dict[str, Any]:
|
264 |
+
"""
|
265 |
+
获取指定provider的详细信息
|
266 |
+
|
267 |
+
参数:
|
268 |
+
provider: provider名称
|
269 |
+
|
270 |
+
返回:
|
271 |
+
provider的配置信息
|
272 |
+
"""
|
273 |
+
return _router.get_provider_info(provider)
|
src/podcast_transcribe/audio.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
音频处理工具模块
|
3 |
+
"""
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from io import BytesIO
|
7 |
+
from pydub import AudioSegment
|
8 |
+
from typing import Tuple, Dict, Any
|
9 |
+
|
10 |
+
|
11 |
+
def load_audio(audio_file: str, target_sample_rate: int = 16000, mono: bool = True) -> Tuple[AudioSegment, np.ndarray]:
|
12 |
+
"""
|
13 |
+
加载音频文件并转换为目标采样率和通道数
|
14 |
+
|
15 |
+
参数:
|
16 |
+
audio_file: 音频文件路径
|
17 |
+
target_sample_rate: 目标采样率,默认16kHz
|
18 |
+
mono: 是否转换为单声道,默认True
|
19 |
+
|
20 |
+
返回:
|
21 |
+
AudioSegment对象和对应的numpy数组
|
22 |
+
"""
|
23 |
+
try:
|
24 |
+
audio = AudioSegment.from_file(audio_file)
|
25 |
+
|
26 |
+
# 转换为单声道(如果需要)
|
27 |
+
if mono and audio.channels > 1:
|
28 |
+
audio = audio.set_channels(1)
|
29 |
+
|
30 |
+
# 转换采样率
|
31 |
+
if audio.frame_rate != target_sample_rate:
|
32 |
+
audio = audio.set_frame_rate(target_sample_rate)
|
33 |
+
|
34 |
+
# 获取音频波形(用于pyannote)
|
35 |
+
waveform = np.array(audio.get_array_of_samples()).astype(np.float32) / 32768.0
|
36 |
+
|
37 |
+
return audio, waveform
|
38 |
+
|
39 |
+
except Exception as e:
|
40 |
+
raise RuntimeError(f"无法加载音频文件: {str(e)}")
|
41 |
+
|
42 |
+
|
43 |
+
def extract_audio_segment(audio: AudioSegment, start_ms: int, end_ms: int) -> BytesIO:
|
44 |
+
"""
|
45 |
+
从音频中提取指定时间段
|
46 |
+
|
47 |
+
参数:
|
48 |
+
audio: AudioSegment对象
|
49 |
+
start_ms: 开始时间(毫秒)
|
50 |
+
end_ms: 结束时间(毫秒)
|
51 |
+
|
52 |
+
返回:
|
53 |
+
包含音频段的BytesIO对象
|
54 |
+
"""
|
55 |
+
try:
|
56 |
+
sub_audio = audio[start_ms:end_ms]
|
57 |
+
fp = BytesIO()
|
58 |
+
sub_audio.export(fp, format="wav")
|
59 |
+
fp.seek(0)
|
60 |
+
return fp
|
61 |
+
except Exception as e:
|
62 |
+
raise RuntimeError(f"无法提取音频段: {str(e)}")
|
src/podcast_transcribe/diarization/diarization_pyannote_mlx.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
基于pyannote/speaker-diarization-3.1模型实现的说话人分离模块
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
import numpy as np
|
7 |
+
from pydub import AudioSegment
|
8 |
+
from typing import Any, Dict, List, Mapping, Text, Union, Optional, Tuple
|
9 |
+
import logging
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from .diarizer_base import BaseDiarizer
|
13 |
+
from ..schemas import DiarizationResult
|
14 |
+
|
15 |
+
# 配置日志
|
16 |
+
logger = logging.getLogger("diarization")
|
17 |
+
|
18 |
+
class PyannoteTranscriber(BaseDiarizer):
|
19 |
+
"""使用pyannote/speaker-diarization-3.1模型进行说话人分离"""
|
20 |
+
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
model_name: str = "pyannote/speaker-diarization-3.1",
|
24 |
+
token: Optional[str] = None,
|
25 |
+
device: str = "cpu",
|
26 |
+
segmentation_batch_size: int = 32,
|
27 |
+
):
|
28 |
+
"""
|
29 |
+
初始化说话人分离器
|
30 |
+
|
31 |
+
参数:
|
32 |
+
model_name: 模型名称
|
33 |
+
token: Hugging Face令牌,用于访问模型
|
34 |
+
device: 推理设备,'cpu'或'cuda'
|
35 |
+
segmentation_batch_size: 分割批处理大小,默认为32
|
36 |
+
"""
|
37 |
+
super().__init__(model_name, token, device, segmentation_batch_size)
|
38 |
+
|
39 |
+
# 加载模型
|
40 |
+
self._load_model()
|
41 |
+
|
42 |
+
def _load_model(self):
|
43 |
+
"""加载pyannote模型"""
|
44 |
+
try:
|
45 |
+
# 懒加载pyannote.audio
|
46 |
+
try:
|
47 |
+
from pyannote.audio import Pipeline
|
48 |
+
except ImportError:
|
49 |
+
raise ImportError("请先安装pyannote.audio库: pip install pyannote.audio")
|
50 |
+
|
51 |
+
if not self.token:
|
52 |
+
raise ValueError("需要提供Hugging Face令牌才能使用pyannote模型。请通过参数传入或设置HF_TOKEN环境变量。")
|
53 |
+
|
54 |
+
logger.info(f"开始加载模型 {self.model_name}")
|
55 |
+
self.pipeline = Pipeline.from_pretrained(
|
56 |
+
self.model_name,
|
57 |
+
use_auth_token=self.token
|
58 |
+
)
|
59 |
+
|
60 |
+
# 设置设备
|
61 |
+
self.pipeline.to(torch.device(self.device))
|
62 |
+
|
63 |
+
# 设置分割批处理大小
|
64 |
+
if hasattr(self.pipeline, "segmentation_batch_size"):
|
65 |
+
logger.info(f"设置分割批处理大小: {self.segmentation_batch_size}")
|
66 |
+
self.pipeline.segmentation_batch_size = self.segmentation_batch_size
|
67 |
+
|
68 |
+
logger.info(f"模型加载成功")
|
69 |
+
except Exception as e:
|
70 |
+
logger.error(f"加载模型失败: {str(e)}", exc_info=True)
|
71 |
+
raise RuntimeError(f"加载模型失败: {str(e)}")
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
def diarize(self, audio: AudioSegment) -> DiarizationResult:
|
76 |
+
"""
|
77 |
+
对音频进行说话人分离
|
78 |
+
|
79 |
+
参数:
|
80 |
+
audio: 要处理的AudioSegment对象
|
81 |
+
|
82 |
+
返回:
|
83 |
+
DiarizationResult对象,包含分段结果和说话人数量
|
84 |
+
"""
|
85 |
+
logger.info(f"开始处理 {len(audio)/1000:.2f} 秒的音频进行说话人分离")
|
86 |
+
|
87 |
+
# 准备音频输入
|
88 |
+
temp_audio_path = self._prepare_audio(audio)
|
89 |
+
|
90 |
+
try:
|
91 |
+
# 执行说话人分离
|
92 |
+
logger.debug("开始执行说话人分离")
|
93 |
+
from pyannote.audio.pipelines.utils.hook import ProgressHook
|
94 |
+
|
95 |
+
# 自定义 ProgressHook 类
|
96 |
+
class CustomProgressHook(ProgressHook):
|
97 |
+
def __call__(
|
98 |
+
self,
|
99 |
+
step_name: Text,
|
100 |
+
step_artifact: Any,
|
101 |
+
file: Optional[Mapping] = None,
|
102 |
+
total: Optional[int] = None,
|
103 |
+
completed: Optional[int] = None,
|
104 |
+
):
|
105 |
+
if completed is not None:
|
106 |
+
logger.info(f"处理中 {step_name}: ({completed/total*100:.1f}%)")
|
107 |
+
else:
|
108 |
+
logger.info(f"已完成 {step_name}")
|
109 |
+
|
110 |
+
with CustomProgressHook() as hook:
|
111 |
+
diarization = self.pipeline(temp_audio_path, hook=hook)
|
112 |
+
|
113 |
+
# 转换分段结果
|
114 |
+
segments, num_speakers = self._convert_segments(diarization)
|
115 |
+
|
116 |
+
logger.info(f"说话人分离完成,检测到 {num_speakers} 个说话人,生成 {len(segments)} 个分段")
|
117 |
+
|
118 |
+
return DiarizationResult(
|
119 |
+
segments=segments,
|
120 |
+
num_speakers=num_speakers
|
121 |
+
)
|
122 |
+
|
123 |
+
except Exception as e:
|
124 |
+
logger.error(f"说话人分离失败: {str(e)}", exc_info=True)
|
125 |
+
raise RuntimeError(f"说话人分离失败: {str(e)}")
|
126 |
+
finally:
|
127 |
+
# 删除临时文件
|
128 |
+
if os.path.exists(temp_audio_path):
|
129 |
+
os.remove(temp_audio_path)
|
130 |
+
|
131 |
+
|
132 |
+
def diarize_audio(
|
133 |
+
audio_segment: AudioSegment,
|
134 |
+
model_name: str = "pyannote/speaker-diarization-3.1",
|
135 |
+
token: Optional[str] = None,
|
136 |
+
device: str = "cpu",
|
137 |
+
segmentation_batch_size: int = 32,
|
138 |
+
) -> DiarizationResult:
|
139 |
+
"""
|
140 |
+
使用pyannote模型对音频进行说话人��离
|
141 |
+
|
142 |
+
参数:
|
143 |
+
audio_segment: 输入的AudioSegment对象
|
144 |
+
model_name: 使用的模型名称
|
145 |
+
token: Hugging Face令牌
|
146 |
+
device: 推理设备,'cpu'、'cuda'、'mps'
|
147 |
+
segmentation_batch_size: 分割批处理大小,默认为32
|
148 |
+
|
149 |
+
返回:
|
150 |
+
DiarizationResult对象,包含分段和说话人数量
|
151 |
+
"""
|
152 |
+
logger.info(f"调用diarize_audio函数,音频长度: {len(audio_segment)/1000:.2f}秒")
|
153 |
+
transcriber = PyannoteTranscriber(model_name=model_name, token=token, device=device, segmentation_batch_size=segmentation_batch_size)
|
154 |
+
return transcriber.diarize(audio_segment)
|
src/podcast_transcribe/diarization/diarization_pyannote_transformers.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
基于pyannote.audio库调用pyannote/speaker-diarization-3.1模型实现的说话人分离模块
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
import numpy as np
|
7 |
+
from pydub import AudioSegment
|
8 |
+
from typing import Any, Dict, List, Mapping, Text, Union, Optional, Tuple
|
9 |
+
import logging
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from .diarizer_base import BaseDiarizer
|
13 |
+
from ..schemas import DiarizationResult
|
14 |
+
|
15 |
+
# 配置日志
|
16 |
+
logger = logging.getLogger("diarization")
|
17 |
+
|
18 |
+
|
19 |
+
class PyannoteTransformersTranscriber(BaseDiarizer):
|
20 |
+
"""使用pyannote.audio库调用pyannote/speaker-diarization-3.1模型进行说话人分离"""
|
21 |
+
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
model_name: str = "pyannote/speaker-diarization-3.1",
|
25 |
+
token: Optional[str] = None,
|
26 |
+
device: str = "cpu",
|
27 |
+
segmentation_batch_size: int = 32,
|
28 |
+
):
|
29 |
+
"""
|
30 |
+
初始化说话人分离器
|
31 |
+
|
32 |
+
参数:
|
33 |
+
model_name: 模型名称
|
34 |
+
token: Hugging Face令牌,用于访问模型
|
35 |
+
device: 推理设备,'cpu'或'cuda'
|
36 |
+
segmentation_batch_size: 分割批处理大小,默认为32
|
37 |
+
"""
|
38 |
+
super().__init__(model_name, token, device, segmentation_batch_size)
|
39 |
+
|
40 |
+
# 加载模型
|
41 |
+
self._load_model()
|
42 |
+
|
43 |
+
def _load_model(self):
|
44 |
+
"""使用pyannote.audio加载模型"""
|
45 |
+
try:
|
46 |
+
# 检查依赖库
|
47 |
+
try:
|
48 |
+
from pyannote.audio import Pipeline
|
49 |
+
except ImportError:
|
50 |
+
raise ImportError("请先安装pyannote.audio库: pip install pyannote.audio")
|
51 |
+
|
52 |
+
if not self.token:
|
53 |
+
raise ValueError("需要提供Hugging Face令牌才能使用pyannote模型。请通过参数传入或设置HF_TOKEN环境变量。")
|
54 |
+
|
55 |
+
logger.info(f"开始使用pyannote.audio加载模型 {self.model_name}")
|
56 |
+
|
57 |
+
# 使用pyannote.audio Pipeline加载说话人分离模型
|
58 |
+
self.pipeline = Pipeline.from_pretrained(
|
59 |
+
self.model_name,
|
60 |
+
use_auth_token=self.token
|
61 |
+
)
|
62 |
+
|
63 |
+
# 设置设备
|
64 |
+
logger.info(f"将模型移动到设备: {self.device}")
|
65 |
+
self.pipeline.to(torch.device(self.device))
|
66 |
+
|
67 |
+
# 设置分割批处理大小
|
68 |
+
if hasattr(self.pipeline, "segmentation_batch_size"):
|
69 |
+
logger.info(f"设置分割批处理大小: {self.segmentation_batch_size}")
|
70 |
+
self.pipeline.segmentation_batch_size = self.segmentation_batch_size
|
71 |
+
|
72 |
+
logger.info(f"pyannote.audio模型加载成功")
|
73 |
+
|
74 |
+
except Exception as e:
|
75 |
+
logger.error(f"加载模型失败: {str(e)}", exc_info=True)
|
76 |
+
raise RuntimeError(f"模型加载失败: {str(e)}")
|
77 |
+
|
78 |
+
def diarize(self, audio: AudioSegment) -> DiarizationResult:
|
79 |
+
"""
|
80 |
+
对音频进行说话人分离
|
81 |
+
|
82 |
+
参数:
|
83 |
+
audio: 要处理的AudioSegment对象
|
84 |
+
|
85 |
+
返回:
|
86 |
+
DiarizationResult对象,包含分段结果和说话人数量
|
87 |
+
"""
|
88 |
+
logger.info(f"开始使用pyannote.audio处理 {len(audio)/1000:.2f} 秒的音频进行说话人分离")
|
89 |
+
|
90 |
+
# 准备音频输入
|
91 |
+
temp_audio_path = self._prepare_audio(audio)
|
92 |
+
|
93 |
+
try:
|
94 |
+
# 执行说话人分离
|
95 |
+
logger.debug("开始执行说话人分离")
|
96 |
+
|
97 |
+
# 使用自定义 ProgressHook 来显示进度
|
98 |
+
try:
|
99 |
+
from pyannote.audio.pipelines.utils.hook import ProgressHook
|
100 |
+
|
101 |
+
class CustomProgressHook(ProgressHook):
|
102 |
+
def __call__(
|
103 |
+
self,
|
104 |
+
step_name: Text,
|
105 |
+
step_artifact: Any,
|
106 |
+
file: Optional[Mapping] = None,
|
107 |
+
total: Optional[int] = None,
|
108 |
+
completed: Optional[int] = None,
|
109 |
+
):
|
110 |
+
if completed is not None and total is not None:
|
111 |
+
percentage = completed / total * 100
|
112 |
+
logger.info(f"处理中 {step_name}: ({percentage:.1f}%)")
|
113 |
+
else:
|
114 |
+
logger.info(f"已完成 {step_name}")
|
115 |
+
|
116 |
+
with CustomProgressHook() as hook:
|
117 |
+
diarization = self.pipeline(temp_audio_path, hook=hook)
|
118 |
+
|
119 |
+
except ImportError:
|
120 |
+
# 如果ProgressHook不可用,直接执行
|
121 |
+
logger.info("ProgressHook不可用,直接执行说话人分离")
|
122 |
+
diarization = self.pipeline(temp_audio_path)
|
123 |
+
|
124 |
+
# 转换分段结果
|
125 |
+
segments, num_speakers = self._convert_segments(diarization)
|
126 |
+
|
127 |
+
logger.info(f"说话人分离完成,检测到 {num_speakers} 个说话人,生成 {len(segments)} 个分段")
|
128 |
+
|
129 |
+
return DiarizationResult(
|
130 |
+
segments=segments,
|
131 |
+
num_speakers=num_speakers
|
132 |
+
)
|
133 |
+
|
134 |
+
except Exception as e:
|
135 |
+
logger.error(f"说话人分离失败: {str(e)}", exc_info=True)
|
136 |
+
raise RuntimeError(f"说话人分离失败: {str(e)}")
|
137 |
+
finally:
|
138 |
+
# 删除临时文件
|
139 |
+
if os.path.exists(temp_audio_path):
|
140 |
+
os.remove(temp_audio_path)
|
141 |
+
|
142 |
+
|
143 |
+
def diarize_audio(
|
144 |
+
audio_segment: AudioSegment,
|
145 |
+
model_name: str = "pyannote/speaker-diarization-3.1",
|
146 |
+
token: Optional[str] = None,
|
147 |
+
device: str = "cpu",
|
148 |
+
segmentation_batch_size: int = 32,
|
149 |
+
) -> DiarizationResult:
|
150 |
+
"""
|
151 |
+
使用pyannote.audio调用pyannote模型对音频进行说话人分离
|
152 |
+
|
153 |
+
参数:
|
154 |
+
audio_segment: 输入的AudioSegment对象
|
155 |
+
model_name: 使用的模型名称
|
156 |
+
token: Hugging Face令牌
|
157 |
+
device: 推理设备,'cpu'、'cuda'、'mps'
|
158 |
+
segmentation_batch_size: 分割批处理大小,默认为32
|
159 |
+
|
160 |
+
返回:
|
161 |
+
DiarizationResult对象,包含分段和说话人数量
|
162 |
+
"""
|
163 |
+
logger.info(f"调用pyannote.audio版本diarize_audio函数,音频长度: {len(audio_segment)/1000:.2f}秒")
|
164 |
+
transcriber = PyannoteTransformersTranscriber(
|
165 |
+
model_name=model_name,
|
166 |
+
token=token,
|
167 |
+
device=device,
|
168 |
+
segmentation_batch_size=segmentation_batch_size
|
169 |
+
)
|
170 |
+
return transcriber.diarize(audio_segment)
|
src/podcast_transcribe/diarization/diarizer_base.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
说话人分离器基础类,包含可复用的方法
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
import logging
|
7 |
+
from abc import ABC, abstractmethod
|
8 |
+
from pydub import AudioSegment
|
9 |
+
from typing import Any, Dict, List, Union, Optional, Tuple
|
10 |
+
|
11 |
+
from ..schemas import DiarizationResult
|
12 |
+
|
13 |
+
# 配置日志
|
14 |
+
logger = logging.getLogger("diarization")
|
15 |
+
|
16 |
+
|
17 |
+
class BaseDiarizer(ABC):
|
18 |
+
"""说话人分离器基础类"""
|
19 |
+
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
model_name: str,
|
23 |
+
token: Optional[str] = None,
|
24 |
+
device: str = "cpu",
|
25 |
+
segmentation_batch_size: int = 32,
|
26 |
+
):
|
27 |
+
"""
|
28 |
+
初始化说话人分离器基础参数
|
29 |
+
|
30 |
+
参数:
|
31 |
+
model_name: 模型名称
|
32 |
+
token: Hugging Face令牌,用于访问模型
|
33 |
+
device: 推理设备,'cpu'或'cuda'
|
34 |
+
segmentation_batch_size: 分割批处理大小,默认为32
|
35 |
+
"""
|
36 |
+
self.model_name = model_name
|
37 |
+
self.token = token or os.environ.get("HF_TOKEN")
|
38 |
+
self.device = device
|
39 |
+
self.segmentation_batch_size = segmentation_batch_size
|
40 |
+
|
41 |
+
logger.info(f"初始化说话人分离器,模型: {model_name},设备: {device},分割批处理大小: {segmentation_batch_size}")
|
42 |
+
|
43 |
+
@abstractmethod
|
44 |
+
def _load_model(self):
|
45 |
+
"""加载模型,子类需要实现"""
|
46 |
+
pass
|
47 |
+
|
48 |
+
def _prepare_audio(self, audio: AudioSegment) -> str:
|
49 |
+
"""
|
50 |
+
准备音频数据,保存为临时文件
|
51 |
+
|
52 |
+
参数:
|
53 |
+
audio: 输入的AudioSegment对象
|
54 |
+
|
55 |
+
返回:
|
56 |
+
临时音频文件的路径
|
57 |
+
"""
|
58 |
+
logger.debug(f"准备音频数据: 时长={len(audio)/1000:.2f}秒, 采样率={audio.frame_rate}Hz, 声道数={audio.channels}")
|
59 |
+
|
60 |
+
# 确保采样率为16kHz (pyannote模型要求)
|
61 |
+
if audio.frame_rate != 16000:
|
62 |
+
logger.debug(f"重采样音频从 {audio.frame_rate}Hz 到 16000Hz")
|
63 |
+
audio = audio.set_frame_rate(16000)
|
64 |
+
|
65 |
+
# 确保是单声道
|
66 |
+
if audio.channels > 1:
|
67 |
+
logger.debug(f"将{audio.channels}声道音频转换为单声道")
|
68 |
+
audio = audio.set_channels(1)
|
69 |
+
|
70 |
+
# 保存为临时文件
|
71 |
+
temp_audio_path = "_temp_audio_for_diarization.wav"
|
72 |
+
audio.export(temp_audio_path, format="wav")
|
73 |
+
|
74 |
+
logger.debug(f"音频处理完成,保存至: {temp_audio_path}")
|
75 |
+
|
76 |
+
return temp_audio_path
|
77 |
+
|
78 |
+
def _convert_segments(self, diarization) -> Tuple[List[Dict[str, Union[float, str, int]]], int]:
|
79 |
+
"""
|
80 |
+
将pyannote的分段结果转换为所需格式
|
81 |
+
|
82 |
+
参数:
|
83 |
+
diarization: pyannote模型返回的分段结果
|
84 |
+
|
85 |
+
返回:
|
86 |
+
转换后的分段列表和说话人数量
|
87 |
+
"""
|
88 |
+
segments = []
|
89 |
+
speakers = set()
|
90 |
+
|
91 |
+
# 遍历说话人分离结果
|
92 |
+
for turn, _, speaker in diarization.itertracks(yield_label=True):
|
93 |
+
segments.append({
|
94 |
+
"start": turn.start,
|
95 |
+
"end": turn.end,
|
96 |
+
"speaker": speaker
|
97 |
+
})
|
98 |
+
speakers.add(speaker)
|
99 |
+
|
100 |
+
# 按开始时间排序
|
101 |
+
segments.sort(key=lambda x: x["start"])
|
102 |
+
|
103 |
+
logger.debug(f"转换了 {len(segments)} 个分段,检测到 {len(speakers)} 个说话人")
|
104 |
+
|
105 |
+
return segments, len(speakers)
|
106 |
+
|
107 |
+
@abstractmethod
|
108 |
+
def diarize(self, audio: AudioSegment) -> DiarizationResult:
|
109 |
+
"""
|
110 |
+
对音频进行说话人分离,子类需要实现
|
111 |
+
|
112 |
+
参数:
|
113 |
+
audio: 要处理的AudioSegment对象
|
114 |
+
|
115 |
+
返回:
|
116 |
+
DiarizationResult对象,包含分段结果和说话人数量
|
117 |
+
"""
|
118 |
+
pass
|
src/podcast_transcribe/diarization/diarizer_router.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
说话人分离模型调用路由器
|
3 |
+
根据传递的provider参数调用不同的说话人分离实现,支持延迟加载
|
4 |
+
"""
|
5 |
+
|
6 |
+
import logging
|
7 |
+
from typing import Dict, Any, Optional, Callable
|
8 |
+
from pydub import AudioSegment
|
9 |
+
from ..schemas import DiarizationResult
|
10 |
+
from . import diarization_pyannote_mlx
|
11 |
+
from . import diarization_pyannote_transformers
|
12 |
+
|
13 |
+
# 配置日志
|
14 |
+
logger = logging.getLogger("diarization")
|
15 |
+
|
16 |
+
|
17 |
+
class DiarizerRouter:
|
18 |
+
"""说话人分离模型调用路由器,支持多种实现的统一调用"""
|
19 |
+
|
20 |
+
def __init__(self):
|
21 |
+
"""初始化路由器"""
|
22 |
+
self._loaded_modules = {} # 用于缓存已加载的模块
|
23 |
+
self._diarizers = {} # 用于缓存已实例化的分离器
|
24 |
+
|
25 |
+
# 定义支持的provider配置
|
26 |
+
self._provider_configs = {
|
27 |
+
"pyannote_mlx": {
|
28 |
+
"module_path": "diarization_pyannote_mlx",
|
29 |
+
"function_name": "diarize_audio",
|
30 |
+
"default_model": "pyannote/speaker-diarization-3.1",
|
31 |
+
"supported_params": ["model_name", "token", "device", "segmentation_batch_size"],
|
32 |
+
"description": "基于pyannote.audio的原生MLX实现"
|
33 |
+
},
|
34 |
+
"pyannote_transformers": {
|
35 |
+
"module_path": "diarization_pyannote_transformers",
|
36 |
+
"function_name": "diarize_audio",
|
37 |
+
"default_model": "pyannote/speaker-diarization-3.1",
|
38 |
+
"supported_params": ["model_name", "token", "device", "segmentation_batch_size"],
|
39 |
+
"description": "基于transformers库调用pyannote模型"
|
40 |
+
}
|
41 |
+
}
|
42 |
+
|
43 |
+
def _lazy_load_module(self, provider: str):
|
44 |
+
"""
|
45 |
+
获取指定provider的模块
|
46 |
+
|
47 |
+
参数:
|
48 |
+
provider: provider名称
|
49 |
+
|
50 |
+
返回:
|
51 |
+
对应的模块
|
52 |
+
"""
|
53 |
+
if provider not in self._provider_configs:
|
54 |
+
raise ValueError(f"不支持的provider: {provider}")
|
55 |
+
|
56 |
+
if provider not in self._loaded_modules:
|
57 |
+
module_path = self._provider_configs[provider]["module_path"]
|
58 |
+
logger.info(f"获取模块: {module_path}")
|
59 |
+
|
60 |
+
# 根据module_path返回对应的模块
|
61 |
+
if module_path == "diarization_pyannote_mlx":
|
62 |
+
module = diarization_pyannote_mlx
|
63 |
+
elif module_path == "diarization_pyannote_transformers":
|
64 |
+
module = diarization_pyannote_transformers
|
65 |
+
else:
|
66 |
+
raise ImportError(f"未找到模块: {module_path}")
|
67 |
+
|
68 |
+
self._loaded_modules[provider] = module
|
69 |
+
logger.info(f"模块 {module_path} 获取成功")
|
70 |
+
|
71 |
+
return self._loaded_modules[provider]
|
72 |
+
|
73 |
+
def _get_diarize_function(self, provider: str) -> Callable:
|
74 |
+
"""
|
75 |
+
获取指定provider的说话人分离函数
|
76 |
+
|
77 |
+
参数:
|
78 |
+
provider: provider名称
|
79 |
+
|
80 |
+
返回:
|
81 |
+
说话人分离函数
|
82 |
+
"""
|
83 |
+
module = self._lazy_load_module(provider)
|
84 |
+
function_name = self._provider_configs[provider]["function_name"]
|
85 |
+
|
86 |
+
if not hasattr(module, function_name):
|
87 |
+
raise AttributeError(f"模块中未找到函数: {function_name}")
|
88 |
+
|
89 |
+
return getattr(module, function_name)
|
90 |
+
|
91 |
+
def _filter_params(self, provider: str, params: Dict[str, Any]) -> Dict[str, Any]:
|
92 |
+
"""
|
93 |
+
过滤参数,只保留指定provider支持的参数
|
94 |
+
|
95 |
+
参数:
|
96 |
+
provider: provider名称
|
97 |
+
params: 原始参数字典
|
98 |
+
|
99 |
+
返回:
|
100 |
+
过滤后的参数字典
|
101 |
+
"""
|
102 |
+
supported_params = self._provider_configs[provider]["supported_params"]
|
103 |
+
filtered_params = {}
|
104 |
+
|
105 |
+
for param in supported_params:
|
106 |
+
if param in params:
|
107 |
+
filtered_params[param] = params[param]
|
108 |
+
|
109 |
+
# 如果没有指定model_name,使用默认模型
|
110 |
+
if "model_name" not in filtered_params and "model_name" in supported_params:
|
111 |
+
filtered_params["model_name"] = self._provider_configs[provider]["default_model"]
|
112 |
+
|
113 |
+
return filtered_params
|
114 |
+
|
115 |
+
def diarize(
|
116 |
+
self,
|
117 |
+
audio_segment: AudioSegment,
|
118 |
+
provider: str,
|
119 |
+
**kwargs
|
120 |
+
) -> DiarizationResult:
|
121 |
+
"""
|
122 |
+
统一的说话人分离接口
|
123 |
+
|
124 |
+
参数:
|
125 |
+
audio_segment: 输入的AudioSegment对象
|
126 |
+
provider: 说话人分离提供者名称
|
127 |
+
**kwargs: 其他参数,如model_name, token, device, segmentation_batch_size等
|
128 |
+
|
129 |
+
返回:
|
130 |
+
DiarizationResult对象
|
131 |
+
"""
|
132 |
+
logger.info(f"使用provider '{provider}' 进行说话人分离,音频长度: {len(audio_segment)/1000:.2f}秒")
|
133 |
+
|
134 |
+
if provider not in self._provider_configs:
|
135 |
+
available_providers = list(self._provider_configs.keys())
|
136 |
+
raise ValueError(f"不支持的provider: {provider}。支持的provider: {available_providers}")
|
137 |
+
|
138 |
+
try:
|
139 |
+
# 获取说话人分离函数
|
140 |
+
diarize_func = self._get_diarize_function(provider)
|
141 |
+
|
142 |
+
# 过滤并准备参数
|
143 |
+
filtered_kwargs = self._filter_params(provider, kwargs)
|
144 |
+
|
145 |
+
logger.debug(f"调用 {provider} 说话人分离函数,参数: {filtered_kwargs}")
|
146 |
+
|
147 |
+
# 执行说话人分离
|
148 |
+
result = diarize_func(audio_segment, **filtered_kwargs)
|
149 |
+
|
150 |
+
logger.info(f"说话人分离完成,检测到 {result.num_speakers} 个说话人,生成 {len(result.segments)} 个分段")
|
151 |
+
return result
|
152 |
+
|
153 |
+
except Exception as e:
|
154 |
+
logger.error(f"使用provider '{provider}' 进行说话人分离失败: {str(e)}", exc_info=True)
|
155 |
+
raise RuntimeError(f"说话人分离失败: {str(e)}")
|
156 |
+
|
157 |
+
def get_available_providers(self) -> Dict[str, str]:
|
158 |
+
"""
|
159 |
+
获取所有可用的provider及其描述
|
160 |
+
|
161 |
+
返回:
|
162 |
+
provider名称到描述的映射
|
163 |
+
"""
|
164 |
+
return {
|
165 |
+
provider: config["description"]
|
166 |
+
for provider, config in self._provider_configs.items()
|
167 |
+
}
|
168 |
+
|
169 |
+
def get_provider_info(self, provider: str) -> Dict[str, Any]:
|
170 |
+
"""
|
171 |
+
获取指定provider的详细信息
|
172 |
+
|
173 |
+
参数:
|
174 |
+
provider: provider名称
|
175 |
+
|
176 |
+
返回:
|
177 |
+
provider的配置信息
|
178 |
+
"""
|
179 |
+
if provider not in self._provider_configs:
|
180 |
+
raise ValueError(f"不支持的provider: {provider}")
|
181 |
+
|
182 |
+
return self._provider_configs[provider].copy()
|
183 |
+
|
184 |
+
|
185 |
+
# 创建全局路由器实例
|
186 |
+
_router = DiarizerRouter()
|
187 |
+
|
188 |
+
|
189 |
+
def diarize_audio(
|
190 |
+
audio_segment: AudioSegment,
|
191 |
+
provider: str = "pyannote_mlx",
|
192 |
+
model_name: Optional[str] = None,
|
193 |
+
token: Optional[str] = None,
|
194 |
+
device: str = "cpu",
|
195 |
+
segmentation_batch_size: int = 32,
|
196 |
+
**kwargs
|
197 |
+
) -> DiarizationResult:
|
198 |
+
"""
|
199 |
+
统一的音频说话人分离接口函数
|
200 |
+
|
201 |
+
参数:
|
202 |
+
audio_segment: 输入的AudioSegment对象
|
203 |
+
provider: 说话人分离提供者,可选值:
|
204 |
+
- "pyannote_mlx": 基于pyannote.audio的原生MLX实现
|
205 |
+
- "pyannote_transformers": 基于transformers库调用pyannote模型
|
206 |
+
model_name: 模型名称,如果不指定则使用默认模型
|
207 |
+
token: Hugging Face令牌,用于访问模型
|
208 |
+
device: 推理设备,'cpu'、'cuda'、'mps'
|
209 |
+
segmentation_batch_size: 分割批处理大小,默认为32
|
210 |
+
**kwargs: 其他参数
|
211 |
+
|
212 |
+
返回:
|
213 |
+
DiarizationResult对象,包含分段结果和说话人数量
|
214 |
+
|
215 |
+
示例:
|
216 |
+
# 使用默认pyannote MLX实现
|
217 |
+
result = diarize_audio(audio_segment, provider="pyannote_mlx", token="your_hf_token")
|
218 |
+
|
219 |
+
# 使用transformers实现
|
220 |
+
result = diarize_audio(
|
221 |
+
audio_segment,
|
222 |
+
provider="pyannote_transformers",
|
223 |
+
token="your_hf_token"
|
224 |
+
)
|
225 |
+
|
226 |
+
# 使用GPU设备
|
227 |
+
result = diarize_audio(
|
228 |
+
audio_segment,
|
229 |
+
provider="pyannote_mlx",
|
230 |
+
token="your_hf_token",
|
231 |
+
device="cuda"
|
232 |
+
)
|
233 |
+
|
234 |
+
# 自定义批处理大小
|
235 |
+
result = diarize_audio(
|
236 |
+
audio_segment,
|
237 |
+
provider="pyannote_mlx",
|
238 |
+
token="your_hf_token",
|
239 |
+
segmentation_batch_size=64
|
240 |
+
)
|
241 |
+
"""
|
242 |
+
# 准备参数
|
243 |
+
params = kwargs.copy()
|
244 |
+
if model_name is not None:
|
245 |
+
params["model_name"] = model_name
|
246 |
+
if token is not None:
|
247 |
+
params["token"] = token
|
248 |
+
if device != "cpu":
|
249 |
+
params["device"] = device
|
250 |
+
if segmentation_batch_size != 32:
|
251 |
+
params["segmentation_batch_size"] = segmentation_batch_size
|
252 |
+
|
253 |
+
return _router.diarize(audio_segment, provider, **params)
|
254 |
+
|
255 |
+
|
256 |
+
def get_available_providers() -> Dict[str, str]:
|
257 |
+
"""
|
258 |
+
获取所有可用的说话人分离提供者
|
259 |
+
|
260 |
+
返回:
|
261 |
+
provider名称到描述的映射
|
262 |
+
"""
|
263 |
+
return _router.get_available_providers()
|
264 |
+
|
265 |
+
|
266 |
+
def get_provider_info(provider: str) -> Dict[str, Any]:
|
267 |
+
"""
|
268 |
+
获取指定provider的详细信息
|
269 |
+
|
270 |
+
参数:
|
271 |
+
provider: provider名称
|
272 |
+
|
273 |
+
返回:
|
274 |
+
provider的配置信息
|
275 |
+
"""
|
276 |
+
return _router.get_provider_info(provider)
|
src/podcast_transcribe/llm/llm_base.py
ADDED
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import uuid
|
3 |
+
import torch
|
4 |
+
from typing import List, Dict, Optional, Union, Literal
|
5 |
+
from abc import ABC, abstractmethod
|
6 |
+
|
7 |
+
|
8 |
+
class BaseChatCompletion(ABC):
|
9 |
+
"""Gemma 聊天完成的基类,包含公共功能"""
|
10 |
+
|
11 |
+
def __init__(self, model_name: str):
|
12 |
+
self.model_name = model_name
|
13 |
+
|
14 |
+
@abstractmethod
|
15 |
+
def _load_model_and_tokenizer(self):
|
16 |
+
"""加载模型和分词器的抽象方法,由子类实现"""
|
17 |
+
pass
|
18 |
+
|
19 |
+
@abstractmethod
|
20 |
+
def _generate_response(self, prompt_str: str, temperature: float, max_tokens: int, top_p: float, **kwargs) -> str:
|
21 |
+
"""生成响应的抽象方法,由子类实现"""
|
22 |
+
pass
|
23 |
+
|
24 |
+
def _format_messages_for_gemma(self, messages: List[Dict[str, str]]) -> str:
|
25 |
+
"""
|
26 |
+
为Gemma格式化消息。
|
27 |
+
Gemma期望特定的格式,通常类似于:
|
28 |
+
<start_of_turn>user
|
29 |
+
{user_message}<end_of_turn>
|
30 |
+
<start_of_turn>model
|
31 |
+
{assistant_message}<end_of_turn>
|
32 |
+
...
|
33 |
+
<start_of_turn>user
|
34 |
+
{current_user_message}<end_of_turn>
|
35 |
+
<start_of_turn>model
|
36 |
+
"""
|
37 |
+
# 尝试使用分词器的聊天模板(如果可用)
|
38 |
+
try:
|
39 |
+
# Hugging Face分词器中的apply_chat_template方法
|
40 |
+
# 通常需要一个字典列表,每个字典包含'role'和'content'。
|
41 |
+
# 我们需要确保我们的`messages`格式兼容。
|
42 |
+
# add_generation_prompt=True 至关重要,以确保模型知道轮到它发言了。
|
43 |
+
return self.tokenizer.apply_chat_template(
|
44 |
+
messages, tokenize=False, add_generation_prompt=True
|
45 |
+
)
|
46 |
+
except Exception:
|
47 |
+
# 如果apply_chat_template失败或不可用,则回退到手动格式化
|
48 |
+
prompt_parts = []
|
49 |
+
for message in messages:
|
50 |
+
role = message.get("role")
|
51 |
+
content = message.get("content")
|
52 |
+
if role == "user":
|
53 |
+
prompt_parts.append(f"<start_of_turn>user\n{content}<end_of_turn>")
|
54 |
+
elif role == "assistant":
|
55 |
+
prompt_parts.append(f"<start_of_turn>model\n{content}<end_of_turn>")
|
56 |
+
elif role == "system": # Gemma可能不会以相同的方式显式使用'system',通常是前置的
|
57 |
+
# 对于Gemma,系统提示通常只是前置到第一个用户消息或隐式处理。
|
58 |
+
# 我们会在这里前置它,尽管其有效性取决于特定的Gemma微调。
|
59 |
+
# 一种常见的模式是在开头放置系统指令,不使用特殊标记。
|
60 |
+
# 然而,为了保持结构化,我们将尝试一种通用方法。
|
61 |
+
# 如果分词器在其模板中有特定的方式来处理系统提示,
|
62 |
+
# 那么`apply_chat_template`将是首选。
|
63 |
+
# 由于我们处于回退状态,这是一个最佳猜测。
|
64 |
+
# 一些模型期望系统提示在轮次结构之外,或者在最开始。
|
65 |
+
# 为了在回退中简化,我们只做前置处理。
|
66 |
+
# 如果`apply_chat_template`不可用,更健壮的解决方案是检查模型的特定聊天模板。
|
67 |
+
prompt_parts.insert(0, f"<start_of_turn>system\n{content}<end_of_turn>")
|
68 |
+
|
69 |
+
# 添加提示,让模型开始生成
|
70 |
+
prompt_parts.append("<start_of_turn>model")
|
71 |
+
return "\n".join(prompt_parts)
|
72 |
+
|
73 |
+
def _post_process_response(self, response_text: str, prompt_str: str) -> str:
|
74 |
+
"""
|
75 |
+
后处理生成的响应文本,清理提示和特殊标记
|
76 |
+
"""
|
77 |
+
# 后处理:Gemma的输出可能包含输入提示或特殊标记。
|
78 |
+
# 我们需要清理这些,以仅返回助手的最新消息。
|
79 |
+
# 一种常见的模式是,模型输出将以我们给它的提示开始,
|
80 |
+
# 或者它可能包含 <start_of_turn>model 标记,然后是其响应。
|
81 |
+
|
82 |
+
# 如果模型输出包含提示,然后是新的响应:
|
83 |
+
if response_text.startswith(prompt_str):
|
84 |
+
assistant_message_content = response_text[len(prompt_str):].strip()
|
85 |
+
else:
|
86 |
+
# 如果模型不回显提示,则可能需要更复杂的清理。
|
87 |
+
# 对于Gemma,响应通常跟随提示的最后一部分 "<start_of_turn>model\n"。
|
88 |
+
# 让我们尝试找到最后一个 "<start_of_turn>model" 并获取其后的文本。
|
89 |
+
# 这是一种启发式方法,可能需要根据实际模型输出进行调整。
|
90 |
+
parts = response_text.split("<start_of_turn>model")
|
91 |
+
if len(parts) > 1:
|
92 |
+
assistant_message_content = parts[-1].strip()
|
93 |
+
# 进一步清理 <end_of_turn> 或其他特殊标记
|
94 |
+
assistant_message_content = assistant_message_content.split("<end_of_turn>")[0].strip()
|
95 |
+
else: # 如果上述方法不起作用,则回退
|
96 |
+
assistant_message_content = response_text.strip()
|
97 |
+
|
98 |
+
return assistant_message_content
|
99 |
+
|
100 |
+
def _calculate_tokens(self, prompt_str: str, assistant_message_content: str) -> Dict[str, int]:
|
101 |
+
"""
|
102 |
+
计算token数量(近似值,因为确切的OpenAI分词可能不同)
|
103 |
+
"""
|
104 |
+
# 对于提示token,我们对输入到模型的字符串进行分词。
|
105 |
+
# 对于完成token,我们对生成的助手消息进行分词。
|
106 |
+
prompt_tokens = len(self.tokenizer.encode(prompt_str))
|
107 |
+
completion_tokens = len(self.tokenizer.encode(assistant_message_content))
|
108 |
+
total_tokens = prompt_tokens + completion_tokens
|
109 |
+
|
110 |
+
return {
|
111 |
+
"prompt_tokens": prompt_tokens,
|
112 |
+
"completion_tokens": completion_tokens,
|
113 |
+
"total_tokens": total_tokens
|
114 |
+
}
|
115 |
+
|
116 |
+
def _build_chat_completion_response(self, assistant_message_content: str, token_usage: Dict[str, int]) -> Dict:
|
117 |
+
"""
|
118 |
+
构建模仿OpenAI结构的响应对象
|
119 |
+
基于: https://platform.openai.com/docs/api-reference/chat/object
|
120 |
+
"""
|
121 |
+
# 获取完成的当前时间戳
|
122 |
+
created_timestamp = int(time.time())
|
123 |
+
completion_id = f"chatcmpl-{uuid.uuid4().hex}" # 创建一个唯一的ID
|
124 |
+
|
125 |
+
return {
|
126 |
+
"id": completion_id,
|
127 |
+
"object": "chat.completion",
|
128 |
+
"created": created_timestamp,
|
129 |
+
"model": self.model_name, # 报告我们使用的模型名称
|
130 |
+
"choices": [
|
131 |
+
{
|
132 |
+
"index": 0,
|
133 |
+
"message": {
|
134 |
+
"role": "assistant",
|
135 |
+
"content": assistant_message_content,
|
136 |
+
},
|
137 |
+
"finish_reason": "stop", # 假定为 "stop"
|
138 |
+
}
|
139 |
+
],
|
140 |
+
"usage": token_usage,
|
141 |
+
}
|
142 |
+
|
143 |
+
def create(
|
144 |
+
self,
|
145 |
+
messages: List[Dict[str, str]],
|
146 |
+
temperature: float = 0.7,
|
147 |
+
max_tokens: int = 2048,
|
148 |
+
top_p: float = 1.0,
|
149 |
+
model: Optional[str] = None,
|
150 |
+
**kwargs,
|
151 |
+
):
|
152 |
+
"""
|
153 |
+
创建聊天完成响应。
|
154 |
+
模仿OpenAI的ChatCompletion.create方法。
|
155 |
+
"""
|
156 |
+
if model and model != self.model_name:
|
157 |
+
# 这是一个简化的处理。在实际场景中,您可能希望加载新模型。
|
158 |
+
# 目前,我们将只打印一个警告并使用初始化的模型。
|
159 |
+
print(f"警告: 'model' 参数 ({model}) 与初始化的模型 ({self.model_name}) 不同。"
|
160 |
+
f"正在使用初始化的模型。要使用不同的模型,请重新初始化该类。")
|
161 |
+
|
162 |
+
# 为Gemma格式化消息
|
163 |
+
prompt_str = self._format_messages_for_gemma(messages)
|
164 |
+
|
165 |
+
# 生成响应(由子类实现)
|
166 |
+
response_text = self._generate_response(prompt_str, temperature, max_tokens, top_p, **kwargs)
|
167 |
+
|
168 |
+
# 后处理响应
|
169 |
+
assistant_message_content = self._post_process_response(response_text, prompt_str)
|
170 |
+
|
171 |
+
# 计算token使用量
|
172 |
+
token_usage = self._calculate_tokens(prompt_str, assistant_message_content)
|
173 |
+
|
174 |
+
# 构建响应对象
|
175 |
+
return self._build_chat_completion_response(assistant_message_content, token_usage)
|
176 |
+
|
177 |
+
|
178 |
+
class TransformersBaseChatCompletion(BaseChatCompletion):
|
179 |
+
"""基于Transformers库的聊天完成基类,提供通用的设备管理和量化功能"""
|
180 |
+
|
181 |
+
def __init__(
|
182 |
+
self,
|
183 |
+
model_name: str,
|
184 |
+
use_4bit_quantization: bool = False,
|
185 |
+
device_map: Optional[str] = "auto",
|
186 |
+
device: Optional[str] = None,
|
187 |
+
trust_remote_code: bool = True,
|
188 |
+
torch_dtype: Optional[torch.dtype] = None
|
189 |
+
):
|
190 |
+
super().__init__(model_name)
|
191 |
+
self.use_4bit_quantization = use_4bit_quantization
|
192 |
+
self.device_map = device_map
|
193 |
+
self.trust_remote_code = trust_remote_code
|
194 |
+
self.torch_dtype = torch_dtype or torch.float16
|
195 |
+
self.device = device
|
196 |
+
|
197 |
+
# 加载模型和分词器
|
198 |
+
self._load_model_and_tokenizer()
|
199 |
+
|
200 |
+
def _get_quantization_config(self):
|
201 |
+
"""获取量化配置"""
|
202 |
+
if not self.use_4bit_quantization:
|
203 |
+
return None
|
204 |
+
|
205 |
+
if self.device and self.device.type == "mps":
|
206 |
+
print("警告: MPS 设备不支持 4bit 量化,将禁用量化")
|
207 |
+
self.use_4bit_quantization = False
|
208 |
+
return None
|
209 |
+
|
210 |
+
# 导入量化配置
|
211 |
+
try:
|
212 |
+
from transformers import BitsAndBytesConfig
|
213 |
+
except ImportError:
|
214 |
+
raise ImportError("请先安装 bitsandbytes 库: pip install bitsandbytes")
|
215 |
+
|
216 |
+
return BitsAndBytesConfig(
|
217 |
+
load_in_4bit=True,
|
218 |
+
bnb_4bit_compute_dtype=self.torch_dtype,
|
219 |
+
bnb_4bit_quant_type="nf4",
|
220 |
+
bnb_4bit_use_double_quant=True,
|
221 |
+
)
|
222 |
+
|
223 |
+
def _load_tokenizer(self):
|
224 |
+
"""加载分词器"""
|
225 |
+
try:
|
226 |
+
from transformers import AutoTokenizer
|
227 |
+
except ImportError:
|
228 |
+
raise ImportError("请先安装 transformers 库: pip install transformers")
|
229 |
+
|
230 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
231 |
+
self.model_name,
|
232 |
+
trust_remote_code=self.trust_remote_code
|
233 |
+
)
|
234 |
+
|
235 |
+
# 设置 pad_token 如果不存在
|
236 |
+
if self.tokenizer.pad_token is None:
|
237 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
238 |
+
|
239 |
+
def _load_model(self):
|
240 |
+
"""加载模型"""
|
241 |
+
try:
|
242 |
+
from transformers import AutoModelForCausalLM
|
243 |
+
except ImportError:
|
244 |
+
raise ImportError("请先安装 transformers 库: pip install transformers")
|
245 |
+
|
246 |
+
print(f"正在加载模型: {self.model_name}")
|
247 |
+
print(f"4bit量化: {'启用' if self.use_4bit_quantization else '禁用'}")
|
248 |
+
print(f"目标设备: {self.device}")
|
249 |
+
print(f"设备映射: {self.device_map}")
|
250 |
+
|
251 |
+
# 配置模型加载参数
|
252 |
+
model_kwargs = {
|
253 |
+
"trust_remote_code": self.trust_remote_code,
|
254 |
+
"torch_dtype": self.torch_dtype,
|
255 |
+
}
|
256 |
+
|
257 |
+
# 处理量化配置
|
258 |
+
quantization_config = self._get_quantization_config()
|
259 |
+
if quantization_config:
|
260 |
+
model_kwargs["quantization_config"] = quantization_config
|
261 |
+
print(f"使用 4bit 量化配置")
|
262 |
+
|
263 |
+
# 处理设备映射
|
264 |
+
if self.device_map is not None:
|
265 |
+
if self.device and self.device.type == "mps":
|
266 |
+
print("警告: MPS 设备不支持 device_map,将手动管理设备")
|
267 |
+
else:
|
268 |
+
model_kwargs["device_map"] = self.device_map
|
269 |
+
|
270 |
+
# 加载模型
|
271 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
272 |
+
self.model_name,
|
273 |
+
**model_kwargs
|
274 |
+
)
|
275 |
+
|
276 |
+
# MPS 或手动设备管理
|
277 |
+
if self.device_map is None or (self.device and self.device.type == "mps"):
|
278 |
+
if not self.use_4bit_quantization:
|
279 |
+
print(f"手动移动模型到设备: {self.device}")
|
280 |
+
self.model = self.model.to(self.device)
|
281 |
+
|
282 |
+
print(f"模型 {self.model_name} 加载成功")
|
283 |
+
|
284 |
+
def _load_model_and_tokenizer(self):
|
285 |
+
"""加载模型和分词器"""
|
286 |
+
try:
|
287 |
+
self._load_tokenizer()
|
288 |
+
self._load_model()
|
289 |
+
except Exception as e:
|
290 |
+
print(f"加载模型 {self.model_name} 时出错: {e}")
|
291 |
+
self._print_error_hints()
|
292 |
+
raise
|
293 |
+
|
294 |
+
def _print_error_hints(self):
|
295 |
+
"""打印错误提示信息"""
|
296 |
+
print("请确保模型名称正确且可访问。")
|
297 |
+
if self.use_4bit_quantization:
|
298 |
+
print("如果使用量化,请确保已安装 bitsandbytes 库: pip install bitsandbytes")
|
299 |
+
if self.device and self.device.type == "mps":
|
300 |
+
print("MPS 设备注意事项:")
|
301 |
+
print("- 不支持 4bit 量化")
|
302 |
+
print("- 不支持 device_map")
|
303 |
+
print("- 确保 PyTorch 版本支持 MPS")
|
304 |
+
|
305 |
+
def _generate_response(
|
306 |
+
self,
|
307 |
+
prompt_str: str,
|
308 |
+
temperature: float,
|
309 |
+
max_tokens: int,
|
310 |
+
top_p: float,
|
311 |
+
**kwargs
|
312 |
+
) -> str:
|
313 |
+
"""使用 transformers 生成响应"""
|
314 |
+
|
315 |
+
# 对提示进行编码
|
316 |
+
inputs = self.tokenizer.encode(prompt_str, return_tensors="pt")
|
317 |
+
|
318 |
+
# 移动输入到正确的设备
|
319 |
+
if self.device_map is None or (self.device and self.device.type == "mps"):
|
320 |
+
inputs = inputs.to(self.device)
|
321 |
+
|
322 |
+
# 生成参数
|
323 |
+
generation_config = {
|
324 |
+
"max_new_tokens": max_tokens,
|
325 |
+
"temperature": temperature,
|
326 |
+
"top_p": top_p,
|
327 |
+
"do_sample": True if temperature > 0 else False,
|
328 |
+
"pad_token_id": self.tokenizer.pad_token_id,
|
329 |
+
"eos_token_id": self.tokenizer.eos_token_id,
|
330 |
+
"repetition_penalty": kwargs.get("repetition_penalty", 1.1),
|
331 |
+
"no_repeat_ngram_size": kwargs.get("no_repeat_ngram_size", 3),
|
332 |
+
}
|
333 |
+
|
334 |
+
# 如果温度为0,使用贪婪解码
|
335 |
+
if temperature == 0:
|
336 |
+
generation_config["do_sample"] = False
|
337 |
+
generation_config.pop("temperature", None)
|
338 |
+
generation_config.pop("top_p", None)
|
339 |
+
|
340 |
+
try:
|
341 |
+
# 生成响应
|
342 |
+
with torch.no_grad():
|
343 |
+
outputs = self.model.generate(
|
344 |
+
inputs,
|
345 |
+
**generation_config
|
346 |
+
)
|
347 |
+
|
348 |
+
# 解码生成的文本,跳过输入部分
|
349 |
+
generated_tokens = outputs[0][len(inputs[0]):]
|
350 |
+
generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
351 |
+
|
352 |
+
return generated_text
|
353 |
+
|
354 |
+
except Exception as e:
|
355 |
+
print(f"生成响应时出错: {e}")
|
356 |
+
raise
|
357 |
+
|
358 |
+
def get_model_info(self) -> Dict[str, Union[str, bool, int]]:
|
359 |
+
"""获取模型信息"""
|
360 |
+
model_info = {
|
361 |
+
"model_name": self.model_name,
|
362 |
+
"use_4bit_quantization": self.use_4bit_quantization,
|
363 |
+
"device": str(self.device),
|
364 |
+
"device_type": self.device.type,
|
365 |
+
"device_map": self.device_map,
|
366 |
+
"model_type": "transformers",
|
367 |
+
"torch_dtype": str(self.torch_dtype),
|
368 |
+
"mps_available": torch.backends.mps.is_available() if hasattr(torch.backends, 'mps') else False,
|
369 |
+
"cuda_available": torch.cuda.is_available(),
|
370 |
+
}
|
371 |
+
|
372 |
+
# 添加模型配置信息(如果可用)
|
373 |
+
try:
|
374 |
+
if hasattr(self.model, "config"):
|
375 |
+
config = self.model.config
|
376 |
+
model_info.update({
|
377 |
+
"vocab_size": getattr(config, "vocab_size", "未知"),
|
378 |
+
"hidden_size": getattr(config, "hidden_size", "未知"),
|
379 |
+
"num_layers": getattr(config, "num_hidden_layers", "未知"),
|
380 |
+
"num_attention_heads": getattr(config, "num_attention_heads", "未知"),
|
381 |
+
})
|
382 |
+
except Exception:
|
383 |
+
pass
|
384 |
+
|
385 |
+
return model_info
|
386 |
+
|
387 |
+
def clear_cache(self):
|
388 |
+
"""清理 GPU 缓存"""
|
389 |
+
if torch.cuda.is_available():
|
390 |
+
torch.cuda.empty_cache()
|
391 |
+
print("GPU 缓存已清理")
|
src/podcast_transcribe/llm/llm_gemma_mlx.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mlx_lm import load, generate
|
2 |
+
from mlx_lm.sample_utils import make_sampler
|
3 |
+
from typing import Dict, Union
|
4 |
+
from .llm_base import BaseChatCompletion
|
5 |
+
|
6 |
+
|
7 |
+
class GemmaMLXChatCompletion(BaseChatCompletion):
|
8 |
+
"""基于 MLX 库的 Gemma 聊天完成实现"""
|
9 |
+
|
10 |
+
def __init__(self, model_name: str = "mlx-community/gemma-3-12b-it-4bit-DWQ"):
|
11 |
+
super().__init__(model_name)
|
12 |
+
self._load_model_and_tokenizer()
|
13 |
+
|
14 |
+
def _load_model_and_tokenizer(self):
|
15 |
+
"""加载 MLX 模型和分词器"""
|
16 |
+
try:
|
17 |
+
print(f"正在加载 MLX 模型: {self.model_name}")
|
18 |
+
self.model, self.tokenizer = load(self.model_name)
|
19 |
+
print(f"MLX 模型 {self.model_name} 加载成功")
|
20 |
+
except Exception as e:
|
21 |
+
print(f"加载模型 {self.model_name} 时出错: {e}")
|
22 |
+
print("请确保模型名称正确且可访问。")
|
23 |
+
print("您可以尝试使用 'mlx_lm.utils.get_model_path(model_name)' 搜索可用的模型。")
|
24 |
+
raise
|
25 |
+
|
26 |
+
def _generate_response(
|
27 |
+
self,
|
28 |
+
prompt_str: str,
|
29 |
+
temperature: float,
|
30 |
+
max_tokens: int,
|
31 |
+
top_p: float,
|
32 |
+
**kwargs
|
33 |
+
) -> str:
|
34 |
+
"""使用 MLX 生成响应"""
|
35 |
+
|
36 |
+
# 为temperature和top_p创建一个采样器
|
37 |
+
sampler = make_sampler(temp=temperature, top_p=top_p)
|
38 |
+
|
39 |
+
# 生成响应
|
40 |
+
# mlx_lm中的`generate`函数接受模型、分词器、提示和其他生成参数。
|
41 |
+
# 我们需要将我们的参数映射到`generate`期望的参数。
|
42 |
+
# `mlx_lm.generate` 的 verbose 参数可用于调试。
|
43 |
+
# `temperature` 是 `mlx_lm.generate` 中温度的参数名称。
|
44 |
+
response_text = generate(
|
45 |
+
self.model,
|
46 |
+
self.tokenizer,
|
47 |
+
prompt=prompt_str,
|
48 |
+
max_tokens=max_tokens,
|
49 |
+
sampler=sampler,
|
50 |
+
# verbose=True # 取消注释以调试生成过程
|
51 |
+
)
|
52 |
+
|
53 |
+
return response_text
|
54 |
+
|
55 |
+
def get_model_info(self) -> Dict[str, Union[str, bool, int]]:
|
56 |
+
"""获取模型信息"""
|
57 |
+
return {
|
58 |
+
"model_name": self.model_name,
|
59 |
+
"model_type": "mlx",
|
60 |
+
"library": "mlx_lm"
|
61 |
+
}
|
62 |
+
|
src/podcast_transcribe/llm/llm_gemma_transfomers.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
3 |
+
from typing import List, Dict, Optional, Union, Literal
|
4 |
+
from .llm_base import TransformersBaseChatCompletion
|
5 |
+
|
6 |
+
|
7 |
+
class GemmaTransformersChatCompletion(TransformersBaseChatCompletion):
|
8 |
+
"""基于 Transformers 库的 Gemma 聊天完成实现"""
|
9 |
+
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
model_name: str = "google/gemma-3-12b-it",
|
13 |
+
use_4bit_quantization: bool = False,
|
14 |
+
device_map: Optional[str] = "auto",
|
15 |
+
device: Optional[str] = None,
|
16 |
+
trust_remote_code: bool = True
|
17 |
+
):
|
18 |
+
# Gemma 使用 float16 作为默认数据类型
|
19 |
+
super().__init__(
|
20 |
+
model_name=model_name,
|
21 |
+
use_4bit_quantization=use_4bit_quantization,
|
22 |
+
device_map=device_map,
|
23 |
+
device=device,
|
24 |
+
trust_remote_code=trust_remote_code,
|
25 |
+
torch_dtype=torch.float16
|
26 |
+
)
|
27 |
+
|
28 |
+
def _print_error_hints(self):
|
29 |
+
"""打印Gemma特定的错误提示信息"""
|
30 |
+
super()._print_error_hints()
|
31 |
+
print("Gemma 特殊要求:")
|
32 |
+
print("- 建议使用 Transformers >= 4.21.0")
|
33 |
+
print("- 推荐使用 float16 数据类型")
|
34 |
+
print("- 确保有足够的GPU内存")
|
35 |
+
|
36 |
+
|
37 |
+
# 为了保持向后兼容性,也可以提供一个简化的工厂函数
|
38 |
+
def create_gemma_transformers_client(
|
39 |
+
model_name: str = "google/gemma-3-12b-it",
|
40 |
+
use_4bit_quantization: bool = False,
|
41 |
+
device: Optional[str] = None,
|
42 |
+
**kwargs
|
43 |
+
) -> GemmaTransformersChatCompletion:
|
44 |
+
"""
|
45 |
+
创建 Gemma Transformers 客户端的工厂函数
|
46 |
+
|
47 |
+
Args:
|
48 |
+
model_name: 模型名称
|
49 |
+
use_4bit_quantization: 是否使用4bit量化
|
50 |
+
device: 指定设备 ("cpu", "cuda", "mps", 等)
|
51 |
+
**kwargs: 其他传递给构造函数的参数
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
GemmaTransformersChatCompletion 实例
|
55 |
+
"""
|
56 |
+
return GemmaTransformersChatCompletion(
|
57 |
+
model_name=model_name,
|
58 |
+
use_4bit_quantization=use_4bit_quantization,
|
59 |
+
device=device,
|
60 |
+
**kwargs
|
61 |
+
)
|
src/podcast_transcribe/llm/llm_phi4_transfomers.py
ADDED
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
3 |
+
from typing import List, Dict, Optional, Union, Literal
|
4 |
+
from .llm_base import TransformersBaseChatCompletion
|
5 |
+
|
6 |
+
|
7 |
+
class Phi4TransformersChatCompletion(TransformersBaseChatCompletion):
|
8 |
+
"""基于 Transformers 库的 Phi-4-mini-reasoning 聊天完成实现"""
|
9 |
+
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
model_name: str = "microsoft/Phi-4-mini-reasoning",
|
13 |
+
use_4bit_quantization: bool = False,
|
14 |
+
device_map: Optional[str] = "auto",
|
15 |
+
device: Optional[str] = None,
|
16 |
+
trust_remote_code: bool = True
|
17 |
+
):
|
18 |
+
# Phi-4 使用 bfloat16 作为推荐数据类型
|
19 |
+
super().__init__(
|
20 |
+
model_name=model_name,
|
21 |
+
use_4bit_quantization=use_4bit_quantization,
|
22 |
+
device_map=device_map,
|
23 |
+
device=device,
|
24 |
+
trust_remote_code=trust_remote_code,
|
25 |
+
torch_dtype=torch.bfloat16
|
26 |
+
)
|
27 |
+
|
28 |
+
def _print_error_hints(self):
|
29 |
+
"""打印Phi-4特定的错误提示信息"""
|
30 |
+
super()._print_error_hints()
|
31 |
+
print("Phi-4 特殊要求:")
|
32 |
+
print("- 建议使用 Transformers >= 4.51.3")
|
33 |
+
print("- 推荐使用 bfloat16 数据类型")
|
34 |
+
print("- 模型支持 128K token 上下文长度")
|
35 |
+
|
36 |
+
def _format_phi4_messages(self, messages: List[Dict[str, str]]) -> str:
|
37 |
+
"""
|
38 |
+
格式化消息为 Phi-4 的聊天格式
|
39 |
+
Phi-4 使用特定的聊天模板格式
|
40 |
+
"""
|
41 |
+
# 使用 tokenizer 的内置聊天模板
|
42 |
+
if hasattr(self.tokenizer, 'apply_chat_template'):
|
43 |
+
return self.tokenizer.apply_chat_template(
|
44 |
+
messages,
|
45 |
+
tokenize=False,
|
46 |
+
add_generation_prompt=True
|
47 |
+
)
|
48 |
+
else:
|
49 |
+
# 如果没有聊天模板,使用 Phi-4 的标准格式
|
50 |
+
formatted_prompt = ""
|
51 |
+
for message in messages:
|
52 |
+
role = message.get("role", "user")
|
53 |
+
content = message.get("content", "")
|
54 |
+
|
55 |
+
if role == "system":
|
56 |
+
formatted_prompt += f"<|system|>\n{content}<|end|>\n"
|
57 |
+
elif role == "user":
|
58 |
+
formatted_prompt += f"<|user|>\n{content}<|end|>\n"
|
59 |
+
elif role == "assistant":
|
60 |
+
formatted_prompt += f"<|assistant|>\n{content}<|end|>\n"
|
61 |
+
|
62 |
+
# 添加助手开始标记
|
63 |
+
formatted_prompt += "<|assistant|>\n"
|
64 |
+
return formatted_prompt
|
65 |
+
|
66 |
+
def _generate_response(
|
67 |
+
self,
|
68 |
+
prompt_str: str,
|
69 |
+
temperature: float,
|
70 |
+
max_tokens: int,
|
71 |
+
top_p: float,
|
72 |
+
enable_reasoning: bool = True,
|
73 |
+
**kwargs
|
74 |
+
) -> str:
|
75 |
+
"""使用 transformers 生成响应,针对 Phi-4 推理功能优化"""
|
76 |
+
|
77 |
+
# 对提示进行编码
|
78 |
+
inputs = self.tokenizer.encode(prompt_str, return_tensors="pt")
|
79 |
+
|
80 |
+
# 移动输入到正确的设备
|
81 |
+
if self.device_map is None or self.device.type == "mps":
|
82 |
+
inputs = inputs.to(self.device)
|
83 |
+
|
84 |
+
# Phi-4-mini-reasoning 优化的生成参数
|
85 |
+
generation_config = {
|
86 |
+
"max_new_tokens": min(max_tokens, 32768), # Phi-4-mini 支持最大 32K token
|
87 |
+
"temperature": temperature,
|
88 |
+
"top_p": top_p,
|
89 |
+
"do_sample": True if temperature > 0 else False,
|
90 |
+
"pad_token_id": self.tokenizer.pad_token_id,
|
91 |
+
"eos_token_id": self.tokenizer.eos_token_id,
|
92 |
+
"repetition_penalty": kwargs.get("repetition_penalty", 1.1),
|
93 |
+
"no_repeat_ngram_size": kwargs.get("no_repeat_ngram_size", 3),
|
94 |
+
}
|
95 |
+
|
96 |
+
# 推理模式配置
|
97 |
+
if enable_reasoning and "reasoning" in self.model_name.lower():
|
98 |
+
# 为推理任务优化的配置
|
99 |
+
generation_config.update({
|
100 |
+
"temperature": max(temperature, 0.1), # 推理模式下保持一定的温度
|
101 |
+
"top_p": min(top_p, 0.95), # 推理模式下限制 top_p
|
102 |
+
"do_sample": True, # 推理模式下总是启用采样
|
103 |
+
"early_stopping": False, # 允许完整的推理过程
|
104 |
+
})
|
105 |
+
|
106 |
+
# 如果温度为0,使用贪婪解码
|
107 |
+
if temperature == 0:
|
108 |
+
generation_config["do_sample"] = False
|
109 |
+
generation_config.pop("temperature", None)
|
110 |
+
generation_config.pop("top_p", None)
|
111 |
+
|
112 |
+
try:
|
113 |
+
# 生成响应
|
114 |
+
with torch.no_grad():
|
115 |
+
outputs = self.model.generate(
|
116 |
+
inputs,
|
117 |
+
**generation_config
|
118 |
+
)
|
119 |
+
|
120 |
+
# 解码生成的文本,跳过输入部分
|
121 |
+
generated_tokens = outputs[0][len(inputs[0]):]
|
122 |
+
generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
123 |
+
|
124 |
+
return generated_text
|
125 |
+
|
126 |
+
except Exception as e:
|
127 |
+
print(f"生成响应时出错: {e}")
|
128 |
+
raise
|
129 |
+
|
130 |
+
def create(
|
131 |
+
self,
|
132 |
+
messages: List[Dict[str, str]],
|
133 |
+
temperature: float = 0.7,
|
134 |
+
max_tokens: int = 2048,
|
135 |
+
top_p: float = 1.0,
|
136 |
+
model: Optional[str] = None,
|
137 |
+
enable_reasoning: bool = True,
|
138 |
+
**kwargs,
|
139 |
+
):
|
140 |
+
"""
|
141 |
+
创建聊天完成响应,支持Phi-4特有的推理功能
|
142 |
+
"""
|
143 |
+
if model and model != self.model_name:
|
144 |
+
print(f"警告: 'model' 参数 ({model}) 与初始化的模型 ({self.model_name}) 不同。"
|
145 |
+
f"正在使用初始化的模型。要使用不同的模型,请重新初始化该类。")
|
146 |
+
|
147 |
+
# 检查是否为推理任务
|
148 |
+
is_reasoning_task = self._is_reasoning_task(messages)
|
149 |
+
|
150 |
+
# 格式化消息为 Phi-4 聊天格式
|
151 |
+
if is_reasoning_task and enable_reasoning:
|
152 |
+
prompt_str = self._format_reasoning_prompt(messages)
|
153 |
+
else:
|
154 |
+
prompt_str = self._format_phi4_messages(messages)
|
155 |
+
|
156 |
+
# 生成响应
|
157 |
+
response_text = self._generate_response(
|
158 |
+
prompt_str,
|
159 |
+
temperature,
|
160 |
+
max_tokens,
|
161 |
+
top_p,
|
162 |
+
enable_reasoning=enable_reasoning and is_reasoning_task,
|
163 |
+
**kwargs
|
164 |
+
)
|
165 |
+
|
166 |
+
# 后处理响应(使用基类的方法,但针对Phi-4调整)
|
167 |
+
assistant_message_content = self._post_process_phi4_response(response_text, prompt_str)
|
168 |
+
|
169 |
+
# 计算token使用量
|
170 |
+
token_usage = self._calculate_tokens(prompt_str, assistant_message_content)
|
171 |
+
|
172 |
+
# 构建响应对象
|
173 |
+
response = self._build_chat_completion_response(assistant_message_content, token_usage)
|
174 |
+
|
175 |
+
# 添加Phi-4特有的信息
|
176 |
+
response["reasoning_enabled"] = enable_reasoning and is_reasoning_task
|
177 |
+
|
178 |
+
return response
|
179 |
+
|
180 |
+
def _post_process_phi4_response(self, response_text: str, prompt_str: str) -> str:
|
181 |
+
"""
|
182 |
+
后处理Phi-4生成的响应文本
|
183 |
+
"""
|
184 |
+
# Phi-4的输出通常不包含输入提示,直接返回生成的内容
|
185 |
+
assistant_message_content = response_text.strip()
|
186 |
+
|
187 |
+
# 清理可能的特殊标记
|
188 |
+
if assistant_message_content.endswith("<|end|>"):
|
189 |
+
assistant_message_content = assistant_message_content[:-7].strip()
|
190 |
+
|
191 |
+
return assistant_message_content
|
192 |
+
|
193 |
+
def _is_reasoning_task(self, messages: List[Dict[str, str]]) -> bool:
|
194 |
+
"""检测是否为推理任务"""
|
195 |
+
reasoning_keywords = [
|
196 |
+
"解题", "推理", "计算", "证明", "分析", "逻辑", "步骤",
|
197 |
+
"solve", "reasoning", "calculate", "prove", "analyze", "logic", "step"
|
198 |
+
]
|
199 |
+
|
200 |
+
for message in messages:
|
201 |
+
content = message.get("content", "").lower()
|
202 |
+
if any(keyword in content for keyword in reasoning_keywords):
|
203 |
+
return True
|
204 |
+
|
205 |
+
return False
|
206 |
+
|
207 |
+
def _format_reasoning_prompt(self, messages: List[Dict[str, str]]) -> str:
|
208 |
+
"""
|
209 |
+
为推理任务格式化特殊的提示词
|
210 |
+
"""
|
211 |
+
# 添加推理指导的系统消息
|
212 |
+
reasoning_system_msg = {
|
213 |
+
"role": "system",
|
214 |
+
"content": "你是一个专业的数学推理助手。请逐步分析问题,展示详细的推理过程,包括:\n1. 问题理解\n2. 解题思路\n3. 具体步骤\n4. 最终答案\n\n每个步骤都要清晰明了。"
|
215 |
+
}
|
216 |
+
|
217 |
+
# 将推理系统消息添加到消息列表的开头
|
218 |
+
enhanced_messages = [reasoning_system_msg] + messages
|
219 |
+
|
220 |
+
# 使用标准格式化方法
|
221 |
+
return self._format_phi4_messages(enhanced_messages)
|
222 |
+
|
223 |
+
def reasoning_completion(
|
224 |
+
self,
|
225 |
+
messages: List[Dict[str, str]],
|
226 |
+
temperature: float = 0.3, # 推理任务使用较低的温度
|
227 |
+
max_tokens: int = 2048, # 推理任务需要更多 tokens
|
228 |
+
top_p: float = 0.9,
|
229 |
+
extract_reasoning_steps: bool = True,
|
230 |
+
**kwargs
|
231 |
+
) -> Dict[str, Union[str, Dict, List]]:
|
232 |
+
"""
|
233 |
+
专门用于推理任务的聊天完成接口
|
234 |
+
|
235 |
+
Args:
|
236 |
+
messages: 对话消息列表
|
237 |
+
temperature: 采样温度(推理任务建议使用较低值)
|
238 |
+
max_tokens: 最大生成token数量
|
239 |
+
top_p: top-p采样参数
|
240 |
+
extract_reasoning_steps: 是否提取推理步骤
|
241 |
+
**kwargs: 其他参数
|
242 |
+
|
243 |
+
Returns:
|
244 |
+
包含推理步骤的响应字典
|
245 |
+
"""
|
246 |
+
# 强制启用推理模式
|
247 |
+
response = self.create(
|
248 |
+
messages=messages,
|
249 |
+
temperature=temperature,
|
250 |
+
max_tokens=max_tokens,
|
251 |
+
top_p=top_p,
|
252 |
+
enable_reasoning=True,
|
253 |
+
**kwargs
|
254 |
+
)
|
255 |
+
|
256 |
+
if extract_reasoning_steps:
|
257 |
+
# 提取推理步骤
|
258 |
+
content = response["choices"][0]["message"]["content"]
|
259 |
+
reasoning_steps = self._extract_reasoning_steps(content)
|
260 |
+
response["reasoning_steps"] = reasoning_steps
|
261 |
+
|
262 |
+
return response
|
263 |
+
|
264 |
+
def _extract_reasoning_steps(self, content: str) -> List[Dict[str, str]]:
|
265 |
+
"""
|
266 |
+
从响应内容中提取推理步骤
|
267 |
+
"""
|
268 |
+
steps = []
|
269 |
+
lines = content.split('\n')
|
270 |
+
current_step = {"title": "", "content": ""}
|
271 |
+
|
272 |
+
step_patterns = [
|
273 |
+
"1. 问题理解", "2. 解题思路", "3. 具体步骤", "4. 最终答案",
|
274 |
+
"步骤", "分析", "解答", "结论", "reasoning", "step", "analysis", "solution"
|
275 |
+
]
|
276 |
+
|
277 |
+
for line in lines:
|
278 |
+
line = line.strip()
|
279 |
+
if not line:
|
280 |
+
continue
|
281 |
+
|
282 |
+
# 检查是否是新的步骤开始
|
283 |
+
is_new_step = any(pattern in line.lower() for pattern in step_patterns)
|
284 |
+
if is_new_step and current_step["content"]:
|
285 |
+
steps.append(current_step.copy())
|
286 |
+
current_step = {"title": line, "content": ""}
|
287 |
+
elif is_new_step:
|
288 |
+
current_step["title"] = line
|
289 |
+
else:
|
290 |
+
if current_step["title"]:
|
291 |
+
current_step["content"] += line + "\n"
|
292 |
+
else:
|
293 |
+
current_step["content"] = line + "\n"
|
294 |
+
|
295 |
+
# 添加最后一个步骤
|
296 |
+
if current_step["title"] or current_step["content"]:
|
297 |
+
steps.append(current_step)
|
298 |
+
|
299 |
+
return steps
|
300 |
+
|
301 |
+
def get_model_info(self) -> Dict[str, Union[str, bool, int]]:
|
302 |
+
"""获取 Phi-4 模型信息"""
|
303 |
+
model_info = super().get_model_info()
|
304 |
+
|
305 |
+
# 添加Phi-4特有的信息
|
306 |
+
model_info.update({
|
307 |
+
"model_family": "Phi-4-mini-reasoning",
|
308 |
+
"parameters": "3.8B",
|
309 |
+
"context_length": "128K tokens",
|
310 |
+
"specialization": "数学推理优化",
|
311 |
+
})
|
312 |
+
|
313 |
+
return model_info
|
314 |
+
|
315 |
+
|
316 |
+
# 工厂函数
|
317 |
+
def create_phi4_transformers_client(
|
318 |
+
model_name: str = "microsoft/Phi-4-mini-reasoning",
|
319 |
+
use_4bit_quantization: bool = False,
|
320 |
+
device: Optional[str] = None,
|
321 |
+
**kwargs
|
322 |
+
) -> Phi4TransformersChatCompletion:
|
323 |
+
"""
|
324 |
+
创建 Phi-4 Transformers 客户端的工厂函数
|
325 |
+
|
326 |
+
Args:
|
327 |
+
model_name: 模型名称,默认为 microsoft/Phi-4-mini-reasoning
|
328 |
+
use_4bit_quantization: 是否使用4bit量化
|
329 |
+
device: 指定设备 ("cpu", "cuda", "mps", 等)
|
330 |
+
**kwargs: 其他传递给构造函数的参数
|
331 |
+
|
332 |
+
Returns:
|
333 |
+
Phi4TransformersChatCompletion 实例
|
334 |
+
"""
|
335 |
+
return Phi4TransformersChatCompletion(
|
336 |
+
model_name=model_name,
|
337 |
+
use_4bit_quantization=use_4bit_quantization,
|
338 |
+
device=device,
|
339 |
+
**kwargs
|
340 |
+
)
|
341 |
+
|
342 |
+
def create_reasoning_client(
|
343 |
+
model_name: str = "microsoft/Phi-4-mini-reasoning",
|
344 |
+
use_4bit_quantization: bool = False,
|
345 |
+
device: Optional[str] = None,
|
346 |
+
**kwargs
|
347 |
+
) -> Phi4TransformersChatCompletion:
|
348 |
+
"""
|
349 |
+
创建专门用于推理任务的 Phi-4 客户端
|
350 |
+
|
351 |
+
Args:
|
352 |
+
model_name: 模型名称,推荐使用 microsoft/Phi-4-mini-reasoning
|
353 |
+
use_4bit_quantization: 是否使用4bit量化
|
354 |
+
device: 指定设备 ("cpu", "cuda", "mps", 等)
|
355 |
+
**kwargs: 其他传递给构造函数的参数
|
356 |
+
|
357 |
+
Returns:
|
358 |
+
优化了推理功能的 Phi4TransformersChatCompletion 实例
|
359 |
+
"""
|
360 |
+
# 确保使用推理模型
|
361 |
+
if "reasoning" not in model_name.lower():
|
362 |
+
print("警告: 建议使用包含 'reasoning' 的模型名称以获得最佳推理性能")
|
363 |
+
|
364 |
+
return Phi4TransformersChatCompletion(
|
365 |
+
model_name=model_name,
|
366 |
+
use_4bit_quantization=use_4bit_quantization,
|
367 |
+
device=device,
|
368 |
+
**kwargs
|
369 |
+
)
|
src/podcast_transcribe/llm/llm_router.py
ADDED
@@ -0,0 +1,578 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
LLM模型调用路由器
|
3 |
+
根据传递的provider参数调用不同的LLM实现,支持延迟加载
|
4 |
+
"""
|
5 |
+
|
6 |
+
import logging
|
7 |
+
from typing import Dict, Any, Optional, List, Union
|
8 |
+
from .llm_base import BaseChatCompletion
|
9 |
+
from . import llm_gemma_mlx
|
10 |
+
from . import llm_gemma_transfomers
|
11 |
+
from . import llm_phi4_transfomers
|
12 |
+
|
13 |
+
# 配置日志
|
14 |
+
logger = logging.getLogger("llm")
|
15 |
+
|
16 |
+
|
17 |
+
class LLMRouter:
|
18 |
+
"""LLM模型调用路由器,支持多种实现的统一调用"""
|
19 |
+
|
20 |
+
def __init__(self):
|
21 |
+
"""初始化路由器"""
|
22 |
+
self._loaded_modules = {} # 用于缓存已加载的模块
|
23 |
+
self._llm_instances = {} # 用于缓存已实例化的LLM实例
|
24 |
+
|
25 |
+
# 定义支持的provider配置
|
26 |
+
self._provider_configs = {
|
27 |
+
"gemma-mlx": {
|
28 |
+
"module_path": "llm_gemma_mlx",
|
29 |
+
"class_name": "GemmaMLXChatCompletion",
|
30 |
+
"default_model": "mlx-community/gemma-3-12b-it-4bit-DWQ",
|
31 |
+
"supported_params": ["model_name"],
|
32 |
+
"description": "基于MLX库的Gemma聊天完成实现"
|
33 |
+
},
|
34 |
+
"gemma-transformers": {
|
35 |
+
"module_path": "llm_gemma_transfomers",
|
36 |
+
"class_name": "GemmaTransformersChatCompletion",
|
37 |
+
"default_model": "google/gemma-3-12b-it",
|
38 |
+
"supported_params": [
|
39 |
+
"model_name", "use_4bit_quantization", "device_map",
|
40 |
+
"device", "trust_remote_code"
|
41 |
+
],
|
42 |
+
"description": "基于Transformers库的Gemma聊天完成实现"
|
43 |
+
},
|
44 |
+
"phi4-transformers": {
|
45 |
+
"module_path": "llm_phi4_transfomers",
|
46 |
+
"class_name": "Phi4TransformersChatCompletion",
|
47 |
+
"default_model": "microsoft/Phi-4-reasoning",
|
48 |
+
"supported_params": [
|
49 |
+
"model_name", "use_4bit_quantization", "device_map",
|
50 |
+
"device", "trust_remote_code", "enable_reasoning"
|
51 |
+
],
|
52 |
+
"description": "基于Transformers库的Phi-4推理聊天完成实现"
|
53 |
+
}
|
54 |
+
}
|
55 |
+
|
56 |
+
def _lazy_load_module(self, provider: str):
|
57 |
+
"""
|
58 |
+
获取指定provider的模块
|
59 |
+
|
60 |
+
参数:
|
61 |
+
provider: provider名称
|
62 |
+
|
63 |
+
返回:
|
64 |
+
对应的模块
|
65 |
+
"""
|
66 |
+
if provider not in self._provider_configs:
|
67 |
+
raise ValueError(f"不支持的provider: {provider}")
|
68 |
+
|
69 |
+
if provider not in self._loaded_modules:
|
70 |
+
module_path = self._provider_configs[provider]["module_path"]
|
71 |
+
logger.info(f"获取模块: {module_path}")
|
72 |
+
|
73 |
+
# 根据module_path返回对应的模块
|
74 |
+
if module_path == "llm_gemma_mlx":
|
75 |
+
module = llm_gemma_mlx
|
76 |
+
elif module_path == "llm_gemma_transfomers":
|
77 |
+
module = llm_gemma_transfomers
|
78 |
+
elif module_path == "llm_phi4_transfomers":
|
79 |
+
module = llm_phi4_transfomers
|
80 |
+
else:
|
81 |
+
raise ImportError(f"未找到模块: {module_path}")
|
82 |
+
|
83 |
+
self._loaded_modules[provider] = module
|
84 |
+
logger.info(f"模块 {module_path} 获取成功")
|
85 |
+
|
86 |
+
return self._loaded_modules[provider]
|
87 |
+
|
88 |
+
def _get_llm_class(self, provider: str):
|
89 |
+
"""
|
90 |
+
获取指定provider的LLM类
|
91 |
+
|
92 |
+
参数:
|
93 |
+
provider: provider名称
|
94 |
+
|
95 |
+
返回:
|
96 |
+
LLM类
|
97 |
+
"""
|
98 |
+
module = self._lazy_load_module(provider)
|
99 |
+
class_name = self._provider_configs[provider]["class_name"]
|
100 |
+
|
101 |
+
if not hasattr(module, class_name):
|
102 |
+
raise AttributeError(f"模块中未找到类: {class_name}")
|
103 |
+
|
104 |
+
return getattr(module, class_name)
|
105 |
+
|
106 |
+
def _filter_params(self, provider: str, params: Dict[str, Any]) -> Dict[str, Any]:
|
107 |
+
"""
|
108 |
+
过滤参数,只保留指定provider支持的参数
|
109 |
+
|
110 |
+
参数:
|
111 |
+
provider: provider名称
|
112 |
+
params: 原始参数字典
|
113 |
+
|
114 |
+
返回:
|
115 |
+
过滤后的参数字典
|
116 |
+
"""
|
117 |
+
supported_params = self._provider_configs[provider]["supported_params"]
|
118 |
+
filtered_params = {}
|
119 |
+
|
120 |
+
for param in supported_params:
|
121 |
+
if param in params:
|
122 |
+
filtered_params[param] = params[param]
|
123 |
+
|
124 |
+
# 如果没有指定model_name,使用默认模型
|
125 |
+
if "model_name" not in filtered_params and "model_name" in supported_params:
|
126 |
+
filtered_params["model_name"] = self._provider_configs[provider]["default_model"]
|
127 |
+
|
128 |
+
return filtered_params
|
129 |
+
|
130 |
+
def _get_instance_key(self, provider: str, params: Dict[str, Any]) -> str:
|
131 |
+
"""
|
132 |
+
生成LLM实例的缓存键
|
133 |
+
|
134 |
+
参数:
|
135 |
+
provider: provider名称
|
136 |
+
params: 参数字典
|
137 |
+
|
138 |
+
返回:
|
139 |
+
实例缓存键
|
140 |
+
"""
|
141 |
+
# 将参数转换为可哈希的字符串
|
142 |
+
param_str = "_".join([f"{k}={v}" for k, v in sorted(params.items())])
|
143 |
+
return f"{provider}_{param_str}"
|
144 |
+
|
145 |
+
def _get_or_create_instance(self, provider: str, **kwargs) -> BaseChatCompletion:
|
146 |
+
"""
|
147 |
+
获取或创建LLM实例(支持缓存复用)
|
148 |
+
|
149 |
+
参数:
|
150 |
+
provider: provider名称
|
151 |
+
**kwargs: 构造函数参数
|
152 |
+
|
153 |
+
返回:
|
154 |
+
LLM实例
|
155 |
+
"""
|
156 |
+
# 过滤并准备参数
|
157 |
+
filtered_kwargs = self._filter_params(provider, kwargs)
|
158 |
+
|
159 |
+
# 生成实例缓存键
|
160 |
+
instance_key = self._get_instance_key(provider, filtered_kwargs)
|
161 |
+
|
162 |
+
# 检查是否已有缓存实例
|
163 |
+
if instance_key not in self._llm_instances:
|
164 |
+
try:
|
165 |
+
# 获取LLM类
|
166 |
+
llm_class = self._get_llm_class(provider)
|
167 |
+
|
168 |
+
logger.debug(f"创建 {provider} LLM实例,参数: {filtered_kwargs}")
|
169 |
+
|
170 |
+
# 创建实例
|
171 |
+
instance = llm_class(**filtered_kwargs)
|
172 |
+
|
173 |
+
# 缓存实例
|
174 |
+
self._llm_instances[instance_key] = instance
|
175 |
+
|
176 |
+
logger.info(f"LLM实例创建成功: {provider} ({instance.model_name})")
|
177 |
+
|
178 |
+
except Exception as e:
|
179 |
+
logger.error(f"创建 {provider} LLM实例失败: {str(e)}", exc_info=True)
|
180 |
+
raise RuntimeError(f"创建LLM实例失败: {str(e)}")
|
181 |
+
|
182 |
+
return self._llm_instances[instance_key]
|
183 |
+
|
184 |
+
def chat_completion(
|
185 |
+
self,
|
186 |
+
messages: List[Dict[str, str]],
|
187 |
+
provider: str,
|
188 |
+
temperature: float = 0.7,
|
189 |
+
max_tokens: int = 2048,
|
190 |
+
top_p: float = 1.0,
|
191 |
+
model: Optional[str] = None,
|
192 |
+
**kwargs
|
193 |
+
) -> Dict[str, Any]:
|
194 |
+
"""
|
195 |
+
统一的聊天完成接口
|
196 |
+
|
197 |
+
参数:
|
198 |
+
messages: 消息列表,每个消息包含role和content
|
199 |
+
provider: LLM提供者名称
|
200 |
+
temperature: 温度参数,控制生成的随机性
|
201 |
+
max_tokens: 最大生成token数
|
202 |
+
top_p: nucleus采样参数
|
203 |
+
model: 可选的模型名称,如果提供则覆盖默认model_name
|
204 |
+
**kwargs: 其他参数,如device、use_4bit_quantization等
|
205 |
+
|
206 |
+
返回:
|
207 |
+
聊天完成响应字典
|
208 |
+
"""
|
209 |
+
logger.info(f"使用provider '{provider}' 进行聊天完成,消息数量: {len(messages)}")
|
210 |
+
|
211 |
+
if provider not in self._provider_configs:
|
212 |
+
available_providers = list(self._provider_configs.keys())
|
213 |
+
raise ValueError(f"不支持的provider: {provider}。支持的provider: {available_providers}")
|
214 |
+
|
215 |
+
try:
|
216 |
+
# 如果提供了model参数,添加到kwargs中
|
217 |
+
if model is not None:
|
218 |
+
kwargs["model_name"] = model
|
219 |
+
|
220 |
+
# 获取或创建LLM实例
|
221 |
+
llm_instance = self._get_or_create_instance(provider, **kwargs)
|
222 |
+
|
223 |
+
# 调用聊天完成
|
224 |
+
result = llm_instance.create(
|
225 |
+
messages=messages,
|
226 |
+
temperature=temperature,
|
227 |
+
max_tokens=max_tokens,
|
228 |
+
top_p=top_p,
|
229 |
+
model=model,
|
230 |
+
**kwargs
|
231 |
+
)
|
232 |
+
|
233 |
+
logger.info(f"聊天完成成功,使用tokens: {result.get('usage', {}).get('total_tokens', 'unknown')}")
|
234 |
+
return result
|
235 |
+
|
236 |
+
except Exception as e:
|
237 |
+
logger.error(f"使用provider '{provider}' 进行聊天完成失败: {str(e)}", exc_info=True)
|
238 |
+
raise RuntimeError(f"聊天完成失败: {str(e)}")
|
239 |
+
|
240 |
+
def reasoning_completion(
|
241 |
+
self,
|
242 |
+
messages: List[Dict[str, str]],
|
243 |
+
provider: str = "phi4-transformers",
|
244 |
+
temperature: float = 0.3,
|
245 |
+
max_tokens: int = 2048,
|
246 |
+
top_p: float = 0.9,
|
247 |
+
model: Optional[str] = None,
|
248 |
+
extract_reasoning_steps: bool = True,
|
249 |
+
**kwargs
|
250 |
+
) -> Dict[str, Any]:
|
251 |
+
"""
|
252 |
+
专门用于推理任务的聊天完成接口
|
253 |
+
|
254 |
+
参数:
|
255 |
+
messages: 消息列表,每个消息包含role和content
|
256 |
+
provider: LLM提供者名称,默认使用phi4-transformers
|
257 |
+
temperature: 温度参数(推理任务建议使用较低值)
|
258 |
+
max_tokens: 最大生成token数
|
259 |
+
top_p: nucleus采样参数
|
260 |
+
model: 可选的模型名称
|
261 |
+
extract_reasoning_steps: 是否提取推理步骤
|
262 |
+
**kwargs: 其他参数
|
263 |
+
|
264 |
+
返回:
|
265 |
+
包含推理步骤的响应字典
|
266 |
+
"""
|
267 |
+
logger.info(f"使用provider '{provider}' 进行推理完成,消息数量: {len(messages)}")
|
268 |
+
|
269 |
+
# 确保使用支持推理的provider
|
270 |
+
if provider not in ["phi4-transformers"]:
|
271 |
+
logger.warning(f"Provider '{provider}' 可能不支持推理功能,建议使用 'phi4-transformers'")
|
272 |
+
|
273 |
+
try:
|
274 |
+
# 如果提供了model参数,添加到kwargs中
|
275 |
+
if model is not None:
|
276 |
+
kwargs["model_name"] = model
|
277 |
+
|
278 |
+
# 获取或创建LLM实例
|
279 |
+
llm_instance = self._get_or_create_instance(provider, **kwargs)
|
280 |
+
|
281 |
+
# 检查实例是否支持推理完成
|
282 |
+
if hasattr(llm_instance, 'reasoning_completion'):
|
283 |
+
result = llm_instance.reasoning_completion(
|
284 |
+
messages=messages,
|
285 |
+
temperature=temperature,
|
286 |
+
max_tokens=max_tokens,
|
287 |
+
top_p=top_p,
|
288 |
+
extract_reasoning_steps=extract_reasoning_steps,
|
289 |
+
**kwargs
|
290 |
+
)
|
291 |
+
else:
|
292 |
+
# 回退到普通聊天完成
|
293 |
+
logger.warning(f"Provider '{provider}' 不支持推理完成,回退到普通聊天完成")
|
294 |
+
result = llm_instance.create(
|
295 |
+
messages=messages,
|
296 |
+
temperature=temperature,
|
297 |
+
max_tokens=max_tokens,
|
298 |
+
top_p=top_p,
|
299 |
+
model=model,
|
300 |
+
**kwargs
|
301 |
+
)
|
302 |
+
|
303 |
+
logger.info(f"推理完成成功,使用tokens: {result.get('usage', {}).get('total_tokens', 'unknown')}")
|
304 |
+
return result
|
305 |
+
|
306 |
+
except Exception as e:
|
307 |
+
logger.error(f"使用provider '{provider}' 进行推理完成失败: {str(e)}", exc_info=True)
|
308 |
+
raise RuntimeError(f"推理完成失败: {str(e)}")
|
309 |
+
|
310 |
+
def get_model_info(self, provider: str, **kwargs) -> Dict[str, Any]:
|
311 |
+
"""
|
312 |
+
获取模型信息
|
313 |
+
|
314 |
+
参数:
|
315 |
+
provider: provider名称
|
316 |
+
**kwargs: 构造函数参数
|
317 |
+
|
318 |
+
返回:
|
319 |
+
模型信息字典
|
320 |
+
"""
|
321 |
+
try:
|
322 |
+
llm_instance = self._get_or_create_instance(provider, **kwargs)
|
323 |
+
return llm_instance.get_model_info()
|
324 |
+
except Exception as e:
|
325 |
+
logger.error(f"获取模型信息失败: {str(e)}")
|
326 |
+
raise RuntimeError(f"获取模型信息失败: {str(e)}")
|
327 |
+
|
328 |
+
def get_available_providers(self) -> Dict[str, str]:
|
329 |
+
"""
|
330 |
+
获取所有可用的provider及其描述
|
331 |
+
|
332 |
+
返回:
|
333 |
+
provider名称到描述的映射
|
334 |
+
"""
|
335 |
+
return {
|
336 |
+
provider: config["description"]
|
337 |
+
for provider, config in self._provider_configs.items()
|
338 |
+
}
|
339 |
+
|
340 |
+
def get_provider_info(self, provider: str) -> Dict[str, Any]:
|
341 |
+
"""
|
342 |
+
获取指定provider的详细信息
|
343 |
+
|
344 |
+
参数:
|
345 |
+
provider: provider名称
|
346 |
+
|
347 |
+
返回:
|
348 |
+
provider的配置信息
|
349 |
+
"""
|
350 |
+
if provider not in self._provider_configs:
|
351 |
+
raise ValueError(f"不支持的provider: {provider}")
|
352 |
+
|
353 |
+
return self._provider_configs[provider].copy()
|
354 |
+
|
355 |
+
def clear_cache(self):
|
356 |
+
"""清理缓存的实例"""
|
357 |
+
# 清理每个实例的GPU缓存
|
358 |
+
for instance in self._llm_instances.values():
|
359 |
+
if hasattr(instance, 'clear_cache'):
|
360 |
+
instance.clear_cache()
|
361 |
+
|
362 |
+
# 清理实例缓存
|
363 |
+
self._llm_instances.clear()
|
364 |
+
logger.info("LLM实例缓存已清理")
|
365 |
+
|
366 |
+
|
367 |
+
# 创建全局路由器实例
|
368 |
+
_router = LLMRouter()
|
369 |
+
|
370 |
+
|
371 |
+
def chat_completion(
|
372 |
+
messages: List[Dict[str, str]],
|
373 |
+
provider: str = "gemma-mlx",
|
374 |
+
temperature: float = 0.7,
|
375 |
+
max_tokens: int = 2048,
|
376 |
+
top_p: float = 1.0,
|
377 |
+
model: Optional[str] = None,
|
378 |
+
device: Optional[str] = None,
|
379 |
+
use_4bit_quantization: bool = False,
|
380 |
+
device_map: Optional[str] = "auto",
|
381 |
+
trust_remote_code: bool = True,
|
382 |
+
**kwargs
|
383 |
+
) -> Dict[str, Any]:
|
384 |
+
"""
|
385 |
+
统一的聊天完成接口函数
|
386 |
+
|
387 |
+
参数:
|
388 |
+
messages: 消息列表,每个消息包含role和content字段
|
389 |
+
provider: LLM提供者,可选值:
|
390 |
+
- "gemma-mlx": 基于MLX库的Gemma聊天完成实现
|
391 |
+
- "gemma-transformers": 基于Transformers库的Gemma聊天完成实现
|
392 |
+
- "phi4-transformers": 基于Transformers库的Phi-4推理聊天完成实现
|
393 |
+
temperature: 温度参数,控制生成的随机性 (0.0-2.0)
|
394 |
+
max_tokens: 最大生成token数
|
395 |
+
top_p: nucleus采样参数 (0.0-1.0)
|
396 |
+
model: 模型名称,如果不指定则使用默认模型
|
397 |
+
device: 推理设备,'cpu'、'cuda'、'mps'(仅transformers provider支持)
|
398 |
+
use_4bit_quantization: 是否使用4bit量化(仅transformers provider支持)
|
399 |
+
device_map: 设备映射配置(仅transformers provider支持)
|
400 |
+
trust_remote_code: 是否信任远程代码(仅transformers provider支持)
|
401 |
+
**kwargs: 其他参数
|
402 |
+
|
403 |
+
返回:
|
404 |
+
聊天完成响应字典,包含生成的消息和使用统计
|
405 |
+
|
406 |
+
示例:
|
407 |
+
# 使用默认MLX实现
|
408 |
+
response = chat_completion(
|
409 |
+
messages=[{"role": "user", "content": "你好"}],
|
410 |
+
provider="gemma-mlx"
|
411 |
+
)
|
412 |
+
|
413 |
+
# 使用Gemma transformers实现
|
414 |
+
response = chat_completion(
|
415 |
+
messages=[{"role": "user", "content": "你好"}],
|
416 |
+
provider="gemma-transformers",
|
417 |
+
model="google/gemma-3-12b-it",
|
418 |
+
device="cuda",
|
419 |
+
use_4bit_quantization=True
|
420 |
+
)
|
421 |
+
|
422 |
+
# 使用Phi-4推理实现
|
423 |
+
response = chat_completion(
|
424 |
+
messages=[{"role": "user", "content": "解这个数学题:2x + 5 = 15"}],
|
425 |
+
provider="phi4-transformers",
|
426 |
+
model="microsoft/Phi-4-mini-reasoning",
|
427 |
+
device="cuda"
|
428 |
+
)
|
429 |
+
|
430 |
+
# 自定义参数
|
431 |
+
response = chat_completion(
|
432 |
+
messages=[
|
433 |
+
{"role": "system", "content": "你是一个有用的助手"},
|
434 |
+
{"role": "user", "content": "请介绍自己"}
|
435 |
+
],
|
436 |
+
provider="gemma-mlx",
|
437 |
+
temperature=0.8,
|
438 |
+
max_tokens=1024
|
439 |
+
)
|
440 |
+
"""
|
441 |
+
# 准备参数
|
442 |
+
params = kwargs.copy()
|
443 |
+
if model is not None:
|
444 |
+
params["model_name"] = model
|
445 |
+
if device is not None:
|
446 |
+
params["device"] = device
|
447 |
+
if use_4bit_quantization:
|
448 |
+
params["use_4bit_quantization"] = use_4bit_quantization
|
449 |
+
if device_map != "auto":
|
450 |
+
params["device_map"] = device_map
|
451 |
+
if not trust_remote_code:
|
452 |
+
params["trust_remote_code"] = trust_remote_code
|
453 |
+
|
454 |
+
return _router.chat_completion(
|
455 |
+
messages=messages,
|
456 |
+
provider=provider,
|
457 |
+
temperature=temperature,
|
458 |
+
max_tokens=max_tokens,
|
459 |
+
top_p=top_p,
|
460 |
+
model=model,
|
461 |
+
**params
|
462 |
+
)
|
463 |
+
|
464 |
+
|
465 |
+
def reasoning_completion(
|
466 |
+
messages: List[Dict[str, str]],
|
467 |
+
provider: str = "phi4-transformers",
|
468 |
+
temperature: float = 0.3,
|
469 |
+
max_tokens: int = 2048,
|
470 |
+
top_p: float = 0.9,
|
471 |
+
model: Optional[str] = None,
|
472 |
+
device: Optional[str] = None,
|
473 |
+
use_4bit_quantization: bool = False,
|
474 |
+
device_map: Optional[str] = "auto",
|
475 |
+
trust_remote_code: bool = True,
|
476 |
+
extract_reasoning_steps: bool = True,
|
477 |
+
**kwargs
|
478 |
+
) -> Dict[str, Any]:
|
479 |
+
"""
|
480 |
+
专门用于推理任务的聊天完成接口函数
|
481 |
+
|
482 |
+
参数:
|
483 |
+
messages: 消息列表,每个消息包含role和content字段
|
484 |
+
provider: LLM提供者,默认使用phi4-transformers
|
485 |
+
temperature: 温度参数(推理任务建议使用较低值)
|
486 |
+
max_tokens: 最大生成token数
|
487 |
+
top_p: nucleus采样参数
|
488 |
+
model: 模型名称,如果不指定则使用默认模型
|
489 |
+
device: 推理设备
|
490 |
+
use_4bit_quantization: 是否使用4bit量化
|
491 |
+
device_map: 设备映射配置
|
492 |
+
trust_remote_code: 是否信任远程代码
|
493 |
+
extract_reasoning_steps: 是否提取推理步骤
|
494 |
+
**kwargs: 其他参数
|
495 |
+
|
496 |
+
返回:
|
497 |
+
包含推理步骤的响应字典
|
498 |
+
|
499 |
+
示例:
|
500 |
+
# 数学推理任务
|
501 |
+
response = reasoning_completion(
|
502 |
+
messages=[{"role": "user", "content": "解这个方程:3x + 7 = 22"}],
|
503 |
+
provider="phi4-transformers",
|
504 |
+
extract_reasoning_steps=True
|
505 |
+
)
|
506 |
+
|
507 |
+
# 逻辑推理任务
|
508 |
+
response = reasoning_completion(
|
509 |
+
messages=[{"role": "user", "content": "如果所有的猫都是动物,而小花是一只猫,那么小花是什么?"}],
|
510 |
+
provider="phi4-transformers",
|
511 |
+
temperature=0.2
|
512 |
+
)
|
513 |
+
"""
|
514 |
+
# 准备参数
|
515 |
+
params = kwargs.copy()
|
516 |
+
if model is not None:
|
517 |
+
params["model_name"] = model
|
518 |
+
if device is not None:
|
519 |
+
params["device"] = device
|
520 |
+
if use_4bit_quantization:
|
521 |
+
params["use_4bit_quantization"] = use_4bit_quantization
|
522 |
+
if device_map != "auto":
|
523 |
+
params["device_map"] = device_map
|
524 |
+
if not trust_remote_code:
|
525 |
+
params["trust_remote_code"] = trust_remote_code
|
526 |
+
|
527 |
+
return _router.reasoning_completion(
|
528 |
+
messages=messages,
|
529 |
+
provider=provider,
|
530 |
+
temperature=temperature,
|
531 |
+
max_tokens=max_tokens,
|
532 |
+
top_p=top_p,
|
533 |
+
model=model,
|
534 |
+
extract_reasoning_steps=extract_reasoning_steps,
|
535 |
+
**params
|
536 |
+
)
|
537 |
+
|
538 |
+
|
539 |
+
def get_model_info(provider: str = "gemma-mlx", **kwargs) -> Dict[str, Any]:
|
540 |
+
"""
|
541 |
+
获取模型信息
|
542 |
+
|
543 |
+
参数:
|
544 |
+
provider: provider名称
|
545 |
+
**kwargs: 构造函数参数
|
546 |
+
|
547 |
+
返回:
|
548 |
+
模型信息字典
|
549 |
+
"""
|
550 |
+
return _router.get_model_info(provider, **kwargs)
|
551 |
+
|
552 |
+
|
553 |
+
def get_available_providers() -> Dict[str, str]:
|
554 |
+
"""
|
555 |
+
获取所有可用的LLM提供者
|
556 |
+
|
557 |
+
返回:
|
558 |
+
provider名称到描述的映射
|
559 |
+
"""
|
560 |
+
return _router.get_available_providers()
|
561 |
+
|
562 |
+
|
563 |
+
def get_provider_info(provider: str) -> Dict[str, Any]:
|
564 |
+
"""
|
565 |
+
获取指定provider的详细信息
|
566 |
+
|
567 |
+
参数:
|
568 |
+
provider: provider名称
|
569 |
+
|
570 |
+
返回:
|
571 |
+
provider的配置信息
|
572 |
+
"""
|
573 |
+
return _router.get_provider_info(provider)
|
574 |
+
|
575 |
+
|
576 |
+
def clear_cache():
|
577 |
+
"""清理缓存的LLM实例"""
|
578 |
+
_router.clear_cache()
|
src/podcast_transcribe/rss/podcast_rss_parser.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import feedparser
|
3 |
+
# from dataclasses import dataclass, field # 已移除
|
4 |
+
from typing import Optional # , List, Dict # List 和 Dict 不再需要
|
5 |
+
from datetime import datetime
|
6 |
+
import time
|
7 |
+
|
8 |
+
from ..schemas import PodcastEpisode, PodcastChannel
|
9 |
+
|
10 |
+
def _parse_date(date_str: Optional[str]) -> Optional[datetime]:
|
11 |
+
if not date_str:
|
12 |
+
return None
|
13 |
+
try:
|
14 |
+
# feedparser 已经将日期解析为 time.struct_time 类型
|
15 |
+
# 我们将其转换为 datetime 类型
|
16 |
+
if isinstance(date_str, time.struct_time):
|
17 |
+
return datetime.fromtimestamp(time.mktime(date_str))
|
18 |
+
# 如果 feedparser 解析失败或返回字符串,则回退使用其他字符串格式解析
|
19 |
+
# 这是一种常见的 RSS 日期格式
|
20 |
+
return datetime.strptime(date_str, '%a, %d %b %Y %H:%M:%S %z')
|
21 |
+
except (ValueError, TypeError):
|
22 |
+
try:
|
23 |
+
return datetime.strptime(date_str, '%a, %d %b %Y %H:%M:%S %Z') # 处理 GMT, EST 等时区
|
24 |
+
except (ValueError, TypeError):
|
25 |
+
# 如果时区缺失或无法解析,则尝试不带时区解析
|
26 |
+
try:
|
27 |
+
return datetime.strptime(date_str[:-6], '%a, %d %b %Y %H:%M:%S')
|
28 |
+
except (ValueError, TypeError):
|
29 |
+
print(f"Warning: Could not parse date string: {date_str}")
|
30 |
+
return None
|
31 |
+
|
32 |
+
def fetch_rss_content(rss_url: str) -> Optional[bytes]:
|
33 |
+
"""
|
34 |
+
通过 HTTP 请求获取 RSS feed 的内容。
|
35 |
+
|
36 |
+
参数:
|
37 |
+
rss_url: 播客 RSS feed 的 URL。
|
38 |
+
|
39 |
+
返回:
|
40 |
+
bytes 类型的 RSS 内容,如果获取失败则返回 None。
|
41 |
+
"""
|
42 |
+
headers = {
|
43 |
+
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.114 Safari/537.36'
|
44 |
+
}
|
45 |
+
try:
|
46 |
+
response = requests.get(rss_url, headers=headers, timeout=30)
|
47 |
+
response.raise_for_status() # 针对 HTTP 错误抛出异常
|
48 |
+
return response.content
|
49 |
+
except requests.exceptions.RequestException as e:
|
50 |
+
print(f"获取 RSS feed 时出错: {e}")
|
51 |
+
return None
|
52 |
+
|
53 |
+
def parse_rss_xml_content(rss_content: bytes) -> Optional[PodcastChannel]:
|
54 |
+
"""
|
55 |
+
解析播客 RSS XML 内容,并返回其主要信息和剧集详情。
|
56 |
+
|
57 |
+
参数:
|
58 |
+
rss_content: bytes 类型的 RSS XML 内容。
|
59 |
+
|
60 |
+
返回:
|
61 |
+
一个包含已解析信息的 PodcastChannel 对象,如果解析失败则返回 None。
|
62 |
+
"""
|
63 |
+
feed = feedparser.parse(rss_content)
|
64 |
+
|
65 |
+
if feed.bozo:
|
66 |
+
# 如果 feed 格式不正确,bozo 为 True
|
67 |
+
# feed.bozo_exception 包含异常信息
|
68 |
+
print(f"警告: RSS feed 可能格式不正确。Bozo 异常: {feed.bozo_exception}")
|
69 |
+
# 即使格式不完全正确,feedparser 通常仍会尝试解析,所以我们不在此处直接返回 None
|
70 |
+
# 但如果关键的 feed 或 channel_info 缺失,后续会自然失败
|
71 |
+
|
72 |
+
channel_info = feed.get('feed', {})
|
73 |
+
if not channel_info: # 如果连基本的 feed 结构都没有,则认为解析失败
|
74 |
+
print("错误: RSS 内容无法解析为有效的 feed 结构。")
|
75 |
+
return None
|
76 |
+
|
77 |
+
podcast_channel = PodcastChannel(
|
78 |
+
title=channel_info.get('title'),
|
79 |
+
link=channel_info.get('link'),
|
80 |
+
description=channel_info.get('subtitle') or channel_info.get('description'),
|
81 |
+
language=channel_info.get('language'),
|
82 |
+
image_url=channel_info.get('image', {}).get('href') if channel_info.get('image') else None,
|
83 |
+
author=channel_info.get('author') or channel_info.get('itunes_author'),
|
84 |
+
last_build_date=_parse_date(channel_info.get('updated_parsed') or channel_info.get('published_parsed'))
|
85 |
+
)
|
86 |
+
|
87 |
+
for entry in feed.entries:
|
88 |
+
# 确定 shownotes:优先使用 content:encoded,然后是 itunes:summary,其次是 description/summary
|
89 |
+
shownotes = None
|
90 |
+
# 1. 优先尝试 <content:encoded>
|
91 |
+
# entry.content 是一个 FeedParserDict 对象列表
|
92 |
+
if 'content' in entry and entry.content:
|
93 |
+
for content_item in entry.content:
|
94 |
+
# 检查 content_item 是否有 value 属性并且该值非空
|
95 |
+
if hasattr(content_item, 'value') and content_item.value:
|
96 |
+
shownotes = content_item.value
|
97 |
+
break # 找到第一个有效的 content:encoded,停止查找
|
98 |
+
|
99 |
+
# 2. 如果没有从 content:encoded 获得,尝试 itunes:summary
|
100 |
+
if not shownotes and 'itunes_summary' in entry:
|
101 |
+
shownotes = entry.itunes_summary
|
102 |
+
|
103 |
+
# 3. 最后回退到 summary 或 description
|
104 |
+
if not shownotes: # 回退到 summary 或 description
|
105 |
+
shownotes = entry.get('summary') or entry.get('description')
|
106 |
+
|
107 |
+
# 从 enclosures 获取音频 URL
|
108 |
+
audio_url = None
|
109 |
+
if 'enclosures' in entry:
|
110 |
+
for enc in entry.enclosures:
|
111 |
+
if enc.get('type', '').startswith('audio/'):
|
112 |
+
audio_url = enc.get('href')
|
113 |
+
break
|
114 |
+
|
115 |
+
# 解析特定于剧集的 iTunes 标签
|
116 |
+
itunes_season = None
|
117 |
+
try:
|
118 |
+
itunes_season_str = entry.get('itunes_season')
|
119 |
+
if itunes_season_str:
|
120 |
+
itunes_season = int(itunes_season_str)
|
121 |
+
except (ValueError, TypeError):
|
122 |
+
pass # 如果不是有效整数则忽略
|
123 |
+
|
124 |
+
itunes_episode_number = None
|
125 |
+
try:
|
126 |
+
itunes_episode_number_str = entry.get('itunes_episode')
|
127 |
+
if itunes_episode_number_str:
|
128 |
+
itunes_episode_number = int(itunes_episode_number_str)
|
129 |
+
except (ValueError, TypeError):
|
130 |
+
pass # 如果不是有效整数则忽略
|
131 |
+
|
132 |
+
episode = PodcastEpisode(
|
133 |
+
title=entry.get('title'),
|
134 |
+
link=entry.get('link'),
|
135 |
+
published_date=_parse_date(entry.get('published_parsed')),
|
136 |
+
summary=entry.get('summary'), # 这通常是较短的版本
|
137 |
+
shownotes=shownotes, # 这是我们尝试获取的更详细版本
|
138 |
+
audio_url=audio_url,
|
139 |
+
guid=entry.get('id') or entry.get('guid'),
|
140 |
+
duration=entry.get('itunes_duration'),
|
141 |
+
episode_type=entry.get('itunes_episodetype'),
|
142 |
+
season=itunes_season,
|
143 |
+
episode_number=itunes_episode_number
|
144 |
+
)
|
145 |
+
podcast_channel.episodes.append(episode)
|
146 |
+
|
147 |
+
return podcast_channel
|
148 |
+
|
149 |
+
def parse_podcast_rss(rss_url: str) -> Optional[PodcastChannel]:
|
150 |
+
"""
|
151 |
+
从给定的 RSS URL 获取并解析播客数据。
|
152 |
+
|
153 |
+
参数:
|
154 |
+
rss_url: 播客 RSS feed 的 URL。
|
155 |
+
|
156 |
+
返回:
|
157 |
+
一个包含已解析信息的 PodcastChannel 对象,如果获取或解析失败则返回 None。
|
158 |
+
"""
|
159 |
+
rss_content = fetch_rss_content(rss_url)
|
160 |
+
if rss_content:
|
161 |
+
return parse_rss_xml_content(rss_content)
|
162 |
+
return None
|
src/podcast_transcribe/schemas.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import List, Optional, Dict, Union
|
3 |
+
from datetime import datetime
|
4 |
+
|
5 |
+
@dataclass
|
6 |
+
class EnhancedSegment:
|
7 |
+
"""增强的转录分段,包含说话人信息"""
|
8 |
+
start: float # 开始时间(秒)
|
9 |
+
end: float # 结束时间(秒)
|
10 |
+
text: str # 转录的文本
|
11 |
+
speaker: str # 说话人ID
|
12 |
+
language: str # 检测到的语言
|
13 |
+
speaker_name: Optional[str] = None # 识别出的说话人名称
|
14 |
+
|
15 |
+
|
16 |
+
@dataclass
|
17 |
+
class CombinedTranscriptionResult:
|
18 |
+
"""结合ASR和说话人分离的转录结果"""
|
19 |
+
segments: List[EnhancedSegment] # 包含说话人和文本的分段
|
20 |
+
text: str # 完整转录文本
|
21 |
+
language: str # 检测到的语言
|
22 |
+
num_speakers: int # 检测到的说话人数量
|
23 |
+
|
24 |
+
|
25 |
+
@dataclass
|
26 |
+
class PodcastEpisode:
|
27 |
+
title: Optional[str] = None
|
28 |
+
link: Optional[str] = None
|
29 |
+
published_date: Optional[datetime] = None
|
30 |
+
summary: Optional[str] = None # 简短摘要
|
31 |
+
shownotes: Optional[str] = None # 详细的shownotes,通常是HTML格式
|
32 |
+
audio_url: Optional[str] = None
|
33 |
+
guid: Optional[str] = None
|
34 |
+
duration: Optional[str] = None # 例如,来自 <itunes:duration>
|
35 |
+
episode_type: Optional[str] = None # 例如,来自 <itunes:episodetype>
|
36 |
+
season: Optional[int] = None # 例如,来自 <itunes:season>
|
37 |
+
episode_number: Optional[int] = None # 例如,来自 <itunes:episode>
|
38 |
+
|
39 |
+
@dataclass
|
40 |
+
class PodcastChannel:
|
41 |
+
title: Optional[str] = None
|
42 |
+
link: Optional[str] = None
|
43 |
+
description: Optional[str] = None
|
44 |
+
language: Optional[str] = None
|
45 |
+
image_url: Optional[str] = None
|
46 |
+
author: Optional[str] = None # 例如,来自 <itunes:author>
|
47 |
+
last_build_date: Optional[datetime] = None
|
48 |
+
episodes: List[PodcastEpisode] = field(default_factory=list)
|
49 |
+
|
50 |
+
|
51 |
+
@dataclass
|
52 |
+
class TranscriptionResult:
|
53 |
+
"""转录结果数据类"""
|
54 |
+
text: str # 转录的文本
|
55 |
+
segments: List[Dict[str, Union[float, str]]] # 包含时间戳的分段
|
56 |
+
language: str # 检测到的语言
|
57 |
+
|
58 |
+
|
59 |
+
@dataclass
|
60 |
+
class DiarizationResult:
|
61 |
+
"""说话人分离结果数据类"""
|
62 |
+
segments: List[Dict[str, Union[float, str, int]]] # 包含时间戳和说话人ID的分段
|
63 |
+
num_speakers: int # 检测到的说话人数量
|
src/podcast_transcribe/summary/speaker_identify.py
ADDED
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict, Optional
|
2 |
+
import json
|
3 |
+
import re
|
4 |
+
|
5 |
+
from ..schemas import EnhancedSegment, PodcastChannel, PodcastEpisode
|
6 |
+
from ..llm import llm_router
|
7 |
+
|
8 |
+
|
9 |
+
class SpeakerIdentifier:
|
10 |
+
"""
|
11 |
+
说话人识别器类,用于根据转录分段和播客元数据识别说话人的真实姓名或昵称
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, llm_model_name: str, llm_provider: str):
|
15 |
+
"""
|
16 |
+
初始化说话人识别器
|
17 |
+
|
18 |
+
参数:
|
19 |
+
llm_model_name: LLM模型名称,如果为None则使用默认模型
|
20 |
+
llm_provider: LLM提供者,默认为"gemma-mlx"
|
21 |
+
"""
|
22 |
+
self.llm_model_name = llm_model_name
|
23 |
+
self.llm_provider = llm_provider
|
24 |
+
|
25 |
+
def _clean_html(self, html_string: Optional[str]) -> str:
|
26 |
+
"""
|
27 |
+
简单地从字符串中移除HTML标签并清理多余空白。
|
28 |
+
"""
|
29 |
+
if not html_string:
|
30 |
+
return ""
|
31 |
+
# 移除HTML标签
|
32 |
+
text = re.sub(r'<[^>]+>', ' ', html_string)
|
33 |
+
# 替换HTML实体(简单版本,只处理常见几个)
|
34 |
+
text = text.replace(' ', ' ').replace('&', '&').replace('<', '<').replace('>', '>')
|
35 |
+
# 移除多余的空白符
|
36 |
+
text = re.sub(r'\\s+', ' ', text).strip()
|
37 |
+
return text
|
38 |
+
|
39 |
+
def _get_dialogue_samples(
|
40 |
+
self,
|
41 |
+
segments: List[EnhancedSegment],
|
42 |
+
max_samples_per_speaker: int = 3, # 增加样本数量
|
43 |
+
max_length_per_sample: int = 200 # 增加样本长度
|
44 |
+
) -> Dict[str, List[str]]:
|
45 |
+
"""
|
46 |
+
为每个说话人提取对话样本。
|
47 |
+
"""
|
48 |
+
speaker_dialogues: Dict[str, List[str]] = {}
|
49 |
+
for segment in segments:
|
50 |
+
speaker = segment.speaker
|
51 |
+
if speaker == "UNKNOWN" or not segment.text.strip(): # 跳过未知说话人或空文本
|
52 |
+
continue
|
53 |
+
|
54 |
+
if speaker not in speaker_dialogues:
|
55 |
+
speaker_dialogues[speaker] = []
|
56 |
+
|
57 |
+
if len(speaker_dialogues[speaker]) < max_samples_per_speaker:
|
58 |
+
text_sample = segment.text.strip()[:max_length_per_sample]
|
59 |
+
if len(segment.text.strip()) > max_length_per_sample:
|
60 |
+
text_sample += "..."
|
61 |
+
speaker_dialogues[speaker].append(text_sample)
|
62 |
+
return speaker_dialogues
|
63 |
+
|
64 |
+
def recognize_speaker_names(
|
65 |
+
self,
|
66 |
+
segments: List[EnhancedSegment],
|
67 |
+
podcast_info: Optional[PodcastChannel],
|
68 |
+
episode_info: Optional[PodcastEpisode],
|
69 |
+
max_shownotes_length: int = 1500,
|
70 |
+
max_desc_length: int = 500
|
71 |
+
) -> Dict[str, str]:
|
72 |
+
"""
|
73 |
+
使用LLM根据转录分段和播客/剧集元数据识别说话人的真实姓名或昵称。
|
74 |
+
|
75 |
+
参数:
|
76 |
+
segments: 转录后的 EnhancedSegment 列表。
|
77 |
+
podcast_info: 包含播客元数据的 PodcastChannel 对象。
|
78 |
+
episode_info: 包含单集播客元数据的 PodcastEpisode 对象。
|
79 |
+
max_shownotes_length: 用于Prompt的 Shownotes 最大字符数。
|
80 |
+
max_desc_length: 用于Prompt的播客描述最大字符数。
|
81 |
+
|
82 |
+
返回:
|
83 |
+
一个字典,键是原始的 "SPEAKER_XX",值是识别出的说话人名称。
|
84 |
+
"""
|
85 |
+
unique_speaker_ids = sorted(list(set(seg.speaker for seg in segments if seg.speaker != "UNKNOWN" and seg.text.strip())))
|
86 |
+
if not unique_speaker_ids:
|
87 |
+
print("未能从 segments 中提取到有效的 speaker_ids。")
|
88 |
+
return {}
|
89 |
+
|
90 |
+
dialogue_samples = self._get_dialogue_samples(segments)
|
91 |
+
|
92 |
+
# 增加每个说话人的话语分析信息,包括话语频率和长度
|
93 |
+
speaker_stats = {}
|
94 |
+
for segment in segments:
|
95 |
+
speaker = segment.speaker
|
96 |
+
if speaker == "UNKNOWN" or not segment.text.strip():
|
97 |
+
continue
|
98 |
+
|
99 |
+
if speaker not in speaker_stats:
|
100 |
+
speaker_stats[speaker] = {
|
101 |
+
"total_segments": 0,
|
102 |
+
"total_chars": 0,
|
103 |
+
"avg_segment_length": 0,
|
104 |
+
"intro_likely": False # 是否有介绍性质的话语
|
105 |
+
}
|
106 |
+
|
107 |
+
speaker_stats[speaker]["total_segments"] += 1
|
108 |
+
speaker_stats[speaker]["total_chars"] += len(segment.text)
|
109 |
+
|
110 |
+
# 检测可能的自我介绍或他人介绍
|
111 |
+
lower_text = segment.text.lower()
|
112 |
+
intro_patterns = [
|
113 |
+
r'欢迎来到', r'欢迎收听', r'我是', r'我叫', r'大家好', r'今天的嘉宾是', r'我们请到了',
|
114 |
+
r'welcome to', r'i\'m your host', r'this is', r'today we have', r'joining us',
|
115 |
+
r'our guest', r'my name is'
|
116 |
+
]
|
117 |
+
if any(re.search(pattern, lower_text) for pattern in intro_patterns):
|
118 |
+
speaker_stats[speaker]["intro_likely"] = True
|
119 |
+
|
120 |
+
# 计算平均话语长度
|
121 |
+
for speaker, stats in speaker_stats.items():
|
122 |
+
if stats["total_segments"] > 0:
|
123 |
+
stats["avg_segment_length"] = stats["total_chars"] / stats["total_segments"]
|
124 |
+
|
125 |
+
# 创建增强的说话人信息,包含统计数据
|
126 |
+
speaker_info_for_prompt = []
|
127 |
+
for speaker_id in unique_speaker_ids:
|
128 |
+
samples = dialogue_samples.get(speaker_id, ["(No dialogue samples available)"])
|
129 |
+
stats = speaker_stats.get(speaker_id, {"total_segments": 0, "avg_segment_length": 0, "intro_likely": False})
|
130 |
+
|
131 |
+
speaker_info_for_prompt.append({
|
132 |
+
"speaker_id": speaker_id,
|
133 |
+
"dialogue_samples": samples,
|
134 |
+
"speech_stats": {
|
135 |
+
"total_segments": stats["total_segments"],
|
136 |
+
"avg_segment_length": round(stats["avg_segment_length"], 2),
|
137 |
+
"has_intro_pattern": stats["intro_likely"]
|
138 |
+
}
|
139 |
+
})
|
140 |
+
|
141 |
+
# 安全地访问属性,提供默认值
|
142 |
+
podcast_title = podcast_info.title if podcast_info and podcast_info.title else "Unknown Podcast"
|
143 |
+
podcast_author = podcast_info.author if podcast_info and podcast_info.author else "Unknown"
|
144 |
+
|
145 |
+
raw_podcast_desc = podcast_info.description if podcast_info and podcast_info.description else ""
|
146 |
+
cleaned_podcast_desc = self._clean_html(raw_podcast_desc)
|
147 |
+
podcast_desc_for_prompt = cleaned_podcast_desc[:max_desc_length]
|
148 |
+
if len(cleaned_podcast_desc) > max_desc_length:
|
149 |
+
podcast_desc_for_prompt += "..."
|
150 |
+
|
151 |
+
episode_title = episode_info.title if episode_info and episode_info.title else "Unknown Episode"
|
152 |
+
|
153 |
+
raw_episode_summary = episode_info.summary if episode_info and episode_info.summary else ""
|
154 |
+
cleaned_episode_summary = self._clean_html(raw_episode_summary)
|
155 |
+
episode_summary_for_prompt = cleaned_episode_summary[:max_desc_length] # 使用与描述相同的长度限制
|
156 |
+
if len(cleaned_episode_summary) > max_desc_length:
|
157 |
+
episode_summary_for_prompt += "..."
|
158 |
+
|
159 |
+
raw_episode_shownotes = episode_info.shownotes if episode_info and episode_info.shownotes else ""
|
160 |
+
cleaned_episode_shownotes = self._clean_html(raw_episode_shownotes)
|
161 |
+
episode_shownotes_for_prompt = cleaned_episode_shownotes[:max_shownotes_length]
|
162 |
+
if len(cleaned_episode_shownotes) > max_shownotes_length:
|
163 |
+
episode_shownotes_for_prompt += "..."
|
164 |
+
|
165 |
+
system_prompt = """You are an experienced podcast content analyst. Your task is to accurately identify the real names, nicknames, or roles of different speakers (tagged in SPEAKER_XX format) in a podcast episode, based on the provided metadata, episode information, dialogue snippets, and speech patterns. Your analysis should NOT rely on the order of speakers or speaker IDs."""
|
166 |
+
|
167 |
+
user_prompt_template = f"""
|
168 |
+
Contextual Information:
|
169 |
+
|
170 |
+
1. **Podcast Information**:
|
171 |
+
* Podcast Title: {podcast_title}
|
172 |
+
* Podcast Author/Producer: {podcast_author} (This information often points to the main host or production team)
|
173 |
+
* Podcast Description: {podcast_desc_for_prompt}
|
174 |
+
|
175 |
+
2. **Current Episode Information**:
|
176 |
+
* Episode Title: {episode_title}
|
177 |
+
* Episode Summary: {episode_summary_for_prompt}
|
178 |
+
* Detailed Episode Notes (Shownotes):
|
179 |
+
```text
|
180 |
+
{episode_shownotes_for_prompt}
|
181 |
+
```
|
182 |
+
(Pay close attention to any host names, guest names, positions, or social media handles mentioned in the Shownotes.)
|
183 |
+
|
184 |
+
3. **Speakers to Identify and Their Information**:
|
185 |
+
```json
|
186 |
+
{json.dumps(speaker_info_for_prompt, ensure_ascii=False, indent=2)}
|
187 |
+
```
|
188 |
+
(Analyze dialogue samples and speech statistics to understand speaker roles and identities. DO NOT use speaker IDs to determine roles - SPEAKER_00 is not necessarily the host.)
|
189 |
+
|
190 |
+
Task:
|
191 |
+
Based on all the information above, assign the most accurate name or role to each "speaker_id".
|
192 |
+
|
193 |
+
Analysis Guidance:
|
194 |
+
* A host typically has more frequent, shorter segments, often introduces the show or guests, and may mention the podcast name
|
195 |
+
* In panel discussion formats, there might be multiple hosts or co-hosts of similar speaking patterns
|
196 |
+
* In interview formats, the host typically asks questions while guests give longer answers
|
197 |
+
* Speakers who make introductory statements or welcome listeners are likely hosts
|
198 |
+
* Use dialogue content (not just speaking patterns) to identify names and roles
|
199 |
+
|
200 |
+
Output Requirements and Guidelines:
|
201 |
+
* Please return the result strictly in JSON format. The keys of the JSON object should be the original "speaker_id" (e.g., "SPEAKER_00"), and the values should be the identified person's name or role (string type).
|
202 |
+
* **Prioritize Specific Names/Nicknames**: If there is sufficient information (e.g., guests explicitly listed in Shownotes, or names mentioned in dialogue), please use the identified specific names, such as "John Doe", "AI Assistant", "Dr. Evelyn Reed". Do NOT append roles like "(Host)" or "(Guest)" if a specific name is found.
|
203 |
+
* **Host Identification**:
|
204 |
+
* Hosts may be identified by analyzing speech patterns - they often speak more frequently in shorter segments
|
205 |
+
* Look for introduction patterns in dialogue where speakers welcome listeners or introduce the show
|
206 |
+
* The podcast author (if provided and credible) is often a host but verify through dialogue
|
207 |
+
* There may be multiple hosts (co-hosts) in panel-style podcasts
|
208 |
+
* If a host's name is identified, use the identified name directly (e.g., "Lex Fridman"). Do not append "(Host)".
|
209 |
+
* If the host's name cannot be determined but the role is clearly a host, use "Podcast Host".
|
210 |
+
* **Guest Identification**:
|
211 |
+
* Guests often give longer responses and speak less frequently than hosts
|
212 |
+
* For other non-host speakers, if a specific name is identified, use the identified name directly (e.g., "John Carmack"). Do not append "(Guest)".
|
213 |
+
* If specific names cannot be identified for guests, label them sequentially as "Guest 1", "Guest 2", etc.
|
214 |
+
* **Handling Multiple Hosts/Guests**: If there are multiple hosts or guests and they can be distinguished by name, use their names. If you cannot distinguish specific identities but know there are multiple hosts, use "Host 1", "Host 2", etc. Similarly for guests without specific names, use "Guest 1", "Guest 2".
|
215 |
+
* **Ensure Completeness**: The returned JSON object must include all "speaker_id"s listed in the input as keys.
|
216 |
+
|
217 |
+
JSON Output Example:
|
218 |
+
```json
|
219 |
+
{{
|
220 |
+
"SPEAKER_00": "Jane Smith",
|
221 |
+
"SPEAKER_01": "Podcast Host",
|
222 |
+
"SPEAKER_02": "Alex Green"
|
223 |
+
}}
|
224 |
+
```
|
225 |
+
Note that in this example, SPEAKER_01 is identified as the host, not SPEAKER_00, based on content analysis, not ID order.
|
226 |
+
|
227 |
+
Please begin your analysis and provide the JSON result.
|
228 |
+
"""
|
229 |
+
|
230 |
+
messages = [
|
231 |
+
{"role": "system", "content": system_prompt},
|
232 |
+
{"role": "user", "content": user_prompt_template}
|
233 |
+
]
|
234 |
+
|
235 |
+
# 预设默认映射,使用更智能的启发式方法而不是简单依赖顺序
|
236 |
+
final_map = {}
|
237 |
+
|
238 |
+
# 尝试使用说话模式启发式方法来初步识别角色
|
239 |
+
# 1. 说话次数最多的可能是主持人
|
240 |
+
# 2. 有介绍性话语的可能是主持人
|
241 |
+
# 3. 其他角色先标记为嘉宾
|
242 |
+
|
243 |
+
host_candidates = []
|
244 |
+
for speaker_id, stats in speaker_stats.items():
|
245 |
+
if stats["intro_likely"]:
|
246 |
+
host_candidates.append((speaker_id, 2)) # 优先级2:有介绍性话语
|
247 |
+
else:
|
248 |
+
# 按说话次数排序
|
249 |
+
host_candidates.append((speaker_id, stats["total_segments"]))
|
250 |
+
|
251 |
+
# 按可能性排序(介绍性话语 > 说话次数)
|
252 |
+
host_candidates.sort(key=lambda x: (-1 if x[1] == 2 else 0, x[1]), reverse=True)
|
253 |
+
|
254 |
+
if host_candidates:
|
255 |
+
# 最可能的主持人
|
256 |
+
host_id = host_candidates[0][0]
|
257 |
+
final_map[host_id] = "Podcast Host"
|
258 |
+
|
259 |
+
# 其他人先标为嘉宾
|
260 |
+
guest_counter = 1
|
261 |
+
for speaker_id in unique_speaker_ids:
|
262 |
+
if speaker_id != host_id:
|
263 |
+
final_map[speaker_id] = f"Guest {guest_counter}"
|
264 |
+
guest_counter += 1
|
265 |
+
else:
|
266 |
+
# 如果没有明显线索,使用传统的顺序方法作为备选
|
267 |
+
is_host_assigned = False
|
268 |
+
guest_counter = 1
|
269 |
+
for speaker_id in unique_speaker_ids:
|
270 |
+
if not is_host_assigned:
|
271 |
+
final_map[speaker_id] = "Podcast Host"
|
272 |
+
is_host_assigned = True
|
273 |
+
else:
|
274 |
+
final_map[speaker_id] = f"Guest {guest_counter}"
|
275 |
+
guest_counter += 1
|
276 |
+
|
277 |
+
try:
|
278 |
+
response = llm_router.chat_completion(
|
279 |
+
messages=messages,
|
280 |
+
provider=self.llm_provider,
|
281 |
+
model=self.llm_model_name,
|
282 |
+
temperature=0.1,
|
283 |
+
max_tokens=1024
|
284 |
+
)
|
285 |
+
assistant_response_content = response["choices"][0]["message"]["content"]
|
286 |
+
|
287 |
+
parsed_llm_output = None
|
288 |
+
# 尝试从Markdown代码块中提取JSON
|
289 |
+
json_match = re.search(r'```json\s*(\{.*?\})\s*```', assistant_response_content, re.DOTALL)
|
290 |
+
if json_match:
|
291 |
+
json_str = json_match.group(1)
|
292 |
+
else:
|
293 |
+
# 如果没有markdown块,尝试找到第一个 '{' 到最后一个 '}'
|
294 |
+
first_brace = assistant_response_content.find('{')
|
295 |
+
last_brace = assistant_response_content.rfind('}')
|
296 |
+
if first_brace != -1 and last_brace != -1 and last_brace > first_brace:
|
297 |
+
json_str = assistant_response_content[first_brace : last_brace+1]
|
298 |
+
else: # 如果还是找不到,就认为整个回复都是JSON(可能需要更复杂的清理)
|
299 |
+
json_str = assistant_response_content.strip()
|
300 |
+
|
301 |
+
try:
|
302 |
+
parsed_llm_output = json.loads(json_str)
|
303 |
+
if not isinstance(parsed_llm_output, dict): # 确保解析出来是字典
|
304 |
+
print(f"LLM返回的JSON不是一个字典: {parsed_llm_output}")
|
305 |
+
parsed_llm_output = None # 重置,以便使用默认值
|
306 |
+
except json.JSONDecodeError as e:
|
307 |
+
print(f"LLM返回的JSON解析失败: {e}")
|
308 |
+
print(f"用于解析的字符串: '{json_str}'")
|
309 |
+
# parsed_llm_output 保持为 None,将使用默认值
|
310 |
+
|
311 |
+
if parsed_llm_output:
|
312 |
+
# 直接使用LLM的有效输出,不再依赖预设的角色分配逻辑
|
313 |
+
final_map = {}
|
314 |
+
unknown_counter = 1
|
315 |
+
|
316 |
+
# 先处理LLM识别出的角色
|
317 |
+
for spk_id in unique_speaker_ids:
|
318 |
+
if spk_id in parsed_llm_output and isinstance(parsed_llm_output[spk_id], str) and parsed_llm_output[spk_id].strip():
|
319 |
+
final_map[spk_id] = parsed_llm_output[spk_id].strip()
|
320 |
+
else:
|
321 |
+
# 如果LLM没有给出特定ID的结果,使用"Unknown Speaker"
|
322 |
+
final_map[spk_id] = f"Unknown Speaker {unknown_counter}"
|
323 |
+
unknown_counter += 1
|
324 |
+
|
325 |
+
# 检查是否有"Host"或"主持人"标识
|
326 |
+
has_host = any("主持人" in name or "Host" in name for name in final_map.values())
|
327 |
+
|
328 |
+
# 如果没有任何主持人标识,且存在"Unknown Speaker",可以考虑将最活跃的未知说话人设为主持人
|
329 |
+
if not has_host and any("Unknown Speaker" in name for name in final_map.values()):
|
330 |
+
# 找出最活跃的未知说话人
|
331 |
+
most_active_unknown = None
|
332 |
+
max_segments = 0
|
333 |
+
|
334 |
+
for spk_id, name in final_map.items():
|
335 |
+
if "Unknown Speaker" in name and spk_id in speaker_stats:
|
336 |
+
if speaker_stats[spk_id]["total_segments"] > max_segments:
|
337 |
+
max_segments = speaker_stats[spk_id]["total_segments"]
|
338 |
+
most_active_unknown = spk_id
|
339 |
+
|
340 |
+
if most_active_unknown:
|
341 |
+
final_map[most_active_unknown] = "Podcast Host"
|
342 |
+
|
343 |
+
return final_map
|
344 |
+
|
345 |
+
except Exception as e:
|
346 |
+
import traceback
|
347 |
+
print(f"调用LLM或处理响应时发生严重错误: {e}")
|
348 |
+
print(traceback.format_exc())
|
349 |
+
# 发生任何严重错误,返回初始的启发式映射
|
350 |
+
return final_map
|
src/podcast_transcribe/transcriber.py
ADDED
@@ -0,0 +1,588 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
整合ASR和说话人分离的转录器模块,支持流式处理长语音对话
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
from pydub import AudioSegment
|
7 |
+
from typing import Dict, List, Union, Optional, Any
|
8 |
+
import logging
|
9 |
+
from concurrent.futures import ThreadPoolExecutor
|
10 |
+
import re
|
11 |
+
|
12 |
+
from .summary.speaker_identify import SpeakerIdentifier # 新增导入
|
13 |
+
|
14 |
+
# 导入ASR和说话人分离模块,使用相对导入
|
15 |
+
from .asr import asr_router
|
16 |
+
from .asr.asr_base import TranscriptionResult
|
17 |
+
from .diarization import diarizer_router
|
18 |
+
from .schemas import EnhancedSegment, CombinedTranscriptionResult, PodcastChannel, PodcastEpisode, DiarizationResult
|
19 |
+
|
20 |
+
# 配置日志
|
21 |
+
logger = logging.getLogger("podcast_transcribe")
|
22 |
+
|
23 |
+
class CombinedTranscriber:
|
24 |
+
"""整合ASR和说话人分离的转录器"""
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
asr_model_name: str,
|
29 |
+
asr_provider: str,
|
30 |
+
diarization_provider: str,
|
31 |
+
diarization_model_name: str,
|
32 |
+
llm_model_name: Optional[str] = None,
|
33 |
+
llm_provider: Optional[str] = None,
|
34 |
+
hf_token: Optional[str] = None,
|
35 |
+
device: Optional[str] = None,
|
36 |
+
segmentation_batch_size: int = 64,
|
37 |
+
parallel: bool = False,
|
38 |
+
):
|
39 |
+
"""
|
40 |
+
初始化转录器
|
41 |
+
|
42 |
+
参数:
|
43 |
+
asr_model_name: ASR模型名称
|
44 |
+
asr_provider: ASR提供者名称
|
45 |
+
diarization_provider: 说话人分离提供者名称
|
46 |
+
diarization_model_name: 说话人分离模型名称
|
47 |
+
hf_token: Hugging Face令牌
|
48 |
+
device: 推理设备,'cpu'或'cuda'
|
49 |
+
segmentation_batch_size: 分割批处理大小,默认为64
|
50 |
+
parallel: 是否并行执行ASR和说话人分离,默认为False
|
51 |
+
"""
|
52 |
+
if not device:
|
53 |
+
import torch
|
54 |
+
if torch.backends.mps.is_available():
|
55 |
+
device = "mps"
|
56 |
+
if not llm_model_name:
|
57 |
+
llm_model_name = "mlx-community/gemma-3-12b-it-4bit-DWQ"
|
58 |
+
if not llm_provider:
|
59 |
+
llm_provider = "gemma-mlx"
|
60 |
+
|
61 |
+
elif torch.cuda.is_available():
|
62 |
+
device = "cuda"
|
63 |
+
if not llm_model_name:
|
64 |
+
llm_model_name = "google/gemma-3-12b-it"
|
65 |
+
if not llm_provider:
|
66 |
+
llm_provider = "gemma-transformers"
|
67 |
+
else:
|
68 |
+
device = "cpu"
|
69 |
+
if not llm_model_name:
|
70 |
+
llm_model_name = "google/gemma-3-12b-it"
|
71 |
+
if not llm_provider:
|
72 |
+
llm_provider = "gemma-transformers"
|
73 |
+
|
74 |
+
self.asr_model_name = asr_model_name
|
75 |
+
self.asr_provider = asr_provider
|
76 |
+
self.diarization_provider = diarization_provider
|
77 |
+
self.diarization_model_name = diarization_model_name
|
78 |
+
self.hf_token = hf_token or os.environ.get("HF_TOKEN")
|
79 |
+
self.device = device
|
80 |
+
self.segmentation_batch_size = segmentation_batch_size
|
81 |
+
self.parallel = parallel
|
82 |
+
|
83 |
+
self.speaker_identifier = SpeakerIdentifier(
|
84 |
+
llm_model_name=llm_model_name,
|
85 |
+
llm_provider=llm_provider
|
86 |
+
)
|
87 |
+
|
88 |
+
logger.info(f"初始化组合转录器,ASR提供者: {asr_provider},ASR模型: {asr_model_name},分离提供者: {diarization_provider},分离模型: {diarization_model_name},分割批处理大小: {segmentation_batch_size},并行执行: {parallel},推理设备: {device}")
|
89 |
+
|
90 |
+
|
91 |
+
def _merge_adjacent_text_segments(self, segments: List[EnhancedSegment]) -> List[EnhancedSegment]:
|
92 |
+
"""
|
93 |
+
合并相邻的、可能属于同一句子的 EnhancedSegment。
|
94 |
+
合并条件:同一说话人,时间基本连续,文本内容可拼接。
|
95 |
+
"""
|
96 |
+
if not segments:
|
97 |
+
return []
|
98 |
+
|
99 |
+
merged_segments: List[EnhancedSegment] = []
|
100 |
+
if not segments: # 重复检查,可移除
|
101 |
+
return merged_segments
|
102 |
+
|
103 |
+
current_merged_segment = segments[0]
|
104 |
+
|
105 |
+
for i in range(1, len(segments)):
|
106 |
+
next_segment = segments[i]
|
107 |
+
|
108 |
+
time_gap_seconds = next_segment.start - current_merged_segment.end
|
109 |
+
|
110 |
+
can_merge_text = False
|
111 |
+
if current_merged_segment.text and next_segment.text:
|
112 |
+
current_text_stripped = current_merged_segment.text.strip()
|
113 |
+
if current_text_stripped and not current_text_stripped[-1] in ".。?!?!":
|
114 |
+
can_merge_text = True
|
115 |
+
|
116 |
+
if (current_merged_segment.speaker == next_segment.speaker and
|
117 |
+
0 <= time_gap_seconds < 0.75 and
|
118 |
+
can_merge_text):
|
119 |
+
current_merged_segment = EnhancedSegment(
|
120 |
+
start=current_merged_segment.start,
|
121 |
+
end=next_segment.end,
|
122 |
+
text=(current_merged_segment.text.strip() + " " + next_segment.text.strip()).strip(),
|
123 |
+
speaker=current_merged_segment.speaker,
|
124 |
+
language=current_merged_segment.language
|
125 |
+
)
|
126 |
+
else:
|
127 |
+
merged_segments.append(current_merged_segment)
|
128 |
+
current_merged_segment = next_segment
|
129 |
+
|
130 |
+
merged_segments.append(current_merged_segment)
|
131 |
+
|
132 |
+
return merged_segments
|
133 |
+
|
134 |
+
def _run_asr(self, audio: AudioSegment) -> TranscriptionResult:
|
135 |
+
"""执行ASR处理"""
|
136 |
+
logger.debug("执行ASR...")
|
137 |
+
return asr_router.transcribe_audio(
|
138 |
+
audio,
|
139 |
+
provider=self.asr_provider,
|
140 |
+
model_name=self.asr_model_name,
|
141 |
+
device=self.device
|
142 |
+
)
|
143 |
+
|
144 |
+
def _run_diarization(self, audio: AudioSegment) -> DiarizationResult:
|
145 |
+
"""执行说话人分离处理"""
|
146 |
+
logger.debug("执行说话人分离...")
|
147 |
+
return diarizer_router.diarize_audio(
|
148 |
+
audio,
|
149 |
+
provider=self.diarization_provider,
|
150 |
+
model_name=self.diarization_model_name,
|
151 |
+
token=self.hf_token,
|
152 |
+
device=self.device,
|
153 |
+
segmentation_batch_size=self.segmentation_batch_size
|
154 |
+
)
|
155 |
+
|
156 |
+
def transcribe(self, audio: AudioSegment) -> CombinedTranscriptionResult:
|
157 |
+
"""
|
158 |
+
转录整个音频 (新的非流式逻辑将在这里实现)
|
159 |
+
|
160 |
+
参数:
|
161 |
+
audio: 要转录的AudioSegment对象
|
162 |
+
|
163 |
+
返回:
|
164 |
+
包含完整转录和说话人信息的结果
|
165 |
+
"""
|
166 |
+
logger.info(f"开始转录 {len(audio)/1000:.2f} 秒的音频 (非流式)")
|
167 |
+
|
168 |
+
if self.parallel:
|
169 |
+
# 并行执行ASR和说话人分离
|
170 |
+
logger.info("并行执行ASR和说话人分离")
|
171 |
+
with ThreadPoolExecutor(max_workers=2) as executor:
|
172 |
+
asr_future = executor.submit(self._run_asr, audio)
|
173 |
+
diarization_future = executor.submit(self._run_diarization, audio)
|
174 |
+
|
175 |
+
asr_result: TranscriptionResult = asr_future.result()
|
176 |
+
diarization_result: DiarizationResult = diarization_future.result()
|
177 |
+
|
178 |
+
logger.debug(f"ASR完成,识别语言: {asr_result.language},得到 {len(asr_result.segments)} 个分段")
|
179 |
+
logger.debug(f"说话人分离完成,得到 {len(diarization_result.segments)} 个说话人分段,检测到 {diarization_result.num_speakers} 个说话人")
|
180 |
+
else:
|
181 |
+
# 顺序执行ASR和说话人分离
|
182 |
+
# 步骤1: 对整个音频执行ASR
|
183 |
+
logger.debug("执行ASR...")
|
184 |
+
asr_result: TranscriptionResult = asr_router.transcribe_audio(
|
185 |
+
audio,
|
186 |
+
provider=self.asr_provider,
|
187 |
+
model_name=self.asr_model_name,
|
188 |
+
device=self.device
|
189 |
+
)
|
190 |
+
logger.debug(f"ASR完成,识别语言: {asr_result.language},得到 {len(asr_result.segments)} 个分段")
|
191 |
+
|
192 |
+
# 步骤2: 对整个音频执行说话人分离
|
193 |
+
logger.debug("执行说话人分离...")
|
194 |
+
diarization_result: DiarizationResult = diarizer_router.diarize_audio(
|
195 |
+
audio,
|
196 |
+
provider=self.diarization_provider,
|
197 |
+
model_name=self.diarization_model_name,
|
198 |
+
token=self.hf_token,
|
199 |
+
device=self.device,
|
200 |
+
segmentation_batch_size=self.segmentation_batch_size
|
201 |
+
)
|
202 |
+
logger.debug(f"说话人分离完成,得到 {len(diarization_result.segments)} 个说话人分段,检测到 {diarization_result.num_speakers} 个说话人")
|
203 |
+
|
204 |
+
# 步骤3: 创建增强分段
|
205 |
+
all_enhanced_segments: List[EnhancedSegment] = self._create_enhanced_segments_with_splitting(
|
206 |
+
asr_result.segments,
|
207 |
+
diarization_result.segments,
|
208 |
+
asr_result.language
|
209 |
+
)
|
210 |
+
|
211 |
+
# 步骤4: (可选)合并相邻的文本分段
|
212 |
+
if all_enhanced_segments:
|
213 |
+
logger.debug(f"合并前有 {len(all_enhanced_segments)} 个增强分段,尝试合并相邻分段...")
|
214 |
+
final_segments = self._merge_adjacent_text_segments(all_enhanced_segments)
|
215 |
+
logger.debug(f"合并后有 {len(final_segments)} 个增强分段")
|
216 |
+
else:
|
217 |
+
final_segments = []
|
218 |
+
logger.debug("没有增强分段可供合并。")
|
219 |
+
|
220 |
+
# 整理合并的文本
|
221 |
+
full_text = " ".join([segment.text for segment in final_segments]).strip()
|
222 |
+
|
223 |
+
# 计算最终说话人数
|
224 |
+
num_speakers_set = set(s.speaker for s in final_segments if s.speaker != "UNKNOWN")
|
225 |
+
|
226 |
+
return CombinedTranscriptionResult(
|
227 |
+
segments=final_segments,
|
228 |
+
text=full_text,
|
229 |
+
language=asr_result.language or "unknown",
|
230 |
+
num_speakers=len(num_speakers_set) if num_speakers_set else diarization_result.num_speakers
|
231 |
+
)
|
232 |
+
|
233 |
+
# 新方法:根据标点分割ASR文本片段
|
234 |
+
def _split_asr_segment_by_punctuation(
|
235 |
+
self,
|
236 |
+
asr_seg_text: str,
|
237 |
+
asr_seg_start: float,
|
238 |
+
asr_seg_end: float
|
239 |
+
) -> List[Dict[str, Any]]:
|
240 |
+
"""
|
241 |
+
根据标点符号分割ASR文本片段,并按字符比例估算子片段的时间戳。
|
242 |
+
返回: 字典列表,每个字典包含 'text', 'start', 'end'。
|
243 |
+
"""
|
244 |
+
sentence_terminators = ".。?!?!;;"
|
245 |
+
# 正则表达式:匹配句子内容以及紧随其后的标点(如果存在)
|
246 |
+
# 使用 re.split 保留分隔符,然后重组
|
247 |
+
parts = re.split(f'([{sentence_terminators}])', asr_seg_text)
|
248 |
+
|
249 |
+
sub_texts_final = []
|
250 |
+
current_s = ""
|
251 |
+
for s_part in parts:
|
252 |
+
if not s_part:
|
253 |
+
continue
|
254 |
+
current_s += s_part
|
255 |
+
if s_part in sentence_terminators:
|
256 |
+
if current_s.strip():
|
257 |
+
sub_texts_final.append(current_s.strip())
|
258 |
+
current_s = ""
|
259 |
+
if current_s.strip():
|
260 |
+
sub_texts_final.append(current_s.strip())
|
261 |
+
|
262 |
+
if not sub_texts_final or (len(sub_texts_final) == 1 and sub_texts_final[0] == asr_seg_text.strip()):
|
263 |
+
# 没有有效分割或分割后只有一个句子(等于原始文本)
|
264 |
+
return [{"text": asr_seg_text.strip(), "start": asr_seg_start, "end": asr_seg_end}]
|
265 |
+
|
266 |
+
output_sub_segments = []
|
267 |
+
total_text_len = len(asr_seg_text) # 使用原始文本长度进行比例计算
|
268 |
+
if total_text_len == 0:
|
269 |
+
return [{"text": "", "start": asr_seg_start, "end": asr_seg_end}]
|
270 |
+
|
271 |
+
current_time = asr_seg_start
|
272 |
+
original_duration = asr_seg_end - asr_seg_start
|
273 |
+
|
274 |
+
for i, sub_text in enumerate(sub_texts_final):
|
275 |
+
sub_len = len(sub_text)
|
276 |
+
sub_duration = (sub_len / total_text_len) * original_duration
|
277 |
+
|
278 |
+
sub_start_time = current_time
|
279 |
+
sub_end_time = current_time + sub_duration
|
280 |
+
|
281 |
+
# 对于最后一个分片,确保其结束时间与原始分段的结束时间一致,以避免累积误差
|
282 |
+
if i == len(sub_texts_final) - 1:
|
283 |
+
sub_end_time = asr_seg_end
|
284 |
+
|
285 |
+
# 确保结束时间不超过原始结束时间,并且开始时间不晚于结束时间
|
286 |
+
sub_end_time = min(sub_end_time, asr_seg_end)
|
287 |
+
if sub_start_time >= sub_end_time and sub_start_time == asr_seg_end : # 如果开始等于原始结束,允许微小片段
|
288 |
+
if sub_text: # 仅当有文本时
|
289 |
+
output_sub_segments.append({"text": sub_text, "start": sub_start_time, "end": sub_end_time})
|
290 |
+
elif sub_start_time < sub_end_time :
|
291 |
+
output_sub_segments.append({"text": sub_text, "start": sub_start_time, "end": sub_end_time})
|
292 |
+
|
293 |
+
current_time = sub_end_time
|
294 |
+
if current_time >= asr_seg_end and i < len(sub_texts_final) -1: # 如果时间已用完,但还有句子
|
295 |
+
# 将剩余句子附加到最后一个有效的时间段,或创建零长度的段
|
296 |
+
logger.warning(f"时间已在分割过程中用尽,但仍有文本未分配时间。原始段: [{asr_seg_start}-{asr_seg_end}], 当前子句: '{sub_text}'")
|
297 |
+
# 为后续未分配时间的文本创建零时长或极短时长的片段,附着在末尾
|
298 |
+
for k in range(i + 1, len(sub_texts_final)):
|
299 |
+
remaining_text = sub_texts_final[k]
|
300 |
+
if remaining_text:
|
301 |
+
output_sub_segments.append({"text": remaining_text, "start": asr_seg_end, "end": asr_seg_end})
|
302 |
+
break
|
303 |
+
|
304 |
+
|
305 |
+
# 如果处理后没有任何子分段(例如原始文本为空,或分割逻辑问题),返回原始信息作为一个分段
|
306 |
+
if not output_sub_segments and asr_seg_text.strip():
|
307 |
+
return [{"text": asr_seg_text.strip(), "start": asr_seg_start, "end": asr_seg_end}]
|
308 |
+
elif not output_sub_segments and not asr_seg_text.strip():
|
309 |
+
return [{"text": "", "start": asr_seg_start, "end": asr_seg_end}]
|
310 |
+
|
311 |
+
|
312 |
+
return output_sub_segments
|
313 |
+
|
314 |
+
# 新的核心方法:创建增强分段,包含说话人分配和按需分裂逻辑
|
315 |
+
def _create_enhanced_segments_with_splitting(
|
316 |
+
self,
|
317 |
+
asr_segments: List[Dict[str, Union[float, str]]],
|
318 |
+
diarization_segments: List[Dict[str, Union[float, str, int]]],
|
319 |
+
language: str
|
320 |
+
) -> List[EnhancedSegment]:
|
321 |
+
"""
|
322 |
+
为ASR分段分配说话人,如果ASR分段跨越多个说话人,则尝试按标点分裂。
|
323 |
+
"""
|
324 |
+
final_enhanced_segments: List[EnhancedSegment] = []
|
325 |
+
|
326 |
+
if not asr_segments:
|
327 |
+
return []
|
328 |
+
|
329 |
+
# 为了快速查找,可以预处理 diarization_segments,但对于数量不多的情况,直接遍历也可
|
330 |
+
# diarization_segments.sort(key=lambda x: x['start']) # 确保有序
|
331 |
+
|
332 |
+
for asr_seg in asr_segments:
|
333 |
+
asr_start = float(asr_seg["start"])
|
334 |
+
asr_end = float(asr_seg["end"])
|
335 |
+
asr_text = str(asr_seg["text"]).strip()
|
336 |
+
|
337 |
+
if not asr_text or asr_start >= asr_end: # 跳过无效的ASR分段
|
338 |
+
continue
|
339 |
+
|
340 |
+
# 找出与当前ASR分段在时间上重叠的所有说话人分段
|
341 |
+
overlapping_diar_segs = []
|
342 |
+
for diar_seg in diarization_segments:
|
343 |
+
diar_start = float(diar_seg["start"])
|
344 |
+
diar_end = float(diar_seg["end"])
|
345 |
+
|
346 |
+
overlap_start = max(asr_start, diar_start)
|
347 |
+
overlap_end = min(asr_end, diar_end)
|
348 |
+
|
349 |
+
if overlap_end > overlap_start: # 有重叠
|
350 |
+
overlapping_diar_segs.append({
|
351 |
+
"speaker": str(diar_seg["speaker"]),
|
352 |
+
"start": diar_start,
|
353 |
+
"end": diar_end,
|
354 |
+
"overlap_duration": overlap_end - overlap_start
|
355 |
+
})
|
356 |
+
|
357 |
+
distinct_speakers_in_overlap = set(d['speaker'] for d in overlapping_diar_segs)
|
358 |
+
|
359 |
+
segments_to_process_further: List[Dict[str, Any]] = []
|
360 |
+
|
361 |
+
if len(distinct_speakers_in_overlap) > 1:
|
362 |
+
logger.debug(f"ASR段 [{asr_start:.2f}-{asr_end:.2f}] \"{asr_text[:50]}...\" 跨越 {len(distinct_speakers_in_overlap)} 个说话人。尝试按标点分裂。")
|
363 |
+
# 跨多个说话人,尝试按标点分裂ASR segment
|
364 |
+
sub_asr_segments_data = self._split_asr_segment_by_punctuation(
|
365 |
+
asr_text,
|
366 |
+
asr_start,
|
367 |
+
asr_end
|
368 |
+
)
|
369 |
+
if len(sub_asr_segments_data) > 1:
|
370 |
+
logger.debug(f"成功将ASR段分裂成 {len(sub_asr_segments_data)} 个子句。")
|
371 |
+
segments_to_process_further.extend(sub_asr_segments_data)
|
372 |
+
else:
|
373 |
+
# 单一说话人或无说话人重叠(也视为单一处理单位)
|
374 |
+
segments_to_process_further.append({"text": asr_text, "start": asr_start, "end": asr_end})
|
375 |
+
|
376 |
+
# 为每个原始或分裂后的ASR(子)分段分配说话人
|
377 |
+
for current_proc_seg_data in segments_to_process_further:
|
378 |
+
proc_text = current_proc_seg_data["text"].strip()
|
379 |
+
proc_start = current_proc_seg_data["start"]
|
380 |
+
proc_end = current_proc_seg_data["end"]
|
381 |
+
|
382 |
+
if not proc_text or proc_start >= proc_end: # 跳过无效的子分段
|
383 |
+
continue
|
384 |
+
|
385 |
+
# 为当前处理的(可能是子)分段确定最佳说话人
|
386 |
+
speaker_overlaps_for_proc_seg = {}
|
387 |
+
for diar_seg_info in overlapping_diar_segs: # 使用之前计算的、与原始ASR段重叠的diar_segs
|
388 |
+
# 现在需要计算这个 diar_seg_info 与 proc_seg 的重叠
|
389 |
+
overlap_start = max(proc_start, diar_seg_info["start"])
|
390 |
+
overlap_end = min(proc_end, diar_seg_info["end"])
|
391 |
+
|
392 |
+
if overlap_end > overlap_start:
|
393 |
+
overlap_duration = overlap_end - overlap_start
|
394 |
+
speaker = diar_seg_info["speaker"]
|
395 |
+
speaker_overlaps_for_proc_seg[speaker] = \
|
396 |
+
speaker_overlaps_for_proc_seg.get(speaker, 0) + overlap_duration
|
397 |
+
|
398 |
+
best_speaker = "UNKNOWN"
|
399 |
+
if speaker_overlaps_for_proc_seg:
|
400 |
+
best_speaker = max(speaker_overlaps_for_proc_seg.items(), key=lambda x: x[1])[0]
|
401 |
+
elif overlapping_diar_segs: # 如果子分段本身没有重叠,但原始ASR段有
|
402 |
+
# 可以选择原始ASR段中占比最大的,或者最近的
|
403 |
+
# 为简化,如果子分段无直接重叠,也可能标记为UNKNOWN,或尝试找最近的
|
404 |
+
# 这里采用:如果子分段无直接重叠,但在原始ASR段中有说话人,则使用原始ASR段中重叠最长的
|
405 |
+
# (此逻辑分支效果待观察,更简单的是直接UNKNOWN)
|
406 |
+
# 此处简化:若子分段无重叠,则为UNKNOWN
|
407 |
+
pass # best_speaker 默认为 UNKNOWN
|
408 |
+
|
409 |
+
# 如果 best_speaker 仍为 UNKNOWN,但原始ASR段只有一个说话者,则使用该说话者
|
410 |
+
if best_speaker == "UNKNOWN" and len(distinct_speakers_in_overlap) == 1:
|
411 |
+
best_speaker = list(distinct_speakers_in_overlap)[0]
|
412 |
+
elif best_speaker == "UNKNOWN" and not overlapping_diar_segs:
|
413 |
+
# 如果整个ASR段都没有任何说话人信息,则确实是UNKNOWN
|
414 |
+
pass
|
415 |
+
|
416 |
+
|
417 |
+
final_enhanced_segments.append(
|
418 |
+
EnhancedSegment(
|
419 |
+
start=proc_start,
|
420 |
+
end=proc_end,
|
421 |
+
text=proc_text,
|
422 |
+
speaker=best_speaker,
|
423 |
+
language=language # 所有子分段继承原始ASR段的语言
|
424 |
+
)
|
425 |
+
)
|
426 |
+
|
427 |
+
# 对最终结果按开始时间排序
|
428 |
+
final_enhanced_segments.sort(key=lambda seg: seg.start)
|
429 |
+
return final_enhanced_segments
|
430 |
+
|
431 |
+
def transcribe_podcast(
|
432 |
+
self,
|
433 |
+
audio: AudioSegment,
|
434 |
+
podcast_info: PodcastChannel,
|
435 |
+
episode_info: PodcastEpisode,
|
436 |
+
) -> CombinedTranscriptionResult:
|
437 |
+
"""
|
438 |
+
专门针对播客剧集的音频转录方法
|
439 |
+
|
440 |
+
参数:
|
441 |
+
audio: 要转录的AudioSegment对象
|
442 |
+
podcast_info: 播客频道信息
|
443 |
+
episode_info: 播客剧集信息
|
444 |
+
|
445 |
+
返回:
|
446 |
+
包含完整转录和识别后说话人名称的结果
|
447 |
+
"""
|
448 |
+
logger.info(f"开始转录播客剧集 {len(audio)/1000:.2f} 秒的音频")
|
449 |
+
|
450 |
+
# 1. 先执行基础转录流程
|
451 |
+
transcription_result = self.transcribe(audio)
|
452 |
+
|
453 |
+
# 3. 识别说话人名称
|
454 |
+
logger.info("识别说话人名称...")
|
455 |
+
speaker_name_map = self.speaker_identifier.recognize_speaker_names(
|
456 |
+
transcription_result.segments,
|
457 |
+
podcast_info,
|
458 |
+
episode_info
|
459 |
+
)
|
460 |
+
|
461 |
+
# 4. 将识别的说话人名称添加到转录结果中
|
462 |
+
enhanced_segments_with_names = []
|
463 |
+
for segment in transcription_result.segments:
|
464 |
+
# 复制原始段落并添加说话人名称
|
465 |
+
speaker_id = segment.speaker
|
466 |
+
speaker_name = speaker_name_map.get(speaker_id, None)
|
467 |
+
|
468 |
+
# 创建新的段落对象,包含说话人名称
|
469 |
+
new_segment = EnhancedSegment(
|
470 |
+
start=segment.start,
|
471 |
+
end=segment.end,
|
472 |
+
text=segment.text,
|
473 |
+
speaker=speaker_id,
|
474 |
+
language=segment.language,
|
475 |
+
speaker_name=speaker_name
|
476 |
+
)
|
477 |
+
enhanced_segments_with_names.append(new_segment)
|
478 |
+
|
479 |
+
# 5. 创建并返回新的转录结果
|
480 |
+
return CombinedTranscriptionResult(
|
481 |
+
segments=enhanced_segments_with_names,
|
482 |
+
text=transcription_result.text,
|
483 |
+
language=transcription_result.language,
|
484 |
+
num_speakers=transcription_result.num_speakers
|
485 |
+
)
|
486 |
+
|
487 |
+
|
488 |
+
def transcribe_audio(
|
489 |
+
audio_segment: AudioSegment,
|
490 |
+
asr_model_name: str = "distil-whisper/distil-large-v3.5",
|
491 |
+
asr_provider: str = "distil_whisper_transformers",
|
492 |
+
diarization_model_name: str = "pyannote/speaker-diarization-3.1",
|
493 |
+
diarization_provider: str = "pyannote_transformers",
|
494 |
+
hf_token: Optional[str] = None,
|
495 |
+
device: Optional[str] = None,
|
496 |
+
segmentation_batch_size: int = 64,
|
497 |
+
parallel: bool = False,
|
498 |
+
) -> CombinedTranscriptionResult: # 返回类型固定为 CombinedTranscriptionResult
|
499 |
+
"""
|
500 |
+
整合ASR和说话人分离的音频转录函数 (仅支持非流式)
|
501 |
+
|
502 |
+
参数:
|
503 |
+
audio_segment: 输入的AudioSegment对象
|
504 |
+
asr_model_name: ASR模型名称
|
505 |
+
asr_provider: ASR提供者名称
|
506 |
+
diarization_model_name: 说话人分离模型名称
|
507 |
+
diarization_provider: 说话人分离提供者名称
|
508 |
+
hf_token: Hugging Face令牌
|
509 |
+
device: 推理设备,'cpu'或'cuda'
|
510 |
+
segmentation_batch_size: 分割批处理大小,默认为64
|
511 |
+
parallel: 是否并行执行ASR和说话人分离,默认为False
|
512 |
+
|
513 |
+
返回:
|
514 |
+
完整转录结果
|
515 |
+
"""
|
516 |
+
logger.info(f"调用transcribe_audio函数 (非流式),音频长度: {len(audio_segment)/1000:.2f}秒")
|
517 |
+
|
518 |
+
transcriber = CombinedTranscriber(
|
519 |
+
asr_model_name=asr_model_name,
|
520 |
+
asr_provider=asr_provider,
|
521 |
+
diarization_model_name=diarization_model_name,
|
522 |
+
diarization_provider=diarization_provider,
|
523 |
+
hf_token=hf_token,
|
524 |
+
device=device,
|
525 |
+
segmentation_batch_size=segmentation_batch_size,
|
526 |
+
parallel=parallel
|
527 |
+
)
|
528 |
+
|
529 |
+
# 直接调用 transcribe 方法
|
530 |
+
return transcriber.transcribe(audio_segment)
|
531 |
+
|
532 |
+
def transcribe_podcast_audio(
|
533 |
+
audio_segment: AudioSegment,
|
534 |
+
podcast_info: PodcastChannel,
|
535 |
+
episode_info: PodcastEpisode,
|
536 |
+
asr_model_name: str = "distil-whisper/distil-large-v3.5",
|
537 |
+
asr_provider: str = "distil_whisper_transformers",
|
538 |
+
diarization_model_name: str = "pyannote/speaker-diarization-3.1",
|
539 |
+
diarization_provider: str = "pyannote_transformers",
|
540 |
+
llm_model_name: Optional[str] = None,
|
541 |
+
llm_provider: Optional[str] = None,
|
542 |
+
hf_token: Optional[str] = None,
|
543 |
+
device: Optional[str] = None,
|
544 |
+
segmentation_batch_size: int = 64,
|
545 |
+
parallel: bool = False,
|
546 |
+
) -> CombinedTranscriptionResult:
|
547 |
+
"""
|
548 |
+
针对播客剧集的音频转录函数,包含说话人名称识别
|
549 |
+
|
550 |
+
参数:
|
551 |
+
audio_segment: 输入的AudioSegment对象
|
552 |
+
podcast_info: 播客频道信息
|
553 |
+
episode_info: 播客剧集信息
|
554 |
+
asr_model_name: ASR模型名称
|
555 |
+
asr_provider: ASR提供者名称
|
556 |
+
diarization_provider: 说话人分离提供者名称
|
557 |
+
diarization_model_name: 说话人分离模型名称
|
558 |
+
llm_model_name: LLM模型名称,如果为None则无法识别说话人名称
|
559 |
+
llm_provider: LLM提供者名称,如果为None则无法识别说话人名称
|
560 |
+
hf_token: Hugging Face令牌
|
561 |
+
device: 推理设备,'cpu'或'cuda'
|
562 |
+
segmentation_batch_size: 分割批处理大小,默认为64
|
563 |
+
parallel: 是否并行执行ASR和说话人分离,默认为False
|
564 |
+
|
565 |
+
返回:
|
566 |
+
包含说话人名称的完整转录结果
|
567 |
+
"""
|
568 |
+
logger.info(f"调用transcribe_podcast_audio函数,音频长度: {len(audio_segment)/1000:.2f}秒")
|
569 |
+
|
570 |
+
transcriber = CombinedTranscriber(
|
571 |
+
asr_model_name=asr_model_name,
|
572 |
+
asr_provider=asr_provider,
|
573 |
+
diarization_provider=diarization_provider,
|
574 |
+
diarization_model_name=diarization_model_name,
|
575 |
+
llm_model_name=llm_model_name,
|
576 |
+
llm_provider=llm_provider,
|
577 |
+
hf_token=hf_token,
|
578 |
+
device=device,
|
579 |
+
segmentation_batch_size=segmentation_batch_size,
|
580 |
+
parallel=parallel
|
581 |
+
)
|
582 |
+
|
583 |
+
# 调用播客专用转录方法
|
584 |
+
return transcriber.transcribe_podcast(
|
585 |
+
audio=audio_segment,
|
586 |
+
podcast_info=podcast_info,
|
587 |
+
episode_info=episode_info,
|
588 |
+
)
|
src/podcast_transcribe/webui/app.py
ADDED
@@ -0,0 +1,585 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import requests
|
3 |
+
import io
|
4 |
+
from pydub import AudioSegment
|
5 |
+
import traceback # 用于打印更详细的错误信息
|
6 |
+
import tempfile
|
7 |
+
import os
|
8 |
+
import uuid
|
9 |
+
import atexit
|
10 |
+
import shutil
|
11 |
+
|
12 |
+
# 尝试相对导入,这在通过 `python -m src.podcast_transcribe.webui.app` 运行时有效
|
13 |
+
try:
|
14 |
+
from podcast_transcribe.rss.podcast_rss_parser import parse_podcast_rss
|
15 |
+
from podcast_transcribe.schemas import PodcastChannel, PodcastEpisode, CombinedTranscriptionResult, EnhancedSegment
|
16 |
+
from podcast_transcribe.transcriber import transcribe_podcast_audio
|
17 |
+
except ImportError:
|
18 |
+
# 如果直接运行此脚本,并且项目根目录不在PYTHONPATH中,
|
19 |
+
# 则需要将项目根目录添加到 sys.path
|
20 |
+
import sys
|
21 |
+
import os
|
22 |
+
# 获取当前脚本文件所在的目录 (src/podcast_transcribe/webui)
|
23 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
24 |
+
# 获取项目根目录 (向上三级: webui -> podcast_transcribe -> src -> project_root)
|
25 |
+
# 修正:应该是 src 的父目录是项目根
|
26 |
+
project_root = os.path.dirname(os.path.dirname(os.path.dirname(current_dir)))
|
27 |
+
# 将 src 目录添加到 sys.path,因为模块是 podcast_transcribe.xxx
|
28 |
+
src_path = os.path.join(project_root, "src")
|
29 |
+
if src_path not in sys.path:
|
30 |
+
sys.path.insert(0, src_path)
|
31 |
+
|
32 |
+
from podcast_transcribe.rss.podcast_rss_parser import parse_podcast_rss
|
33 |
+
from podcast_transcribe.schemas import PodcastChannel, PodcastEpisode, CombinedTranscriptionResult, EnhancedSegment
|
34 |
+
from podcast_transcribe.transcriber import transcribe_podcast_audio
|
35 |
+
|
36 |
+
# 用于存储应用程序使用的所有临时文件路径
|
37 |
+
temp_files = []
|
38 |
+
|
39 |
+
def cleanup_temp_files():
|
40 |
+
"""清理应用程序使用的临时文件"""
|
41 |
+
global temp_files
|
42 |
+
print(f"应用程序退出,清理 {len(temp_files)} 个临时文件...")
|
43 |
+
|
44 |
+
for filepath in temp_files:
|
45 |
+
try:
|
46 |
+
if os.path.exists(filepath):
|
47 |
+
os.remove(filepath)
|
48 |
+
print(f"已删除临时文件: {filepath}")
|
49 |
+
except Exception as e:
|
50 |
+
print(f"无法删除临时文件 {filepath}: {e}")
|
51 |
+
|
52 |
+
# 清空列表
|
53 |
+
temp_files = []
|
54 |
+
|
55 |
+
# 注册应用程序退出时的清理函数
|
56 |
+
atexit.register(cleanup_temp_files)
|
57 |
+
|
58 |
+
def parse_rss_feed(rss_url: str):
|
59 |
+
"""回调函数:解析 RSS Feed"""
|
60 |
+
print(f"开始解析RSS: {rss_url}")
|
61 |
+
|
62 |
+
if not rss_url:
|
63 |
+
print("RSS地址为空")
|
64 |
+
return {
|
65 |
+
status_message_area: gr.update(value="错误:请输入 RSS 地址。"),
|
66 |
+
podcast_title_display: gr.update(value="", visible=False),
|
67 |
+
episode_dropdown: gr.update(choices=[], value=None, interactive=False),
|
68 |
+
podcast_data_state: None,
|
69 |
+
audio_player: gr.update(value=None),
|
70 |
+
current_audio_url_state: None,
|
71 |
+
episode_shownotes: gr.update(value="", visible=False),
|
72 |
+
transcription_output_df: gr.update(value=None, headers=["说话人", "文本", "时间"]),
|
73 |
+
transcribe_button: gr.update(interactive=False),
|
74 |
+
selected_episode_index_state: None
|
75 |
+
}
|
76 |
+
|
77 |
+
try:
|
78 |
+
print(f"正在解析RSS: {rss_url}")
|
79 |
+
# 先更新状态消息,但由于不再使用生成器,我们直接在解析后更新UI
|
80 |
+
|
81 |
+
podcast_data: PodcastChannel = parse_podcast_rss(rss_url)
|
82 |
+
print(f"RSS解析结果: 频道名称={podcast_data.title if podcast_data else 'None'}, 剧集数量={len(podcast_data.episodes) if podcast_data and podcast_data.episodes else 0}")
|
83 |
+
|
84 |
+
if podcast_data and podcast_data.episodes:
|
85 |
+
choices = []
|
86 |
+
for i, episode in enumerate(podcast_data.episodes):
|
87 |
+
# 使用 (标题 (时长), guid 或索引) 作为选项
|
88 |
+
# 如果 guid 不可靠或缺失,可以使用索引
|
89 |
+
label = f"{episode.title or '无标题'} (时长: {episode.duration or '未知'})"
|
90 |
+
# 将 episode 对象直接作为值传递,或仅传递一个唯一标识符
|
91 |
+
# 为了简单起见,我们使用索引作为唯一ID,因为我们需要从 podcast_data_state 中检索完整的 episode
|
92 |
+
choices.append((label, i))
|
93 |
+
|
94 |
+
# 显示播客标题
|
95 |
+
podcast_title = f"## 🎙️ {podcast_data.title or '未知播客'}"
|
96 |
+
if podcast_data.author:
|
97 |
+
podcast_title += f"\n**主播/制作人:** {podcast_data.author}"
|
98 |
+
if podcast_data.description:
|
99 |
+
# 限制描述长度,避免界面过长
|
100 |
+
description = podcast_data.description[:300]
|
101 |
+
if len(podcast_data.description) > 300:
|
102 |
+
description += "..."
|
103 |
+
podcast_title += f"\n\n**播客简介:** {description}"
|
104 |
+
|
105 |
+
return {
|
106 |
+
status_message_area: gr.update(value=f"成功解析到 {len(podcast_data.episodes)} 个剧集。请选择一个剧集。"),
|
107 |
+
podcast_title_display: gr.update(value=podcast_title, visible=True),
|
108 |
+
episode_dropdown: gr.update(choices=choices, value=None, interactive=True),
|
109 |
+
podcast_data_state: podcast_data,
|
110 |
+
audio_player: gr.update(value=None),
|
111 |
+
current_audio_url_state: None,
|
112 |
+
episode_shownotes: gr.update(value="", visible=False),
|
113 |
+
transcription_output_df: gr.update(value=None),
|
114 |
+
transcribe_button: gr.update(interactive=False),
|
115 |
+
selected_episode_index_state: None
|
116 |
+
}
|
117 |
+
elif podcast_data: # 有 channel 信息但没有 episodes
|
118 |
+
print("解析成功但未找到剧集")
|
119 |
+
podcast_title = f"## 🎙️ {podcast_data.title or '未知播客'}"
|
120 |
+
if podcast_data.author:
|
121 |
+
podcast_title += f"\n**主播/制作人:** {podcast_data.author}"
|
122 |
+
|
123 |
+
return {
|
124 |
+
status_message_area: gr.update(value="解析成功,但未找到任何剧集。"),
|
125 |
+
podcast_title_display: gr.update(value=podcast_title, visible=True),
|
126 |
+
episode_dropdown: gr.update(choices=[], value=None, interactive=False),
|
127 |
+
podcast_data_state: podcast_data, # 仍然存储,以防万一
|
128 |
+
audio_player: gr.update(value=None),
|
129 |
+
current_audio_url_state: None,
|
130 |
+
episode_shownotes: gr.update(value="", visible=False),
|
131 |
+
transcription_output_df: gr.update(value=None),
|
132 |
+
transcribe_button: gr.update(interactive=False),
|
133 |
+
selected_episode_index_state: None
|
134 |
+
}
|
135 |
+
else:
|
136 |
+
print(f"解析RSS失败: {rss_url}")
|
137 |
+
return {
|
138 |
+
status_message_area: gr.update(value=f"解析 RSS失败: {rss_url}。请检查URL或网络连接。"),
|
139 |
+
podcast_title_display: gr.update(value="", visible=False),
|
140 |
+
episode_dropdown: gr.update(choices=[], value=None, interactive=False),
|
141 |
+
podcast_data_state: None,
|
142 |
+
audio_player: gr.update(value=None),
|
143 |
+
current_audio_url_state: None,
|
144 |
+
episode_shownotes: gr.update(value="", visible=False),
|
145 |
+
transcription_output_df: gr.update(value=None),
|
146 |
+
transcribe_button: gr.update(interactive=False),
|
147 |
+
selected_episode_index_state: None
|
148 |
+
}
|
149 |
+
except Exception as e:
|
150 |
+
print(f"解析 RSS 时发生错误: {e}")
|
151 |
+
traceback.print_exc()
|
152 |
+
return {
|
153 |
+
status_message_area: gr.update(value=f"解析 RSS 时发生严重错误: {e}"),
|
154 |
+
podcast_title_display: gr.update(value="", visible=False),
|
155 |
+
episode_dropdown: gr.update(choices=[], value=None, interactive=False),
|
156 |
+
podcast_data_state: None,
|
157 |
+
audio_player: gr.update(value=None),
|
158 |
+
current_audio_url_state: None,
|
159 |
+
episode_shownotes: gr.update(value="", visible=False),
|
160 |
+
transcription_output_df: gr.update(value=None),
|
161 |
+
transcribe_button: gr.update(interactive=False),
|
162 |
+
selected_episode_index_state: None
|
163 |
+
}
|
164 |
+
|
165 |
+
def load_episode_audio(selected_episode_index: int, podcast_data: PodcastChannel):
|
166 |
+
"""回调函数:当用户从下拉菜单选择一个剧集时加载音频"""
|
167 |
+
global temp_files
|
168 |
+
print(f"开始加载剧集音频,选择的索引: {selected_episode_index}")
|
169 |
+
|
170 |
+
if selected_episode_index is None or podcast_data is None or not podcast_data.episodes:
|
171 |
+
print("未选择剧集或无播客数据")
|
172 |
+
return {
|
173 |
+
audio_player: gr.update(value=None),
|
174 |
+
current_audio_url_state: None,
|
175 |
+
status_message_area: gr.update(value="请先解析 RSS 并选择一个剧集。"),
|
176 |
+
episode_shownotes: gr.update(value="", visible=False),
|
177 |
+
transcription_output_df: gr.update(value=None),
|
178 |
+
local_audio_file_path: None,
|
179 |
+
transcribe_button: gr.update(interactive=False),
|
180 |
+
selected_episode_index_state: None
|
181 |
+
}
|
182 |
+
|
183 |
+
try:
|
184 |
+
episode = podcast_data.episodes[selected_episode_index]
|
185 |
+
audio_url = episode.audio_url
|
186 |
+
print(f"获取到剧集信息,标题: {episode.title}, 音频URL: {audio_url}")
|
187 |
+
|
188 |
+
# 准备剧集信息显示
|
189 |
+
episode_shownotes_content = ""
|
190 |
+
|
191 |
+
# 准备shownotes内容
|
192 |
+
if episode.shownotes:
|
193 |
+
# 清理HTML标签并格式化shownotes
|
194 |
+
import re
|
195 |
+
# 简单的HTML标签清理
|
196 |
+
clean_shownotes = re.sub(r'<[^>]+>', '', episode.shownotes)
|
197 |
+
# 替换HTML实体
|
198 |
+
clean_shownotes = clean_shownotes.replace(' ', ' ').replace('&', '&').replace('<', '<').replace('>', '>')
|
199 |
+
# 清理多余空白
|
200 |
+
clean_shownotes = re.sub(r'\s+', ' ', clean_shownotes).strip()
|
201 |
+
|
202 |
+
episode_shownotes_content = f"### 📝 剧集详情\n\n**标题:** {episode.title or '无标题'}\n\n"
|
203 |
+
if episode.published_date:
|
204 |
+
episode_shownotes_content += f"**发布日期:** {episode.published_date.strftime('%Y年%m月%d日')}\n\n"
|
205 |
+
if episode.duration:
|
206 |
+
episode_shownotes_content += f"**时长:** {episode.duration}\n\n"
|
207 |
+
|
208 |
+
episode_shownotes_content += f"**节目介绍:**\n\n{clean_shownotes}"
|
209 |
+
elif episode.summary:
|
210 |
+
# 如果没有shownotes,使用summary
|
211 |
+
episode_shownotes_content = f"### 📝 剧集详情\n\n**标题:** {episode.title or '无标题'}\n\n"
|
212 |
+
if episode.published_date:
|
213 |
+
episode_shownotes_content += f"**发布日期:** {episode.published_date.strftime('%Y年%m月%d日')}\n\n"
|
214 |
+
if episode.duration:
|
215 |
+
episode_shownotes_content += f"**时长:** {episode.duration}\n\n"
|
216 |
+
|
217 |
+
episode_shownotes_content += f"**节目简介:**\n\n{episode.summary}"
|
218 |
+
else:
|
219 |
+
# 最基本的信息
|
220 |
+
episode_shownotes_content = f"### 📝 剧集详情\n\n**标题:** {episode.title or '无标题'}\n\n"
|
221 |
+
if episode.published_date:
|
222 |
+
episode_shownotes_content += f"**发布日期:** {episode.published_date.strftime('%Y年%m月%d日')}\n\n"
|
223 |
+
if episode.duration:
|
224 |
+
episode_shownotes_content += f"**时长:** {episode.duration}\n\n"
|
225 |
+
|
226 |
+
if audio_url:
|
227 |
+
# 更新状态消息
|
228 |
+
print(f"正在下载音频: {audio_url}")
|
229 |
+
|
230 |
+
# 下载音频文件
|
231 |
+
try:
|
232 |
+
headers = {
|
233 |
+
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.114 Safari/537.36'
|
234 |
+
}
|
235 |
+
|
236 |
+
# 创建临时文件
|
237 |
+
temp_dir = tempfile.gettempdir()
|
238 |
+
# 使用UUID生成唯一文件名,避免冲突
|
239 |
+
unique_filename = f"podcast_audio_{uuid.uuid4().hex}"
|
240 |
+
|
241 |
+
# 先发送一个HEAD请求获取内容类型
|
242 |
+
head_response = requests.head(audio_url, timeout=30, headers=headers)
|
243 |
+
|
244 |
+
# 根据内容类型确定文件扩展名
|
245 |
+
content_type = head_response.headers.get('Content-Type', '').lower()
|
246 |
+
if 'mp3' in content_type:
|
247 |
+
file_ext = '.mp3'
|
248 |
+
elif 'mpeg' in content_type:
|
249 |
+
file_ext = '.mp3'
|
250 |
+
elif 'mp4' in content_type or 'm4a' in content_type:
|
251 |
+
file_ext = '.mp4'
|
252 |
+
elif 'wav' in content_type:
|
253 |
+
file_ext = '.wav'
|
254 |
+
elif 'ogg' in content_type:
|
255 |
+
file_ext = '.ogg'
|
256 |
+
else:
|
257 |
+
# 默认扩展名
|
258 |
+
file_ext = '.mp3'
|
259 |
+
|
260 |
+
temp_filepath = os.path.join(temp_dir, unique_filename + file_ext)
|
261 |
+
|
262 |
+
# 将文件路径添加到全局临时文件列表
|
263 |
+
temp_files.append(temp_filepath)
|
264 |
+
|
265 |
+
# 保存到临时文件
|
266 |
+
# 使用流式下载,避免一次性加载整个文件到内存
|
267 |
+
with open(temp_filepath, 'wb') as f:
|
268 |
+
# 使用流式响应并设置较大的块大小提高效率
|
269 |
+
response = requests.get(audio_url, timeout=60, headers=headers, stream=True)
|
270 |
+
response.raise_for_status()
|
271 |
+
|
272 |
+
# 从响应中获取文件大小(如果服务器提供)
|
273 |
+
total_size = int(response.headers.get('content-length', 0))
|
274 |
+
downloaded = 0
|
275 |
+
chunk_size = 8192 # 8KB 的块大小
|
276 |
+
|
277 |
+
# 分块下载并写入文件
|
278 |
+
for chunk in response.iter_content(chunk_size=chunk_size):
|
279 |
+
if chunk: # 过滤掉保持连接活跃的空块
|
280 |
+
f.write(chunk)
|
281 |
+
downloaded += len(chunk)
|
282 |
+
# 可以在这里添加下载进度更新
|
283 |
+
if total_size > 0:
|
284 |
+
download_percentage = downloaded / total_size
|
285 |
+
print(f"下载进度: {download_percentage:.1%}")
|
286 |
+
|
287 |
+
print(f"音频已下载到临时文件: {temp_filepath}")
|
288 |
+
|
289 |
+
return {
|
290 |
+
audio_player: gr.update(value=temp_filepath, label=f"当前播放: {episode.title or '无标题'}"),
|
291 |
+
current_audio_url_state: audio_url,
|
292 |
+
status_message_area: gr.update(value=f"已加载剧集: {episode.title or '无标题'}。"),
|
293 |
+
episode_shownotes: gr.update(value=episode_shownotes_content, visible=True),
|
294 |
+
transcription_output_df: gr.update(value=None),
|
295 |
+
local_audio_file_path: temp_filepath,
|
296 |
+
transcribe_button: gr.update(interactive=True),
|
297 |
+
selected_episode_index_state: selected_episode_index
|
298 |
+
}
|
299 |
+
except requests.exceptions.RequestException as e:
|
300 |
+
print(f"下载音频失败: {e}")
|
301 |
+
traceback.print_exc()
|
302 |
+
return {
|
303 |
+
audio_player: gr.update(value=None),
|
304 |
+
current_audio_url_state: None,
|
305 |
+
status_message_area: gr.update(value=f"错误:下载音频失败: {e}"),
|
306 |
+
episode_shownotes: gr.update(value=episode_shownotes_content, visible=True),
|
307 |
+
transcription_output_df: gr.update(value=None),
|
308 |
+
local_audio_file_path: None,
|
309 |
+
transcribe_button: gr.update(interactive=False),
|
310 |
+
selected_episode_index_state: None
|
311 |
+
}
|
312 |
+
else:
|
313 |
+
print(f"剧集 '{episode.title}' 缺少有效的音频URL")
|
314 |
+
return {
|
315 |
+
audio_player: gr.update(value=None),
|
316 |
+
current_audio_url_state: None,
|
317 |
+
status_message_area: gr.update(value=f"错误:选中的剧集 '{episode.title}' 没有提供有效的音频URL。"),
|
318 |
+
episode_shownotes: gr.update(value=episode_shownotes_content, visible=True),
|
319 |
+
transcription_output_df: gr.update(value=None),
|
320 |
+
local_audio_file_path: None,
|
321 |
+
transcribe_button: gr.update(interactive=False),
|
322 |
+
selected_episode_index_state: None
|
323 |
+
}
|
324 |
+
except IndexError:
|
325 |
+
print(f"无效的剧集索引: {selected_episode_index}")
|
326 |
+
return {
|
327 |
+
audio_player: gr.update(value=None),
|
328 |
+
current_audio_url_state: None,
|
329 |
+
status_message_area: gr.update(value="错误:选择的剧集索引无效。"),
|
330 |
+
episode_shownotes: gr.update(value="", visible=False),
|
331 |
+
transcription_output_df: gr.update(value=None),
|
332 |
+
local_audio_file_path: None,
|
333 |
+
transcribe_button: gr.update(interactive=False),
|
334 |
+
selected_episode_index_state: None
|
335 |
+
}
|
336 |
+
except Exception as e:
|
337 |
+
print(f"加载音频时发生错误: {e}")
|
338 |
+
traceback.print_exc()
|
339 |
+
return {
|
340 |
+
audio_player: gr.update(value=None),
|
341 |
+
current_audio_url_state: None,
|
342 |
+
status_message_area: gr.update(value=f"加载音频时发生严重错误: {e}"),
|
343 |
+
episode_shownotes: gr.update(value="", visible=False),
|
344 |
+
transcription_output_df: gr.update(value=None),
|
345 |
+
local_audio_file_path: None,
|
346 |
+
transcribe_button: gr.update(interactive=False),
|
347 |
+
selected_episode_index_state: None
|
348 |
+
}
|
349 |
+
|
350 |
+
def disable_buttons_before_transcription(local_audio_file_path: str):
|
351 |
+
"""在开始转录前禁用按钮"""
|
352 |
+
print("禁用界面按钮以防止转录期间的交互")
|
353 |
+
return {
|
354 |
+
parse_button: gr.update(interactive=False),
|
355 |
+
episode_dropdown: gr.update(interactive=False),
|
356 |
+
transcribe_button: gr.update(interactive=False),
|
357 |
+
status_message_area: gr.update(value="开始转录过程,请耐心等待...")
|
358 |
+
}
|
359 |
+
|
360 |
+
def start_transcription(local_audio_file_path: str, podcast_data: PodcastChannel, selected_episode_index: int, progress=gr.Progress(track_tqdm=True)):
|
361 |
+
"""回调函数:开始转录当前加载的音频"""
|
362 |
+
print(f"开始转录本地音频文件: {local_audio_file_path}, 选中剧集索引: {selected_episode_index}")
|
363 |
+
|
364 |
+
if not local_audio_file_path or not os.path.exists(local_audio_file_path):
|
365 |
+
print("没有可用的本地音频文件")
|
366 |
+
return {
|
367 |
+
transcription_output_df: gr.update(value=None),
|
368 |
+
status_message_area: gr.update(value="错误:没有有效的音频文件用于转录。请先选择一个剧集。"),
|
369 |
+
parse_button: gr.update(interactive=True),
|
370 |
+
episode_dropdown: gr.update(interactive=True),
|
371 |
+
transcribe_button: gr.update(interactive=True)
|
372 |
+
}
|
373 |
+
|
374 |
+
try:
|
375 |
+
# 先更新状态消息并禁用按钮
|
376 |
+
progress(0, desc="初始化转录过程...")
|
377 |
+
|
378 |
+
# 使用progress回调来更新进度
|
379 |
+
progress(0.2, desc="加载音频文件...")
|
380 |
+
|
381 |
+
# 从文件加载音频
|
382 |
+
audio_segment = AudioSegment.from_file(local_audio_file_path)
|
383 |
+
print(f"音频加载完成,时长: {len(audio_segment)/1000}秒")
|
384 |
+
|
385 |
+
progress(0.4, desc="音频加载完成,开始转录 (此过程可能需要较长时间)...")
|
386 |
+
|
387 |
+
# 获取当前选中的剧集信息
|
388 |
+
episode_info = None
|
389 |
+
if podcast_data and podcast_data.episodes and selected_episode_index is not None:
|
390 |
+
if 0 <= selected_episode_index < len(podcast_data.episodes):
|
391 |
+
episode_info = podcast_data.episodes[selected_episode_index]
|
392 |
+
print(f"获取到当前选中剧集信息: {episode_info.title if episode_info else '无'}")
|
393 |
+
|
394 |
+
# 调用转录函数
|
395 |
+
print("开始转录音频...")
|
396 |
+
result: CombinedTranscriptionResult = transcribe_podcast_audio(audio_segment,
|
397 |
+
podcast_info=podcast_data,
|
398 |
+
episode_info=episode_info,
|
399 |
+
segmentation_batch_size=64,
|
400 |
+
parallel=True)
|
401 |
+
print(f"转录完成,结果: {result is not None}, 段落数: {len(result.segments) if result and result.segments else 0}")
|
402 |
+
progress(0.9, desc="转录完成,正在格式化结果...")
|
403 |
+
|
404 |
+
if result and result.segments:
|
405 |
+
formatted_segments = []
|
406 |
+
for seg in result.segments:
|
407 |
+
time_str = f"{seg.start:.2f}s - {seg.end:.2f}s"
|
408 |
+
formatted_segments.append([seg.speaker, seg.speaker_name, seg.text, time_str])
|
409 |
+
|
410 |
+
progress(1.0, desc="转录结果已生成!")
|
411 |
+
return {
|
412 |
+
transcription_output_df: gr.update(value=formatted_segments),
|
413 |
+
status_message_area: gr.update(value=f"转录完成!共 {len(result.segments)} 个片段。检测到 {result.num_speakers} 个说话人。"),
|
414 |
+
parse_button: gr.update(interactive=True),
|
415 |
+
episode_dropdown: gr.update(interactive=True),
|
416 |
+
transcribe_button: gr.update(interactive=True)
|
417 |
+
}
|
418 |
+
elif result: # 有 result 但没有 segments
|
419 |
+
progress(1.0, desc="转录完成,但无文本片段")
|
420 |
+
return {
|
421 |
+
transcription_output_df: gr.update(value=None),
|
422 |
+
status_message_area: gr.update(value="转录完成,但未生成任何文本片段。"),
|
423 |
+
parse_button: gr.update(interactive=True),
|
424 |
+
episode_dropdown: gr.update(interactive=True),
|
425 |
+
transcribe_button: gr.update(interactive=True)
|
426 |
+
}
|
427 |
+
else: # result 为 None
|
428 |
+
progress(1.0, desc="转录失败")
|
429 |
+
return {
|
430 |
+
transcription_output_df: gr.update(value=None),
|
431 |
+
status_message_area: gr.update(value="转录失败,未能获取结果。"),
|
432 |
+
parse_button: gr.update(interactive=True),
|
433 |
+
episode_dropdown: gr.update(interactive=True),
|
434 |
+
transcribe_button: gr.update(interactive=True)
|
435 |
+
}
|
436 |
+
except Exception as e:
|
437 |
+
print(f"转录过程中发生错误: {e}")
|
438 |
+
traceback.print_exc()
|
439 |
+
progress(1.0, desc="转录失败: 处理错误")
|
440 |
+
return {
|
441 |
+
transcription_output_df: gr.update(value=None),
|
442 |
+
status_message_area: gr.update(value=f"转录过程中发生严重错误: {e}"),
|
443 |
+
parse_button: gr.update(interactive=True),
|
444 |
+
episode_dropdown: gr.update(interactive=True),
|
445 |
+
transcribe_button: gr.update(interactive=True)
|
446 |
+
}
|
447 |
+
|
448 |
+
# --- Gradio 界面定义 ---
|
449 |
+
with gr.Blocks(title="播客转录工具 v2", css="""
|
450 |
+
.status-message-container {
|
451 |
+
min-height: 50px;
|
452 |
+
height: auto;
|
453 |
+
max-height: none;
|
454 |
+
overflow-y: visible;
|
455 |
+
white-space: normal;
|
456 |
+
word-wrap: break-word;
|
457 |
+
margin-top: 10px;
|
458 |
+
margin-bottom: 10px;
|
459 |
+
border-radius: 6px;
|
460 |
+
background-color: rgba(32, 36, 45, 0.03);
|
461 |
+
border: 1px solid rgba(32, 36, 45, 0.1);
|
462 |
+
color: #303030;
|
463 |
+
}
|
464 |
+
.episode-cover {
|
465 |
+
max-width: 300px;
|
466 |
+
max-height: 300px;
|
467 |
+
border-radius: 8px;
|
468 |
+
box-shadow: 0 4px 8px rgba(0,0,0,0.1);
|
469 |
+
}
|
470 |
+
""") as demo:
|
471 |
+
gr.Markdown("# 🎙️ 播客转录工具")
|
472 |
+
|
473 |
+
# 状态管理
|
474 |
+
podcast_data_state = gr.State(None) # 存储解析后的 PodcastChannel 对象
|
475 |
+
current_audio_url_state = gr.State(None) # 存储当前选中剧集的音频URL
|
476 |
+
local_audio_file_path = gr.State(None) # 存储下载到本地的音频文件路径
|
477 |
+
selected_episode_index_state = gr.State(None) # 存储当前选中的剧集索引
|
478 |
+
|
479 |
+
with gr.Row():
|
480 |
+
rss_url_input = gr.Textbox(
|
481 |
+
label="播客 RSS 地址",
|
482 |
+
placeholder="例如: https://your-podcast-feed.com/rss.xml",
|
483 |
+
elem_id="rss-url-input"
|
484 |
+
)
|
485 |
+
parse_button = gr.Button("🔗 解析 RSS", elem_id="parse-rss-button")
|
486 |
+
|
487 |
+
status_message_area = gr.Markdown(
|
488 |
+
"",
|
489 |
+
elem_id="status-message",
|
490 |
+
elem_classes="status-message-container", # 添加自定义CSS类
|
491 |
+
show_label=False
|
492 |
+
)
|
493 |
+
|
494 |
+
# 播客标题显示区域
|
495 |
+
podcast_title_display = gr.Markdown(
|
496 |
+
"",
|
497 |
+
visible=False,
|
498 |
+
elem_id="podcast-title-display"
|
499 |
+
)
|
500 |
+
|
501 |
+
episode_dropdown = gr.Dropdown(
|
502 |
+
label="选择剧集",
|
503 |
+
choices=[],
|
504 |
+
interactive=False, # 初始时不可交互,解析成功后设为 True
|
505 |
+
elem_id="episode-dropdown"
|
506 |
+
)
|
507 |
+
|
508 |
+
# 剧集信息显示区域
|
509 |
+
with gr.Row():
|
510 |
+
with gr.Column(scale=2):
|
511 |
+
episode_shownotes = gr.Markdown(
|
512 |
+
"",
|
513 |
+
visible=False,
|
514 |
+
elem_id="episode-shownotes"
|
515 |
+
)
|
516 |
+
|
517 |
+
audio_player = gr.Audio(
|
518 |
+
label="播客音频播放器",
|
519 |
+
interactive=False, # 音频源由程序控制,用户不能直接修改
|
520 |
+
elem_id="audio-player"
|
521 |
+
)
|
522 |
+
|
523 |
+
transcribe_button = gr.Button("🔊 开始转录", elem_id="transcribe-button", interactive=False)
|
524 |
+
|
525 |
+
gr.Markdown("## 📝 转录结果")
|
526 |
+
transcription_output_df = gr.DataFrame(
|
527 |
+
headers=["说话人ID", "说话人名称", "转录文本", "起止时间"],
|
528 |
+
interactive=False,
|
529 |
+
wrap=True, # 允许文本换行
|
530 |
+
row_count=(10, "dynamic"), # 显示10行,可滚动
|
531 |
+
col_count=(4, "fixed"),
|
532 |
+
elem_id="transcription-output"
|
533 |
+
)
|
534 |
+
|
535 |
+
# --- 事件处理 ---
|
536 |
+
parse_button.click(
|
537 |
+
fn=parse_rss_feed,
|
538 |
+
inputs=[rss_url_input],
|
539 |
+
outputs=[
|
540 |
+
status_message_area,
|
541 |
+
podcast_title_display,
|
542 |
+
episode_dropdown,
|
543 |
+
podcast_data_state,
|
544 |
+
audio_player,
|
545 |
+
current_audio_url_state,
|
546 |
+
episode_shownotes,
|
547 |
+
transcription_output_df,
|
548 |
+
transcribe_button,
|
549 |
+
selected_episode_index_state
|
550 |
+
]
|
551 |
+
)
|
552 |
+
|
553 |
+
episode_dropdown.change(
|
554 |
+
fn=load_episode_audio,
|
555 |
+
inputs=[episode_dropdown, podcast_data_state],
|
556 |
+
outputs=[
|
557 |
+
audio_player,
|
558 |
+
current_audio_url_state,
|
559 |
+
status_message_area,
|
560 |
+
episode_shownotes,
|
561 |
+
transcription_output_df,
|
562 |
+
local_audio_file_path,
|
563 |
+
transcribe_button,
|
564 |
+
selected_episode_index_state
|
565 |
+
]
|
566 |
+
)
|
567 |
+
|
568 |
+
# 首先禁用按钮,然后执行转录
|
569 |
+
transcribe_button.click(
|
570 |
+
fn=disable_buttons_before_transcription,
|
571 |
+
inputs=[local_audio_file_path],
|
572 |
+
outputs=[parse_button, episode_dropdown, transcribe_button, status_message_area]
|
573 |
+
).then(
|
574 |
+
fn=start_transcription,
|
575 |
+
inputs=[local_audio_file_path, podcast_data_state, selected_episode_index_state],
|
576 |
+
outputs=[transcription_output_df, status_message_area, parse_button, episode_dropdown, transcribe_button]
|
577 |
+
)
|
578 |
+
|
579 |
+
if __name__ == "__main__":
|
580 |
+
try:
|
581 |
+
# demo.launch(debug=True, share=True) # share=True 会生成一个公开链接
|
582 |
+
demo.launch(debug=True)
|
583 |
+
finally:
|
584 |
+
# 确保在应用程序退出时清理临时文件
|
585 |
+
cleanup_temp_files()
|