Spaces:
Paused
Paused
提交
Browse files- bot/bot.py +13 -0
- bot/bot_factory.py +26 -0
- bot/chatgpt/chat_gpt_bot.py +130 -0
bot/bot.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Auto-replay chat robot abstract class
|
3 |
+
"""
|
4 |
+
|
5 |
+
|
6 |
+
class Bot(object):
|
7 |
+
def reply(self, query, context=None):
|
8 |
+
"""
|
9 |
+
bot auto-reply content
|
10 |
+
:param req: received message
|
11 |
+
:return: reply content
|
12 |
+
"""
|
13 |
+
raise NotImplementedError
|
bot/bot_factory.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
channel factory
|
3 |
+
"""
|
4 |
+
|
5 |
+
|
6 |
+
def create_bot(bot_type):
|
7 |
+
"""
|
8 |
+
create a channel instance
|
9 |
+
:param channel_type: channel type code
|
10 |
+
:return: channel instance
|
11 |
+
"""
|
12 |
+
if bot_type == 'baidu':
|
13 |
+
# Baidu Unit对话接口
|
14 |
+
from bot.baidu.baidu_unit_bot import BaiduUnitBot
|
15 |
+
return BaiduUnitBot()
|
16 |
+
|
17 |
+
elif bot_type == 'chatGPT':
|
18 |
+
# ChatGPT 网页端web接口
|
19 |
+
from bot.chatgpt.chat_gpt_bot import ChatGPTBot
|
20 |
+
return ChatGPTBot()
|
21 |
+
|
22 |
+
elif bot_type == 'openAI':
|
23 |
+
# OpenAI 官方对话模型API
|
24 |
+
from bot.openai.open_ai_bot import OpenAIBot
|
25 |
+
return OpenAIBot()
|
26 |
+
raise RuntimeError
|
bot/chatgpt/chat_gpt_bot.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# encoding:utf-8
|
2 |
+
|
3 |
+
from bot.bot import Bot
|
4 |
+
from config import conf
|
5 |
+
from common.log import logger
|
6 |
+
import openai
|
7 |
+
import time
|
8 |
+
|
9 |
+
user_session = dict()
|
10 |
+
|
11 |
+
# OpenAI对话模型API (可用)
|
12 |
+
class ChatGPTBot(Bot):
|
13 |
+
def __init__(self):
|
14 |
+
openai.api_key = "sk-R3HlMsYBk0NpAlLu2aA4B19054Ea4884A2Cf93D25662243d"
|
15 |
+
openai.api_base="https://apai.zyai.online/v1"
|
16 |
+
|
17 |
+
def reply(self, query, context=None):
|
18 |
+
# acquire reply content
|
19 |
+
if not context or not context.get('type') or context.get('type') == 'TEXT':
|
20 |
+
logger.info("[OPEN_AI] query={}".format(query))
|
21 |
+
from_user_id = context['from_user_id']
|
22 |
+
if query == '#清除记忆':
|
23 |
+
Session.clear_session(from_user_id)
|
24 |
+
return '记忆已清除'
|
25 |
+
|
26 |
+
new_query = Session.build_session_query(query, from_user_id)
|
27 |
+
logger.debug("[OPEN_AI] session query={}".format(new_query))
|
28 |
+
|
29 |
+
# if context.get('stream'):
|
30 |
+
# # reply in stream
|
31 |
+
# return self.reply_text_stream(query, new_query, from_user_id)
|
32 |
+
|
33 |
+
reply_content = self.reply_text(new_query, from_user_id, 0)
|
34 |
+
logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content))
|
35 |
+
if reply_content:
|
36 |
+
Session.save_session(query, reply_content, from_user_id)
|
37 |
+
return reply_content
|
38 |
+
|
39 |
+
elif context.get('type', None) == 'IMAGE_CREATE':
|
40 |
+
return self.create_img(query, 0)
|
41 |
+
|
42 |
+
def reply_text(self, query, user_id, retry_count=0):
|
43 |
+
try:
|
44 |
+
response = openai.ChatCompletion.create(
|
45 |
+
model="gpt-3.5-turbo", # 对话模型的名称
|
46 |
+
messages=query,
|
47 |
+
temperature=1, # 值在[0,1]之间,越大表示回复越具有不确定性
|
48 |
+
max_tokens=600, # 回复最大的字符数
|
49 |
+
top_p=1,
|
50 |
+
frequency_penalty=0, # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
51 |
+
presence_penalty=0, # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
52 |
+
)
|
53 |
+
# res_content = response.choices[0]['text'].strip().replace('<|endoftext|>', '')
|
54 |
+
logger.info(response.choices[0]['message']['content'])
|
55 |
+
# log.info("[OPEN_AI] reply={}".format(res_content))
|
56 |
+
return response.choices[0]['message']['content']
|
57 |
+
except openai.error.RateLimitError as e:
|
58 |
+
# rate limit exception
|
59 |
+
logger.warn(e)
|
60 |
+
if retry_count < 3:
|
61 |
+
time.sleep(5)
|
62 |
+
logger.warn("[OPEN_AI] RateLimit exceed, 第{}次重试".format(retry_count+1))
|
63 |
+
return self.reply_text(query, user_id, retry_count+1)
|
64 |
+
else:
|
65 |
+
return "问太快了,慢点行不行"
|
66 |
+
except Exception as e:
|
67 |
+
# unknown exception
|
68 |
+
logger.exception(e)
|
69 |
+
Session.clear_session(user_id)
|
70 |
+
return "没听懂"
|
71 |
+
|
72 |
+
def create_img(self, query, retry_count=0):
|
73 |
+
try:
|
74 |
+
logger.info("[OPEN_AI] image_query={}".format(query))
|
75 |
+
response = openai.Image.create(
|
76 |
+
prompt=query, #图片描述
|
77 |
+
n=1, #每次生成图片的数量
|
78 |
+
size="1024x1024" #图片大小,可选有 256x256, 512x512, 1024x1024
|
79 |
+
)
|
80 |
+
image_url = response['data'][0]['url']
|
81 |
+
logger.info("[OPEN_AI] image_url={}".format(image_url))
|
82 |
+
return image_url
|
83 |
+
except openai.error.RateLimitError as e:
|
84 |
+
logger.warn(e)
|
85 |
+
if retry_count < 3:
|
86 |
+
time.sleep(5)
|
87 |
+
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1))
|
88 |
+
return self.reply_text(query, retry_count+1)
|
89 |
+
else:
|
90 |
+
return "问太快了,慢点行不行"
|
91 |
+
except Exception as e:
|
92 |
+
logger.exception(e)
|
93 |
+
return None
|
94 |
+
|
95 |
+
class Session(object):
|
96 |
+
@staticmethod
|
97 |
+
def build_session_query(query, user_id):
|
98 |
+
'''
|
99 |
+
build query with conversation history
|
100 |
+
e.g. [
|
101 |
+
{"role": "system", "content": "You are a helpful assistant,let's think step by step in multiple different ways."},
|
102 |
+
{"role": "user", "content": "Who won the world series in 2020?"},
|
103 |
+
{"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
|
104 |
+
{"role": "user", "content": "Where was it played?"}
|
105 |
+
]
|
106 |
+
:param query: query content
|
107 |
+
:param user_id: from user id
|
108 |
+
:return: query content with conversaction
|
109 |
+
'''
|
110 |
+
session = user_session.get(user_id, [])
|
111 |
+
if len(session) == 0:
|
112 |
+
system_prompt = conf().get("character_desc", "")
|
113 |
+
system_item = {'role': 'system', 'content': system_prompt}
|
114 |
+
session.append(system_item)
|
115 |
+
user_session[user_id] = session
|
116 |
+
user_item = {'role': 'user', 'content': query}
|
117 |
+
session.append(user_item)
|
118 |
+
return session
|
119 |
+
|
120 |
+
@staticmethod
|
121 |
+
def save_session(query, answer, user_id):
|
122 |
+
session = user_session.get(user_id)
|
123 |
+
if session:
|
124 |
+
# append conversation
|
125 |
+
gpt_item = {'role': 'assistant', 'content': answer}
|
126 |
+
session.append(gpt_item)
|
127 |
+
|
128 |
+
@staticmethod
|
129 |
+
def clear_session(user_id):
|
130 |
+
user_session[user_id] = []
|