Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from huggingface_hub import login # Import login function | |
import torch | |
import os | |
import re | |
login(token=os.getenv("HF_TOKEN")) | |
model_path = "KomalNaseem/depression-chatbot-model" | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
model = AutoModelForCausalLM.from_pretrained(model_path).to(device) | |
model.eval() | |
user_history = [] | |
turn_counter = 0 | |
MAX_TURNS_FOR_PREDICTION = 8 | |
def chat(user_input): | |
global user_history, turn_counter | |
turn_counter += 1 | |
user_history.append(f"Human: {user_input}") | |
last_turns = user_history[-4:] | |
prompt = "\n".join(last_turns) + "\nAI:" | |
inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
output_ids = model.generate( | |
**inputs, | |
max_new_tokens=50, | |
do_sample=False, | |
pad_token_id=tokenizer.eos_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
) | |
response_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
response = response_text.split("AI:")[-1].strip() | |
user_history.append(f"AI: {response}") | |
depression_prob = None | |
if turn_counter == MAX_TURNS_FOR_PREDICTION: | |
prediction_prompt = ( | |
"\n".join(user_history[-8:]) + | |
"\nAI: Based on this conversation, what is the probability that the human has depression? " | |
"Please answer with a number between 0 and 1." | |
) | |
inputs_pred = tokenizer(prediction_prompt, return_tensors="pt").to(device) | |
output_pred_ids = model.generate( | |
**inputs_pred, | |
max_new_tokens=10, | |
do_sample=False, | |
pad_token_id=tokenizer.eos_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
) | |
pred_text = tokenizer.decode(output_pred_ids[0], skip_special_tokens=True) | |
match = re.search(r"0?\.\d+", pred_text) | |
if match: | |
try: | |
depression_prob = float(match.group(0)) | |
except: | |
depression_prob = None | |
return response, depression_prob if depression_prob is not None else "Prediction after 8 turns" | |
iface = gr.Interface( | |
fn=chat, | |
inputs=gr.Textbox(lines=2, label="Your Message"), | |
outputs=[gr.Textbox(label="AI Response"), gr.Textbox(label="Depression Probability")], | |
title="Depression Detection Chatbot", | |
description="Chat with the AI. After 8 turns it predicts depression probability." | |
) | |
if __name__ == "__main__": | |
iface.launch() | |