HoneyTian commited on
Commit
cb69fb3
·
1 Parent(s): d32c7e7
Files changed (1) hide show
  1. main.py +28 -1
main.py CHANGED
@@ -22,6 +22,7 @@ import shutil
22
  import tempfile
23
  import time
24
  from typing import Dict, Tuple
 
25
  import zipfile
26
 
27
  import gradio as gr
@@ -30,11 +31,11 @@ import librosa
30
  import librosa.display
31
  import matplotlib.pyplot as plt
32
  import numpy as np
 
33
 
34
  import log
35
  from project_settings import environment, project_path, log_directory
36
  from toolbox.os.command import Command
37
- from toolbox.torchaudio.models.dfnet.inference_dfnet import InferenceDfNet
38
  from toolbox.torchaudio.models.dfnet2.inference_dfnet2 import InferenceDfNet2
39
  from toolbox.torchaudio.models.dtln.inference_dtln import InferenceDTLN
40
  from toolbox.torchaudio.models.frcrn.inference_frcrn import InferenceFRCRN
@@ -79,6 +80,28 @@ def get_args():
79
  return args
80
 
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  def shell(cmd: str):
83
  return Command.popen(cmd)
84
 
@@ -131,6 +154,10 @@ def when_click_denoise_button(noisy_audio_file_t = None, noisy_audio_microphone_
131
  noisy_audio_t: Tuple = noisy_audio_file_t or noisy_audio_microphone_t
132
 
133
  sample_rate, signal = noisy_audio_t
 
 
 
 
134
  audio_duration = signal.shape[-1] // 8000
135
 
136
  # Test: 使用 microphone 时,显示采样率是 44100,但 signal 实际是按 8000 的采样率的。
 
22
  import tempfile
23
  import time
24
  from typing import Dict, Tuple
25
+ import uuid
26
  import zipfile
27
 
28
  import gradio as gr
 
31
  import librosa.display
32
  import matplotlib.pyplot as plt
33
  import numpy as np
34
+ from scipy.io import wavfile
35
 
36
  import log
37
  from project_settings import environment, project_path, log_directory
38
  from toolbox.os.command import Command
 
39
  from toolbox.torchaudio.models.dfnet2.inference_dfnet2 import InferenceDfNet2
40
  from toolbox.torchaudio.models.dtln.inference_dtln import InferenceDTLN
41
  from toolbox.torchaudio.models.frcrn.inference_frcrn import InferenceFRCRN
 
80
  return args
81
 
82
 
83
+ def save_input_audio(sample_rate: int, signal: np.ndarray) -> str:
84
+ if signal.dtype != np.int16:
85
+ raise AssertionError(f"only support dtype np.int16, however: {signal.dtype}")
86
+ temp_audio_dir = Path(tempfile.gettempdir()) / "input_audio"
87
+ temp_audio_dir.mkdir(parents=True, exist_ok=True)
88
+ filename = temp_audio_dir / f"{uuid.uuid4()}.wav"
89
+ filename = filename.as_posix()
90
+ wavfile.write(
91
+ filename,
92
+ sample_rate, signal
93
+ )
94
+ return filename
95
+
96
+
97
+ def convert_sample_rate(signal: np.ndarray, sample_rate: int, target_sample_rate: int):
98
+ filename = save_input_audio(sample_rate, signal)
99
+
100
+ signal, _ = librosa.load(filename, sr=target_sample_rate)
101
+ signal = np.array(signal * (1 << 15), dtype=np.int16)
102
+ return signal
103
+
104
+
105
  def shell(cmd: str):
106
  return Command.popen(cmd)
107
 
 
154
  noisy_audio_t: Tuple = noisy_audio_file_t or noisy_audio_microphone_t
155
 
156
  sample_rate, signal = noisy_audio_t
157
+ if sample_rate != 8000:
158
+ signal = convert_sample_rate(signal, sample_rate, 8000)
159
+ sample_rate = 8000
160
+
161
  audio_duration = signal.shape[-1] // 8000
162
 
163
  # Test: 使用 microphone 时,显示采样率是 44100,但 signal 实际是按 8000 的采样率的。