wli3221134 commited on
Commit
340d6c0
·
verified ·
1 Parent(s): 5db4ee9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -4
app.py CHANGED
@@ -43,7 +43,7 @@ def load_model():
43
  model, threshold = load_model()
44
 
45
  # 检测函数
46
- def detect(dataset, model):
47
  """进行音频伪造检测"""
48
  with torch.no_grad():
49
  for batch in dataset:
@@ -56,9 +56,19 @@ def detect(dataset, model):
56
  'attention_mask': pf['attention_mask'].to(device)
57
  } for pf in batch['prompt_features']]
58
 
 
59
  # 模型的前向传播逻辑 (需要补充具体实现)
 
 
 
 
 
 
 
 
 
60
  # 假设 result 是模型返回的结果
61
- result = {"is_fake": True, "confidence": 85.5} # 示例返回值
62
  return result
63
 
64
  # 音频伪造检测主函数
@@ -70,11 +80,12 @@ def audio_deepfake_detection(demonstrations, query_audio_path):
70
  :return: 检测结果
71
  """
72
  demonstration_paths = [audio[0] for audio in demonstrations if audio[0] is not None]
 
73
  print(f"Demonstration audio paths: {demonstration_paths}")
74
  print(f"Query audio path: {query_audio_path}")
75
 
76
  # 数据集处理
77
- audio_dataset = dataset.DemoDataset(demonstration_paths, query_audio_path)
78
 
79
  # 调用检测函数
80
  result = detect(audio_dataset, model)
@@ -93,7 +104,7 @@ def gradio_ui():
93
  (demonstration_audio2, label2),
94
  (demonstration_audio3, label3),
95
  ]
96
- return audio_deepfake_detection(demonstrations, query_audio)
97
 
98
  interface = gr.Interface(
99
  fn=detection_wrapper,
 
43
  model, threshold = load_model()
44
 
45
  # 检测函数
46
+ def detect(dataset):
47
  """进行音频伪造检测"""
48
  with torch.no_grad():
49
  for batch in dataset:
 
56
  'attention_mask': pf['attention_mask'].to(device)
57
  } for pf in batch['prompt_features']]
58
 
59
+ prompt_labels = batch['prompt_labels'].to(device)
60
  # 模型的前向传播逻辑 (需要补充具体实现)
61
+ outputs = model({
62
+ 'main_features': main_features,
63
+ 'prompt_features': prompt_features,
64
+ 'prompt_labels': prompt_labels
65
+ })
66
+
67
+ avg_scores = outputs['avg_logits'].softmax(dim=-1) # [batch_size, 2]
68
+ deepfake_scores = avg_scores[:, 1].cpu() # [batch_size]
69
+ is_fake = True if deepfake_scores[0] > threshold else False
70
  # 假设 result 是模型返回的结果
71
+ result = {"is_fake": is_fake, "confidence": deepfake_scores[0]} # 示例返回值
72
  return result
73
 
74
  # 音频伪造检测主函数
 
80
  :return: 检测结果
81
  """
82
  demonstration_paths = [audio[0] for audio in demonstrations if audio[0] is not None]
83
+ demonstration_labels = [audio[1] for audio in demonstrations if audio[1] is not None]
84
  print(f"Demonstration audio paths: {demonstration_paths}")
85
  print(f"Query audio path: {query_audio_path}")
86
 
87
  # 数据集处理
88
+ audio_dataset = dataset.DemoDataset(demonstration_paths, demonstration_labels, query_audio_path)
89
 
90
  # 调用检测函数
91
  result = detect(audio_dataset, model)
 
104
  (demonstration_audio2, label2),
105
  (demonstration_audio3, label3),
106
  ]
107
+ return audio_deepfake_detection(demonstrations,query_audio)
108
 
109
  interface = gr.Interface(
110
  fn=detection_wrapper,