File size: 7,894 Bytes
5db4ee9
0a63b23
719b808
 
88a8fb2
 
719b808
ca6274c
e7b9619
 
 
0a63b23
4a81ee5
719b808
 
 
c24b00b
 
719b808
4a81ee5
 
 
719b808
4a81ee5
719b808
4a81ee5
 
 
 
0aed679
 
 
 
 
 
4a81ee5
 
0aed679
4a81ee5
 
afbf6ef
0aed679
4a81ee5
 
 
0aed679
4a81ee5
 
0aed679
4a81ee5
719b808
4a81ee5
 
0aed679
88a8fb2
0aed679
34146f0
0aed679
 
 
 
34146f0
 
 
 
0aed679
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34146f0
0aed679
340d6c0
 
 
 
 
 
0aed679
4a81ee5
 
0aed679
 
 
 
 
 
 
88a8fb2
34146f0
4a81ee5
7c8310f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a81ee5
88a8fb2
7c8310f
4a81ee5
 
 
afbf6ef
87c8d2e
4a81ee5
ca6274c
afbf6ef
 
ca6274c
 
88a8fb2
 
7c8310f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca6274c
88a8fb2
18c21b7
88a8fb2
efcd998
88a8fb2
efcd998
 
 
e9c3646
 
 
 
 
 
 
 
ca6274c
88a8fb2
ca6274c
 
88a8fb2
719b808
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import spaces
import gradio as gr
import os
import torch
from model import Wav2Vec2BERT_Llama  # 自定义模型模块
import dataset  # 自定义数据集模块
from huggingface_hub import hf_hub_download

@spaces.GPU
def dummy(): # just a dummy
    pass

# 修改 load_model 函数
def load_model():
    checkpoint_path = hf_hub_download(
        repo_id="amphion/deepfake_detection", 
        filename="checkpoints_wav2vec2bert_ft_llama_labels_ASVspoof2019_RandomPrompts_6/model_checkpoint.pth",
        repo_type="model"
    )
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
    return checkpoint_path

checkpoint_path = load_model()

# 将 detect 函数移到 GPU 装饰器下
@spaces.GPU
def detect_on_gpu(dataset):
    """在 GPU 上进行音频伪造检测"""
    print("\n=== 开始音频检测 ===")
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"使用设备: {device}")
    
    print("正在初始化模型...")
    model = Wav2Vec2BERT_Llama().to(device)
    
    print(f"正在加载模型权重: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model_state_dict = checkpoint['model_state_dict']
    threshold = 0.8
    print(f"检测阈值设置为: {threshold}")

    # 处理模型状态字典的 key
    if hasattr(model, 'module') and not any(key.startswith('module.') for key in model_state_dict.keys()):
        print("添加 'module.' 前缀到状态字典的 key")
        model_state_dict = {'module.' + key: value for key, value in model_state_dict.items()}
    elif not hasattr(model, 'module') and any(key.startswith('module.') for key in model_state_dict.keys()):
        print("移除状态字典 key 中的 'module.' 前缀")
        model_state_dict = {key.replace('module.', ''): value for key, value in model_state_dict.items()}

    model.load_state_dict(model_state_dict)
    model.eval()
    print("模型加载完成,进入评估模式")

    print("\n开始处理音频数据...")
    with torch.no_grad():
        for batch_idx, batch in enumerate(dataset):
            print(f"\n处理批次 {batch_idx + 1}")
            
            print("准备主特征...")
            main_features = {
                'input_features': batch['main_features']['input_features'].to(device),
                'attention_mask': batch['main_features']['attention_mask'].to(device)
            }
            print(f"主特征形状: {main_features['input_features'].shape}")
            
            if len(batch['prompt_features']) > 0:
                print("\n准备提示特征...")
                prompt_features = [{
                    'input_features': pf['input_features'].to(device),
                    'attention_mask': pf['attention_mask'].to(device)
                } for pf in batch['prompt_features']]
                print(f"提示特征数量: {len(prompt_features)}")
                print(f"第一个提示特征形状: {prompt_features[0]['input_features'].shape}")

                print("\n准备提示标签...")
                prompt_labels = batch['prompt_labels'].to(device)
                print(f"提示标签形状: {prompt_labels.shape}")
                print(f"提示标签值: {prompt_labels}")
            else:
                prompt_features = []
                prompt_labels = []

            print("\n执行模型推理...")
            outputs = model({
                'main_features': main_features,
                'prompt_features': prompt_features,
                'prompt_labels': prompt_labels
            })

            print("\n处理模型输出...")
            avg_scores = outputs['avg_logits'].softmax(dim=-1)
            deepfake_scores = avg_scores[:, 1].cpu()
            is_fake = deepfake_scores[0].item() > threshold
            
            result = {"is_fake": is_fake, "confidence": deepfake_scores[0] if is_fake else 1-deepfake_scores[0]}

            break
    
    print("\n=== 检测完成 ===")
    return result

# 修改音频伪造检测主函数
# def audio_deepfake_detection(demonstrations, query_audio_path):
#     demonstration_paths = [audio[0] for audio in demonstrations if audio[0] is not None]
#     demonstration_labels = [audio[1] for audio in demonstrations if audio[1] is not None]
#     if len(demonstration_paths) != len(demonstration_labels):
#         demonstration_labels = demonstration_labels[:len(demonstration_paths)]
    
#     # 数据集处理
#     audio_dataset = dataset.DemoDataset(demonstration_paths, demonstration_labels, query_audio_path)
    
#     # 调用 GPU 检测函数
#     result = detect_on_gpu(audio_dataset)
    
#     return {
#         "Is AI Generated": result["is_fake"],
#         "Confidence": f"{100*result['confidence']:.2f}%"
#     }
# 0 demonstrations
def audio_deepfake_detection(query_audio_path):
    
    # 数据集处理
    audio_dataset = dataset.DemoDataset([], [], query_audio_path)
    
    # 调用 GPU 检测函数
    result = detect_on_gpu(audio_dataset)
    is_fake = "是/Yes" if result["is_fake"] else "否/No"
    confidence = f"{100*result['confidence']:.2f}%"
    
    return {
        "是否为AI生成/Is AI Generated": is_fake,
        "检测可信度/Confidence": confidence
    }

# Gradio 界面
def gradio_ui():
    # def detection_wrapper(demonstration_audio1, label1, demonstration_audio2, label2, demonstration_audio3, label3, query_audio):
    #     demonstrations = [
    #         (demonstration_audio1, label1),
    #         (demonstration_audio2, label2),
    #         (demonstration_audio3, label3),
    #     ]
    #     return audio_deepfake_detection(demonstrations,query_audio)

    # interface = gr.Interface(
    #     fn=detection_wrapper,
    #     inputs=[
    #         gr.Audio(sources=["upload"], type="filepath", label="Demonstration Audio 1"),
    #         gr.Dropdown(choices=["bonafide", "spoof"], value="bonafide", label="Label 1"),
    #         gr.Audio(sources=["upload"], type="filepath", label="Demonstration Audio 2"),
    #         gr.Dropdown(choices=["bonafide", "spoof"], value="bonafide", label="Label 2"),
    #         gr.Audio(sources=["upload"], type="filepath", label="Demonstration Audio 3"),
    #         gr.Dropdown(choices=["bonafide", "spoof"], value="bonafide", label="Label 3"),
    #         gr.Audio(sources=["upload"], type="filepath", label="Query Audio (Audio for Detection)")
    #     ],
    #     outputs=gr.JSON(label="Detection Results"),
    #     title="Audio Deepfake Detection System",
    #     description="Upload demonstration audios and a query audio to detect whether the query is AI-generated.",
    # )
    # return interface

    def detection_wrapper(query_audio):
        return audio_deepfake_detection(query_audio)

    interface = gr.Interface(
        fn=detection_wrapper,
        inputs=[
            gr.Audio(sources=["upload"], type="filepath", label="测试音频 / Test Audio")
        ],
        outputs=gr.JSON(label="检测结果 / Detection Results"),
        title="音频伪造检测系统 / Audio Deepfake Detection System",
        description="上传一个测试音频以检测该音频是否为AI生成。/ Upload a test audio to detect whether the audio is AI-generated.",
        article=(
            "由香港中文大学(深圳)武执政教授团队开发。"
            "Developed by a team led by Prof Zhizheng Wu from the Chinese University of Hong Kong, Shenzhen."
            "\n\n"
            "本系统用于检测音频是否为AI生成,适用于研究和教育目的。"
            "This system is designed to detect whether an audio is AI-generated, "
            "and is intended for research and educational purposes."
        )
    )
    return interface

if __name__ == "__main__":
    demo = gradio_ui()
    demo.launch()