mindvridge / app.py
MindVR's picture
Update app.py
9e8233a verified
import os
import torch
from typing import List
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
from transformers import AutoTokenizer as SummarizerTokenizer, AutoModelForSeq2SeqLM
device = "cuda" if torch.cuda.is_available() else "cpu"
# Summarization model
summarizer_model_id = "facebook/bart-large-cnn"
summarizer_tokenizer = SummarizerTokenizer.from_pretrained(summarizer_model_id)
summarizer_model = AutoModelForSeq2SeqLM.from_pretrained(
summarizer_model_id,
torch_dtype=torch.float16,
device_map="auto"
)
summarizer_model.to(device)
def summarize_text(text: str, max_length=150) -> str:
inputs = summarizer_tokenizer([text], return_tensors="pt", max_length=1024, truncation=True).to(device)
summary_ids = summarizer_model.generate(
inputs['input_ids'],
num_beams=4,
max_length=max_length,
early_stopping=True
)
summary = summarizer_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return summary
HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN:
login(token=HF_TOKEN)
model_id = "deepseek-ai/DeepSeek-R1-Distill-Llama-70B"
tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto",
low_cpu_mem_usage=True,
token=HF_TOKEN
)
# --- GIỮ LẠI CHỈ 1 HÀM build_prompt, ĐÃ BỔ SUNG SUMMARIZATION ---
def build_prompt(prompt: str, histories: List[str], new_message: str) -> str:
prompt_text = prompt.strip() + "\n" if prompt else ""
histories_text = "\n".join(histories) if histories else ""
# Tóm tắt nếu quá dài (tùy chỉnh ngưỡng này)
if len(histories_text) > 3000:
histories_text = summarize_text(histories_text, max_length=180)
if histories_text:
prompt_text += histories_text + "\n"
prompt_text += f"User: {new_message}\nAI:"
return prompt_text
def chat(
prompt: str,
histories: List[str],
new_message: str
) -> str:
prompt_text = build_prompt(prompt, histories, new_message)
input_ids = tokenizer(prompt_text, return_tensors="pt").input_ids.to(device)
with torch.no_grad():
output = model.generate(
input_ids,
max_new_tokens=256,
do_sample=True,
top_p=0.95,
temperature=0.7,
pad_token_id=tokenizer.eos_token_id
)
output_text = tokenizer.decode(output[0], skip_special_tokens=True)
if "AI:" in output_text:
response = output_text.split("AI:")[-1].strip()
if "User:" in response:
response = response.split("User:")[0].strip()
else:
response = output_text.strip()
return response
with gr.Blocks() as demo:
gr.Markdown("# MindVR Therapy Chatbot\n\nDùng thử UI hoặc gọi API!")
prompt_box = gr.Textbox(lines=2, label="Prompt (System Prompt, chỉ dẫn context cho AI, có thể bỏ trống)")
histories_box = gr.Textbox(lines=8, label="Histories (mỗi dòng là một message, ví dụ: User: Xin chào)")
new_message_box = gr.Textbox(label="New message")
output = gr.Textbox(label="AI Response")
def _chat_ui(prompt, histories, new_message):
# histories nhập từ UI là multiline string -> chuyển thành list
histories_list = [line.strip() for line in histories.split('\n') if line.strip()]
return chat(prompt, histories_list, new_message)
btn = gr.Button("Gửi")
btn.click(_chat_ui, inputs=[prompt_box, histories_box, new_message_box], outputs=output)
# API chuẩn RESTful với prompt, histories, new_message
gr.api(chat, api_name="chat_ai")
demo.launch()