ollieollie KingNish commited on
Commit
9386371
·
verified ·
1 Parent(s): 4ce9740

Refactored Code (#3)

Browse files

- Refactored Code (0f266a9994c6faa8fbd5f0acb3ab0578d8bd43ee)


Co-authored-by: Nishith Jain <[email protected]>

Files changed (1) hide show
  1. app.py +63 -41
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import random
2
  import numpy as np
3
  import torch
4
- from chatterbox.src.chatterbox.tts import ChatterboxTTS # Assuming this path is correct
5
  import gradio as gr
6
  import spaces
7
 
@@ -9,38 +9,32 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
9
  print(f"🚀 Running on device: {DEVICE}")
10
 
11
  # --- Global Model Initialization ---
12
- # Load the model once when the application starts.
13
- # This model will be accessible by the @spaces.GPU decorated function.
14
  MODEL = None
15
 
16
  def get_or_load_model():
 
 
17
  global MODEL
18
  if MODEL is None:
19
- print("Global MODEL is None, loading...")
20
  try:
21
  MODEL = ChatterboxTTS.from_pretrained(DEVICE)
22
- # Ensure model is on the correct device if not handled by from_pretrained
23
- if DEVICE == "cuda" and hasattr(MODEL, 'to'):
24
  MODEL.to(DEVICE)
25
- print(f"Global MODEL loaded. Device: {DEVICE}")
26
- if hasattr(MODEL, 'device'): # If the model object has a device attribute
27
- print(f"Model internal device attribute: {MODEL.device}")
28
  except Exception as e:
29
- print(f"Error loading global model: {e}")
30
  raise
31
  return MODEL
32
 
33
  # Attempt to load the model at startup.
34
- # If this fails, the app will likely fail to start, which is informative.
35
  try:
36
  get_or_load_model()
37
  except Exception as e:
38
- # Handle critical model loading failure if necessary, or let it propagate
39
- print(f"CRITICAL: Failed to load model on startup. Error: {e}")
40
- # You might want to display an error in Gradio if this happens,
41
- # but for now, a print is fine for debugging.
42
 
43
  def set_seed(seed: int):
 
44
  torch.manual_seed(seed)
45
  if DEVICE == "cuda":
46
  torch.cuda.manual_seed(seed)
@@ -48,46 +42,78 @@ def set_seed(seed: int):
48
  random.seed(seed)
49
  np.random.seed(seed)
50
 
51
- @spaces.GPU # Your GPU-accelerated function
52
- def generate_tts_audio(text_input, audio_prompt_path_input, exaggeration_input, temperature_input, seed_num_input, cfgw_input):
53
- current_model = get_or_load_model() # Access the global model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  if current_model is None:
56
- # This should ideally not happen if startup loading was successful
57
- # Or, it indicates an issue with the global model pattern in this specific env.
58
- raise RuntimeError("Model could not be loaded or accessed.")
59
 
60
  if seed_num_input != 0:
61
  set_seed(int(seed_num_input))
62
 
63
- print(f"Generating audio for text: '{text_input}'")
64
  wav = current_model.generate(
65
- text_input[:300],
66
  audio_prompt_path=audio_prompt_path_input,
67
  exaggeration=exaggeration_input,
68
  temperature=temperature_input,
69
  cfg_weight=cfgw_input,
70
  )
71
  print("Audio generation complete.")
72
- # ONLY return pickleable data
73
  return (current_model.sr, wav.squeeze(0).numpy())
74
 
75
-
76
  with gr.Blocks() as demo:
77
- # No gr.State needed for the model object if it's managed globally
78
- # and not passed back and forth.
79
-
 
 
 
80
  with gr.Row():
81
  with gr.Column():
82
- text = gr.Textbox(value="Now let's make my mum's favourite. So three mars bars into the pan. Then we add the tuna and just stir for a bit, just let the chocolate and fish infuse. A sprinkle of olive oil and some tomato ketchup. Now smell that. Oh boy this is going to be incredible.", label="Text to synthesize (max chars 300)")
83
- ref_wav = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Reference Audio File", value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart.flac")
84
- exaggeration = gr.Slider(0.25, 2, step=.05, label="Exaggeration (Neutral = 0.5, extreme values can be unstable)", value=.5)
85
- cfg_weight = gr.Slider(0.2, 1, step=.05, label="CFG/Pace", value=0.5)
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  with gr.Accordion("More options", open=False):
88
  seed_num = gr.Number(value=0, label="Random seed (0 for random)")
89
- temp = gr.Slider(0.05, 5, step=.05, label="temperature", value=.8)
90
-
91
 
92
  run_btn = gr.Button("Generate", variant="primary")
93
 
@@ -95,9 +121,8 @@ with gr.Blocks() as demo:
95
  audio_output = gr.Audio(label="Output Audio")
96
 
97
  run_btn.click(
98
- fn=generate_tts_audio, # Use the new function name
99
  inputs=[
100
- # model_state, # Removed: model is now global
101
  text,
102
  ref_wav,
103
  exaggeration,
@@ -105,10 +130,7 @@ with gr.Blocks() as demo:
105
  seed_num,
106
  cfg_weight,
107
  ],
108
- outputs=[audio_output], # Only outputting the audio data
109
  )
110
 
111
- demo.queue(
112
- max_size=50,
113
- default_concurrency_limit=1, # Important for a single global model
114
- ).launch() # share=True is not needed and causes a warning on Spaces
 
1
  import random
2
  import numpy as np
3
  import torch
4
+ from chatterbox.src.chatterbox.tts import ChatterboxTTS
5
  import gradio as gr
6
  import spaces
7
 
 
9
  print(f"🚀 Running on device: {DEVICE}")
10
 
11
  # --- Global Model Initialization ---
 
 
12
  MODEL = None
13
 
14
  def get_or_load_model():
15
+ """Loads the ChatterboxTTS model if it hasn't been loaded already,
16
+ and ensures it's on the correct device."""
17
  global MODEL
18
  if MODEL is None:
19
+ print("Model not loaded, initializing...")
20
  try:
21
  MODEL = ChatterboxTTS.from_pretrained(DEVICE)
22
+ if hasattr(MODEL, 'to') and str(MODEL.device) != DEVICE:
 
23
  MODEL.to(DEVICE)
24
+ print(f"Model loaded successfully. Internal device: {getattr(MODEL, 'device', 'N/A')}")
 
 
25
  except Exception as e:
26
+ print(f"Error loading model: {e}")
27
  raise
28
  return MODEL
29
 
30
  # Attempt to load the model at startup.
 
31
  try:
32
  get_or_load_model()
33
  except Exception as e:
34
+ print(f"CRITICAL: Failed to load model on startup. Application may not function. Error: {e}")
 
 
 
35
 
36
  def set_seed(seed: int):
37
+ """Sets the random seed for reproducibility across torch, numpy, and random."""
38
  torch.manual_seed(seed)
39
  if DEVICE == "cuda":
40
  torch.cuda.manual_seed(seed)
 
42
  random.seed(seed)
43
  np.random.seed(seed)
44
 
45
+ @spaces.GPU
46
+ def generate_tts_audio(
47
+ text_input: str,
48
+ audio_prompt_path_input: str,
49
+ exaggeration_input: float,
50
+ temperature_input: float,
51
+ seed_num_input: int,
52
+ cfgw_input: float
53
+ ) -> tuple[int, np.ndarray]:
54
+ """
55
+ Generates TTS audio using the ChatterboxTTS model.
56
+
57
+ Args:
58
+ text_input: The text to synthesize (max 300 characters).
59
+ audio_prompt_path_input: Path to the reference audio file.
60
+ exaggeration_input: Exaggeration parameter for the model.
61
+ temperature_input: Temperature parameter for the model.
62
+ seed_num_input: Random seed (0 for random).
63
+ cfgw_input: CFG/Pace weight.
64
+
65
+ Returns:
66
+ A tuple containing the sample rate (int) and the audio waveform (numpy.ndarray).
67
+ """
68
+ current_model = get_or_load_model()
69
 
70
  if current_model is None:
71
+ raise RuntimeError("TTS model is not loaded.")
 
 
72
 
73
  if seed_num_input != 0:
74
  set_seed(int(seed_num_input))
75
 
76
+ print(f"Generating audio for text: '{text_input[:50]}...'")
77
  wav = current_model.generate(
78
+ text_input[:300], # Truncate text to max chars
79
  audio_prompt_path=audio_prompt_path_input,
80
  exaggeration=exaggeration_input,
81
  temperature=temperature_input,
82
  cfg_weight=cfgw_input,
83
  )
84
  print("Audio generation complete.")
 
85
  return (current_model.sr, wav.squeeze(0).numpy())
86
 
 
87
  with gr.Blocks() as demo:
88
+ gr.Markdown(
89
+ """
90
+ # Chatterbox TTS Demo
91
+ Generate high-quality speech from text with reference audio styling.
92
+ """
93
+ )
94
  with gr.Row():
95
  with gr.Column():
96
+ text = gr.Textbox(
97
+ value="Now let's make my mum's favourite. So three mars bars into the pan. Then we add the tuna and just stir for a bit, just let the chocolate and fish infuse. A sprinkle of olive oil and some tomato ketchup. Now smell that. Oh boy this is going to be incredible.",
98
+ label="Text to synthesize (max chars 300)",
99
+ max_lines=5
100
+ )
101
+ ref_wav = gr.Audio(
102
+ sources=["upload", "microphone"],
103
+ type="filepath",
104
+ label="Reference Audio File (Optional)",
105
+ value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart.flac"
106
+ )
107
+ exaggeration = gr.Slider(
108
+ 0.25, 2, step=.05, label="Exaggeration (Neutral = 0.5, extreme values can be unstable)", value=.5
109
+ )
110
+ cfg_weight = gr.Slider(
111
+ 0.2, 1, step=.05, label="CFG/Pace", value=0.5
112
+ )
113
 
114
  with gr.Accordion("More options", open=False):
115
  seed_num = gr.Number(value=0, label="Random seed (0 for random)")
116
+ temp = gr.Slider(0.05, 5, step=.05, label="Temperature", value=.8)
 
117
 
118
  run_btn = gr.Button("Generate", variant="primary")
119
 
 
121
  audio_output = gr.Audio(label="Output Audio")
122
 
123
  run_btn.click(
124
+ fn=generate_tts_audio,
125
  inputs=[
 
126
  text,
127
  ref_wav,
128
  exaggeration,
 
130
  seed_num,
131
  cfg_weight,
132
  ],
133
+ outputs=[audio_output],
134
  )
135
 
136
+ demo.launch()