oza75 commited on
Commit
3c87043
·
verified ·
1 Parent(s): 21bea8f

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +5 -16
handler.py CHANGED
@@ -38,26 +38,15 @@ class EndpointHandler:
38
  torch.Tensor: Preprocessed audio tensor ready for inference
39
  """
40
  try:
 
41
  # Handle different input types
42
- if isinstance(input_data, bytes):
43
  # Load audio from bytes
44
- audio_buffer = BytesIO(input_data)
45
  waveform, sample_rate = torchaudio.load(audio_buffer)
46
- elif isinstance(input_data, dict):
47
- logger.info(f"Input data: {input_data.keys()}")
48
- if 'audio' in input_data:
49
- # Handle numpy array input
50
- audio_array = input_data['audio']
51
- if isinstance(audio_array, list):
52
- audio_array = np.array(audio_array)
53
- waveform = torch.from_numpy(audio_array)
54
- sample_rate = input_data.get('sampling_rate', self.target_sample_rate)
55
- # Ensure 2D tensor [channels, time]
56
- if waveform.dim() == 1:
57
- waveform = waveform.unsqueeze(0)
58
- else:
59
- raise ValueError("Input dictionary must contain 'audio' key")
60
  else:
 
 
61
  raise ValueError("Unsupported input type")
62
 
63
  # Convert to float32
 
38
  torch.Tensor: Preprocessed audio tensor ready for inference
39
  """
40
  try:
41
+ audio = input_data if isinstance(input_data, bytes) else input_data['inputs']
42
  # Handle different input types
43
+ if isinstance(audio, bytes):
44
  # Load audio from bytes
45
+ audio_buffer = BytesIO(audio)
46
  waveform, sample_rate = torchaudio.load(audio_buffer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  else:
48
+ logger.error(f"Unsupported input type: {type(audio)}")
49
+ logger.debug(f"Input data: {input_data.keys()}")
50
  raise ValueError("Unsupported input type")
51
 
52
  # Convert to float32