Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
- src/f5_tts/model/trainer.py +10 -4
src/f5_tts/model/trainer.py
CHANGED
|
@@ -61,7 +61,7 @@ class Trainer:
|
|
| 61 |
gradient_accumulation_steps=grad_accumulation_steps,
|
| 62 |
**accelerate_kwargs,
|
| 63 |
)
|
| 64 |
-
|
| 65 |
self.logger = logger
|
| 66 |
if self.logger == "wandb":
|
| 67 |
if exists(wandb_resume_id):
|
|
@@ -325,7 +325,9 @@ class Trainer:
|
|
| 325 |
|
| 326 |
if self.log_samples and self.accelerator.is_local_main_process:
|
| 327 |
ref_audio, ref_audio_len = vocoder.decode(batch["mel"][0].unsqueeze(0)), mel_lengths[0]
|
| 328 |
-
torchaudio.save(
|
|
|
|
|
|
|
| 329 |
with torch.inference_mode():
|
| 330 |
generated, _ = self.accelerator.unwrap_model(self.model).sample(
|
| 331 |
cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
|
|
@@ -336,8 +338,12 @@ class Trainer:
|
|
| 336 |
sway_sampling_coef=sway_sampling_coef,
|
| 337 |
)
|
| 338 |
generated = generated.to(torch.float32)
|
| 339 |
-
gen_audio = vocoder.decode(
|
| 340 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
|
| 342 |
if global_step % self.last_per_steps == 0:
|
| 343 |
self.save_checkpoint(global_step, last=True)
|
|
|
|
| 61 |
gradient_accumulation_steps=grad_accumulation_steps,
|
| 62 |
**accelerate_kwargs,
|
| 63 |
)
|
| 64 |
+
|
| 65 |
self.logger = logger
|
| 66 |
if self.logger == "wandb":
|
| 67 |
if exists(wandb_resume_id):
|
|
|
|
| 325 |
|
| 326 |
if self.log_samples and self.accelerator.is_local_main_process:
|
| 327 |
ref_audio, ref_audio_len = vocoder.decode(batch["mel"][0].unsqueeze(0)), mel_lengths[0]
|
| 328 |
+
torchaudio.save(
|
| 329 |
+
f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio.cpu(), target_sample_rate
|
| 330 |
+
)
|
| 331 |
with torch.inference_mode():
|
| 332 |
generated, _ = self.accelerator.unwrap_model(self.model).sample(
|
| 333 |
cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
|
|
|
|
| 338 |
sway_sampling_coef=sway_sampling_coef,
|
| 339 |
)
|
| 340 |
generated = generated.to(torch.float32)
|
| 341 |
+
gen_audio = vocoder.decode(
|
| 342 |
+
generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.accelerator.device)
|
| 343 |
+
)
|
| 344 |
+
torchaudio.save(
|
| 345 |
+
f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio.cpu(), target_sample_rate
|
| 346 |
+
)
|
| 347 |
|
| 348 |
if global_step % self.last_per_steps == 0:
|
| 349 |
self.save_checkpoint(global_step, last=True)
|