YUGOROU's picture
Update app.py
17185df verified
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import spaces
import warnings
warnings.filterwarnings("ignore")
# モデルとトークナイザーの設定
MODEL_NAME = "YUGOROU/TeenEmo-Reasoning-v2"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# デフォルト設定
DEFAULT_SYSTEM_PROMPT = """あなたは思いやりのあるカウンセラーです。10代の若者の感情や悩みに寄り添い、親身になって話を聞いてください。適切なアドバイスを提供し、ポジティブな視点を提供してください。"""
# Global variables for model and tokenizer
model = None
tokenizer = None
def load_model():
"""モデルとトークナイザーの読み込み"""
global model, tokenizer
try:
# トークナイザーの読み込み
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
use_fast=True
)
# モデルの読み込み(GPU使用時は量子化を使用)
if DEVICE == "cuda":
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=quantization_config,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
)
else:
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
)
print(f"Model loaded successfully on {DEVICE}")
return True
except Exception as e:
print(f"Error loading model: {e}")
return False
# モデルの初期化
print("Loading model...")
model_loaded = load_model()
@spaces.GPU(duration=120) # 最大120秒のGPU使用
def generate_response(
message,
history,
system_prompt=DEFAULT_SYSTEM_PROMPT,
temperature=1.3,
max_new_tokens=512,
top_k=50,
top_p=0.9,
repetition_penalty=1.1
):
"""メッセージに対する応答を生成"""
if not model_loaded:
return "モデルの読み込みに失敗しました。"
try:
# 会話履歴を構築
conversation = []
# システムプロンプトを追加
if system_prompt.strip():
conversation.append({"role": "system", "content": system_prompt})
# 履歴を追加
for user_msg, assistant_msg in history:
conversation.append({"role": "user", "content": user_msg})
if assistant_msg:
conversation.append({"role": "assistant", "content": assistant_msg})
# 現在のメッセージを追加
conversation.append({"role": "user", "content": message})
# トークナイザーでテキストを変換
input_text = tokenizer.apply_chat_template(
conversation,
tokenize=False,
add_generation_prompt=True
)
# トークン化
inputs = tokenizer(input_text, return_tensors="pt").to(DEVICE)
# 生成設定
generation_kwargs = {
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"top_k": top_k,
"top_p": top_p,
"repetition_penalty": repetition_penalty,
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id,
"eos_token_id": tokenizer.eos_token_id,
}
# 応答生成
with torch.no_grad():
outputs = model.generate(**generation_kwargs)
# 応答をデコード
response = tokenizer.decode(
outputs[0][len(inputs["input_ids"][0]):],
skip_special_tokens=True
).strip()
return response
except Exception as e:
return f"エラーが発生しました: {str(e)}"
def clear_chat():
"""チャット履歴をクリア"""
return [], ""
# Gradio インターフェースの作成
with gr.Blocks(
title="TeenEmo Reasoning v2 - 10代向けカウンセリングAI",
theme=gr.themes.Soft(),
css=""".gradio-container {max-width: 1000px; margin: auto;}"""
) as demo:
gr.HTML("""
<div style="text-align: center; padding: 20px;">
<h1>🤖 TeenEmo Reasoning v2</h1>
<h3>10代向けカウンセリングAIチャットボット</h3>
<p>このAIは、10代の若者の感情的な悩みや課題に寄り添い、思いやりのあるサポートを提供するように設計されています。</p>
<p><strong>機能:</strong> 感情理解、共感的対話、建設的なアドバイス</p>
<p><em>※ このAIは補助的なツールです。深刻な問題については専門家にご相談ください。</em></p>
</div>
""")
with gr.Row():
with gr.Column(scale=3):
# メインのチャットインターフェース
chatbot = gr.Chatbot(
height=500,
placeholder="👋 こんにちは!何でも気軽に話しかけてくださいね。あなたの気持ちに寄り添います。",
show_label=False,
container=True,
bubble_full_width=False
)
with gr.Row():
msg = gr.Textbox(
placeholder="メッセージを入力してください...",
container=False,
scale=4,
show_label=False
)
submit_btn = gr.Button("送信", variant="primary", scale=1)
clear_btn = gr.Button("クリア", variant="secondary", scale=1)
with gr.Column(scale=1, min_width=300):
# 高度な設定
with gr.Accordion("⚙️ 高度な設定", open=False):
system_prompt = gr.Textbox(
label="システムプロンプト",
value=DEFAULT_SYSTEM_PROMPT,
lines=4,
placeholder="AIの役割や行動指針を設定してください"
)
with gr.Group():
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=1.3,
step=0.1,
label="Temperature (創造性)",
info="高いほど創造的、低いほど一貫性のある応答"
)
max_new_tokens = gr.Slider(
minimum=50,
maximum=1024,
value=512,
step=50,
label="最大新規トークン数",
info="応答の最大長を制御"
)
top_k = gr.Slider(
minimum=1,
maximum=100,
value=50,
step=1,
label="Top K サンプリング",
info="選択候補の語彙数を制限"
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.05,
label="Top P サンプリング",
info="累積確率でトークンを選択"
)
repetition_penalty = gr.Slider(
minimum=1.0,
maximum=2.0,
value=1.1,
step=0.05,
label="繰り返しペナルティ",
info="同じ表現の繰り返しを抑制"
)
reset_settings = gr.Button("設定をリセット", variant="secondary")
# サンプル質問
with gr.Row():
gr.Examples(
examples=[
["最近、学校で友達関係に悩んでいます。どうすればいいでしょうか?"],
["将来について不安で眠れません。"],
["親との関係がうまくいかなくて困っています。"],
["勉強のプレッシャーでストレスを感じています。"],
["自分に自信が持てません。どうしたら自信をつけられますか?"]
],
inputs=msg,
label="💬 サンプル質問(クリックして試してみてください)"
)
# 注意事項
gr.HTML("""
<div style="background-color: #f0f8ff; padding: 15px; border-radius: 10px; margin-top: 20px;">
<h4>🔔 ご利用上の注意</h4>
<ul>
<li>このAIは研究・学習目的で開発されたプロトタイプです</li>
<li>深刻な心理的問題や危機的状況では、専門家や信頼できる大人にご相談ください</li>
<li>緊急時は適切な支援機関(いのちの電話: 0570-783-556など)にご連絡ください</li>
<li>生成される応答の正確性は保証されません</li>
</ul>
</div>
""")
# イベントハンドラーの設定
def respond(message, chat_history, system_prompt, temperature, max_new_tokens, top_k, top_p, repetition_penalty):
if not message.strip():
return chat_history, ""
# 応答を生成
bot_message = generate_response(
message,
chat_history,
system_prompt,
temperature,
max_new_tokens,
top_k,
top_p,
repetition_penalty
)
# チャット履歴に追加
chat_history.append((message, bot_message))
return chat_history, ""
def reset_advanced_settings():
return (
DEFAULT_SYSTEM_PROMPT,
0.7, # temperature
512, # max_new_tokens
50, # top_k
0.9, # top_p
1.1 # repetition_penalty
)
# イベントの接続
submit_btn.click(
respond,
inputs=[msg, chatbot, system_prompt, temperature, max_new_tokens, top_k, top_p, repetition_penalty],
outputs=[chatbot, msg]
)
msg.submit(
respond,
inputs=[msg, chatbot, system_prompt, temperature, max_new_tokens, top_k, top_p, repetition_penalty],
outputs=[chatbot, msg]
)
clear_btn.click(clear_chat, outputs=[chatbot, msg])
reset_settings.click(
reset_advanced_settings,
outputs=[system_prompt, temperature, max_new_tokens, top_k, top_p, repetition_penalty]
)
# アプリケーションの起動
if __name__ == "__main__":
demo.queue(max_size=10).launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True
)