v-e-n-o-m commited on
Commit
04cf987
·
1 Parent(s): 8542679

Switch to whisper-medium for CPU

Browse files
Files changed (1) hide show
  1. app.py +23 -24
app.py CHANGED
@@ -1,7 +1,6 @@
1
  from fastapi import FastAPI, File, UploadFile, Form, HTTPException
2
  from transformers import pipeline
3
  import soundfile as sf
4
- import io
5
  import numpy as np
6
  import torch
7
  from pydub import AudioSegment
@@ -27,25 +26,25 @@ async def root():
27
  @app.get("/health")
28
  async def health():
29
  logger.info("Health check accessed")
30
- return {"status": "ok", "model": "whisper-large-v3"}
 
31
 
32
  @app.on_event("startup")
33
  async def startup_event():
34
  print("Uvicorn started successfully")
35
 
36
- print("Loading Whisper-large-v3...")
37
  try:
38
  pipe = pipeline(
39
  "automatic-speech-recognition",
40
- model="openai/whisper-large-v3",
41
- torch_dtype=torch.float16,
42
- device="cuda" if torch.cuda.is_available() else "cpu",
43
- model_kwargs={"use_safetensors": True},
44
- chunk_length_s=30 # Process 30s chunks
45
  )
46
- print("Model loaded successfully")
47
  except Exception as e:
48
- print(f"Model loading failed: {str(e)}")
49
  raise e
50
 
51
  @contextmanager
@@ -56,9 +55,12 @@ def temp_file(suffix):
56
  finally:
57
  os.unlink(temp.name)
58
 
59
- @timeout(120, use_signals=False) # Timeout after 120s
60
  def transcribe_audio(audio_data, language):
61
- return pipe(audio_data, generate_kwargs={"language": language, "task": "transcribe"}, batch_size=1)
 
 
 
62
 
63
  @app.post("/transcribe")
64
  async def transcribe(audio: UploadFile = File(...), language: str = Form(...)):
@@ -79,32 +81,29 @@ async def transcribe(audio: UploadFile = File(...), language: str = Form(...)):
79
  with open(temp_audio_path, "wb") as f:
80
  f.write(audio_bytes)
81
 
82
- # Check duration
83
  duration = librosa.get_duration(path=temp_audio_path)
84
  logger.info(f"Audio duration: {duration} seconds")
85
- if duration > 300: # 5min max
86
- raise HTTPException(400, detail="Audio too long, max 300s")
87
 
88
  with temp_file(".wav") as temp_wav_path:
89
  if ext != ".wav":
90
  logger.info(f"Converting {temp_audio_path} to WAV...")
91
- try:
92
- audio_segment = AudioSegment.from_file(temp_audio_path)
93
- audio_segment = audio_segment.set_frame_rate(16000).set_channels(1)
94
- audio_segment.export(temp_wav_path, format="wav")
95
- except Exception as e:
96
- logger.error(f"Conversion failed: {str(e)}")
97
- raise HTTPException(500, detail=f"Audio conversion failed: {str(e)}")
98
  else:
99
  logger.info("Skipping conversion for WAV input")
100
  temp_wav_path = temp_audio_path
101
 
102
  audio_data, sample_rate = sf.read(temp_wav_path)
 
103
  if len(audio_data.shape) > 1:
104
  audio_data = np.mean(audio_data, axis=1)
105
-
106
  if sample_rate != 16000:
107
- raise HTTPException(500, detail="Converted audio is not 16kHz")
 
 
108
 
109
  logger.info("Transcribing...")
110
  try:
 
1
  from fastapi import FastAPI, File, UploadFile, Form, HTTPException
2
  from transformers import pipeline
3
  import soundfile as sf
 
4
  import numpy as np
5
  import torch
6
  from pydub import AudioSegment
 
26
  @app.get("/health")
27
  async def health():
28
  logger.info("Health check accessed")
29
+ device = "cuda" if torch.cuda.is_available() else "cpu"
30
+ return {"status": "ok", "model": "whisper-medium", "device": device}
31
 
32
  @app.on_event("startup")
33
  async def startup_event():
34
  print("Uvicorn started successfully")
35
 
36
+ print("Loading Whisper-medium...")
37
  try:
38
  pipe = pipeline(
39
  "automatic-speech-recognition",
40
+ model="openai/whisper-medium",
41
+ torch_dtype=torch.float32,
42
+ device="cpu",
43
+ model_kwargs={"use_safetensors": True}
 
44
  )
45
+ logger.info("Model loaded successfully")
46
  except Exception as e:
47
+ logger.error(f"Model loading failed: {str(e)}")
48
  raise e
49
 
50
  @contextmanager
 
55
  finally:
56
  os.unlink(temp.name)
57
 
58
+ @timeout(30, use_signals=False) # 30s timeout
59
  def transcribe_audio(audio_data, language):
60
+ logger.info("Starting transcription pipeline...")
61
+ result = pipe(audio_data, generate_kwargs={"language": language, "task": "transcribe"})
62
+ logger.info("Transcription pipeline completed")
63
+ return result
64
 
65
  @app.post("/transcribe")
66
  async def transcribe(audio: UploadFile = File(...), language: str = Form(...)):
 
81
  with open(temp_audio_path, "wb") as f:
82
  f.write(audio_bytes)
83
 
 
84
  duration = librosa.get_duration(path=temp_audio_path)
85
  logger.info(f"Audio duration: {duration} seconds")
86
+ if duration > 60:
87
+ raise HTTPException(400, detail="Audio too long, max 60s")
88
 
89
  with temp_file(".wav") as temp_wav_path:
90
  if ext != ".wav":
91
  logger.info(f"Converting {temp_audio_path} to WAV...")
92
+ audio_segment = AudioSegment.from_file(temp_audio_path)
93
+ audio_segment = audio_segment.set_frame_rate(16000).set_channels(1)
94
+ audio_segment.export(temp_wav_path, format="wav")
 
 
 
 
95
  else:
96
  logger.info("Skipping conversion for WAV input")
97
  temp_wav_path = temp_audio_path
98
 
99
  audio_data, sample_rate = sf.read(temp_wav_path)
100
+ logger.info(f"Audio data shape: {audio_data.shape}, sample rate: {sample_rate}")
101
  if len(audio_data.shape) > 1:
102
  audio_data = np.mean(audio_data, axis=1)
 
103
  if sample_rate != 16000:
104
+ raise HTTPException(500, detail="Audio is not 16kHz")
105
+ if np.max(np.abs(audio_data)) < 1e-5:
106
+ raise HTTPException(400, detail="Audio is silent")
107
 
108
  logger.info("Transcribing...")
109
  try: