Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
""" | |
模型接口模块 | |
用于与两个大模型API进行交互的封装 | |
""" | |
import os | |
import sys | |
import requests | |
import json | |
import time | |
import logging | |
import random | |
from typing import Dict, Any, Optional, Callable | |
from urllib3.exceptions import InsecureRequestWarning | |
# 在代码开头强制设置终端编码为UTF-8(仅在Windows执行) | |
if os.name == 'nt': | |
os.system('chcp 65001 > nul') | |
# 获取当前脚本文件所在目录的绝对路径 | |
current_script_dir = os.path.dirname(os.path.abspath(__file__)) | |
# 项目根目录 (即 '20250907_大模型辩论' 目录, 是 'src' 的上一级) | |
project_root = os.path.dirname(current_script_dir) | |
# 定义数据输入、输出和日志目录 | |
DATA_DIR = os.path.join(project_root, 'data') | |
OUTPUT_DIR = os.path.join(project_root, 'output') | |
LOGS_DIR = os.path.join(project_root, 'logs') | |
# 确保目录存在 | |
os.makedirs(OUTPUT_DIR, exist_ok=True) | |
os.makedirs(LOGS_DIR, exist_ok=True) | |
# 配置日志 | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
handlers=[ | |
logging.FileHandler(os.path.join(LOGS_DIR, '模型接口日志.log'), encoding='utf-8'), | |
logging.StreamHandler(sys.stdout) | |
] | |
) | |
logger = logging.getLogger(__name__) | |
# 捕获警告并记录到日志 | |
logging.captureWarnings(True) | |
class ModelInterface: | |
"""模型接口基类""" | |
def __init__(self, api_key: str, base_url: str): | |
self.api_key = api_key | |
self.base_url = base_url | |
self.headers = { | |
'Authorization': f'Bearer {api_key}', | |
'Content-Type': 'application/json' | |
} | |
logger.info(f"模型接口初始化完成,基础URL: {base_url}") | |
def _should_retry(self, status_code: Optional[int]) -> bool: | |
return status_code == 429 or (status_code is not None and 500 <= status_code < 600) | |
def _compute_backoff_seconds(self, attempt: int, retry_after: Optional[str]) -> float: | |
# 优先使用服务端的 Retry-After | |
if retry_after: | |
try: | |
return float(retry_after) | |
except Exception: | |
pass | |
# 指数退避 + 抖动:1, 2, 4, 8 ... 上限10,并加入 0~300ms 抖动 | |
base = min(10, 2 ** attempt) | |
return base + random.random() * 0.3 | |
def send_request(self, model: str, messages: list, **kwargs) -> Dict[str, Any]: | |
payload = { | |
'model': model, | |
'messages': messages, | |
**kwargs | |
} | |
max_retries = 3 | |
last_exc = None | |
for attempt in range(max_retries + 1): | |
try: | |
logger.info(f"向模型 {model} 发送请求") | |
response = requests.post( | |
f"{self.base_url}/chat/completions", | |
headers=self.headers, | |
json=payload, | |
timeout=60 | |
) | |
if self._should_retry(response.status_code): | |
wait_s = self._compute_backoff_seconds(attempt, response.headers.get('Retry-After')) | |
logger.warning(f"模型 {model} 返回 {response.status_code},{wait_s:.2f}s 后重试 (attempt={attempt})") | |
time.sleep(wait_s) | |
continue | |
response.raise_for_status() | |
result = response.json() | |
logger.info(f"模型 {model} 响应成功") | |
return result | |
except requests.exceptions.RequestException as e: | |
last_exc = e | |
status = getattr(e.response, 'status_code', None) | |
if self._should_retry(status) and attempt < max_retries: | |
wait_s = self._compute_backoff_seconds(attempt, getattr(e.response, 'headers', {}).get('Retry-After')) | |
logger.warning(f"请求异常 {status},{wait_s:.2f}s 后重试 (attempt={attempt})") | |
time.sleep(wait_s) | |
continue | |
logger.error(f"请求模型 {model} 失败: {str(e)}") | |
raise | |
# 如果走到这里,说明重试仍失败 | |
if last_exc: | |
raise last_exc | |
def send_stream_request(self, model: str, messages: list, callback: Callable[[str], None], **kwargs) -> str: | |
payload = { | |
'model': model, | |
'messages': messages, | |
'stream': True, | |
**kwargs | |
} | |
max_retries = 3 | |
last_exc = None | |
for attempt in range(max_retries + 1): | |
full_response = "" | |
try: | |
logger.info(f"向模型 {model} 发送流式请求") | |
response = requests.post( | |
f"{self.base_url}/chat/completions", | |
headers=self.headers, | |
json=payload, | |
timeout=60, | |
stream=True | |
) | |
if self._should_retry(response.status_code): | |
wait_s = self._compute_backoff_seconds(attempt, response.headers.get('Retry-After')) | |
logger.warning(f"模型 {model} 流式返回 {response.status_code},{wait_s:.2f}s 后重试 (attempt={attempt})") | |
time.sleep(wait_s) | |
continue | |
response.raise_for_status() | |
# 处理流式响应 | |
for line in response.iter_lines(): | |
if line: | |
decoded_line = line.decode('utf-8') | |
if decoded_line.startswith("data: "): | |
data = decoded_line[6:] | |
if data != "[DONE]": | |
try: | |
json_data = json.loads(data) | |
content = json_data["choices"][0]["delta"].get("content", "") | |
if content: | |
full_response += content | |
callback(content) | |
except json.JSONDecodeError: | |
pass | |
logger.info(f"模型 {model} 流式响应完成") | |
return full_response | |
except requests.exceptions.RequestException as e: | |
last_exc = e | |
status = getattr(e.response, 'status_code', None) | |
if self._should_retry(status) and attempt < max_retries: | |
wait_s = self._compute_backoff_seconds(attempt, getattr(e.response, 'headers', {}).get('Retry-After')) | |
logger.warning(f"流式请求异常 {status},{wait_s:.2f}s 后重试 (attempt={attempt})") | |
time.sleep(wait_s) | |
continue | |
logger.error(f"流式请求模型 {model} 失败: {str(e)}") | |
raise | |
except Exception as e: | |
last_exc = e | |
logger.error(f"处理流式响应时出错: {str(e)}") | |
raise | |
if last_exc: | |
raise last_exc | |
def get_response_text(self, response: Dict[str, Any]) -> str: | |
try: | |
return response['choices'][0]['message']['content'] | |
except (KeyError, IndexError) as e: | |
logger.error(f"提取回复文本失败: {str(e)}") | |
raise | |
class GLM45Interface(ModelInterface): | |
"""GLM-4.5模型接口""" | |
def __init__(self, api_key: str): | |
super().__init__(api_key, "https://api-inference.modelscope.cn/v1") | |
self.model_name = "ZhipuAI/GLM-4.5" | |
logger.info(f"GLM-4.5模型接口初始化完成") | |
def chat(self, messages: list, temperature: float = 0.7, max_tokens: int = 8000) -> str: | |
try: | |
response = self.send_request( | |
model=self.model_name, | |
messages=messages, | |
temperature=temperature, | |
max_tokens=max_tokens | |
) | |
return self.get_response_text(response) | |
except Exception as e: | |
logger.error(f"GLM-4.5对话失败: {str(e)}") | |
return f"GLM-4.5对话失败: {str(e)}" | |
def chat_stream(self, messages: list, callback: Callable[[str], None], temperature: float = 0.7, max_tokens: int = 8000) -> str: | |
try: | |
return self.send_stream_request( | |
model=self.model_name, | |
messages=messages, | |
callback=callback, | |
temperature=temperature, | |
max_tokens=max_tokens | |
) | |
except Exception as e: | |
logger.error(f"GLM-4.5流式对话失败: {str(e)}") | |
error_msg = f"GLM-4.5流式对话失败: {str(e)}" | |
callback(f"\n{error_msg}\n") | |
return error_msg | |
class DeepSeekV31Interface(ModelInterface): | |
"""DeepSeek-V3.1模型接口""" | |
def __init__(self, api_key: str): | |
super().__init__(api_key, "https://api-inference.modelscope.cn/v1") | |
self.model_name = "deepseek-ai/DeepSeek-V3.1" | |
logger.info(f"DeepSeek-V3.1模型接口初始化完成") | |
def chat(self, messages: list, temperature: float = 0.7, max_tokens: int = 8000) -> str: | |
try: | |
response = self.send_request( | |
model=self.model_name, | |
messages=messages, | |
temperature=temperature, | |
max_tokens=max_tokens | |
) | |
return self.get_response_text(response) | |
except Exception as e: | |
logger.error(f"DeepSeek-V3.1对话失败: {str(e)}") | |
return f"DeepSeek-V3.1对话失败: {str(e)}" | |
def chat_stream(self, messages: list, callback: Callable[[str], None], temperature: float = 0.7, max_tokens: int = 8000) -> str: | |
try: | |
return self.send_stream_request( | |
model=self.model_name, | |
messages=messages, | |
callback=callback, | |
temperature=temperature, | |
max_tokens=max_tokens | |
) | |
except Exception as e: | |
logger.error(f"DeepSeek-V3.1流式对话失败: {str(e)}") | |
error_msg = f"DeepSeek-V3.1流式对话失败: {str(e)}" | |
callback(f"\n{error_msg}\n") | |
return error_msg | |
class QwenInterface(ModelInterface): | |
"""Qwen模型接口""" | |
def __init__(self, api_key: str): | |
super().__init__(api_key, "https://api-inference.modelscope.cn/v1") | |
self.model_name = "Qwen/Qwen3-235B-A22B-Thinking-2507" | |
logger.info(f"Qwen模型接口初始化完成") | |
def chat(self, messages: list, temperature: float = 0.7, max_tokens: int = 8000) -> str: | |
try: | |
response = self.send_request( | |
model=self.model_name, | |
messages=messages, | |
temperature=temperature, | |
max_tokens=max_tokens | |
) | |
return self.get_response_text(response) | |
except Exception as e: | |
logger.error(f"Qwen对话失败: {str(e)}") | |
return f"Qwen对话失败: {str(e)}" | |
def chat_stream(self, messages: list, callback: Callable[[str], None], temperature: float = 0.7, max_tokens: int = 8000) -> str: | |
try: | |
return self.send_stream_request( | |
model=self.model_name, | |
messages=messages, | |
callback=callback, | |
temperature=temperature, | |
max_tokens=max_tokens | |
) | |
except Exception as e: | |
logger.error(f"Qwen流式对话失败: {str(e)}") | |
error_msg = f"Qwen流式对话失败: {str(e)}" | |
callback(f"\n{error_msg}\n") | |
return error_msg | |
# 新增:Qwen Instruct 非思考版 | |
class QwenInstructInterface(ModelInterface): | |
"""Qwen Instruct模型接口(非思考版)""" | |
def __init__(self, api_key: str): | |
super().__init__(api_key, "https://api-inference.modelscope.cn/v1") | |
self.model_name = "Qwen/Qwen3-235B-A22B-Instruct-2507" | |
logger.info(f"Qwen Instruct模型接口初始化完成") | |
def chat(self, messages: list, temperature: float = 0.7, max_tokens: int = 8000) -> str: | |
try: | |
response = self.send_request( | |
model=self.model_name, | |
messages=messages, | |
temperature=temperature, | |
max_tokens=max_tokens | |
) | |
return self.get_response_text(response) | |
except Exception as e: | |
logger.error(f"Qwen Instruct对话失败: {str(e)}") | |
return f"Qwen Instruct对话失败: {str(e)}" | |
def chat_stream(self, messages: list, callback: Callable[[str], None], temperature: float = 0.7, max_tokens: int = 8000) -> str: | |
try: | |
return self.send_stream_request( | |
model=self.model_name, | |
messages=messages, | |
callback=callback, | |
temperature=temperature, | |
max_tokens=max_tokens | |
) | |
except Exception as e: | |
logger.error(f"Qwen Instruct流式对话失败: {str(e)}") | |
error_msg = f"Qwen Instruct流式对话失败: {str(e)}" | |
callback(f"\n{error_msg}\n") | |
return error_msg | |
class ModelManager: | |
"""模型管理器,统一管理两个模型接口""" | |
def __init__(self, api_key: str): | |
self.api_key = api_key | |
self.glm45 = GLM45Interface(api_key) | |
self.deepseek_v31 = DeepSeekV31Interface(api_key) | |
self.qwen = QwenInterface(api_key) | |
self.qwen_instruct = QwenInstructInterface(api_key) | |
logger.info("模型管理器初始化完成") | |
def get_model(self, model_name: str) -> 'ModelInterface': | |
if model_name.lower() == 'glm45': | |
return self.glm45 | |
elif model_name.lower() == 'deepseek_v31': | |
return self.deepseek_v31 | |
elif model_name.lower() == 'qwen': | |
return self.qwen | |
elif model_name.lower() == 'qwen_instruct': | |
return self.qwen_instruct | |
else: | |
logger.error(f"不支持的模型名称: {model_name}") | |
raise ValueError(f"不支持的模型名称: {model_name}") | |
# 测试代码 | |
if __name__ == "__main__": | |
# 测试模型接口 | |
api_key = "ms-b4690538-3224-493a-8f5b-4073d527f788" | |
manager = ModelManager(api_key) | |
# 测试GLM-4.5 | |
glm45 = manager.get_model('glm45') | |
messages = [{"role": "user", "content": "你好,请简单介绍一下自己"}] | |
print("=== 测试GLM-4.5普通对话 ===") | |
response = glm45.chat(messages) | |
print(f"GLM-4.5回复: {response}") | |
print("\n=== 测试GLM-4.5流式对话 ===") | |
def stream_callback(content): | |
print(content, end='', flush=True) | |
response = glm45.chat_stream(messages, stream_callback) | |
print(f"\n完整回复: {response}") | |
# 测试DeepSeek-V3.1 | |
deepseek_v31 = manager.get_model('deepseek_v31') | |
messages = [{"role": "user", "content": "你好,请简单介绍一下自己"}] | |
print("\n=== 测试DeepSeek-V3.1普通对话 ===") | |
response = deepseek_v31.chat(messages) | |
print(f"DeepSeek-V3.1回复: {response}") | |
print("\n=== 测试DeepSeek-V3.1流式对话 ===") | |
response = deepseek_v31.chat_stream(messages, stream_callback) | |
print(f"\n完整回复: {response}") |