Spaces:
				
			
			
	
			
			
		Build error
		
	
	
	
			
			
	
	
	
	
		
		
		Build error
		
	Update audio_foundation_models.py
Browse files- audio_foundation_models.py +285 -2
 
    	
        audio_foundation_models.py
    CHANGED
    
    | 
         @@ -4,6 +4,8 @@ sys.path.append(os.path.dirname(os.path.realpath(__file__))) 
     | 
|
| 4 | 
         
             
            sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
         
     | 
| 5 | 
         
             
            sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'NeuralSeq'))
         
     | 
| 6 | 
         
             
            sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'text_to_audio/Make_An_Audio'))
         
     | 
| 
         | 
|
| 
         | 
|
| 7 | 
         
             
            import matplotlib
         
     | 
| 8 | 
         
             
            import librosa
         
     | 
| 9 | 
         
             
            from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPSegProcessor, CLIPSegForImageSegmentation
         
     | 
| 
         @@ -40,7 +42,16 @@ from utils.hparams import set_hparams 
     | 
|
| 40 | 
         
             
            from utils.hparams import hparams as hp
         
     | 
| 41 | 
         
             
            from utils.os_utils import move_file
         
     | 
| 42 | 
         
             
            import scipy.io.wavfile as wavfile
         
     | 
| 43 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 44 | 
         | 
| 45 | 
         
             
            def prompts(name, description):
         
     | 
| 46 | 
         
             
                def decorator(func):
         
     | 
| 
         @@ -520,4 +531,276 @@ class A2T: 
     | 
|
| 520 | 
         
             
                def inference(self, audio_path):
         
     | 
| 521 | 
         
             
                    audio = whisper.load_audio(audio_path)
         
     | 
| 522 | 
         
             
                    caption_text = self.model(audio)
         
     | 
| 523 | 
         
            -
                    return caption_text[0]
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 4 | 
         
             
            sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
         
     | 
| 5 | 
         
             
            sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'NeuralSeq'))
         
     | 
| 6 | 
         
             
            sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'text_to_audio/Make_An_Audio'))
         
     | 
| 7 | 
         
            +
            sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'audio_detection'))
         
     | 
| 8 | 
         
            +
            sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'mono2binaural'))
         
     | 
| 9 | 
         
             
            import matplotlib
         
     | 
| 10 | 
         
             
            import librosa
         
     | 
| 11 | 
         
             
            from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPSegProcessor, CLIPSegForImageSegmentation
         
     | 
| 
         | 
|
| 42 | 
         
             
            from utils.hparams import hparams as hp
         
     | 
| 43 | 
         
             
            from utils.os_utils import move_file
         
     | 
| 44 | 
         
             
            import scipy.io.wavfile as wavfile
         
     | 
| 45 | 
         
            +
            from audio_infer.utils import config as detection_config
         
     | 
| 46 | 
         
            +
            from audio_infer.pytorch.models import PVT
         
     | 
| 47 | 
         
            +
            from src.models import BinauralNetwork
         
     | 
| 48 | 
         
            +
            from sound_extraction.model.LASSNet import LASSNet
         
     | 
| 49 | 
         
            +
            from sound_extraction.utils.stft import STFT
         
     | 
| 50 | 
         
            +
            from sound_extraction.utils.wav_io import load_wav, save_wav
         
     | 
| 51 | 
         
            +
            from target_sound_detection.src import models as tsd_models
         
     | 
| 52 | 
         
            +
            from target_sound_detection.src.models import event_labels
         
     | 
| 53 | 
         
            +
            from target_sound_detection.src.utils import median_filter, decode_with_timestamps
         
     | 
| 54 | 
         
            +
            import clip
         
     | 
| 55 | 
         | 
| 56 | 
         
             
            def prompts(name, description):
         
     | 
| 57 | 
         
             
                def decorator(func):
         
     | 
| 
         | 
|
| 531 | 
         
             
                def inference(self, audio_path):
         
     | 
| 532 | 
         
             
                    audio = whisper.load_audio(audio_path)
         
     | 
| 533 | 
         
             
                    caption_text = self.model(audio)
         
     | 
| 534 | 
         
            +
                    return caption_text[0]
         
     | 
| 535 | 
         
            +
             
     | 
| 536 | 
         
            +
            class SoundDetection:
         
     | 
| 537 | 
         
            +
                def __init__(self, device):
         
     | 
| 538 | 
         
            +
                    self.device = device
         
     | 
| 539 | 
         
            +
                    self.sample_rate = 32000
         
     | 
| 540 | 
         
            +
                    self.window_size = 1024
         
     | 
| 541 | 
         
            +
                    self.hop_size = 320
         
     | 
| 542 | 
         
            +
                    self.mel_bins = 64
         
     | 
| 543 | 
         
            +
                    self.fmin = 50
         
     | 
| 544 | 
         
            +
                    self.fmax = 14000
         
     | 
| 545 | 
         
            +
                    self.model_type = 'PVT'
         
     | 
| 546 | 
         
            +
                    self.checkpoint_path = 'audio_detection/audio_infer/useful_ckpts/audio_detection.pth'
         
     | 
| 547 | 
         
            +
                    self.classes_num = detection_config.classes_num
         
     | 
| 548 | 
         
            +
                    self.labels = detection_config.labels
         
     | 
| 549 | 
         
            +
                    self.frames_per_second = self.sample_rate // self.hop_size
         
     | 
| 550 | 
         
            +
                    # Model = eval(self.model_type)
         
     | 
| 551 | 
         
            +
                    self.model = PVT(sample_rate=self.sample_rate, window_size=self.window_size, 
         
     | 
| 552 | 
         
            +
                        hop_size=self.hop_size, mel_bins=self.mel_bins, fmin=self.fmin, fmax=self.fmax, 
         
     | 
| 553 | 
         
            +
                        classes_num=self.classes_num)
         
     | 
| 554 | 
         
            +
                    checkpoint = torch.load(self.checkpoint_path, map_location=self.device)
         
     | 
| 555 | 
         
            +
                    self.model.load_state_dict(checkpoint['model'])
         
     | 
| 556 | 
         
            +
                    self.model.to(device)
         
     | 
| 557 | 
         
            +
             
     | 
| 558 | 
         
            +
                @prompts(name="Detect The Sound Event From The Audio",
         
     | 
| 559 | 
         
            +
                         description="useful for when you want to know what event in the audio and the sound event start or end time, "
         
     | 
| 560 | 
         
            +
                                     "receives audio_path as input. "
         
     | 
| 561 | 
         
            +
                                     "The input to this tool should be a string, "
         
     | 
| 562 | 
         
            +
                                     "representing the audio_path. " )  
         
     | 
| 563 | 
         
            +
                
         
     | 
| 564 | 
         
            +
                def inference(self, audio_path):
         
     | 
| 565 | 
         
            +
                    # Forward
         
     | 
| 566 | 
         
            +
                    (waveform, _) = librosa.core.load(audio_path, sr=self.sample_rate, mono=True)
         
     | 
| 567 | 
         
            +
                    waveform = waveform[None, :]    # (1, audio_length)
         
     | 
| 568 | 
         
            +
                    waveform = torch.from_numpy(waveform)
         
     | 
| 569 | 
         
            +
                    waveform = waveform.to(self.device)
         
     | 
| 570 | 
         
            +
                    # Forward
         
     | 
| 571 | 
         
            +
                    with torch.no_grad():
         
     | 
| 572 | 
         
            +
                        self.model.eval()
         
     | 
| 573 | 
         
            +
                        batch_output_dict = self.model(waveform, None)
         
     | 
| 574 | 
         
            +
                    framewise_output = batch_output_dict['framewise_output'].data.cpu().numpy()[0]
         
     | 
| 575 | 
         
            +
                    """(time_steps, classes_num)"""
         
     | 
| 576 | 
         
            +
                    # print('Sound event detection result (time_steps x classes_num): {}'.format(
         
     | 
| 577 | 
         
            +
                    #     framewise_output.shape))
         
     | 
| 578 | 
         
            +
                    import numpy as np
         
     | 
| 579 | 
         
            +
                    import matplotlib.pyplot as plt
         
     | 
| 580 | 
         
            +
                    sorted_indexes = np.argsort(np.max(framewise_output, axis=0))[::-1]
         
     | 
| 581 | 
         
            +
                    top_k = 10  # Show top results
         
     | 
| 582 | 
         
            +
                    top_result_mat = framewise_output[:, sorted_indexes[0 : top_k]]    
         
     | 
| 583 | 
         
            +
                    """(time_steps, top_k)"""
         
     | 
| 584 | 
         
            +
                    # Plot result    
         
     | 
| 585 | 
         
            +
                    stft = librosa.core.stft(y=waveform[0].data.cpu().numpy(), n_fft=self.window_size, 
         
     | 
| 586 | 
         
            +
                        hop_length=self.hop_size, window='hann', center=True)
         
     | 
| 587 | 
         
            +
                    frames_num = stft.shape[-1]
         
     | 
| 588 | 
         
            +
                    fig, axs = plt.subplots(2, 1, sharex=True, figsize=(10, 4))
         
     | 
| 589 | 
         
            +
                    axs[0].matshow(np.log(np.abs(stft)), origin='lower', aspect='auto', cmap='jet')
         
     | 
| 590 | 
         
            +
                    axs[0].set_ylabel('Frequency bins')
         
     | 
| 591 | 
         
            +
                    axs[0].set_title('Log spectrogram')
         
     | 
| 592 | 
         
            +
                    axs[1].matshow(top_result_mat.T, origin='upper', aspect='auto', cmap='jet', vmin=0, vmax=1)
         
     | 
| 593 | 
         
            +
                    axs[1].xaxis.set_ticks(np.arange(0, frames_num, self.frames_per_second))
         
     | 
| 594 | 
         
            +
                    axs[1].xaxis.set_ticklabels(np.arange(0, frames_num / self.frames_per_second))
         
     | 
| 595 | 
         
            +
                    axs[1].yaxis.set_ticks(np.arange(0, top_k))
         
     | 
| 596 | 
         
            +
                    axs[1].yaxis.set_ticklabels(np.array(self.labels)[sorted_indexes[0 : top_k]])
         
     | 
| 597 | 
         
            +
                    axs[1].yaxis.grid(color='k', linestyle='solid', linewidth=0.3, alpha=0.3)
         
     | 
| 598 | 
         
            +
                    axs[1].set_xlabel('Seconds')
         
     | 
| 599 | 
         
            +
                    axs[1].xaxis.set_ticks_position('bottom')
         
     | 
| 600 | 
         
            +
                    plt.tight_layout()
         
     | 
| 601 | 
         
            +
                    image_filename = os.path.join('image', str(uuid.uuid4())[0:8] + ".png")
         
     | 
| 602 | 
         
            +
                    plt.savefig(image_filename)
         
     | 
| 603 | 
         
            +
                    return image_filename
         
     | 
| 604 | 
         
            +
             
     | 
| 605 | 
         
            +
            class SoundExtraction:
         
     | 
| 606 | 
         
            +
                def __init__(self, device):
         
     | 
| 607 | 
         
            +
                    self.device = device
         
     | 
| 608 | 
         
            +
                    self.model_file = 'sound_extraction/useful_ckpts/LASSNet.pt'
         
     | 
| 609 | 
         
            +
                    self.stft = STFT()
         
     | 
| 610 | 
         
            +
                    import torch.nn as nn
         
     | 
| 611 | 
         
            +
                    self.model = nn.DataParallel(LASSNet(device)).to(device)
         
     | 
| 612 | 
         
            +
                    checkpoint = torch.load(self.model_file)
         
     | 
| 613 | 
         
            +
                    self.model.load_state_dict(checkpoint['model'])
         
     | 
| 614 | 
         
            +
                    self.model.eval()
         
     | 
| 615 | 
         
            +
             
     | 
| 616 | 
         
            +
                @prompts(name="Extract Sound Event From Mixture Audio Based On Language Description",
         
     | 
| 617 | 
         
            +
                         description="useful for when you extract target sound from a mixture audio, you can describe the target sound by text, "
         
     | 
| 618 | 
         
            +
                                     "receives audio_path and text as input. "
         
     | 
| 619 | 
         
            +
                                     "The input to this tool should be a comma seperated string of two, "
         
     | 
| 620 | 
         
            +
                                     "representing mixture audio path and input text." ) 
         
     | 
| 621 | 
         
            +
                
         
     | 
| 622 | 
         
            +
                def inference(self, inputs):
         
     | 
| 623 | 
         
            +
                    #key = ['ref_audio', 'text']
         
     | 
| 624 | 
         
            +
                    val = inputs.split(",")
         
     | 
| 625 | 
         
            +
                    audio_path = val[0] # audio_path, text
         
     | 
| 626 | 
         
            +
                    text = val[1]
         
     | 
| 627 | 
         
            +
                    waveform = load_wav(audio_path)
         
     | 
| 628 | 
         
            +
                    waveform = torch.tensor(waveform).transpose(1,0)
         
     | 
| 629 | 
         
            +
                    mixed_mag, mixed_phase = self.stft.transform(waveform)
         
     | 
| 630 | 
         
            +
                    text_query = ['[CLS] ' + text]
         
     | 
| 631 | 
         
            +
                    mixed_mag = mixed_mag.transpose(2,1).unsqueeze(0).to(self.device)
         
     | 
| 632 | 
         
            +
                    est_mask = self.model(mixed_mag, text_query)
         
     | 
| 633 | 
         
            +
                    est_mag = est_mask * mixed_mag  
         
     | 
| 634 | 
         
            +
                    est_mag = est_mag.squeeze(1)  
         
     | 
| 635 | 
         
            +
                    est_mag = est_mag.permute(0, 2, 1) 
         
     | 
| 636 | 
         
            +
                    est_wav = self.stft.inverse(est_mag.cpu().detach(), mixed_phase)
         
     | 
| 637 | 
         
            +
                    est_wav = est_wav.squeeze(0).squeeze(0).numpy()  
         
     | 
| 638 | 
         
            +
                    #est_path = f'output/est{i}.wav'
         
     | 
| 639 | 
         
            +
                    audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
         
     | 
| 640 | 
         
            +
                    print('audio_filename ', audio_filename)
         
     | 
| 641 | 
         
            +
                    save_wav(est_wav, audio_filename)
         
     | 
| 642 | 
         
            +
                    return audio_filename
         
     | 
| 643 | 
         
            +
             
     | 
| 644 | 
         
            +
             
     | 
| 645 | 
         
            +
            class Binaural:
         
     | 
| 646 | 
         
            +
                def __init__(self, device):
         
     | 
| 647 | 
         
            +
                    self.device = device
         
     | 
| 648 | 
         
            +
                    self.model_file = 'mono2binaural/useful_ckpts/m2b/binaural_network.net'
         
     | 
| 649 | 
         
            +
                    self.position_file = ['mono2binaural/useful_ckpts/m2b/tx_positions.txt',
         
     | 
| 650 | 
         
            +
                                          'mono2binaural/useful_ckpts/m2b/tx_positions2.txt',
         
     | 
| 651 | 
         
            +
                                          'mono2binaural/useful_ckpts/m2b/tx_positions3.txt',
         
     | 
| 652 | 
         
            +
                                          'mono2binaural/useful_ckpts/m2b/tx_positions4.txt',
         
     | 
| 653 | 
         
            +
                                          'mono2binaural/useful_ckpts/m2b/tx_positions5.txt']
         
     | 
| 654 | 
         
            +
                    self.net = BinauralNetwork(view_dim=7,
         
     | 
| 655 | 
         
            +
                                  warpnet_layers=4,
         
     | 
| 656 | 
         
            +
                                  warpnet_channels=64,
         
     | 
| 657 | 
         
            +
                                  )
         
     | 
| 658 | 
         
            +
                    self.net.load_from_file(self.model_file)
         
     | 
| 659 | 
         
            +
                    self.sr = 48000
         
     | 
| 660 | 
         
            +
             
     | 
| 661 | 
         
            +
                @prompts(name="Sythesize Binaural Audio From A Mono Audio Input",
         
     | 
| 662 | 
         
            +
                         description="useful for when you want to transfer your mono audio into binaural audio, "
         
     | 
| 663 | 
         
            +
                                     "receives audio_path as input. "
         
     | 
| 664 | 
         
            +
                                     "The input to this tool should be a string, "
         
     | 
| 665 | 
         
            +
                                     "representing the audio_path. " ) 
         
     | 
| 666 | 
         
            +
                
         
     | 
| 667 | 
         
            +
                def inference(self, audio_path):
         
     | 
| 668 | 
         
            +
                    mono, sr  = librosa.load(path=audio_path, sr=self.sr, mono=True)
         
     | 
| 669 | 
         
            +
                    mono = torch.from_numpy(mono)
         
     | 
| 670 | 
         
            +
                    mono = mono.unsqueeze(0)
         
     | 
| 671 | 
         
            +
                    import numpy as np
         
     | 
| 672 | 
         
            +
                    import random
         
     | 
| 673 | 
         
            +
                    rand_int = random.randint(0,4)
         
     | 
| 674 | 
         
            +
                    view = np.loadtxt(self.position_file[rand_int]).transpose().astype(np.float32)
         
     | 
| 675 | 
         
            +
                    view = torch.from_numpy(view)
         
     | 
| 676 | 
         
            +
                    if not view.shape[-1] * 400 == mono.shape[-1]:
         
     | 
| 677 | 
         
            +
                        mono = mono[:,:(mono.shape[-1]//400)*400] # 
         
     | 
| 678 | 
         
            +
                        if view.shape[1]*400 > mono.shape[1]:
         
     | 
| 679 | 
         
            +
                            m_a = view.shape[1] - mono.shape[-1]//400 
         
     | 
| 680 | 
         
            +
                            rand_st = random.randint(0,m_a)
         
     | 
| 681 | 
         
            +
                            view = view[:,m_a:m_a+(mono.shape[-1]//400)] # 
         
     | 
| 682 | 
         
            +
                    # binauralize and save output
         
     | 
| 683 | 
         
            +
                    self.net.eval().to(self.device)
         
     | 
| 684 | 
         
            +
                    mono, view = mono.to(self.device), view.to(self.device)
         
     | 
| 685 | 
         
            +
                    chunk_size = 48000  # forward in chunks of 1s
         
     | 
| 686 | 
         
            +
                    rec_field =  1000  # add 1000 samples as "safe bet" since warping has undefined rec. field
         
     | 
| 687 | 
         
            +
                    rec_field -= rec_field % 400  # make sure rec_field is a multiple of 400 to match audio and view frequencies
         
     | 
| 688 | 
         
            +
                    chunks = [
         
     | 
| 689 | 
         
            +
                        {
         
     | 
| 690 | 
         
            +
                            "mono": mono[:, max(0, i-rec_field):i+chunk_size],
         
     | 
| 691 | 
         
            +
                            "view": view[:, max(0, i-rec_field)//400:(i+chunk_size)//400]
         
     | 
| 692 | 
         
            +
                        }
         
     | 
| 693 | 
         
            +
                        for i in range(0, mono.shape[-1], chunk_size)
         
     | 
| 694 | 
         
            +
                    ]
         
     | 
| 695 | 
         
            +
                    for i, chunk in enumerate(chunks):
         
     | 
| 696 | 
         
            +
                        with torch.no_grad():
         
     | 
| 697 | 
         
            +
                            mono = chunk["mono"].unsqueeze(0)
         
     | 
| 698 | 
         
            +
                            view = chunk["view"].unsqueeze(0)
         
     | 
| 699 | 
         
            +
                            binaural = self.net(mono, view).squeeze(0)
         
     | 
| 700 | 
         
            +
                            if i > 0:
         
     | 
| 701 | 
         
            +
                                binaural = binaural[:, -(mono.shape[-1]-rec_field):]
         
     | 
| 702 | 
         
            +
                            chunk["binaural"] = binaural
         
     | 
| 703 | 
         
            +
                    binaural = torch.cat([chunk["binaural"] for chunk in chunks], dim=-1)
         
     | 
| 704 | 
         
            +
                    binaural = torch.clamp(binaural, min=-1, max=1).cpu()
         
     | 
| 705 | 
         
            +
                    #binaural = chunked_forwarding(net, mono, view)
         
     | 
| 706 | 
         
            +
                    audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
         
     | 
| 707 | 
         
            +
                    import torchaudio
         
     | 
| 708 | 
         
            +
                    torchaudio.save(audio_filename, binaural, sr)
         
     | 
| 709 | 
         
            +
                    #soundfile.write(audio_filename, binaural, samplerate = 48000)
         
     | 
| 710 | 
         
            +
                    print(f"Processed Binaural.run, audio_filename: {audio_filename}")
         
     | 
| 711 | 
         
            +
                    return audio_filename
         
     | 
| 712 | 
         
            +
             
     | 
| 713 | 
         
            +
            class TargetSoundDetection:
         
     | 
| 714 | 
         
            +
                def __init__(self, device):
         
     | 
| 715 | 
         
            +
                    self.device = device
         
     | 
| 716 | 
         
            +
                    self.MEL_ARGS = {
         
     | 
| 717 | 
         
            +
                        'n_mels': 64,
         
     | 
| 718 | 
         
            +
                        'n_fft': 2048,
         
     | 
| 719 | 
         
            +
                        'hop_length': int(22050 * 20 / 1000),
         
     | 
| 720 | 
         
            +
                        'win_length': int(22050 * 40 / 1000)
         
     | 
| 721 | 
         
            +
                    }
         
     | 
| 722 | 
         
            +
                    self.EPS = np.spacing(1)
         
     | 
| 723 | 
         
            +
                    self.clip_model, _ = clip.load("ViT-B/32", device=self.device)
         
     | 
| 724 | 
         
            +
                    self.event_labels = event_labels
         
     | 
| 725 | 
         
            +
                    self.id_to_event =  {i : label for i, label in enumerate(self.event_labels)}
         
     | 
| 726 | 
         
            +
                    config = torch.load('audio_detection/target_sound_detection/useful_ckpts/tsd/run_config.pth', map_location='cpu')
         
     | 
| 727 | 
         
            +
                    config_parameters = dict(config)
         
     | 
| 728 | 
         
            +
                    config_parameters['tao'] = 0.6
         
     | 
| 729 | 
         
            +
                    if 'thres' not in config_parameters.keys():
         
     | 
| 730 | 
         
            +
                        config_parameters['thres'] = 0.5
         
     | 
| 731 | 
         
            +
                    if 'time_resolution' not in config_parameters.keys():
         
     | 
| 732 | 
         
            +
                        config_parameters['time_resolution'] = 125
         
     | 
| 733 | 
         
            +
                    model_parameters = torch.load('audio_detection/target_sound_detection/useful_ckpts/tsd/run_model_7_loss=-0.0724.pt'
         
     | 
| 734 | 
         
            +
                                                    , map_location=lambda storage, loc: storage) # load parameter 
         
     | 
| 735 | 
         
            +
                    self.model = getattr(tsd_models, config_parameters['model'])(config_parameters,
         
     | 
| 736 | 
         
            +
                                inputdim=64, outputdim=2, time_resolution=config_parameters['time_resolution'], **config_parameters['model_args'])
         
     | 
| 737 | 
         
            +
                    self.model.load_state_dict(model_parameters)
         
     | 
| 738 | 
         
            +
                    self.model = self.model.to(self.device).eval()
         
     | 
| 739 | 
         
            +
                    self.re_embeds = torch.load('audio_detection/target_sound_detection/useful_ckpts/tsd/text_emb.pth')
         
     | 
| 740 | 
         
            +
                    self.ref_mel = torch.load('audio_detection/target_sound_detection/useful_ckpts/tsd/ref_mel.pth')
         
     | 
| 741 | 
         
            +
             
     | 
| 742 | 
         
            +
                def extract_feature(self, fname):
         
     | 
| 743 | 
         
            +
                    import soundfile as sf
         
     | 
| 744 | 
         
            +
                    y, sr = sf.read(fname, dtype='float32')
         
     | 
| 745 | 
         
            +
                    print('y ', y.shape)
         
     | 
| 746 | 
         
            +
                    ti = y.shape[0]/sr
         
     | 
| 747 | 
         
            +
                    if y.ndim > 1:
         
     | 
| 748 | 
         
            +
                        y = y.mean(1)
         
     | 
| 749 | 
         
            +
                    y = librosa.resample(y, sr, 22050)
         
     | 
| 750 | 
         
            +
                    lms_feature = np.log(librosa.feature.melspectrogram(y, **self.MEL_ARGS) + self.EPS).T
         
     | 
| 751 | 
         
            +
                    return lms_feature,ti
         
     | 
| 752 | 
         
            +
                
         
     | 
| 753 | 
         
            +
                def build_clip(self, text):
         
     | 
| 754 | 
         
            +
                    text = clip.tokenize(text).to(self.device) # ["a diagram with dog", "a dog", "a cat"]
         
     | 
| 755 | 
         
            +
                    text_features = self.clip_model.encode_text(text)
         
     | 
| 756 | 
         
            +
                    return text_features
         
     | 
| 757 | 
         
            +
                
         
     | 
| 758 | 
         
            +
                def cal_similarity(self, target, retrievals):
         
     | 
| 759 | 
         
            +
                    ans = []
         
     | 
| 760 | 
         
            +
                    for name in retrievals.keys():
         
     | 
| 761 | 
         
            +
                        tmp = retrievals[name]
         
     | 
| 762 | 
         
            +
                        s = torch.cosine_similarity(target.squeeze(), tmp.squeeze(), dim=0)
         
     | 
| 763 | 
         
            +
                        ans.append(s.item())
         
     | 
| 764 | 
         
            +
                    return ans.index(max(ans))
         
     | 
| 765 | 
         
            +
             
     | 
| 766 | 
         
            +
                @prompts(name="Target Sound Detection",
         
     | 
| 767 | 
         
            +
                         description="useful for when you want to know when the target sound event in the audio happens. You can use language descriptions to instruct the model, "
         
     | 
| 768 | 
         
            +
                                     "receives text description and audio_path as input. "
         
     | 
| 769 | 
         
            +
                                     "The input to this tool should be a comma seperated string of two, "
         
     | 
| 770 | 
         
            +
                                     "representing audio path and the text description. " ) 
         
     | 
| 771 | 
         
            +
                
         
     | 
| 772 | 
         
            +
                def inference(self, text, audio_path):
         
     | 
| 773 | 
         
            +
                    target_emb = self.build_clip(text) # torch type
         
     | 
| 774 | 
         
            +
                    idx = self.cal_similarity(target_emb, self.re_embeds)
         
     | 
| 775 | 
         
            +
                    target_event = self.id_to_event[idx]
         
     | 
| 776 | 
         
            +
                    embedding = self.ref_mel[target_event]
         
     | 
| 777 | 
         
            +
                    embedding = torch.from_numpy(embedding)
         
     | 
| 778 | 
         
            +
                    embedding = embedding.unsqueeze(0).to(self.device).float()
         
     | 
| 779 | 
         
            +
                    inputs,ti = self.extract_feature(audio_path)
         
     | 
| 780 | 
         
            +
                    inputs = torch.from_numpy(inputs)
         
     | 
| 781 | 
         
            +
                    inputs = inputs.unsqueeze(0).to(self.device).float()
         
     | 
| 782 | 
         
            +
                    decision, decision_up, logit = self.model(inputs, embedding)
         
     | 
| 783 | 
         
            +
                    pred = decision_up.detach().cpu().numpy()
         
     | 
| 784 | 
         
            +
                    pred = pred[:,:,0]
         
     | 
| 785 | 
         
            +
                    frame_num = decision_up.shape[1]
         
     | 
| 786 | 
         
            +
                    time_ratio = ti / frame_num
         
     | 
| 787 | 
         
            +
                    filtered_pred = median_filter(pred, window_size=1, threshold=0.5)
         
     | 
| 788 | 
         
            +
                    time_predictions = []
         
     | 
| 789 | 
         
            +
                    for index_k in range(filtered_pred.shape[0]):
         
     | 
| 790 | 
         
            +
                        decoded_pred = []
         
     | 
| 791 | 
         
            +
                        decoded_pred_ = decode_with_timestamps(target_event, filtered_pred[index_k,:])
         
     | 
| 792 | 
         
            +
                        if len(decoded_pred_) == 0: # neg deal
         
     | 
| 793 | 
         
            +
                            decoded_pred_.append((target_event, 0, 0))
         
     | 
| 794 | 
         
            +
                        decoded_pred.append(decoded_pred_)
         
     | 
| 795 | 
         
            +
                        for num_batch in range(len(decoded_pred)): # when we test our model,the batch_size is 1
         
     | 
| 796 | 
         
            +
                            cur_pred = pred[num_batch]
         
     | 
| 797 | 
         
            +
                            # Save each frame output, for later visualization
         
     | 
| 798 | 
         
            +
                            label_prediction = decoded_pred[num_batch] # frame predict
         
     | 
| 799 | 
         
            +
                            for event_label, onset, offset in label_prediction:
         
     | 
| 800 | 
         
            +
                                time_predictions.append({
         
     | 
| 801 | 
         
            +
                                    'onset': onset*time_ratio,
         
     | 
| 802 | 
         
            +
                                    'offset': offset*time_ratio,})
         
     | 
| 803 | 
         
            +
                    ans = ''
         
     | 
| 804 | 
         
            +
                    for i,item in enumerate(time_predictions):
         
     | 
| 805 | 
         
            +
                        ans = ans + 'segment' + str(i+1) + ' start_time: ' + str(item['onset']) + '  end_time: ' + str(item['offset']) + '\t'
         
     | 
| 806 | 
         
            +
                    return ans
         
     |