File size: 7,693 Bytes
b8ff14e
 
7b9bebe
c7ec63e
7b9bebe
c7ec63e
 
 
7b9bebe
 
36c201f
7b9bebe
c7ec63e
 
 
 
 
 
7b9bebe
2a89f5d
 
 
 
 
 
36c201f
 
 
 
 
2a89f5d
 
36c201f
 
 
 
 
2a89f5d
 
36c201f
2a89f5d
36c201f
c7ec63e
7b9bebe
2a89f5d
 
 
 
 
 
 
7b9bebe
 
 
2a89f5d
7b9bebe
2a89f5d
7b9bebe
 
 
2a89f5d
7b9bebe
 
 
 
2a89f5d
 
 
7b9bebe
 
2a89f5d
 
7b9bebe
 
dfd5bb6
36c201f
2a89f5d
 
 
 
 
4eafa91
2a89f5d
4eafa91
2a89f5d
36c201f
2a89f5d
36c201f
2a89f5d
 
 
 
 
 
 
 
 
 
36c201f
2a89f5d
 
 
 
 
 
 
dd2bb14
36c201f
2a89f5d
 
36c201f
 
 
 
2a89f5d
dd2bb14
 
36c201f
 
 
7b9bebe
 
 
 
dfd5bb6
7b9bebe
 
 
 
 
 
 
 
 
 
 
 
 
c7ec63e
7b9bebe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82dd2b7
99268d3
7b9bebe
 
 
 
 
82dd2b7
 
7b9bebe
 
 
b9af6a3
 
 
 
 
 
 
 
 
 
7b9bebe
c7ec63e
b9af6a3
92eb8b6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import gradio as gr
print("Gradio version:", gr.__version__)
import os, time, re
import numpy as np
import joblib
import librosa
from huggingface_hub import hf_hub_download
from deepface import DeepFace
from transformers import pipeline
# 如果不手动用 AutoTokenizer/AutoModel,就不必 import AutoTokenizer, AutoModelForSequenceClassification

# --- 1. 加载 SVM 语音模型 ---
print("Downloading SVM model from Hugging Face Hub...")
model_path = hf_hub_download(repo_id="GCLing/emotion-svm-model", filename="svm_emotion_model.joblib")
print(f"SVM model downloaded to: {model_path}")
svm_model = joblib.load(model_path)
print("SVM model loaded.")

# --- 2. 文本情绪分析:规则+zero-shot ---
try:
    zero_shot = pipeline("zero-shot-classification", model="joeddav/xlm-roberta-large-xnli")
except Exception as e:
    print("加载 zero-shot pipeline 失败:", e)
    zero_shot = None

candidate_labels = ["joy", "sadness", "anger", "fear", "surprise", "disgust"]
label_map_en2cn = {
    "joy": "高興", "sadness": "悲傷", "anger": "憤怒",
    "fear": "恐懼", "surprise": "驚訝", "disgust": "厭惡"
}

# 关键词列表:注意繁简体一致,或可添加两种形式
emo_keywords = {
    "happy": ["開心","快樂","愉快","喜悦","喜悅","歡喜","興奮","高興"],
    "angry": ["生氣","憤怒","不爽","發火","火大","氣憤"],
    "sad": ["傷心","難過","哭","難受","心酸","憂","悲","哀","痛苦","慘","愁"],
    "surprise": ["驚訝","意外","嚇","驚詫","詫異","訝異","好奇"],
    "fear": ["怕","恐懼","緊張","懼","膽怯","畏"],
    "disgust": ["噁心","厭惡","反感"]  # 如需“厭惡”等
}
# 否定词列表
negations = ["不","沒","沒有","別","勿","非"]

def keyword_emotion(text: str):
    """
    规则方法:扫描 emo_keywords,处理前置否定词。
    返回 None 或 {} 表示规则未命中;否则返回非空 dict,例如 {'angry': 2, 'sad':1} 或归一化 {'angry':0.67,'sad':0.33}。
    """
    if not text or text.strip() == "":
        return None
    text_proc = text.strip()  # 中文不需要 lower
    counts = {emo: 0 for emo in emo_keywords}
    for emo, kws in emo_keywords.items():
        for w in kws:
            idx = text_proc.find(w)
            if idx != -1:
                # 检查前一到两字符是否否定词
                neg = False
                for neg_word in negations:
                    plen = len(neg_word)
                    if idx - plen >= 0 and text_proc[idx-plen:idx] == neg_word:
                        neg = True
                        break
                if not neg:
                    counts[emo] += 1
                else:
                    # 若否定,可选择减分或忽略;这里忽略
                    pass
    total = sum(counts.values())
    if total > 0:
        # 归一化
        return {emo: counts[emo] / total for emo in counts if counts[emo] > 0}
    else:
        return None

def predict_text_mixed(text: str):
    """
    文本情绪分析:先规则,若规则命中返回最高情绪及其比例;否则fallback zero-shot返回多类别分布。
    返回 dict[str, float],供 Gradio Label 显示。
    """
    print("predict_text_mixed called, text:", repr(text))
    if not text or text.strip() == "":
        print("輸入為空,返回空")
        return {}
    # 规则优先
    res = keyword_emotion(text)
    print("keyword_emotion result:", res)
    if res:
        # 只返回最高项:也可返回完整分布 res
        top_emo = max(res, key=res.get)  # 例如 "angry"
        mapping = {
            "happy": "高興",
            "angry": "憤怒",
            "sad": "悲傷",
            "surprise": "驚訝",
            "fear": "恐懼",
            "disgust": "厭惡"
        }
        cn = mapping.get(top_emo, top_emo)
        prob = float(res[top_emo])
        print(f"使用規則方法,返回: {{'{cn}': {prob}}}")
        return {cn: prob}
    # 规则未命中,zero-shot fallback
    if zero_shot is None:
        print("zero_shot pipeline 未加载,返回中性")
        return {"中性": 1.0}
    try:
        out = zero_shot(text, candidate_labels=candidate_labels,
                        hypothesis_template="这句話表達了{}情緒")
        print("zero-shot 返回:", out)
        result = {}
        for lab, sc in zip(out["labels"], out["scores"]):
            cn = label_map_en2cn.get(lab.lower(), lab)
            result[cn] = float(sc)
        print("zero-shot 结果映射中文:", result)
        return result
    except Exception as e:
        print("zero-shot error:", e)
        return {"中性": 1.0}

# --- 3. 语音情绪预测函数 ---
def extract_feature(signal: np.ndarray, sr: int) -> np.ndarray:
    mfcc = librosa.feature.mfcc(y=signal, sr=sr, n_mfcc=13)
    return np.concatenate([np.mean(mfcc, axis=1), np.var(mfcc, axis=1)])

def predict_voice(audio_path: str):
    if not audio_path:
        print("predict_voice: 无 audio_path,跳过")
        return {}
    try:
        signal, sr = librosa.load(audio_path, sr=None)
        feat = extract_feature(signal, sr)
        probs = svm_model.predict_proba([feat])[0]
        labels = svm_model.classes_
        return {labels[i]: float(probs[i]) for i in range(len(labels))}
    except Exception as e:
        print("predict_voice error:", e)
        return {}

# --- 4. 人脸情绪预测函数 ---
def predict_face(img: np.ndarray):
    print("predict_face called, img is None?", img is None)
    if img is None:
        return {}
    try:
        res = DeepFace.analyze(img, actions=["emotion"], detector_backend="opencv")
        if isinstance(res, list):
            first = res[0] if res else {}
            emo = first.get("emotion", {}) if isinstance(first, dict) else {}
        else:
            emo = res.get("emotion", {}) if isinstance(res, dict) else {}
        # 转 float,确保 JSON 可序列化
        emo_fixed = {k: float(v) for k, v in emo.items()}
        print("predict_face result:", emo_fixed)
        return emo_fixed
    except Exception as e:
        print("DeepFace.analyze error:", e)
        return {}

# --- 5. Gradio 界面 ---
with gr.Blocks() as demo:
    gr.Markdown("## 多模態情緒分析示例")
    with gr.Tabs():
        # 臉部情緒 Tab
        with gr.TabItem("臉部情緒"):
            gr.Markdown("### 臉部情緒 (即時 Webcam Streaming 分析)")
            with gr.Row():
                webcam = gr.Image(sources="webcam", streaming=True, type="numpy", label="攝像頭畫面")
                face_out = gr.JSON(label="情緒原始結果")
            webcam.stream(fn=predict_face, inputs=webcam, outputs=face_out)
        # 語音情緒 Tab
        with gr.TabItem("語音情緒"):
            gr.Markdown("### 語音情緒 分析")
            with gr.Row():
                audio = gr.Audio(sources="microphone", streaming=False, type="filepath", label="錄音")

                voice_out = gr.Label(label="語音情緒結果")
            audio.change(fn=predict_voice, inputs=audio, outputs=voice_out)
        # 文字情緒 Tab
        with gr.TabItem("文字情緒"):
                gr.Markdown("### 文字情緒 分析 (規則+zero-shot)")
                with gr.Row():
                    text = gr.Textbox(lines=3, placeholder="請輸入中文文字…")
                    text_out = gr.Label(label="文字情緒結果")
                # 可以用回車提交
                btn = gr.Button("分析")
                btn.click(fn=predict_text_mixed, inputs=text, outputs=text_out)
    
    return demo

if __name__ == "__main__":
    demo = build_interface()
    demo.launch()