CaioMartins1 commited on
Commit
c64f408
·
1 Parent(s): d5cc24a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -11
app.py CHANGED
@@ -1,18 +1,17 @@
1
  import gradio as gr
2
- from transformers import pipeline
3
  from datasets import load_dataset
4
 
5
  # Load the dataset
6
  advice_dataset = load_dataset("ziq/depression_advice")
7
- depression_dataset = load_dataset("ShreyaR/DepressionDetection")
8
 
9
- # Load the model
10
- model_name = "mrm8488/distilroberta-base-finetuned-suicide-depression"
11
- qa_pipeline = pipeline("question-answering", model=model_name)
 
12
 
13
  # Extract context and messages
14
  contexts = advice_dataset["train"]["text"]
15
- messages = depression_dataset["train"]["clean_text"]
16
 
17
  # Define a function to generate answers
18
  def generate_answer(messages):
@@ -20,11 +19,19 @@ def generate_answer(messages):
20
  if isinstance(messages, list):
21
  messages = messages[0]
22
 
23
- # Use the QA model to generate the answer for the single message
24
- results = qa_pipeline(question=messages, context=contexts)
25
-
26
- # Return the answer
27
- return results["answer"] if results["answer"] else "No answer found."
 
 
 
 
 
 
 
 
28
 
29
  # Create a Gradio interface
30
  iface = gr.Interface(
 
1
  import gradio as gr
2
+ from transformers import pipeline, BertTokenizer, BertForQuestionAnswering
3
  from datasets import load_dataset
4
 
5
  # Load the dataset
6
  advice_dataset = load_dataset("ziq/depression_advice")
 
7
 
8
+ # Load the fine-tuned BERT model and tokenizer
9
+ model_dir = "./bert-finetuned-depression"
10
+ model = BertForQuestionAnswering.from_pretrained(model_dir)
11
+ tokenizer = BertTokenizer.from_pretrained(model_dir)
12
 
13
  # Extract context and messages
14
  contexts = advice_dataset["train"]["text"]
 
15
 
16
  # Define a function to generate answers
17
  def generate_answer(messages):
 
19
  if isinstance(messages, list):
20
  messages = messages[0]
21
 
22
+ # Tokenize the input message
23
+ inputs = tokenizer(messages, return_tensors="pt")
24
+
25
+ # Use the fine-tuned BERT model to generate the answer for the single message
26
+ with torch.no_grad():
27
+ outputs = model(**inputs)
28
+
29
+ # Decode the output and return the answer
30
+ answer_start = torch.argmax(outputs.start_logits)
31
+ answer_end = torch.argmax(outputs.end_logits) + 1
32
+ answer = tokenizer.decode(inputs["input_ids"][0][answer_start:answer_end])
33
+
34
+ return answer if answer else "No answer found."
35
 
36
  # Create a Gradio interface
37
  iface = gr.Interface(