Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
""" | |
辩论控制器模块 | |
管理辩论流程和控制发言顺序 | |
""" | |
import os | |
import sys | |
import json | |
import time | |
import logging | |
import threading | |
from typing import Dict, List, Any, Optional, Callable | |
from datetime import datetime | |
from queue import Queue | |
# 在代码开头强制设置终端编码为UTF-8(仅Windows执行) | |
if os.name == 'nt': | |
os.system('chcp 65001 > nul') | |
# 获取当前脚本文件所在目录的绝对路径 | |
current_script_dir = os.path.dirname(os.path.abspath(__file__)) | |
# 项目根目录 (即 '20250907_大模型辩论' 目录, 是 'src' 的上一级) | |
project_root = os.path.dirname(current_script_dir) | |
# 定义数据输入、输出和日志目录 | |
DATA_DIR = os.path.join(project_root, 'data') | |
OUTPUT_DIR = os.path.join(project_root, 'output') | |
LOGS_DIR = os.path.join(project_root, 'logs') | |
# 确保目录存在 | |
os.makedirs(OUTPUT_DIR, exist_ok=True) | |
os.makedirs(LOGS_DIR, exist_ok=True) | |
# 配置日志 | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
handlers=[ | |
logging.FileHandler(os.path.join(LOGS_DIR, '辩论控制器日志.log'), encoding='utf-8'), | |
logging.StreamHandler(sys.stdout) | |
] | |
) | |
logger = logging.getLogger(__name__) | |
# 捕获警告并记录到日志 | |
logging.captureWarnings(True) | |
class ConversationMessage: | |
"""对话消息类""" | |
def __init__(self, role: str, content: str, model_name: str, timestamp: datetime = None): | |
""" | |
初始化对话消息 | |
Args: | |
role: 消息角色 ('user' 或 'assistant') | |
content: 消息内容 | |
model_name: 模型名称 | |
timestamp: 时间戳 | |
""" | |
self.role = role | |
self.content = content | |
self.model_name = model_name | |
self.timestamp = timestamp or datetime.now() | |
def to_dict(self) -> Dict[str, Any]: | |
"""转换为字典""" | |
return { | |
'role': self.role, | |
'content': self.content, | |
'model_name': self.model_name, | |
'timestamp': self.timestamp.isoformat() | |
} | |
def from_dict(cls, data: Dict[str, Any]) -> 'ConversationMessage': | |
"""从字典创建实例""" | |
return cls( | |
role=data['role'], | |
content=data['content'], | |
model_name=data['model_name'], | |
timestamp=datetime.fromisoformat(data['timestamp']) | |
) | |
class ConversationSession: | |
"""对话会话类""" | |
def __init__(self, topic: str, mode: str = 'debate', max_rounds: int = 5, pro_model: str = 'glm45', con_model: str = 'deepseek_v31', initial_prompt: str = "", initial_prompt_mode: str = 'append'): | |
""" | |
初始化对话会话 | |
Args: | |
topic: 话题或任务 | |
mode: 对话模式 ('debate' 或 'discussion') | |
max_rounds: 最大轮数 | |
pro_model: AI 1 (在辩论中为正方) | |
con_model: AI 2 (在辩论中为反方) | |
initial_prompt: 自定义初始提示 | |
initial_prompt_mode: 初始提示模式 ('append' 或 'override') | |
""" | |
self.topic = topic | |
self.mode = mode | |
self.max_rounds = max_rounds | |
self.pro_model = pro_model | |
self.con_model = con_model | |
self.initial_prompt = initial_prompt | |
self.initial_prompt_mode = initial_prompt_mode | |
self.messages: List[ConversationMessage] = [] | |
self.current_round = 0 | |
self.is_active = True | |
self.start_time = None | |
self.end_time = None | |
self.debate_id = f"conversation_{datetime.now().strftime('%Y%m%d_%H%M%S')}" | |
logger.info(f"对话会话初始化完成,模式: {mode}, 话题: {topic}, 最大轮数: {max_rounds}, AI_1: {pro_model}, AI_2: {con_model}") | |
def add_message(self, message: ConversationMessage): | |
"""添加消息到对话记录""" | |
self.messages.append(message) | |
logger.info(f"添加消息到对话记录,模型: {message.model_name}, 内容长度: {len(message.content)}") | |
def get_messages_for_model(self, model_name: str) -> List[Dict[str, str]]: | |
""" | |
获取指定模型的消息历史 | |
Args: | |
model_name: 模型名称 | |
Returns: | |
消息历史列表 | |
""" | |
# 转换为API需要的格式 | |
result = [] | |
for msg in self.messages: | |
result.append({ | |
'role': msg.role, | |
'content': msg.content | |
}) | |
return result | |
def get_debate_summary(self) -> Dict[str, Any]: | |
"""获取辩论摘要""" | |
return { | |
'debate_id': self.debate_id, | |
'topic': self.topic, | |
'max_rounds': self.max_rounds, | |
'current_round': self.current_round, | |
'total_messages': len(self.messages), | |
'is_active': self.is_active, | |
'start_time': self.start_time.isoformat() if self.start_time else None, | |
'end_time': self.end_time.isoformat() if self.end_time else None, | |
'duration': (self.end_time - self.start_time).total_seconds() if self.start_time and self.end_time else None, | |
'pro_model': self.pro_model, | |
'con_model': self.con_model | |
} | |
def save_to_file(self, file_path: str): | |
"""保存辩论记录到文件""" | |
data = { | |
'debate_info': self.get_debate_summary(), | |
'messages': [msg.to_dict() for msg in self.messages] | |
} | |
with open(file_path, 'w', encoding='utf-8') as f: | |
json.dump(data, f, ensure_ascii=False, indent=2) | |
logger.info(f"辩论记录已保存到: {file_path}") | |
def load_from_file(cls, file_path: str) -> 'ConversationSession': | |
"""从文件加载辩论记录""" | |
with open(file_path, 'r', encoding='utf-8') as f: | |
data = json.load(f) | |
debate_info = data['debate_info'] | |
session = cls( | |
topic=debate_info['topic'], | |
max_rounds=debate_info['max_rounds'], | |
pro_model=debate_info.get('pro_model', 'glm45'), | |
con_model=debate_info.get('con_model', 'deepseek_v31') | |
) | |
session.debate_id = debate_info['debate_id'] | |
session.current_round = debate_info['current_round'] | |
session.is_active = debate_info['is_active'] | |
session.start_time = datetime.fromisoformat(debate_info['start_time']) if debate_info['start_time'] else None | |
session.end_time = datetime.fromisoformat(debate_info['end_time']) if debate_info['end_time'] else None | |
session.messages = [ConversationMessage.from_dict(msg_data) for msg_data in data['messages']] | |
logger.info(f"从文件加载对话记录: {file_path}") | |
return session | |
def generate_prompt(self, speaker_name: str) -> str: | |
""" | |
根据对话历史和当前发言者生成提示 | |
Args: | |
speaker_name: 当前发言者名称 | |
Returns: | |
生成的提示字符串 | |
""" | |
is_pro = self.pro_model == speaker_name | |
role = "正方" if is_pro else "反方" | |
opponent_role = "反方" if is_pro else "正方" | |
# 计算轮次元信息 | |
# 首轮之前 current_round 可能为 0,此时视为第1轮;剩余轮数按“包含本轮”计算 | |
current_round_display = self.current_round if self.current_round and self.current_round > 0 else 1 | |
remaining_rounds = max(0, self.max_rounds - current_round_display + 1) | |
meta_info = ( | |
f"[对局信息]总轮数: {self.max_rounds};当前轮次: 第{current_round_display}轮;剩余轮数: {remaining_rounds}。\n" | |
f"[表达建议]无需模仿对方的结构或句式,请按你认为最有效的方式组织内容,保持清晰与力度。\n" | |
f"[风格限制]避免使用模板化句式开头(例如:“对方始终…/对方陷入…”);同一场对话中每轮至少更换一种表达结构,使用不同的连接词与论证展开;若指出对方问题,请给出“引用原句→具体证据→逻辑推导→结论”的完整链条,避免空泛断言。\n" | |
) | |
# 检查是否有自定义初始提示 | |
if self.initial_prompt: | |
if self.initial_prompt_mode == 'override': | |
# 如果是覆盖模式,且是第一轮的第一个发言者,直接使用自定义提示 | |
if len(self.messages) == 0: | |
return self.initial_prompt | |
else: # append 模式 | |
custom_prompt_part = f"\n\n此外,请务必遵守以下额外指示:{self.initial_prompt}" | |
else: | |
custom_prompt_part = "" | |
# 第一轮的第一个发言者(正方) | |
if len(self.messages) == 0: | |
prompt = ( | |
f"{meta_info}" | |
f"你将参与一场关于“{self.topic}”的{'辩论' if self.mode == 'debate' else '协作讨论'}。" | |
f"你是{role if self.mode == 'debate' else ('AI 1' if is_pro else 'AI 2')}," | |
f"你的任务是清晰、有力地阐述和捍卫你的立场(或提出高质量的协作建议)。请提出你的主要论点和论据。请直接开始你的陈述。" | |
) | |
return prompt + custom_prompt_part | |
# 后续发言 | |
last_message = self.messages[-1] | |
opponent_statement = last_message.content | |
if self.mode == 'debate': | |
prompt = ( | |
f"{meta_info}" | |
f"现在轮到你发言。你是{role}。请仔细阅读{opponent_role}({last_message.role})的上一轮发言," | |
f"然后提出简洁、有力的反驳,并进一步强化你自己的观点。不要说任何无关的话,直接开始你的陈述。\n\n" | |
f"**{opponent_role}的发言**:\n{opponent_statement}" | |
) | |
else: # discussion | |
prompt = ( | |
f"{meta_info}" | |
f"现在轮到你发言。你是{'AI 1' if is_pro else 'AI 2'}。请基于{'AI 2' if is_pro else 'AI 1'}({last_message.role})的发言," | |
f"继续就“{self.topic}”这个任务进行协作,提出你的想法和建议,共同推动讨论走向深入。\n\n" | |
f"**上一位参与者的发言**:\n{opponent_statement}" | |
) | |
return prompt + custom_prompt_part | |
def _generate_debate_prompt(self, speaker_name: str) -> str: | |
"""为辩论模式生成提示""" | |
is_positive_side = (speaker_name == self.pro_model) | |
role = "正方" if is_positive_side else "反方" | |
if not self.messages: | |
base_prompt = f"你将作为{role},就以下话题进行辩论:{self.topic}。请提出你的主要论点,陈述你的核心立场和关键论据。" | |
if self.initial_prompt: | |
if self.initial_prompt_mode == 'override': | |
return self.initial_prompt | |
else: | |
return f"{base_prompt}\n\n另外,请特别注意以下指示:{self.initial_prompt}" | |
return base_prompt | |
else: | |
last_message = self.messages[-1] | |
opponent_statement = last_message.content | |
if len(self.messages) == 1: | |
return f"你将作为{role},就以下话题进行辩论:{self.topic}。你的对手({'正方' if not is_positive_side else '反方'})的开场陈述是:\n\n“{opponent_statement}”\n\n请直接反驳对方的观点,并提出你自己的论点。" | |
else: | |
return f"现在轮到你({role})发言。你的对手刚刚的发言是:\n\n“{opponent_statement}”\n\n请针对他的观点进行反驳,并进一步阐述和强化你自己的立场。" | |
def _generate_discussion_prompt(self, speaker_name: str) -> str: | |
"""为协作讨论模式生成提示""" | |
is_ai1 = speaker_name == self.pro_model | |
ai_role = "AI 1" if is_ai1 else "AI 2" | |
partner_role = "AI 2" if is_ai1 else "AI 1" | |
if not self.messages: | |
base_prompt = f"你将作为 {ai_role},与 {partner_role} 一同协作,探讨如何完成以下任务:'{self.topic}'。请提出你的初步想法、策略或行动计划。" | |
if self.initial_prompt: | |
if self.initial_prompt_mode == 'override': | |
return self.initial_prompt | |
else: | |
return f"{base_prompt}\n\n另外,请特别注意以下指示:{self.initial_prompt}" | |
return base_prompt | |
else: | |
last_message = self.messages[-1] | |
partner_statement = last_message.content | |
return f"现在轮到你 ({ai_role}) 发言。你的合作伙伴 ({partner_role}) 刚刚提出的想法是:\n\n“{partner_statement}”\n\n请基于他的观点进行补充、提出不同角度的看法,或者共同推进任务 '{self.topic}' 的下一步。" | |
class DebateController: | |
"""辩论控制器类""" | |
def __init__(self, model_manager, output_callback: Optional[Callable] = None): | |
""" | |
初始化辩论控制器 | |
Args: | |
model_manager: 模型管理器实例 | |
output_callback: 输出回调函数,用于实时显示辩论内容 | |
""" | |
self.model_manager = model_manager | |
self.output_callback = output_callback | |
self.current_session: Optional[ConversationSession] = None | |
self.debate_thread: Optional[threading.Thread] = None | |
self.stop_event = threading.Event() | |
logger.info("辩论控制器初始化完成") | |
def create_debate(self, topic: str, max_rounds: int = 5, first_model: str = 'glm45') -> ConversationSession: | |
""" | |
创建新的辩论会话 | |
Args: | |
topic: 辩论话题 | |
max_rounds: 最大轮数 | |
first_model: 首发模型 | |
Returns: | |
辩论会话实例 | |
""" | |
self.current_session = ConversationSession(topic, max_rounds, first_model) | |
logger.info(f"创建新辩论会话,话题: {topic}") | |
return self.current_session | |
def start_debate(self): | |
"""开始辩论""" | |
if not self.current_session: | |
logger.error("没有活动的辩论会话") | |
return | |
if self.current_session.is_active: | |
logger.warning("辩论已经在进行中") | |
return | |
self.current_session.is_active = True | |
self.current_session.start_time = datetime.now() | |
self.stop_event.clear() | |
# 启动辩论线程 | |
self.debate_thread = threading.Thread(target=self._debate_loop) | |
self.debate_thread.daemon = True | |
self.debate_thread.start() | |
logger.info("辩论开始") | |
def stop_debate(self): | |
"""停止辩论""" | |
if not self.current_session or not self.current_session.is_active: | |
logger.warning("没有活动的辩论会话") | |
return | |
self.stop_event.set() | |
self.current_session.is_active = False | |
self.current_session.end_time = datetime.now() | |
if self.debate_thread and self.debate_thread.is_alive(): | |
self.debate_thread.join(timeout=5) | |
# 保存辩论记录 | |
self._save_debate_record() | |
logger.info("辩论已停止") | |
def _debate_loop(self): | |
"""辩论主循环""" | |
session = self.current_session | |
if not session: | |
return | |
# 确定模型顺序 | |
model_a = session.pro_model | |
model_b = session.con_model | |
# 获取模型接口 | |
model_a_interface = self.model_manager.get_model(model_a) | |
model_b_interface = self.model_manager.get_model(model_b) | |
# 构建初始提示 | |
initial_prompt_a = f"你将作为正方,就以下话题进行辩论:{session.topic}。请提出你的主要论点。" | |
initial_prompt_b = f"你将作为反方,就以下话题进行辩论:{session.topic}。请针对对方的论点进行反驳。" | |
# 第一轮:模型A提出论点 | |
if not self._should_continue_debate(session): | |
return | |
self._output_message(f"=== 辩论开始 ===\n话题: {session.topic}\n") | |
self._output_message(f"--- 第1轮 ---\n") | |
# 模型A发言 | |
self._output_message(f"{model_a} (正方): ") | |
response_a = model_a_interface.chat([{"role": "user", "content": initial_prompt_a}]) | |
self._output_message(f"{response_a}\n\n") | |
# 记录消息 | |
session.add_message(ConversationMessage("user", initial_prompt_a, "system")) | |
session.add_message(ConversationMessage("assistant", response_a, model_a)) | |
session.current_round = 1 | |
# 后续轮次 | |
while self._should_continue_debate(session): | |
session.current_round += 1 | |
self._output_message(f"--- 第{session.current_round}轮 ---\n") | |
# 模型B反驳 | |
self._output_message(f"{model_b} (反方): ") | |
prompt_b = f"{initial_prompt_b}\n\n对方的论点: {response_a}\n\n请进行反驳。" | |
response_b = model_b_interface.chat(session.get_messages_for_model(model_b) + [{"role": "user", "content": prompt_b}]) | |
self._output_message(f"{response_b}\n\n") | |
# 记录消息 | |
session.add_message(ConversationMessage("user", prompt_b, "system")) | |
session.add_message(ConversationMessage("assistant", response_b, model_b)) | |
# 检查是否应该继续 | |
if not self._should_continue_debate(session): | |
break | |
# 模型A回应 | |
self._output_message(f"{model_a} (正方): ") | |
prompt_a = f"请针对对方的反驳进行回应:{response_b}" | |
response_a = model_a_interface.chat(session.get_messages_for_model(model_a) + [{"role": "user", "content": prompt_a}]) | |
self._output_message(f"{response_a}\n\n") | |
# 记录消息 | |
session.add_message(ConversationMessage("user", prompt_a, "system")) | |
session.add_message(ConversationMessage("assistant", response_a, model_a)) | |
# 辩论结束 | |
session.is_active = False | |
session.end_time = datetime.now() | |
self._output_message("=== 辩论结束 ===\n") | |
# 保存辩论记录 | |
self._save_debate_record() | |
def _should_continue_debate(self, session: ConversationSession) -> bool: | |
"""检查是否应该继续辩论""" | |
if self.stop_event.is_set(): | |
return False | |
if session.current_round >= session.max_rounds: | |
return False | |
return True | |
def _output_message(self, message: str): | |
"""输出消息""" | |
if self.output_callback: | |
self.output_callback(message) | |
else: | |
print(message, end='', flush=True) | |
def _save_debate_record(self): | |
"""保存辩论记录""" | |
if not self.current_session: | |
return | |
# 创建输出文件路径 | |
output_dir = os.path.join(OUTPUT_DIR, "辩论记录") | |
os.makedirs(output_dir, exist_ok=True) | |
file_path = os.path.join(output_dir, f"{self.current_session.debate_id}.json") | |
self.current_session.save_to_file(file_path) | |
logger.info(f"辩论记录已保存: {file_path}") | |
# 测试代码 | |
if __name__ == "__main__": | |
# 简单测试 | |
import sys | |
import os | |
import importlib.util | |
# 动态导入模块 | |
module_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "2_模型接口.py") | |
spec = importlib.util.spec_from_file_location("model_interface", module_path) | |
model_interface = importlib.util.module_from_spec(spec) | |
spec.loader.exec_module(model_interface) | |
# 创建模型管理器 | |
api_key = "ms-b4690538-3224-493a-8f5b-4073d527f788" | |
model_manager = model_interface.ModelManager(api_key) | |
# 创建辩论控制器 | |
def print_callback(message): | |
print(message, end='', flush=True) | |
controller = DebateController(model_manager, print_callback) | |
# 创建辩论会话 | |
session = controller.create_debate("人工智能是否会取代人类的工作", max_rounds=3, first_model='glm45') | |
# 开始辩论 | |
controller.start_debate() | |
# 等待辩论结束 | |
if controller.debate_thread: | |
controller.debate_thread.join() |