v-e-n-o-m commited on
Commit
7ded65d
·
1 Parent(s): 7379e3f
Files changed (3) hide show
  1. Dockerfile +8 -0
  2. app.py +112 -48
  3. requirements.txt +5 -5
Dockerfile CHANGED
@@ -9,6 +9,11 @@ RUN apt-get update && apt-get install -y \
9
  # Set working directory
10
  WORKDIR /app
11
 
 
 
 
 
 
12
  # Copy requirements and install
13
  COPY requirements.txt .
14
  RUN pip install --no-cache-dir -r requirements.txt
@@ -16,6 +21,9 @@ RUN pip install --no-cache-dir -r requirements.txt
16
  # Copy application code
17
  COPY app.py .
18
 
 
 
 
19
  # Expose port
20
  EXPOSE 8000
21
 
 
9
  # Set working directory
10
  WORKDIR /app
11
 
12
+ # Create cache and logs directories
13
+ RUN mkdir -p /app/cache /app/logs && \
14
+ chown -R 1000:1000 /app/cache /app/logs && \
15
+ chmod -R 775 /app/cache /app/logs
16
+
17
  # Copy requirements and install
18
  COPY requirements.txt .
19
  RUN pip install --no-cache-dir -r requirements.txt
 
21
  # Copy application code
22
  COPY app.py .
23
 
24
+ # Set environment variable for transformers cache
25
+ ENV TRANSFORMERS_CACHE=/app/cache
26
+
27
  # Expose port
28
  EXPOSE 8000
29
 
app.py CHANGED
@@ -1,72 +1,136 @@
1
- from fastapi import FastAPI, File, UploadFile
 
 
2
  from transformers import pipeline, WhisperForConditionalGeneration, WhisperProcessor
3
  import torch
4
  import soundfile as sf
5
- import io
6
  import subprocess
7
  import tempfile
8
- import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  app = FastAPI(title="Quran Transcription API")
11
 
12
  # Load model and processor
13
- model_id = "tarteel-ai/whisper-base-ar-quran"
14
- processor = WhisperProcessor.from_pretrained(model_id)
15
- model = WhisperForConditionalGeneration.from_pretrained(model_id)
16
- model.generation_config.no_timestamps_token_id = processor.tokenizer.convert_tokens_to_ids("<|notimestamps|>")
 
 
 
 
 
 
17
 
18
  # Initialize ASR pipeline
19
- asr = pipeline(
20
- "automatic-speech-recognition",
21
- model=model,
22
- tokenizer=processor.tokenizer,
23
- feature_extractor=processor.feature_extractor,
24
- device=0 if torch.cuda.is_available() else -1
25
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  @app.post("/transcribe")
28
  async def transcribe_audio(file: UploadFile = File(...)):
 
 
29
  # Validate file type
30
  if not file.filename.lower().endswith(".mp3"):
31
- return {"error": "Only MP3 files are supported"}
 
32
 
33
  # Read MP3 file
34
- mp3_data = await file.read()
 
 
 
 
 
35
 
36
- # Create temporary files for conversion
37
- with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as temp_mp3:
38
- temp_mp3.write(mp3_data)
39
- temp_mp3_path = temp_mp3.name
 
 
40
 
41
- temp_wav_path = temp_mp3_path.replace(".mp3", ".wav")
42
-
43
- try:
44
- # Convert MP3 to 16 kHz mono WAV using ffmpeg
45
- subprocess.run(
46
- [
47
- "ffmpeg",
48
- "-i", temp_mp3_path,
49
- "-ar", "16000",
50
- "-ac", "1",
51
- "-y", # Overwrite output file if exists
52
- temp_wav_path
53
- ],
54
- check=True,
55
- capture_output=True
56
- )
 
 
 
 
 
 
57
 
58
  # Read WAV file
59
- audio, sample_rate = sf.read(temp_wav_path)
60
- if sample_rate != 16000:
61
- return {"error": "Converted audio is not 16 kHz"}
 
 
 
 
 
 
62
 
63
  # Transcribe
64
- transcription = asr(audio, return_timestamps=False)["text"]
65
- return {"transcription": transcription}
66
-
67
- finally:
68
- # Clean up temporary files
69
- if os.path.exists(temp_mp3_path):
70
- os.unlink(temp_mp3_path)
71
- if os.path.exists(temp_wav_path):
72
- os.unlink(temp_wav_path)
 
1
+ import logging
2
+ import os
3
+ from fastapi import FastAPI, File, UploadFile, HTTPException
4
  from transformers import pipeline, WhisperForConditionalGeneration, WhisperProcessor
5
  import torch
6
  import soundfile as sf
 
7
  import subprocess
8
  import tempfile
9
+ from contextlib import contextmanager
10
+
11
+ # Configure logging
12
+ logging.basicConfig(
13
+ level=logging.INFO,
14
+ format="%(asctime)s [%(levelname)s] %(message)s",
15
+ handlers=[
16
+ logging.FileHandler("/app/logs/app.log"),
17
+ logging.StreamHandler()
18
+ ]
19
+ )
20
+ logger = logging.getLogger(__name__)
21
+
22
+ # Ensure cache directory exists
23
+ os.makedirs("/app/cache", exist_ok=True)
24
+ os.environ["TRANSFORMERS_CACHE"] = "/app/cache"
25
 
26
  app = FastAPI(title="Quran Transcription API")
27
 
28
  # Load model and processor
29
+ try:
30
+ model_id = "tarteel-ai/whisper-base-ar-quran"
31
+ logger.info(f"Loading processor for model: {model_id}")
32
+ processor = WhisperProcessor.from_pretrained(model_id)
33
+ logger.info(f"Loading model: {model_id}")
34
+ model = WhisperForConditionalGeneration.from_pretrained(model_id)
35
+ model.generation_config.no_timestamps_token_id = processor.tokenizer.convert_tokens_to_ids("<|notimestamps|>")
36
+ except Exception as e:
37
+ logger.error(f"Failed to load model: {str(e)}")
38
+ raise HTTPException(status_code=500, detail="Model loading failed")
39
 
40
  # Initialize ASR pipeline
41
+ try:
42
+ logger.info("Initializing ASR pipeline")
43
+ asr = pipeline(
44
+ "automatic-speech-recognition",
45
+ model=model,
46
+ tokenizer=processor.tokenizer,
47
+ feature_extractor=processor.feature_extractor,
48
+ device=0 if torch.cuda.is_available() else -1
49
+ )
50
+ except Exception as e:
51
+ logger.error(f"Failed to initialize ASR pipeline: {str(e)}")
52
+ raise HTTPException(status_code=500, detail="Pipeline initialization failed")
53
+
54
+ @contextmanager
55
+ def temporary_files():
56
+ """Context manager for creating and cleaning up temporary files."""
57
+ temp_mp3 = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False)
58
+ temp_wav_path = temp_mp3.name.replace(".mp3", ".wav")
59
+ try:
60
+ yield temp_mp3, temp_wav_path
61
+ finally:
62
+ for path in [temp_mp3.name, temp_wav_path]:
63
+ if os.path.exists(path):
64
+ try:
65
+ os.unlink(path)
66
+ logger.debug(f"Deleted temporary file: {path}")
67
+ except Exception as e:
68
+ logger.warning(f"Failed to delete temporary file {path}: {str(e)}")
69
 
70
  @app.post("/transcribe")
71
  async def transcribe_audio(file: UploadFile = File(...)):
72
+ logger.info(f"Received file: {file.filename}")
73
+
74
  # Validate file type
75
  if not file.filename.lower().endswith(".mp3"):
76
+ logger.error(f"Invalid file type: {file.filename}. Only MP3 is supported")
77
+ raise HTTPException(status_code=400, detail="Only MP3 files are supported")
78
 
79
  # Read MP3 file
80
+ try:
81
+ mp3_data = await file.read()
82
+ logger.debug(f"Read {len(mp3_data)} bytes from MP3 file")
83
+ except Exception as e:
84
+ logger.error(f"Failed to read MP3 file: {str(e)}")
85
+ raise HTTPException(status_code=500, detail="Failed to read audio file")
86
 
87
+ # Convert MP3 to WAV
88
+ with temporary_files() as (temp_mp3, temp_wav_path):
89
+ try:
90
+ temp_mp3.write(mp3_data)
91
+ temp_mp3.close()
92
+ logger.info(f"Saved MP3 to temporary file: {temp_mp3.name}")
93
 
94
+ # Convert to 16 kHz mono WAV using ffmpeg
95
+ logger.info(f"Converting MP3 to WAV: {temp_wav_path}")
96
+ result = subprocess.run(
97
+ [
98
+ "ffmpeg",
99
+ "-i", temp_mp3.name,
100
+ "-ar", "16000",
101
+ "-ac", "1",
102
+ "-y", # Overwrite output file if exists
103
+ temp_wav_path
104
+ ],
105
+ check=True,
106
+ capture_output=True,
107
+ text=True
108
+ )
109
+ logger.debug(f"ffmpeg output: {result.stdout}")
110
+ except subprocess.CalledProcessError as e:
111
+ logger.error(f"ffmpeg conversion failed: {e.stderr}")
112
+ raise HTTPException(status_code=500, detail="Audio conversion failed")
113
+ except Exception as e:
114
+ logger.error(f"Unexpected error during conversion: {str(e)}")
115
+ raise HTTPException(status_code=500, detail="Unexpected error during conversion")
116
 
117
  # Read WAV file
118
+ try:
119
+ audio, sample_rate = sf.read(temp_wav_path)
120
+ logger.info(f"Read WAV file: {temp_wav_path}, sample rate: {sample_rate}")
121
+ if sample_rate != 16000:
122
+ logger.error(f"Invalid sample rate: {sample_rate}. Expected 16000 Hz")
123
+ raise HTTPException(status_code=400, detail="Converted audio is not 16 kHz")
124
+ except Exception as e:
125
+ logger.error(f"Failed to read WAV file: {str(e)}")
126
+ raise HTTPException(status_code=500, detail="Failed to read converted audio")
127
 
128
  # Transcribe
129
+ try:
130
+ logger.info("Starting transcription")
131
+ transcription = asr(audio, return_timestamps=False)["text"]
132
+ logger.info(f"Transcription completed: {transcription}")
133
+ return {"transcription": transcription}
134
+ except Exception as e:
135
+ logger.error(f"Transcription failed: {str(e)}")
136
+ raise HTTPException(status_code=500, detail="Transcription failed")
 
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
- fastapi==0.115.0
2
- uvicorn==0.30.6
3
- transformers==4.44.2
4
- torch==2.4.1
5
- soundfile==0.12.1
6
  python-multipart==0.0.9
 
1
+ fastapi==0.115.0
2
+ uvicorn==0.30.6
3
+ transformers==4.44.2
4
+ torch==2.4.1
5
+ soundfile==0.12.1
6
  python-multipart==0.0.9