Spaces:
Running
Running
deploy
Browse files- app.py +30 -38
- requirements.txt +10 -10
app.py
CHANGED
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
1 |
from fastapi import FastAPI, File, UploadFile
|
2 |
from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
3 |
import torch
|
@@ -5,9 +8,7 @@ import io
|
|
5 |
import soundfile as sf
|
6 |
import numpy as np
|
7 |
import torchaudio
|
8 |
-
import
|
9 |
-
import timeout_decorator
|
10 |
-
import time
|
11 |
|
12 |
# Set up logging
|
13 |
logging.basicConfig(level=logging.INFO)
|
@@ -30,80 +31,71 @@ except Exception as e:
|
|
30 |
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ur", task="transcribe")
|
31 |
logger.info("Set forced_decoder_ids for Urdu transcription")
|
32 |
|
33 |
-
# Move model to CPU
|
34 |
device = "cpu"
|
35 |
model.to(device)
|
36 |
logger.info(f"Model loaded and moved to {device}")
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
@app.post("/transcribe")
|
39 |
async def transcribe_audio(file: UploadFile = File(...)):
|
40 |
try:
|
41 |
start_time = time.time()
|
42 |
-
|
|
|
|
|
43 |
logger.info("Reading audio file")
|
44 |
-
|
45 |
-
audio_data, sample_rate = sf.read(io.BytesIO(await file.read()))
|
46 |
-
except Exception as e:
|
47 |
-
logger.error(f"Failed to read audio file: {str(e)}")
|
48 |
-
return {"error": f"Invalid or unsupported audio file: {str(e)}. Supported formats: WAV, MP3, FLAC."}
|
49 |
logger.info(f"Audio read in {time.time() - start_time:.2f} seconds")
|
50 |
|
51 |
# Ensure audio is mono
|
52 |
if len(audio_data.shape) > 1:
|
53 |
-
audio_data = np.mean(audio_data, axis=1)
|
54 |
|
55 |
# Resample to 16kHz if necessary
|
56 |
target_sample_rate = 16000
|
57 |
if sample_rate != target_sample_rate:
|
58 |
logger.info(f"Resampling audio from {sample_rate} Hz to {target_sample_rate} Hz")
|
59 |
-
step_time = time.time()
|
60 |
audio_tensor = torch.from_numpy(audio_data).float()
|
61 |
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
|
62 |
-
|
63 |
-
audio_data = audio_tensor.numpy()
|
64 |
sample_rate = target_sample_rate
|
65 |
-
logger.info(f"Resampling completed in {time.time() - step_time:.2f} seconds")
|
66 |
-
|
67 |
-
# Trim silence (simplified for torchaudio 2.0.2)
|
68 |
-
logger.info("Trimming silence")
|
69 |
-
step_time = time.time()
|
70 |
-
audio_tensor = torch.from_numpy(audio_data).float()
|
71 |
-
vad = torchaudio.transforms.Vad(sample_rate=sample_rate) # No threshold
|
72 |
-
audio_tensor = vad(audio_tensor)
|
73 |
-
audio_data = audio_tensor.numpy()
|
74 |
-
logger.info(f"Silence trimming completed in {time.time() - step_time:.2f} seconds")
|
75 |
|
76 |
# Process audio input
|
77 |
logger.info("Processing audio input")
|
78 |
-
step_time = time.time()
|
79 |
inputs = processor(audio_data, sampling_rate=sample_rate, return_tensors="pt")
|
80 |
input_features = inputs.input_features.to(device)
|
81 |
-
logger.info(f"Input processing completed in {time.time() - step_time:.2f} seconds")
|
82 |
|
83 |
-
# Generate transcription with timeout
|
84 |
logger.info("Generating transcription")
|
85 |
-
|
86 |
-
@timeout_decorator.timeout(15, timeout_exception=TimeoutError) # 15-second timeout
|
87 |
-
def generate_transcription():
|
88 |
with torch.no_grad():
|
89 |
generated_ids = model.generate(
|
90 |
input_features,
|
91 |
max_new_tokens=225,
|
92 |
-
num_beams=1,
|
93 |
-
length_penalty=0.0
|
94 |
)
|
95 |
return generated_ids
|
96 |
|
97 |
-
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
100 |
|
|
|
101 |
total_time = time.time() - start_time
|
102 |
logger.info(f"Total transcription time: {total_time:.2f} seconds")
|
|
|
103 |
return {"transcription": transcription}
|
104 |
-
|
105 |
-
logger.error("Transcription timed out after 15 seconds")
|
106 |
-
return {"error": "Transcription took too long. Try a faster model or check Space performance."}
|
107 |
except Exception as e:
|
108 |
logger.error(f"Error during transcription: {str(e)}")
|
109 |
return {"error": str(e)}
|
|
|
1 |
+
import asyncio
|
2 |
+
import logging
|
3 |
+
import time
|
4 |
from fastapi import FastAPI, File, UploadFile
|
5 |
from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
6 |
import torch
|
|
|
8 |
import soundfile as sf
|
9 |
import numpy as np
|
10 |
import torchaudio
|
11 |
+
import psutil
|
|
|
|
|
12 |
|
13 |
# Set up logging
|
14 |
logging.basicConfig(level=logging.INFO)
|
|
|
31 |
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ur", task="transcribe")
|
32 |
logger.info("Set forced_decoder_ids for Urdu transcription")
|
33 |
|
34 |
+
# Move model to CPU
|
35 |
device = "cpu"
|
36 |
model.to(device)
|
37 |
logger.info(f"Model loaded and moved to {device}")
|
38 |
|
39 |
+
# Log memory usage
|
40 |
+
def log_memory_usage():
|
41 |
+
process = psutil.Process()
|
42 |
+
mem_info = process.memory_info()
|
43 |
+
logger.info(f"Memory usage: {mem_info.rss / 1024**2:.2f} MB")
|
44 |
+
|
45 |
@app.post("/transcribe")
|
46 |
async def transcribe_audio(file: UploadFile = File(...)):
|
47 |
try:
|
48 |
start_time = time.time()
|
49 |
+
log_memory_usage()
|
50 |
+
|
51 |
+
# Read audio file
|
52 |
logger.info("Reading audio file")
|
53 |
+
audio_data, sample_rate = sf.read(io.BytesIO(await file.read()))
|
|
|
|
|
|
|
|
|
54 |
logger.info(f"Audio read in {time.time() - start_time:.2f} seconds")
|
55 |
|
56 |
# Ensure audio is mono
|
57 |
if len(audio_data.shape) > 1:
|
58 |
+
audio_data = np.mean(audio_data, axis=1)
|
59 |
|
60 |
# Resample to 16kHz if necessary
|
61 |
target_sample_rate = 16000
|
62 |
if sample_rate != target_sample_rate:
|
63 |
logger.info(f"Resampling audio from {sample_rate} Hz to {target_sample_rate} Hz")
|
|
|
64 |
audio_tensor = torch.from_numpy(audio_data).float()
|
65 |
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
|
66 |
+
audio_data = resampler(audio_tensor).numpy()
|
|
|
67 |
sample_rate = target_sample_rate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
# Process audio input
|
70 |
logger.info("Processing audio input")
|
|
|
71 |
inputs = processor(audio_data, sampling_rate=sample_rate, return_tensors="pt")
|
72 |
input_features = inputs.input_features.to(device)
|
|
|
73 |
|
74 |
+
# Generate transcription with async timeout
|
75 |
logger.info("Generating transcription")
|
76 |
+
async def generate_transcription():
|
|
|
|
|
77 |
with torch.no_grad():
|
78 |
generated_ids = model.generate(
|
79 |
input_features,
|
80 |
max_new_tokens=225,
|
81 |
+
num_beams=1,
|
82 |
+
length_penalty=0.0
|
83 |
)
|
84 |
return generated_ids
|
85 |
|
86 |
+
try:
|
87 |
+
async with asyncio.timeout(60): # 60-second timeout
|
88 |
+
generated_ids = await generate_transcription()
|
89 |
+
except asyncio.TimeoutError:
|
90 |
+
logger.error("Transcription timed out after 60 seconds")
|
91 |
+
return {"error": "Transcription took too long. Try a smaller model or upgrade your Space."}
|
92 |
|
93 |
+
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
94 |
total_time = time.time() - start_time
|
95 |
logger.info(f"Total transcription time: {total_time:.2f} seconds")
|
96 |
+
log_memory_usage()
|
97 |
return {"transcription": transcription}
|
98 |
+
|
|
|
|
|
99 |
except Exception as e:
|
100 |
logger.error(f"Error during transcription: {str(e)}")
|
101 |
return {"error": str(e)}
|
requirements.txt
CHANGED
@@ -1,11 +1,11 @@
|
|
1 |
-
transformers==4.
|
2 |
-
torch==2.
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
accelerate==0.30.1
|
|
|
1 |
+
transformers==4.44.2
|
2 |
+
torch==2.4.1
|
3 |
+
torchaudio==2.4.1
|
4 |
+
fastapi==0.103.0
|
5 |
+
uvicorn==0.23.2
|
6 |
+
pydantic==2.3.0
|
7 |
+
soundfile==0.12.1
|
8 |
+
python-multipart==0.0.9
|
9 |
+
numpy==1.26.4
|
10 |
+
psutil==6.0.0
|
11 |
accelerate==0.30.1
|