qiuhuachuan
commited on
Commit
•
3e93681
1
Parent(s):
3257e85
Update README.md
Browse files
README.md
CHANGED
@@ -17,6 +17,134 @@ tags:
|
|
17 |
这是因为训练语料中并没有人设相关的训练样本。
|
18 |
```
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
## 免责声明
|
21 |
|
22 |
我们的心理健康支持对话机器人(以下简称“机器人”)旨在为用户提供情感支持和心理健康建议。然而,机器人不是医疗保健专业人员,不能替代医生、心理医生或其他专业人士的意见、诊断、建议或治疗。
|
|
|
17 |
这是因为训练语料中并没有人设相关的训练样本。
|
18 |
```
|
19 |
|
20 |
+
## 体验地址: http://47.97.220.53:8080/
|
21 |
+
|
22 |
+
## Code
|
23 |
+
```
|
24 |
+
import os
|
25 |
+
import ujson
|
26 |
+
from typing import Optional
|
27 |
+
|
28 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
|
29 |
+
|
30 |
+
import uvicorn
|
31 |
+
import torch
|
32 |
+
from transformers import AutoTokenizer, AutoModel
|
33 |
+
from fastapi import FastAPI
|
34 |
+
from fastapi.middleware.cors import CORSMiddleware
|
35 |
+
from pydantic import BaseModel
|
36 |
+
|
37 |
+
from peft import PeftModel
|
38 |
+
|
39 |
+
model = AutoModel.from_pretrained('THUDM/chatglm-6b',
|
40 |
+
revision='v0.1.0',
|
41 |
+
trust_remote_code=True)
|
42 |
+
LaRA_PATH = 'qiuhuachuan/MeChat'
|
43 |
+
model = PeftModel.from_pretrained(model, LaRA_PATH)
|
44 |
+
model = model.float().to(device='cuda')
|
45 |
+
|
46 |
+
tokenizer = AutoTokenizer.from_pretrained('THUDM/chatglm-6b',
|
47 |
+
trust_remote_code=True)
|
48 |
+
|
49 |
+
|
50 |
+
class ChatInfo(BaseModel):
|
51 |
+
owner: str
|
52 |
+
msg: str
|
53 |
+
unique_id: str
|
54 |
+
|
55 |
+
|
56 |
+
class RatingInfo(BaseModel):
|
57 |
+
thumb_up: Optional[bool]
|
58 |
+
thumb_down: Optional[bool]
|
59 |
+
unique_id: str
|
60 |
+
idx: int
|
61 |
+
|
62 |
+
|
63 |
+
app = FastAPI()
|
64 |
+
app.add_middleware(CORSMiddleware,
|
65 |
+
allow_origins=['*'],
|
66 |
+
allow_credentials=True,
|
67 |
+
allow_methods=['*'],
|
68 |
+
allow_headers=['*'])
|
69 |
+
|
70 |
+
|
71 |
+
def format_example(example: dict) -> dict:
|
72 |
+
context = f'''Input: {example['input']}\n'''
|
73 |
+
|
74 |
+
return {'context': context, 'target': ''}
|
75 |
+
|
76 |
+
|
77 |
+
def generate_response(data: dict):
|
78 |
+
with torch.no_grad():
|
79 |
+
feature = format_example(data)
|
80 |
+
input_text = feature['context']
|
81 |
+
ids = tokenizer.encode(input_text)
|
82 |
+
input_length = len(ids)
|
83 |
+
input_ids = torch.LongTensor([ids]).to(device='cuda')
|
84 |
+
|
85 |
+
out = model.generate(input_ids=input_ids,
|
86 |
+
max_length=2040,
|
87 |
+
do_sample=True,
|
88 |
+
temperature=0.9,
|
89 |
+
top_p=0.9)
|
90 |
+
|
91 |
+
raw_out_text = tokenizer.decode(out[0])
|
92 |
+
true_out_text = tokenizer.decode(out[0][input_length:])
|
93 |
+
|
94 |
+
answer = true_out_text.replace('\nEND', '').strip()
|
95 |
+
return answer
|
96 |
+
|
97 |
+
|
98 |
+
@app.post('/v1/chat')
|
99 |
+
async def chat(ChatInfo: ChatInfo):
|
100 |
+
unique_id = ChatInfo.unique_id
|
101 |
+
# './dialogues'用于存储聊天数据
|
102 |
+
existing_files = os.listdir('./dialogues')
|
103 |
+
|
104 |
+
target_file = f'{unique_id}.json'
|
105 |
+
if target_file in existing_files:
|
106 |
+
with open(f'./dialogues/{unique_id}.json', 'r', encoding='utf-8') as f:
|
107 |
+
data: list = ujson.load(f)
|
108 |
+
else:
|
109 |
+
data = []
|
110 |
+
data.append({
|
111 |
+
'owner': ChatInfo.owner,
|
112 |
+
'msg': ChatInfo.msg,
|
113 |
+
'unique_id': ChatInfo.unique_id
|
114 |
+
})
|
115 |
+
input_str = ''
|
116 |
+
for item in data:
|
117 |
+
if item['owner'] == 'seeker':
|
118 |
+
input_str += '求助者:' + item['msg']
|
119 |
+
else:
|
120 |
+
input_str += '支持者:' + item['msg']
|
121 |
+
input_str += '支持者:'
|
122 |
+
while len(input_str) > 2000:
|
123 |
+
if input_str.index('求助者:') > input_str.index('支持者:'):
|
124 |
+
start_idx = input_str.index('求助者:')
|
125 |
+
else:
|
126 |
+
start_idx = input_str.index('支持者:')
|
127 |
+
input_str = input_str[start_idx:]
|
128 |
+
|
129 |
+
wrapped_data = {'input': input_str}
|
130 |
+
|
131 |
+
response = generate_response(data=wrapped_data)
|
132 |
+
supporter_msg = {
|
133 |
+
'owner': 'supporter',
|
134 |
+
'msg': response,
|
135 |
+
'unique_id': unique_id
|
136 |
+
}
|
137 |
+
data.append(supporter_msg)
|
138 |
+
with open(f'./dialogues/{unique_id}.json', 'w', encoding='utf-8') as f:
|
139 |
+
ujson.dump(data, f, ensure_ascii=False, indent=2)
|
140 |
+
return {'item': supporter_msg, 'responseCode': 200}
|
141 |
+
|
142 |
+
|
143 |
+
if __name__ == '__main__':
|
144 |
+
uvicorn.run(app, host='0.0.0.0', port=8000)
|
145 |
+
```
|
146 |
+
|
147 |
+
|
148 |
## 免责声明
|
149 |
|
150 |
我们的心理健康支持对话机器人(以下简称“机器人”)旨在为用户提供情感支持和心理健康建议。然而,机器人不是医疗保健专业人员,不能替代医生、心理医生或其他专业人士的意见、诊断、建议或治疗。
|