Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
924aa01
1
Parent(s):
48811fe
更新LLM模型为google/gemma-3-4b-it,移除不再使用的Phi-4模型,优化设备参数支持,增强说话人识别器的日志记录功能。
Browse files- examples/combined_podcast_transcription.py +4 -10
- examples/simple_llm.py +5 -5
- examples/simple_speaker_identify.py +4 -2
- src/podcast_transcribe/llm/llm_base.py +0 -7
- src/podcast_transcribe/llm/llm_gemma_transfomers.py +3 -2
- src/podcast_transcribe/llm/llm_phi4_transfomers.py +0 -369
- src/podcast_transcribe/llm/llm_router.py +23 -32
- src/podcast_transcribe/summary/speaker_identify.py +9 -2
- src/podcast_transcribe/transcriber.py +12 -20
examples/combined_podcast_transcription.py
CHANGED
@@ -9,9 +9,6 @@ sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
|
9 |
from src.podcast_transcribe.transcriber import transcribe_podcast_audio
|
10 |
from src.podcast_transcribe.audio import load_audio
|
11 |
from src.podcast_transcribe.rss.podcast_rss_parser import parse_rss_xml_content
|
12 |
-
from podcast_transcribe.llm.llm_gemma_mlx import GemmaMLXChatCompletion
|
13 |
-
from src.podcast_transcribe.schemas import EnhancedSegment, CombinedTranscriptionResult
|
14 |
-
from src.podcast_transcribe.summary.speaker_identify import recognize_speaker_names
|
15 |
|
16 |
def main():
|
17 |
"""主函数"""
|
@@ -20,9 +17,10 @@ def main():
|
|
20 |
# audio_file = Path("/Users/konie/Desktop/voices/lex_ai_john_carmack_30.wav")
|
21 |
|
22 |
# 模型配置
|
23 |
-
asr_model_name = "
|
24 |
diarization_model_name = "pyannote/speaker-diarization-3.1" # 说话人分离模型名称
|
25 |
-
|
|
|
26 |
device = "mps" # 设备类型
|
27 |
segmentation_batch_size = 64
|
28 |
parallel = True
|
@@ -60,13 +58,9 @@ def main():
|
|
60 |
result = transcribe_podcast_audio(audio,
|
61 |
podcast_info=mock_podcast_info,
|
62 |
episode_info=mock_episode_info,
|
63 |
-
asr_model_name=asr_model_name,
|
64 |
-
diarization_model_name=diarization_model_name,
|
65 |
-
llm_model_name=llm_model_path,
|
66 |
device=device,
|
67 |
segmentation_batch_size=segmentation_batch_size,
|
68 |
-
parallel=parallel,
|
69 |
-
llm_model_name=llm_model_path)
|
70 |
|
71 |
# 输出结果
|
72 |
print("\n转录结果:")
|
|
|
9 |
from src.podcast_transcribe.transcriber import transcribe_podcast_audio
|
10 |
from src.podcast_transcribe.audio import load_audio
|
11 |
from src.podcast_transcribe.rss.podcast_rss_parser import parse_rss_xml_content
|
|
|
|
|
|
|
12 |
|
13 |
def main():
|
14 |
"""主函数"""
|
|
|
17 |
# audio_file = Path("/Users/konie/Desktop/voices/lex_ai_john_carmack_30.wav")
|
18 |
|
19 |
# 模型配置
|
20 |
+
asr_model_name = "distil-whisper/distil-large-v3.5" # ASR模型名称
|
21 |
diarization_model_name = "pyannote/speaker-diarization-3.1" # 说话人分离模型名称
|
22 |
+
llm_model_name = "google/gemma-3-4b-it"
|
23 |
+
llm_provider = "gemma-transformers"
|
24 |
device = "mps" # 设备类型
|
25 |
segmentation_batch_size = 64
|
26 |
parallel = True
|
|
|
58 |
result = transcribe_podcast_audio(audio,
|
59 |
podcast_info=mock_podcast_info,
|
60 |
episode_info=mock_episode_info,
|
|
|
|
|
|
|
61 |
device=device,
|
62 |
segmentation_batch_size=segmentation_batch_size,
|
63 |
+
parallel=parallel,)
|
|
|
64 |
|
65 |
# 输出结果
|
66 |
print("\n转录结果:")
|
examples/simple_llm.py
CHANGED
@@ -1,11 +1,10 @@
|
|
1 |
-
|
2 |
# 添加项目根目录到Python路径
|
3 |
import sys
|
4 |
from pathlib import Path
|
5 |
|
6 |
sys.path.insert(0, str(Path(__file__).parent.parent))
|
7 |
|
8 |
-
from src.podcast_transcribe.llm.llm_phi4_transfomers import Phi4TransformersChatCompletion
|
9 |
from src.podcast_transcribe.llm.llm_gemma_mlx import GemmaMLXChatCompletion
|
10 |
from src.podcast_transcribe.llm.llm_gemma_transfomers import GemmaTransformersChatCompletion
|
11 |
|
@@ -14,6 +13,7 @@ if __name__ == "__main__":
|
|
14 |
# 示例用法:
|
15 |
print("正在初始化 LLM 聊天补全...")
|
16 |
try:
|
|
|
17 |
model_name = "google/gemma-3-4b-it"
|
18 |
use_4bit_quantization = False
|
19 |
device = "mps"
|
@@ -22,10 +22,10 @@ if __name__ == "__main__":
|
|
22 |
# 或者,如果您有更小、更快的模型,可以尝试使用,例如:"mlx-community/gemma-2b-it-8bit"
|
23 |
if model_name.startswith("mlx-community"):
|
24 |
gemma_chat = GemmaMLXChatCompletion(model_name=model_name)
|
25 |
-
elif model_name.startswith("microsoft"):
|
26 |
-
gemma_chat = Phi4TransformersChatCompletion(model_name=model_name, use_4bit_quantization=use_4bit_quantization, device=device)
|
27 |
else:
|
28 |
-
|
|
|
|
|
29 |
|
30 |
print("\n--- 示例 1: 简单用户查询 ---")
|
31 |
messages_example1 = [
|
|
|
1 |
+
import torch # 导入 torch
|
2 |
# 添加项目根目录到Python路径
|
3 |
import sys
|
4 |
from pathlib import Path
|
5 |
|
6 |
sys.path.insert(0, str(Path(__file__).parent.parent))
|
7 |
|
|
|
8 |
from src.podcast_transcribe.llm.llm_gemma_mlx import GemmaMLXChatCompletion
|
9 |
from src.podcast_transcribe.llm.llm_gemma_transfomers import GemmaTransformersChatCompletion
|
10 |
|
|
|
13 |
# 示例用法:
|
14 |
print("正在初始化 LLM 聊天补全...")
|
15 |
try:
|
16 |
+
# model_name = "mlx-community/gemma-3-12b-it-4bit-DWQ"
|
17 |
model_name = "google/gemma-3-4b-it"
|
18 |
use_4bit_quantization = False
|
19 |
device = "mps"
|
|
|
22 |
# 或者,如果您有更小、更快的模型,可以尝试使用,例如:"mlx-community/gemma-2b-it-8bit"
|
23 |
if model_name.startswith("mlx-community"):
|
24 |
gemma_chat = GemmaMLXChatCompletion(model_name=model_name)
|
|
|
|
|
25 |
else:
|
26 |
+
# 如果设备是 mps,则使用 float32 以增加稳定性
|
27 |
+
dtype_to_use = torch.float32 if device == "mps" else torch.float16
|
28 |
+
gemma_chat = GemmaTransformersChatCompletion(model_name=model_name, use_4bit_quantization=use_4bit_quantization, device=device, torch_dtype=dtype_to_use)
|
29 |
|
30 |
print("\n--- 示例 1: 简单用户查询 ---")
|
31 |
messages_example1 = [
|
examples/simple_speaker_identify.py
CHANGED
@@ -13,6 +13,7 @@ from src.podcast_transcribe.summary.speaker_identify import SpeakerIdentifier
|
|
13 |
if __name__ == '__main__':
|
14 |
transcribe_result_dump_file = Path.joinpath(Path(__file__).parent, "output", "lex_ai_john_carmack_1.transcription.json")
|
15 |
podcast_rss_xml_file = Path.joinpath(Path(__file__).parent, "input", "lexfridman.com.rss.xml")
|
|
|
16 |
|
17 |
# Load the transcription result
|
18 |
if not os.path.exists(transcribe_result_dump_file):
|
@@ -57,8 +58,9 @@ if __name__ == '__main__':
|
|
57 |
|
58 |
|
59 |
speaker_identifier = SpeakerIdentifier(
|
60 |
-
llm_model_name="
|
61 |
-
llm_provider="gemma-
|
|
|
62 |
)
|
63 |
|
64 |
# 3. Call the function
|
|
|
13 |
if __name__ == '__main__':
|
14 |
transcribe_result_dump_file = Path.joinpath(Path(__file__).parent, "output", "lex_ai_john_carmack_1.transcription.json")
|
15 |
podcast_rss_xml_file = Path.joinpath(Path(__file__).parent, "input", "lexfridman.com.rss.xml")
|
16 |
+
device = "mps"
|
17 |
|
18 |
# Load the transcription result
|
19 |
if not os.path.exists(transcribe_result_dump_file):
|
|
|
58 |
|
59 |
|
60 |
speaker_identifier = SpeakerIdentifier(
|
61 |
+
llm_model_name="google/gemma-3-4b-it",
|
62 |
+
llm_provider="gemma-transformers",
|
63 |
+
device=device
|
64 |
)
|
65 |
|
66 |
# 3. Call the function
|
src/podcast_transcribe/llm/llm_base.py
CHANGED
@@ -146,19 +146,12 @@ class BaseChatCompletion(ABC):
|
|
146 |
temperature: float = 0.7,
|
147 |
max_tokens: int = 2048,
|
148 |
top_p: float = 1.0,
|
149 |
-
model: Optional[str] = None,
|
150 |
**kwargs,
|
151 |
):
|
152 |
"""
|
153 |
创建聊天完成响应。
|
154 |
模仿OpenAI的ChatCompletion.create方法。
|
155 |
"""
|
156 |
-
if model and model != self.model_name:
|
157 |
-
# 这是一个简化的处理。在实际场景中,您可能希望加载新模型。
|
158 |
-
# 目前,我们将只打印一个警告并使用初始化的模型。
|
159 |
-
print(f"警告: 'model' 参数 ({model}) 与初始化的模型 ({self.model_name}) 不同。"
|
160 |
-
f"正在使用初始化的模型。要使用不同的模型,请重新初始化该类。")
|
161 |
-
|
162 |
# 为Gemma格式化消息
|
163 |
prompt_str = self._format_messages_for_gemma(messages)
|
164 |
|
|
|
146 |
temperature: float = 0.7,
|
147 |
max_tokens: int = 2048,
|
148 |
top_p: float = 1.0,
|
|
|
149 |
**kwargs,
|
150 |
):
|
151 |
"""
|
152 |
创建聊天完成响应。
|
153 |
模仿OpenAI的ChatCompletion.create方法。
|
154 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
# 为Gemma格式化消息
|
156 |
prompt_str = self._format_messages_for_gemma(messages)
|
157 |
|
src/podcast_transcribe/llm/llm_gemma_transfomers.py
CHANGED
@@ -13,7 +13,8 @@ class GemmaTransformersChatCompletion(TransformersBaseChatCompletion):
|
|
13 |
use_4bit_quantization: bool = False,
|
14 |
device_map: Optional[str] = None,
|
15 |
device: Optional[str] = None,
|
16 |
-
trust_remote_code: bool = True
|
|
|
17 |
):
|
18 |
# Gemma 使用 float16 作为默认数据类型
|
19 |
super().__init__(
|
@@ -22,7 +23,7 @@ class GemmaTransformersChatCompletion(TransformersBaseChatCompletion):
|
|
22 |
device_map=device_map,
|
23 |
device=device,
|
24 |
trust_remote_code=trust_remote_code,
|
25 |
-
torch_dtype=torch.float16
|
26 |
)
|
27 |
|
28 |
def _print_error_hints(self):
|
|
|
13 |
use_4bit_quantization: bool = False,
|
14 |
device_map: Optional[str] = None,
|
15 |
device: Optional[str] = None,
|
16 |
+
trust_remote_code: bool = True,
|
17 |
+
torch_dtype: Optional[torch.dtype] = None
|
18 |
):
|
19 |
# Gemma 使用 float16 作为默认数据类型
|
20 |
super().__init__(
|
|
|
23 |
device_map=device_map,
|
24 |
device=device,
|
25 |
trust_remote_code=trust_remote_code,
|
26 |
+
torch_dtype=torch_dtype if torch_dtype is not None else torch.float16
|
27 |
)
|
28 |
|
29 |
def _print_error_hints(self):
|
src/podcast_transcribe/llm/llm_phi4_transfomers.py
DELETED
@@ -1,369 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
3 |
-
from typing import List, Dict, Optional, Union, Literal
|
4 |
-
from .llm_base import TransformersBaseChatCompletion
|
5 |
-
|
6 |
-
|
7 |
-
class Phi4TransformersChatCompletion(TransformersBaseChatCompletion):
|
8 |
-
"""基于 Transformers 库的 Phi-4-mini-reasoning 聊天完成实现"""
|
9 |
-
|
10 |
-
def __init__(
|
11 |
-
self,
|
12 |
-
model_name: str = "microsoft/Phi-4-mini-reasoning",
|
13 |
-
use_4bit_quantization: bool = False,
|
14 |
-
device_map: Optional[str] = None,
|
15 |
-
device: Optional[str] = None,
|
16 |
-
trust_remote_code: bool = True
|
17 |
-
):
|
18 |
-
# Phi-4 使用 bfloat16 作为推荐数据类型
|
19 |
-
super().__init__(
|
20 |
-
model_name=model_name,
|
21 |
-
use_4bit_quantization=use_4bit_quantization,
|
22 |
-
device_map=device_map,
|
23 |
-
device=device,
|
24 |
-
trust_remote_code=trust_remote_code,
|
25 |
-
torch_dtype=torch.bfloat16
|
26 |
-
)
|
27 |
-
|
28 |
-
def _print_error_hints(self):
|
29 |
-
"""打印Phi-4特定的错误提示信息"""
|
30 |
-
super()._print_error_hints()
|
31 |
-
print("Phi-4 特殊要求:")
|
32 |
-
print("- 建议使用 Transformers >= 4.51.3")
|
33 |
-
print("- 推荐使用 bfloat16 数据类型")
|
34 |
-
print("- 模型支持 128K token 上下文长度")
|
35 |
-
|
36 |
-
def _format_phi4_messages(self, messages: List[Dict[str, str]]) -> str:
|
37 |
-
"""
|
38 |
-
格式化消息为 Phi-4 的聊天格式
|
39 |
-
Phi-4 使用特定的聊天模板格式
|
40 |
-
"""
|
41 |
-
# 使用 tokenizer 的内置聊天模板
|
42 |
-
if hasattr(self.tokenizer, 'apply_chat_template'):
|
43 |
-
return self.tokenizer.apply_chat_template(
|
44 |
-
messages,
|
45 |
-
tokenize=False,
|
46 |
-
add_generation_prompt=True
|
47 |
-
)
|
48 |
-
else:
|
49 |
-
# 如果没有聊天模板,使用 Phi-4 的标准格式
|
50 |
-
formatted_prompt = ""
|
51 |
-
for message in messages:
|
52 |
-
role = message.get("role", "user")
|
53 |
-
content = message.get("content", "")
|
54 |
-
|
55 |
-
if role == "system":
|
56 |
-
formatted_prompt += f"<|system|>\n{content}<|end|>\n"
|
57 |
-
elif role == "user":
|
58 |
-
formatted_prompt += f"<|user|>\n{content}<|end|>\n"
|
59 |
-
elif role == "assistant":
|
60 |
-
formatted_prompt += f"<|assistant|>\n{content}<|end|>\n"
|
61 |
-
|
62 |
-
# 添加助手开始标记
|
63 |
-
formatted_prompt += "<|assistant|>\n"
|
64 |
-
return formatted_prompt
|
65 |
-
|
66 |
-
def _generate_response(
|
67 |
-
self,
|
68 |
-
prompt_str: str,
|
69 |
-
temperature: float,
|
70 |
-
max_tokens: int,
|
71 |
-
top_p: float,
|
72 |
-
enable_reasoning: bool = True,
|
73 |
-
**kwargs
|
74 |
-
) -> str:
|
75 |
-
"""使用 transformers 生成响应,针对 Phi-4 推理功能优化"""
|
76 |
-
|
77 |
-
# 对提示进行编码
|
78 |
-
inputs = self.tokenizer.encode(prompt_str, return_tensors="pt")
|
79 |
-
|
80 |
-
# 移动输入到正确的设备
|
81 |
-
if self.device_map is None or self.device.type == "mps":
|
82 |
-
inputs = inputs.to(self.device)
|
83 |
-
|
84 |
-
# Phi-4-mini-reasoning 优化的生成参数
|
85 |
-
generation_config = {
|
86 |
-
"max_new_tokens": min(max_tokens, 32768), # Phi-4-mini 支持最大 32K token
|
87 |
-
"temperature": temperature,
|
88 |
-
"top_p": top_p,
|
89 |
-
"do_sample": True if temperature > 0 else False,
|
90 |
-
"pad_token_id": self.tokenizer.pad_token_id,
|
91 |
-
"eos_token_id": self.tokenizer.eos_token_id,
|
92 |
-
"repetition_penalty": kwargs.get("repetition_penalty", 1.1),
|
93 |
-
"no_repeat_ngram_size": kwargs.get("no_repeat_ngram_size", 3),
|
94 |
-
}
|
95 |
-
|
96 |
-
# 推理模式配置
|
97 |
-
if enable_reasoning and "reasoning" in self.model_name.lower():
|
98 |
-
# 为推理任务优化的配置
|
99 |
-
generation_config.update({
|
100 |
-
"temperature": max(temperature, 0.1), # 推理模式下保持一定的温度
|
101 |
-
"top_p": min(top_p, 0.95), # 推理模式下限制 top_p
|
102 |
-
"do_sample": True, # 推理模式下总是启用采样
|
103 |
-
"early_stopping": False, # 允许完整的推理过程
|
104 |
-
})
|
105 |
-
|
106 |
-
# 如果温度为0,使用贪婪解码
|
107 |
-
if temperature == 0:
|
108 |
-
generation_config["do_sample"] = False
|
109 |
-
generation_config.pop("temperature", None)
|
110 |
-
generation_config.pop("top_p", None)
|
111 |
-
|
112 |
-
try:
|
113 |
-
# 生成响应
|
114 |
-
with torch.no_grad():
|
115 |
-
outputs = self.model.generate(
|
116 |
-
inputs,
|
117 |
-
**generation_config
|
118 |
-
)
|
119 |
-
|
120 |
-
# 解码生成的文本,跳过输入部分
|
121 |
-
generated_tokens = outputs[0][len(inputs[0]):]
|
122 |
-
generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
123 |
-
|
124 |
-
return generated_text
|
125 |
-
|
126 |
-
except Exception as e:
|
127 |
-
print(f"生成响应时出错: {e}")
|
128 |
-
raise
|
129 |
-
|
130 |
-
def create(
|
131 |
-
self,
|
132 |
-
messages: List[Dict[str, str]],
|
133 |
-
temperature: float = 0.7,
|
134 |
-
max_tokens: int = 2048,
|
135 |
-
top_p: float = 1.0,
|
136 |
-
model: Optional[str] = None,
|
137 |
-
enable_reasoning: bool = True,
|
138 |
-
**kwargs,
|
139 |
-
):
|
140 |
-
"""
|
141 |
-
创建聊天完成响应,支持Phi-4特有的推理功能
|
142 |
-
"""
|
143 |
-
if model and model != self.model_name:
|
144 |
-
print(f"警告: 'model' 参数 ({model}) 与初始化的模型 ({self.model_name}) 不同。"
|
145 |
-
f"正在使用初始化的模型。要使用不同的模型,请重新初始化该类。")
|
146 |
-
|
147 |
-
# 检查是否为推理任务
|
148 |
-
is_reasoning_task = self._is_reasoning_task(messages)
|
149 |
-
|
150 |
-
# 格式化消息为 Phi-4 聊天格式
|
151 |
-
if is_reasoning_task and enable_reasoning:
|
152 |
-
prompt_str = self._format_reasoning_prompt(messages)
|
153 |
-
else:
|
154 |
-
prompt_str = self._format_phi4_messages(messages)
|
155 |
-
|
156 |
-
# 生成响应
|
157 |
-
response_text = self._generate_response(
|
158 |
-
prompt_str,
|
159 |
-
temperature,
|
160 |
-
max_tokens,
|
161 |
-
top_p,
|
162 |
-
enable_reasoning=enable_reasoning and is_reasoning_task,
|
163 |
-
**kwargs
|
164 |
-
)
|
165 |
-
|
166 |
-
# 后处理响应(使用基类的方法,但针对Phi-4调整)
|
167 |
-
assistant_message_content = self._post_process_phi4_response(response_text, prompt_str)
|
168 |
-
|
169 |
-
# 计算token使用量
|
170 |
-
token_usage = self._calculate_tokens(prompt_str, assistant_message_content)
|
171 |
-
|
172 |
-
# 构建响应对象
|
173 |
-
response = self._build_chat_completion_response(assistant_message_content, token_usage)
|
174 |
-
|
175 |
-
# 添加Phi-4特有的信息
|
176 |
-
response["reasoning_enabled"] = enable_reasoning and is_reasoning_task
|
177 |
-
|
178 |
-
return response
|
179 |
-
|
180 |
-
def _post_process_phi4_response(self, response_text: str, prompt_str: str) -> str:
|
181 |
-
"""
|
182 |
-
后处理Phi-4生成的响应文本
|
183 |
-
"""
|
184 |
-
# Phi-4的输出通常不包含输入提示,直接返回生成的内容
|
185 |
-
assistant_message_content = response_text.strip()
|
186 |
-
|
187 |
-
# 清理可能的特殊标记
|
188 |
-
if assistant_message_content.endswith("<|end|>"):
|
189 |
-
assistant_message_content = assistant_message_content[:-7].strip()
|
190 |
-
|
191 |
-
return assistant_message_content
|
192 |
-
|
193 |
-
def _is_reasoning_task(self, messages: List[Dict[str, str]]) -> bool:
|
194 |
-
"""检测是否为推理任务"""
|
195 |
-
reasoning_keywords = [
|
196 |
-
"解题", "推理", "计算", "证明", "分析", "逻辑", "步骤",
|
197 |
-
"solve", "reasoning", "calculate", "prove", "analyze", "logic", "step"
|
198 |
-
]
|
199 |
-
|
200 |
-
for message in messages:
|
201 |
-
content = message.get("content", "").lower()
|
202 |
-
if any(keyword in content for keyword in reasoning_keywords):
|
203 |
-
return True
|
204 |
-
|
205 |
-
return False
|
206 |
-
|
207 |
-
def _format_reasoning_prompt(self, messages: List[Dict[str, str]]) -> str:
|
208 |
-
"""
|
209 |
-
为推理任务格式化特殊的提示词
|
210 |
-
"""
|
211 |
-
# 添加推理指导的系统消息
|
212 |
-
reasoning_system_msg = {
|
213 |
-
"role": "system",
|
214 |
-
"content": "你是一个专业的数学推理助手。请逐步分析问题,展示详细的推理过程,包括:\n1. 问题理解\n2. 解题思路\n3. 具体步骤\n4. 最终答案\n\n每个步骤都要清晰明了。"
|
215 |
-
}
|
216 |
-
|
217 |
-
# 将推理系统消息添加到消息列表的开头
|
218 |
-
enhanced_messages = [reasoning_system_msg] + messages
|
219 |
-
|
220 |
-
# 使用标准格式化方法
|
221 |
-
return self._format_phi4_messages(enhanced_messages)
|
222 |
-
|
223 |
-
def reasoning_completion(
|
224 |
-
self,
|
225 |
-
messages: List[Dict[str, str]],
|
226 |
-
temperature: float = 0.3, # 推理任务使用较低的温度
|
227 |
-
max_tokens: int = 2048, # 推理任务需要更多 tokens
|
228 |
-
top_p: float = 0.9,
|
229 |
-
extract_reasoning_steps: bool = True,
|
230 |
-
**kwargs
|
231 |
-
) -> Dict[str, Union[str, Dict, List]]:
|
232 |
-
"""
|
233 |
-
专门用于推理任务的聊天完成接口
|
234 |
-
|
235 |
-
Args:
|
236 |
-
messages: 对话消息列表
|
237 |
-
temperature: 采样温度(推理任务建议使用较低值)
|
238 |
-
max_tokens: 最大生成token数量
|
239 |
-
top_p: top-p采样参数
|
240 |
-
extract_reasoning_steps: 是否提取推理步骤
|
241 |
-
**kwargs: 其他参数
|
242 |
-
|
243 |
-
Returns:
|
244 |
-
包含推理步骤的响应字典
|
245 |
-
"""
|
246 |
-
# 强制启用推理模式
|
247 |
-
response = self.create(
|
248 |
-
messages=messages,
|
249 |
-
temperature=temperature,
|
250 |
-
max_tokens=max_tokens,
|
251 |
-
top_p=top_p,
|
252 |
-
enable_reasoning=True,
|
253 |
-
**kwargs
|
254 |
-
)
|
255 |
-
|
256 |
-
if extract_reasoning_steps:
|
257 |
-
# 提取推理步骤
|
258 |
-
content = response["choices"][0]["message"]["content"]
|
259 |
-
reasoning_steps = self._extract_reasoning_steps(content)
|
260 |
-
response["reasoning_steps"] = reasoning_steps
|
261 |
-
|
262 |
-
return response
|
263 |
-
|
264 |
-
def _extract_reasoning_steps(self, content: str) -> List[Dict[str, str]]:
|
265 |
-
"""
|
266 |
-
从响应内容中提取推理步骤
|
267 |
-
"""
|
268 |
-
steps = []
|
269 |
-
lines = content.split('\n')
|
270 |
-
current_step = {"title": "", "content": ""}
|
271 |
-
|
272 |
-
step_patterns = [
|
273 |
-
"1. 问题理解", "2. 解题思路", "3. 具体步骤", "4. 最终答案",
|
274 |
-
"步骤", "分析", "解答", "结论", "reasoning", "step", "analysis", "solution"
|
275 |
-
]
|
276 |
-
|
277 |
-
for line in lines:
|
278 |
-
line = line.strip()
|
279 |
-
if not line:
|
280 |
-
continue
|
281 |
-
|
282 |
-
# 检查是否是新的步骤开始
|
283 |
-
is_new_step = any(pattern in line.lower() for pattern in step_patterns)
|
284 |
-
if is_new_step and current_step["content"]:
|
285 |
-
steps.append(current_step.copy())
|
286 |
-
current_step = {"title": line, "content": ""}
|
287 |
-
elif is_new_step:
|
288 |
-
current_step["title"] = line
|
289 |
-
else:
|
290 |
-
if current_step["title"]:
|
291 |
-
current_step["content"] += line + "\n"
|
292 |
-
else:
|
293 |
-
current_step["content"] = line + "\n"
|
294 |
-
|
295 |
-
# 添加最后一个步骤
|
296 |
-
if current_step["title"] or current_step["content"]:
|
297 |
-
steps.append(current_step)
|
298 |
-
|
299 |
-
return steps
|
300 |
-
|
301 |
-
def get_model_info(self) -> Dict[str, Union[str, bool, int]]:
|
302 |
-
"""获取 Phi-4 模型信息"""
|
303 |
-
model_info = super().get_model_info()
|
304 |
-
|
305 |
-
# 添加Phi-4特有的信息
|
306 |
-
model_info.update({
|
307 |
-
"model_family": "Phi-4-mini-reasoning",
|
308 |
-
"parameters": "3.8B",
|
309 |
-
"context_length": "128K tokens",
|
310 |
-
"specialization": "数学推理优化",
|
311 |
-
})
|
312 |
-
|
313 |
-
return model_info
|
314 |
-
|
315 |
-
|
316 |
-
# 工厂函数
|
317 |
-
def create_phi4_transformers_client(
|
318 |
-
model_name: str = "microsoft/Phi-4-mini-reasoning",
|
319 |
-
use_4bit_quantization: bool = False,
|
320 |
-
device: Optional[str] = None,
|
321 |
-
**kwargs
|
322 |
-
) -> Phi4TransformersChatCompletion:
|
323 |
-
"""
|
324 |
-
创建 Phi-4 Transformers 客户端的工厂函数
|
325 |
-
|
326 |
-
Args:
|
327 |
-
model_name: 模型名称,默认为 microsoft/Phi-4-mini-reasoning
|
328 |
-
use_4bit_quantization: 是否使用4bit量化
|
329 |
-
device: 指定设备 ("cpu", "cuda", "mps", 等)
|
330 |
-
**kwargs: 其他传递给构造函数的参数
|
331 |
-
|
332 |
-
Returns:
|
333 |
-
Phi4TransformersChatCompletion 实例
|
334 |
-
"""
|
335 |
-
return Phi4TransformersChatCompletion(
|
336 |
-
model_name=model_name,
|
337 |
-
use_4bit_quantization=use_4bit_quantization,
|
338 |
-
device=device,
|
339 |
-
**kwargs
|
340 |
-
)
|
341 |
-
|
342 |
-
def create_reasoning_client(
|
343 |
-
model_name: str = "microsoft/Phi-4-mini-reasoning",
|
344 |
-
use_4bit_quantization: bool = False,
|
345 |
-
device: Optional[str] = None,
|
346 |
-
**kwargs
|
347 |
-
) -> Phi4TransformersChatCompletion:
|
348 |
-
"""
|
349 |
-
创建专门用于推理任务的 Phi-4 客户端
|
350 |
-
|
351 |
-
Args:
|
352 |
-
model_name: 模型名称,推荐使用 microsoft/Phi-4-mini-reasoning
|
353 |
-
use_4bit_quantization: 是否使用4bit量化
|
354 |
-
device: 指定设备 ("cpu", "cuda", "mps", 等)
|
355 |
-
**kwargs: 其他传递给构造函数的参数
|
356 |
-
|
357 |
-
Returns:
|
358 |
-
优化了推理功能的 Phi4TransformersChatCompletion 实例
|
359 |
-
"""
|
360 |
-
# 确保使用推理模型
|
361 |
-
if "reasoning" not in model_name.lower():
|
362 |
-
print("警告: 建议使用包含 'reasoning' 的模型名称以获得最佳推理性能")
|
363 |
-
|
364 |
-
return Phi4TransformersChatCompletion(
|
365 |
-
model_name=model_name,
|
366 |
-
use_4bit_quantization=use_4bit_quantization,
|
367 |
-
device=device,
|
368 |
-
**kwargs
|
369 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/podcast_transcribe/llm/llm_router.py
CHANGED
@@ -4,13 +4,13 @@ LLM模型调用路由器
|
|
4 |
"""
|
5 |
|
6 |
import logging
|
|
|
7 |
from typing import Dict, Any, Optional, List, Union
|
8 |
|
9 |
import spaces
|
10 |
from .llm_base import BaseChatCompletion
|
11 |
from . import llm_gemma_mlx
|
12 |
from . import llm_gemma_transfomers
|
13 |
-
from . import llm_phi4_transfomers
|
14 |
|
15 |
# 配置日志
|
16 |
logger = logging.getLogger("llm")
|
@@ -39,19 +39,9 @@ class LLMRouter:
|
|
39 |
"default_model": "google/gemma-3-4b-it",
|
40 |
"supported_params": [
|
41 |
"model_name", "use_4bit_quantization", "device_map",
|
42 |
-
"device", "trust_remote_code"
|
43 |
],
|
44 |
"description": "基于Transformers库的Gemma聊天完成实现"
|
45 |
-
},
|
46 |
-
"phi4-transformers": {
|
47 |
-
"module_path": "llm_phi4_transfomers",
|
48 |
-
"class_name": "Phi4TransformersChatCompletion",
|
49 |
-
"default_model": "microsoft/Phi-4-reasoning",
|
50 |
-
"supported_params": [
|
51 |
-
"model_name", "use_4bit_quantization", "device_map",
|
52 |
-
"device", "trust_remote_code", "enable_reasoning"
|
53 |
-
],
|
54 |
-
"description": "基于Transformers库的Phi-4推理聊天完成实现"
|
55 |
}
|
56 |
}
|
57 |
|
@@ -77,8 +67,6 @@ class LLMRouter:
|
|
77 |
module = llm_gemma_mlx
|
78 |
elif module_path == "llm_gemma_transfomers":
|
79 |
module = llm_gemma_transfomers
|
80 |
-
elif module_path == "llm_phi4_transfomers":
|
81 |
-
module = llm_phi4_transfomers
|
82 |
else:
|
83 |
raise ImportError(f"未找到模块: {module_path}")
|
84 |
|
@@ -219,6 +207,12 @@ class LLMRouter:
|
|
219 |
if model is not None:
|
220 |
kwargs["model_name"] = model
|
221 |
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
# 获取或创建LLM实例
|
223 |
llm_instance = self._get_or_create_instance(provider, **kwargs)
|
224 |
|
@@ -242,7 +236,7 @@ class LLMRouter:
|
|
242 |
def reasoning_completion(
|
243 |
self,
|
244 |
messages: List[Dict[str, str]],
|
245 |
-
provider: str = "
|
246 |
temperature: float = 0.3,
|
247 |
max_tokens: int = 2048,
|
248 |
top_p: float = 0.9,
|
@@ -255,7 +249,7 @@ class LLMRouter:
|
|
255 |
|
256 |
参数:
|
257 |
messages: 消息列表,每个消息包含role和content
|
258 |
-
provider: LLM提供者名称,默认使用
|
259 |
temperature: 温度参数(推理任务建议使用较低值)
|
260 |
max_tokens: 最大生成token数
|
261 |
top_p: nucleus采样参数
|
@@ -269,14 +263,20 @@ class LLMRouter:
|
|
269 |
logger.info(f"使用provider '{provider}' 进行推理完成,消息数量: {len(messages)}")
|
270 |
|
271 |
# 确保使用支持推理的provider
|
272 |
-
if provider not in ["
|
273 |
-
logger.warning(f"Provider '{provider}' 可能不支持推理功能,建议使用 '
|
274 |
|
275 |
try:
|
276 |
# 如果提供了model参数,添加到kwargs中
|
277 |
if model is not None:
|
278 |
kwargs["model_name"] = model
|
279 |
|
|
|
|
|
|
|
|
|
|
|
|
|
280 |
# 获取或创建LLM实例
|
281 |
llm_instance = self._get_or_create_instance(provider, **kwargs)
|
282 |
|
@@ -372,7 +372,7 @@ _router = LLMRouter()
|
|
372 |
@spaces.GPU(duration=60)
|
373 |
def chat_completion(
|
374 |
messages: List[Dict[str, str]],
|
375 |
-
provider: str = "gemma-
|
376 |
temperature: float = 0.7,
|
377 |
max_tokens: int = 2048,
|
378 |
top_p: float = 1.0,
|
@@ -391,7 +391,6 @@ def chat_completion(
|
|
391 |
provider: LLM提供者,可选值:
|
392 |
- "gemma-mlx": 基于MLX库的Gemma聊天完成实现
|
393 |
- "gemma-transformers": 基于Transformers库的Gemma聊天完成实现
|
394 |
-
- "phi4-transformers": 基于Transformers库的Phi-4推理聊天完成实现
|
395 |
temperature: 温度参数,控制生成的随机性 (0.0-2.0)
|
396 |
max_tokens: 最大生成token数
|
397 |
top_p: nucleus采样参数 (0.0-1.0)
|
@@ -421,14 +420,6 @@ def chat_completion(
|
|
421 |
use_4bit_quantization=True
|
422 |
)
|
423 |
|
424 |
-
# 使用Phi-4推理实现
|
425 |
-
response = chat_completion(
|
426 |
-
messages=[{"role": "user", "content": "解这个数学题:2x + 5 = 15"}],
|
427 |
-
provider="phi4-transformers",
|
428 |
-
model="microsoft/Phi-4-mini-reasoning",
|
429 |
-
device="cuda"
|
430 |
-
)
|
431 |
-
|
432 |
# 自定义参数
|
433 |
response = chat_completion(
|
434 |
messages=[
|
@@ -466,7 +457,7 @@ def chat_completion(
|
|
466 |
@spaces.GPU(duration=60)
|
467 |
def reasoning_completion(
|
468 |
messages: List[Dict[str, str]],
|
469 |
-
provider: str = "
|
470 |
temperature: float = 0.3,
|
471 |
max_tokens: int = 2048,
|
472 |
top_p: float = 0.9,
|
@@ -483,7 +474,7 @@ def reasoning_completion(
|
|
483 |
|
484 |
参数:
|
485 |
messages: 消息列表,每个消息包含role和content字段
|
486 |
-
provider: LLM提供者,默认使用
|
487 |
temperature: 温度参数(推理任务建议使用较低值)
|
488 |
max_tokens: 最大生成token数
|
489 |
top_p: nucleus采样参数
|
@@ -502,14 +493,14 @@ def reasoning_completion(
|
|
502 |
# 数学推理任务
|
503 |
response = reasoning_completion(
|
504 |
messages=[{"role": "user", "content": "解这个方程:3x + 7 = 22"}],
|
505 |
-
provider="
|
506 |
extract_reasoning_steps=True
|
507 |
)
|
508 |
|
509 |
# 逻辑推理任务
|
510 |
response = reasoning_completion(
|
511 |
messages=[{"role": "user", "content": "如果所有的猫都是动物,而小花是一只猫,那么小花是什么?"}],
|
512 |
-
provider="
|
513 |
temperature=0.2
|
514 |
)
|
515 |
"""
|
|
|
4 |
"""
|
5 |
|
6 |
import logging
|
7 |
+
import torch
|
8 |
from typing import Dict, Any, Optional, List, Union
|
9 |
|
10 |
import spaces
|
11 |
from .llm_base import BaseChatCompletion
|
12 |
from . import llm_gemma_mlx
|
13 |
from . import llm_gemma_transfomers
|
|
|
14 |
|
15 |
# 配置日志
|
16 |
logger = logging.getLogger("llm")
|
|
|
39 |
"default_model": "google/gemma-3-4b-it",
|
40 |
"supported_params": [
|
41 |
"model_name", "use_4bit_quantization", "device_map",
|
42 |
+
"device", "trust_remote_code", "torch_dtype"
|
43 |
],
|
44 |
"description": "基于Transformers库的Gemma聊天完成实现"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
}
|
46 |
}
|
47 |
|
|
|
67 |
module = llm_gemma_mlx
|
68 |
elif module_path == "llm_gemma_transfomers":
|
69 |
module = llm_gemma_transfomers
|
|
|
|
|
70 |
else:
|
71 |
raise ImportError(f"未找到模块: {module_path}")
|
72 |
|
|
|
207 |
if model is not None:
|
208 |
kwargs["model_name"] = model
|
209 |
|
210 |
+
# 如果设备是 mps,并且是 transformers provider,则强制使用 float32
|
211 |
+
current_device = kwargs.get("device")
|
212 |
+
if current_device == "mps":
|
213 |
+
if provider == "gemma-transformers":
|
214 |
+
kwargs["torch_dtype"] = torch.float32
|
215 |
+
|
216 |
# 获取或创建LLM实例
|
217 |
llm_instance = self._get_or_create_instance(provider, **kwargs)
|
218 |
|
|
|
236 |
def reasoning_completion(
|
237 |
self,
|
238 |
messages: List[Dict[str, str]],
|
239 |
+
provider: str = "gemma-transformers",
|
240 |
temperature: float = 0.3,
|
241 |
max_tokens: int = 2048,
|
242 |
top_p: float = 0.9,
|
|
|
249 |
|
250 |
参数:
|
251 |
messages: 消息列表,每个消息包含role和content
|
252 |
+
provider: LLM提供者名称,默认使用gemma-transformers
|
253 |
temperature: 温度参数(推理任务建议使用较低值)
|
254 |
max_tokens: 最大生成token数
|
255 |
top_p: nucleus采样参数
|
|
|
263 |
logger.info(f"使用provider '{provider}' 进行推理完成,消息数量: {len(messages)}")
|
264 |
|
265 |
# 确保使用支持推理的provider
|
266 |
+
if provider not in ["gemma-transformers"]:
|
267 |
+
logger.warning(f"Provider '{provider}' 可能不支持推理功能,建议使用 'gemma-transformers'")
|
268 |
|
269 |
try:
|
270 |
# 如果提供了model参数,添加到kwargs中
|
271 |
if model is not None:
|
272 |
kwargs["model_name"] = model
|
273 |
|
274 |
+
# 如果设备是 mps,并且是 transformers provider,则强制使用 float32
|
275 |
+
current_device = kwargs.get("device")
|
276 |
+
if current_device == "mps":
|
277 |
+
if provider == "gemma-transformers":
|
278 |
+
kwargs["torch_dtype"] = torch.float32
|
279 |
+
|
280 |
# 获取或创建LLM实例
|
281 |
llm_instance = self._get_or_create_instance(provider, **kwargs)
|
282 |
|
|
|
372 |
@spaces.GPU(duration=60)
|
373 |
def chat_completion(
|
374 |
messages: List[Dict[str, str]],
|
375 |
+
provider: str = "gemma-transformers",
|
376 |
temperature: float = 0.7,
|
377 |
max_tokens: int = 2048,
|
378 |
top_p: float = 1.0,
|
|
|
391 |
provider: LLM提供者,可选值:
|
392 |
- "gemma-mlx": 基于MLX库的Gemma聊天完成实现
|
393 |
- "gemma-transformers": 基于Transformers库的Gemma聊天完成实现
|
|
|
394 |
temperature: 温度参数,控制生成的随机性 (0.0-2.0)
|
395 |
max_tokens: 最大生成token数
|
396 |
top_p: nucleus采样参数 (0.0-1.0)
|
|
|
420 |
use_4bit_quantization=True
|
421 |
)
|
422 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
423 |
# 自定义参数
|
424 |
response = chat_completion(
|
425 |
messages=[
|
|
|
457 |
@spaces.GPU(duration=60)
|
458 |
def reasoning_completion(
|
459 |
messages: List[Dict[str, str]],
|
460 |
+
provider: str = "gemma-transformers",
|
461 |
temperature: float = 0.3,
|
462 |
max_tokens: int = 2048,
|
463 |
top_p: float = 0.9,
|
|
|
474 |
|
475 |
参数:
|
476 |
messages: 消息列表,每个消息包含role和content字段
|
477 |
+
provider: LLM提供者,默认使用gemma-transformers
|
478 |
temperature: 温度参数(推理任务建议使用较低值)
|
479 |
max_tokens: 最大生成token数
|
480 |
top_p: nucleus采样参数
|
|
|
493 |
# 数学推理任务
|
494 |
response = reasoning_completion(
|
495 |
messages=[{"role": "user", "content": "解这个方程:3x + 7 = 22"}],
|
496 |
+
provider="gemma-transformers",
|
497 |
extract_reasoning_steps=True
|
498 |
)
|
499 |
|
500 |
# 逻辑推理任务
|
501 |
response = reasoning_completion(
|
502 |
messages=[{"role": "user", "content": "如果所有的猫都是动物,而小花是一只猫,那么小花是什么?"}],
|
503 |
+
provider="gemma-transformers",
|
504 |
temperature=0.2
|
505 |
)
|
506 |
"""
|
src/podcast_transcribe/summary/speaker_identify.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from typing import List, Dict, Optional
|
2 |
import json
|
3 |
import re
|
@@ -5,22 +6,26 @@ import re
|
|
5 |
from ..schemas import EnhancedSegment, PodcastChannel, PodcastEpisode
|
6 |
from ..llm import llm_router
|
7 |
|
|
|
|
|
8 |
|
9 |
class SpeakerIdentifier:
|
10 |
"""
|
11 |
说话人识别器类,用于根据转录分段和播客元数据识别说话人的真实姓名或昵称
|
12 |
"""
|
13 |
|
14 |
-
def __init__(self, llm_model_name: str, llm_provider: str):
|
15 |
"""
|
16 |
初始化说话人识别器
|
17 |
|
18 |
参数:
|
19 |
llm_model_name: LLM模型名称,如果为None则使用默认模型
|
20 |
llm_provider: LLM提供者,默认为"gemma-mlx"
|
|
|
21 |
"""
|
22 |
self.llm_model_name = llm_model_name
|
23 |
self.llm_provider = llm_provider
|
|
|
24 |
|
25 |
def _clean_html(self, html_string: Optional[str]) -> str:
|
26 |
"""
|
@@ -280,8 +285,10 @@ Please begin your analysis and provide the JSON result.
|
|
280 |
provider=self.llm_provider,
|
281 |
model=self.llm_model_name,
|
282 |
temperature=0.1,
|
283 |
-
max_tokens=1024
|
|
|
284 |
)
|
|
|
285 |
assistant_response_content = response["choices"][0]["message"]["content"]
|
286 |
|
287 |
parsed_llm_output = None
|
|
|
1 |
+
import logging
|
2 |
from typing import List, Dict, Optional
|
3 |
import json
|
4 |
import re
|
|
|
6 |
from ..schemas import EnhancedSegment, PodcastChannel, PodcastEpisode
|
7 |
from ..llm import llm_router
|
8 |
|
9 |
+
# 配置日志
|
10 |
+
logger = logging.getLogger("speaker_identify")
|
11 |
|
12 |
class SpeakerIdentifier:
|
13 |
"""
|
14 |
说话人识别器类,用于根据转录分段和播客元数据识别说话人的真实姓名或昵称
|
15 |
"""
|
16 |
|
17 |
+
def __init__(self, llm_model_name: str, llm_provider: str, device: Optional[str] = None):
|
18 |
"""
|
19 |
初始化说话人识别器
|
20 |
|
21 |
参数:
|
22 |
llm_model_name: LLM模型名称,如果为None则使用默认模型
|
23 |
llm_provider: LLM提供者,默认为"gemma-mlx"
|
24 |
+
device: 计算设备,例如 "cpu", "cuda", "mps"
|
25 |
"""
|
26 |
self.llm_model_name = llm_model_name
|
27 |
self.llm_provider = llm_provider
|
28 |
+
self.device = device
|
29 |
|
30 |
def _clean_html(self, html_string: Optional[str]) -> str:
|
31 |
"""
|
|
|
285 |
provider=self.llm_provider,
|
286 |
model=self.llm_model_name,
|
287 |
temperature=0.1,
|
288 |
+
max_tokens=1024,
|
289 |
+
device=self.device
|
290 |
)
|
291 |
+
logger.info(f"LLM调用日志,请求参数:【{messages}】, 响应: 【{response}】")
|
292 |
assistant_response_content = response["choices"][0]["message"]["content"]
|
293 |
|
294 |
parsed_llm_output = None
|
src/podcast_transcribe/transcriber.py
CHANGED
@@ -29,8 +29,8 @@ class CombinedTranscriber:
|
|
29 |
asr_provider: str,
|
30 |
diarization_provider: str,
|
31 |
diarization_model_name: str,
|
32 |
-
llm_model_name:
|
33 |
-
llm_provider:
|
34 |
device: Optional[str] = None,
|
35 |
segmentation_batch_size: int = 64,
|
36 |
parallel: bool = False,
|
@@ -43,6 +43,8 @@ class CombinedTranscriber:
|
|
43 |
asr_provider: ASR提供者名称
|
44 |
diarization_provider: 说话人分离提供者名称
|
45 |
diarization_model_name: 说话人分离模型名称
|
|
|
|
|
46 |
device: 推理设备,'cpu'或'cuda'
|
47 |
segmentation_batch_size: 分割批处理大小,默认为64
|
48 |
parallel: 是否并行执行ASR和说话人分离,默认为False
|
@@ -51,23 +53,10 @@ class CombinedTranscriber:
|
|
51 |
import torch
|
52 |
if torch.backends.mps.is_available():
|
53 |
device = "mps"
|
54 |
-
if not llm_model_name:
|
55 |
-
llm_model_name = "mlx-community/gemma-3-12b-it-4bit-DWQ"
|
56 |
-
if not llm_provider:
|
57 |
-
llm_provider = "gemma-mlx"
|
58 |
-
|
59 |
elif torch.cuda.is_available():
|
60 |
device = "cuda"
|
61 |
-
if not llm_model_name:
|
62 |
-
llm_model_name = "google/gemma-3-4b-it"
|
63 |
-
if not llm_provider:
|
64 |
-
llm_provider = "gemma-transformers"
|
65 |
else:
|
66 |
device = "cpu"
|
67 |
-
if not llm_model_name:
|
68 |
-
llm_model_name = "google/gemma-3-4b-it"
|
69 |
-
if not llm_provider:
|
70 |
-
llm_provider = "gemma-transformers"
|
71 |
|
72 |
self.asr_model_name = asr_model_name
|
73 |
self.asr_provider = asr_provider
|
@@ -79,7 +68,8 @@ class CombinedTranscriber:
|
|
79 |
|
80 |
self.speaker_identifier = SpeakerIdentifier(
|
81 |
llm_model_name=llm_model_name,
|
82 |
-
llm_provider=llm_provider
|
|
|
83 |
)
|
84 |
|
85 |
logger.info(f"初始化组合转录器,ASR提供者: {asr_provider},ASR模型: {asr_model_name},分离提供者: {diarization_provider},分离模型: {diarization_model_name},分割批处理大小: {segmentation_batch_size},并行执行: {parallel},推理设备: {device}")
|
@@ -513,6 +503,8 @@ def transcribe_audio(
|
|
513 |
asr_provider=asr_provider,
|
514 |
diarization_model_name=diarization_model_name,
|
515 |
diarization_provider=diarization_provider,
|
|
|
|
|
516 |
device=device,
|
517 |
segmentation_batch_size=segmentation_batch_size,
|
518 |
parallel=parallel
|
@@ -529,8 +521,8 @@ def transcribe_podcast_audio(
|
|
529 |
asr_provider: str = "distil_whisper_transformers",
|
530 |
diarization_model_name: str = "pyannote/speaker-diarization-3.1",
|
531 |
diarization_provider: str = "pyannote_transformers",
|
532 |
-
llm_model_name:
|
533 |
-
llm_provider:
|
534 |
device: Optional[str] = None,
|
535 |
segmentation_batch_size: int = 64,
|
536 |
parallel: bool = False,
|
@@ -546,8 +538,8 @@ def transcribe_podcast_audio(
|
|
546 |
asr_provider: ASR提供者名称
|
547 |
diarization_provider: 说话人分离提供者名称
|
548 |
diarization_model_name: 说话人分离模型名称
|
549 |
-
llm_model_name: LLM
|
550 |
-
llm_provider: LLM
|
551 |
device: 推理设备,'cpu'或'cuda'
|
552 |
segmentation_batch_size: 分割批处理大小,默认为64
|
553 |
parallel: 是否并行执行ASR和说话人分离,默认为False
|
|
|
29 |
asr_provider: str,
|
30 |
diarization_provider: str,
|
31 |
diarization_model_name: str,
|
32 |
+
llm_model_name: str,
|
33 |
+
llm_provider: str,
|
34 |
device: Optional[str] = None,
|
35 |
segmentation_batch_size: int = 64,
|
36 |
parallel: bool = False,
|
|
|
43 |
asr_provider: ASR提供者名称
|
44 |
diarization_provider: 说话人分离提供者名称
|
45 |
diarization_model_name: 说话人分离模型名称
|
46 |
+
llm_model_name: LLM模型名称
|
47 |
+
llm_provider: LLM提供者名称
|
48 |
device: 推理设备,'cpu'或'cuda'
|
49 |
segmentation_batch_size: 分割批处理大小,默认为64
|
50 |
parallel: 是否并行执行ASR和说话人分离,默认为False
|
|
|
53 |
import torch
|
54 |
if torch.backends.mps.is_available():
|
55 |
device = "mps"
|
|
|
|
|
|
|
|
|
|
|
56 |
elif torch.cuda.is_available():
|
57 |
device = "cuda"
|
|
|
|
|
|
|
|
|
58 |
else:
|
59 |
device = "cpu"
|
|
|
|
|
|
|
|
|
60 |
|
61 |
self.asr_model_name = asr_model_name
|
62 |
self.asr_provider = asr_provider
|
|
|
68 |
|
69 |
self.speaker_identifier = SpeakerIdentifier(
|
70 |
llm_model_name=llm_model_name,
|
71 |
+
llm_provider=llm_provider,
|
72 |
+
device=device
|
73 |
)
|
74 |
|
75 |
logger.info(f"初始化组合转录器,ASR提供者: {asr_provider},ASR模型: {asr_model_name},分离提供者: {diarization_provider},分离模型: {diarization_model_name},分割批处理大小: {segmentation_batch_size},并行执行: {parallel},推理设备: {device}")
|
|
|
503 |
asr_provider=asr_provider,
|
504 |
diarization_model_name=diarization_model_name,
|
505 |
diarization_provider=diarization_provider,
|
506 |
+
llm_model_name="",
|
507 |
+
llm_provider="",
|
508 |
device=device,
|
509 |
segmentation_batch_size=segmentation_batch_size,
|
510 |
parallel=parallel
|
|
|
521 |
asr_provider: str = "distil_whisper_transformers",
|
522 |
diarization_model_name: str = "pyannote/speaker-diarization-3.1",
|
523 |
diarization_provider: str = "pyannote_transformers",
|
524 |
+
llm_model_name: str = "google/gemma-3-4b-it",
|
525 |
+
llm_provider: str = "gemma-transformers",
|
526 |
device: Optional[str] = None,
|
527 |
segmentation_batch_size: int = 64,
|
528 |
parallel: bool = False,
|
|
|
538 |
asr_provider: ASR提供者名称
|
539 |
diarization_provider: 说话人分离提供者名称
|
540 |
diarization_model_name: 说话人分离模型名称
|
541 |
+
llm_model_name: LLM模型名称
|
542 |
+
llm_provider: LLM提供者名称
|
543 |
device: 推理设备,'cpu'或'cuda'
|
544 |
segmentation_batch_size: 分割批处理大小,默认为64
|
545 |
parallel: 是否并行执行ASR和说话人分离,默认为False
|