v-e-n-o-m commited on
Commit
e65b477
·
1 Parent(s): f3b9613
Files changed (2) hide show
  1. app.py +30 -38
  2. requirements.txt +10 -10
app.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  from fastapi import FastAPI, File, UploadFile
2
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
3
  import torch
@@ -5,9 +8,7 @@ import io
5
  import soundfile as sf
6
  import numpy as np
7
  import torchaudio
8
- import logging
9
- import timeout_decorator
10
- import time
11
 
12
  # Set up logging
13
  logging.basicConfig(level=logging.INFO)
@@ -30,80 +31,71 @@ except Exception as e:
30
  model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ur", task="transcribe")
31
  logger.info("Set forced_decoder_ids for Urdu transcription")
32
 
33
- # Move model to CPU (free Spaces don’t have GPU)
34
  device = "cpu"
35
  model.to(device)
36
  logger.info(f"Model loaded and moved to {device}")
37
 
 
 
 
 
 
 
38
  @app.post("/transcribe")
39
  async def transcribe_audio(file: UploadFile = File(...)):
40
  try:
41
  start_time = time.time()
42
- # Read audio file (supports WAV, MP3, etc.)
 
 
43
  logger.info("Reading audio file")
44
- try:
45
- audio_data, sample_rate = sf.read(io.BytesIO(await file.read()))
46
- except Exception as e:
47
- logger.error(f"Failed to read audio file: {str(e)}")
48
- return {"error": f"Invalid or unsupported audio file: {str(e)}. Supported formats: WAV, MP3, FLAC."}
49
  logger.info(f"Audio read in {time.time() - start_time:.2f} seconds")
50
 
51
  # Ensure audio is mono
52
  if len(audio_data.shape) > 1:
53
- audio_data = np.mean(audio_data, axis=1) # Convert to mono
54
 
55
  # Resample to 16kHz if necessary
56
  target_sample_rate = 16000
57
  if sample_rate != target_sample_rate:
58
  logger.info(f"Resampling audio from {sample_rate} Hz to {target_sample_rate} Hz")
59
- step_time = time.time()
60
  audio_tensor = torch.from_numpy(audio_data).float()
61
  resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
62
- audio_tensor = resampler(audio_tensor)
63
- audio_data = audio_tensor.numpy()
64
  sample_rate = target_sample_rate
65
- logger.info(f"Resampling completed in {time.time() - step_time:.2f} seconds")
66
-
67
- # Trim silence (simplified for torchaudio 2.0.2)
68
- logger.info("Trimming silence")
69
- step_time = time.time()
70
- audio_tensor = torch.from_numpy(audio_data).float()
71
- vad = torchaudio.transforms.Vad(sample_rate=sample_rate) # No threshold
72
- audio_tensor = vad(audio_tensor)
73
- audio_data = audio_tensor.numpy()
74
- logger.info(f"Silence trimming completed in {time.time() - step_time:.2f} seconds")
75
 
76
  # Process audio input
77
  logger.info("Processing audio input")
78
- step_time = time.time()
79
  inputs = processor(audio_data, sampling_rate=sample_rate, return_tensors="pt")
80
  input_features = inputs.input_features.to(device)
81
- logger.info(f"Input processing completed in {time.time() - step_time:.2f} seconds")
82
 
83
- # Generate transcription with timeout
84
  logger.info("Generating transcription")
85
- step_time = time.time()
86
- @timeout_decorator.timeout(15, timeout_exception=TimeoutError) # 15-second timeout
87
- def generate_transcription():
88
  with torch.no_grad():
89
  generated_ids = model.generate(
90
  input_features,
91
  max_new_tokens=225,
92
- num_beams=1, # Disable beam search
93
- length_penalty=0.0 # Faster decoding
94
  )
95
  return generated_ids
96
 
97
- generated_ids = generate_transcription()
98
- logger.info(f"Transcription generated in {time.time() - step_time:.2f} seconds")
99
- transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
 
 
100
 
 
101
  total_time = time.time() - start_time
102
  logger.info(f"Total transcription time: {total_time:.2f} seconds")
 
103
  return {"transcription": transcription}
104
- except TimeoutError:
105
- logger.error("Transcription timed out after 15 seconds")
106
- return {"error": "Transcription took too long. Try a faster model or check Space performance."}
107
  except Exception as e:
108
  logger.error(f"Error during transcription: {str(e)}")
109
  return {"error": str(e)}
 
1
+ import asyncio
2
+ import logging
3
+ import time
4
  from fastapi import FastAPI, File, UploadFile
5
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
6
  import torch
 
8
  import soundfile as sf
9
  import numpy as np
10
  import torchaudio
11
+ import psutil
 
 
12
 
13
  # Set up logging
14
  logging.basicConfig(level=logging.INFO)
 
31
  model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ur", task="transcribe")
32
  logger.info("Set forced_decoder_ids for Urdu transcription")
33
 
34
+ # Move model to CPU
35
  device = "cpu"
36
  model.to(device)
37
  logger.info(f"Model loaded and moved to {device}")
38
 
39
+ # Log memory usage
40
+ def log_memory_usage():
41
+ process = psutil.Process()
42
+ mem_info = process.memory_info()
43
+ logger.info(f"Memory usage: {mem_info.rss / 1024**2:.2f} MB")
44
+
45
  @app.post("/transcribe")
46
  async def transcribe_audio(file: UploadFile = File(...)):
47
  try:
48
  start_time = time.time()
49
+ log_memory_usage()
50
+
51
+ # Read audio file
52
  logger.info("Reading audio file")
53
+ audio_data, sample_rate = sf.read(io.BytesIO(await file.read()))
 
 
 
 
54
  logger.info(f"Audio read in {time.time() - start_time:.2f} seconds")
55
 
56
  # Ensure audio is mono
57
  if len(audio_data.shape) > 1:
58
+ audio_data = np.mean(audio_data, axis=1)
59
 
60
  # Resample to 16kHz if necessary
61
  target_sample_rate = 16000
62
  if sample_rate != target_sample_rate:
63
  logger.info(f"Resampling audio from {sample_rate} Hz to {target_sample_rate} Hz")
 
64
  audio_tensor = torch.from_numpy(audio_data).float()
65
  resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
66
+ audio_data = resampler(audio_tensor).numpy()
 
67
  sample_rate = target_sample_rate
 
 
 
 
 
 
 
 
 
 
68
 
69
  # Process audio input
70
  logger.info("Processing audio input")
 
71
  inputs = processor(audio_data, sampling_rate=sample_rate, return_tensors="pt")
72
  input_features = inputs.input_features.to(device)
 
73
 
74
+ # Generate transcription with async timeout
75
  logger.info("Generating transcription")
76
+ async def generate_transcription():
 
 
77
  with torch.no_grad():
78
  generated_ids = model.generate(
79
  input_features,
80
  max_new_tokens=225,
81
+ num_beams=1,
82
+ length_penalty=0.0
83
  )
84
  return generated_ids
85
 
86
+ try:
87
+ async with asyncio.timeout(60): # 60-second timeout
88
+ generated_ids = await generate_transcription()
89
+ except asyncio.TimeoutError:
90
+ logger.error("Transcription timed out after 60 seconds")
91
+ return {"error": "Transcription took too long. Try a smaller model or upgrade your Space."}
92
 
93
+ transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
94
  total_time = time.time() - start_time
95
  logger.info(f"Total transcription time: {total_time:.2f} seconds")
96
+ log_memory_usage()
97
  return {"transcription": transcription}
98
+
 
 
99
  except Exception as e:
100
  logger.error(f"Error during transcription: {str(e)}")
101
  return {"error": str(e)}
requirements.txt CHANGED
@@ -1,11 +1,11 @@
1
- transformers==4.38.2
2
- torch==2.0.1
3
- fastapi==0.103.0
4
- uvicorn==0.23.2
5
- pydantic==2.3.0
6
- soundfile==0.12.1
7
- python-multipart==0.0.9
8
- numpy==1.26.4
9
- timeout-decorator==0.5.0
10
- torchaudio==2.0.2
11
  accelerate==0.30.1
 
1
+ transformers==4.44.2
2
+ torch==2.4.1
3
+ torchaudio==2.4.1
4
+ fastapi==0.103.0
5
+ uvicorn==0.23.2
6
+ pydantic==2.3.0
7
+ soundfile==0.12.1
8
+ python-multipart==0.0.9
9
+ numpy==1.26.4
10
+ psutil==6.0.0
11
  accelerate==0.30.1