Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
""" | |
AI大模型辩论系统Web版本 | |
基于FastAPI的Web应用,提供图形用户界面的辩论系统 | |
""" | |
import os | |
import sys | |
import json | |
import logging | |
from datetime import datetime, timezone, timedelta | |
from typing import Optional, Dict, Any | |
import importlib.util | |
import asyncio | |
import socket | |
import uuid | |
import math | |
import re | |
from fastapi import Query | |
from fastapi.responses import HTMLResponse, StreamingResponse | |
from urllib.parse import quote # 新增:导入URL编码工具 | |
# 在代码开头强制设置终端编码为UTF-8(仅在Windows执行) | |
if os.name == 'nt': | |
os.system('chcp 65001 > nul') | |
# 获取当前脚本文件所在目录的绝对路径 | |
current_script_dir = os.path.dirname(os.path.abspath(__file__)) | |
# 项目根目录 (即 '20250907_大模型辩论' 目录, 是 'src' 的上一级) | |
project_root = os.path.dirname(current_script_dir) | |
# 定义数据输入、输出和日志目录 | |
DATA_DIR = os.path.join(project_root, 'data') | |
OUTPUT_DIR = os.path.join(project_root, 'output') | |
LOGS_DIR = os.path.join(project_root, 'logs') | |
# 确保目录存在 | |
os.makedirs(OUTPUT_DIR, exist_ok=True) | |
os.makedirs(LOGS_DIR, exist_ok=True) | |
# 配置日志 | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
handlers=[ | |
logging.FileHandler(os.path.join(LOGS_DIR, '对话系统Web日志.log'), encoding='utf-8'), | |
logging.StreamHandler(sys.stdout) | |
] | |
) | |
logger = logging.getLogger(__name__) | |
# 捕获警告并记录到日志 | |
logging.captureWarnings(True) | |
# 导入FastAPI相关模块 | |
try: | |
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect | |
from fastapi.responses import HTMLResponse | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.templating import Jinja2Templates | |
import uvicorn | |
from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware | |
logger.info("FastAPI模块导入成功") | |
except ImportError as e: | |
logger.error(f"导入FastAPI模块失败: {str(e)}") | |
print("请安装FastAPI: pip install fastapi uvicorn") | |
sys.exit(1) | |
# 导入自定义模块 | |
try: | |
# 动态导入模型接口模块 | |
model_interface_path = os.path.join(current_script_dir, "model_interface.py") | |
model_interface_spec = importlib.util.spec_from_file_location("model_interface", model_interface_path) | |
model_interface = importlib.util.module_from_spec(model_interface_spec) | |
model_interface_spec.loader.exec_module(model_interface) | |
# 动态导入对话控制器模块 | |
debate_controller_path = os.path.join(current_script_dir, "debate_controller.py") | |
debate_controller_spec = importlib.util.spec_from_file_location("debate_controller", debate_controller_path) | |
debate_controller = importlib.util.module_from_spec(debate_controller_spec) | |
debate_controller_spec.loader.exec_module(debate_controller) | |
# 从模块中获取需要的类 | |
ModelManager = model_interface.ModelManager | |
ConversationMessage = debate_controller.ConversationMessage | |
ConversationSession = debate_controller.ConversationSession | |
except Exception as e: | |
logger.error(f"导入模块失败: {str(e)}") | |
sys.exit(1) | |
# 创建FastAPI应用 | |
app = FastAPI(title="AI大模型对话系统", description="基于FastAPI的AI大模型对话系统Web版本") | |
# 添加中间件以处理反向代理头信息 | |
app.add_middleware(ProxyHeadersMiddleware, trusted_hosts="*") | |
# 设置静态文件目录 | |
static_dir = os.path.join(project_root, "static") | |
app.mount("/static", StaticFiles(directory=static_dir), name="static") | |
# 设置模板目录 | |
templates_dir = os.path.join(project_root, "templates") | |
templates = Jinja2Templates(directory=templates_dir) | |
# 补回:连接管理器 | |
class ConnectionManager: | |
"""连接管理器,负责管理所有WebSocket连接及其状态""" | |
def __init__(self): | |
self.active_connections: Dict[WebSocket, Dict[str, Any]] = {} | |
async def connect(self, websocket: WebSocket): | |
await websocket.accept() | |
self.active_connections[websocket] = {"session": None, "task": None, "judge_task": None, "conv_id": None, "judge_id": None} | |
logger.info(f"新连接建立: {websocket.client}. 当前总连接数: {len(self.active_connections)}") | |
def disconnect(self, websocket: WebSocket): | |
if websocket in self.active_connections: | |
task = self.active_connections[websocket].get("task") | |
if task and not task.done(): | |
task.cancel() | |
logger.info(f"连接 {websocket.client} 的对话任务已被取消。") | |
jtask = self.active_connections[websocket].get("judge_task") | |
if jtask and not jtask.done(): | |
jtask.cancel() | |
logger.info(f"连接 {websocket.client} 的评判任务已被取消。") | |
del self.active_connections[websocket] | |
logger.info(f"连接断开: {websocket.client}. 当前总连接数: {len(self.active_connections)}") | |
def get_session(self, websocket: WebSocket) -> Optional[ConversationSession]: | |
return self.active_connections.get(websocket, {}).get("session") | |
def get_task(self, websocket: WebSocket) -> Optional[asyncio.Task]: | |
return self.active_connections.get(websocket, {}).get("task") | |
def get_judge_task(self, websocket: WebSocket) -> Optional[asyncio.Task]: | |
return self.active_connections.get(websocket, {}).get("judge_task") | |
def set_conv_id(self, websocket: WebSocket, conv_id: str): | |
if websocket in self.active_connections: | |
self.active_connections[websocket]["conv_id"] = conv_id | |
def get_conv_id(self, websocket: WebSocket) -> Optional[str]: | |
return self.active_connections.get(websocket, {}).get("conv_id") | |
def set_judge_id(self, websocket: WebSocket, judge_id: str): | |
if websocket in self.active_connections: | |
self.active_connections[websocket]["judge_id"] = judge_id | |
def get_current_judge_id(self, websocket: WebSocket) -> Optional[str]: | |
return self.active_connections.get(websocket, {}).get("judge_id") | |
def set_conversation(self, websocket: WebSocket, session: ConversationSession, task: asyncio.Task): | |
if websocket in self.active_connections: | |
self.active_connections[websocket]["session"] = session | |
self.active_connections[websocket]["task"] = task | |
def set_judge_task(self, websocket: WebSocket, task: asyncio.Task): | |
if websocket in self.active_connections: | |
self.active_connections[websocket]["judge_task"] = task | |
def clear_conversation(self, websocket: WebSocket): | |
if websocket in self.active_connections: | |
self.active_connections[websocket]["session"] = None | |
self.active_connections[websocket]["task"] = None | |
def cancel_all(self, websocket: WebSocket): | |
if websocket in self.active_connections: | |
t = self.active_connections[websocket].get("task") | |
if t and not t.done(): | |
t.cancel() | |
jt = self.active_connections[websocket].get("judge_task") | |
if jt and not jt.done(): | |
jt.cancel() | |
# 刷新流ID,确保旧任务输出被丢弃 | |
self.active_connections[websocket]["conv_id"] = str(uuid.uuid4()) | |
self.active_connections[websocket]["judge_id"] = str(uuid.uuid4()) | |
def _get_api_key_from_env() -> str: | |
"""从多种环境变量名中获取API Key""" | |
candidate_keys = [ | |
"MODELSCOPE_API_KEY", "MODELSCOPE_TOKEN", "MS_API_KEY", | |
"MS_TOKEN", "API_KEY" | |
] | |
for k in candidate_keys: | |
v = os.environ.get(k) | |
if v: | |
logger.info(f"已从环境变量 {k} 读取到API Key(不显示具体值)") | |
return v | |
logger.warning("未在环境变量中找到ModelScope API Key,请在Space Secrets中设置 MODELSCOPE_API_KEY") | |
return "" | |
# 实例化连接管理器和模型管理器 | |
manager = ConnectionManager() | |
model_manager = ModelManager(_get_api_key_from_env()) | |
# 新增:内存中的会话备份,避免必须写磁盘 | |
recent_sessions: Dict[str, ConversationSession] = {} | |
# 新增:内存缓存最近一次评判结果 | |
recent_judges: Dict[str, str] = {} | |
async def startup_event(): | |
"""应用启动时执行的事件""" | |
os.makedirs(static_dir, exist_ok=True) | |
os.makedirs(templates_dir, exist_ok=True) | |
create_templates() | |
async def read_root(request: Request): | |
"""主页路由,返回Web界面""" | |
return templates.TemplateResponse("index.html", { | |
"request": request, | |
"title": "AI大模型对话系统" | |
}) | |
async def get_status(): | |
"""获取系统状态""" | |
return { | |
"status": "running", | |
"timestamp": datetime.now().isoformat(), | |
"models": ["glm45", "deepseek_v31", "qwen", "qwen_instruct"], | |
"active_connections": len(manager.active_connections), | |
"has_api_key": bool(_get_api_key_from_env()) | |
} | |
async def websocket_endpoint(websocket: WebSocket): | |
"""WebSocket端点,用于实时通信""" | |
await manager.connect(websocket) | |
try: | |
while True: | |
data = await websocket.receive_text() | |
await handle_websocket_message(websocket, data) | |
except WebSocketDisconnect: | |
manager.disconnect(websocket) | |
except Exception as e: | |
logger.error(f"WebSocket处理错误: {str(e)}", exc_info=True) | |
manager.disconnect(websocket) | |
async def handle_websocket_message(websocket: WebSocket, message: str): | |
"""处理WebSocket消息""" | |
try: | |
data = json.loads(message) | |
action = data.get("action") | |
if action == "start_conversation": | |
# 强行停止正在进行的对话与评判 | |
manager.cancel_all(websocket) | |
# 启动新对话(后台任务) | |
new_task = asyncio.create_task(start_conversation(websocket, data)) | |
# 任务将在 start_conversation 内部注册到 manager 中 | |
elif action == "stop_conversation": | |
await stop_conversation(websocket) | |
elif action == "judge_debate": | |
# 评判改为后台任务,并记录以便可取消 | |
jtask = asyncio.create_task(judge_debate(websocket, data)) | |
manager.set_judge_task(websocket, jtask) | |
elif action == "summarize_collaboration": | |
# 协作总结改为后台任务,并记录以便可取消 | |
stask = asyncio.create_task(summarize_collaboration(websocket, data)) | |
manager.set_judge_task(websocket, stask) | |
else: | |
await websocket.send_text(json.dumps({ | |
"type": "error", | |
"message": f"未知操作: {action}" | |
})) | |
except json.JSONDecodeError: | |
await websocket.send_text(json.dumps({ | |
"type": "error", | |
"message": "无效的JSON格式" | |
})) | |
except Exception as e: | |
logger.error(f"处理WebSocket消息时出错: {str(e)}", exc_info=True) | |
await websocket.send_text(json.dumps({ | |
"type": "error", | |
"message": f"处理消息时出错: {str(e)}" | |
})) | |
async def start_conversation(websocket: WebSocket, data: dict): | |
"""开始一场独立的对话""" | |
loop = asyncio.get_event_loop() | |
session = None | |
try: | |
topic = data.get("topic", "真与善谁更重要?") | |
mode = data.get("mode", "debate") | |
rounds = int(data.get("rounds", 3)) | |
pro_model_name = data.get("pro_model", "deepseek_v31") | |
con_model_name = data.get("con_model", "qwen_instruct") | |
initial_prompt = data.get("initial_prompt", "").strip() | |
initial_prompt_mode = data.get("initial_prompt_mode", "append") | |
save_enabled = bool(data.get("save_records", False)) | |
session = ConversationSession( | |
topic=topic, mode=mode, max_rounds=rounds, | |
pro_model=pro_model_name, con_model=con_model_name, | |
initial_prompt=initial_prompt, initial_prompt_mode=initial_prompt_mode | |
) | |
# 将新创建的session和当前任务关联到这个websocket连接 | |
manager.set_conversation(websocket, session, asyncio.current_task()) | |
# 刷新对话流ID | |
conv_stream_id = str(uuid.uuid4()) | |
manager.set_conv_id(websocket, conv_stream_id) | |
await websocket.send_text(json.dumps({ | |
"type": "conversation_started", "message": "对话已开始", | |
"topic": topic, "mode": mode, "rounds": rounds, | |
"pro_model": pro_model_name, "con_model": con_model_name, | |
"debate_id": session.debate_id | |
})) | |
model_a = model_manager.get_model(pro_model_name) | |
model_b = model_manager.get_model(con_model_name) | |
session.start_time = datetime.now() | |
speakers = [(pro_model_name, model_a), (con_model_name, model_b)] | |
for i in range(rounds * 2): | |
round_num = (i // 2) + 1 | |
if i % 2 == 0: | |
session.current_round = round_num | |
await websocket.send_text(json.dumps({ "type": "round_info", "message": f"—— 第({round_num}/{rounds})轮 ——" })) | |
speaker_name, speaker_model = speakers[i % 2] | |
is_pro = (i % 2 == 0) | |
role = ("正方" if mode == 'debate' else "AI 1") if is_pro else ("反方" if mode == 'debate' else "AI 2") | |
await websocket.send_text(json.dumps({ "type": "model_speaking", "model": speaker_name, "role": role })) | |
prompt = session.generate_prompt(speaker_name) | |
response_content = "" | |
def stream_callback(content): | |
nonlocal response_content | |
# 丢弃旧流输出 | |
if manager.get_conv_id(websocket) != conv_stream_id: | |
return | |
response_content += content | |
asyncio.run_coroutine_threadsafe( | |
websocket.send_text(json.dumps({"type": "stream_content", "content": content})), loop | |
) | |
await loop.run_in_executor(None, speaker_model.chat_stream, session.get_messages_for_model(speaker_name) + [{"role": "user", "content": prompt}], stream_callback) | |
await websocket.send_text(json.dumps({"type": "stream_end", "model": speaker_name})) | |
session.add_message(ConversationMessage("user", prompt, "system")) | |
session.add_message(ConversationMessage("assistant", response_content, speaker_name)) | |
save_conversation_record(session, save_enabled) | |
# 新增:实时更新内存备份 | |
recent_sessions[session.debate_id] = session | |
if session.is_active: | |
session.end_time = datetime.now() | |
session.is_active = False | |
await websocket.send_text(json.dumps({ "type": "conversation_ended", "message": "=== 对话结束 ===" })) | |
save_conversation_record(session, save_enabled) | |
# 新增:结束时更新内存备份 | |
recent_sessions[session.debate_id] = session | |
logger.info(f"对话 {session.debate_id} 正常结束。") | |
except asyncio.CancelledError: | |
if session: | |
logger.info(f"对话任务 {session.debate_id} 被取消。") | |
session.is_active = False | |
session.end_time = datetime.now() | |
save_conversation_record(session, bool(data.get("save_records", False))) | |
# 新增:取消时也更新内存备份 | |
recent_sessions[session.debate_id] = session | |
await websocket.send_text(json.dumps({ | |
"type": "conversation_stopped", "message": "对话已停止" | |
})) | |
raise | |
except Exception as e: | |
logger.error(f"对话过程中出错: {str(e)}", exc_info=True) | |
await websocket.send_text(json.dumps({ "type": "error", "message": f"对话过程中出错: {str(e)}" })) | |
finally: | |
manager.clear_conversation(websocket) | |
logger.info("连接的对话会话清理完毕。") | |
async def judge_debate(websocket: WebSocket, data: dict): | |
"""评判一场指定的对话""" | |
loop = asyncio.get_event_loop() | |
try: | |
judge_model_name = data.get("judge_model", "qwen_instruct") | |
debate_id = data.get("debate_id") | |
save_enabled = bool(data.get("save_records", False)) | |
if not debate_id: | |
await websocket.send_text(json.dumps({"type": "error", "message": "缺少对话ID无法评判。"})) | |
return | |
file_path = os.path.join(OUTPUT_DIR, "对话记录", f"{debate_id}.json") | |
session_to_judge = None | |
if os.path.exists(file_path): | |
session_to_judge = ConversationSession.load_from_file(file_path) | |
else: | |
# 新增:优先使用内存备份,避免必须写磁盘 | |
session_to_judge = recent_sessions.get(debate_id) | |
if not session_to_judge: | |
await websocket.send_text(json.dumps({"type": "error", "message": f"找不到对话记录: {debate_id}(内存与磁盘均不存在)"})) | |
return # 修复:此 return 必须在 if 块内部 | |
judge_model = model_manager.get_model(judge_model_name) | |
assistant_messages = [msg for msg in session_to_judge.messages if msg.role == 'assistant'] | |
# 构造带轮次的实录 | |
total_msgs = len(assistant_messages) | |
actual_rounds = math.ceil(total_msgs / 2) if total_msgs > 0 else 0 | |
transcript_parts = [] | |
total_rounds = session_to_judge.max_rounds | |
for r in range(actual_rounds): | |
transcript_parts.append(f"—— 第({r+1}/{total_rounds})轮 ——") | |
pro_idx = r * 2 | |
con_idx = r * 2 + 1 | |
if pro_idx < total_msgs: | |
pro_model = session_to_judge.pro_model | |
transcript_parts.append(f"正方 ({pro_model}): {assistant_messages[pro_idx].content}") | |
if con_idx < total_msgs: | |
con_model = session_to_judge.con_model | |
transcript_parts.append(f"反方 ({con_model}): {assistant_messages[con_idx].content}") | |
debate_transcript = "\n\n".join(transcript_parts) | |
if session_to_judge.mode == 'debate': | |
judge_prompt = ( | |
f"你是一位专业的辩论评审。请基于以下对话实录,对双方的辩论表现进行专业评判。\n\n" | |
f"对话模式:辩论\n" | |
f"辩论话题:{session_to_judge.topic}\n" | |
f"正方(AI 1):{session_to_judge.pro_model}\n" | |
f"反方(AI 2):{session_to_judge.con_model}\n\n" | |
f"实际轮次:{actual_rounds}(禁止杜撰额外轮次)\n\n" | |
f"对话实录:\n{debate_transcript}\n\n" | |
f"请严格按照以下步骤进行分析并输出:\n" | |
f"1. **结论先行(1-2句)**:直接指出哪一方更胜一筹,并给出最关键的1-2条理由。\n" | |
f"2. **维度对比表(Markdown 表格)**:从至少六个维度对双方评分/评述:立场清晰度、论据扎实度、反驳力度、逻辑结构、证据引用/事实性、聚焦度(针对性)。最后一列写明该维度的优势方。\n" | |
f"3. **证据引用**:逐点引用原文并标注\"第X轮-正/反\"来支撑判定,禁止引用不存在的轮次。\n" | |
f"4. **改进建议**:分别给正反双方各2-3条可执行的改进建议。" | |
) | |
else: | |
judge_prompt = ( | |
f"你是一位专业的协作评审。请基于以下对话实录,评估两位 AI 的协作质量与产出。\n\n" | |
f"对话模式:协作讨论\n" | |
f"协作任务:{session_to_judge.topic}\n" | |
f"AI 1:{session_to_judge.pro_model}\n" | |
f"AI 2:{session_to_judge.con_model}\n\n" | |
f"实际轮次:{actual_rounds}(禁止杜撰额外轮次)\n\n" | |
f"对话实录:\n{debate_transcript}\n\n" | |
f"请严格按照以下步骤进行分析并输出:\n" | |
f"1. **总体评估(1-2句)**:先给出任务完成度与协作有效性的总体判断。\n" | |
f"2. **协作维度表(Markdown 表格)**:从至少六个维度评述:目标对齐、信息共享/互补、方案可行性、风险识别、推进计划(时间/里程碑)、个人贡献度。最后一列说明哪一方贡献更关键。\n" | |
f"3. **行动计划**:给出一份精炼的下一步行动清单(里程碑+负责人+时间节点)。引用原文时请标注\"第X轮-参与者\"。\n" | |
f"4. **改进建议**:指出影响协作效率的关键瓶颈,并给出2-3条可落地的改进建议。" | |
) | |
# 刷新评判流ID | |
judge_stream_id = str(uuid.uuid4()) | |
manager.set_judge_id(websocket, judge_stream_id) | |
await websocket.send_text(json.dumps({"type": "judge_started", "model": judge_model_name})) | |
response_content = "" | |
def stream_callback(content): | |
nonlocal response_content | |
# 丢弃旧流输出 | |
if manager.get_current_judge_id(websocket) != judge_stream_id: | |
return | |
response_content += content | |
asyncio.run_coroutine_threadsafe( | |
websocket.send_text(json.dumps({"type": "judge_stream_content", "content": content})), loop | |
) | |
await loop.run_in_executor(None, judge_model.chat_stream, [{"role": "user", "content": judge_prompt}], stream_callback) | |
await websocket.send_text(json.dumps({"type": "judge_stream_end", "model": judge_model_name})) | |
logger.info(f"模型 {judge_model_name} 已完成评判。") | |
# 可选:保存评判结果 | |
if save_enabled: | |
try: | |
judge_dir = os.path.join(OUTPUT_DIR, "评判记录") | |
os.makedirs(judge_dir, exist_ok=True) | |
judge_file = os.path.join(judge_dir, f"{debate_id}_judge_{judge_model_name}.md") | |
with open(judge_file, "w", encoding="utf-8") as jf: | |
jf.write(f"# 评判结果\n\n对话ID: {debate_id}\n\n评判模型: {judge_model_name}\n\n---\n\n") | |
jf.write(response_content) | |
logger.info(f"评判结果已保存: {judge_file}") | |
except Exception as e: | |
logger.error(f"保存评判结果失败: {str(e)}") | |
# 新增:写入内存缓存,便于导出 | |
recent_judges[debate_id] = response_content | |
except Exception as e: | |
logger.error(f"评判过程中出错: {str(e)}", exc_info=True) | |
await websocket.send_text(json.dumps({"type": "error", "message": f"评判过程中出错: {str(e)}"})) | |
async def summarize_collaboration(websocket: WebSocket, data: dict): | |
"""总结协作任务的对话内容""" | |
loop = asyncio.get_event_loop() | |
try: | |
summary_model_name = data.get("summary_model", "qwen_instruct") | |
debate_id = data.get("debate_id") | |
save_enabled = bool(data.get("save_records", False)) | |
if not debate_id: | |
await websocket.send_text(json.dumps({"type": "error", "message": "缺少对话ID无法总结。"})) | |
return | |
file_path = os.path.join(OUTPUT_DIR, "对话记录", f"{debate_id}.json") | |
session_to_summarize = None | |
if os.path.exists(file_path): | |
session_to_summarize = ConversationSession.load_from_file(file_path) | |
else: | |
# 优先使用内存备份,避免必须写磁盘 | |
session_to_summarize = recent_sessions.get(debate_id) | |
if not session_to_summarize: | |
await websocket.send_text(json.dumps({"type": "error", "message": f"找不到对话记录: {debate_id}(内存与磁盘均不存在)"})) | |
return | |
summary_model = model_manager.get_model(summary_model_name) | |
assistant_messages = [msg for msg in session_to_summarize.messages if msg.role == 'assistant'] | |
# 构造带轮次的实录 | |
total_msgs = len(assistant_messages) | |
actual_rounds = math.ceil(total_msgs / 2) if total_msgs > 0 else 0 | |
transcript_parts = [] | |
total_rounds = session_to_summarize.max_rounds | |
for r in range(actual_rounds): | |
transcript_parts.append(f"—— 第({r+1}/{total_rounds})轮 ——") | |
pro_idx = r * 2 | |
con_idx = r * 2 + 1 | |
if pro_idx < total_msgs: | |
pro_model = session_to_summarize.pro_model | |
transcript_parts.append(f"AI 1 ({pro_model}): {assistant_messages[pro_idx].content}") | |
if con_idx < total_msgs: | |
con_model = session_to_summarize.con_model | |
transcript_parts.append(f"AI 2 ({con_model}): {assistant_messages[con_idx].content}") | |
collaboration_transcript = "\n\n".join(transcript_parts) | |
summary_prompt = ( | |
f"你是一位专业的会议记录员和内容总结专家。请基于以下协作对话实录,对两位AI的讨论内容进行全面总结。\n\n" | |
f"对话模式:协作讨论\n" | |
f"协作任务:{session_to_summarize.topic}\n" | |
f"AI 1:{session_to_summarize.pro_model}\n" | |
f"AI 2:{session_to_summarize.con_model}\n\n" | |
f"实际轮次:{actual_rounds}(禁止杜撰额外轮次)\n\n" | |
f"对话实录:\n{collaboration_transcript}\n\n" | |
f"请严格按照以下步骤进行分析并输出:\n" | |
f"1. **任务概述**:简要描述协作任务的目标和背景。\n" | |
f"2. **主要观点**:总结两位AI在讨论中提出的主要观点和想法,按主题分类。\n" | |
f"3. **达成共识**:列出两位AI在讨论中达成的一致意见和共识。\n" | |
f"4. **分歧点**:指出两位AI在讨论中存在的不同意见或分歧。\n" | |
f"5. **最终结论**:总结协作讨论的最终结论或成果。\n" | |
f"6. **后续建议**:提出基于当前讨论的后续行动建议或需要进一步探讨的问题。" | |
) | |
# 刷新总结流ID | |
summary_stream_id = str(uuid.uuid4()) | |
manager.set_judge_id(websocket, summary_stream_id) | |
await websocket.send_text(json.dumps({"type": "summary_started", "model": summary_model_name})) | |
response_content = "" | |
def stream_callback(content): | |
nonlocal response_content | |
# 丢弃旧流输出 | |
if manager.get_current_judge_id(websocket) != summary_stream_id: | |
return | |
response_content += content | |
asyncio.run_coroutine_threadsafe( | |
websocket.send_text(json.dumps({"type": "summary_stream_content", "content": content})), loop | |
) | |
await loop.run_in_executor(None, summary_model.chat_stream, [{"role": "user", "content": summary_prompt}], stream_callback) | |
await websocket.send_text(json.dumps({"type": "summary_stream_end", "model": summary_model_name})) | |
logger.info(f"模型 {summary_model_name} 已完成协作总结。") | |
# 可选:保存总结结果 | |
if save_enabled: | |
try: | |
summary_dir = os.path.join(OUTPUT_DIR, "总结记录") | |
os.makedirs(summary_dir, exist_ok=True) | |
summary_file = os.path.join(summary_dir, f"{debate_id}_summary_{summary_model_name}.md") | |
with open(summary_file, "w", encoding="utf-8") as sf: | |
sf.write(f"# 协作总结\n\n对话ID: {debate_id}\n\n总结模型: {summary_model_name}\n\n---\n\n") | |
sf.write(response_content) | |
logger.info(f"协作总结已保存: {summary_file}") | |
except Exception as e: | |
logger.error(f"保存协作总结失败: {str(e)}") | |
# 写入内存缓存,便于导出 | |
recent_judges[debate_id] = response_content | |
except Exception as e: | |
logger.error(f"总结过程中出错: {str(e)}", exc_info=True) | |
await websocket.send_text(json.dumps({"type": "error", "message": f"总结过程中出错: {str(e)}"})) | |
async def stop_conversation(websocket: WebSocket): | |
"""停止当前连接的对话""" | |
task = manager.get_task(websocket) | |
if task and not task.done(): | |
task.cancel() | |
logger.info("发送取消请求到对话任务。") | |
else: | |
logger.warning("请求停止对话,但没有活动的任务。") | |
await websocket.send_text(json.dumps({"type": "info", "message": "没有正在进行的对话可供停止。"})) | |
def save_conversation_record(session: ConversationSession, save_enabled: bool): | |
"""保存指定的对话记录(按需)""" | |
if not save_enabled: | |
return | |
if session: | |
try: | |
output_dir = os.path.join(OUTPUT_DIR, "对话记录") | |
os.makedirs(output_dir, exist_ok=True) | |
file_path = os.path.join(output_dir, f"{session.debate_id}.json") | |
session.save_to_file(file_path) | |
logger.info(f"对话记录已保存: {file_path}") | |
except Exception as e: | |
logger.error(f"保存对话记录时出错: {str(e)}") | |
def create_templates(): | |
"""创建HTML模板文件""" | |
template_path = os.path.join(templates_dir, "index.html") | |
# 始终覆盖模板文件,确保其与代码同步 | |
index_html = """ | |
<!DOCTYPE html> | |
<html lang="zh-CN"> | |
<head> | |
<meta charset="UTF-8"> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
<title>AI大模型对话系统</title> | |
<link rel="preconnect" href="https://fonts.googleapis.com"> | |
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin> | |
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&family=Noto+Sans+SC:wght@400;500;700&display=swap" rel="stylesheet"> | |
<link rel="stylesheet" href="{{ url_for('static', path='css/style.css') }}?v=20250908a"> | |
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script> | |
<script src="https://cdn.jsdelivr.net/npm/dompurify/dist/purify.min.js"></script> | |
</head> | |
<body> | |
<div class="container"> | |
<div class="header"> | |
<div class="header-title"> | |
<h1>AI大模型对话系统</h1> | |
<p>观看两个AI大模型实时对话</p> | |
</div> | |
<div class="header-controls"> | |
<select id="exportFormat"> | |
<option value="md" selected>Markdown (.md)</option> | |
<option value="json">JSON (.json)</option> | |
<option value="txt">文本 (.txt)</option> | |
</select> | |
<button id="exportAllBtn" disabled>导出记录</button> | |
</div> | |
</div> | |
<div class="main-layout"> | |
<div class="sidebar"> | |
<h3 class="panel-title">设置</h3> | |
<div class="control-group"> | |
<label>对话模式</label> | |
<div class="radio-group"> | |
<input type="radio" id="modeDebate" name="mode" value="debate" checked> | |
<label for="modeDebate">辩论</label> | |
<input type="radio" id="modeDiscussion" name="mode" value="discussion"> | |
<label for="modeDiscussion">协作讨论</label> | |
</div> | |
</div> | |
<div class="control-group"><label for="topic">对话任务/话题</label><input type="text" id="topic" value="真与善谁更重要?"></div> | |
<div class="control-group"> | |
<label for="initialPrompt">自定义初始提示 (可选)</label> | |
<textarea id="initialPrompt" rows="6" placeholder="默认提示示例:'你将作为正方,就[话题]进行辩论...'。你可以在此输入额外指示(默认追加),或选择覆盖默认提示。"></textarea> | |
</div> | |
<div class="control-group prompt-mode-group"> | |
<label>提示词模式</label> | |
<div class="radio-group"> | |
<input type="radio" id="promptAppend" name="promptMode" value="append" checked> | |
<label for="promptAppend">追加</label> | |
<input type="radio" id="promptOverride" name="promptMode" value="override"> | |
<label for="promptOverride">覆盖</label> | |
</div> | |
</div> | |
<div class="control-group"><label for="rounds">轮数</label><input type="number" id="rounds" min="1" max="10" value="3"></div> | |
<div class="control-group"> | |
<label for="proModel">AI 1 (正方)</label> | |
<select id="proModel"> | |
<option value="deepseek_v31" selected>deepseek-ai/DeepSeek-V3.1</option> | |
<option value="qwen_instruct">Qwen/Qwen3-235B-Instruct</option> | |
<option value="qwen">Qwen/Qwen3-235B-Thinking</option> | |
<option value="glm45">ZhipuAI/GLM-4.5</option> | |
</select> | |
</div> | |
<div class="control-group"> | |
<label for="conModel">AI 2 (反方)</label> | |
<select id="conModel"> | |
<option value="deepseek_v31">deepseek-ai/DeepSeek-V3.1</option> | |
<option value="qwen_instruct" selected>Qwen/Qwen3-235B-Instruct</option> | |
<option value="qwen">Qwen/Qwen3-235B-Thinking</option> | |
<option value="glm45">ZhipuAI/GLM-4.5</option> | |
</select> | |
</div> | |
<div class="controls"> | |
<button id="startBtn" disabled>开始对话</button> | |
<button id="stopBtn" disabled>停止对话</button> | |
</div> | |
</div> | |
<div class="chat-area"> | |
<div class="conversation-wrapper"> | |
<h3 class="panel-title">对话区</h3> | |
<div id="output" class="output-container"></div> | |
</div> | |
<div id="judge-section" class="judge-section"> | |
<h3 class="panel-title" id="judgeSectionTitle">评判区</h3> | |
<div class="judge-controls"> | |
<label for="judgeModel">选择总结模型</label> | |
<select id="judgeModel"> | |
<option value="deepseek_v31">deepseek-ai/DeepSeek-V3.1</option> | |
<option value="qwen_instruct" selected>Qwen/Qwen3-235B-Instruct</option> | |
<option value="qwen">Qwen/Qwen3-235B-Thinking</option> | |
<option value="glm45">ZhipuAI/GLM-4.5</option> | |
</select> | |
<button id="judgeBtn" disabled>评判双方辩论表现</button> | |
<button id="summaryBtn" disabled style="display:none;">总结对话</button> | |
</div> | |
<div id="judge-output" class="output-container judge-output"></div> | |
</div> | |
</div> | |
</div> | |
</div> | |
<script src="{{ url_for('static', path='js/script.js') }}?v=20250908a"></script> | |
</body> | |
</html>""" | |
with open(template_path, "w", encoding="utf-8") as f: | |
f.write(index_html) | |
logger.info("Web应用模板文件 'index.html' 已被强制更新。") | |
def get_local_ip(): | |
"""获取本机局域网IP地址""" | |
try: | |
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: | |
# 连接到一个公共DNS服务器的IP(不会真的发送数据) | |
s.connect(("8.8.8.8", 80)) | |
return s.getsockname()[0] | |
except Exception: | |
return "127.0.0.1" # 如果获取失败,返回本地回环地址 | |
if __name__ == "__main__": | |
# 本地运行时,FastAPI的启动事件也会触发,所以模板创建是安全的 | |
# 智能端口切换:优先使用环境变量PORT,否则默认为8000 | |
port = int(os.environ.get("PORT", 8000)) | |
local_ip = get_local_ip() | |
logger.info("="*50) | |
logger.info("AI大模型对话系统已启动") | |
logger.info(f" - 本机访问: http://localhost:{port}") | |
logger.info(f" - 局域网访问: http://{local_ip}:{port}") | |
logger.info("="*50) | |
uvicorn.run(app, host="0.0.0.0", port=port) | |
async def export_records(debate_id: str = Query(...), format: str = Query("md")): | |
"""导出指定对话的记录(对话+最近一次评判),不落盘,直接下载。""" | |
fmt = (format or "md").lower() | |
session = recent_sessions.get(debate_id) | |
if not session: | |
# 兜底:尝试磁盘 | |
file_path = os.path.join(OUTPUT_DIR, "对话记录", f"{debate_id}.json") | |
if os.path.exists(file_path): | |
session = ConversationSession.load_from_file(file_path) | |
if not session: | |
return HTMLResponse(content=f"<pre>找不到对话记录: {debate_id}</pre>", status_code=404) | |
judge_md = recent_judges.get(debate_id, "") | |
# 仅保留双方模型的回答(assistant),不导出提示词/用户消息 | |
assistant_messages = [m for m in session.messages if m.role == 'assistant'] | |
# 导出时间(优先用会话开始时间),统一转换为东八区,精确到分钟 | |
base_dt = session.start_time or datetime.utcnow() | |
if base_dt.tzinfo is None: | |
beijing_dt = base_dt + timedelta(hours=8) | |
else: | |
beijing_dt = base_dt.astimezone(timezone(timedelta(hours=8))) | |
ts = beijing_dt.strftime('%Y-%m-%d %H:%M (UTC+8)') | |
filename_ts = beijing_dt.strftime('%Y%m%d_%H%M%S') | |
# 清理话题作文文件名安全的部分 | |
topic = session.topic or "未命名对话" # 修复:为None或空话题提供默认值 | |
sanitized_topic = re.sub(r'[\\/*?:"<>|]', "_", topic).replace(" ", "_") | |
sanitized_topic = (sanitized_topic[:50] + '...') if len(sanitized_topic) > 50 else sanitized_topic | |
if fmt == 'json': | |
payload = { | |
'debate_id': debate_id, | |
'topic': session.topic, | |
'export_time': ts, | |
'messages': [ | |
{ | |
'role': m.role, | |
'content': m.content, | |
'model': m.model_name, | |
'round': (idx // 2) + 1, | |
'round_total': session.max_rounds, | |
'round_label': f"{(idx // 2) + 1}/{session.max_rounds}" | |
} for idx, m in enumerate(assistant_messages) | |
], | |
'judge_markdown': judge_md | |
} | |
content = json.dumps(payload, ensure_ascii=False, indent=2) | |
filename = f"{filename_ts}_{sanitized_topic}.json" | |
media = "application/json; charset=utf-8" | |
elif fmt == 'txt': | |
parts = [f"对话ID: {debate_id}", f"话题: {session.topic}", f"导出时间: {ts}", "---"] | |
total_rounds = session.max_rounds | |
for i, m in enumerate(assistant_messages): | |
round_no = (i // 2) + 1 | |
# 每轮开始插入分隔 | |
if i % 2 == 0: | |
parts.append(f"—— 第({round_no}/{total_rounds})轮 ——") | |
# 推断角色与模型 | |
is_pro = (i % 2 == 0) | |
model_name = session.pro_model if is_pro else session.con_model | |
role_cn = "正方" if is_pro else "反方" | |
parts.append(f"{role_cn} ({model_name}):\n{m.content}\n") | |
if judge_md: | |
parts.append("\n==== 评判结果 ====") | |
parts.append(md_to_text(judge_md)) | |
content = "\n".join(parts) | |
filename = f"{filename_ts}_{sanitized_topic}.txt" | |
media = "text/plain; charset=utf-8" | |
else: # md | |
parts = ["# 对话记录", f"对话ID: {debate_id}", f"话题: {session.topic}", f"导出时间: {ts}", ""] | |
total_rounds = session.max_rounds | |
for i, m in enumerate(assistant_messages): | |
round_no = (i // 2) + 1 | |
if i % 2 == 0: | |
parts.append(f"### 第({round_no}/{total_rounds})轮") | |
parts.append("") | |
is_pro = (i % 2 == 0) | |
model_name = session.pro_model if is_pro else session.con_model | |
role_cn = "正方" if is_pro else "反方" | |
parts.append(f"**{role_cn} ({model_name})**:") | |
parts.append("") | |
parts.append(m.content) | |
parts.append("") | |
if judge_md: | |
parts += ["---", "# 评判结果", "", judge_md] | |
content = "\n".join(parts) | |
filename = f"{filename_ts}_{sanitized_topic}.md" | |
media = "text/markdown; charset=utf-8" | |
encoded_filename = quote(filename) | |
headers = {"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"} | |
return StreamingResponse(iter([content.encode('utf-8')]), media_type=media, headers=headers) | |
# 简单将Markdown转为纯文本,用于TXT导出 | |
def md_to_text(md: str) -> str: | |
if not md: | |
return "" | |
text = md | |
# 代码块三反引号去除 | |
text = re.sub(r"```[\s\S]*?```", lambda m: re.sub(r"^```.*\n|\n```$", "", m.group(0)), text) | |
# 行内代码 | |
text = re.sub(r"`([^`]+)`", r"\1", text) | |
# 粗体/斜体 | |
text = re.sub(r"\*\*([^*]+)\*\*", r"\1", text) | |
text = re.sub(r"__([^_]+)__", r"\1", text) | |
text = re.sub(r"\*([^*]+)\*", r"\1", text) | |
text = re.sub(r"_([^_]+)_", r"\1", text) | |
# 链接与图片 | |
text = re.sub(r"!\[([^\]]*)\]\([^)]*\)", r"\1", text) | |
text = re.sub(r"\[([^\]]+)\]\([^)]*\)", r"\1", text) | |
# 标题、引用、水平线 | |
text = re.sub(r"^>\s*", "", text, flags=re.MULTILINE) | |
text = re.sub(r"^#{1,6}\s*", "", text, flags=re.MULTILINE) | |
text = re.sub(r"^\s*-{3,}\s*$", "", text, flags=re.MULTILINE) | |
# 表格竖线与分隔行 | |
text = re.sub(r"^\|?\s*-+\s*(\|\s*-+\s*)+\|?\s*$", "", text, flags=re.MULTILINE) | |
text = text.replace("|", " |") | |
return text |