|
import gradio as gr |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
|
import spaces |
|
import warnings |
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
MODEL_NAME = "YUGOROU/TeenEmo-Reasoning-v2" |
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
DEFAULT_SYSTEM_PROMPT = """あなたは思いやりのあるカウンセラーです。10代の若者の感情や悩みに寄り添い、親身になって話を聞いてください。適切なアドバイスを提供し、ポジティブな視点を提供してください。""" |
|
|
|
|
|
model = None |
|
tokenizer = None |
|
|
|
def load_model(): |
|
"""モデルとトークナイザーの読み込み""" |
|
global model, tokenizer |
|
|
|
try: |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
MODEL_NAME, |
|
trust_remote_code=True, |
|
use_fast=True |
|
) |
|
|
|
|
|
if DEVICE == "cuda": |
|
|
|
quantization_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype=torch.float16, |
|
bnb_4bit_use_double_quant=True, |
|
) |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_NAME, |
|
quantization_config=quantization_config, |
|
device_map="auto", |
|
trust_remote_code=True, |
|
torch_dtype=torch.float16, |
|
low_cpu_mem_usage=True, |
|
) |
|
else: |
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_NAME, |
|
trust_remote_code=True, |
|
torch_dtype=torch.float16, |
|
low_cpu_mem_usage=True, |
|
) |
|
|
|
print(f"Model loaded successfully on {DEVICE}") |
|
return True |
|
|
|
except Exception as e: |
|
print(f"Error loading model: {e}") |
|
return False |
|
|
|
|
|
print("Loading model...") |
|
model_loaded = load_model() |
|
|
|
@spaces.GPU(duration=120) |
|
def generate_response( |
|
message, |
|
history, |
|
system_prompt=DEFAULT_SYSTEM_PROMPT, |
|
temperature=1.3, |
|
max_new_tokens=512, |
|
top_k=50, |
|
top_p=0.9, |
|
repetition_penalty=1.1 |
|
): |
|
"""メッセージに対する応答を生成""" |
|
|
|
if not model_loaded: |
|
return "モデルの読み込みに失敗しました。" |
|
|
|
try: |
|
|
|
conversation = [] |
|
|
|
|
|
if system_prompt.strip(): |
|
conversation.append({"role": "system", "content": system_prompt}) |
|
|
|
|
|
for user_msg, assistant_msg in history: |
|
conversation.append({"role": "user", "content": user_msg}) |
|
if assistant_msg: |
|
conversation.append({"role": "assistant", "content": assistant_msg}) |
|
|
|
|
|
conversation.append({"role": "user", "content": message}) |
|
|
|
|
|
input_text = tokenizer.apply_chat_template( |
|
conversation, |
|
tokenize=False, |
|
add_generation_prompt=True |
|
) |
|
|
|
|
|
inputs = tokenizer(input_text, return_tensors="pt").to(DEVICE) |
|
|
|
|
|
generation_kwargs = { |
|
"input_ids": inputs["input_ids"], |
|
"attention_mask": inputs["attention_mask"], |
|
"max_new_tokens": max_new_tokens, |
|
"temperature": temperature, |
|
"top_k": top_k, |
|
"top_p": top_p, |
|
"repetition_penalty": repetition_penalty, |
|
"do_sample": True, |
|
"pad_token_id": tokenizer.eos_token_id, |
|
"eos_token_id": tokenizer.eos_token_id, |
|
} |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model.generate(**generation_kwargs) |
|
|
|
|
|
response = tokenizer.decode( |
|
outputs[0][len(inputs["input_ids"][0]):], |
|
skip_special_tokens=True |
|
).strip() |
|
|
|
return response |
|
|
|
except Exception as e: |
|
return f"エラーが発生しました: {str(e)}" |
|
|
|
def clear_chat(): |
|
"""チャット履歴をクリア""" |
|
return [], "" |
|
|
|
|
|
with gr.Blocks( |
|
title="TeenEmo Reasoning v2 - 10代向けカウンセリングAI", |
|
theme=gr.themes.Soft(), |
|
css=""".gradio-container {max-width: 1000px; margin: auto;}""" |
|
) as demo: |
|
|
|
gr.HTML(""" |
|
<div style="text-align: center; padding: 20px;"> |
|
<h1>🤖 TeenEmo Reasoning v2</h1> |
|
<h3>10代向けカウンセリングAIチャットボット</h3> |
|
<p>このAIは、10代の若者の感情的な悩みや課題に寄り添い、思いやりのあるサポートを提供するように設計されています。</p> |
|
<p><strong>機能:</strong> 感情理解、共感的対話、建設的なアドバイス</p> |
|
<p><em>※ このAIは補助的なツールです。深刻な問題については専門家にご相談ください。</em></p> |
|
</div> |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
|
|
chatbot = gr.Chatbot( |
|
height=500, |
|
placeholder="👋 こんにちは!何でも気軽に話しかけてくださいね。あなたの気持ちに寄り添います。", |
|
show_label=False, |
|
container=True, |
|
bubble_full_width=False |
|
) |
|
|
|
with gr.Row(): |
|
msg = gr.Textbox( |
|
placeholder="メッセージを入力してください...", |
|
container=False, |
|
scale=4, |
|
show_label=False |
|
) |
|
submit_btn = gr.Button("送信", variant="primary", scale=1) |
|
clear_btn = gr.Button("クリア", variant="secondary", scale=1) |
|
|
|
with gr.Column(scale=1, min_width=300): |
|
|
|
with gr.Accordion("⚙️ 高度な設定", open=False): |
|
system_prompt = gr.Textbox( |
|
label="システムプロンプト", |
|
value=DEFAULT_SYSTEM_PROMPT, |
|
lines=4, |
|
placeholder="AIの役割や行動指針を設定してください" |
|
) |
|
|
|
with gr.Group(): |
|
temperature = gr.Slider( |
|
minimum=0.1, |
|
maximum=2.0, |
|
value=1.3, |
|
step=0.1, |
|
label="Temperature (創造性)", |
|
info="高いほど創造的、低いほど一貫性のある応答" |
|
) |
|
|
|
max_new_tokens = gr.Slider( |
|
minimum=50, |
|
maximum=1024, |
|
value=512, |
|
step=50, |
|
label="最大新規トークン数", |
|
info="応答の最大長を制御" |
|
) |
|
|
|
top_k = gr.Slider( |
|
minimum=1, |
|
maximum=100, |
|
value=50, |
|
step=1, |
|
label="Top K サンプリング", |
|
info="選択候補の語彙数を制限" |
|
) |
|
|
|
top_p = gr.Slider( |
|
minimum=0.1, |
|
maximum=1.0, |
|
value=0.9, |
|
step=0.05, |
|
label="Top P サンプリング", |
|
info="累積確率でトークンを選択" |
|
) |
|
|
|
repetition_penalty = gr.Slider( |
|
minimum=1.0, |
|
maximum=2.0, |
|
value=1.1, |
|
step=0.05, |
|
label="繰り返しペナルティ", |
|
info="同じ表現の繰り返しを抑制" |
|
) |
|
|
|
reset_settings = gr.Button("設定をリセット", variant="secondary") |
|
|
|
|
|
with gr.Row(): |
|
gr.Examples( |
|
examples=[ |
|
["最近、学校で友達関係に悩んでいます。どうすればいいでしょうか?"], |
|
["将来について不安で眠れません。"], |
|
["親との関係がうまくいかなくて困っています。"], |
|
["勉強のプレッシャーでストレスを感じています。"], |
|
["自分に自信が持てません。どうしたら自信をつけられますか?"] |
|
], |
|
inputs=msg, |
|
label="💬 サンプル質問(クリックして試してみてください)" |
|
) |
|
|
|
|
|
gr.HTML(""" |
|
<div style="background-color: #f0f8ff; padding: 15px; border-radius: 10px; margin-top: 20px;"> |
|
<h4>🔔 ご利用上の注意</h4> |
|
<ul> |
|
<li>このAIは研究・学習目的で開発されたプロトタイプです</li> |
|
<li>深刻な心理的問題や危機的状況では、専門家や信頼できる大人にご相談ください</li> |
|
<li>緊急時は適切な支援機関(いのちの電話: 0570-783-556など)にご連絡ください</li> |
|
<li>生成される応答の正確性は保証されません</li> |
|
</ul> |
|
</div> |
|
""") |
|
|
|
|
|
def respond(message, chat_history, system_prompt, temperature, max_new_tokens, top_k, top_p, repetition_penalty): |
|
if not message.strip(): |
|
return chat_history, "" |
|
|
|
|
|
bot_message = generate_response( |
|
message, |
|
chat_history, |
|
system_prompt, |
|
temperature, |
|
max_new_tokens, |
|
top_k, |
|
top_p, |
|
repetition_penalty |
|
) |
|
|
|
|
|
chat_history.append((message, bot_message)) |
|
return chat_history, "" |
|
|
|
def reset_advanced_settings(): |
|
return ( |
|
DEFAULT_SYSTEM_PROMPT, |
|
0.7, |
|
512, |
|
50, |
|
0.9, |
|
1.1 |
|
) |
|
|
|
|
|
submit_btn.click( |
|
respond, |
|
inputs=[msg, chatbot, system_prompt, temperature, max_new_tokens, top_k, top_p, repetition_penalty], |
|
outputs=[chatbot, msg] |
|
) |
|
|
|
msg.submit( |
|
respond, |
|
inputs=[msg, chatbot, system_prompt, temperature, max_new_tokens, top_k, top_p, repetition_penalty], |
|
outputs=[chatbot, msg] |
|
) |
|
|
|
clear_btn.click(clear_chat, outputs=[chatbot, msg]) |
|
|
|
reset_settings.click( |
|
reset_advanced_settings, |
|
outputs=[system_prompt, temperature, max_new_tokens, top_k, top_p, repetition_penalty] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.queue(max_size=10).launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=False, |
|
show_error=True |
|
) |