JotunnBurton commited on
Commit
53d08de
·
verified ·
1 Parent(s): bc45e1c

Upload clap_wrapper.py

Browse files
Files changed (1) hide show
  1. clap_wrapper.py +56 -0
clap_wrapper.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import torch
4
+ from transformers import ClapModel, ClapProcessor
5
+
6
+ from config import config
7
+
8
+ models = dict()
9
+ LOCAL_PATH = "./emotional/clap-htsat-fused"
10
+ processor = ClapProcessor.from_pretrained(LOCAL_PATH)
11
+
12
+
13
+ def get_clap_audio_feature(audio_data, device=config.bert_gen_config.device):
14
+ if (
15
+ sys.platform == "darwin"
16
+ and torch.backends.mps.is_available()
17
+ and device == "cpu"
18
+ ):
19
+ device = "mps"
20
+ if not device:
21
+ device = "cuda"
22
+ if device not in models.keys():
23
+ if config.webui_config.fp16_run:
24
+ models[device] = ClapModel.from_pretrained(
25
+ LOCAL_PATH, torch_dtype=torch.float16
26
+ ).to(device)
27
+ else:
28
+ models[device] = ClapModel.from_pretrained(LOCAL_PATH).to(device)
29
+ with torch.no_grad():
30
+ inputs = processor(
31
+ audios=audio_data, return_tensors="pt", sampling_rate=48000
32
+ ).to(device)
33
+ emb = models[device].get_audio_features(**inputs).float()
34
+ return emb.T
35
+
36
+
37
+ def get_clap_text_feature(text, device=config.bert_gen_config.device):
38
+ if (
39
+ sys.platform == "darwin"
40
+ and torch.backends.mps.is_available()
41
+ and device == "cpu"
42
+ ):
43
+ device = "mps"
44
+ if not device:
45
+ device = "cuda"
46
+ if device not in models.keys():
47
+ if config.webui_config.fp16_run:
48
+ models[device] = ClapModel.from_pretrained(
49
+ LOCAL_PATH, torch_dtype=torch.float16
50
+ ).to(device)
51
+ else:
52
+ models[device] = ClapModel.from_pretrained(LOCAL_PATH).to(device)
53
+ with torch.no_grad():
54
+ inputs = processor(text=text, return_tensors="pt").to(device)
55
+ emb = models[device].get_text_features(**inputs).float()
56
+ return emb.T