admin commited on
Commit
45381f2
·
1 Parent(s): 8389616
Files changed (3) hide show
  1. app.py +78 -80
  2. model.py +39 -3
  3. requirements.txt +5 -3
app.py CHANGED
@@ -11,22 +11,7 @@ import torchvision.transforms as transforms
11
  from collections import Counter
12
  from PIL import Image
13
  from tqdm import tqdm
14
- from model import net, MODEL_DIR
15
-
16
-
17
- MODEL = net()
18
- TRANS = {
19
- "PearlRiver": "Pearl River",
20
- "YoungChang": "YOUNG CHANG",
21
- "Steinway-T": "STEINWAY Theater",
22
- "Hsinghai": "HSINGHAI",
23
- "Kawai": "KAWAI",
24
- "Steinway": "STEINWAY",
25
- "Kawai-G": "KAWAI Grand",
26
- "Yamaha": "YAMAHA",
27
- }
28
- CLASSES = list(TRANS.keys())
29
- CACHE_DIR = "./__pycache__/tmp"
30
 
31
 
32
  def most_common_element(input_list):
@@ -36,30 +21,26 @@ def most_common_element(input_list):
36
 
37
 
38
  def wav_to_mel(audio_path: str, width=0.18):
39
- os.makedirs(CACHE_DIR, exist_ok=True)
40
- try:
41
- y, sr = librosa.load(audio_path, sr=48000)
42
- non_silent = y
43
- mel_spec = librosa.feature.melspectrogram(y=non_silent, sr=sr)
44
- log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max)
45
- dur = librosa.get_duration(y=non_silent, sr=sr)
46
- total_frames = log_mel_spec.shape[1]
47
- step = int(width * total_frames / dur)
48
- count = int(total_frames / step)
49
- begin = int(0.5 * (total_frames - count * step))
50
- end = begin + step * count
51
- for i in tqdm(range(begin, end, step), desc="Converting wav to jpgs..."):
52
- librosa.display.specshow(log_mel_spec[:, i : i + step])
53
- plt.axis("off")
54
- plt.savefig(
55
- f"{CACHE_DIR}/{os.path.basename(audio_path)[:-4]}_{i}.jpg",
56
- bbox_inches="tight",
57
- pad_inches=0.0,
58
- )
59
- plt.close()
60
-
61
- except Exception as e:
62
- print(f"Error converting {audio_path} : {e}")
63
 
64
 
65
  def embed_img(img_path, input_size=224):
@@ -74,65 +55,82 @@ def embed_img(img_path, input_size=224):
74
  return transform(img).unsqueeze(0)
75
 
76
 
77
- def inference(wav_path, folder_path=CACHE_DIR):
78
- if os.path.exists(folder_path):
79
- shutil.rmtree(folder_path)
80
-
81
- if not wav_path:
82
- return None, "Please input an audio!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- wav_to_mel(wav_path)
85
- outputs = []
86
- all_files = os.listdir(folder_path)
87
- for file_name in all_files:
88
- if file_name.lower().endswith(".jpg"):
89
- file_path = os.path.join(folder_path, file_name)
90
- input = embed_img(file_path)
91
- output: torch.Tensor = MODEL(input)
92
- pred_id = torch.max(output.data, 1)[1]
93
- outputs.append(pred_id)
94
 
95
- max_count_item = most_common_element(outputs)
96
- shutil.rmtree(folder_path)
97
- return os.path.basename(wav_path), TRANS[CLASSES[max_count_item]]
98
 
99
 
100
  if __name__ == "__main__":
101
  warnings.filterwarnings("ignore")
 
 
 
 
 
 
 
 
 
 
 
102
  example_wavs = []
103
- for cls in CLASSES:
104
  example_wavs.append(f"{MODEL_DIR}/examples/{cls}.wav")
105
 
106
  with gr.Blocks() as demo:
107
  gr.Interface(
108
- fn=inference,
109
- inputs=gr.Audio(type="filepath", label="Upload a piano recording"),
110
  outputs=[
111
- gr.Textbox(label="Audio filename", show_copy_button=True),
112
- gr.Textbox(
113
- label="Piano classification result",
114
- show_copy_button=True,
115
- ),
116
  ],
117
  examples=example_wavs,
118
  cache_examples=False,
119
  allow_flagging="never",
120
- title="It is recommended to keep the duration of recording around 3s, too long will affect the recognition efficiency.",
121
  )
122
 
123
  gr.Markdown(
124
- """
125
- # Cite
126
- ```bibtex
127
- @inproceedings{zhou2023holistic,
128
- title = {A Holistic Evaluation of Piano Sound Quality},
129
- author = {Monan Zhou and Shangda Wu and Shaohua Ji and Zijin Li and Wei Li},
130
- booktitle = {National Conference on Sound and Music Technology},
131
- pages = {3--17},
132
- year = {2023},
133
- organization = {Springer}
134
- }
135
- ```"""
136
  )
137
 
138
  demo.launch()
 
11
  from collections import Counter
12
  from PIL import Image
13
  from tqdm import tqdm
14
+ from model import net, _L, MODEL_DIR, TMP_DIR
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
 
17
  def most_common_element(input_list):
 
21
 
22
 
23
  def wav_to_mel(audio_path: str, width=0.18):
24
+ os.makedirs(TMP_DIR, exist_ok=True)
25
+ y, sr = librosa.load(audio_path, sr=48000)
26
+ non_silent = y
27
+ mel_spec = librosa.feature.melspectrogram(y=non_silent, sr=sr)
28
+ log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max)
29
+ dur = librosa.get_duration(y=non_silent, sr=sr)
30
+ total_frames = log_mel_spec.shape[1]
31
+ step = int(width * total_frames / dur)
32
+ count = int(total_frames / step)
33
+ begin = int(0.5 * (total_frames - count * step))
34
+ end = begin + step * count
35
+ for i in tqdm(range(begin, end, step), desc="转换 wav 至 jpgs..."):
36
+ librosa.display.specshow(log_mel_spec[:, i : i + step])
37
+ plt.axis("off")
38
+ plt.savefig(
39
+ f"{TMP_DIR}/{os.path.basename(audio_path)[:-4]}_{i}.jpg",
40
+ bbox_inches="tight",
41
+ pad_inches=0.0,
42
+ )
43
+ plt.close()
 
 
 
 
44
 
45
 
46
  def embed_img(img_path, input_size=224):
 
55
  return transform(img).unsqueeze(0)
56
 
57
 
58
+ def infer(wav_path, folder_path=TMP_DIR):
59
+ status = "Success"
60
+ filename = result = None
61
+ try:
62
+ if os.path.exists(folder_path):
63
+ shutil.rmtree(folder_path)
64
+
65
+ if not wav_path:
66
+ raise ValueError("请输入音频!")
67
+
68
+ wav_to_mel(wav_path)
69
+ outputs = []
70
+ all_files = os.listdir(folder_path)
71
+ for file_name in all_files:
72
+ if file_name.lower().endswith(".jpg"):
73
+ file_path = os.path.join(folder_path, file_name)
74
+ input = embed_img(file_path)
75
+ output: torch.Tensor = net()(input)
76
+ pred_id = torch.max(output.data, 1)[1]
77
+ outputs.append(pred_id)
78
+
79
+ max_count_item = most_common_element(outputs)
80
+ filename = os.path.basename(wav_path)
81
+ result = translate[classes[max_count_item]]
82
 
83
+ except Exception as e:
84
+ status = f"{e}"
 
 
 
 
 
 
 
 
85
 
86
+ return status, filename, result
 
 
87
 
88
 
89
  if __name__ == "__main__":
90
  warnings.filterwarnings("ignore")
91
+ translate = {
92
+ "PearlRiver": _L("珠江"),
93
+ "YoungChang": _L("英昌"),
94
+ "Steinway-T": _L("施坦威剧场"),
95
+ "Hsinghai": _L("星海"),
96
+ "Kawai": _L("卡瓦依"),
97
+ "Steinway": _L("施坦威"),
98
+ "Kawai-G": _L("卡瓦依三角"),
99
+ "Yamaha": _L("雅马哈"),
100
+ }
101
+ classes = list(translate.keys())
102
  example_wavs = []
103
+ for cls in classes:
104
  example_wavs.append(f"{MODEL_DIR}/examples/{cls}.wav")
105
 
106
  with gr.Blocks() as demo:
107
  gr.Interface(
108
+ fn=infer,
109
+ inputs=gr.Audio(type="filepath", label=_L("上传钢琴录音")),
110
  outputs=[
111
+ gr.Textbox(label=_L("状态栏"), show_copy_button=True),
112
+ gr.Textbox(label=_L("音频文件名"), show_copy_button=True),
113
+ gr.Textbox(label=_L("钢琴分类结果"), show_copy_button=True),
 
 
114
  ],
115
  examples=example_wavs,
116
  cache_examples=False,
117
  allow_flagging="never",
118
+ title=_L("建议录音时长保持在 3s 左右, 过长会影响识别效率"),
119
  )
120
 
121
  gr.Markdown(
122
+ f"# {_L('引用')}"
123
+ + """
124
+ ```bibtex
125
+ @inproceedings{zhou2023holistic,
126
+ title = {A Holistic Evaluation of Piano Sound Quality},
127
+ author = {Monan Zhou and Shangda Wu and Shaohua Ji and Zijin Li and Wei Li},
128
+ booktitle = {National Conference on Sound and Music Technology},
129
+ pages = {3--17},
130
+ year = {2023},
131
+ organization = {Springer}
132
+ }
133
+ ```"""
134
  )
135
 
136
  demo.launch()
model.py CHANGED
@@ -1,9 +1,45 @@
 
1
  import torch
2
  import torch.nn as nn
3
- from huggingface_hub import snapshot_download
 
4
  from torchvision.models import squeezenet1_1
5
 
6
- MODEL_DIR = snapshot_download("ccmusic-database/pianos", cache_dir="./__pycache__")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
 
9
  def Classifier(cls_num=8, output_size=512, linear_output=False):
@@ -41,7 +77,7 @@ def Classifier(cls_num=8, output_size=512, linear_output=False):
41
  )
42
 
43
 
44
- def net(weights=f"{MODEL_DIR}/save.pt"):
45
  model = squeezenet1_1(pretrained=False)
46
  model.classifier = Classifier()
47
  model.load_state_dict(torch.load(weights, map_location=torch.device("cpu")))
 
1
+ import os
2
  import torch
3
  import torch.nn as nn
4
+ import huggingface_hub
5
+ import modelscope
6
  from torchvision.models import squeezenet1_1
7
 
8
+ TMP_DIR = "./__pycache__/tmp"
9
+ EN_US = os.getenv("LANG") != "zh_CN.UTF-8"
10
+
11
+ ZH2EN = {
12
+ "上传钢琴录音": "Upload a piano recording",
13
+ "状态栏": "Status",
14
+ "音频文件名": "Audio filename",
15
+ "钢琴分类结果": "Piano classification result",
16
+ "建议录音时长保持在 3s 左右, 过长会影响识别效率": "It is recommended to keep the duration of recording around 3s, too long will affect the recognition efficiency.",
17
+ "引用": "Cite",
18
+ "珠江": "Pearl River",
19
+ "英昌": "YOUNG CHANG",
20
+ "施坦威剧场": "STEINWAY Theater",
21
+ "星海": "HSINGHAI",
22
+ "卡瓦依": "KAWAI",
23
+ "施坦威": "STEINWAY",
24
+ "卡瓦依三角": "KAWAI Grand",
25
+ "雅马哈": "YAMAHA",
26
+ }
27
+
28
+ MODEL_DIR = (
29
+ huggingface_hub.snapshot_download(
30
+ "ccmusic-database/pianos",
31
+ cache_dir="./__pycache__",
32
+ )
33
+ if EN_US
34
+ else modelscope.snapshot_download(
35
+ "ccmusic-database/pianos",
36
+ cache_dir="./__pycache__",
37
+ )
38
+ )
39
+
40
+
41
+ def _L(zh_txt: str):
42
+ return ZH2EN[zh_txt] if EN_US else zh_txt
43
 
44
 
45
  def Classifier(cls_num=8, output_size=512, linear_output=False):
 
77
  )
78
 
79
 
80
+ def net(weights=MODEL_DIR + "/save.pt"):
81
  model = squeezenet1_1(pretrained=False)
82
  model.classifier = Classifier()
83
  model.load_state_dict(torch.load(weights, map_location=torch.device("cpu")))
requirements.txt CHANGED
@@ -1,5 +1,7 @@
1
- torch
2
- pillow
 
 
3
  librosa
4
  matplotlib
5
- torchvision
 
1
+ torch==2.6.0+cu118
2
+ -f https://download.pytorch.org/whl/torch
3
+ torchvision==0.21.0+cu118
4
+ -f https://download.pytorch.org/whl/torchvision
5
  librosa
6
  matplotlib
7
+ modelscope[framework]==1.21.0