cheesecz commited on
Commit
4874e49
·
verified ·
1 Parent(s): 641bcbe

Upload 3 files

Browse files
Files changed (3) hide show
  1. Dockerfile +22 -0
  2. app.py +147 -0
  3. requirements.txt +8 -0
Dockerfile ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && apt-get install -y \
7
+ build-essential \
8
+ ffmpeg \
9
+ && rm -rf /var/lib/apt/lists/*
10
+
11
+ COPY requirements.txt .
12
+ RUN pip install --no-cache-dir -r requirements.txt
13
+
14
+ RUN python -c "from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq; \
15
+ processor = AutoProcessor.from_pretrained('nyrahealth/CrisperWhisper'); \
16
+ model = AutoModelForSpeechSeq2Seq.from_pretrained('nyrahealth/CrisperWhisper')"
17
+
18
+ COPY app.py .
19
+
20
+ ENV PORT=8080
21
+
22
+ CMD exec uvicorn app:app --host 0.0.0.0 --port $PORT --workers 4
app.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import json
4
+ from pathlib import Path
5
+ from typing import Dict, Any
6
+
7
+ from fastapi import FastAPI, File, UploadFile, HTTPException
8
+ from fastapi.responses import JSONResponse
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ import torch
11
+ import torchaudio
12
+ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
13
+ import logging
14
+ import uvicorn
15
+
16
+ # Configure logging
17
+ logging.basicConfig(level=logging.INFO)
18
+ logger = logging.getLogger(__name__)
19
+
20
+ app = FastAPI(
21
+ title="Speech-to-Text API",
22
+ description="API for speech-to-text transcription using CrisperWhisper model",
23
+ version="1.0.0"
24
+ )
25
+
26
+ # Add CORS middleware
27
+ app.add_middleware(
28
+ CORSMiddleware,
29
+ allow_origins=["*"],
30
+ allow_credentials=True,
31
+ allow_methods=["*"],
32
+ allow_headers=["*"],
33
+ )
34
+
35
+ # Initialize model and processor
36
+ @app.on_event("startup")
37
+ async def load_model():
38
+ logger.info("Loading CrisperWhisper model...")
39
+ global processor, model, device
40
+
41
+ device = "cuda" if torch.cuda.is_available() else "cpu"
42
+ processor = AutoProcessor.from_pretrained("nyrahealth/CrisperWhisper")
43
+ model = AutoModelForSpeechSeq2Seq.from_pretrained("nyrahealth/CrisperWhisper").to(device)
44
+ model.eval()
45
+ logger.info(f"Model loaded successfully on {device}")
46
+
47
+ # Create a temporary directory to store files
48
+ TEMP_DIR = Path(tempfile.mkdtemp())
49
+ ALLOWED_EXTENSIONS = {'mp3', 'wav', 'flac', 'ogg', 'm4a', 'mp4'}
50
+
51
+ def is_valid_audio_file(filename: str) -> bool:
52
+ return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
53
+
54
+ @app.post("/transcribe")
55
+ async def transcribe_audio(file: UploadFile = File(...)):
56
+ """
57
+ Transcribe an audio file and return word-level timestamps.
58
+
59
+ - **file**: Audio file to transcribe (MP3, WAV, FLAC, OGG, M4A, MP4)
60
+
61
+ Returns a JSON with transcription and timestamps.
62
+ """
63
+ # Check if file is selected
64
+ if not file.filename:
65
+ raise HTTPException(status_code=400, detail="No file selected")
66
+
67
+ # Check if file type is allowed
68
+ if not is_valid_audio_file(file.filename):
69
+ raise HTTPException(status_code=400,
70
+ detail=f"File type not allowed. Supported formats: {', '.join(ALLOWED_EXTENSIONS)}")
71
+
72
+ try:
73
+ # Create a safe filename
74
+ safe_filename = ''.join(c if c.isalnum() or c in '._- ' else '_' for c in file.filename)
75
+ file_path = TEMP_DIR / safe_filename
76
+
77
+ # Save the uploaded file
78
+ with open(file_path, "wb") as buffer:
79
+ content = await file.read()
80
+ buffer.write(content)
81
+
82
+ logger.info(f"Processing file: {safe_filename}")
83
+
84
+ # Load audio file
85
+ waveform, sample_rate = torchaudio.load(file_path)
86
+
87
+ # Convert to mono if stereo
88
+ if waveform.shape[0] > 1:
89
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
90
+
91
+ # Resample to 16kHz if needed
92
+ if sample_rate != 16000:
93
+ resampler = torchaudio.transforms.Resample(sample_rate, 16000)
94
+ waveform = resampler(waveform)
95
+ sample_rate = 16000
96
+
97
+ # Process audio with the model
98
+ input_features = processor(waveform.squeeze().numpy(), sampling_rate=sample_rate, return_tensors="pt").to(device)
99
+
100
+ # Generate transcription with word timestamps
101
+ with torch.no_grad():
102
+ generated_tokens = model.generate(
103
+ **input_features,
104
+ return_timestamps=True,
105
+ task="transcribe"
106
+ )
107
+
108
+ # Process outputs
109
+ result = processor.decode_timestamps(generated_tokens[0].detach().cpu(), slice_start_indices=True)
110
+
111
+ # Format the output
112
+ full_text = result['text']
113
+
114
+ # Process chunks with timestamps
115
+ chunks = []
116
+ for chunk in result['chunks']:
117
+ # Only include non-empty chunks
118
+ if chunk['text'].strip():
119
+ chunks.append({
120
+ "timestamp": [chunk['timestamp'][0], chunk['timestamp'][1]],
121
+ "text": chunk['text'].strip()
122
+ })
123
+
124
+ # Create output JSON
125
+ output = {
126
+ "text": full_text,
127
+ "chunks": chunks
128
+ }
129
+
130
+ # Clean up the file immediately to save space
131
+ os.remove(file_path)
132
+
133
+ # Return JSON directly
134
+ return output
135
+
136
+ except Exception as e:
137
+ logger.error(f"Error during transcription: {str(e)}")
138
+ raise HTTPException(status_code=500, detail=str(e))
139
+
140
+ @app.get("/health")
141
+ async def health_check():
142
+ """Health check endpoint for Cloud Run"""
143
+ return {"status": "healthy"}
144
+
145
+ if __name__ == "__main__":
146
+ port = int(os.environ.get("PORT", 8080))
147
+ uvicorn.run("app:app", host="0.0.0.0", port=port, reload=False)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.104.1
2
+ uvicorn[standard]==0.24.0
3
+ python-multipart==0.0.6
4
+ torch==2.1.0
5
+ torchaudio==2.1.0
6
+ transformers==4.36.0
7
+ accelerate==0.25.0
8
+ soundfile==0.12.1