DavidLanz's picture
Update app.py
82057f3 verified
from flask import Flask, request, jsonify, Response
from faster_whisper import WhisperModel
import torch
import time
import datetime
from threading import Semaphore
import os
from werkzeug.utils import secure_filename
import tempfile
from moviepy.editor import VideoFileClip
import logging
import torchaudio
import ffmpeg # ffmpeg-python
# ------------------------------------
# 日誌
# ------------------------------------
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
app = Flask(__name__)
# ------------------------------------
# 設定
# ------------------------------------
MAX_CONCURRENT_REQUESTS = 1
MAX_FILE_DURATION = 60 * 30 # 30 分鐘
TEMPORARY_FOLDER = tempfile.gettempdir()
ALLOWED_AUDIO_EXTENSIONS = {'mp3', 'wav', 'ogg', 'm4a', 'flac', 'aac', 'wma', 'opus', 'aiff'}
ALLOWED_VIDEO_EXTENSIONS = {'mp4', 'avi', 'mov', 'mkv', 'webm', 'flv', 'wmv', 'mpeg', 'mpg', '3gp'}
ALLOWED_EXTENSIONS = ALLOWED_AUDIO_EXTENSIONS.union(ALLOWED_VIDEO_EXTENSIONS)
API_KEY = os.environ.get("API_KEY") # 在 HF Space 的 Repo secrets 設定
MODEL_NAME = os.environ.get("WHISPER_MODEL", "guillaumekln/faster-whisper-large-v2")
# 預設提示(可用 ?prompt 覆蓋)
DEFAULT_INITIAL_PROMPT = "請使用繁體中文輸出"
# ------------------------------------
# 裝置與模型
# ------------------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
compute_type = "float16" if device == "cuda" else "int8"
logging.info(f"使用設備: {device},計算類型: {compute_type}")
beamsize = 2
try:
wmodel = WhisperModel(
MODEL_NAME,
device=device,
compute_type=compute_type,
download_root="./model_cache"
)
logging.info(f"模型 {MODEL_NAME} 載入成功.")
except Exception as e:
logging.error(f"載入模型 {MODEL_NAME} 失敗: {e}")
wmodel = None
# ------------------------------------
# 併發控制
# ------------------------------------
request_semaphore = Semaphore(MAX_CONCURRENT_REQUESTS)
active_requests = 0
# ------------------------------------
# 小工具
# ------------------------------------
def validate_api_key(req):
api_key = req.headers.get('X-API-Key')
return api_key == API_KEY if API_KEY else True # 若沒設定 API_KEY,預設放行(可依需求改)
def allowed_file(filename):
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
def cleanup_temp_files(*file_paths):
for file_path in file_paths:
try:
if file_path and os.path.exists(file_path):
os.remove(file_path)
logging.info(f"刪除暫存檔案: {file_path}")
except Exception as e:
logging.error(f"刪除暫存檔案 {file_path} 出錯: {str(e)}")
def extract_audio_from_video(video_path, output_audio_path):
"""
使用 ffmpeg 從影片擷取 PCM WAV,並用 moviepy 檢查長度
"""
try:
# 先擷取音訊
ffmpeg.input(video_path).output(
output_audio_path,
acodec='pcm_s16le'
# 可加參數: ar=44100, ac=2
).run(capture_stdout=True, capture_stderr=True)
# 再檢查影片時長
video = VideoFileClip(video_path)
if video.duration > MAX_FILE_DURATION:
video.close()
raise ValueError(f"視頻時長超過 {MAX_FILE_DURATION} 秒")
video.close()
return output_audio_path
except Exception as e:
logging.exception("提取視頻中的音訊出錯")
raise Exception(f"提取視頻中的音訊出錯: {str(e)}")
def fmt_mmss_mmm(seconds: float) -> str:
"""
轉成 MM:SS.mmm(符合需求,如 00:01.000)
若未來需要小時欄位,可改為 HH:MM:SS.mmm。
"""
if seconds is None:
seconds = 0.0
total_ms = int(round(seconds * 1000))
minutes, ms = divmod(total_ms, 60_000)
sec, ms = divmod(ms, 1000)
return f"{minutes:02d}:{sec:02d}.{ms:03d}"
def read_lang_param_with_default_zh():
"""
讀取 ?lang= 參數;沒帶或為 auto 時預設繁體中文 (zh)
"""
lang_param = request.args.get("lang", "").strip()
if not lang_param or lang_param.lower() == "auto":
return "zh"
return lang_param
def read_initial_prompt():
"""
讀取 ?prompt= 參數;沒帶則使用 DEFAULT_INITIAL_PROMPT
"""
prompt = request.args.get("prompt", "").strip()
return prompt if prompt else DEFAULT_INITIAL_PROMPT
def run_transcribe_pipeline(uploaded_file_path: str, file_extension: str):
"""
共用的轉錄流程:處理影片/音訊、長度檢查、呼叫 Faster-Whisper。
回傳:(segments_iterable, is_video, temp_audio_path)
"""
is_video = file_extension in ALLOWED_VIDEO_EXTENSIONS
temp_audio_path = None
if is_video:
temp_audio_path = os.path.join(TEMPORARY_FOLDER, f"temp_audio_{int(time.time())}.wav")
extract_audio_from_video(uploaded_file_path, temp_audio_path)
transcription_file = temp_audio_path
else:
transcription_file = uploaded_file_path
# 檢查音訊長度
try:
waveform, sample_rate = torchaudio.load(transcription_file, format=file_extension)
duration = waveform.size(1) / sample_rate
if duration > MAX_FILE_DURATION:
raise ValueError(f"音訊時長超過 {MAX_FILE_DURATION} 秒")
except Exception:
logging.exception(f"使用 torchaudio.load 載入音訊檔出錯: {transcription_file}")
try:
torchaudio.set_audio_backend("soundfile")
waveform, sample_rate = torchaudio.load(transcription_file)
duration = waveform.size(1) / sample_rate
if duration > MAX_FILE_DURATION:
raise ValueError(f"音訊時長超過 {MAX_FILE_DURATION} 秒")
except Exception as soundfile_err:
logging.exception(f"使用 soundfile 後端載入音訊檔出錯: {transcription_file}")
raise Exception(f'使用兩個後端載入音訊檔都出錯: {str(soundfile_err)}')
finally:
torchaudio.set_audio_backend("default")
# 預設語言 zh,並帶 initial_prompt(可被 ?lang / ?prompt 覆蓋)
language = read_lang_param_with_default_zh()
initial_prompt = read_initial_prompt()
# 轉錄(保留 segment 級時間)
segments, info = wmodel.transcribe(
transcription_file,
beam_size=beamsize,
vad_filter=True,
without_timestamps=False, # 要保留時間戳
compression_ratio_threshold=2.4,
word_timestamps=False, # 如需字級,設 True
language=language,
initial_prompt=initial_prompt
)
return segments, is_video, temp_audio_path
# ------------------------------------
# 健康檢查與狀態
# ------------------------------------
@app.route("/health", methods=["GET"])
def health_check():
return jsonify({
'status': 'API 正在運行',
'timestamp': datetime.datetime.now().isoformat(),
'device': device,
'compute_type': compute_type,
'active_requests': active_requests,
'max_duration_supported': MAX_FILE_DURATION,
'supported_formats': list(ALLOWED_EXTENSIONS),
'model': MODEL_NAME,
'default_language': 'zh',
'default_initial_prompt': DEFAULT_INITIAL_PROMPT
})
@app.route("/status/busy", methods=["GET"])
def server_busy():
is_busy = active_requests >= MAX_CONCURRENT_REQUESTS
return jsonify({
'is_busy': is_busy,
'active_requests': active_requests,
'max_capacity': MAX_CONCURRENT_REQUESTS
})
# ------------------------------------
# 端點 1:JSON(start/end 為 "MM:SS.mmm" 字串)
# ------------------------------------
@app.route("/whisper_transcribe", methods=["POST"])
def transcribe_json():
global active_requests
if not validate_api_key(request):
return jsonify({'error': '無效的 API 金鑰'}), 401
if not request_semaphore.acquire(blocking=False):
return jsonify({'error': '伺服器繁忙'}), 503
active_requests += 1
t0 = time.time()
temp_file_path = None
temp_audio_path = None
try:
if wmodel is None:
return jsonify({'error': '模型載入失敗。請檢查伺服器日誌。'}), 500
if 'file' not in request.files:
return jsonify({'error': '未提供檔'}), 400
file = request.files['file']
if not (file and allowed_file(file.filename)):
return jsonify({'error': f'無效的檔案格式。支持:{", ".join(ALLOWED_EXTENSIONS)}'}), 400
# 儲存上傳檔
temp_file_path = os.path.join(TEMPORARY_FOLDER, secure_filename(file.filename))
file.save(temp_file_path)
file_extension = file.filename.rsplit('.', 1)[1].lower()
# 執行轉錄流程
try:
segments_iter, is_video, temp_audio_path = run_transcribe_pipeline(temp_file_path, file_extension)
except Exception as e:
return jsonify({'error': str(e)}), 400
# 組 JSON:start/end 以 "MM:SS.mmm"
results = []
for seg in segments_iter:
start = seg.start or 0.0
end = seg.end or 0.0
text = (seg.text or "").strip()
results.append({
"start": fmt_mmss_mmm(start),
"end": fmt_mmss_mmm(end),
"text": text
})
return jsonify({
'file_type': 'video' if is_video else 'audio',
'segments': results
}), 200
except Exception as e:
logging.exception("轉錄過程中發生異常")
return jsonify({'error': str(e)}), 500
finally:
cleanup_temp_files(temp_file_path, temp_audio_path)
active_requests -= 1
request_semaphore.release()
logging.info(f"/whisper_transcribe 用時:{time.time() - t0:.2f}s (活動請求:{active_requests})")
# ------------------------------------
# 端點 2:純文字(整段合併,沒有時間戳)
# ------------------------------------
@app.route("/whisper_transcribe_text", methods=["POST"])
def transcribe_text_only():
global active_requests
if not validate_api_key(request):
return jsonify({'error': '無效的 API 金鑰'}), 401
if not request_semaphore.acquire(blocking=False):
return jsonify({'error': '伺服器繁忙'}), 503
active_requests += 1
t0 = time.time()
temp_file_path = None
temp_audio_path = None
try:
if wmodel is None:
return jsonify({'error': '模型載入失敗。請檢查伺服器日誌。'}), 500
if 'file' not in request.files:
return jsonify({'error': '未提供檔'}), 400
file = request.files['file']
if not (file and allowed_file(file.filename)):
return jsonify({'error': f'無效的檔案格式。支持:{", ".join(ALLOWED_EXTENSIONS)}'}), 400
# 儲存上傳檔
temp_file_path = os.path.join(TEMPORARY_FOLDER, secure_filename(file.filename))
file.save(temp_file_path)
file_extension = file.filename.rsplit('.', 1)[1].lower()
# 執行轉錄流程(沿用同一流程,僅輸出不同)
try:
segments_iter, is_video, temp_audio_path = run_transcribe_pipeline(temp_file_path, file_extension)
except Exception as e:
return jsonify({'error': str(e)}), 400
# 合併純文字
full_text = " ".join((seg.text or "").strip() for seg in segments_iter if (seg.text or "").strip())
# 直接回「純文字」
return Response(full_text, mimetype="text/plain; charset=utf-8"), 200
except Exception as e:
logging.exception("轉錄過程中發生異常")
return jsonify({'error': str(e)}), 500
finally:
cleanup_temp_files(temp_file_path, temp_audio_path)
active_requests -= 1
request_semaphore.release()
logging.info(f"/whisper_transcribe_text 用時:{time.time() - t0:.2f}s (活動請求:{active_requests})")
if __name__ == "__main__":
if not os.path.exists(TEMPORARY_FOLDER):
os.makedirs(TEMPORARY_FOLDER)
logging.info(f"新建暫存檔案夾: {TEMPORARY_FOLDER}")
app.run(host="0.0.0.0", port=7860, threaded=True)