wli3221134 commited on
Commit
53d1d01
·
verified ·
1 Parent(s): 719b808

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +9 -1
model.py CHANGED
@@ -5,14 +5,22 @@ from llama_nar import LlamaNAREmb
5
  from transformers import LlamaConfig
6
  import time
7
  import torch.nn.functional as F
 
8
 
9
 
10
  class Wav2Vec2BERT_Llama(nn.Module):
11
  def __init__(self):
12
  super().__init__()
13
 
 
 
 
 
14
  # 1. 加载预训练模型
15
- self.wav2vec2bert = Wav2Vec2BertModel.from_pretrained("/mntcephfs/lab_data/wangli/pretrain/w2v-bert-2.0/", output_hidden_states=True)
 
 
 
16
 
17
  # 2. 选择性冻结参数
18
  for name, param in self.wav2vec2bert.named_parameters():
 
5
  from transformers import LlamaConfig
6
  import time
7
  import torch.nn.functional as F
8
+ from huggingface_hub import hf_hub_download
9
 
10
 
11
  class Wav2Vec2BERT_Llama(nn.Module):
12
  def __init__(self):
13
  super().__init__()
14
 
15
+ ckpt_path = hf_hub_download(
16
+ repo_id="amphion/deepfake_detection",
17
+ filename="w2vbert2"
18
+ )
19
  # 1. 加载预训练模型
20
+ self.wav2vec2bert = Wav2Vec2BertModel.from_pretrained(
21
+ ckpt_path,
22
+ output_hidden_states=True
23
+ )
24
 
25
  # 2. 选择性冻结参数
26
  for name, param in self.wav2vec2bert.named_parameters():