Spaces:
Running
Running
from flask import Flask, request, jsonify | |
import torch | |
import torchaudio | |
import librosa | |
import os | |
import base64 | |
import tempfile | |
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq | |
from huggingface_hub import login | |
import logging | |
# Configuration du logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = Flask(__name__) | |
# Configuration du modèle - EXACTEMENT comme dans votre Gradio qui marchait | |
MODEL_NAME = "Ronaldodev/speech-to-text-fongbe" | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
# Variables globales pour le modèle | |
model = None | |
processor = None | |
model_loaded = False | |
def load_model(): | |
"""Charger le modèle privé au démarrage - EXACTEMENT votre code qui marchait""" | |
global model, processor, model_loaded | |
try: | |
logger.info("🔄 Chargement du modèle privé...") | |
if not HF_TOKEN: | |
raise ValueError("HF_TOKEN non configuré dans les secrets") | |
login(token=HF_TOKEN) | |
logger.info("✅ Authentification HF réussie") | |
# EXACTEMENT comme dans votre Gradio - pas d'optimisation qui casse | |
model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME) | |
processor = AutoProcessor.from_pretrained(MODEL_NAME) | |
logger.info("✅ Modèle chargé avec succès!") | |
model_loaded = True | |
return True | |
except Exception as e: | |
logger.error(f"❌ Erreur chargement: {e}") | |
model_loaded = False | |
return False | |
def transcribe_audio_file(audio_path): | |
"""Fonction principale de transcription - EXACTEMENT votre code qui marchait""" | |
if model is None or processor is None: | |
return {"success": False, "error": "Modèle non chargé. Vérifiez les logs."} | |
if audio_path is None: | |
return {"success": False, "error": "Aucun fichier audio fourni"} | |
try: | |
logger.info(f"🎵 Traitement audio: {audio_path}") | |
# EXACTEMENT votre logique qui marchait | |
try: | |
waveform, sample_rate = torchaudio.load(audio_path) | |
logger.info(f"✅ Audio chargé avec torchaudio: {sample_rate}Hz") | |
except Exception as e: | |
logger.warning(f"⚠️ Torchaudio échoué, essai librosa: {e}") | |
waveform, sample_rate = librosa.load(audio_path, sr=None) | |
waveform = torch.tensor(waveform).unsqueeze(0) | |
logger.info(f"✅ Audio chargé avec librosa: {sample_rate}Hz") | |
# EXACTEMENT votre traitement qui marchait | |
if waveform.shape[0] > 1: | |
waveform = waveform.mean(dim=0, keepdim=True) | |
logger.info("🔄 Conversion stéréo → mono") | |
if sample_rate != 16000: | |
logger.info(f"🔄 Resampling {sample_rate}Hz → 16000Hz") | |
resampler = torchaudio.transforms.Resample(sample_rate, 16000) | |
waveform = resampler(waveform) | |
# EXACTEMENT votre processing qui marchait | |
inputs = processor( | |
waveform.squeeze(), | |
sampling_rate=16000, | |
return_tensors="pt" | |
) | |
logger.info("🔄 Génération de la transcription...") | |
with torch.no_grad(): | |
# EXACTEMENT vos paramètres qui marchaient | |
result = model.generate( | |
**inputs, | |
max_length=500, | |
do_sample=False, | |
num_beams=1 | |
) | |
transcription = processor.batch_decode(result, skip_special_tokens=True)[0] | |
logger.info(f"✅ Transcription réussie: '{transcription}'") | |
return { | |
"success": True, | |
"transcription": transcription.strip(), | |
"model_name": MODEL_NAME | |
} | |
except Exception as e: | |
error_msg = f"❌ Erreur de transcription: {str(e)}" | |
logger.error(error_msg) | |
return {"success": False, "error": error_msg} | |
# ENDPOINTS FLASK - Juste l'enrobage, le cœur reste intact | |
def health_check(): | |
"""Point d'entrée principal de l'API""" | |
return jsonify({ | |
"status": "OK", | |
"message": "🎤 API STT Fongbé - Reconnaissance vocale pour la langue Fongbé", | |
"model_name": MODEL_NAME, | |
"model_loaded": model_loaded, | |
"version": "1.0.0" | |
}) | |
def health(): | |
"""Vérification de l'état de santé de l'API""" | |
return jsonify({ | |
"status": "healthy" if model_loaded else "model_not_loaded", | |
"model_loaded": model_loaded, | |
"model_name": MODEL_NAME, | |
"message": "Modèle chargé et prêt" if model_loaded else "Modèle sera chargé à la première utilisation" | |
}) | |
def load_model_endpoint(): | |
"""Charger le modèle manuellement""" | |
success = load_model() | |
return jsonify({ | |
"success": success, | |
"model_loaded": model_loaded, | |
"message": "Modèle chargé avec succès" if success else "Erreur lors du chargement" | |
}) | |
def transcribe_base64(): | |
"""Transcription audio à partir de données base64""" | |
try: | |
# Charger le modèle si pas encore fait | |
if not model_loaded: | |
if not load_model(): | |
return jsonify({"success": False, "error": "Impossible de charger le modèle"}), 503 | |
data = request.get_json() | |
if not data or "audio_base64" not in data: | |
return jsonify({"success": False, "error": "Paramètre 'audio_base64' requis"}), 400 | |
logger.info("🎵 Transcription via base64...") | |
audio_base64 = data["audio_base64"].strip() | |
remove_prefix = data.get("remove_prefix", True) | |
# Supprimer le préfixe data:audio/... si présent | |
if remove_prefix and audio_base64.startswith('data:'): | |
audio_base64 = audio_base64.split(',')[1] | |
# Décoder le base64 | |
try: | |
audio_bytes = base64.b64decode(audio_base64) | |
except Exception as e: | |
return jsonify({"success": False, "error": f"Données base64 invalides: {str(e)}"}), 400 | |
# Créer un fichier temporaire | |
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file: | |
temp_file.write(audio_bytes) | |
temp_path = temp_file.name | |
try: | |
# Utiliser EXACTEMENT votre fonction qui marchait | |
result = transcribe_audio_file(temp_path) | |
return jsonify(result) | |
finally: | |
# Nettoyer le fichier temporaire | |
if os.path.exists(temp_path): | |
os.unlink(temp_path) | |
except Exception as e: | |
logger.error(f"❌ Erreur transcription base64: {e}") | |
return jsonify({"success": False, "error": str(e)}), 500 | |
def transcribe_file(): | |
"""Transcription audio à partir d'un fichier uploadé""" | |
try: | |
# Charger le modèle si pas encore fait | |
if not model_loaded: | |
if not load_model(): | |
return jsonify({"success": False, "error": "Impossible de charger le modèle"}), 503 | |
# Vérifier qu'un fichier est présent | |
if 'audio_file' not in request.files: | |
return jsonify({"success": False, "error": "Aucun fichier 'audio_file' fourni"}), 400 | |
audio_file = request.files['audio_file'] | |
if audio_file.filename == '': | |
return jsonify({"success": False, "error": "Aucun fichier sélectionné"}), 400 | |
logger.info(f"🎵 Transcription du fichier: {audio_file.filename}") | |
# Sauvegarder le fichier temporairement | |
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file: | |
audio_file.save(temp_file.name) | |
temp_path = temp_file.name | |
try: | |
# Utiliser EXACTEMENT votre fonction qui marchait | |
result = transcribe_audio_file(temp_path) | |
if result["success"]: | |
result["filename"] = audio_file.filename | |
return jsonify(result) | |
finally: | |
# Nettoyer le fichier temporaire | |
if os.path.exists(temp_path): | |
os.unlink(temp_path) | |
except Exception as e: | |
logger.error(f"❌ Erreur transcription fichier: {e}") | |
return jsonify({"success": False, "error": str(e)}), 500 | |
def transcribe_url(): | |
"""Transcription audio à partir d'une URL""" | |
try: | |
# Charger le modèle si pas encore fait | |
if not load_model(): | |
return jsonify({"success": False, "error": "Impossible de charger le modèle"}), 503 | |
data = request.get_json() | |
if not data or "url" not in data: | |
return jsonify({"success": False, "error": "Paramètre 'url' requis"}), 400 | |
url = data["url"] | |
logger.info(f"🌐 Téléchargement depuis URL: {url}") | |
import requests | |
# Télécharger le fichier | |
response = requests.get(url, timeout=30) | |
response.raise_for_status() | |
# Créer un fichier temporaire | |
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file: | |
temp_file.write(response.content) | |
temp_path = temp_file.name | |
try: | |
# Utiliser EXACTEMENT votre fonction qui marchait | |
result = transcribe_audio_file(temp_path) | |
if result["success"]: | |
result["url"] = url | |
return jsonify(result) | |
finally: | |
# Nettoyer le fichier temporaire | |
if os.path.exists(temp_path): | |
os.unlink(temp_path) | |
except Exception as e: | |
logger.error(f"❌ Erreur transcription URL: {e}") | |
return jsonify({"success": False, "error": str(e)}), 500 | |
def test(): | |
"""Endpoint de test simple""" | |
return jsonify({ | |
"status": "API fonctionnelle", | |
"message": "Test réussi ✅", | |
"model_loaded": model_loaded, | |
"timestamp": "2025-01-04" | |
}) | |
if __name__ == "__main__": | |
print("🚀 DÉMARRAGE API STT FONGBÉ - FLASK") | |
print("=" * 50) | |
print("🌐 Port: 7860") | |
print("📖 Endpoints disponibles:") | |
print(" GET / - Statut de l'API") | |
print(" GET /health - Santé de l'API") | |
print(" GET /test - Test simple") | |
print(" POST /load-model - Charger le modèle") | |
print(" POST /transcribe/base64 - Transcription base64") | |
print(" POST /transcribe/file - Transcription fichier") | |
print(" POST /transcribe/url - Transcription URL") | |
print("=" * 50) | |
# Essayer de charger le modèle au démarrage (optionnel) | |
print("🔄 Tentative de chargement du modèle au démarrage...") | |
if load_model(): | |
print("✅ Modèle chargé au démarrage") | |
else: | |
print("⚠️ Modèle sera chargé à la première utilisation") | |
app.run(host="0.0.0.0", port=7860, debug=True) |