import os import re import json import torch import shutil import requests import modelscope import huggingface_hub import gradio as gr from tqdm import tqdm from piano_transcription_inference import PianoTranscription, load_audio, sample_rate from urllib.parse import urlparse from convert import midi2xml, xml2abc, xml2mxl, xml2jpg EN_US = os.getenv("LANG") != "zh_CN.UTF-8" ZH2EN = { "上传模式": "Uploading Mode", "上传音频": "Upload an audio", "下载 MIDI": "Download MIDI", "下载 PDF 乐谱": "Download PDF score", "下载 MusicXML": "Download MusicXML", "下载 MXL": "Download MXL", "ABC 记谱": "ABC notation", "五线谱": "Staff", "状态栏": "Status", "请上传音频 100% 后再点提交": "Please make sure the audio is completely uploaded before clicking Submit", "直链模式": "Direct Link Mode", "输入音频 URL 直链": "Input audio direct link", "下载音频": "Download audio", "网易云音乐可直接输入非 VIP 歌曲页面链接自动解析": "For Netease Cloud music, you can directly input the non-VIP song page link", "# 钢琴转谱工具": "# Piano Transcription Tool", } WEIGHTS_PATH = ( huggingface_hub.snapshot_download( "Genius-Society/piano_trans", cache_dir="./__pycache__", ) if EN_US else modelscope.snapshot_download( "Genius-Society/piano_trans", cache_dir="./__pycache__", ) ) + "/CRNN_note_F1=0.9677_pedal_F1=0.9186.pth" def _L(zh_txt: str): return ZH2EN[zh_txt] if EN_US else zh_txt def clean_cache(cache_dir): if os.path.exists(cache_dir): shutil.rmtree(cache_dir) os.mkdir(cache_dir) def download_audio(url: str, save_path: str): # 发起流式请求 response = requests.get(url, stream=True) response.raise_for_status() # 获取文件总大小(字节),如果服务器未返回,则 total=0 total = int(response.headers.get("content-length", 0)) # 打开文件并创建 tqdm 进度条 with open(save_path, "wb") as file, tqdm( desc=save_path, # 进度条前缀文字 total=total, # 总大小 unit="B", # 单位为字节 unit_scale=True, # 根据文件大小自动转换单位 unit_divisor=1024, # 1024 字节 = 1 KB ) as pbar: # 以 8 KB 为块循环写入 for chunk in response.iter_content(chunk_size=8192): if chunk: # 忽略 keep-alive 产生的空块 file.write(chunk) pbar.update(len(chunk)) # 更新进度条 def is_url(s: str): try: # 解析字符串 result = urlparse(s) # 检查scheme(如http, https)和netloc(域名) return all([result.scheme, result.netloc]) except: # 如果解析过程中发生异常,则返回False return False def audio2midi(audio_path: str, cache_dir: str): audio, _ = load_audio(audio_path, sr=sample_rate, mono=True) transcriptor = PianoTranscription( device="cuda" if torch.cuda.is_available() else "cpu", checkpoint_path=WEIGHTS_PATH, ) midi_path = f"{cache_dir}/output.mid" transcriptor.transcribe(audio, midi_path) return midi_path, os.path.basename(audio_path).split(".")[-2].capitalize() def extract_fst_int(input_string: str): match = re.search(r"\d+", input_string) if match: return str(int(match.group())) else: return "" def music163_song_info(id: str): detail_api = "https://music.163.com/api/v3/song/detail" parm_dict = {"id": id, "c": str([{"id": id}]), "csrf_token": ""} free = False song_name = "获取歌曲失败" response = requests.get(detail_api, params=parm_dict) # 检查请求是否成功 if response.status_code == 200: # 处理成功响应 data = json.loads(response.text) if data and "songs" in data and data["songs"]: fee = int(data["songs"][0]["fee"]) free = fee == 0 or fee == 8 song_name = str(data["songs"][0]["name"]) else: song_name = "歌曲不存在" else: raise ConnectionError(f"错误: {response.status_code}, {response.text}") return song_name, free def upl_infer(audio_path: str, cache_dir="./__pycache__/mode1"): status = "Success" midi = pdf = xml = mxl = abc = jpg = None try: clean_cache(cache_dir) midi, title = audio2midi(audio_path, cache_dir) xml = midi2xml(midi, title) abc = xml2abc(xml) mxl = xml2mxl(xml) pdf, jpg = xml2jpg(xml) except Exception as e: status = f"{e}" return status, midi, pdf, xml, mxl, abc, jpg def url_infer(song: str, cache_dir="./__pycache__/mode2"): song_name = "" status = "Success" audio = midi = pdf = xml = mxl = abc = jpg = None try: clean_cache(cache_dir) download_path = f"{cache_dir}/output.mp3" if (is_url(song) and "163" in song and "?id=" in song) or song.isdigit(): song_id = extract_fst_int(song.split("?id=")[-1]) song = f"https://music.163.com/song/media/outer/url?id={song_id}.mp3" song_name, free = music163_song_info(song_id) if not free: raise AttributeError("付费歌曲无法解析") download_audio(song, download_path) if not os.path.exists(download_path): raise FileExistsError(f"{download_path} not exist") midi, title = audio2midi(download_path, cache_dir) if song_name: title = song_name audio = download_path xml = midi2xml(midi, title) abc = xml2abc(xml) mxl = xml2mxl(xml) pdf, jpg = xml2jpg(xml) except Exception as e: status = f"{e}" return status, audio, midi, pdf, xml, mxl, abc, jpg if __name__ == "__main__": with gr.Blocks() as iface: gr.Markdown(_L("# 钢琴转谱工具")) with gr.Tab(_L("上传模式")): gr.Interface( fn=upl_infer, inputs=gr.Audio(label=_L("上传音频"), type="filepath"), outputs=[ gr.Textbox(label=_L("状态栏"), show_copy_button=True), gr.File(label=_L("下载 MIDI")), gr.File(label=_L("下载 PDF 乐谱")), gr.File(label=_L("下载 MusicXML")), gr.File(label=_L("下载 MXL")), gr.Textbox(label=_L("ABC 记谱"), show_copy_button=True), gr.Image( label=_L("五线谱"), type="filepath", show_share_button=False, ), ], title=_L("请上传音频 100% 后再点提交"), flagging_mode="never", ) if not EN_US: with gr.Tab(_L("直链模式")): gr.Interface( fn=url_infer, inputs=gr.Textbox( label=_L("输入音频 URL 直链"), placeholder="https://music.163.com/#/song?id=", ), outputs=[ gr.Textbox(label=_L("状态栏"), show_copy_button=True), gr.Audio(label=_L("下载音频"), type="filepath"), gr.File(label=_L("下载 MIDI")), gr.File(label=_L("下载 PDF 乐谱")), gr.File(label=_L("下载 MusicXML")), gr.File(label=_L("下载 MXL")), gr.Textbox(label=_L("ABC 记谱"), show_copy_button=True), gr.Image( label=_L("五线谱"), type="filepath", show_share_button=False, ), ], title=_L("网易云音乐可直接输入非 VIP 歌曲页面链接自动解析"), examples=["1945798894", "1945798973", "1946098771"], flagging_mode="never", cache_examples=False, ) iface.launch()