Spaces:
Running
Running
import gradio as gr | |
import torch | |
import torchaudio | |
import librosa | |
import os | |
import base64 | |
import io | |
import tempfile | |
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq | |
from huggingface_hub import login | |
import logging | |
MODEL_NAME = "Ronaldodev/speech-to-text-fongbe" | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
model = None | |
processor = None | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
def load_model(): | |
"""Charger le modèle privé au démarrage""" | |
global model, processor | |
try: | |
logger.info("🔄 Chargement du modèle privé...") | |
if not HF_TOKEN: | |
raise ValueError("HF_TOKEN non configuré dans les secrets") | |
login(token=HF_TOKEN) | |
logger.info("✅ Authentification HF réussie") | |
model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME) | |
processor = AutoProcessor.from_pretrained(MODEL_NAME) | |
logger.info("✅ Modèle chargé avec succès!") | |
return True | |
except Exception as e: | |
logger.error(f"❌ Erreur chargement: {e}") | |
return False | |
def process_audio_data(audio_data, sample_rate=None): | |
"""Fonction commune pour traiter les données audio""" | |
if model is None or processor is None: | |
raise Exception("Modèle non chargé") | |
# Convertir en mono si nécessaire | |
if len(audio_data.shape) > 1: | |
audio_data = audio_data.mean(axis=0) | |
# Convertir en tensor PyTorch | |
if not isinstance(audio_data, torch.Tensor): | |
waveform = torch.tensor(audio_data, dtype=torch.float32).unsqueeze(0) | |
else: | |
waveform = audio_data.unsqueeze(0) if audio_data.dim() == 1 else audio_data | |
# Resampling si nécessaire | |
if sample_rate and sample_rate != 16000: | |
logger.info(f"🔄 Resampling {sample_rate}Hz → 16000Hz") | |
resampler = torchaudio.transforms.Resample(sample_rate, 16000) | |
waveform = resampler(waveform) | |
inputs = processor( | |
waveform.squeeze(), | |
sampling_rate=16000, | |
return_tensors="pt" | |
) | |
logger.info("🔄 Génération de la transcription...") | |
with torch.no_grad(): | |
result = model.generate( | |
**inputs, | |
max_length=500, | |
do_sample=False, | |
num_beams=1 | |
) | |
transcription = processor.batch_decode(result, skip_special_tokens=True)[0] | |
return transcription.strip() | |
def transcribe(audio): | |
"""Fonction pour l'interface Gradio (fichier)""" | |
if audio is None: | |
return "❌ Aucun fichier audio fourni" | |
try: | |
logger.info(f"🎵 Traitement audio: {audio}") | |
try: | |
waveform, sample_rate = torchaudio.load(audio) | |
logger.info(f"✅ Audio chargé avec torchaudio: {sample_rate}Hz") | |
except Exception as e: | |
logger.warning(f"⚠️ Torchaudio échoué, essai librosa: {e}") | |
waveform, sample_rate = librosa.load(audio, sr=None) | |
waveform = torch.tensor(waveform).unsqueeze(0) | |
logger.info(f"✅ Audio chargé avec librosa: {sample_rate}Hz") | |
transcription = process_audio_data(waveform, sample_rate) | |
logger.info(f"✅ Transcription réussie: '{transcription}'") | |
return transcription | |
except Exception as e: | |
error_msg = f"❌ Erreur de transcription: {str(e)}" | |
logger.error(error_msg) | |
return error_msg | |
def transcribe_api_base64(audio_base64): | |
"""API pour données base64""" | |
try: | |
logger.info("🔄 API: Traitement base64...") | |
# Décoder le base64 | |
if audio_base64.startswith('data:'): | |
# Format: data:audio/wav;base64,XXXXX | |
audio_base64 = audio_base64.split(',')[1] | |
audio_bytes = base64.b64decode(audio_base64) | |
# Créer un fichier temporaire | |
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file: | |
temp_file.write(audio_bytes) | |
temp_path = temp_file.name | |
try: | |
# Charger avec librosa | |
waveform, sample_rate = librosa.load(temp_path, sr=None) | |
waveform = torch.tensor(waveform) | |
transcription = process_audio_data(waveform, sample_rate) | |
logger.info(f"✅ API Transcription: '{transcription}'") | |
return {"success": True, "transcription": transcription} | |
finally: | |
# Nettoyer le fichier temporaire | |
os.unlink(temp_path) | |
except Exception as e: | |
error_msg = f"Erreur API base64: {str(e)}" | |
logger.error(error_msg) | |
return {"success": False, "error": error_msg} | |
def transcribe_api_file(audio_file): | |
"""API pour fichier audio direct""" | |
try: | |
logger.info("🔄 API: Traitement fichier...") | |
# Lire le fichier | |
audio_bytes = audio_file.read() | |
# Créer un fichier temporaire | |
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file: | |
temp_file.write(audio_bytes) | |
temp_path = temp_file.name | |
try: | |
# Charger avec librosa | |
waveform, sample_rate = librosa.load(temp_path, sr=None) | |
waveform = torch.tensor(waveform) | |
transcription = process_audio_data(waveform, sample_rate) | |
logger.info(f"✅ API Transcription: '{transcription}'") | |
return {"success": True, "transcription": transcription} | |
finally: | |
# Nettoyer le fichier temporaire | |
os.unlink(temp_path) | |
except Exception as e: | |
error_msg = f"Erreur API fichier: {str(e)}" | |
logger.error(error_msg) | |
return {"success": False, "error": error_msg} | |
print("🚀 DÉMARRAGE API STT FONGBÉ - RONALDODEV") | |
print("=" * 50) | |
if load_model(): | |
print("✅ Modèle chargé - Interface prête!") | |
model_status = "✅ Modèle chargé et prêt" | |
else: | |
print("❌ Erreur de chargement du modèle") | |
model_status = "❌ Erreur de chargement" | |
# Interface Gradio principale | |
with gr.Blocks(theme=gr.themes.Soft(), title="🎤 API STT Fongbé") as demo: | |
gr.Markdown(f""" | |
# 🎤 API STT Fongbé - Ronaldodev | |
**Reconnaissance vocale pour la langue Fongbé** | |
**Statut:** {model_status} | |
**Modèle:** `{MODEL_NAME}` | |
""") | |
with gr.Tab("🎵 Interface Utilisateur"): | |
audio_input = gr.Audio( | |
sources=["upload", "microphone"], | |
type="filepath", | |
label="🎤 Uploadez un fichier ou enregistrez directement" | |
) | |
transcription_output = gr.Textbox( | |
label="📝 Transcription en Fongbé", | |
placeholder="La transcription apparaîtra ici...", | |
lines=3 | |
) | |
transcribe_btn = gr.Button("🔄 Transcrire", variant="primary") | |
transcribe_btn.click( | |
fn=transcribe, | |
inputs=audio_input, | |
outputs=transcription_output | |
) | |
with gr.Tab("🔌 API Documentation"): | |
gr.Markdown(""" | |
## 📡 Endpoints API Disponibles | |
### 1. **POST** `/api/transcribe_base64` | |
Pour envoyer de l'audio en base64 | |
**Headers:** | |
``` | |
Content-Type: application/json | |
``` | |
**Body:** | |
```json | |
{ | |
"audio_base64": "data:audio/wav;base64,UklGRnoAAABXQVZF..." | |
} | |
``` | |
**Réponse:** | |
```json | |
{ | |
"success": true, | |
"transcription": "votre transcription ici" | |
} | |
``` | |
### 2. **POST** `/api/transcribe_file` | |
Pour envoyer un fichier audio directement | |
**Headers:** | |
``` | |
Content-Type: multipart/form-data | |
``` | |
**Body:** | |
- `audio_file`: votre fichier audio (WAV, MP3, M4A...) | |
**Réponse:** | |
```json | |
{ | |
"success": true, | |
"transcription": "votre transcription ici" | |
} | |
``` | |
### 📱 Exemple d'utilisation | |
**Python:** | |
```python | |
import requests | |
import base64 | |
# Méthode 1: Base64 | |
with open("audio.wav", "rb") as f: | |
audio_b64 = base64.b64encode(f.read()).decode() | |
response = requests.post( | |
"https://ronaldodev-stt-fongbe.hf.space/api/transcribe_base64", | |
json={"audio_base64": f"data:audio/wav;base64,{audio_b64}"} | |
) | |
# Méthode 2: Fichier direct | |
with open("audio.wav", "rb") as f: | |
response = requests.post( | |
"https://ronaldodev-stt-fongbe.hf.space/api/transcribe_file", | |
files={"audio_file": f} | |
) | |
result = response.json() | |
print(result["transcription"]) | |
``` | |
**Flutter:** | |
```dart | |
// Fichier direct | |
var request = http.MultipartRequest( | |
'POST', | |
Uri.parse('https://ronaldodev-stt-fongbe.hf.space/api/transcribe_file') | |
); | |
request.files.add(await http.MultipartFile.fromPath('audio_file', audioPath)); | |
var response = await request.send(); | |
``` | |
""") | |
# Ajouter les endpoints API personnalisés | |
demo.add_api_route( | |
"/api/transcribe_base64", | |
transcribe_api_base64, | |
methods=["POST"] | |
) | |
demo.add_api_route( | |
"/api/transcribe_file", | |
transcribe_api_file, | |
methods=["POST"] | |
) | |
demo.launch() |