KomalNaseem commited on
Commit
aa5c4a9
·
verified ·
1 Parent(s): 392653e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -0
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+
5
+ model_id = "komal/depression-chatbot-model" # replace with your model path
6
+
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+
9
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
10
+ model = AutoModelForCausalLM.from_pretrained(model_id).to(device)
11
+ model.eval()
12
+
13
+ user_history = []
14
+ turn_counter = 0
15
+ MAX_TURNS_FOR_PREDICTION = 8
16
+
17
+ def chat(user_input):
18
+ global user_history, turn_counter
19
+ turn_counter += 1
20
+
21
+ user_history.append(f"Human: {user_input}")
22
+ last_turns = user_history[-4:]
23
+ prompt = "\n".join(last_turns) + "\nAI:"
24
+
25
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
26
+ output_ids = model.generate(
27
+ **inputs,
28
+ max_new_tokens=50,
29
+ do_sample=False,
30
+ pad_token_id=tokenizer.eos_token_id,
31
+ eos_token_id=tokenizer.eos_token_id,
32
+ )
33
+ response_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
34
+ response = response_text.split("AI:")[-1].strip()
35
+ user_history.append(f"AI: {response}")
36
+
37
+ depression_prob = None
38
+ if turn_counter == MAX_TURNS_FOR_PREDICTION:
39
+ prediction_prompt = (
40
+ "\n".join(user_history[-8:]) +
41
+ "\nAI: Based on this conversation, what is the probability that the human has depression? "
42
+ "Please answer with a number between 0 and 1."
43
+ )
44
+ inputs_pred = tokenizer(prediction_prompt, return_tensors="pt").to(device)
45
+ output_pred_ids = model.generate(
46
+ **inputs_pred,
47
+ max_new_tokens=10,
48
+ do_sample=False,
49
+ pad_token_id=tokenizer.eos_token_id,
50
+ eos_token_id=tokenizer.eos_token_id,
51
+ )
52
+ pred_text = tokenizer.decode(output_pred_ids[0], skip_special_tokens=True)
53
+
54
+ import re
55
+ match = re.search(r"0?\.\d+", pred_text)
56
+ if match:
57
+ try:
58
+ depression_prob = float(match.group(0))
59
+ except:
60
+ depression_prob = None
61
+
62
+ return response, depression_prob
63
+
64
+ chat_interface = gr.Interface(
65
+ fn=chat,
66
+ inputs=gr.Textbox(label="User Input"),
67
+ outputs=[
68
+ gr.Textbox(label="Bot Response"),
69
+ gr.Textbox(label="Depression Probability (after 8 turns)")
70
+ ],
71
+ title="Depression Detection Chatbot",
72
+ description="This chatbot detects depression by conversing for multiple turns. After 8 turns, it estimates the probability of depression."
73
+ )
74
+
75
+ chat_interface.launch()