Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -43,7 +43,7 @@ def load_model():
|
|
43 |
model, threshold = load_model()
|
44 |
|
45 |
# 检测函数
|
46 |
-
def detect(dataset
|
47 |
"""进行音频伪造检测"""
|
48 |
with torch.no_grad():
|
49 |
for batch in dataset:
|
@@ -56,9 +56,19 @@ def detect(dataset, model):
|
|
56 |
'attention_mask': pf['attention_mask'].to(device)
|
57 |
} for pf in batch['prompt_features']]
|
58 |
|
|
|
59 |
# 模型的前向传播逻辑 (需要补充具体实现)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
# 假设 result 是模型返回的结果
|
61 |
-
result = {"is_fake":
|
62 |
return result
|
63 |
|
64 |
# 音频伪造检测主函数
|
@@ -70,11 +80,12 @@ def audio_deepfake_detection(demonstrations, query_audio_path):
|
|
70 |
:return: 检测结果
|
71 |
"""
|
72 |
demonstration_paths = [audio[0] for audio in demonstrations if audio[0] is not None]
|
|
|
73 |
print(f"Demonstration audio paths: {demonstration_paths}")
|
74 |
print(f"Query audio path: {query_audio_path}")
|
75 |
|
76 |
# 数据集处理
|
77 |
-
audio_dataset = dataset.DemoDataset(demonstration_paths, query_audio_path)
|
78 |
|
79 |
# 调用检测函数
|
80 |
result = detect(audio_dataset, model)
|
@@ -93,7 +104,7 @@ def gradio_ui():
|
|
93 |
(demonstration_audio2, label2),
|
94 |
(demonstration_audio3, label3),
|
95 |
]
|
96 |
-
return audio_deepfake_detection(demonstrations,
|
97 |
|
98 |
interface = gr.Interface(
|
99 |
fn=detection_wrapper,
|
|
|
43 |
model, threshold = load_model()
|
44 |
|
45 |
# 检测函数
|
46 |
+
def detect(dataset):
|
47 |
"""进行音频伪造检测"""
|
48 |
with torch.no_grad():
|
49 |
for batch in dataset:
|
|
|
56 |
'attention_mask': pf['attention_mask'].to(device)
|
57 |
} for pf in batch['prompt_features']]
|
58 |
|
59 |
+
prompt_labels = batch['prompt_labels'].to(device)
|
60 |
# 模型的前向传播逻辑 (需要补充具体实现)
|
61 |
+
outputs = model({
|
62 |
+
'main_features': main_features,
|
63 |
+
'prompt_features': prompt_features,
|
64 |
+
'prompt_labels': prompt_labels
|
65 |
+
})
|
66 |
+
|
67 |
+
avg_scores = outputs['avg_logits'].softmax(dim=-1) # [batch_size, 2]
|
68 |
+
deepfake_scores = avg_scores[:, 1].cpu() # [batch_size]
|
69 |
+
is_fake = True if deepfake_scores[0] > threshold else False
|
70 |
# 假设 result 是模型返回的结果
|
71 |
+
result = {"is_fake": is_fake, "confidence": deepfake_scores[0]} # 示例返回值
|
72 |
return result
|
73 |
|
74 |
# 音频伪造检测主函数
|
|
|
80 |
:return: 检测结果
|
81 |
"""
|
82 |
demonstration_paths = [audio[0] for audio in demonstrations if audio[0] is not None]
|
83 |
+
demonstration_labels = [audio[1] for audio in demonstrations if audio[1] is not None]
|
84 |
print(f"Demonstration audio paths: {demonstration_paths}")
|
85 |
print(f"Query audio path: {query_audio_path}")
|
86 |
|
87 |
# 数据集处理
|
88 |
+
audio_dataset = dataset.DemoDataset(demonstration_paths, demonstration_labels, query_audio_path)
|
89 |
|
90 |
# 调用检测函数
|
91 |
result = detect(audio_dataset, model)
|
|
|
104 |
(demonstration_audio2, label2),
|
105 |
(demonstration_audio3, label3),
|
106 |
]
|
107 |
+
return audio_deepfake_detection(demonstrations,query_audio)
|
108 |
|
109 |
interface = gr.Interface(
|
110 |
fn=detection_wrapper,
|