Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
'''
|
4 |
+
@Time : 2023/09/22 17:43:35
|
5 |
+
@Author : zoeyxiong
|
6 |
+
@File : chatgpt_bot.py
|
7 |
+
@Desc : 调用chatGPT类
|
8 |
+
'''
|
9 |
+
import os
|
10 |
+
import json
|
11 |
+
import openai
|
12 |
+
import gradio as gr
|
13 |
+
|
14 |
+
default_model = 'text-davinci-003'
|
15 |
+
|
16 |
+
class ChatGPT:
|
17 |
+
def __init__(self, model ,init_system={"role": "system", "content": "你是一个AI助手"}, save_message=False, ):
|
18 |
+
self.messages = []
|
19 |
+
self.init_system = init_system
|
20 |
+
self.model = model
|
21 |
+
self.messages.append(init_system)
|
22 |
+
# 开启此项,须告知用户
|
23 |
+
self.save_message = save_message
|
24 |
+
self.filename="./user_messages.json"
|
25 |
+
|
26 |
+
def ask_gpt(self):
|
27 |
+
rsp = openai.ChatCompletion.create(
|
28 |
+
model=self.model,
|
29 |
+
messages=self.messages
|
30 |
+
)
|
31 |
+
return rsp.get("choices")[0]["message"]["content"]
|
32 |
+
|
33 |
+
def get_response(self, question):
|
34 |
+
""" 调用openai接口, 获取回答
|
35 |
+
"""
|
36 |
+
# 用户的问题加入到message
|
37 |
+
self.messages.append({"role": "user", "content": question})
|
38 |
+
# 问chatgpt问题的答案
|
39 |
+
rsp = openai.ChatCompletion.create(
|
40 |
+
model=self.model,
|
41 |
+
messages=self.messages,
|
42 |
+
)
|
43 |
+
answer = rsp.get("choices")[0]["message"]["content"]
|
44 |
+
# 得到的答案加入message,多轮对话的历史信息
|
45 |
+
self.messages.append({"role": "assistant", "content": answer})
|
46 |
+
return answer
|
47 |
+
|
48 |
+
def clean_history(self):
|
49 |
+
""" 清空历史信息
|
50 |
+
"""
|
51 |
+
self.messages.clear()
|
52 |
+
self.messages.append(self.init_system)
|
53 |
+
|
54 |
+
#!/usr/bin/env python
|
55 |
+
# -*- encoding: utf-8 -*-
|
56 |
+
'''
|
57 |
+
@Time : 2023/09/22 17:43:37
|
58 |
+
@Author : zoeyxiong
|
59 |
+
@File : gradio_chatgpt_v2.py
|
60 |
+
@Desc : 使用gradio调用chatgpt
|
61 |
+
'''
|
62 |
+
|
63 |
+
openai.api_key = "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"
|
64 |
+
MODEL_NAME = 'gpt-3.5-turbo'
|
65 |
+
# 自定义system
|
66 |
+
INIT_MSG = {"role": "system", "content": "你是一个资深算法工程师."}
|
67 |
+
# 设置端口号,默认7560,遇冲突可自定义
|
68 |
+
SERVER_PORT = 7560
|
69 |
+
# 调用gpt的bot
|
70 |
+
chatgpt = ChatGPT(MODEL_NAME, INIT_MSG)
|
71 |
+
|
72 |
+
def predict(input, chatbot):
|
73 |
+
""" 调用openai接口,获取答案
|
74 |
+
"""
|
75 |
+
chatbot.append((input, ""))
|
76 |
+
# 找chatgpt要答案
|
77 |
+
response = chatgpt.get_response(input)
|
78 |
+
chatbot[-1] = (input, response)
|
79 |
+
return chatbot
|
80 |
+
|
81 |
+
def reset_user_input():
|
82 |
+
return gr.update(value='')
|
83 |
+
|
84 |
+
def reset_state():
|
85 |
+
chatgpt.clean_history()
|
86 |
+
return []
|
87 |
+
|
88 |
+
|
89 |
+
def main():
|
90 |
+
with gr.Blocks() as demo:
|
91 |
+
gr.HTML("""<h1 align="center">{}</h1>""".format(MODEL_NAME))
|
92 |
+
# gradio的chatbot
|
93 |
+
chatbot = gr.Chatbot()
|
94 |
+
with gr.Row():
|
95 |
+
with gr.Column(scale=4):
|
96 |
+
with gr.Column(scale=50):
|
97 |
+
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
|
98 |
+
container=False)
|
99 |
+
with gr.Column(min_width=32, scale=1):
|
100 |
+
submitBtn = gr.Button("Submit", variant="primary")
|
101 |
+
with gr.Column(scale=1):
|
102 |
+
emptyBtn = gr.Button("Clear History")
|
103 |
+
# 提交问题
|
104 |
+
submitBtn.click(predict, [user_input, chatbot],
|
105 |
+
[chatbot], show_progress=True)
|
106 |
+
submitBtn.click(reset_user_input, [], [user_input])
|
107 |
+
# 清空历史对话
|
108 |
+
emptyBtn.click(reset_state, outputs=[chatbot], show_progress=True)
|
109 |
+
|
110 |
+
|
111 |
+
demo.queue().launch(share=False, inbrowser=True, server_port=SERVER_PORT)
|
112 |
+
|
113 |
+
|
114 |
+
if __name__ == '__main__':
|
115 |
+
main()
|