Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
    	
        src/f5_tts/eval/eval_infer_batch.py
    CHANGED
    
    | @@ -189,13 +189,13 @@ def main(): | |
| 189 | 
             
                                gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
         | 
| 190 | 
             
                                gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
         | 
| 191 | 
             
                                if mel_spec_type == "vocos":
         | 
| 192 | 
            -
                                    generated_wave = vocoder.decode(gen_mel_spec)
         | 
| 193 | 
             
                                elif mel_spec_type == "bigvgan":
         | 
| 194 | 
            -
                                    generated_wave = vocoder(gen_mel_spec)
         | 
| 195 |  | 
| 196 | 
             
                                if ref_rms_list[i] < target_rms:
         | 
| 197 | 
             
                                    generated_wave = generated_wave * ref_rms_list[i] / target_rms
         | 
| 198 | 
            -
                                torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave | 
| 199 |  | 
| 200 | 
             
                accelerator.wait_for_everyone()
         | 
| 201 | 
             
                if accelerator.is_main_process:
         | 
|  | |
| 189 | 
             
                                gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
         | 
| 190 | 
             
                                gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
         | 
| 191 | 
             
                                if mel_spec_type == "vocos":
         | 
| 192 | 
            +
                                    generated_wave = vocoder.decode(gen_mel_spec).cpu()
         | 
| 193 | 
             
                                elif mel_spec_type == "bigvgan":
         | 
| 194 | 
            +
                                    generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
         | 
| 195 |  | 
| 196 | 
             
                                if ref_rms_list[i] < target_rms:
         | 
| 197 | 
             
                                    generated_wave = generated_wave * ref_rms_list[i] / target_rms
         | 
| 198 | 
            +
                                torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate)
         | 
| 199 |  | 
| 200 | 
             
                accelerator.wait_for_everyone()
         | 
| 201 | 
             
                if accelerator.is_main_process:
         | 
    	
        src/f5_tts/infer/speech_edit.py
    CHANGED
    
    | @@ -181,13 +181,13 @@ with torch.inference_mode(): | |
| 181 | 
             
                generated = generated[:, ref_audio_len:, :]
         | 
| 182 | 
             
                gen_mel_spec = generated.permute(0, 2, 1)
         | 
| 183 | 
             
                if mel_spec_type == "vocos":
         | 
| 184 | 
            -
                    generated_wave = vocoder.decode(gen_mel_spec)
         | 
| 185 | 
             
                elif mel_spec_type == "bigvgan":
         | 
| 186 | 
            -
                    generated_wave = vocoder(gen_mel_spec)
         | 
| 187 |  | 
| 188 | 
             
                if rms < target_rms:
         | 
| 189 | 
             
                    generated_wave = generated_wave * rms / target_rms
         | 
| 190 |  | 
| 191 | 
             
                save_spectrogram(gen_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png")
         | 
| 192 | 
            -
                torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave | 
| 193 | 
             
                print(f"Generated wav: {generated_wave.shape}")
         | 
|  | |
| 181 | 
             
                generated = generated[:, ref_audio_len:, :]
         | 
| 182 | 
             
                gen_mel_spec = generated.permute(0, 2, 1)
         | 
| 183 | 
             
                if mel_spec_type == "vocos":
         | 
| 184 | 
            +
                    generated_wave = vocoder.decode(gen_mel_spec).cpu()
         | 
| 185 | 
             
                elif mel_spec_type == "bigvgan":
         | 
| 186 | 
            +
                    generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
         | 
| 187 |  | 
| 188 | 
             
                if rms < target_rms:
         | 
| 189 | 
             
                    generated_wave = generated_wave * rms / target_rms
         | 
| 190 |  | 
| 191 | 
             
                save_spectrogram(gen_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png")
         | 
| 192 | 
            +
                torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave, target_sample_rate)
         | 
| 193 | 
             
                print(f"Generated wav: {generated_wave.shape}")
         | 
    	
        src/f5_tts/model/trainer.py
    CHANGED
    
    | @@ -324,26 +324,31 @@ class Trainer: | |
| 324 | 
             
                                self.save_checkpoint(global_step)
         | 
| 325 |  | 
| 326 | 
             
                                if self.log_samples and self.accelerator.is_local_main_process:
         | 
| 327 | 
            -
                                     | 
| 328 | 
            -
                                     | 
| 329 | 
            -
                                         | 
| 330 | 
            -
                                     | 
| 331 | 
             
                                    with torch.inference_mode():
         | 
| 332 | 
             
                                        generated, _ = self.accelerator.unwrap_model(self.model).sample(
         | 
| 333 | 
             
                                            cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
         | 
| 334 | 
            -
                                            text= | 
| 335 | 
             
                                            duration=ref_audio_len * 2,
         | 
| 336 | 
             
                                            steps=nfe_step,
         | 
| 337 | 
             
                                            cfg_strength=cfg_strength,
         | 
| 338 | 
             
                                            sway_sampling_coef=sway_sampling_coef,
         | 
| 339 | 
             
                                        )
         | 
| 340 | 
            -
             | 
| 341 | 
            -
             | 
| 342 | 
            -
                                         | 
| 343 | 
            -
             | 
| 344 | 
            -
             | 
| 345 | 
            -
             | 
| 346 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 347 |  | 
| 348 | 
             
                            if global_step % self.last_per_steps == 0:
         | 
| 349 | 
             
                                self.save_checkpoint(global_step, last=True)
         | 
|  | |
| 324 | 
             
                                self.save_checkpoint(global_step)
         | 
| 325 |  | 
| 326 | 
             
                                if self.log_samples and self.accelerator.is_local_main_process:
         | 
| 327 | 
            +
                                    ref_audio_len = mel_lengths[0]
         | 
| 328 | 
            +
                                    infer_text = [
         | 
| 329 | 
            +
                                        text_inputs[0] + ([" "] if isinstance(text_inputs[0], list) else " ") + text_inputs[0]
         | 
| 330 | 
            +
                                    ]
         | 
| 331 | 
             
                                    with torch.inference_mode():
         | 
| 332 | 
             
                                        generated, _ = self.accelerator.unwrap_model(self.model).sample(
         | 
| 333 | 
             
                                            cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
         | 
| 334 | 
            +
                                            text=infer_text,
         | 
| 335 | 
             
                                            duration=ref_audio_len * 2,
         | 
| 336 | 
             
                                            steps=nfe_step,
         | 
| 337 | 
             
                                            cfg_strength=cfg_strength,
         | 
| 338 | 
             
                                            sway_sampling_coef=sway_sampling_coef,
         | 
| 339 | 
             
                                        )
         | 
| 340 | 
            +
                                        generated = generated.to(torch.float32)
         | 
| 341 | 
            +
                                        gen_mel_spec = generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.accelerator.device)
         | 
| 342 | 
            +
                                        ref_mel_spec = batch["mel"][0].unsqueeze(0)
         | 
| 343 | 
            +
                                        if self.vocoder_name == "vocos":
         | 
| 344 | 
            +
                                            gen_audio = vocoder.decode(gen_mel_spec).cpu()
         | 
| 345 | 
            +
                                            ref_audio = vocoder.decode(ref_mel_spec).cpu()
         | 
| 346 | 
            +
                                        elif self.vocoder_name == "bigvgan":
         | 
| 347 | 
            +
                                            gen_audio = vocoder(gen_mel_spec).squeeze(0).cpu()
         | 
| 348 | 
            +
                                            ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu()
         | 
| 349 | 
            +
             | 
| 350 | 
            +
                                    torchaudio.save(f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio, target_sample_rate)
         | 
| 351 | 
            +
                                    torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate)
         | 
| 352 |  | 
| 353 | 
             
                            if global_step % self.last_per_steps == 0:
         | 
| 354 | 
             
                                self.save_checkpoint(global_step, last=True)
         | 
 
			
