YUGOROU commited on
Commit
f7bfd42
·
verified ·
1 Parent(s): 9d386a1

Upload 2 files

Browse files
Files changed (2) hide show
  1. app (2).py +326 -0
  2. requirements (2).txt +19 -0
app (2).py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
4
+ import spaces
5
+ import warnings
6
+ warnings.filterwarnings("ignore")
7
+
8
+ # モデルとトークナイザーの設定
9
+ MODEL_NAME = "YUGOROU/TeenEmo-Reasoning-v2"
10
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
+
12
+ # デフォルト設定
13
+ DEFAULT_SYSTEM_PROMPT = """あなたは思いやりのあるカウンセラーです。10代の若者の感情や悩みに寄り添い、親身になって話を聞いてください。適切なアドバイスを提供し、ポジティブな視点を提供してください。"""
14
+
15
+ # Global variables for model and tokenizer
16
+ model = None
17
+ tokenizer = None
18
+
19
+ def load_model():
20
+ """モデルとトークナイザーの読み込み"""
21
+ global model, tokenizer
22
+
23
+ try:
24
+ # トークナイザーの読み込み
25
+ tokenizer = AutoTokenizer.from_pretrained(
26
+ MODEL_NAME,
27
+ trust_remote_code=True,
28
+ use_fast=True
29
+ )
30
+
31
+ # モデルの読み込み(GPU使用時は量子化を使用)
32
+ if DEVICE == "cuda":
33
+ # 量子化設定(メモリ使用量を削減)
34
+ quantization_config = BitsAndBytesConfig(
35
+ load_in_4bit=True,
36
+ bnb_4bit_quant_type="nf4",
37
+ bnb_4bit_compute_dtype=torch.float16,
38
+ bnb_4bit_use_double_quant=True,
39
+ )
40
+
41
+ model = AutoModelForCausalLM.from_pretrained(
42
+ MODEL_NAME,
43
+ quantization_config=quantization_config,
44
+ device_map="auto",
45
+ trust_remote_code=True,
46
+ torch_dtype=torch.float16,
47
+ low_cpu_mem_usage=True,
48
+ )
49
+ else:
50
+ model = AutoModelForCausalLM.from_pretrained(
51
+ MODEL_NAME,
52
+ trust_remote_code=True,
53
+ torch_dtype=torch.float32,
54
+ low_cpu_mem_usage=True,
55
+ )
56
+ model = model.to(DEVICE)
57
+
58
+ print(f"Model loaded successfully on {DEVICE}")
59
+ return True
60
+
61
+ except Exception as e:
62
+ print(f"Error loading model: {e}")
63
+ return False
64
+
65
+ # モデルの初期化
66
+ print("Loading model...")
67
+ model_loaded = load_model()
68
+
69
+ @spaces.GPU(duration=120) # 最大120秒のGPU使用
70
+ def generate_response(
71
+ message,
72
+ history,
73
+ system_prompt=DEFAULT_SYSTEM_PROMPT,
74
+ temperature=0.7,
75
+ max_new_tokens=512,
76
+ top_k=50,
77
+ top_p=0.9,
78
+ repetition_penalty=1.1
79
+ ):
80
+ """メッセージに対する応答を生成"""
81
+
82
+ if not model_loaded:
83
+ return "モデルの読み込みに失敗しました。"
84
+
85
+ try:
86
+ # 会話履歴を構築
87
+ conversation = []
88
+
89
+ # システムプロンプトを追加
90
+ if system_prompt.strip():
91
+ conversation.append({"role": "system", "content": system_prompt})
92
+
93
+ # 履歴を追加
94
+ for user_msg, assistant_msg in history:
95
+ conversation.append({"role": "user", "content": user_msg})
96
+ if assistant_msg:
97
+ conversation.append({"role": "assistant", "content": assistant_msg})
98
+
99
+ # 現在のメッセージを追加
100
+ conversation.append({"role": "user", "content": message})
101
+
102
+ # トークナイザーでテキストを変換
103
+ input_text = tokenizer.apply_chat_template(
104
+ conversation,
105
+ tokenize=False,
106
+ add_generation_prompt=True
107
+ )
108
+
109
+ # トークン化
110
+ inputs = tokenizer(input_text, return_tensors="pt").to(DEVICE)
111
+
112
+ # 生成設定
113
+ generation_kwargs = {
114
+ "input_ids": inputs["input_ids"],
115
+ "attention_mask": inputs["attention_mask"],
116
+ "max_new_tokens": max_new_tokens,
117
+ "temperature": temperature,
118
+ "top_k": top_k,
119
+ "top_p": top_p,
120
+ "repetition_penalty": repetition_penalty,
121
+ "do_sample": True,
122
+ "pad_token_id": tokenizer.eos_token_id,
123
+ "eos_token_id": tokenizer.eos_token_id,
124
+ }
125
+
126
+ # 応答生成
127
+ with torch.no_grad():
128
+ outputs = model.generate(**generation_kwargs)
129
+
130
+ # 応答をデコード
131
+ response = tokenizer.decode(
132
+ outputs[0][len(inputs["input_ids"][0]):],
133
+ skip_special_tokens=True
134
+ ).strip()
135
+
136
+ return response
137
+
138
+ except Exception as e:
139
+ return f"エラーが発生しました: {str(e)}"
140
+
141
+ def clear_chat():
142
+ """チャット履歴をクリア"""
143
+ return [], ""
144
+
145
+ # Gradio インターフェースの作成
146
+ with gr.Blocks(
147
+ title="TeenEmo Reasoning v2 - 10代向けカウンセリングAI",
148
+ theme=gr.themes.Soft(),
149
+ css=""".gradio-container {max-width: 1000px; margin: auto;}"""
150
+ ) as demo:
151
+
152
+ gr.HTML("""
153
+ <div style="text-align: center; padding: 20px;">
154
+ <h1>🤖 TeenEmo Reasoning v2</h1>
155
+ <h3>10代向けカウンセリン���AIチャットボット</h3>
156
+ <p>このAIは、10代の若者の感情的な悩みや課題に寄り添い、思いやりのあるサポートを提供するように設計されています。</p>
157
+ <p><strong>機能:</strong> 感情理解、共感的対話、建設的なアドバイス</p>
158
+ <p><em>※ このAIは補助的なツールです。深刻な問題については専門家にご相談ください。</em></p>
159
+ </div>
160
+ """)
161
+
162
+ with gr.Row():
163
+ with gr.Column(scale=3):
164
+ # メインのチャットインターフェース
165
+ chatbot = gr.Chatbot(
166
+ height=500,
167
+ placeholder="👋 こんにちは!何でも気軽に話しかけてくださいね。あなたの気持ちに寄り添います。",
168
+ show_label=False,
169
+ container=True,
170
+ bubble_full_width=False
171
+ )
172
+
173
+ with gr.Row():
174
+ msg = gr.Textbox(
175
+ placeholder="メッセージを入力してください...",
176
+ container=False,
177
+ scale=4,
178
+ show_label=False
179
+ )
180
+ submit_btn = gr.Button("送信", variant="primary", scale=1)
181
+ clear_btn = gr.Button("クリア", variant="secondary", scale=1)
182
+
183
+ with gr.Column(scale=1, min_width=300):
184
+ # 高度な設定
185
+ with gr.Accordion("⚙️ 高度な設定", open=False):
186
+ system_prompt = gr.Textbox(
187
+ label="システムプロンプト",
188
+ value=DEFAULT_SYSTEM_PROMPT,
189
+ lines=4,
190
+ placeholder="AIの役割や行動指針を設定してください"
191
+ )
192
+
193
+ with gr.Group():
194
+ temperature = gr.Slider(
195
+ minimum=0.1,
196
+ maximum=2.0,
197
+ value=0.7,
198
+ step=0.1,
199
+ label="Temperature (創造性)",
200
+ info="高いほど創造的、低いほど一貫性のある応答"
201
+ )
202
+
203
+ max_new_tokens = gr.Slider(
204
+ minimum=50,
205
+ maximum=1024,
206
+ value=512,
207
+ step=50,
208
+ label="最大新規トークン数",
209
+ info="応答の最大長を制御"
210
+ )
211
+
212
+ top_k = gr.Slider(
213
+ minimum=1,
214
+ maximum=100,
215
+ value=50,
216
+ step=1,
217
+ label="Top K サンプリング",
218
+ info="選択候補の語彙数を制限"
219
+ )
220
+
221
+ top_p = gr.Slider(
222
+ minimum=0.1,
223
+ maximum=1.0,
224
+ value=0.9,
225
+ step=0.05,
226
+ label="Top P サンプリング",
227
+ info="累積確率でトークンを選択"
228
+ )
229
+
230
+ repetition_penalty = gr.Slider(
231
+ minimum=1.0,
232
+ maximum=2.0,
233
+ value=1.1,
234
+ step=0.05,
235
+ label="繰り返しペナルティ",
236
+ info="同じ表現の繰り返しを抑制"
237
+ )
238
+
239
+ reset_settings = gr.Button("設定をリセット", variant="secondary")
240
+
241
+ # サンプル質問
242
+ with gr.Row():
243
+ gr.Examples(
244
+ examples=[
245
+ ["最近、学校で友達関係に悩んでいます。どうすればいいでしょうか?"],
246
+ ["将来について不安で眠れません。"],
247
+ ["親との関係がうまくいかなくて困っています。"],
248
+ ["勉強のプレッシャーでストレスを感じています。"],
249
+ ["自分に自信が持てません。どうしたら自信をつけられますか?"]
250
+ ],
251
+ inputs=msg,
252
+ label="💬 サンプル質問(クリックして試してみてください)"
253
+ )
254
+
255
+ # 注意事項
256
+ gr.HTML("""
257
+ <div style="background-color: #f0f8ff; padding: 15px; border-radius: 10px; margin-top: 20px;">
258
+ <h4>🔔 ご利用上の注意</h4>
259
+ <ul>
260
+ <li>このAIは研究・学習目的で開発されたプロトタイプです</li>
261
+ <li>深刻な心理的問題や危機的状況では、専門家や信頼できる大人にご相談ください</li>
262
+ <li>緊急時は適切な支援機関(いのちの電話: 0570-783-556など)にご連絡ください</li>
263
+ <li>生成される応答の正確性は保証されません</li>
264
+ </ul>
265
+ </div>
266
+ """)
267
+
268
+ # イベントハンドラーの設定
269
+ def respond(message, chat_history, system_prompt, temperature, max_new_tokens, top_k, top_p, repetition_penalty):
270
+ if not message.strip():
271
+ return chat_history, ""
272
+
273
+ # 応答を生成
274
+ bot_message = generate_response(
275
+ message,
276
+ chat_history,
277
+ system_prompt,
278
+ temperature,
279
+ max_new_tokens,
280
+ top_k,
281
+ top_p,
282
+ repetition_penalty
283
+ )
284
+
285
+ # チャット履歴に追加
286
+ chat_history.append((message, bot_message))
287
+ return chat_history, ""
288
+
289
+ def reset_advanced_settings():
290
+ return (
291
+ DEFAULT_SYSTEM_PROMPT,
292
+ 0.7, # temperature
293
+ 512, # max_new_tokens
294
+ 50, # top_k
295
+ 0.9, # top_p
296
+ 1.1 # repetition_penalty
297
+ )
298
+
299
+ # イベントの接続
300
+ submit_btn.click(
301
+ respond,
302
+ inputs=[msg, chatbot, system_prompt, temperature, max_new_tokens, top_k, top_p, repetition_penalty],
303
+ outputs=[chatbot, msg]
304
+ )
305
+
306
+ msg.submit(
307
+ respond,
308
+ inputs=[msg, chatbot, system_prompt, temperature, max_new_tokens, top_k, top_p, repetition_penalty],
309
+ outputs=[chatbot, msg]
310
+ )
311
+
312
+ clear_btn.click(clear_chat, outputs=[chatbot, msg])
313
+
314
+ reset_settings.click(
315
+ reset_advanced_settings,
316
+ outputs=[system_prompt, temperature, max_new_tokens, top_k, top_p, repetition_penalty]
317
+ )
318
+
319
+ # アプリケーションの起動
320
+ if __name__ == "__main__":
321
+ demo.queue(max_size=10).launch(
322
+ server_name="0.0.0.0",
323
+ server_port=7860,
324
+ share=False,
325
+ show_error=True
326
+ )
requirements (2).txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core dependencies
2
+ gradio==5.35.0
3
+ torch>=2.4.0
4
+ transformers>=4.48.0
5
+ accelerate>=1.8.0
6
+ huggingface-hub>=0.33.0
7
+
8
+ # GPU optimization
9
+ bitsandbytes>=0.43.0
10
+ spaces>=0.28.3
11
+
12
+ # Additional utilities
13
+ numpy>=1.24.0
14
+ sentencepiece>=0.1.99
15
+ protobuf>=3.20.0
16
+ typing-extensions>=4.8.0
17
+
18
+ # Optional performance enhancements
19
+ optimum[onnxruntime]>=1.17.0