kakuguo commited on
Commit
cd9a59b
1 Parent(s): 4dd3cc3
Files changed (3) hide show
  1. bot/bot.py +13 -0
  2. bot/bot_factory.py +26 -0
  3. bot/chatgpt/chat_gpt_bot.py +130 -0
bot/bot.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Auto-replay chat robot abstract class
3
+ """
4
+
5
+
6
+ class Bot(object):
7
+ def reply(self, query, context=None):
8
+ """
9
+ bot auto-reply content
10
+ :param req: received message
11
+ :return: reply content
12
+ """
13
+ raise NotImplementedError
bot/bot_factory.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ channel factory
3
+ """
4
+
5
+
6
+ def create_bot(bot_type):
7
+ """
8
+ create a channel instance
9
+ :param channel_type: channel type code
10
+ :return: channel instance
11
+ """
12
+ if bot_type == 'baidu':
13
+ # Baidu Unit对话接口
14
+ from bot.baidu.baidu_unit_bot import BaiduUnitBot
15
+ return BaiduUnitBot()
16
+
17
+ elif bot_type == 'chatGPT':
18
+ # ChatGPT 网页端web接口
19
+ from bot.chatgpt.chat_gpt_bot import ChatGPTBot
20
+ return ChatGPTBot()
21
+
22
+ elif bot_type == 'openAI':
23
+ # OpenAI 官方对话模型API
24
+ from bot.openai.open_ai_bot import OpenAIBot
25
+ return OpenAIBot()
26
+ raise RuntimeError
bot/chatgpt/chat_gpt_bot.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # encoding:utf-8
2
+
3
+ from bot.bot import Bot
4
+ from config import conf
5
+ from common.log import logger
6
+ import openai
7
+ import time
8
+
9
+ user_session = dict()
10
+
11
+ # OpenAI对话模型API (可用)
12
+ class ChatGPTBot(Bot):
13
+ def __init__(self):
14
+ openai.api_key = "sk-R3HlMsYBk0NpAlLu2aA4B19054Ea4884A2Cf93D25662243d"
15
+ openai.api_base="https://apai.zyai.online/v1"
16
+
17
+ def reply(self, query, context=None):
18
+ # acquire reply content
19
+ if not context or not context.get('type') or context.get('type') == 'TEXT':
20
+ logger.info("[OPEN_AI] query={}".format(query))
21
+ from_user_id = context['from_user_id']
22
+ if query == '#清除记忆':
23
+ Session.clear_session(from_user_id)
24
+ return '记忆已清除'
25
+
26
+ new_query = Session.build_session_query(query, from_user_id)
27
+ logger.debug("[OPEN_AI] session query={}".format(new_query))
28
+
29
+ # if context.get('stream'):
30
+ # # reply in stream
31
+ # return self.reply_text_stream(query, new_query, from_user_id)
32
+
33
+ reply_content = self.reply_text(new_query, from_user_id, 0)
34
+ logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content))
35
+ if reply_content:
36
+ Session.save_session(query, reply_content, from_user_id)
37
+ return reply_content
38
+
39
+ elif context.get('type', None) == 'IMAGE_CREATE':
40
+ return self.create_img(query, 0)
41
+
42
+ def reply_text(self, query, user_id, retry_count=0):
43
+ try:
44
+ response = openai.ChatCompletion.create(
45
+ model="gpt-3.5-turbo", # 对话模型的名称
46
+ messages=query,
47
+ temperature=1, # 值在[0,1]之间,越大表示回复越具有不确定性
48
+ max_tokens=600, # 回复最大的字符数
49
+ top_p=1,
50
+ frequency_penalty=0, # [-2,2]之间,该值越大则更倾向于产生不同的内容
51
+ presence_penalty=0, # [-2,2]之间,该值越大则更倾向于产生不同的内容
52
+ )
53
+ # res_content = response.choices[0]['text'].strip().replace('<|endoftext|>', '')
54
+ logger.info(response.choices[0]['message']['content'])
55
+ # log.info("[OPEN_AI] reply={}".format(res_content))
56
+ return response.choices[0]['message']['content']
57
+ except openai.error.RateLimitError as e:
58
+ # rate limit exception
59
+ logger.warn(e)
60
+ if retry_count < 3:
61
+ time.sleep(5)
62
+ logger.warn("[OPEN_AI] RateLimit exceed, 第{}次重试".format(retry_count+1))
63
+ return self.reply_text(query, user_id, retry_count+1)
64
+ else:
65
+ return "问太快了,慢点行不行"
66
+ except Exception as e:
67
+ # unknown exception
68
+ logger.exception(e)
69
+ Session.clear_session(user_id)
70
+ return "没听懂"
71
+
72
+ def create_img(self, query, retry_count=0):
73
+ try:
74
+ logger.info("[OPEN_AI] image_query={}".format(query))
75
+ response = openai.Image.create(
76
+ prompt=query, #图片描述
77
+ n=1, #每次生成图片的数量
78
+ size="1024x1024" #图片大小,可选有 256x256, 512x512, 1024x1024
79
+ )
80
+ image_url = response['data'][0]['url']
81
+ logger.info("[OPEN_AI] image_url={}".format(image_url))
82
+ return image_url
83
+ except openai.error.RateLimitError as e:
84
+ logger.warn(e)
85
+ if retry_count < 3:
86
+ time.sleep(5)
87
+ logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1))
88
+ return self.reply_text(query, retry_count+1)
89
+ else:
90
+ return "问太快了,慢点行不行"
91
+ except Exception as e:
92
+ logger.exception(e)
93
+ return None
94
+
95
+ class Session(object):
96
+ @staticmethod
97
+ def build_session_query(query, user_id):
98
+ '''
99
+ build query with conversation history
100
+ e.g. [
101
+ {"role": "system", "content": "You are a helpful assistant,let's think step by step in multiple different ways."},
102
+ {"role": "user", "content": "Who won the world series in 2020?"},
103
+ {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
104
+ {"role": "user", "content": "Where was it played?"}
105
+ ]
106
+ :param query: query content
107
+ :param user_id: from user id
108
+ :return: query content with conversaction
109
+ '''
110
+ session = user_session.get(user_id, [])
111
+ if len(session) == 0:
112
+ system_prompt = conf().get("character_desc", "")
113
+ system_item = {'role': 'system', 'content': system_prompt}
114
+ session.append(system_item)
115
+ user_session[user_id] = session
116
+ user_item = {'role': 'user', 'content': query}
117
+ session.append(user_item)
118
+ return session
119
+
120
+ @staticmethod
121
+ def save_session(query, answer, user_id):
122
+ session = user_session.get(user_id)
123
+ if session:
124
+ # append conversation
125
+ gpt_item = {'role': 'assistant', 'content': answer}
126
+ session.append(gpt_item)
127
+
128
+ @staticmethod
129
+ def clear_session(user_id):
130
+ user_session[user_id] = []