mrfakename commited on
Commit
069a328
·
verified ·
1 Parent(s): 924a364

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

src/f5_tts/runtime/triton_trtllm/client_grpc.py CHANGED
@@ -220,8 +220,8 @@ def get_args():
220
  return parser.parse_args()
221
 
222
 
223
- def load_audio(wav_path, target_sample_rate=16000):
224
- assert target_sample_rate == 16000, "hard coding in server"
225
  if isinstance(wav_path, dict):
226
  waveform = wav_path["array"]
227
  sample_rate = wav_path["sampling_rate"]
@@ -244,7 +244,7 @@ async def send(
244
  model_name: str,
245
  padding_duration: int = None,
246
  audio_save_dir: str = "./",
247
- save_sample_rate: int = 16000,
248
  ):
249
  total_duration = 0.0
250
  latency_data = []
@@ -254,7 +254,7 @@ async def send(
254
  for i, item in enumerate(manifest_item_list):
255
  if i % log_interval == 0:
256
  print(f"{name}: {i}/{len(manifest_item_list)}")
257
- waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
258
  duration = len(waveform) / sample_rate
259
  lengths = np.array([[len(waveform)]], dtype=np.int32)
260
 
@@ -417,7 +417,7 @@ async def main():
417
  model_name=args.model_name,
418
  audio_save_dir=args.log_dir,
419
  padding_duration=1,
420
- save_sample_rate=24000 if args.model_name == "f5_tts" else 16000,
421
  )
422
  )
423
  tasks.append(task)
 
220
  return parser.parse_args()
221
 
222
 
223
+ def load_audio(wav_path, target_sample_rate=24000):
224
+ assert target_sample_rate == 24000, "hard coding in server"
225
  if isinstance(wav_path, dict):
226
  waveform = wav_path["array"]
227
  sample_rate = wav_path["sampling_rate"]
 
244
  model_name: str,
245
  padding_duration: int = None,
246
  audio_save_dir: str = "./",
247
+ save_sample_rate: int = 24000,
248
  ):
249
  total_duration = 0.0
250
  latency_data = []
 
254
  for i, item in enumerate(manifest_item_list):
255
  if i % log_interval == 0:
256
  print(f"{name}: {i}/{len(manifest_item_list)}")
257
+ waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=24000)
258
  duration = len(waveform) / sample_rate
259
  lengths = np.array([[len(waveform)]], dtype=np.int32)
260
 
 
417
  model_name=args.model_name,
418
  audio_save_dir=args.log_dir,
419
  padding_duration=1,
420
+ save_sample_rate=24000,
421
  )
422
  )
423
  tasks.append(task)
src/f5_tts/runtime/triton_trtllm/client_http.py CHANGED
@@ -82,7 +82,7 @@ def prepare_request(
82
  samples,
83
  reference_text,
84
  target_text,
85
- sample_rate=16000,
86
  audio_save_dir: str = "./",
87
  ):
88
  assert len(samples.shape) == 1, "samples should be 1D"
@@ -106,8 +106,8 @@ def prepare_request(
106
  return data
107
 
108
 
109
- def load_audio(wav_path, target_sample_rate=16000):
110
- assert target_sample_rate == 16000, "hard coding in server"
111
  if isinstance(wav_path, dict):
112
  samples = wav_path["array"]
113
  sample_rate = wav_path["sampling_rate"]
@@ -129,7 +129,7 @@ if __name__ == "__main__":
129
 
130
  url = f"{server_url}/v2/models/{args.model_name}/infer"
131
  samples, sr = load_audio(args.reference_audio)
132
- assert sr == 16000, "sample rate hardcoded in server"
133
 
134
  samples = np.array(samples, dtype=np.float32)
135
  data = prepare_request(samples, args.reference_text, args.target_text)
 
82
  samples,
83
  reference_text,
84
  target_text,
85
+ sample_rate=24000,
86
  audio_save_dir: str = "./",
87
  ):
88
  assert len(samples.shape) == 1, "samples should be 1D"
 
106
  return data
107
 
108
 
109
+ def load_audio(wav_path, target_sample_rate=24000):
110
+ assert target_sample_rate == 24000, "hard coding in server"
111
  if isinstance(wav_path, dict):
112
  samples = wav_path["array"]
113
  sample_rate = wav_path["sampling_rate"]
 
129
 
130
  url = f"{server_url}/v2/models/{args.model_name}/infer"
131
  samples, sr = load_audio(args.reference_audio)
132
+ assert sr == 24000, "sample rate hardcoded in server"
133
 
134
  samples = np.array(samples, dtype=np.float32)
135
  data = prepare_request(samples, args.reference_text, args.target_text)
src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/config.pbtxt CHANGED
@@ -33,7 +33,7 @@ parameters [
33
  },
34
  {
35
  key: "reference_audio_sample_rate",
36
- value: {string_value:"16000"}
37
  },
38
  {
39
  key: "vocoder",
 
33
  },
34
  {
35
  key: "reference_audio_sample_rate",
36
+ value: {string_value:"24000"}
37
  },
38
  {
39
  key: "vocoder",