Spaces:
Runtime error
Runtime error
Update models/resunet.py
Browse files- models/resunet.py +60 -0
models/resunet.py
CHANGED
|
@@ -652,4 +652,64 @@ class ResUNet30(nn.Module):
|
|
| 652 |
|
| 653 |
return output_dict
|
| 654 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 655 |
|
|
|
|
|
|
| 652 |
|
| 653 |
return output_dict
|
| 654 |
|
| 655 |
+
|
| 656 |
+
@torch.no_grad()
|
| 657 |
+
def chunk_inference(self, input_dict):
|
| 658 |
+
chunk_config = {
|
| 659 |
+
'NL': 1.0,
|
| 660 |
+
'NC': 3.0,
|
| 661 |
+
'NR': 1.0,
|
| 662 |
+
'RATE': self.sampling_rate
|
| 663 |
+
}
|
| 664 |
+
|
| 665 |
+
mixtures = input_dict['mixture']
|
| 666 |
+
conditions = input_dict['condition']
|
| 667 |
+
|
| 668 |
+
film_dict = self.film(
|
| 669 |
+
conditions=conditions,
|
| 670 |
+
)
|
| 671 |
+
|
| 672 |
+
NL = int(chunk_config['NL'] * chunk_config['RATE'])
|
| 673 |
+
NC = int(chunk_config['NC'] * chunk_config['RATE'])
|
| 674 |
+
NR = int(chunk_config['NR'] * chunk_config['RATE'])
|
| 675 |
+
|
| 676 |
+
L = mixtures.shape[2]
|
| 677 |
+
|
| 678 |
+
out_np = np.zeros([1, L])
|
| 679 |
+
|
| 680 |
+
WINDOW = NL + NC + NR
|
| 681 |
+
current_idx = 0
|
| 682 |
+
|
| 683 |
+
while current_idx + WINDOW < L:
|
| 684 |
+
chunk_in = mixtures[:, :, current_idx:current_idx + WINDOW]
|
| 685 |
+
|
| 686 |
+
chunk_out = self.base(
|
| 687 |
+
mixtures=chunk_in,
|
| 688 |
+
film_dict=film_dict,
|
| 689 |
+
)['waveform']
|
| 690 |
+
|
| 691 |
+
chunk_out_np = chunk_out.squeeze(0).cpu().data.numpy()
|
| 692 |
+
|
| 693 |
+
if current_idx == 0:
|
| 694 |
+
out_np[:, current_idx:current_idx+WINDOW-NR] = \
|
| 695 |
+
chunk_out_np[:, :-NR] if NR != 0 else chunk_out_np
|
| 696 |
+
else:
|
| 697 |
+
out_np[:, current_idx+NL:current_idx+WINDOW-NR] = \
|
| 698 |
+
chunk_out_np[:, NL:-NR] if NR != 0 else chunk_out_np[:, NL:]
|
| 699 |
+
|
| 700 |
+
current_idx += NC
|
| 701 |
+
|
| 702 |
+
if current_idx < L:
|
| 703 |
+
chunk_in = mixtures[:, :, current_idx:current_idx + WINDOW]
|
| 704 |
+
chunk_out = self.base(
|
| 705 |
+
mixtures=chunk_in,
|
| 706 |
+
film_dict=film_dict,
|
| 707 |
+
)['waveform']
|
| 708 |
+
|
| 709 |
+
chunk_out_np = chunk_out.squeeze(0).cpu().data.numpy()
|
| 710 |
+
|
| 711 |
+
seg_len = chunk_out_np.shape[1]
|
| 712 |
+
out_np[:, current_idx + NL:current_idx + seg_len] = \
|
| 713 |
+
chunk_out_np[:, NL:]
|
| 714 |
|
| 715 |
+
return out_np
|