konieshadow commited on
Commit
924aa01
·
1 Parent(s): 48811fe

更新LLM模型为google/gemma-3-4b-it,移除不再使用的Phi-4模型,优化设备参数支持,增强说话人识别器的日志记录功能。

Browse files
examples/combined_podcast_transcription.py CHANGED
@@ -9,9 +9,6 @@ sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
9
  from src.podcast_transcribe.transcriber import transcribe_podcast_audio
10
  from src.podcast_transcribe.audio import load_audio
11
  from src.podcast_transcribe.rss.podcast_rss_parser import parse_rss_xml_content
12
- from podcast_transcribe.llm.llm_gemma_mlx import GemmaMLXChatCompletion
13
- from src.podcast_transcribe.schemas import EnhancedSegment, CombinedTranscriptionResult
14
- from src.podcast_transcribe.summary.speaker_identify import recognize_speaker_names
15
 
16
  def main():
17
  """主函数"""
@@ -20,9 +17,10 @@ def main():
20
  # audio_file = Path("/Users/konie/Desktop/voices/lex_ai_john_carmack_30.wav")
21
 
22
  # 模型配置
23
- asr_model_name = "mlx-community/" # ASR模型名称
24
  diarization_model_name = "pyannote/speaker-diarization-3.1" # 说话人分离模型名称
25
- llm_model_path = "mlx-community/gemma-3-12b-it-4bit-DWQ"
 
26
  device = "mps" # 设备类型
27
  segmentation_batch_size = 64
28
  parallel = True
@@ -60,13 +58,9 @@ def main():
60
  result = transcribe_podcast_audio(audio,
61
  podcast_info=mock_podcast_info,
62
  episode_info=mock_episode_info,
63
- asr_model_name=asr_model_name,
64
- diarization_model_name=diarization_model_name,
65
- llm_model_name=llm_model_path,
66
  device=device,
67
  segmentation_batch_size=segmentation_batch_size,
68
- parallel=parallel,
69
- llm_model_name=llm_model_path)
70
 
71
  # 输出结果
72
  print("\n转录结果:")
 
9
  from src.podcast_transcribe.transcriber import transcribe_podcast_audio
10
  from src.podcast_transcribe.audio import load_audio
11
  from src.podcast_transcribe.rss.podcast_rss_parser import parse_rss_xml_content
 
 
 
12
 
13
  def main():
14
  """主函数"""
 
17
  # audio_file = Path("/Users/konie/Desktop/voices/lex_ai_john_carmack_30.wav")
18
 
19
  # 模型配置
20
+ asr_model_name = "distil-whisper/distil-large-v3.5" # ASR模型名称
21
  diarization_model_name = "pyannote/speaker-diarization-3.1" # 说话人分离模型名称
22
+ llm_model_name = "google/gemma-3-4b-it"
23
+ llm_provider = "gemma-transformers"
24
  device = "mps" # 设备类型
25
  segmentation_batch_size = 64
26
  parallel = True
 
58
  result = transcribe_podcast_audio(audio,
59
  podcast_info=mock_podcast_info,
60
  episode_info=mock_episode_info,
 
 
 
61
  device=device,
62
  segmentation_batch_size=segmentation_batch_size,
63
+ parallel=parallel,)
 
64
 
65
  # 输出结果
66
  print("\n转录结果:")
examples/simple_llm.py CHANGED
@@ -1,11 +1,10 @@
1
-
2
  # 添加项目根目录到Python路径
3
  import sys
4
  from pathlib import Path
5
 
6
  sys.path.insert(0, str(Path(__file__).parent.parent))
7
 
8
- from src.podcast_transcribe.llm.llm_phi4_transfomers import Phi4TransformersChatCompletion
9
  from src.podcast_transcribe.llm.llm_gemma_mlx import GemmaMLXChatCompletion
10
  from src.podcast_transcribe.llm.llm_gemma_transfomers import GemmaTransformersChatCompletion
11
 
@@ -14,6 +13,7 @@ if __name__ == "__main__":
14
  # 示例用法:
15
  print("正在初始化 LLM 聊天补全...")
16
  try:
 
17
  model_name = "google/gemma-3-4b-it"
18
  use_4bit_quantization = False
19
  device = "mps"
@@ -22,10 +22,10 @@ if __name__ == "__main__":
22
  # 或者,如果您有更小、更快的模型,可以尝试使用,例如:"mlx-community/gemma-2b-it-8bit"
23
  if model_name.startswith("mlx-community"):
24
  gemma_chat = GemmaMLXChatCompletion(model_name=model_name)
25
- elif model_name.startswith("microsoft"):
26
- gemma_chat = Phi4TransformersChatCompletion(model_name=model_name, use_4bit_quantization=use_4bit_quantization, device=device)
27
  else:
28
- gemma_chat = GemmaTransformersChatCompletion(model_name=model_name, use_4bit_quantization=use_4bit_quantization, device=device)
 
 
29
 
30
  print("\n--- 示例 1: 简单用户查询 ---")
31
  messages_example1 = [
 
1
+ import torch # 导入 torch
2
  # 添加项目根目录到Python路径
3
  import sys
4
  from pathlib import Path
5
 
6
  sys.path.insert(0, str(Path(__file__).parent.parent))
7
 
 
8
  from src.podcast_transcribe.llm.llm_gemma_mlx import GemmaMLXChatCompletion
9
  from src.podcast_transcribe.llm.llm_gemma_transfomers import GemmaTransformersChatCompletion
10
 
 
13
  # 示例用法:
14
  print("正在初始化 LLM 聊天补全...")
15
  try:
16
+ # model_name = "mlx-community/gemma-3-12b-it-4bit-DWQ"
17
  model_name = "google/gemma-3-4b-it"
18
  use_4bit_quantization = False
19
  device = "mps"
 
22
  # 或者,如果您有更小、更快的模型,可以尝试使用,例如:"mlx-community/gemma-2b-it-8bit"
23
  if model_name.startswith("mlx-community"):
24
  gemma_chat = GemmaMLXChatCompletion(model_name=model_name)
 
 
25
  else:
26
+ # 如果设备是 mps,则使用 float32 以增加稳定性
27
+ dtype_to_use = torch.float32 if device == "mps" else torch.float16
28
+ gemma_chat = GemmaTransformersChatCompletion(model_name=model_name, use_4bit_quantization=use_4bit_quantization, device=device, torch_dtype=dtype_to_use)
29
 
30
  print("\n--- 示例 1: 简单用户查询 ---")
31
  messages_example1 = [
examples/simple_speaker_identify.py CHANGED
@@ -13,6 +13,7 @@ from src.podcast_transcribe.summary.speaker_identify import SpeakerIdentifier
13
  if __name__ == '__main__':
14
  transcribe_result_dump_file = Path.joinpath(Path(__file__).parent, "output", "lex_ai_john_carmack_1.transcription.json")
15
  podcast_rss_xml_file = Path.joinpath(Path(__file__).parent, "input", "lexfridman.com.rss.xml")
 
16
 
17
  # Load the transcription result
18
  if not os.path.exists(transcribe_result_dump_file):
@@ -57,8 +58,9 @@ if __name__ == '__main__':
57
 
58
 
59
  speaker_identifier = SpeakerIdentifier(
60
- llm_model_name="mlx-community/gemma-3-12b-it-4bit-DWQ",
61
- llm_provider="gemma-mlx"
 
62
  )
63
 
64
  # 3. Call the function
 
13
  if __name__ == '__main__':
14
  transcribe_result_dump_file = Path.joinpath(Path(__file__).parent, "output", "lex_ai_john_carmack_1.transcription.json")
15
  podcast_rss_xml_file = Path.joinpath(Path(__file__).parent, "input", "lexfridman.com.rss.xml")
16
+ device = "mps"
17
 
18
  # Load the transcription result
19
  if not os.path.exists(transcribe_result_dump_file):
 
58
 
59
 
60
  speaker_identifier = SpeakerIdentifier(
61
+ llm_model_name="google/gemma-3-4b-it",
62
+ llm_provider="gemma-transformers",
63
+ device=device
64
  )
65
 
66
  # 3. Call the function
src/podcast_transcribe/llm/llm_base.py CHANGED
@@ -146,19 +146,12 @@ class BaseChatCompletion(ABC):
146
  temperature: float = 0.7,
147
  max_tokens: int = 2048,
148
  top_p: float = 1.0,
149
- model: Optional[str] = None,
150
  **kwargs,
151
  ):
152
  """
153
  创建聊天完成响应。
154
  模仿OpenAI的ChatCompletion.create方法。
155
  """
156
- if model and model != self.model_name:
157
- # 这是一个简化的处理。在实际场景中,您可能希望加载新模型。
158
- # 目前,我们将只打印一个警告并使用初始化的模型。
159
- print(f"警告: 'model' 参数 ({model}) 与初始化的模型 ({self.model_name}) 不同。"
160
- f"正在使用初始化的模型。要使用不同的模型,请重新初始化该类。")
161
-
162
  # 为Gemma格式化消息
163
  prompt_str = self._format_messages_for_gemma(messages)
164
 
 
146
  temperature: float = 0.7,
147
  max_tokens: int = 2048,
148
  top_p: float = 1.0,
 
149
  **kwargs,
150
  ):
151
  """
152
  创建聊天完成响应。
153
  模仿OpenAI的ChatCompletion.create方法。
154
  """
 
 
 
 
 
 
155
  # 为Gemma格式化消息
156
  prompt_str = self._format_messages_for_gemma(messages)
157
 
src/podcast_transcribe/llm/llm_gemma_transfomers.py CHANGED
@@ -13,7 +13,8 @@ class GemmaTransformersChatCompletion(TransformersBaseChatCompletion):
13
  use_4bit_quantization: bool = False,
14
  device_map: Optional[str] = None,
15
  device: Optional[str] = None,
16
- trust_remote_code: bool = True
 
17
  ):
18
  # Gemma 使用 float16 作为默认数据类型
19
  super().__init__(
@@ -22,7 +23,7 @@ class GemmaTransformersChatCompletion(TransformersBaseChatCompletion):
22
  device_map=device_map,
23
  device=device,
24
  trust_remote_code=trust_remote_code,
25
- torch_dtype=torch.float16
26
  )
27
 
28
  def _print_error_hints(self):
 
13
  use_4bit_quantization: bool = False,
14
  device_map: Optional[str] = None,
15
  device: Optional[str] = None,
16
+ trust_remote_code: bool = True,
17
+ torch_dtype: Optional[torch.dtype] = None
18
  ):
19
  # Gemma 使用 float16 作为默认数据类型
20
  super().__init__(
 
23
  device_map=device_map,
24
  device=device,
25
  trust_remote_code=trust_remote_code,
26
+ torch_dtype=torch_dtype if torch_dtype is not None else torch.float16
27
  )
28
 
29
  def _print_error_hints(self):
src/podcast_transcribe/llm/llm_phi4_transfomers.py DELETED
@@ -1,369 +0,0 @@
1
- import torch
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
3
- from typing import List, Dict, Optional, Union, Literal
4
- from .llm_base import TransformersBaseChatCompletion
5
-
6
-
7
- class Phi4TransformersChatCompletion(TransformersBaseChatCompletion):
8
- """基于 Transformers 库的 Phi-4-mini-reasoning 聊天完成实现"""
9
-
10
- def __init__(
11
- self,
12
- model_name: str = "microsoft/Phi-4-mini-reasoning",
13
- use_4bit_quantization: bool = False,
14
- device_map: Optional[str] = None,
15
- device: Optional[str] = None,
16
- trust_remote_code: bool = True
17
- ):
18
- # Phi-4 使用 bfloat16 作为推荐数据类型
19
- super().__init__(
20
- model_name=model_name,
21
- use_4bit_quantization=use_4bit_quantization,
22
- device_map=device_map,
23
- device=device,
24
- trust_remote_code=trust_remote_code,
25
- torch_dtype=torch.bfloat16
26
- )
27
-
28
- def _print_error_hints(self):
29
- """打印Phi-4特定的错误提示信息"""
30
- super()._print_error_hints()
31
- print("Phi-4 特殊要求:")
32
- print("- 建议使用 Transformers >= 4.51.3")
33
- print("- 推荐使用 bfloat16 数据类型")
34
- print("- 模型支持 128K token 上下文长度")
35
-
36
- def _format_phi4_messages(self, messages: List[Dict[str, str]]) -> str:
37
- """
38
- 格式化消息为 Phi-4 的聊天格式
39
- Phi-4 使用特定的聊天模板格式
40
- """
41
- # 使用 tokenizer 的内置聊天模板
42
- if hasattr(self.tokenizer, 'apply_chat_template'):
43
- return self.tokenizer.apply_chat_template(
44
- messages,
45
- tokenize=False,
46
- add_generation_prompt=True
47
- )
48
- else:
49
- # 如果没有聊天模板,使用 Phi-4 的标准格式
50
- formatted_prompt = ""
51
- for message in messages:
52
- role = message.get("role", "user")
53
- content = message.get("content", "")
54
-
55
- if role == "system":
56
- formatted_prompt += f"<|system|>\n{content}<|end|>\n"
57
- elif role == "user":
58
- formatted_prompt += f"<|user|>\n{content}<|end|>\n"
59
- elif role == "assistant":
60
- formatted_prompt += f"<|assistant|>\n{content}<|end|>\n"
61
-
62
- # 添加助手开始标记
63
- formatted_prompt += "<|assistant|>\n"
64
- return formatted_prompt
65
-
66
- def _generate_response(
67
- self,
68
- prompt_str: str,
69
- temperature: float,
70
- max_tokens: int,
71
- top_p: float,
72
- enable_reasoning: bool = True,
73
- **kwargs
74
- ) -> str:
75
- """使用 transformers 生成响应,针对 Phi-4 推理功能优化"""
76
-
77
- # 对提示进行编码
78
- inputs = self.tokenizer.encode(prompt_str, return_tensors="pt")
79
-
80
- # 移动输入到正确的设备
81
- if self.device_map is None or self.device.type == "mps":
82
- inputs = inputs.to(self.device)
83
-
84
- # Phi-4-mini-reasoning 优化的生成参数
85
- generation_config = {
86
- "max_new_tokens": min(max_tokens, 32768), # Phi-4-mini 支持最大 32K token
87
- "temperature": temperature,
88
- "top_p": top_p,
89
- "do_sample": True if temperature > 0 else False,
90
- "pad_token_id": self.tokenizer.pad_token_id,
91
- "eos_token_id": self.tokenizer.eos_token_id,
92
- "repetition_penalty": kwargs.get("repetition_penalty", 1.1),
93
- "no_repeat_ngram_size": kwargs.get("no_repeat_ngram_size", 3),
94
- }
95
-
96
- # 推理模式配置
97
- if enable_reasoning and "reasoning" in self.model_name.lower():
98
- # 为推理任务优化的配置
99
- generation_config.update({
100
- "temperature": max(temperature, 0.1), # 推理模式下保持一定的温度
101
- "top_p": min(top_p, 0.95), # 推理模式下限制 top_p
102
- "do_sample": True, # 推理模式下总是启用采样
103
- "early_stopping": False, # 允许完整的推理过程
104
- })
105
-
106
- # 如果温度为0,使用贪婪解码
107
- if temperature == 0:
108
- generation_config["do_sample"] = False
109
- generation_config.pop("temperature", None)
110
- generation_config.pop("top_p", None)
111
-
112
- try:
113
- # 生成响应
114
- with torch.no_grad():
115
- outputs = self.model.generate(
116
- inputs,
117
- **generation_config
118
- )
119
-
120
- # 解码生成的文本,跳过输入部分
121
- generated_tokens = outputs[0][len(inputs[0]):]
122
- generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
123
-
124
- return generated_text
125
-
126
- except Exception as e:
127
- print(f"生成响应时出错: {e}")
128
- raise
129
-
130
- def create(
131
- self,
132
- messages: List[Dict[str, str]],
133
- temperature: float = 0.7,
134
- max_tokens: int = 2048,
135
- top_p: float = 1.0,
136
- model: Optional[str] = None,
137
- enable_reasoning: bool = True,
138
- **kwargs,
139
- ):
140
- """
141
- 创建聊天完成响应,支持Phi-4特有的推理功能
142
- """
143
- if model and model != self.model_name:
144
- print(f"警告: 'model' 参数 ({model}) 与初始化的模型 ({self.model_name}) 不同。"
145
- f"正在使用初始化的模型。要使用不同的模型,请重新初始化该类。")
146
-
147
- # 检查是否为推理任务
148
- is_reasoning_task = self._is_reasoning_task(messages)
149
-
150
- # 格式化消息为 Phi-4 聊天格式
151
- if is_reasoning_task and enable_reasoning:
152
- prompt_str = self._format_reasoning_prompt(messages)
153
- else:
154
- prompt_str = self._format_phi4_messages(messages)
155
-
156
- # 生成响应
157
- response_text = self._generate_response(
158
- prompt_str,
159
- temperature,
160
- max_tokens,
161
- top_p,
162
- enable_reasoning=enable_reasoning and is_reasoning_task,
163
- **kwargs
164
- )
165
-
166
- # 后处理响应(使用基类的方法,但针对Phi-4调整)
167
- assistant_message_content = self._post_process_phi4_response(response_text, prompt_str)
168
-
169
- # 计算token使用量
170
- token_usage = self._calculate_tokens(prompt_str, assistant_message_content)
171
-
172
- # 构建响应对象
173
- response = self._build_chat_completion_response(assistant_message_content, token_usage)
174
-
175
- # 添加Phi-4特有的信息
176
- response["reasoning_enabled"] = enable_reasoning and is_reasoning_task
177
-
178
- return response
179
-
180
- def _post_process_phi4_response(self, response_text: str, prompt_str: str) -> str:
181
- """
182
- 后处理Phi-4生成的响应文本
183
- """
184
- # Phi-4的输出通常不包含输入提示,直接返回生成的内容
185
- assistant_message_content = response_text.strip()
186
-
187
- # 清理可能的特殊标记
188
- if assistant_message_content.endswith("<|end|>"):
189
- assistant_message_content = assistant_message_content[:-7].strip()
190
-
191
- return assistant_message_content
192
-
193
- def _is_reasoning_task(self, messages: List[Dict[str, str]]) -> bool:
194
- """检测是否为推理任务"""
195
- reasoning_keywords = [
196
- "解题", "推理", "计算", "证明", "分析", "逻辑", "步骤",
197
- "solve", "reasoning", "calculate", "prove", "analyze", "logic", "step"
198
- ]
199
-
200
- for message in messages:
201
- content = message.get("content", "").lower()
202
- if any(keyword in content for keyword in reasoning_keywords):
203
- return True
204
-
205
- return False
206
-
207
- def _format_reasoning_prompt(self, messages: List[Dict[str, str]]) -> str:
208
- """
209
- 为推理任务格式化特殊的提示词
210
- """
211
- # 添加推理指导的系统消息
212
- reasoning_system_msg = {
213
- "role": "system",
214
- "content": "你是一个专业的数学推理助手。请逐步分析问题,展示详细的推理过程,包括:\n1. 问题理解\n2. 解题思路\n3. 具体步骤\n4. 最终答案\n\n每个步骤都要清晰明了。"
215
- }
216
-
217
- # 将推理系统消息添加到消息列表的开头
218
- enhanced_messages = [reasoning_system_msg] + messages
219
-
220
- # 使用标准格式化方法
221
- return self._format_phi4_messages(enhanced_messages)
222
-
223
- def reasoning_completion(
224
- self,
225
- messages: List[Dict[str, str]],
226
- temperature: float = 0.3, # 推理任务使用较低的温度
227
- max_tokens: int = 2048, # 推理任务需要更多 tokens
228
- top_p: float = 0.9,
229
- extract_reasoning_steps: bool = True,
230
- **kwargs
231
- ) -> Dict[str, Union[str, Dict, List]]:
232
- """
233
- 专门用于推理任务的聊天完成接口
234
-
235
- Args:
236
- messages: 对话消息列表
237
- temperature: 采样温度(推理任务建议使用较低值)
238
- max_tokens: 最大生成token数量
239
- top_p: top-p采样参数
240
- extract_reasoning_steps: 是否提取推理步骤
241
- **kwargs: 其他参数
242
-
243
- Returns:
244
- 包含推理步骤的响应字典
245
- """
246
- # 强制启用推理模式
247
- response = self.create(
248
- messages=messages,
249
- temperature=temperature,
250
- max_tokens=max_tokens,
251
- top_p=top_p,
252
- enable_reasoning=True,
253
- **kwargs
254
- )
255
-
256
- if extract_reasoning_steps:
257
- # 提取推理步骤
258
- content = response["choices"][0]["message"]["content"]
259
- reasoning_steps = self._extract_reasoning_steps(content)
260
- response["reasoning_steps"] = reasoning_steps
261
-
262
- return response
263
-
264
- def _extract_reasoning_steps(self, content: str) -> List[Dict[str, str]]:
265
- """
266
- 从响应内容中提取推理步骤
267
- """
268
- steps = []
269
- lines = content.split('\n')
270
- current_step = {"title": "", "content": ""}
271
-
272
- step_patterns = [
273
- "1. 问题理解", "2. 解题思路", "3. 具体步骤", "4. 最终答案",
274
- "步骤", "分析", "解答", "结论", "reasoning", "step", "analysis", "solution"
275
- ]
276
-
277
- for line in lines:
278
- line = line.strip()
279
- if not line:
280
- continue
281
-
282
- # 检查是否是新的步骤开始
283
- is_new_step = any(pattern in line.lower() for pattern in step_patterns)
284
- if is_new_step and current_step["content"]:
285
- steps.append(current_step.copy())
286
- current_step = {"title": line, "content": ""}
287
- elif is_new_step:
288
- current_step["title"] = line
289
- else:
290
- if current_step["title"]:
291
- current_step["content"] += line + "\n"
292
- else:
293
- current_step["content"] = line + "\n"
294
-
295
- # 添加最后一个步骤
296
- if current_step["title"] or current_step["content"]:
297
- steps.append(current_step)
298
-
299
- return steps
300
-
301
- def get_model_info(self) -> Dict[str, Union[str, bool, int]]:
302
- """获取 Phi-4 模型信息"""
303
- model_info = super().get_model_info()
304
-
305
- # 添加Phi-4特有的信息
306
- model_info.update({
307
- "model_family": "Phi-4-mini-reasoning",
308
- "parameters": "3.8B",
309
- "context_length": "128K tokens",
310
- "specialization": "数学推理优化",
311
- })
312
-
313
- return model_info
314
-
315
-
316
- # 工厂函数
317
- def create_phi4_transformers_client(
318
- model_name: str = "microsoft/Phi-4-mini-reasoning",
319
- use_4bit_quantization: bool = False,
320
- device: Optional[str] = None,
321
- **kwargs
322
- ) -> Phi4TransformersChatCompletion:
323
- """
324
- 创建 Phi-4 Transformers 客户端的工厂函数
325
-
326
- Args:
327
- model_name: 模型名称,默认为 microsoft/Phi-4-mini-reasoning
328
- use_4bit_quantization: 是否使用4bit量化
329
- device: 指定设备 ("cpu", "cuda", "mps", 等)
330
- **kwargs: 其他传递给构造函数的参数
331
-
332
- Returns:
333
- Phi4TransformersChatCompletion 实例
334
- """
335
- return Phi4TransformersChatCompletion(
336
- model_name=model_name,
337
- use_4bit_quantization=use_4bit_quantization,
338
- device=device,
339
- **kwargs
340
- )
341
-
342
- def create_reasoning_client(
343
- model_name: str = "microsoft/Phi-4-mini-reasoning",
344
- use_4bit_quantization: bool = False,
345
- device: Optional[str] = None,
346
- **kwargs
347
- ) -> Phi4TransformersChatCompletion:
348
- """
349
- 创建专门用于推理任务的 Phi-4 客户端
350
-
351
- Args:
352
- model_name: 模型名称,推荐使用 microsoft/Phi-4-mini-reasoning
353
- use_4bit_quantization: 是否使用4bit量化
354
- device: 指定设备 ("cpu", "cuda", "mps", 等)
355
- **kwargs: 其他传递给构造函数的参数
356
-
357
- Returns:
358
- 优化了推理功能的 Phi4TransformersChatCompletion 实例
359
- """
360
- # 确保使用推理模型
361
- if "reasoning" not in model_name.lower():
362
- print("警告: 建议使用包含 'reasoning' 的模型名称以获得最佳推理性能")
363
-
364
- return Phi4TransformersChatCompletion(
365
- model_name=model_name,
366
- use_4bit_quantization=use_4bit_quantization,
367
- device=device,
368
- **kwargs
369
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/podcast_transcribe/llm/llm_router.py CHANGED
@@ -4,13 +4,13 @@ LLM模型调用路由器
4
  """
5
 
6
  import logging
 
7
  from typing import Dict, Any, Optional, List, Union
8
 
9
  import spaces
10
  from .llm_base import BaseChatCompletion
11
  from . import llm_gemma_mlx
12
  from . import llm_gemma_transfomers
13
- from . import llm_phi4_transfomers
14
 
15
  # 配置日志
16
  logger = logging.getLogger("llm")
@@ -39,19 +39,9 @@ class LLMRouter:
39
  "default_model": "google/gemma-3-4b-it",
40
  "supported_params": [
41
  "model_name", "use_4bit_quantization", "device_map",
42
- "device", "trust_remote_code"
43
  ],
44
  "description": "基于Transformers库的Gemma聊天完成实现"
45
- },
46
- "phi4-transformers": {
47
- "module_path": "llm_phi4_transfomers",
48
- "class_name": "Phi4TransformersChatCompletion",
49
- "default_model": "microsoft/Phi-4-reasoning",
50
- "supported_params": [
51
- "model_name", "use_4bit_quantization", "device_map",
52
- "device", "trust_remote_code", "enable_reasoning"
53
- ],
54
- "description": "基于Transformers库的Phi-4推理聊天完成实现"
55
  }
56
  }
57
 
@@ -77,8 +67,6 @@ class LLMRouter:
77
  module = llm_gemma_mlx
78
  elif module_path == "llm_gemma_transfomers":
79
  module = llm_gemma_transfomers
80
- elif module_path == "llm_phi4_transfomers":
81
- module = llm_phi4_transfomers
82
  else:
83
  raise ImportError(f"未找到模块: {module_path}")
84
 
@@ -219,6 +207,12 @@ class LLMRouter:
219
  if model is not None:
220
  kwargs["model_name"] = model
221
 
 
 
 
 
 
 
222
  # 获取或创建LLM实例
223
  llm_instance = self._get_or_create_instance(provider, **kwargs)
224
 
@@ -242,7 +236,7 @@ class LLMRouter:
242
  def reasoning_completion(
243
  self,
244
  messages: List[Dict[str, str]],
245
- provider: str = "phi4-transformers",
246
  temperature: float = 0.3,
247
  max_tokens: int = 2048,
248
  top_p: float = 0.9,
@@ -255,7 +249,7 @@ class LLMRouter:
255
 
256
  参数:
257
  messages: 消息列表,每个消息包含role和content
258
- provider: LLM提供者名称,默认使用phi4-transformers
259
  temperature: 温度参数(推理任务建议使用较低值)
260
  max_tokens: 最大生成token数
261
  top_p: nucleus采样参数
@@ -269,14 +263,20 @@ class LLMRouter:
269
  logger.info(f"使用provider '{provider}' 进行推理完成,消息数量: {len(messages)}")
270
 
271
  # 确保使用支持推理的provider
272
- if provider not in ["phi4-transformers"]:
273
- logger.warning(f"Provider '{provider}' 可能不支持推理功能,建议使用 'phi4-transformers'")
274
 
275
  try:
276
  # 如果提供了model参数,添加到kwargs中
277
  if model is not None:
278
  kwargs["model_name"] = model
279
 
 
 
 
 
 
 
280
  # 获取或创建LLM实例
281
  llm_instance = self._get_or_create_instance(provider, **kwargs)
282
 
@@ -372,7 +372,7 @@ _router = LLMRouter()
372
  @spaces.GPU(duration=60)
373
  def chat_completion(
374
  messages: List[Dict[str, str]],
375
- provider: str = "gemma-mlx",
376
  temperature: float = 0.7,
377
  max_tokens: int = 2048,
378
  top_p: float = 1.0,
@@ -391,7 +391,6 @@ def chat_completion(
391
  provider: LLM提供者,可选值:
392
  - "gemma-mlx": 基于MLX库的Gemma聊天完成实现
393
  - "gemma-transformers": 基于Transformers库的Gemma聊天完成实现
394
- - "phi4-transformers": 基于Transformers库的Phi-4推理聊天完成实现
395
  temperature: 温度参数,控制生成的随机性 (0.0-2.0)
396
  max_tokens: 最大生成token数
397
  top_p: nucleus采样参数 (0.0-1.0)
@@ -421,14 +420,6 @@ def chat_completion(
421
  use_4bit_quantization=True
422
  )
423
 
424
- # 使用Phi-4推理实现
425
- response = chat_completion(
426
- messages=[{"role": "user", "content": "解这个数学题:2x + 5 = 15"}],
427
- provider="phi4-transformers",
428
- model="microsoft/Phi-4-mini-reasoning",
429
- device="cuda"
430
- )
431
-
432
  # 自定义参数
433
  response = chat_completion(
434
  messages=[
@@ -466,7 +457,7 @@ def chat_completion(
466
  @spaces.GPU(duration=60)
467
  def reasoning_completion(
468
  messages: List[Dict[str, str]],
469
- provider: str = "phi4-transformers",
470
  temperature: float = 0.3,
471
  max_tokens: int = 2048,
472
  top_p: float = 0.9,
@@ -483,7 +474,7 @@ def reasoning_completion(
483
 
484
  参数:
485
  messages: 消息列表,每个消息包含role和content字段
486
- provider: LLM提供者,默认使用phi4-transformers
487
  temperature: 温度参数(推理任务建议使用较低值)
488
  max_tokens: 最大生成token数
489
  top_p: nucleus采样参数
@@ -502,14 +493,14 @@ def reasoning_completion(
502
  # 数学推理任务
503
  response = reasoning_completion(
504
  messages=[{"role": "user", "content": "解这个方程:3x + 7 = 22"}],
505
- provider="phi4-transformers",
506
  extract_reasoning_steps=True
507
  )
508
 
509
  # 逻辑推理任务
510
  response = reasoning_completion(
511
  messages=[{"role": "user", "content": "如果所有的猫都是动物,而小花是一只猫,那么小花是什么?"}],
512
- provider="phi4-transformers",
513
  temperature=0.2
514
  )
515
  """
 
4
  """
5
 
6
  import logging
7
+ import torch
8
  from typing import Dict, Any, Optional, List, Union
9
 
10
  import spaces
11
  from .llm_base import BaseChatCompletion
12
  from . import llm_gemma_mlx
13
  from . import llm_gemma_transfomers
 
14
 
15
  # 配置日志
16
  logger = logging.getLogger("llm")
 
39
  "default_model": "google/gemma-3-4b-it",
40
  "supported_params": [
41
  "model_name", "use_4bit_quantization", "device_map",
42
+ "device", "trust_remote_code", "torch_dtype"
43
  ],
44
  "description": "基于Transformers库的Gemma聊天完成实现"
 
 
 
 
 
 
 
 
 
 
45
  }
46
  }
47
 
 
67
  module = llm_gemma_mlx
68
  elif module_path == "llm_gemma_transfomers":
69
  module = llm_gemma_transfomers
 
 
70
  else:
71
  raise ImportError(f"未找到模块: {module_path}")
72
 
 
207
  if model is not None:
208
  kwargs["model_name"] = model
209
 
210
+ # 如果设备是 mps,并且是 transformers provider,则强制使用 float32
211
+ current_device = kwargs.get("device")
212
+ if current_device == "mps":
213
+ if provider == "gemma-transformers":
214
+ kwargs["torch_dtype"] = torch.float32
215
+
216
  # 获取或创建LLM实例
217
  llm_instance = self._get_or_create_instance(provider, **kwargs)
218
 
 
236
  def reasoning_completion(
237
  self,
238
  messages: List[Dict[str, str]],
239
+ provider: str = "gemma-transformers",
240
  temperature: float = 0.3,
241
  max_tokens: int = 2048,
242
  top_p: float = 0.9,
 
249
 
250
  参数:
251
  messages: 消息列表,每个消息包含role和content
252
+ provider: LLM提供者名称,默认使用gemma-transformers
253
  temperature: 温度参数(推理任务建议使用较低值)
254
  max_tokens: 最大生成token数
255
  top_p: nucleus采样参数
 
263
  logger.info(f"使用provider '{provider}' 进行推理完成,消息数量: {len(messages)}")
264
 
265
  # 确保使用支持推理的provider
266
+ if provider not in ["gemma-transformers"]:
267
+ logger.warning(f"Provider '{provider}' 可能不支持推理功能,建议使用 'gemma-transformers'")
268
 
269
  try:
270
  # 如果提供了model参数,添加到kwargs中
271
  if model is not None:
272
  kwargs["model_name"] = model
273
 
274
+ # 如果设备是 mps,并且是 transformers provider,则强制使用 float32
275
+ current_device = kwargs.get("device")
276
+ if current_device == "mps":
277
+ if provider == "gemma-transformers":
278
+ kwargs["torch_dtype"] = torch.float32
279
+
280
  # 获取或创建LLM实例
281
  llm_instance = self._get_or_create_instance(provider, **kwargs)
282
 
 
372
  @spaces.GPU(duration=60)
373
  def chat_completion(
374
  messages: List[Dict[str, str]],
375
+ provider: str = "gemma-transformers",
376
  temperature: float = 0.7,
377
  max_tokens: int = 2048,
378
  top_p: float = 1.0,
 
391
  provider: LLM提供者,可选值:
392
  - "gemma-mlx": 基于MLX库的Gemma聊天完成实现
393
  - "gemma-transformers": 基于Transformers库的Gemma聊天完成实现
 
394
  temperature: 温度参数,控制生成的随机性 (0.0-2.0)
395
  max_tokens: 最大生成token数
396
  top_p: nucleus采样参数 (0.0-1.0)
 
420
  use_4bit_quantization=True
421
  )
422
 
 
 
 
 
 
 
 
 
423
  # 自定义参数
424
  response = chat_completion(
425
  messages=[
 
457
  @spaces.GPU(duration=60)
458
  def reasoning_completion(
459
  messages: List[Dict[str, str]],
460
+ provider: str = "gemma-transformers",
461
  temperature: float = 0.3,
462
  max_tokens: int = 2048,
463
  top_p: float = 0.9,
 
474
 
475
  参数:
476
  messages: 消息列表,每个消息包含role和content字段
477
+ provider: LLM提供者,默认使用gemma-transformers
478
  temperature: 温度参数(推理任务建议使用较低值)
479
  max_tokens: 最大生成token数
480
  top_p: nucleus采样参数
 
493
  # 数学推理任务
494
  response = reasoning_completion(
495
  messages=[{"role": "user", "content": "解这个方程:3x + 7 = 22"}],
496
+ provider="gemma-transformers",
497
  extract_reasoning_steps=True
498
  )
499
 
500
  # 逻辑推理任务
501
  response = reasoning_completion(
502
  messages=[{"role": "user", "content": "如果所有的猫都是动物,而小花是一只猫,那么小花是什么?"}],
503
+ provider="gemma-transformers",
504
  temperature=0.2
505
  )
506
  """
src/podcast_transcribe/summary/speaker_identify.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from typing import List, Dict, Optional
2
  import json
3
  import re
@@ -5,22 +6,26 @@ import re
5
  from ..schemas import EnhancedSegment, PodcastChannel, PodcastEpisode
6
  from ..llm import llm_router
7
 
 
 
8
 
9
  class SpeakerIdentifier:
10
  """
11
  说话人识别器类,用于根据转录分段和播客元数据识别说话人的真实姓名或昵称
12
  """
13
 
14
- def __init__(self, llm_model_name: str, llm_provider: str):
15
  """
16
  初始化说话人识别器
17
 
18
  参数:
19
  llm_model_name: LLM模型名称,如果为None则使用默认模型
20
  llm_provider: LLM提供者,默认为"gemma-mlx"
 
21
  """
22
  self.llm_model_name = llm_model_name
23
  self.llm_provider = llm_provider
 
24
 
25
  def _clean_html(self, html_string: Optional[str]) -> str:
26
  """
@@ -280,8 +285,10 @@ Please begin your analysis and provide the JSON result.
280
  provider=self.llm_provider,
281
  model=self.llm_model_name,
282
  temperature=0.1,
283
- max_tokens=1024
 
284
  )
 
285
  assistant_response_content = response["choices"][0]["message"]["content"]
286
 
287
  parsed_llm_output = None
 
1
+ import logging
2
  from typing import List, Dict, Optional
3
  import json
4
  import re
 
6
  from ..schemas import EnhancedSegment, PodcastChannel, PodcastEpisode
7
  from ..llm import llm_router
8
 
9
+ # 配置日志
10
+ logger = logging.getLogger("speaker_identify")
11
 
12
  class SpeakerIdentifier:
13
  """
14
  说话人识别器类,用于根据转录分段和播客元数据识别说话人的真实姓名或昵称
15
  """
16
 
17
+ def __init__(self, llm_model_name: str, llm_provider: str, device: Optional[str] = None):
18
  """
19
  初始化说话人识别器
20
 
21
  参数:
22
  llm_model_name: LLM模型名称,如果为None则使用默认模型
23
  llm_provider: LLM提供者,默认为"gemma-mlx"
24
+ device: 计算设备,例如 "cpu", "cuda", "mps"
25
  """
26
  self.llm_model_name = llm_model_name
27
  self.llm_provider = llm_provider
28
+ self.device = device
29
 
30
  def _clean_html(self, html_string: Optional[str]) -> str:
31
  """
 
285
  provider=self.llm_provider,
286
  model=self.llm_model_name,
287
  temperature=0.1,
288
+ max_tokens=1024,
289
+ device=self.device
290
  )
291
+ logger.info(f"LLM调用日志,请求参数:【{messages}】, 响应: 【{response}】")
292
  assistant_response_content = response["choices"][0]["message"]["content"]
293
 
294
  parsed_llm_output = None
src/podcast_transcribe/transcriber.py CHANGED
@@ -29,8 +29,8 @@ class CombinedTranscriber:
29
  asr_provider: str,
30
  diarization_provider: str,
31
  diarization_model_name: str,
32
- llm_model_name: Optional[str] = None,
33
- llm_provider: Optional[str] = None,
34
  device: Optional[str] = None,
35
  segmentation_batch_size: int = 64,
36
  parallel: bool = False,
@@ -43,6 +43,8 @@ class CombinedTranscriber:
43
  asr_provider: ASR提供者名称
44
  diarization_provider: 说话人分离提供者名称
45
  diarization_model_name: 说话人分离模型名称
 
 
46
  device: 推理设备,'cpu'或'cuda'
47
  segmentation_batch_size: 分割批处理大小,默认为64
48
  parallel: 是否并行执行ASR和说话人分离,默认为False
@@ -51,23 +53,10 @@ class CombinedTranscriber:
51
  import torch
52
  if torch.backends.mps.is_available():
53
  device = "mps"
54
- if not llm_model_name:
55
- llm_model_name = "mlx-community/gemma-3-12b-it-4bit-DWQ"
56
- if not llm_provider:
57
- llm_provider = "gemma-mlx"
58
-
59
  elif torch.cuda.is_available():
60
  device = "cuda"
61
- if not llm_model_name:
62
- llm_model_name = "google/gemma-3-4b-it"
63
- if not llm_provider:
64
- llm_provider = "gemma-transformers"
65
  else:
66
  device = "cpu"
67
- if not llm_model_name:
68
- llm_model_name = "google/gemma-3-4b-it"
69
- if not llm_provider:
70
- llm_provider = "gemma-transformers"
71
 
72
  self.asr_model_name = asr_model_name
73
  self.asr_provider = asr_provider
@@ -79,7 +68,8 @@ class CombinedTranscriber:
79
 
80
  self.speaker_identifier = SpeakerIdentifier(
81
  llm_model_name=llm_model_name,
82
- llm_provider=llm_provider
 
83
  )
84
 
85
  logger.info(f"初始化组合转录器,ASR提供者: {asr_provider},ASR模型: {asr_model_name},分离提供者: {diarization_provider},分离模型: {diarization_model_name},分割批处理大小: {segmentation_batch_size},并行执行: {parallel},推理设备: {device}")
@@ -513,6 +503,8 @@ def transcribe_audio(
513
  asr_provider=asr_provider,
514
  diarization_model_name=diarization_model_name,
515
  diarization_provider=diarization_provider,
 
 
516
  device=device,
517
  segmentation_batch_size=segmentation_batch_size,
518
  parallel=parallel
@@ -529,8 +521,8 @@ def transcribe_podcast_audio(
529
  asr_provider: str = "distil_whisper_transformers",
530
  diarization_model_name: str = "pyannote/speaker-diarization-3.1",
531
  diarization_provider: str = "pyannote_transformers",
532
- llm_model_name: Optional[str] = None,
533
- llm_provider: Optional[str] = None,
534
  device: Optional[str] = None,
535
  segmentation_batch_size: int = 64,
536
  parallel: bool = False,
@@ -546,8 +538,8 @@ def transcribe_podcast_audio(
546
  asr_provider: ASR提供者名称
547
  diarization_provider: 说话人分离提供者名称
548
  diarization_model_name: 说话人分离模型名称
549
- llm_model_name: LLM模型名称,如果为None则无法识别说话人名称
550
- llm_provider: LLM提供者名称,如果为None则无法识别说话人名称
551
  device: 推理设备,'cpu'或'cuda'
552
  segmentation_batch_size: 分割批处理大小,默认为64
553
  parallel: 是否并行执行ASR和说话人分离,默认为False
 
29
  asr_provider: str,
30
  diarization_provider: str,
31
  diarization_model_name: str,
32
+ llm_model_name: str,
33
+ llm_provider: str,
34
  device: Optional[str] = None,
35
  segmentation_batch_size: int = 64,
36
  parallel: bool = False,
 
43
  asr_provider: ASR提供者名称
44
  diarization_provider: 说话人分离提供者名称
45
  diarization_model_name: 说话人分离模型名称
46
+ llm_model_name: LLM模型名称
47
+ llm_provider: LLM提供者名称
48
  device: 推理设备,'cpu'或'cuda'
49
  segmentation_batch_size: 分割批处理大小,默认为64
50
  parallel: 是否并行执行ASR和说话人分离,默认为False
 
53
  import torch
54
  if torch.backends.mps.is_available():
55
  device = "mps"
 
 
 
 
 
56
  elif torch.cuda.is_available():
57
  device = "cuda"
 
 
 
 
58
  else:
59
  device = "cpu"
 
 
 
 
60
 
61
  self.asr_model_name = asr_model_name
62
  self.asr_provider = asr_provider
 
68
 
69
  self.speaker_identifier = SpeakerIdentifier(
70
  llm_model_name=llm_model_name,
71
+ llm_provider=llm_provider,
72
+ device=device
73
  )
74
 
75
  logger.info(f"初始化组合转录器,ASR提供者: {asr_provider},ASR模型: {asr_model_name},分离提供者: {diarization_provider},分离模型: {diarization_model_name},分割批处理大小: {segmentation_batch_size},并行执行: {parallel},推理设备: {device}")
 
503
  asr_provider=asr_provider,
504
  diarization_model_name=diarization_model_name,
505
  diarization_provider=diarization_provider,
506
+ llm_model_name="",
507
+ llm_provider="",
508
  device=device,
509
  segmentation_batch_size=segmentation_batch_size,
510
  parallel=parallel
 
521
  asr_provider: str = "distil_whisper_transformers",
522
  diarization_model_name: str = "pyannote/speaker-diarization-3.1",
523
  diarization_provider: str = "pyannote_transformers",
524
+ llm_model_name: str = "google/gemma-3-4b-it",
525
+ llm_provider: str = "gemma-transformers",
526
  device: Optional[str] = None,
527
  segmentation_batch_size: int = 64,
528
  parallel: bool = False,
 
538
  asr_provider: ASR提供者名称
539
  diarization_provider: 说话人分离提供者名称
540
  diarization_model_name: 说话人分离模型名称
541
+ llm_model_name: LLM模型名称
542
+ llm_provider: LLM提供者名称
543
  device: 推理设备,'cpu'或'cuda'
544
  segmentation_batch_size: 分割批处理大小,默认为64
545
  parallel: 是否并行执行ASR和说话人分离,默认为False