qiuhuachuan commited on
Commit
3e93681
1 Parent(s): 3257e85

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +128 -0
README.md CHANGED
@@ -17,6 +17,134 @@ tags:
17
  这是因为训练语料中并没有人设相关的训练样本。
18
  ```
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  ## 免责声明
21
 
22
  我们的心理健康支持对话机器人(以下简称“机器人”)旨在为用户提供情感支持和心理健康建议。然而,机器人不是医疗保健专业人员,不能替代医生、心理医生或其他专业人士的意见、诊断、建议或治疗。
 
17
  这是因为训练语料中并没有人设相关的训练样本。
18
  ```
19
 
20
+ ## 体验地址: http://47.97.220.53:8080/
21
+
22
+ ## Code
23
+ ```
24
+ import os
25
+ import ujson
26
+ from typing import Optional
27
+
28
+ os.environ['CUDA_VISIBLE_DEVICES'] = '2'
29
+
30
+ import uvicorn
31
+ import torch
32
+ from transformers import AutoTokenizer, AutoModel
33
+ from fastapi import FastAPI
34
+ from fastapi.middleware.cors import CORSMiddleware
35
+ from pydantic import BaseModel
36
+
37
+ from peft import PeftModel
38
+
39
+ model = AutoModel.from_pretrained('THUDM/chatglm-6b',
40
+ revision='v0.1.0',
41
+ trust_remote_code=True)
42
+ LaRA_PATH = 'qiuhuachuan/MeChat'
43
+ model = PeftModel.from_pretrained(model, LaRA_PATH)
44
+ model = model.float().to(device='cuda')
45
+
46
+ tokenizer = AutoTokenizer.from_pretrained('THUDM/chatglm-6b',
47
+ trust_remote_code=True)
48
+
49
+
50
+ class ChatInfo(BaseModel):
51
+ owner: str
52
+ msg: str
53
+ unique_id: str
54
+
55
+
56
+ class RatingInfo(BaseModel):
57
+ thumb_up: Optional[bool]
58
+ thumb_down: Optional[bool]
59
+ unique_id: str
60
+ idx: int
61
+
62
+
63
+ app = FastAPI()
64
+ app.add_middleware(CORSMiddleware,
65
+ allow_origins=['*'],
66
+ allow_credentials=True,
67
+ allow_methods=['*'],
68
+ allow_headers=['*'])
69
+
70
+
71
+ def format_example(example: dict) -> dict:
72
+ context = f'''Input: {example['input']}\n'''
73
+
74
+ return {'context': context, 'target': ''}
75
+
76
+
77
+ def generate_response(data: dict):
78
+ with torch.no_grad():
79
+ feature = format_example(data)
80
+ input_text = feature['context']
81
+ ids = tokenizer.encode(input_text)
82
+ input_length = len(ids)
83
+ input_ids = torch.LongTensor([ids]).to(device='cuda')
84
+
85
+ out = model.generate(input_ids=input_ids,
86
+ max_length=2040,
87
+ do_sample=True,
88
+ temperature=0.9,
89
+ top_p=0.9)
90
+
91
+ raw_out_text = tokenizer.decode(out[0])
92
+ true_out_text = tokenizer.decode(out[0][input_length:])
93
+
94
+ answer = true_out_text.replace('\nEND', '').strip()
95
+ return answer
96
+
97
+
98
+ @app.post('/v1/chat')
99
+ async def chat(ChatInfo: ChatInfo):
100
+ unique_id = ChatInfo.unique_id
101
+ # './dialogues'用于存储聊天数据
102
+ existing_files = os.listdir('./dialogues')
103
+
104
+ target_file = f'{unique_id}.json'
105
+ if target_file in existing_files:
106
+ with open(f'./dialogues/{unique_id}.json', 'r', encoding='utf-8') as f:
107
+ data: list = ujson.load(f)
108
+ else:
109
+ data = []
110
+ data.append({
111
+ 'owner': ChatInfo.owner,
112
+ 'msg': ChatInfo.msg,
113
+ 'unique_id': ChatInfo.unique_id
114
+ })
115
+ input_str = ''
116
+ for item in data:
117
+ if item['owner'] == 'seeker':
118
+ input_str += '求助者:' + item['msg']
119
+ else:
120
+ input_str += '支持者:' + item['msg']
121
+ input_str += '支持者:'
122
+ while len(input_str) > 2000:
123
+ if input_str.index('求助者:') > input_str.index('支持者:'):
124
+ start_idx = input_str.index('求助者:')
125
+ else:
126
+ start_idx = input_str.index('支持者:')
127
+ input_str = input_str[start_idx:]
128
+
129
+ wrapped_data = {'input': input_str}
130
+
131
+ response = generate_response(data=wrapped_data)
132
+ supporter_msg = {
133
+ 'owner': 'supporter',
134
+ 'msg': response,
135
+ 'unique_id': unique_id
136
+ }
137
+ data.append(supporter_msg)
138
+ with open(f'./dialogues/{unique_id}.json', 'w', encoding='utf-8') as f:
139
+ ujson.dump(data, f, ensure_ascii=False, indent=2)
140
+ return {'item': supporter_msg, 'responseCode': 200}
141
+
142
+
143
+ if __name__ == '__main__':
144
+ uvicorn.run(app, host='0.0.0.0', port=8000)
145
+ ```
146
+
147
+
148
  ## 免责声明
149
 
150
  我们的心理健康支持对话机器人(以下简称“机器人”)旨在为用户提供情感支持和心理健康建议。然而,机器人不是医疗保健专业人员,不能替代医生、心理医生或其他专业人士的意见、诊断、建议或治疗。