Gregniuki commited on
Commit
caa0b3d
·
verified ·
1 Parent(s): 1e88d8f

Update inference_cli.py

Browse files
Files changed (1) hide show
  1. inference_cli.py +118 -878
inference_cli.py CHANGED
@@ -1,895 +1,135 @@
 
 
1
  import argparse
2
- import codecs
3
- import re
4
- import tempfile
5
- from pathlib import Path
6
- import logging
7
- import numpy as np
8
  import soundfile as sf
9
- import tomli
10
- import torch
11
- import torchaudio
12
- from tqdm import tqdm
13
- from einops import rearrange
14
- from pydub import AudioSegment, silence
15
- from transformers import pipeline
16
- from huggingface_hub import login
17
- from cached_path import cached_path
18
- import matplotlib.pyplot as plt # Needed for save_spectrogram
19
 
20
- # --- Import Model Architectures ---
21
- # !! Ensure these models are defined in your project's 'model' module !!
22
  try:
23
- from model import UNetT, DiT
24
- except ImportError:
25
- print("Warning: Could not import UNetT, DiT from 'model'. Using placeholders.")
26
- # Placeholder classes if import fails (script might not work correctly)
27
- class MockModel:
28
- def __init__(self, *args, **kwargs): pass
29
- def to(self, device): return self
30
- def eval(self): pass
31
- def sample(self, *args, **kwargs):
32
- duration = kwargs.get('duration', 500); mel_dim = 100
33
- return torch.randn(1, duration, mel_dim), None
34
- @property
35
- def device(self): return torch.device("cpu")
36
- DiT = MockModel
37
- UNetT = MockModel
38
-
39
- # --- Import/Define Utility Functions ---
40
-
41
- from tokenizers import Tokenizer
42
- from phonemizer import phonemize
43
-
44
- # --- Functions copied/adapted from app.py ---
45
-
46
- # Function to load vocoder (from app.py context, may need adjustment)
47
- def load_vocoder(device='cpu'):
48
- """Loads the Vocos vocoder."""
49
- print("Loading Vocos vocoder (charactr/vocos-mel-24khz)...")
50
- try:
51
- # Ensure vocos library is installed: pip install vocos
52
- from vocos import Vocos
53
- # Determine torch dtype based on device for potential efficiency
54
- # Note: Vocos might internally cast, but being explicit can help.
55
- # Using float32 as a safe default unless on CUDA where float16 might work.
56
- vocos_dtype = torch.float16 if str(device) == 'cuda' else torch.float32
57
-
58
- vocos_model = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
59
- # Cast to appropriate dtype if needed, although Vocos might handle this.
60
- # vocos_model = vocos_model.to(dtype=vocos_dtype) # Optional casting
61
- vocos_model.eval()
62
- print("Vocos vocoder loaded successfully.")
63
- return vocos_model
64
- except ImportError:
65
- print("Error: 'vocos' library not found. Please install it: pip install vocos")
66
- raise
67
- except Exception as e:
68
- print(f"Error loading Vocos model: {e}")
69
- raise
70
-
71
- # Function to remove silence from edges (from app.py)
72
- def remove_silence_edges(aseg):
73
- """Removes silence from the beginning and end of an AudioSegment."""
74
- print("Removing silence from audio edges...")
75
- start_trim = silence.detect_leading_silence(aseg)
76
- end_trim = silence.detect_leading_silence(aseg.reverse())
77
- duration = len(aseg)
78
- trimmed_aseg = aseg[start_trim:duration-end_trim]
79
- print(f"Removed {start_trim}ms from start, {end_trim}ms from end.")
80
- return trimmed_aseg
81
-
82
- # Function to save spectrogram (from app.py)
83
- def save_spectrogram(spectrogram, file_path):
84
- """Saves a spectrogram visualization to a file."""
85
- if spectrogram is None:
86
- print("Spectrogram data is None, cannot save.")
87
- return
88
- try:
89
- print(f"Saving spectrogram to {file_path}...")
90
- plt.figure(figsize=(10, 4))
91
- plt.imshow(spectrogram, aspect='auto', origin='lower', cmap='viridis')
92
- plt.colorbar(label='Mel power')
93
- plt.xlabel('Frames')
94
- plt.ylabel('Mel bins')
95
- plt.title('Generated Mel Spectrogram')
96
- plt.tight_layout()
97
- plt.savefig(file_path)
98
- plt.close() # Close the figure to free memory
99
- print("Spectrogram saved.")
100
- except Exception as e:
101
- print(f"Error saving spectrogram: {e}")
102
-
103
- # Helper function to load checkpoint (from app.py, slightly modified for CLI)
104
- def load_checkpoint(model, ckpt_path, device, use_ema=False):
105
- """Loads model weights from a checkpoint file (.pt or .safetensors)."""
106
- print(f"Loading checkpoint from {ckpt_path}...")
107
  try:
108
- if ckpt_path.endswith(".safetensors"):
109
- # Ensure safetensors is installed: pip install safetensors
110
- from safetensors.torch import load_file
111
- state_dict = load_file(ckpt_path, device="cpu")
112
- elif ckpt_path.endswith(".pt"):
113
- state_dict = torch.load(ckpt_path, map_location="cpu")
114
- else:
115
- raise ValueError(f"Unsupported checkpoint format: {ckpt_path}. Must be .pt or .safetensors")
116
-
117
- # Standardize state_dict format (e.g., remove 'state_dict' key if present)
118
- if "state_dict" in state_dict:
119
- state_dict = state_dict["state_dict"]
120
-
121
- # Handle EMA weights
122
- ema_key_prefix = "ema_model." # Adjust if your EMA keys have a different prefix
123
- final_state_dict = {}
124
- has_ema = any(k.startswith(ema_key_prefix) for k in state_dict.keys())
125
-
126
- if use_ema:
127
- if has_ema:
128
- print("Attempting to load EMA weights.")
129
- ema_state_dict = {k[len(ema_key_prefix):]: v for k, v in state_dict.items() if k.startswith(ema_key_prefix)}
130
- if ema_state_dict:
131
- final_state_dict = ema_state_dict
132
- print("Using EMA weights.")
133
- else:
134
- # This case shouldn't happen if has_ema is true, but as a safeguard:
135
- print("Warning: EMA weights requested but none found starting with prefix. Using regular weights.")
136
- final_state_dict = {k: v for k, v in state_dict.items() if not k.startswith(ema_key_prefix)}
137
- else:
138
- print("Warning: EMA weights requested but no keys found with EMA prefix. Using regular weights.")
139
- final_state_dict = state_dict # Use the original dict if no EMA keys exist
140
- else:
141
- print("Loading non-EMA weights.")
142
- # Filter out EMA weights if they exist and we explicitly don't want them
143
- final_state_dict = {k: v for k, v in state_dict.items() if not k.startswith(ema_key_prefix)}
144
-
145
-
146
- # Load into model, handling potential 'module.' prefix from DDP
147
- model_state_dict = model.state_dict()
148
- processed_state_dict = {}
149
- for k, v in final_state_dict.items():
150
- if k.startswith("module."):
151
- k_proc = k[len("module."):]
152
- else:
153
- k_proc = k
154
-
155
- if k_proc in model_state_dict:
156
- if model_state_dict[k_proc].shape == v.shape:
157
- processed_state_dict[k_proc] = v
158
- else:
159
- print(f"Warning: Shape mismatch for key {k_proc}. Checkpoint: {v.shape}, Model: {model_state_dict[k_proc].shape}. Skipping.")
160
- # else: # Optional: Log unexpected keys
161
- # print(f"Warning: Key {k_proc} from checkpoint not found in model. Skipping.")
162
-
163
- missing_keys, unexpected_keys = model.load_state_dict(processed_state_dict, strict=False)
164
-
165
- if missing_keys:
166
- print(f"Warning: Missing keys in model not found in checkpoint: {missing_keys}")
167
- if unexpected_keys:
168
- # This should ideally be empty if we filter correctly, but good to check.
169
- print(f"Warning: Unexpected keys (should not happen with filtering): {unexpected_keys}")
170
-
171
- print(f"Checkpoint loaded successfully from {ckpt_path}")
172
-
173
- except FileNotFoundError:
174
- print(f"Error: Checkpoint file not found at {ckpt_path}")
175
- raise
176
- except Exception as e:
177
- print(f"Error loading checkpoint from {ckpt_path}: {e}")
178
- raise # Re-raise the exception
179
-
180
- model.eval()
181
- return model.to(device)
182
-
183
- # Primary model loading function (from app.py)
184
- def load_custom(model_cls, model_cfg, ckpt_path: str, vocab_size: int, device='cpu', use_ema=True):
185
- """Loads a custom TTS model (DiT or UNetT) with specified config and checkpoint."""
186
- ckpt_path = ckpt_path.strip()
187
-
188
- if ckpt_path.startswith("hf://"):
189
- print(f"Downloading checkpoint from Hugging Face Hub: {ckpt_path}")
190
- try:
191
- ckpt_path = str(cached_path(ckpt_path))
192
- print(f"Checkpoint downloaded to: {ckpt_path}")
193
- except Exception as e:
194
- print(f"Error downloading checkpoint {ckpt_path}: {e}")
195
- raise
196
-
197
- if not Path(ckpt_path).exists():
198
- raise FileNotFoundError(f"Checkpoint file not found: {ckpt_path}")
199
-
200
- # Ensure necessary config keys are present (add defaults if missing)
201
- if 'mel_dim' not in model_cfg:
202
- model_cfg['mel_dim'] = 100 # Default mel channels
203
- print(f"Warning: 'mel_dim' not in model_cfg, defaulting to {model_cfg['mel_dim']}")
204
- if 'text_num_embeds' not in model_cfg:
205
- model_cfg['text_num_embeds'] = vocab_size
206
- print(f"Setting 'text_num_embeds' in model_cfg to vocab size: {vocab_size}")
207
-
208
- print(f"Instantiating model: {model_cls.__name__} with config: {model_cfg}")
209
  try:
210
- model = model_cls(**model_cfg).to(device) # Instantiate the model
211
- except Exception as e:
212
- print(f"Error instantiating model {model_cls.__name__} with config {model_cfg}: {e}")
213
- raise
214
-
215
- # Load weights using the helper function
216
- model = load_checkpoint(model, ckpt_path, device, use_ema=use_ema)
217
- model.eval() # Ensure model is in evaluation mode
218
- return model
219
-
220
-
221
- # Text chunking function (from app.py)
222
- def chunk_text(text, max_chars):
223
- """
224
- Splits the input text into chunks based on punctuation and length limits.
225
- (Copied from previous answer, assumed correct)
226
- """
227
- if not isinstance(text, str):
228
- print("Warning: Input to chunk_text is not a string. Returning empty list.")
229
- return []
230
-
231
- if max_chars > 135:
232
- print(f"Warning: Calculated max_chars ({max_chars}) > 135. Capping at 135.")
233
- max_chars = 135
234
- if max_chars < 50:
235
- print(f"Warning: Calculated max_chars ({max_chars}) < 50. Setting to 50.")
236
- max_chars = 50
237
-
238
- split_after_space_chars = max_chars + int(max_chars * 0.33)
239
- chunks = []
240
- current_chunk = ""
241
-
242
- # Split the text into sentences based on punctuation followed by whitespace
243
- sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])\s*", text) # Added \s* after CJK punc
244
-
245
- for sentence in sentences:
246
- sentence = sentence.strip()
247
- if not sentence:
248
- continue
249
-
250
- # Estimate potential length increase due to space
251
- estimated_len = len(current_chunk) + len(sentence) + (1 if current_chunk else 0)
252
-
253
- if estimated_len <= max_chars:
254
- current_chunk += (" " + sentence) if current_chunk else sentence
255
- else:
256
- # Process the current_chunk if adding the new sentence exceeds max_chars
257
- while len(current_chunk) > split_after_space_chars:
258
- split_index = current_chunk.rfind(" ", 0, split_after_space_chars)
259
- if split_index == -1: split_index = split_after_space_chars
260
- chunks.append(current_chunk[:split_index].strip())
261
- current_chunk = current_chunk[split_index:].strip()
262
-
263
- if current_chunk:
264
- chunks.append(current_chunk)
265
-
266
- # Start new chunk, handle if sentence itself is too long
267
- while len(sentence) > split_after_space_chars:
268
- split_index = sentence.rfind(" ", 0, split_after_space_chars)
269
- if split_index == -1: split_index = split_after_space_chars
270
- chunks.append(sentence[:split_index].strip())
271
- sentence = sentence[split_index:].strip()
272
- current_chunk = sentence
273
-
274
- # Handle the last chunk
275
- while len(current_chunk) > split_after_space_chars:
276
- split_index = current_chunk.rfind(" ", 0, split_after_space_chars)
277
- if split_index == -1: split_index = split_after_space_chars
278
- chunks.append(current_chunk[:split_index].strip())
279
- current_chunk = current_chunk[split_index:].strip()
280
-
281
- if current_chunk:
282
- chunks.append(current_chunk.strip())
283
-
284
- return [c for c in chunks if c] # Filter empty chunks
285
-
286
-
287
- # Text to IPA function (from app.py)
288
- def text_to_ipa(text, language):
289
- """Converts text to IPA using phonemizer with espeak backend."""
290
- if not isinstance(text, str) or not text.strip():
291
- print(f"Warning: Invalid input text for IPA conversion: {text}")
292
- return "" # Return empty string for invalid input
293
- try:
294
- # Ensure phonemizer is installed: pip install phonemizer
295
- # Ensure espeak-ng is installed: sudo apt-get install espeak-ng (or equivalent)
296
- ipa_text = phonemize(
297
- text,
298
- language=language,
299
- backend='espeak',
300
- strip=False, # Keep punctuation
301
- preserve_punctuation=True,
302
- with_stress=True,
303
- language_switch='remove-flags', # Use this instead of regex removal
304
- njobs=1 # Set njobs=1 for potentially better stability/simpler debugging
305
  )
306
- # Specific removals (might be redundant with remove-flags, but kept for consistency)
307
- ipa_text = re.sub(r'tʃˈaɪniːzlˈe̞tə', '', ipa_text)
308
- ipa_text = re.sub(r'tʃˈaɪniːzɭˈetə', '', ipa_text)
309
- ipa_text = re.sub(r'dʒˈapəniːzlˈe̞tə', '', ipa_text)
310
- ipa_text = re.sub(r'dʒˈapəniːzɭˈetə', '', ipa_text)
311
-
312
- ipa_text = ipa_text.strip()
313
- # Replace multiple spaces with single space
314
- ipa_text = re.sub(r'\s+', ' ', ipa_text)
315
 
316
- print(f"Text: '{text}' | Lang: {language} | IPA: '{ipa_text}'")
317
- return ipa_text
318
- except ImportError:
319
- print("Error: 'phonemizer' library not found. Please install it: pip install phonemizer")
320
- raise
321
  except Exception as e:
322
- # Check if it's an espeak error (often happens if language is unsupported)
323
- if "espeak" in str(e).lower():
324
- print(f"Error: Espeak backend failed for language '{language}'. Is the language code valid and espeak-ng installed/supporting it?")
325
- print(f" Original error: {e}")
326
- else:
327
- print(f"Error phonemizing text: '{text}' with language '{language}'. Error: {e}")
328
- # Decide how to handle error
329
- raise ValueError(f"Phonemization failed for '{text}' ({language})") from e
330
-
331
-
332
- # --- End of functions from app.py ---
333
-
334
- # --- Argument Parser Setup ---
335
- # (Parser definition remains the same as previous refactored version)
336
- parser = argparse.ArgumentParser(
337
- prog="python3 inference-cli.py",
338
- description="Commandline interface for F5/E2 TTS.",
339
- )
340
- parser.add_argument(
341
- "-c", "--config", type=str, default="inference-cli.toml",
342
- help="Path to configuration file (TOML format). Default: inference-cli.toml"
343
- )
344
- # --- Arguments overriding config or providing inputs ---
345
- parser.add_argument( "--ckpt_path", type=str, default=None, help="Path or Hub ID (hf://...) to the TTS model checkpoint (.pt/.safetensors). Overrides config.")
346
- parser.add_argument( "--ref_audio", type=str, default=None, help="Path to the reference audio file (<10s recommended). Overrides config.")
347
- parser.add_argument( "--ref_text", type=str, default=None, help="Reference text. If omitted, Whisper transcription is used. Overrides config.")
348
- parser.add_argument( "--gen_text", type=str, default=None, help="Text to synthesize. Overrides config.")
349
- parser.add_argument( "--gen_file", type=str, default=None, help="File containing text to synthesize (overrides --gen_text and config).")
350
- parser.add_argument( "--output_dir", type=str, default=None, help="Directory to save output audio and spectrogram. Overrides config.")
351
- parser.add_argument( "--output_name", type=str, default="out", help="Base name for output files (e.g., 'my_speech' -> my_speech.wav, my_speech.png). Default: out.")
352
- # --- Parameter Arguments ---
353
- parser.add_argument( "--ref_language", type=str, default=None, help="Language code for reference text phonemization (e.g., 'en-us', 'pl', 'de'). Overrides config.")
354
- parser.add_argument( "--language", type=str, default=None, help="Language code for generated text phonemization (e.g., 'en-us', 'pl', 'de'). Overrides config.")
355
- parser.add_argument( "--speed", type=float, default=None, help="Speech speed multiplier. Overrides config.")
356
- parser.add_argument( "--nfe", type=int, default=None, help="Number of function evaluations (sampling steps). Overrides config.")
357
- parser.add_argument( "--cfg", type=float, default=None, help="Classifier-Free Guidance strength. Overrides config.")
358
- parser.add_argument( "--sway", type=float, default=None, help="Sway sampling coefficient. Overrides config.")
359
- parser.add_argument( "--cross_fade", type=float, default=None, help="Cross-fade duration between batches (seconds). Overrides config.")
360
- parser.add_argument( "--remove_silence", action=argparse.BooleanOptionalAction, default=None, help="Enable/disable final silence removal. Overrides config.")
361
- parser.add_argument( "--hf_token", type=str, default=None, help="Hugging Face API token (for downloading private models/checkpoints).")
362
- parser.add_argument( "--tokenizer_path", type=str, default=None, help="Path to the tokenizer.json file. Overrides config.")
363
- parser.add_argument( "--device", type=str, default=None, help="Device to use ('cuda', 'cpu', 'mps'). Auto-detects if not set.")
364
- parser.add_argument( "--dtype", type=str, default=None, help="Data type ('float16', 'bfloat16', 'float32'). Auto-selects if not set.")
365
-
366
- args = parser.parse_args()
367
 
368
- # --- Load Configuration ---
369
- config = {}
370
- if Path(args.config).exists():
371
  try:
372
- with open(args.config, "rb") as f:
373
- config = tomli.load(f)
374
- print(f"Loaded configuration from {args.config}")
375
- except Exception as e:
376
- print(f"Warning: Could not load config file {args.config}. Error: {e}")
377
- else:
378
- print(f"Warning: Config file {args.config} not found. Using defaults and CLI args.")
379
-
380
- # --- Determine Parameters (CLI > Config > Defaults) ---
381
- # (Parameter determination remains the same)
382
- ckpt_path = args.ckpt_path or config.get("ckpt_path", "hf://Gregniuki/F5-tts_English_German_Polish/multi3/model_900000.pt")
383
- ref_audio_path = args.ref_audio or config.get("ref_audio")
384
- ref_text = args.ref_text if args.ref_text is not None else config.get("ref_text", "")
385
- gen_text = args.gen_text or config.get("gen_text")
386
- gen_file = args.gen_file or config.get("gen_file")
387
- output_dir = Path(args.output_dir or config.get("output_dir", "."))
388
- output_name = args.output_name or config.get("output_name", "out")
389
-
390
- ref_language = args.ref_language or config.get("ref_language", "en-us")
391
- language = args.language or config.get("language", "en-us")
392
- speed = args.speed if args.speed is not None else config.get("speed", 1.0)
393
- nfe_step = args.nfe if args.nfe is not None else config.get("nfe", 32)
394
- cfg_strength = args.cfg if args.cfg is not None else config.get("cfg", 2.0)
395
- sway_sampling_coef = args.sway if args.sway is not None else config.get("sway", -1.0)
396
- cross_fade_duration = args.cross_fade if args.cross_fade is not None else config.get("cross_fade", 0.15)
397
- remove_silence_flag = args.remove_silence if args.remove_silence is not None else config.get("remove_silence", False)
398
- hf_token = args.hf_token or config.get("hf_token")
399
- tokenizer_path = args.tokenizer_path or config.get("tokenizer_path", "data/Emilia_ZH_EN_pinyin/tokenizer.json")
400
-
401
-
402
- # --- Validate Required Arguments ---
403
- if not ckpt_path: raise ValueError("Missing required argument/config: --ckpt_path")
404
- if not ref_audio_path: raise ValueError("Missing required argument/config: --ref_audio")
405
- if not gen_text and not gen_file: raise ValueError("Missing required argument/config: --gen_text or --gen_file")
406
-
407
- # --- Read gen_text from file if provided ---
408
- if gen_file:
409
- try:
410
- with codecs.open(gen_file, "r", "utf-8") as f: gen_text = f.read()
411
- print(f"Loaded generation text from {gen_file}")
412
- except Exception as e: raise ValueError(f"Error reading generation text file {gen_file}: {e}")
413
-
414
- # --- Setup Device and Dtype ---
415
- # (Device/Dtype setup remains the same)
416
- cli_device = args.device or config.get("device")
417
- if cli_device:
418
- device = torch.device(cli_device)
419
- else:
420
- device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
421
-
422
- cli_dtype = args.dtype or config.get("dtype")
423
- if cli_dtype:
424
- dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}
425
- if cli_dtype in dtype_map: dtype = dtype_map[cli_dtype]
426
- else: raise ValueError(f"Unsupported dtype: {cli_dtype}")
427
- else:
428
- if device.type == "cuda": dtype = torch.float16
429
- elif device.type == "cpu" and hasattr(torch.backends, 'cpu') and torch.backends.cpu.supports_bfloat16: dtype = torch.bfloat16
430
- else: dtype = torch.float32
431
 
432
- print(f"Using device: {device}, dtype: {dtype}")
433
-
434
- # --- Hugging Face Login ---
435
- if hf_token:
436
- print("Logging in to Hugging Face Hub...")
437
- try:
438
- login(token=hf_token)
439
- print("Logged in successfully.")
440
  except Exception as e:
441
- print(f"Warning: Hugging Face login failed: {e}")
442
-
443
-
444
- # --- Create Output Directory ---
445
- output_dir.mkdir(parents=True, exist_ok=True)
446
- wave_path = output_dir / f"{output_name}.wav"
447
- spectrogram_path = output_dir / f"{output_name}.png"
448
 
449
- # --- Load Models and Tokenizer ---
450
- print("Loading Tokenizer...")
451
- try:
452
- if not Path(tokenizer_path).exists():
453
- raise FileNotFoundError(f"Tokenizer file not found: {tokenizer_path}")
454
- tokenizer = Tokenizer.from_file(tokenizer_path)
455
- vocab_size = tokenizer.get_vocab_size()
456
- print(f"Tokenizer loaded successfully. Vocab size: {vocab_size}")
457
- except Exception as e:
458
- raise ValueError(f"Error loading tokenizer from {tokenizer_path}: {e}")
459
 
460
- print("Loading Vocoder...")
461
- # Pass device to load_vocoder
462
- vocos = load_vocoder(device=device) # Already includes .to(device).eval()
463
-
464
- print("Loading ASR Model (Whisper)...")
465
- try:
466
- whisper_dtype = torch.float16 if device.type == 'cuda' else torch.float32
467
- # Reduce default batch_size for Whisper CLI use
468
- pipe = pipeline(
469
- "automatic-speech-recognition",
470
- model="openai/whisper-large-v3-turbo",
471
- torch_dtype=whisper_dtype,
472
- device=device,
473
- model_kwargs={"attn_implementation": "sdpa"} # Use SDPA if available
474
- )
475
- print("Whisper model loaded.")
476
- except Exception as e:
477
- print(f"Warning: Could not load Whisper ASR model: {e}. Transcription will not be available.")
478
- pipe = None
479
-
480
- print("Loading TTS Model...")
481
- # --- Determine Model Class and Config ---
482
- # Example configs (ensure they match your actual model requirements)
483
- F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
484
- E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) # Add mel_dim/text_num_embeds if needed by class
485
-
486
- # Heuristic to determine model class (improve if needed)
487
- if "E2TTS" in ckpt_path or "UNetT" in ckpt_path:
488
- model_cls = UNetT
489
- model_cfg = E2TTS_model_cfg
490
- print(f"Assuming E2-TTS (UNetT) architecture for {ckpt_path}.")
491
- elif "F5TTS" in ckpt_path or "DiT" in ckpt_path:
492
- model_cls = DiT
493
- model_cfg = F5TTS_model_cfg
494
- print(f"Assuming F5-TTS (DiT) architecture for {ckpt_path}.")
495
- else:
496
- # Default or raise error if model type cannot be inferred
497
- print(f"Warning: Cannot infer model type from '{ckpt_path}'. Defaulting to DiT/F5TTS.")
498
- model_cls = DiT
499
- model_cfg = F5TTS_model_cfg
500
-
501
-
502
- try:
503
- # Pass vocab_size needed by load_custom
504
- ema_model = load_custom(model_cls, model_cfg, ckpt_path, vocab_size=vocab_size, device=device, use_ema=True)
505
- # Ensure model is using the target runtime dtype
506
- ema_model = ema_model.to(dtype=dtype)
507
- print(f"TTS Model loaded successfully ({model_cls.__name__}).")
508
- except Exception as e:
509
- print(f"Critical Error: Failed to load TTS model from {ckpt_path}: {e}")
510
- raise
511
-
512
- # --- Settings from app.py ---
513
- target_sample_rate = 24000
514
- n_mel_channels = model_cfg.get('mel_dim', 100) # Use mel_dim from config if available
515
- hop_length = 256
516
- target_rms = 0.1
517
-
518
- # --- Main Inference Logic ---
519
-
520
- def infer_batch(ref_audio_tuple, ref_text_ipa, gen_text_ipa_batches,
521
- ema_model, vocos, tokenizer,
522
- remove_silence_post, cross_fade_duration,
523
- nfe_step, cfg_strength, sway_sampling_coef, speed,
524
- target_sample_rate, hop_length, target_rms, device, dtype):
525
- """
526
- Generates audio batches based on reference and text inputs.
527
- (Function body remains the same as previous refactored version)
528
- """
529
- audio, sr = ref_audio_tuple
530
- audio = audio.to(device, dtype=dtype)
531
-
532
- # Preprocess reference audio (resample, RMS norm)
533
- if audio.shape[0] > 1: audio = torch.mean(audio, dim=0, keepdim=True)
534
- current_rms = torch.sqrt(torch.mean(torch.square(audio)))
535
- rms_applied_factor = 1.0 # Track scaling factor applied to ref
536
- if current_rms < target_rms and current_rms > 1e-5: # Add safety check for near-silent audio
537
- print(f"Reference audio RMS ({current_rms:.3f}) below target ({target_rms}). Normalizing.")
538
- rms_applied_factor = target_rms / current_rms
539
- audio = audio * rms_applied_factor
540
- elif current_rms <= 1e-5:
541
- print("Warning: Reference audio is near silent. Skipping RMS normalization.")
542
- else:
543
- print(f"Reference audio RMS ({current_rms:.3f}) >= target ({target_rms}). No normalization.")
544
-
545
- if sr != target_sample_rate:
546
- print(f"Resampling reference audio from {sr} Hz to {target_sample_rate} Hz.")
547
- resampler = torchaudio.transforms.Resample(sr, target_sample_rate).to(device)
548
- audio = resampler(audio)
549
-
550
- ref_audio_len_frames = audio.shape[-1] // hop_length
551
- print(f"Reference audio length: {audio.shape[-1]/target_sample_rate:.2f}s ({ref_audio_len_frames} frames)")
552
-
553
- generated_waves = []
554
- spectrograms = []
555
-
556
- progress_bar = tqdm(gen_text_ipa_batches, desc="Generating Batches")
557
- for i, gen_text_ipa in enumerate(progress_bar):
558
- progress_bar.set_postfix({"Batch": f"{i+1}/{len(gen_text_ipa_batches)}"})
559
-
560
- # Combine reference and generated IPA text
561
- combined_ipa_text = ref_text_ipa + " " + gen_text_ipa
562
- # print(f"Batch {i+1} Combined IPA: {combined_ipa_text}") # Debug
563
-
564
- # Tokenize
565
- try:
566
- # Tokenizer expects single string or list of strings
567
- encoding = tokenizer.encode(combined_ipa_text)
568
- tokens = encoding.ids
569
- token_str = encoding.tokens # For logging/debug
570
-
571
- # --- Model Input Formatting ---
572
- # Check how your specific model's `sample` method expects the 'text' input.
573
- # Option 1 (like app.py): String of space-separated tokens
574
- # token_input_string = ' '.join(map(str, token_str))
575
- # final_text_list = [token_input_string]
576
-
577
- # Option 2: List of token IDs (might be more common)
578
- # final_text_list = [tokens] # List containing the list/tensor of IDs
579
-
580
- # Option 3: Tensor of token IDs (check model docs)
581
- # Assuming model expects Option 1 based on app.py:
582
- token_input_string = ' '.join(map(str, token_str))
583
- final_text_list = [token_input_string]
584
- # print(f"Batch {i+1} Input Text List for Model: {final_text_list}")
585
-
586
- except Exception as e:
587
- print(f"Error tokenizing batch {i+1}: '{combined_ipa_text}'. Error: {e}")
588
- continue
589
-
590
- # Calculate duration
591
- ref_ipa_len = len(ref_text_ipa)
592
- gen_ipa_len = len(gen_text_ipa)
593
- if ref_ipa_len == 0: ref_ipa_len = 1 # Avoid division by zero
594
-
595
- duration_frames = ref_audio_len_frames + int(((ref_audio_len_frames / ref_ipa_len) * gen_ipa_len) / speed)
596
- min_duration_frames = max(10, target_sample_rate // hop_length // 4) # Shorter min duration (e.g. 0.25s)
597
- duration_frames = max(min_duration_frames, duration_frames)
598
- max_duration_frames = 40 * target_sample_rate // hop_length # Increase max duration slightly?
599
- if duration_frames > max_duration_frames:
600
- print(f"Warning: Calculated duration {duration_frames} frames exceeds max {max_duration_frames}. Capping.")
601
- duration_frames = max_duration_frames
602
-
603
- # print(f"Batch {i+1}: Duration={duration_frames} frames")
604
-
605
- # Inference
606
- try:
607
- with torch.inference_mode():
608
- cond_audio = audio.to(ema_model.device, dtype=dtype) # Match model device/dtype
609
- # print(f"Model device: {ema_model.device}, Cond audio device: {cond_audio.device}, dtype: {cond_audio.dtype}")
610
-
611
- generated_mel, _ = ema_model.sample(
612
- cond=cond_audio,
613
- text=final_text_list, # Pass formatted text input
614
- duration=duration_frames,
615
- steps=nfe_step,
616
- cfg_strength=cfg_strength,
617
- sway_sampling_coef=sway_sampling_coef,
618
- )
619
-
620
- # Process generated mel
621
- generated_mel = generated_mel.to(device, dtype=dtype) # Back to main device/dtype
622
- generated_mel = generated_mel[:, ref_audio_len_frames:, :]
623
- generated_mel_spec = rearrange(generated_mel, "1 n d -> 1 d n")
624
-
625
- # Vocoding
626
- # Vocos usually expects float32
627
- vocos_input_mel = generated_mel_spec.to(vocos.device, dtype=torch.float32)
628
- generated_wave = vocos.decode(vocos_input_mel)
629
- generated_wave = generated_wave.to(device, dtype=torch.float32)
630
-
631
- # Adjust RMS (Scale generated audio by the same factor applied to reference)
632
- generated_wave = generated_wave * rms_applied_factor
633
-
634
- # Convert to numpy
635
- generated_wave_np = generated_wave.squeeze().cpu().numpy()
636
- generated_waves.append(generated_wave_np)
637
- spectrograms.append(generated_mel_spec[0].cpu().to(torch.float32).numpy())
638
-
639
- except Exception as e:
640
- logging.exception(f"Error during inference/processing for batch {i+1}:") # Log traceback
641
- print(f"Error details: {e}")
642
- continue
643
-
644
- if not generated_waves:
645
- print("No audio waves were generated.")
646
- return None, None
647
-
648
- # Combine batches
649
- print(f"Combining {len(generated_waves)} generated batches...")
650
- if cross_fade_duration <= 0 or len(generated_waves) == 1:
651
- final_wave = np.concatenate(generated_waves)
652
- else:
653
- # (Cross-fading logic remains the same)
654
- final_wave = generated_waves[0]
655
- for i in range(1, len(generated_waves)):
656
- prev_wave = final_wave; next_wave = generated_waves[i]
657
- cf_samples = min(int(cross_fade_duration * target_sample_rate), len(prev_wave), len(next_wave))
658
- if cf_samples <= 0: final_wave = np.concatenate([prev_wave, next_wave]); continue
659
- p_olap = prev_wave[-cf_samples:]; n_olap = next_wave[:cf_samples]
660
- f_out = np.linspace(1, 0, cf_samples, dtype=p_olap.dtype); f_in = np.linspace(0, 1, cf_samples, dtype=n_olap.dtype)
661
- cf_olap = p_olap * f_out + n_olap * f_in
662
- final_wave = np.concatenate([prev_wave[:-cf_samples], cf_olap, next_wave[cf_samples:]])
663
- print(f"Applied cross-fade of {cross_fade_duration:.2f}s between batches.")
664
-
665
- # Optional: Remove silence post-combination
666
- if remove_silence_post:
667
- print("Removing silence from final output...")
668
- try:
669
- final_wave_float32 = final_wave.astype(np.float32)
670
- with tempfile.NamedTemporaryFile(delete=True, suffix=".wav") as tmp_wav:
671
- sf.write(tmp_wav.name, final_wave_float32, target_sample_rate)
672
- aseg = AudioSegment.from_file(tmp_wav.name)
673
- non_silent_segs = silence.split_on_silence(
674
- aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500
675
- )
676
- if not non_silent_segs:
677
- print("Warning: Silence removal resulted in empty audio. Keeping original.")
678
- else:
679
- non_silent_wave = sum(non_silent_segs, AudioSegment.silent(duration=0))
680
- non_silent_wave.export(tmp_wav.name, format="wav")
681
- final_wave_tensor, _ = torchaudio.load(tmp_wav.name)
682
- final_wave = final_wave_tensor.squeeze().cpu().numpy()
683
- print("Silence removal applied.")
684
- except Exception as e:
685
- print(f"Warning: Failed to remove silence: {e}. Using original.")
686
-
687
- # Combine spectrograms
688
- print("Combining spectrograms...")
689
- try:
690
- if spectrograms:
691
- combined_spectrogram = np.concatenate(spectrograms, axis=1)
692
- else:
693
- combined_spectrogram = None
694
- except ValueError as e:
695
- print(f"Warning: Could not concatenate spectrograms: {e}. Skipping.")
696
- combined_spectrogram = None
697
-
698
- return final_wave, combined_spectrogram
699
-
700
-
701
- def main_infer(ref_audio_orig_path, ref_text_input, gen_text_full,
702
- ema_model, vocos, tokenizer, pipe_asr, # Loaded models/utils
703
- ref_language, language, # Languages
704
- speed, nfe_step, cfg_strength, sway_sampling_coef, # Sampling params
705
- remove_silence_flag, cross_fade_duration, # Postprocessing
706
- target_sample_rate, hop_length, target_rms, # Audio params
707
- device, dtype): # System params
708
- """
709
- Main inference function coordinating preprocessing, batching, and generation.
710
- (Function body remains the same as previous refactored version)
711
- """
712
- print(f"Starting inference for text: '{gen_text_full[:100]}...'")
713
-
714
- # --- Reference Audio Preprocessing ---
715
- print("Processing reference audio...")
716
- processed_ref_path = None
717
- try:
718
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_ref_wav:
719
- processed_ref_path = temp_ref_wav.name # Store path for potential use
720
- aseg = AudioSegment.from_file(ref_audio_orig_path)
721
- print(f"Original ref duration: {len(aseg)/1000:.2f}s")
722
-
723
- # Edge silence removal + padding
724
- aseg = remove_silence_edges(aseg)
725
- aseg += AudioSegment.silent(duration=150)
726
-
727
- # Split/recombine on silence
728
- non_silent_segs = silence.split_on_silence(
729
- aseg, min_silence_len=700, silence_thresh=-50, keep_silence=700
730
- )
731
- if non_silent_segs:
732
- aseg = sum(non_silent_segs, AudioSegment.silent(duration=0)) # Use sum for conciseness
733
- else:
734
- print("Warning: Silence splitting/recombining resulted in empty audio. Using edge-trimmed.")
735
-
736
- # Clip to 10s
737
- max_ref_duration_ms = 10000
738
- if len(aseg) > max_ref_duration_ms:
739
- print(f"Reference audio exceeds {max_ref_duration_ms/1000}s. Clipping...")
740
- aseg = aseg[:max_ref_duration_ms]
741
-
742
- aseg.export(processed_ref_path, format="wav")
743
- print(f"Processed ref duration: {len(aseg)/1000:.2f}s. Saved to temp file: {processed_ref_path}")
744
-
745
- # Load processed audio tensor
746
- ref_audio_tensor, sr_ref = torchaudio.load(processed_ref_path)
747
-
748
- except Exception as e:
749
- print(f"Error processing reference audio {ref_audio_orig_path}: {e}")
750
- if processed_ref_path and Path(processed_ref_path).exists():
751
- Path(processed_ref_path).unlink() # Clean up temp file on error
752
- raise
753
-
754
- # --- Reference Text Handling ---
755
- ref_text_processed = ""
756
- if not ref_text_input or ref_text_input.strip() == "":
757
- print("No reference text provided. Transcribing reference audio...")
758
- if pipe_asr is None:
759
- raise ValueError("Whisper ASR model not loaded. Cannot transcribe. Please provide --ref_text.")
760
- if not processed_ref_path:
761
- raise ValueError("Processed reference audio path is missing for transcription.")
762
- try:
763
- # Ensure Whisper input dtype matches its loaded dtype
764
- whisper_input_dtype = pipe_asr.model.dtype
765
-
766
- # Load audio specifically for Whisper if dtypes differ significantly
767
- # Or rely on pipeline handling. Assuming pipeline handles it for now.
768
- print(f"Transcribing: {processed_ref_path}")
769
- transcription_result = pipe_asr(
770
- processed_ref_path,
771
- chunk_length_s=15,
772
- batch_size=8, # Smaller batch size for CLI
773
- generate_kwargs={"task": "transcribe", "language": None}, # Whisper language detection
774
- return_timestamps=False,
775
- )
776
- ref_text_processed = transcription_result["text"].strip()
777
- print(f"Transcription finished: '{ref_text_processed}'")
778
- if not ref_text_processed:
779
- print("Warning: Transcription resulted in empty text. Using placeholder.")
780
- ref_text_processed = "Reference audio"
781
- except Exception as e:
782
- logging.exception("Error during transcription:")
783
- raise ValueError("Transcription failed. Please provide --ref_text.")
784
- else:
785
- print("Using provided reference text.")
786
- ref_text_processed = ref_text_input
787
-
788
- # Clean up the temporary processed reference audio file
789
- if processed_ref_path and Path(processed_ref_path).exists():
790
- try:
791
- Path(processed_ref_path).unlink()
792
- # print(f"Cleaned up temp ref file: {processed_ref_path}") # Debug
793
- except OSError as e:
794
- print(f"Warning: Could not delete temp ref file {processed_ref_path}: {e}")
795
-
796
-
797
- # Ensure reference text ends with ". "
798
- if not ref_text_processed.endswith(". "):
799
- ref_text_processed = ref_text_processed.rstrip('. ') + ". " # More robust way
800
- print(f"Final Reference Text: '{ref_text_processed}'")
801
-
802
- # --- Phonemize Reference Text ---
803
- print(f"Phonemizing reference text with language: {ref_language}")
804
- ref_text_ipa = text_to_ipa(ref_text_processed, language=ref_language)
805
- if not ref_text_ipa: raise ValueError("Reference text phonemization failed.")
806
-
807
- # --- Chunk and Phonemize Generation Text ---
808
- ref_audio_duration_sec = ref_audio_tensor.shape[-1] / sr_ref if sr_ref > 0 else 1.0
809
- if ref_audio_duration_sec <= 0: ref_audio_duration_sec = 1.0
810
- chars_per_sec = len(ref_text_processed.encode('utf-8')) / ref_audio_duration_sec if ref_audio_duration_sec > 0 else 10.0
811
- if chars_per_sec <= 0: chars_per_sec = 10.0
812
- target_chunk_duration_sec = max(5.0, 20.0 - ref_audio_duration_sec)
813
- max_chars = int(chars_per_sec * target_chunk_duration_sec)
814
-
815
- print(f"Ref duration: {ref_audio_duration_sec:.2f}s => Calculated max_chars/batch: {max_chars}")
816
- gen_text_batches_plain = chunk_text(gen_text_full, max_chars=max_chars)
817
- if not gen_text_batches_plain: raise ValueError("Text chunking resulted in zero batches.")
818
- print(f"Split generation text into {len(gen_text_batches_plain)} batches.")
819
-
820
- print(f"Phonemizing generation text batches with language: {language}")
821
- gen_text_ipa_batches = []
822
- for i, batch_text in enumerate(gen_text_batches_plain):
823
- # print(f" Phonemizing batch {i+1}/{len(gen_text_batches_plain)}...") # Verbose
824
- batch_ipa = text_to_ipa(batch_text, language=language)
825
- if batch_ipa: gen_text_ipa_batches.append(batch_ipa)
826
- else: print(f"Warning: Skipping batch {i+1} due to phonemization failure.")
827
-
828
- if not gen_text_ipa_batches: raise ValueError("Phonemization failed for all generation text batches.")
829
-
830
- # --- Run Batched Inference ---
831
- print(f"Starting batch inference process ({len(gen_text_ipa_batches)} batches)...")
832
- final_wave, combined_spectrogram = infer_batch(
833
- (ref_audio_tensor, sr_ref), ref_text_ipa, gen_text_ipa_batches,
834
- ema_model, vocos, tokenizer,
835
- remove_silence_flag, cross_fade_duration,
836
- nfe_step, cfg_strength, sway_sampling_coef, speed,
837
- target_sample_rate, hop_length, target_rms,
838
- device, dtype
839
- )
840
-
841
- return final_wave, combined_spectrogram
842
-
843
-
844
- # --- Execution ---
845
  if __name__ == "__main__":
846
- # Setup logging
847
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
848
-
849
- try:
850
- final_wave_np, combined_spectrogram_np = main_infer(
851
- ref_audio_path, ref_text, gen_text,
852
- ema_model, vocos, tokenizer, pipe,
853
- ref_language, language,
854
- speed, nfe_step, cfg_strength, sway_sampling_coef,
855
- remove_silence_flag, cross_fade_duration,
856
- target_sample_rate, hop_length, target_rms,
857
- device, dtype
858
- )
859
-
860
- # --- Save Outputs ---
861
- output_saved = False
862
- if final_wave_np is not None and len(final_wave_np) > 0:
863
- print(f"Saving final audio ({len(final_wave_np)/target_sample_rate:.2f}s) to {wave_path}...")
864
- final_wave_float32 = final_wave_np.astype(np.float32) # Ensure float32 for sf
865
- sf.write(str(wave_path), final_wave_float32, target_sample_rate)
866
- print("Audio saved successfully.")
867
- output_saved = True
868
- else:
869
- print("Inference did not produce a valid audio wave.")
870
-
871
- if combined_spectrogram_np is not None:
872
- print(f"Saving combined spectrogram to {spectrogram_path}...")
873
- save_spectrogram(combined_spectrogram_np, str(spectrogram_path))
874
- print("Spectrogram saved successfully.")
875
- output_saved = True
876
- # else: # No need to print if spectrogram was None
877
- # print("Spectrogram generation failed or was skipped.")
878
-
879
- if not output_saved:
880
- print("No output files were generated.")
881
-
882
- except FileNotFoundError as e:
883
- logging.error(f"File not found: {e}")
884
- print(f"\nError: A required file was not found. Please check paths. Details: {e}")
885
- exit(1)
886
- except ValueError as e:
887
- logging.error(f"Value error: {e}")
888
- print(f"\nError: An invalid value or configuration was encountered. Details: {e}")
889
- exit(1)
890
- except Exception as e:
891
- logging.exception("An unexpected error occurred during inference:") # Log traceback
892
- print(f"\nAn unexpected error occurred: {e}")
893
- exit(1)
894
 
895
- print("\nInference completed.")
 
1
+ # --- START OF FILE inference_cli.py ---
2
+
3
  import argparse
4
+ import shutil
 
 
 
 
 
5
  import soundfile as sf
6
+ import os # For path manipulation if needed
7
+ import sys # To potentially add app.py directory to path
 
 
 
 
 
 
 
 
8
 
9
+ # Try to import app.py - assumes it's in the same directory or Python path
 
10
  try:
11
+ # If app.py is not directly importable, you might need to add its directory to the path
12
+ # Example: sys.path.append(os.path.dirname(os.path.abspath(__file__))) # Add current dir
13
+ import app
14
+ from app import infer # Import the main inference function
15
+ except ImportError as e:
16
+ print(f"Error: Could not import 'app.py'. Make sure it's in the Python path.")
17
+ print(f"Details: {e}")
18
+ sys.exit(1)
19
+ except Exception as e:
20
+ print(f"An unexpected error occurred during 'app.py' import: {e}")
21
+ sys.exit(1)
22
+
23
+
24
+ def main():
25
+ parser = argparse.ArgumentParser(description="F5 TTS - Simplified CLI Interface using app.py")
26
+
27
+ # --- Input Arguments ---
28
+ parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file (wav, mp3, etc.)")
29
+ parser.add_argument("--ref_text", default="", help="Reference text. If empty, audio transcription will be performed by app.py's infer function.")
30
+ parser.add_argument("--gen_text", required=True, help="Text to generate")
31
+
32
+ # --- Model & Generation Parameters ---
33
+ # Note: app.py seems hardcoded to load the "Multi" model at the top level.
34
+ # This argument might not change the loaded model unless app.py's infer logic uses it internally.
35
+ parser.add_argument("--exp_name", default="Multi", help="Experiment name / model selection (default: Multi - effectiveness depends on app.py)")
36
+ parser.add_argument("--language", default="en-us", help="Synthesized language code (e.g., en-us, pl, de) (default: en-us)")
37
+ parser.add_argument("--ref_language", default="en-us", help="Reference language code (e.g., en-us, pl, de) (default: en-us)")
38
+ parser.add_argument("--speed", type=float, default=1.0, help="Audio speed factor (default: 1.0)")
39
+
40
+ # --- Postprocessing ---
41
+ parser.add_argument("--remove_silence", action="store_true", help="Remove silence from the output audio (uses app.py logic)")
42
+ parser.add_argument("--cross_fade_duration", type=float, default=0.15, help="Cross-fade duration between batches (s)")
43
+
44
+ # --- Output Arguments ---
45
+ parser.add_argument("--output_audio", default="output.wav", help="Path to save the output WAV file")
46
+ parser.add_argument("--output_spectrogram", default="spectrogram.png", help="Path to save the spectrogram image (PNG)")
47
+
48
+ args = parser.parse_args()
49
+
50
+ print("--- Configuration ---")
51
+ print(f"Reference Audio: {args.ref_audio}")
52
+ print(f"Reference Text: '{args.ref_text if args.ref_text else '<Automatic Transcription>'}'")
53
+ print(f"Generation Text: '{args.gen_text[:100]}...'")
54
+ print(f"Model (exp_name): {args.exp_name}")
55
+ print(f"Synth Language: {args.language}")
56
+ print(f"Ref Language: {args.ref_language}")
57
+ print(f"Speed: {args.speed}")
58
+ print(f"Remove Silence: {args.remove_silence}")
59
+ print(f"Cross-Fade: {args.cross_fade_duration}s")
60
+ print(f"Output Audio: {args.output_audio}")
61
+ print(f"Output Spectrogram: {args.output_spectrogram}")
62
+ print("--------------------")
63
+
64
+ # --- Set Global Variables in app.py ---
65
+ # The 'infer' function in app.py relies on these globals being set.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  try:
67
+ print(f"Setting language in app module to: {args.language}")
68
+ app.language = args.language
69
+ print(f"Setting ref_language in app module to: {args.ref_language}")
70
+ app.ref_language = args.ref_language
71
+ print(f"Setting speed in app module to: {args.speed}")
72
+ app.speed = args.speed
73
+ except AttributeError as e:
74
+ print(f"Error: Could not set global variable in 'app.py'. Does it exist? Details: {e}")
75
+ sys.exit(1)
76
+
77
+ # --- Run Inference ---
78
+ print("\nStarting inference process (will load models if not already loaded)...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  try:
80
+ # Call the infer function directly from the imported app module
81
+ (sr, audio_data), temp_spectrogram_path = infer(
82
+ ref_audio_orig=args.ref_audio,
83
+ ref_text=args.ref_text,
84
+ gen_text=args.gen_text,
85
+ exp_name=args.exp_name,
86
+ remove_silence=args.remove_silence,
87
+ cross_fade_duration=args.cross_fade_duration
88
+ # Note: language, ref_language, speed are used globally within app.py's functions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  )
90
+ print("Inference completed.")
 
 
 
 
 
 
 
 
91
 
 
 
 
 
 
92
  except Exception as e:
93
+ print(f"\nError during inference: {e}")
94
+ import traceback
95
+ traceback.print_exc() # Print detailed traceback
96
+ sys.exit(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
+ # --- Save Outputs ---
 
 
99
  try:
100
+ # Save audio
101
+ print(f"Saving audio to: {args.output_audio}")
102
+ # Ensure directory exists
103
+ os.makedirs(os.path.dirname(os.path.abspath(args.output_audio)) or '.', exist_ok=True)
104
+ # Ensure data is float32 for soundfile
105
+ if audio_data.dtype != "float32":
106
+ audio_data = audio_data.astype("float32")
107
+ sf.write(args.output_audio, audio_data, sr)
108
+
109
+ # Copy spectrogram from the temporary path returned by infer
110
+ print(f"Copying spectrogram from {temp_spectrogram_path} to: {args.output_spectrogram}")
111
+ # Ensure directory exists
112
+ os.makedirs(os.path.dirname(os.path.abspath(args.output_spectrogram)) or '.', exist_ok=True)
113
+ shutil.copy(temp_spectrogram_path, args.output_spectrogram)
114
+
115
+ print("\n--- Success ---")
116
+ print(f"Audio saved in: {args.output_audio}")
117
+ print(f"Spectrogram saved in: {args.output_spectrogram}")
118
+ print("---------------")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
 
 
 
 
 
 
 
 
120
  except Exception as e:
121
+ print(f"\nError saving output files: {e}")
122
+ sys.exit(1)
 
 
 
 
 
123
 
124
+ # Optional: Clean up the temporary spectrogram file if needed,
125
+ # but NamedTemporaryFile usually handles this if delete=True was used in app.py
126
+ # try:
127
+ # if os.path.exists(temp_spectrogram_path):
128
+ # os.remove(temp_spectrogram_path)
129
+ # except Exception as e:
130
+ # print(f"Warning: Could not clean up temporary spectrogram file {temp_spectrogram_path}: {e}")
 
 
 
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  if __name__ == "__main__":
133
+ main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
+ # --- END OF FILE inference_cli.py ---