ASesYusuf1 commited on
Commit
3c66c3b
·
verified ·
1 Parent(s): 0de9a65

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +42 -9
utils.py CHANGED
@@ -132,11 +132,22 @@ def apply_tta(
132
  mix: torch.Tensor,
133
  waveforms_orig: Dict[str, torch.Tensor],
134
  device: str,
135
- model_type: str
 
136
  ) -> Dict[str, torch.Tensor]:
137
  track_proc_list = [mix[::-1].clone(), -mix.clone()]
 
 
 
138
  for i, augmented_mix in enumerate(track_proc_list):
139
- waveforms = demix(config, model, augmented_mix, device, model_type=model_type, pbar=False)
 
 
 
 
 
 
 
140
  for el in waveforms:
141
  if i == 0:
142
  waveforms_orig[el] += waveforms[el][::-1].clone()
@@ -146,8 +157,15 @@ def apply_tta(
146
  gc.collect()
147
  if device.startswith('cuda'):
148
  torch.cuda.empty_cache()
 
149
  for el in waveforms_orig:
150
  waveforms_orig[el] /= (len(track_proc_list) + 1)
 
 
 
 
 
 
151
  return waveforms_orig
152
 
153
  def _getWindowingArray(window_size: int, fade_size: int) -> torch.Tensor:
@@ -164,7 +182,8 @@ def demix(
164
  mix: torch.Tensor,
165
  device: str,
166
  model_type: str,
167
- pbar: bool = False
 
168
  ) -> Dict[str, np.ndarray]:
169
  logging.info(f"Starting demix for model_type: {model_type}, chunk_size: {config.audio.chunk_size}")
170
 
@@ -196,6 +215,10 @@ def demix(
196
  model = model.to(device)
197
  model.eval()
198
 
 
 
 
 
199
  with torch.no_grad(): # Çıkarım için gradyan yok
200
  with torch.cuda.amp.autocast(enabled=device.startswith('cuda'), dtype=torch.float16):
201
  req_shape = (num_instruments,) + mix.shape
@@ -205,7 +228,7 @@ def demix(
205
  i = 0
206
  batch_data = []
207
  batch_locations = []
208
- progress_bar = tqdm(total=mix.shape[1], desc="Processing audio chunks", leave=False) if pbar else None
209
 
210
  while i < mix.shape[1]:
211
  part = mix[:, i:i + chunk_size]
@@ -240,6 +263,13 @@ def demix(
240
  result[..., start:start + seg_len] += x[j, ..., :seg_len]
241
  counter[..., start:start + seg_len] += 1.0
242
 
 
 
 
 
 
 
 
243
  del arr, x
244
  batch_data.clear()
245
  batch_locations.clear()
@@ -248,11 +278,8 @@ def demix(
248
  torch.cuda.empty_cache()
249
  logging.info("Cleared CUDA cache")
250
 
251
- if progress_bar:
252
- progress_bar.update(step)
253
-
254
- if progress_bar:
255
- progress_bar.close()
256
 
257
  estimated_sources = result / (counter + 1e-8)
258
  estimated_sources = estimated_sources.numpy().astype(np.float32)
@@ -264,6 +291,12 @@ def demix(
264
  instruments = config.training.instruments if mode == "demucs" else prefer_target_instrument(config)
265
  ret_data = {k: v for k, v in zip(instruments, estimated_sources)}
266
  logging.info("Demix completed successfully")
 
 
 
 
 
 
267
  return ret_data
268
 
269
  def prefer_target_instrument(config: ConfigDict) -> List[str]:
 
132
  mix: torch.Tensor,
133
  waveforms_orig: Dict[str, torch.Tensor],
134
  device: str,
135
+ model_type: str,
136
+ progress=None # Gradio progress nesnesi
137
  ) -> Dict[str, torch.Tensor]:
138
  track_proc_list = [mix[::-1].clone(), -mix.clone()]
139
+ total_steps = len(track_proc_list)
140
+ processed_steps = 0
141
+
142
  for i, augmented_mix in enumerate(track_proc_list):
143
+ # TTA adımı için ilerleme güncellemesi
144
+ processed_steps += 1
145
+ progress_value = round((processed_steps / total_steps) * 50) # TTA için 0-50% aralığı
146
+ if progress is not None and callable(getattr(progress, '__call__', None)):
147
+ progress(progress_value / 100, desc=f"Applying TTA step {processed_steps}/{total_steps}")
148
+ update_progress_html(f"Applying TTA step {processed_steps}/{total_steps}", progress_value)
149
+
150
+ waveforms = demix(config, model, augmented_mix, device, model_type=model_type, pbar=False, progress=progress)
151
  for el in waveforms:
152
  if i == 0:
153
  waveforms_orig[el] += waveforms[el][::-1].clone()
 
157
  gc.collect()
158
  if device.startswith('cuda'):
159
  torch.cuda.empty_cache()
160
+
161
  for el in waveforms_orig:
162
  waveforms_orig[el] /= (len(track_proc_list) + 1)
163
+
164
+ # TTA tamamlandı
165
+ if progress is not None and callable(getattr(progress, '__call__', None)):
166
+ progress(0.5, desc="TTA completed")
167
+ update_progress_html("TTA completed", 50)
168
+
169
  return waveforms_orig
170
 
171
  def _getWindowingArray(window_size: int, fade_size: int) -> torch.Tensor:
 
182
  mix: torch.Tensor,
183
  device: str,
184
  model_type: str,
185
+ pbar: bool = False,
186
+ progress=None # Gradio progress nesnesi
187
  ) -> Dict[str, np.ndarray]:
188
  logging.info(f"Starting demix for model_type: {model_type}, chunk_size: {config.audio.chunk_size}")
189
 
 
215
  model = model.to(device)
216
  model.eval()
217
 
218
+ # Toplam chunk sayısını hesapla
219
+ total_chunks = (mix.shape[1] + step - 1) // step
220
+ processed_chunks = 0
221
+
222
  with torch.no_grad(): # Çıkarım için gradyan yok
223
  with torch.cuda.amp.autocast(enabled=device.startswith('cuda'), dtype=torch.float16):
224
  req_shape = (num_instruments,) + mix.shape
 
228
  i = 0
229
  batch_data = []
230
  batch_locations = []
231
+ start_time = time.time()
232
 
233
  while i < mix.shape[1]:
234
  part = mix[:, i:i + chunk_size]
 
263
  result[..., start:start + seg_len] += x[j, ..., :seg_len]
264
  counter[..., start:start + seg_len] += 1.0
265
 
266
+ # İlerleme güncellemesi
267
+ processed_chunks += len(batch_data)
268
+ progress_value = min(round((processed_chunks / total_chunks) * 100), 100) # %1 hassasiyet
269
+ if progress is not None and callable(getattr(progress, '__call__', None)):
270
+ progress(progress_value / 100, desc=f"Processing chunk {processed_chunks}/{total_chunks}")
271
+ update_progress_html(f"Processing chunk {processed_chunks}/{total_chunks}", progress_value)
272
+
273
  del arr, x
274
  batch_data.clear()
275
  batch_locations.clear()
 
278
  torch.cuda.empty_cache()
279
  logging.info("Cleared CUDA cache")
280
 
281
+ elapsed_time = time.time() - start_time
282
+ logging.info(f"Demix completed in {elapsed_time:.2f} seconds")
 
 
 
283
 
284
  estimated_sources = result / (counter + 1e-8)
285
  estimated_sources = estimated_sources.numpy().astype(np.float32)
 
291
  instruments = config.training.instruments if mode == "demucs" else prefer_target_instrument(config)
292
  ret_data = {k: v for k, v in zip(instruments, estimated_sources)}
293
  logging.info("Demix completed successfully")
294
+
295
+ # Son ilerleme güncellemesi
296
+ if progress is not None and callable(getattr(progress, '__call__', None)):
297
+ progress(1.0, desc="Demix completed")
298
+ update_progress_html("Demix completed", 100)
299
+
300
  return ret_data
301
 
302
  def prefer_target_instrument(config: ConfigDict) -> List[str]: