Update app.py
Browse files
app.py
CHANGED
@@ -6,10 +6,14 @@ import joblib
|
|
6 |
import numpy as np
|
7 |
import librosa
|
8 |
import gradio as gr
|
|
|
|
|
|
|
9 |
from huggingface_hub import hf_hub_download
|
10 |
from deepface import DeepFace
|
11 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
|
12 |
|
|
|
13 |
# --- 1. 下載並載入 SVM 模型 ---
|
14 |
# 這裡 repo_id 填你的模型倉庫路徑,例如 "GCLing/emotion-svm-model"
|
15 |
# filename 填上傳到該倉庫的檔案名,例如 "svm_emotion_model.joblib"
|
@@ -20,12 +24,21 @@ svm_model = joblib.load(model_path)
|
|
20 |
print("SVM model loaded.")
|
21 |
|
22 |
# --- 2. 載入文字情緒分析模型 ---
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
# --- 3. 聲音特徵擷取函式 ---
|
31 |
def extract_feature(signal: np.ndarray, sr: int) -> np.ndarray:
|
@@ -89,22 +102,44 @@ def predict_voice(audio_path: str):
|
|
89 |
|
90 |
|
91 |
|
92 |
-
def
|
93 |
-
|
|
|
|
|
|
|
94 |
if not text or text.strip() == "":
|
95 |
return {}
|
96 |
-
#
|
97 |
-
|
98 |
-
if
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
try:
|
101 |
-
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
104 |
return result
|
105 |
except Exception as e:
|
106 |
-
print("
|
107 |
-
return {}
|
|
|
|
|
|
|
108 |
|
109 |
|
110 |
# --- 5. 建立 Gradio 介面 ---
|
@@ -128,10 +163,11 @@ with gr.Blocks() as demo:
|
|
128 |
audio.change(fn=predict_voice, inputs=audio, outputs=audio_output)
|
129 |
|
130 |
with gr.TabItem("文字情緒"):
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
|
|
135 |
|
136 |
|
137 |
|
|
|
6 |
import numpy as np
|
7 |
import librosa
|
8 |
import gradio as gr
|
9 |
+
import time
|
10 |
+
import re
|
11 |
+
from transformers import pipeline
|
12 |
from huggingface_hub import hf_hub_download
|
13 |
from deepface import DeepFace
|
14 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
|
15 |
|
16 |
+
|
17 |
# --- 1. 下載並載入 SVM 模型 ---
|
18 |
# 這裡 repo_id 填你的模型倉庫路徑,例如 "GCLing/emotion-svm-model"
|
19 |
# filename 填上傳到該倉庫的檔案名,例如 "svm_emotion_model.joblib"
|
|
|
24 |
print("SVM model loaded.")
|
25 |
|
26 |
# --- 2. 載入文字情緒分析模型 ---
|
27 |
+
zero_shot = pipeline("zero-shot-classification", model="joeddav/xlm-roberta-large-xnli")
|
28 |
+
candidate_labels = ["joy", "sadness", "anger", "fear", "surprise", "disgust"]
|
29 |
+
label_map_en2cn = {
|
30 |
+
"joy": "高興", "sadness": "悲傷", "anger": "憤怒",
|
31 |
+
"fear": "恐懼", "surprise": "驚訝", "disgust": "厭惡"
|
32 |
+
}
|
33 |
+
emo_keywords = {
|
34 |
+
"happy": ["開心","快樂","愉快","喜悦","喜悅","歡喜","興奮","高興"],
|
35 |
+
"angry": ["生氣","憤怒","不爽","發火","火大","氣憤"],
|
36 |
+
"sad": ["傷心","難過","哭","難受","心酸","憂","悲","哀","痛苦","慘","愁"],
|
37 |
+
"surprise": ["驚訝","意外","嚇","驚詫","詫異","訝異","好奇"],
|
38 |
+
"fear": ["怕","恐懼","緊張","懼","膽怯","畏"]
|
39 |
+
}
|
40 |
+
# 简单否定词列表
|
41 |
+
negations = ["不","沒","沒有","別","勿","非"]
|
42 |
|
43 |
# --- 3. 聲音特徵擷取函式 ---
|
44 |
def extract_feature(signal: np.ndarray, sr: int) -> np.ndarray:
|
|
|
102 |
|
103 |
|
104 |
|
105 |
+
def predict_text_mixed(text: str):
|
106 |
+
"""
|
107 |
+
先用 keyword_emotion 规则;若未命中再用 zero-shot 分类,
|
108 |
+
返回 {中文标签: float_score} 的 dict,供 gr.Label 显示。
|
109 |
+
"""
|
110 |
if not text or text.strip() == "":
|
111 |
return {}
|
112 |
+
# 规则优先
|
113 |
+
res = keyword_emotion(text)
|
114 |
+
if res:
|
115 |
+
# 只返回最高那一项及其比例,也可返回完整分布
|
116 |
+
top_emo = max(res, key=res.get)
|
117 |
+
# 可将英文 key 转成中文,若需要
|
118 |
+
# mapping: happy->高兴, angry->愤怒, etc.
|
119 |
+
mapping = {
|
120 |
+
"happy": "高兴",
|
121 |
+
"angry": "愤怒",
|
122 |
+
"sad": "悲伤",
|
123 |
+
"surprise": "惊讶",
|
124 |
+
"fear": "恐惧"
|
125 |
+
}
|
126 |
+
cn = mapping.get(top_emo, top_emo)
|
127 |
+
return {cn: res[top_emo]}
|
128 |
+
# 规则未命中,zero-shot fallback
|
129 |
try:
|
130 |
+
out = zero_shot(text, candidate_labels=candidate_labels,
|
131 |
+
hypothesis_template="这句话表达了{}情绪")
|
132 |
+
result = {}
|
133 |
+
for lab, sc in zip(out["labels"], out["scores"]):
|
134 |
+
cn = label_map_en2cn.get(lab.lower(), lab)
|
135 |
+
result[cn] = float(sc)
|
136 |
return result
|
137 |
except Exception as e:
|
138 |
+
print("zero-shot error:", e)
|
139 |
+
return {"中性": 1.0}
|
140 |
+
|
141 |
+
|
142 |
+
|
143 |
|
144 |
|
145 |
# --- 5. 建立 Gradio 介面 ---
|
|
|
163 |
audio.change(fn=predict_voice, inputs=audio, outputs=audio_output)
|
164 |
|
165 |
with gr.TabItem("文字情緒"):
|
166 |
+
gr.Markdown("### 文字情緒 分析 (规则+zero-shot)")
|
167 |
+
with gr.Row():
|
168 |
+
text = gr.Textbox(lines=3, placeholder="請輸入中文文字…")
|
169 |
+
text_out = gr.Label(label="文字情緒結果")
|
170 |
+
text.submit(fn=predict_text_mixed, inputs=text, outputs=text_out)
|
171 |
|
172 |
|
173 |
|