feat: Implement stereo audio to MIDI transcription
Browse filesThis commit introduces a new stereo processing workflow for audio-to-MIDI transcription, allowing the preservation of spatial information from stereo recordings. The previous implementation was limited to mono processing.
Scale MIDI velocities by 0.8 in Stereo Transcription to avoid loudness/clipping after merge
Applied `scale_instrument_velocity(scale=0.8)` during Stereo Transcription to prevent excessive loudness caused by summing left and right channel MIDI tracks. This helps maintain a more natural dynamic range, avoiding clipping and ensuring more consistent perceived volume after rendering to WAV/FLAC.
- app.py +289 -36
- requirements.txt +2 -0
app.py
CHANGED
|
@@ -1,14 +1,15 @@
|
|
| 1 |
# =================================================================
|
| 2 |
#
|
| 3 |
-
# Merged and Integrated Script for Audio/MIDI Processing and Rendering
|
| 4 |
#
|
| 5 |
# This script combines two functionalities:
|
| 6 |
# 1. Transcribing audio to MIDI using two methods:
|
| 7 |
# a) A general-purpose model (basic-pitch by Spotify).
|
| 8 |
# b) A model specialized for solo piano (ByteDance).
|
|
|
|
| 9 |
# 2. Applying advanced transformations and re-rendering MIDI files using:
|
| 10 |
-
# a) Standard SoundFonts via FluidSynth.
|
| 11 |
-
# b) A custom 8-bit style synthesizer for a chiptune sound.
|
| 12 |
#
|
| 13 |
# The user can upload a Audio (e.g., WAV, MP3), or MIDI file.
|
| 14 |
# - If an audio file is uploaded, it is first transcribed to MIDI using the selected method.
|
|
@@ -29,7 +30,7 @@
|
|
| 29 |
#
|
| 30 |
# pip install gradio torch pytz numpy scipy matplotlib networkx scikit-learn
|
| 31 |
# pip install piano_transcription_inference huggingface_hub
|
| 32 |
-
# pip install basic-pitch pretty_midi librosa
|
| 33 |
#
|
| 34 |
# =================================================================
|
| 35 |
# Core modules:
|
|
@@ -42,6 +43,9 @@ import os
|
|
| 42 |
import hashlib
|
| 43 |
import time as reqtime
|
| 44 |
import copy
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
import torch
|
| 47 |
import gradio as gr
|
|
@@ -60,7 +64,7 @@ import basic_pitch
|
|
| 60 |
from basic_pitch.inference import predict
|
| 61 |
from basic_pitch import ICASSP_2022_MODEL_PATH
|
| 62 |
|
| 63 |
-
# --- Imports for 8-bit Synthesizer ---
|
| 64 |
import pretty_midi
|
| 65 |
import numpy as np
|
| 66 |
from scipy import signal
|
|
@@ -158,18 +162,36 @@ def prepare_soundfonts():
|
|
| 158 |
return ordered_soundfont_map
|
| 159 |
|
| 160 |
# =================================================================================================
|
| 161 |
-
# === 8-bit Style Synthesizer ===
|
| 162 |
# =================================================================================================
|
| 163 |
def synthesize_8bit_style(midi_data, waveform_type, envelope_type, decay_time_s, pulse_width, vibrato_rate, vibrato_depth, bass_boost_level, fs=44100):
|
| 164 |
"""
|
| 165 |
Synthesizes an 8-bit style audio waveform from a PrettyMIDI object.
|
| 166 |
This function generates waveforms manually instead of using a synthesizer like FluidSynth.
|
| 167 |
Includes an optional sub-octave bass booster with adjustable level.
|
|
|
|
|
|
|
| 168 |
"""
|
| 169 |
total_duration = midi_data.get_end_time()
|
| 170 |
-
waveform
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
|
| 172 |
-
for instrument in midi_data.instruments:
|
| 173 |
for note in instrument.notes:
|
| 174 |
freq = pretty_midi.note_number_to_hz(note.pitch)
|
| 175 |
note_duration = note.end - note.start
|
|
@@ -222,13 +244,162 @@ def synthesize_8bit_style(midi_data, waveform_type, envelope_type, decay_time_s,
|
|
| 222 |
|
| 223 |
start_sample = int(note.start * fs)
|
| 224 |
end_sample = start_sample + num_samples
|
| 225 |
-
if end_sample >
|
| 226 |
-
end_sample =
|
| 227 |
note_waveform = note_waveform[:end_sample-start_sample]
|
| 228 |
|
| 229 |
-
waveform
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
|
| 231 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
|
| 233 |
# =================================================================================================
|
| 234 |
# === Stage 1: Audio to MIDI Transcription Functions ===
|
|
@@ -254,7 +425,7 @@ def TranscribePianoAudio(input_file):
|
|
| 254 |
# Use os.path.join to create a platform-independent directory path
|
| 255 |
output_dir = os.path.join("output", "transcribed_piano_")
|
| 256 |
out_mid_path = os.path.join(output_dir, fn1 + '.mid')
|
| 257 |
-
|
| 258 |
# Check for the directory's existence and create it if necessary
|
| 259 |
if not os.path.exists(output_dir):
|
| 260 |
os.makedirs(output_dir)
|
|
@@ -412,7 +583,7 @@ def Render_MIDI(input_midi_path,
|
|
| 412 |
escore = TMIDIX.merge_escore_notes(escore, merge_threshold=merge_misaligned_notes)
|
| 413 |
|
| 414 |
escore = TMIDIX.augment_enhanced_score_notes(escore, timings_divider=1)
|
| 415 |
-
|
| 416 |
first_note_index = [e[0] for e in raw_score[1]].index('note')
|
| 417 |
cscore = TMIDIX.chordify_score([1000, escore])
|
| 418 |
|
|
@@ -420,7 +591,7 @@ def Render_MIDI(input_midi_path,
|
|
| 420 |
|
| 421 |
aux_escore_notes = TMIDIX.augment_enhanced_score_notes(escore, sort_drums_last=True)
|
| 422 |
song_description = TMIDIX.escore_notes_to_text_description(aux_escore_notes)
|
| 423 |
-
|
| 424 |
print('Done!')
|
| 425 |
print('=' * 70)
|
| 426 |
print('Input MIDI metadata:', meta_data[:5])
|
|
@@ -472,7 +643,7 @@ def Render_MIDI(input_midi_path,
|
|
| 472 |
|
| 473 |
if render_transpose_to_C4:
|
| 474 |
output_score = TMIDIX.transpose_escore_notes_to_pitch(output_score, 60) # C4 is MIDI pitch 60
|
| 475 |
-
|
| 476 |
if render_align == "Start Times":
|
| 477 |
output_score = TMIDIX.recalculate_score_timings(output_score)
|
| 478 |
output_score = TMIDIX.align_escore_notes_to_bars(output_score)
|
|
@@ -573,11 +744,12 @@ def Render_MIDI(input_midi_path,
|
|
| 573 |
s8bit_bass_boost_level,
|
| 574 |
fs=srate
|
| 575 |
)
|
| 576 |
-
# Normalize
|
| 577 |
peak_val = np.max(np.abs(audio))
|
| 578 |
if peak_val > 0:
|
| 579 |
audio /= peak_val
|
| 580 |
-
|
|
|
|
| 581 |
except Exception as e:
|
| 582 |
print(f"Error during 8-bit synthesis: {e}")
|
| 583 |
return [None] * 7
|
|
@@ -603,7 +775,7 @@ def Render_MIDI(input_midi_path,
|
|
| 603 |
with open(midi_to_render_path, 'rb') as f:
|
| 604 |
midi_file_content = f.read()
|
| 605 |
|
| 606 |
-
|
| 607 |
soundfont_path=soundfont_path, # Use the dynamically found path
|
| 608 |
sample_rate=srate,
|
| 609 |
output_for_gradio=True
|
|
@@ -619,7 +791,7 @@ def Render_MIDI(input_midi_path,
|
|
| 619 |
|
| 620 |
output_midi_summary = str(meta_data)
|
| 621 |
|
| 622 |
-
return new_md5_hash, fn1, output_midi_summary, midi_to_render_path, (srate,
|
| 623 |
|
| 624 |
# =================================================================================================
|
| 625 |
# === Main Application Logic ===
|
|
@@ -627,6 +799,7 @@ def Render_MIDI(input_midi_path,
|
|
| 627 |
|
| 628 |
def process_and_render_file(input_file,
|
| 629 |
# --- Transcription params ---
|
|
|
|
| 630 |
transcription_method,
|
| 631 |
onset_thresh, frame_thresh, min_note_len, min_freq, max_freq, infer_onsets_bool, melodia_trick_bool, multiple_bends_bool,
|
| 632 |
# --- MIDI rendering params ---
|
|
@@ -645,14 +818,18 @@ def process_and_render_file(input_file,
|
|
| 645 |
start_time = reqtime.time()
|
| 646 |
if input_file is None:
|
| 647 |
# Return a list of updates to clear all output fields
|
| 648 |
-
|
| 649 |
-
return [gr.update(value=None)] * num_outputs
|
| 650 |
|
| 651 |
# The input_file from gr.Audio(type="filepath") is now the direct path (a string),
|
| 652 |
# not a temporary file object. We no longer need to access the .name attribute.
|
| 653 |
input_file_path = input_file
|
| 654 |
filename = os.path.basename(input_file_path)
|
| 655 |
print(f"Processing new file: {filename}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 656 |
|
| 657 |
# --- Step 1: Check file type and transcribe if necessary ---
|
| 658 |
if filename.lower().endswith(('.mid', '.midi', '.kar')):
|
|
@@ -660,17 +837,86 @@ def process_and_render_file(input_file,
|
|
| 660 |
midi_path_for_rendering = input_file_path
|
| 661 |
else: #if filename.lower().endswith(('.wav', '.mp3'))
|
| 662 |
print("Audio file detected. Starting transcription...")
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 674 |
|
| 675 |
# --- Step 2: Render the MIDI file with selected options ---
|
| 676 |
print(f"Proceeding to render MIDI file: {os.path.basename(midi_path_for_rendering)}")
|
|
@@ -696,7 +942,7 @@ def update_ui_visibility(transcription_method, soundfont_choice):
|
|
| 696 |
"""
|
| 697 |
is_general = (transcription_method == "General Purpose")
|
| 698 |
is_8bit = (soundfont_choice == SYNTH_8_BIT_LABEL)
|
| 699 |
-
|
| 700 |
return {
|
| 701 |
general_transcription_settings: gr.update(visible=is_general),
|
| 702 |
synth_8bit_settings: gr.update(visible=is_8bit),
|
|
@@ -751,8 +997,14 @@ if __name__ == "__main__":
|
|
| 751 |
value="General Purpose",
|
| 752 |
info="Choose 'General Purpose' for most music (vocals, etc.). Choose 'Piano-Specific' only for solo piano recordings."
|
| 753 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 754 |
|
| 755 |
-
# --- General Purpose (basic-pitch) Settings ---
|
| 756 |
with gr.Accordion("General Purpose Transcription Settings", open=True) as general_transcription_settings:
|
| 757 |
onset_threshold = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="On-set Threshold", info="Sensitivity for detecting note beginnings. Higher is stricter.")
|
| 758 |
frame_threshold = gr.Slider(0.0, 1.0, value=0.3, step=0.05, label="Frame Threshold", info="Sensitivity for detecting active notes. Higher is stricter.")
|
|
@@ -775,7 +1027,7 @@ if __name__ == "__main__":
|
|
| 775 |
# --- Dynamically create the list of choices ---
|
| 776 |
soundfont_choices = [SYNTH_8_BIT_LABEL] + list(soundfonts_dict.keys())
|
| 777 |
# Set a safe default value
|
| 778 |
-
default_sf_choice = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7" if "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7" in soundfonts_dict else soundfont_choices[0]
|
| 779 |
|
| 780 |
soundfont_bank = gr.Dropdown(
|
| 781 |
soundfont_choices,
|
|
@@ -831,6 +1083,7 @@ if __name__ == "__main__":
|
|
| 831 |
# --- Define all input components for the click event ---
|
| 832 |
all_inputs = [
|
| 833 |
input_file,
|
|
|
|
| 834 |
transcription_method,
|
| 835 |
onset_threshold, frame_threshold, minimum_note_length, minimum_frequency, maximum_frequency,
|
| 836 |
infer_onsets, melodia_trick, multiple_pitch_bends,
|
|
|
|
| 1 |
# =================================================================
|
| 2 |
#
|
| 3 |
+
# Merged and Integrated Script for Audio/MIDI Processing and Rendering (Stereo Enhanced)
|
| 4 |
#
|
| 5 |
# This script combines two functionalities:
|
| 6 |
# 1. Transcribing audio to MIDI using two methods:
|
| 7 |
# a) A general-purpose model (basic-pitch by Spotify).
|
| 8 |
# b) A model specialized for solo piano (ByteDance).
|
| 9 |
+
# - Includes stereo processing by splitting channels, transcribing independently, and merging MIDI.
|
| 10 |
# 2. Applying advanced transformations and re-rendering MIDI files using:
|
| 11 |
+
# a) Standard SoundFonts via FluidSynth (produces stereo audio).
|
| 12 |
+
# b) A custom 8-bit style synthesizer for a chiptune sound (updated for stereo output).
|
| 13 |
#
|
| 14 |
# The user can upload a Audio (e.g., WAV, MP3), or MIDI file.
|
| 15 |
# - If an audio file is uploaded, it is first transcribed to MIDI using the selected method.
|
|
|
|
| 30 |
#
|
| 31 |
# pip install gradio torch pytz numpy scipy matplotlib networkx scikit-learn
|
| 32 |
# pip install piano_transcription_inference huggingface_hub
|
| 33 |
+
# pip install basic-pitch pretty_midi librosa soundfile
|
| 34 |
#
|
| 35 |
# =================================================================
|
| 36 |
# Core modules:
|
|
|
|
| 43 |
import hashlib
|
| 44 |
import time as reqtime
|
| 45 |
import copy
|
| 46 |
+
import librosa
|
| 47 |
+
import pyloudnorm as pyln
|
| 48 |
+
import soundfile as sf
|
| 49 |
|
| 50 |
import torch
|
| 51 |
import gradio as gr
|
|
|
|
| 64 |
from basic_pitch.inference import predict
|
| 65 |
from basic_pitch import ICASSP_2022_MODEL_PATH
|
| 66 |
|
| 67 |
+
# --- Imports for 8-bit Synthesizer & MIDI Merging ---
|
| 68 |
import pretty_midi
|
| 69 |
import numpy as np
|
| 70 |
from scipy import signal
|
|
|
|
| 162 |
return ordered_soundfont_map
|
| 163 |
|
| 164 |
# =================================================================================================
|
| 165 |
+
# === 8-bit Style Synthesizer (Stereo Enabled) ===
|
| 166 |
# =================================================================================================
|
| 167 |
def synthesize_8bit_style(midi_data, waveform_type, envelope_type, decay_time_s, pulse_width, vibrato_rate, vibrato_depth, bass_boost_level, fs=44100):
|
| 168 |
"""
|
| 169 |
Synthesizes an 8-bit style audio waveform from a PrettyMIDI object.
|
| 170 |
This function generates waveforms manually instead of using a synthesizer like FluidSynth.
|
| 171 |
Includes an optional sub-octave bass booster with adjustable level.
|
| 172 |
+
Instruments are panned based on their order in the MIDI file.
|
| 173 |
+
Instrument 1 -> Left, Instrument 2 -> Right.
|
| 174 |
"""
|
| 175 |
total_duration = midi_data.get_end_time()
|
| 176 |
+
# Initialize a stereo waveform buffer (2 channels: Left, Right)
|
| 177 |
+
waveform = np.zeros((2, int(total_duration * fs) + fs))
|
| 178 |
+
|
| 179 |
+
num_instruments = len(midi_data.instruments)
|
| 180 |
+
|
| 181 |
+
for i, instrument in enumerate(midi_data.instruments):
|
| 182 |
+
# --- Panning Logic ---
|
| 183 |
+
# Default to center-panned mono
|
| 184 |
+
pan_l, pan_r = 0.707, 0.707
|
| 185 |
+
if num_instruments == 2:
|
| 186 |
+
if i == 0: # First instrument panned left
|
| 187 |
+
pan_l, pan_r = 1.0, 0.0
|
| 188 |
+
elif i == 1: # Second instrument panned right
|
| 189 |
+
pan_l, pan_r = 0.0, 1.0
|
| 190 |
+
elif num_instruments > 2:
|
| 191 |
+
if i == 0: pan_l, pan_r = 1.0, 0.0 # Left
|
| 192 |
+
elif i == 1: pan_l, pan_r = 0.0, 1.0 # Right
|
| 193 |
+
# Other instruments remain centered
|
| 194 |
|
|
|
|
| 195 |
for note in instrument.notes:
|
| 196 |
freq = pretty_midi.note_number_to_hz(note.pitch)
|
| 197 |
note_duration = note.end - note.start
|
|
|
|
| 244 |
|
| 245 |
start_sample = int(note.start * fs)
|
| 246 |
end_sample = start_sample + num_samples
|
| 247 |
+
if end_sample > waveform.shape[1]:
|
| 248 |
+
end_sample = waveform.shape[1]
|
| 249 |
note_waveform = note_waveform[:end_sample-start_sample]
|
| 250 |
|
| 251 |
+
# Add the mono note waveform to the stereo buffer with panning
|
| 252 |
+
waveform[0, start_sample:end_sample] += note_waveform * pan_l
|
| 253 |
+
waveform[1, start_sample:end_sample] += note_waveform * pan_r
|
| 254 |
+
|
| 255 |
+
return waveform # Returns a (2, N) numpy array
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def analyze_midi_velocity(midi_path):
|
| 259 |
+
midi = pretty_midi.PrettyMIDI(midi_path)
|
| 260 |
+
all_velocities = []
|
| 261 |
+
|
| 262 |
+
print(f"Analyzing velocity for MIDI: {midi_path}")
|
| 263 |
+
for i, instrument in enumerate(midi.instruments):
|
| 264 |
+
velocities = [note.velocity for note in instrument.notes]
|
| 265 |
+
all_velocities.extend(velocities)
|
| 266 |
+
|
| 267 |
+
if velocities:
|
| 268 |
+
print(f"Instrument {i} ({instrument.name}):")
|
| 269 |
+
print(f" Notes count: {len(velocities)}")
|
| 270 |
+
print(f" Velocity min: {min(velocities)}")
|
| 271 |
+
print(f" Velocity max: {max(velocities)}")
|
| 272 |
+
print(f" Velocity mean: {np.mean(velocities):.2f}")
|
| 273 |
+
else:
|
| 274 |
+
print(f"Instrument {i} ({instrument.name}): no notes found.")
|
| 275 |
+
|
| 276 |
+
if all_velocities:
|
| 277 |
+
print("\nOverall MIDI velocity stats:")
|
| 278 |
+
print(f" Total notes: {len(all_velocities)}")
|
| 279 |
+
print(f" Velocity min: {min(all_velocities)}")
|
| 280 |
+
print(f" Velocity max: {max(all_velocities)}")
|
| 281 |
+
print(f" Velocity mean: {np.mean(all_velocities):.2f}")
|
| 282 |
+
else:
|
| 283 |
+
print("No notes found in this MIDI.")
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def scale_instrument_velocity(instrument, scale=0.8):
|
| 287 |
+
for note in instrument.notes:
|
| 288 |
+
note.velocity = max(1, min(127, int(note.velocity * scale)))
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def normalize_loudness(audio_data, sample_rate, target_lufs=-23.0):
|
| 292 |
+
"""
|
| 293 |
+
Normalizes the audio data to a target integrated loudness (LUFS).
|
| 294 |
+
This provides more consistent perceived volume than peak normalization.
|
| 295 |
+
|
| 296 |
+
Args:
|
| 297 |
+
audio_data (np.ndarray): The audio signal.
|
| 298 |
+
sample_rate (int): The sample rate of the audio.
|
| 299 |
+
target_lufs (float): The target loudness in LUFS. Defaults to -23.0,
|
| 300 |
+
a common standard for broadcast.
|
| 301 |
+
|
| 302 |
+
Returns:
|
| 303 |
+
np.ndarray: The loudness-normalized audio data.
|
| 304 |
+
"""
|
| 305 |
+
try:
|
| 306 |
+
# 1. Measure the integrated loudness of the input audio
|
| 307 |
+
meter = pyln.Meter(sample_rate) # create meter
|
| 308 |
+
loudness = meter.integrated_loudness(audio_data)
|
| 309 |
+
|
| 310 |
+
# 2. Calculate the gain needed to reach the target loudness
|
| 311 |
+
# The gain is applied in the linear domain, so we convert from dB
|
| 312 |
+
loudness_gain_db = target_lufs - loudness
|
| 313 |
+
loudness_gain_linear = 10.0 ** (loudness_gain_db / 20.0)
|
| 314 |
+
|
| 315 |
+
# 3. Apply the gain
|
| 316 |
+
normalized_audio = audio_data * loudness_gain_linear
|
| 317 |
+
|
| 318 |
+
# 4. Final safety check: peak normalize to prevent clipping, just in case
|
| 319 |
+
# the loudness normalization results in peaks > 1.0
|
| 320 |
+
peak_val = np.max(np.abs(normalized_audio))
|
| 321 |
+
if peak_val > 1.0:
|
| 322 |
+
normalized_audio /= peak_val
|
| 323 |
+
print(f"Warning: Loudness normalization resulted in clipping. Audio was peak-normalized as a safeguard.")
|
| 324 |
+
|
| 325 |
+
print(f"Audio normalized from {loudness:.2f} LUFS to target {target_lufs} LUFS.")
|
| 326 |
+
return normalized_audio
|
| 327 |
+
|
| 328 |
+
except Exception as e:
|
| 329 |
+
print(f"Loudness normalization failed: {e}. Falling back to original audio.")
|
| 330 |
+
return audio_data
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
# =================================================================================================
|
| 334 |
+
# === MIDI Merging Function ===
|
| 335 |
+
# =================================================================================================
|
| 336 |
+
def merge_midis(midi_path_left, midi_path_right, output_path):
|
| 337 |
+
"""
|
| 338 |
+
Merges two MIDI files into a single MIDI file. This robust version iterates
|
| 339 |
+
through ALL instruments in both MIDI files, ensuring no data is lost if the
|
| 340 |
+
source files are multi-instrumental.
|
| 341 |
+
|
| 342 |
+
It applies hard-left panning (Pan=0) to every instrument from the left MIDI
|
| 343 |
+
and hard-right panning (Pan=127) to every instrument from the right MIDI.
|
| 344 |
+
"""
|
| 345 |
+
try:
|
| 346 |
+
analyze_midi_velocity(midi_path_left)
|
| 347 |
+
analyze_midi_velocity(midi_path_right)
|
| 348 |
+
midi_left = pretty_midi.PrettyMIDI(midi_path_left)
|
| 349 |
+
midi_right = pretty_midi.PrettyMIDI(midi_path_right)
|
| 350 |
+
|
| 351 |
+
merged_midi = pretty_midi.PrettyMIDI()
|
| 352 |
+
|
| 353 |
+
# --- Process ALL instruments from the left channel MIDI ---
|
| 354 |
+
if midi_left.instruments:
|
| 355 |
+
print(f"Found {len(midi_left.instruments)} instrument(s) in the left channel MIDI.")
|
| 356 |
+
# Use a loop to iterate through every instrument
|
| 357 |
+
for instrument in midi_left.instruments:
|
| 358 |
+
scale_instrument_velocity(instrument, scale=0.8)
|
| 359 |
+
# To avoid confusion, we can prefix the instrument name
|
| 360 |
+
instrument.name = f"Left - {instrument.name if instrument.name else 'Instrument'}"
|
| 361 |
+
|
| 362 |
+
# Create and add the Pan Left control change
|
| 363 |
+
# Create a Control Change event for Pan (controller number 10).
|
| 364 |
+
# Set its value to 0 for hard left panning.
|
| 365 |
+
# Add it at the very beginning of the track (time=0.0).
|
| 366 |
+
pan_left = pretty_midi.ControlChange(number=10, value=0, time=0.0)
|
| 367 |
+
# Use insert() to ensure the pan event is the very first one
|
| 368 |
+
instrument.control_changes.insert(0, pan_left)
|
| 369 |
+
|
| 370 |
+
# Append the fully processed instrument to the merged MIDI
|
| 371 |
+
merged_midi.instruments.append(instrument)
|
| 372 |
+
|
| 373 |
+
# --- Process ALL instruments from the right channel MIDI ---
|
| 374 |
+
if midi_right.instruments:
|
| 375 |
+
print(f"Found {len(midi_right.instruments)} instrument(s) in the right channel MIDI.")
|
| 376 |
+
# Use a loop here as well
|
| 377 |
+
for instrument in midi_right.instruments:
|
| 378 |
+
scale_instrument_velocity(instrument, scale=0.8)
|
| 379 |
+
instrument.name = f"Right - {instrument.name if instrument.name else 'Instrument'}"
|
| 380 |
+
|
| 381 |
+
# Create and add the Pan Right control change
|
| 382 |
+
# Create a Control Change event for Pan (controller number 10).
|
| 383 |
+
# Set its value to 127 for hard right panning.
|
| 384 |
+
# Add it at the very beginning of the track (time=0.0).
|
| 385 |
+
pan_right = pretty_midi.ControlChange(number=10, value=127, time=0.0)
|
| 386 |
+
instrument.control_changes.insert(0, pan_right)
|
| 387 |
+
|
| 388 |
+
merged_midi.instruments.append(instrument)
|
| 389 |
|
| 390 |
+
merged_midi.write(output_path)
|
| 391 |
+
print(f"Successfully merged all instruments and panned into '{os.path.basename(output_path)}'")
|
| 392 |
+
analyze_midi_velocity(output_path)
|
| 393 |
+
return output_path
|
| 394 |
+
|
| 395 |
+
except Exception as e:
|
| 396 |
+
print(f"Error merging MIDI files: {e}")
|
| 397 |
+
# Fallback logic remains the same
|
| 398 |
+
if os.path.exists(midi_path_left):
|
| 399 |
+
print("Fallback: Using only the left channel MIDI.")
|
| 400 |
+
return midi_path_left
|
| 401 |
+
return None
|
| 402 |
+
|
| 403 |
|
| 404 |
# =================================================================================================
|
| 405 |
# === Stage 1: Audio to MIDI Transcription Functions ===
|
|
|
|
| 425 |
# Use os.path.join to create a platform-independent directory path
|
| 426 |
output_dir = os.path.join("output", "transcribed_piano_")
|
| 427 |
out_mid_path = os.path.join(output_dir, fn1 + '.mid')
|
| 428 |
+
|
| 429 |
# Check for the directory's existence and create it if necessary
|
| 430 |
if not os.path.exists(output_dir):
|
| 431 |
os.makedirs(output_dir)
|
|
|
|
| 583 |
escore = TMIDIX.merge_escore_notes(escore, merge_threshold=merge_misaligned_notes)
|
| 584 |
|
| 585 |
escore = TMIDIX.augment_enhanced_score_notes(escore, timings_divider=1)
|
| 586 |
+
|
| 587 |
first_note_index = [e[0] for e in raw_score[1]].index('note')
|
| 588 |
cscore = TMIDIX.chordify_score([1000, escore])
|
| 589 |
|
|
|
|
| 591 |
|
| 592 |
aux_escore_notes = TMIDIX.augment_enhanced_score_notes(escore, sort_drums_last=True)
|
| 593 |
song_description = TMIDIX.escore_notes_to_text_description(aux_escore_notes)
|
| 594 |
+
|
| 595 |
print('Done!')
|
| 596 |
print('=' * 70)
|
| 597 |
print('Input MIDI metadata:', meta_data[:5])
|
|
|
|
| 643 |
|
| 644 |
if render_transpose_to_C4:
|
| 645 |
output_score = TMIDIX.transpose_escore_notes_to_pitch(output_score, 60) # C4 is MIDI pitch 60
|
| 646 |
+
|
| 647 |
if render_align == "Start Times":
|
| 648 |
output_score = TMIDIX.recalculate_score_timings(output_score)
|
| 649 |
output_score = TMIDIX.align_escore_notes_to_bars(output_score)
|
|
|
|
| 744 |
s8bit_bass_boost_level,
|
| 745 |
fs=srate
|
| 746 |
)
|
| 747 |
+
# Normalize and prepare for Gradio
|
| 748 |
peak_val = np.max(np.abs(audio))
|
| 749 |
if peak_val > 0:
|
| 750 |
audio /= peak_val
|
| 751 |
+
# Transpose from (2, N) to (N, 2) and convert to int16 for Gradio
|
| 752 |
+
audio_out = (audio.T * 32767).astype(np.int16)
|
| 753 |
except Exception as e:
|
| 754 |
print(f"Error during 8-bit synthesis: {e}")
|
| 755 |
return [None] * 7
|
|
|
|
| 775 |
with open(midi_to_render_path, 'rb') as f:
|
| 776 |
midi_file_content = f.read()
|
| 777 |
|
| 778 |
+
audio_out = midi_to_colab_audio(midi_file_content,
|
| 779 |
soundfont_path=soundfont_path, # Use the dynamically found path
|
| 780 |
sample_rate=srate,
|
| 781 |
output_for_gradio=True
|
|
|
|
| 791 |
|
| 792 |
output_midi_summary = str(meta_data)
|
| 793 |
|
| 794 |
+
return new_md5_hash, fn1, output_midi_summary, midi_to_render_path, (srate, audio_out), output_plot, song_description
|
| 795 |
|
| 796 |
# =================================================================================================
|
| 797 |
# === Main Application Logic ===
|
|
|
|
| 799 |
|
| 800 |
def process_and_render_file(input_file,
|
| 801 |
# --- Transcription params ---
|
| 802 |
+
enable_stereo_processing,
|
| 803 |
transcription_method,
|
| 804 |
onset_thresh, frame_thresh, min_note_len, min_freq, max_freq, infer_onsets_bool, melodia_trick_bool, multiple_bends_bool,
|
| 805 |
# --- MIDI rendering params ---
|
|
|
|
| 818 |
start_time = reqtime.time()
|
| 819 |
if input_file is None:
|
| 820 |
# Return a list of updates to clear all output fields
|
| 821 |
+
return [gr.update(value=None)] * 7
|
|
|
|
| 822 |
|
| 823 |
# The input_file from gr.Audio(type="filepath") is now the direct path (a string),
|
| 824 |
# not a temporary file object. We no longer need to access the .name attribute.
|
| 825 |
input_file_path = input_file
|
| 826 |
filename = os.path.basename(input_file_path)
|
| 827 |
print(f"Processing new file: {filename}")
|
| 828 |
+
|
| 829 |
+
try:
|
| 830 |
+
audio_data, native_sample_rate = librosa.load(input_file_path, sr=None, mono=False)
|
| 831 |
+
except Exception as e:
|
| 832 |
+
raise gr.Error(f"Failed to load audio file: {e}")
|
| 833 |
|
| 834 |
# --- Step 1: Check file type and transcribe if necessary ---
|
| 835 |
if filename.lower().endswith(('.mid', '.midi', '.kar')):
|
|
|
|
| 837 |
midi_path_for_rendering = input_file_path
|
| 838 |
else: #if filename.lower().endswith(('.wav', '.mp3'))
|
| 839 |
print("Audio file detected. Starting transcription...")
|
| 840 |
+
|
| 841 |
+
base_name = os.path.splitext(filename)[0]
|
| 842 |
+
temp_dir = "output/temp_normalized"
|
| 843 |
+
os.makedirs(temp_dir, exist_ok=True)
|
| 844 |
+
|
| 845 |
+
# === STEREO PROCESSING LOGIC ===
|
| 846 |
+
if enable_stereo_processing:
|
| 847 |
+
if audio_data.ndim != 2 or audio_data.shape[0] != 2:
|
| 848 |
+
print("Warning: Audio is not stereo or could not be loaded as stereo. Falling back to mono transcription.")
|
| 849 |
+
enable_stereo_processing = False # Disable stereo processing if audio is not stereo
|
| 850 |
+
|
| 851 |
+
if enable_stereo_processing:
|
| 852 |
+
print("Stereo processing enabled. Splitting channels...")
|
| 853 |
+
try:
|
| 854 |
+
left_channel = audio_data[0]
|
| 855 |
+
right_channel = audio_data[1]
|
| 856 |
+
|
| 857 |
+
normalized_left = normalize_loudness(left_channel, native_sample_rate)
|
| 858 |
+
normalized_right = normalize_loudness(right_channel, native_sample_rate)
|
| 859 |
+
|
| 860 |
+
temp_left_wav_path = os.path.join(temp_dir, f"{base_name}_left.wav")
|
| 861 |
+
temp_right_wav_path = os.path.join(temp_dir, f"{base_name}_right.wav")
|
| 862 |
+
|
| 863 |
+
sf.write(temp_left_wav_path, normalized_left, native_sample_rate)
|
| 864 |
+
sf.write(temp_right_wav_path, normalized_right, native_sample_rate)
|
| 865 |
+
|
| 866 |
+
print(f"Saved left channel to: {temp_left_wav_path}")
|
| 867 |
+
print(f"Saved right channel to: {temp_right_wav_path}")
|
| 868 |
+
|
| 869 |
+
print("Transcribing left channel...")
|
| 870 |
+
if transcription_method == "General Purpose":
|
| 871 |
+
midi_path_left = TranscribeGeneralAudio(temp_left_wav_path, onset_thresh, frame_thresh, min_note_len, min_freq, max_freq, infer_onsets_bool, melodia_trick_bool, multiple_bends_bool)
|
| 872 |
+
else:
|
| 873 |
+
midi_path_left = TranscribePianoAudio(temp_left_wav_path)
|
| 874 |
+
|
| 875 |
+
print("Transcribing right channel...")
|
| 876 |
+
if transcription_method == "General Purpose":
|
| 877 |
+
midi_path_right = TranscribeGeneralAudio(temp_right_wav_path, onset_thresh, frame_thresh, min_note_len, min_freq, max_freq, infer_onsets_bool, melodia_trick_bool, multiple_bends_bool)
|
| 878 |
+
else:
|
| 879 |
+
midi_path_right = TranscribePianoAudio(temp_right_wav_path)
|
| 880 |
+
|
| 881 |
+
if midi_path_left and midi_path_right:
|
| 882 |
+
merged_midi_path = os.path.join(temp_dir, f"{base_name}_merged.mid")
|
| 883 |
+
midi_path_for_rendering = merge_midis(midi_path_left, midi_path_right, merged_midi_path)
|
| 884 |
+
elif midi_path_left:
|
| 885 |
+
print("Warning: Right channel transcription failed. Using left channel only.")
|
| 886 |
+
midi_path_for_rendering = midi_path_left
|
| 887 |
+
elif midi_path_right:
|
| 888 |
+
print("Warning: Left channel transcription failed. Using right channel only.")
|
| 889 |
+
midi_path_for_rendering = midi_path_right
|
| 890 |
+
else:
|
| 891 |
+
raise gr.Error("Both left and right channel transcriptions failed.")
|
| 892 |
+
|
| 893 |
+
except Exception as e:
|
| 894 |
+
print(f"An error occurred during stereo processing: {e}")
|
| 895 |
+
raise gr.Error(f"Stereo Processing Failed: {e}")
|
| 896 |
+
else:
|
| 897 |
+
print("Stereo processing disabled. Using standard mono transcription.")
|
| 898 |
+
if audio_data.ndim == 1:
|
| 899 |
+
mono_signal = audio_data
|
| 900 |
+
else:
|
| 901 |
+
mono_signal = np.mean(audio_data, axis=0)
|
| 902 |
+
|
| 903 |
+
normalized_mono = normalize_loudness(mono_signal, native_sample_rate)
|
| 904 |
+
|
| 905 |
+
temp_mono_wav_path = os.path.join(temp_dir, f"{base_name}_mono.wav")
|
| 906 |
+
sf.write(temp_mono_wav_path, normalized_mono, native_sample_rate)
|
| 907 |
+
|
| 908 |
+
try:
|
| 909 |
+
if transcription_method == "General Purpose":
|
| 910 |
+
midi_path_for_rendering = TranscribeGeneralAudio(
|
| 911 |
+
temp_mono_wav_path, onset_thresh, frame_thresh, min_note_len,
|
| 912 |
+
min_freq, max_freq, infer_onsets_bool, melodia_trick_bool, multiple_bends_bool
|
| 913 |
+
)
|
| 914 |
+
else: # Piano-Specific
|
| 915 |
+
midi_path_for_rendering = TranscribePianoAudio(temp_mono_wav_path)
|
| 916 |
+
analyze_midi_velocity(midi_path_for_rendering)
|
| 917 |
+
except Exception as e:
|
| 918 |
+
print(f"An error occurred during transcription: {e}")
|
| 919 |
+
raise gr.Error(f"Transcription Failed: {e}")
|
| 920 |
|
| 921 |
# --- Step 2: Render the MIDI file with selected options ---
|
| 922 |
print(f"Proceeding to render MIDI file: {os.path.basename(midi_path_for_rendering)}")
|
|
|
|
| 942 |
"""
|
| 943 |
is_general = (transcription_method == "General Purpose")
|
| 944 |
is_8bit = (soundfont_choice == SYNTH_8_BIT_LABEL)
|
| 945 |
+
|
| 946 |
return {
|
| 947 |
general_transcription_settings: gr.update(visible=is_general),
|
| 948 |
synth_8bit_settings: gr.update(visible=is_8bit),
|
|
|
|
| 997 |
value="General Purpose",
|
| 998 |
info="Choose 'General Purpose' for most music (vocals, etc.). Choose 'Piano-Specific' only for solo piano recordings."
|
| 999 |
)
|
| 1000 |
+
|
| 1001 |
+
# --- Stereo Processing Checkbox ---
|
| 1002 |
+
enable_stereo_processing = gr.Checkbox(
|
| 1003 |
+
label="Enable Stereo Transcription",
|
| 1004 |
+
value=False,
|
| 1005 |
+
info="If checked, left/right audio channels are transcribed separately and merged. Doubles processing time."
|
| 1006 |
+
)
|
| 1007 |
|
|
|
|
| 1008 |
with gr.Accordion("General Purpose Transcription Settings", open=True) as general_transcription_settings:
|
| 1009 |
onset_threshold = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="On-set Threshold", info="Sensitivity for detecting note beginnings. Higher is stricter.")
|
| 1010 |
frame_threshold = gr.Slider(0.0, 1.0, value=0.3, step=0.05, label="Frame Threshold", info="Sensitivity for detecting active notes. Higher is stricter.")
|
|
|
|
| 1027 |
# --- Dynamically create the list of choices ---
|
| 1028 |
soundfont_choices = [SYNTH_8_BIT_LABEL] + list(soundfonts_dict.keys())
|
| 1029 |
# Set a safe default value
|
| 1030 |
+
default_sf_choice = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7" if "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7" in soundfonts_dict else (soundfont_choices[0] if soundfont_choices else "")
|
| 1031 |
|
| 1032 |
soundfont_bank = gr.Dropdown(
|
| 1033 |
soundfont_choices,
|
|
|
|
| 1083 |
# --- Define all input components for the click event ---
|
| 1084 |
all_inputs = [
|
| 1085 |
input_file,
|
| 1086 |
+
enable_stereo_processing,
|
| 1087 |
transcription_method,
|
| 1088 |
onset_threshold, frame_threshold, minimum_note_length, minimum_frequency, maximum_frequency,
|
| 1089 |
infer_onsets, melodia_trick, multiple_pitch_bends,
|
requirements.txt
CHANGED
|
@@ -16,6 +16,8 @@ networkx
|
|
| 16 |
scikit-learn
|
| 17 |
psutil
|
| 18 |
pretty_midi
|
|
|
|
|
|
|
| 19 |
piano_transcription_inference
|
| 20 |
|
| 21 |
basic-pitch @ git+https://github.com/avan06/basic-pitch; sys_platform != 'linux'
|
|
|
|
| 16 |
scikit-learn
|
| 17 |
psutil
|
| 18 |
pretty_midi
|
| 19 |
+
soundfile
|
| 20 |
+
pyloudnorm
|
| 21 |
piano_transcription_inference
|
| 22 |
|
| 23 |
basic-pitch @ git+https://github.com/avan06/basic-pitch; sys_platform != 'linux'
|