JotunnBurton commited on
Commit
8a53174
·
verified ·
1 Parent(s): 9fdff96

Update clap_wrapper.py

Browse files
Files changed (1) hide show
  1. clap_wrapper.py +9 -34
clap_wrapper.py CHANGED
@@ -1,35 +1,14 @@
1
  import sys
2
- import os
3
-
4
  import torch
5
  from transformers import ClapModel, ClapProcessor
6
- from huggingface_hub import hf_hub_download
7
- from config import config
8
-
9
- # กำหนดชื่อและ path ของโมเดล
10
- HF_REPO_ID = "laion/clap-htsat-fused"
11
- LOCAL_PATH = "./emotional/clap-htsat-fused"
12
 
13
- # ตรวจสอบว่ามีไฟล์โมเดลใน LOCAL_PATH แล้วหรือยัง ถ้าไม่มีก็ดาวน์โหลด
14
- def ensure_model_downloaded():
15
- os.makedirs(LOCAL_PATH, exist_ok=True)
16
- required_files = ["pytorch_model.bin", "config.json", "preprocessor_config.json"]
17
- for file in required_files:
18
- local_file_path = os.path.join(LOCAL_PATH, file)
19
- if not os.path.isfile(local_file_path):
20
- print(f"Downloading {file} from {HF_REPO_ID}...")
21
- hf_hub_download(
22
- repo_id=HF_REPO_ID,
23
- filename=file,
24
- cache_dir=LOCAL_PATH,
25
- force_download=False
26
- )
27
-
28
- ensure_model_downloaded()
29
 
30
- # โหลด processor
 
31
  models = dict()
32
- processor = ClapProcessor.from_pretrained("laion/clap-htsat-fused")
 
33
 
34
  def get_clap_audio_feature(audio_data, device=config.bert_gen_config.device):
35
  if (
@@ -43,12 +22,10 @@ def get_clap_audio_feature(audio_data, device=config.bert_gen_config.device):
43
  if device not in models.keys():
44
  if config.webui_config.fp16_run:
45
  models[device] = ClapModel.from_pretrained(
46
- LOCAL_PATH, torch_dtype=torch.float16, local_files_only=True
47
  ).to(device)
48
  else:
49
- models[device] = ClapModel.from_pretrained(
50
- LOCAL_PATH, local_files_only=True
51
- ).to(device)
52
  with torch.no_grad():
53
  inputs = processor(
54
  audios=audio_data, return_tensors="pt", sampling_rate=48000
@@ -69,12 +46,10 @@ def get_clap_text_feature(text, device=config.bert_gen_config.device):
69
  if device not in models.keys():
70
  if config.webui_config.fp16_run:
71
  models[device] = ClapModel.from_pretrained(
72
- LOCAL_PATH, torch_dtype=torch.float16, local_files_only=True
73
  ).to(device)
74
  else:
75
- models[device] = ClapModel.from_pretrained(
76
- LOCAL_PATH, local_files_only=True
77
- ).to(device)
78
  with torch.no_grad():
79
  inputs = processor(text=text, return_tensors="pt").to(device)
80
  emb = models[device].get_text_features(**inputs).float()
 
1
  import sys
 
 
2
  import torch
3
  from transformers import ClapModel, ClapProcessor
 
 
 
 
 
 
4
 
5
+ from config import config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ # ใช้โมเดลจาก Hugging Face Hub
8
+ REPO_NAME = "laion/clap-htsat-fused"
9
  models = dict()
10
+ processor = ClapProcessor.from_pretrained(REPO_NAME)
11
+
12
 
13
  def get_clap_audio_feature(audio_data, device=config.bert_gen_config.device):
14
  if (
 
22
  if device not in models.keys():
23
  if config.webui_config.fp16_run:
24
  models[device] = ClapModel.from_pretrained(
25
+ REPO_NAME, torch_dtype=torch.float16
26
  ).to(device)
27
  else:
28
+ models[device] = ClapModel.from_pretrained(REPO_NAME).to(device)
 
 
29
  with torch.no_grad():
30
  inputs = processor(
31
  audios=audio_data, return_tensors="pt", sampling_rate=48000
 
46
  if device not in models.keys():
47
  if config.webui_config.fp16_run:
48
  models[device] = ClapModel.from_pretrained(
49
+ REPO_NAME, torch_dtype=torch.float16
50
  ).to(device)
51
  else:
52
+ models[device] = ClapModel.from_pretrained(REPO_NAME).to(device)
 
 
53
  with torch.no_grad():
54
  inputs = processor(text=text, return_tensors="pt").to(device)
55
  emb = models[device].get_text_features(**inputs).float()