Spaces:
Paused
Paused
#coding: utf-8 | |
import os | |
import tempfile | |
#from typing import Any | |
#from typing import Dict | |
#from typing import IO | |
#from typing import List | |
from typing import Optional | |
from typing import Tuple | |
#from typing import Union | |
from base64 import b64encode | |
from openai import OpenAI | |
from pydub import AudioSegment | |
import streamlit as st | |
#from dotenv import load_dotenv | |
# Charger les variables d'environnement depuis le fichier .env | |
#load_dotenv() | |
class openai_tts(object): | |
def __init__(self, | |
tts_voice: Optional[str] = "nova", | |
tts_model: Optional[str] = "tts-1", | |
response_format: Optional[str] = "mp3", | |
speed: Optional[float] = 1.0 | |
): | |
self.client = None | |
self.init_supported_formats__() | |
self.init_api_client() | |
if response_format: | |
self.set_response_format(response_format) | |
if tts_voice: | |
self.set_tts_voice(tts_voice) | |
if tts_model: | |
self.set_tts_model(tts_model) | |
if speed: | |
self.set_tts_speed(speed) | |
def set_tts_speed(self, speed): | |
if not (0.25 <= speed <= 4.0): | |
raise ValueError(f"[TTS] - Speed must be between 0.25 and 4.0. Provided value: {speed}") | |
else: | |
self.speed = speed | |
return self | |
def set_tts_voice(self, voice): | |
voix_valides = ["alloy", "echo", "fable", "onyx", "nova", "shimmer"] | |
if voice not in voix_valides: | |
raise ValueError(f"[TTS] - Invalid TTS voice: {voice}. Valid voices are: {', '.join(voix_valides)}.") | |
else: | |
self.tts_voice = voice | |
return self | |
def set_tts_model(self, model): | |
if model not in ["tts-1", "tts-1-hd"]: | |
raise ValueError(f"[TTS] - Invalid TTS model: {model}. Valid models are 'tts-1' and 'tts-1-hd'.") | |
else: | |
self.tts_model = model | |
return self | |
def init_supported_formats__(self): | |
self.supported_formats = [ 'mp3', 'opus', 'aac', 'flac', 'wav', 'pcm' ] | |
return self | |
def set_response_format(self, format: str): | |
if format not in self.supported_formats: | |
raise ValueError(f"[TTS] - Unsupported format: {format}. Supported formats are: {', '.join(self.supported_formats)}") | |
else: | |
self.response_format = format | |
return self | |
def init_api_client(self): | |
if not (self.client): | |
# OpenAI client configuration with API key | |
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
return self | |
def text_to_speech(self, | |
input_text: str) -> Tuple[Optional[bytes], float]: | |
""" | |
Convertit du texte en parole en utilisant l'API OpenAI. | |
Args: | |
input_text (str): Le texte à convertir en parole. | |
Returns: | |
Dict[str, Union[float, str]]: Un dictionnaire contenant: | |
- 'audio_duration' (float): La durée de l'audio en secondes. | |
- 'data_bytes' (str): Les données audio encodées en base64. | |
""" | |
response = self.client.audio.speech.create( | |
model=self.tts_model, | |
voice=self.tts_voice, | |
input=input_text, | |
response_format=self.response_format, | |
speed=self.speed | |
) | |
data_output = response.read() | |
tmp_file = tempfile.TemporaryFile() | |
tmp_file.write(data_output) | |
tmp_file.seek(0) | |
audio = AudioSegment.from_file(tmp_file, format=self.response_format) | |
duration = len(audio) / 1000 | |
tmp_file.close() | |
return { | |
"audio_duration": duration, | |
"data_bytes": b64encode(data_output).decode() | |
} | |
def process_tts_message(text_response: str) -> Tuple[Optional[bytes], Optional[float]]: | |
try: | |
tts_output_ = openai_tts( | |
tts_voice=st.session_state.tts_voice, | |
tts_model="tts-1", | |
response_format="mp3", | |
speed=1.0 | |
).text_to_speech(text_response) | |
return tts_output_["data_bytes"], tts_output_["audio_duration"] | |
except Exception as e: | |
st.error(f"Une erreur s'est produite lors de la conversion texte-parole : {e}") | |
return None, None | |
""" | |
if __name__ == "__main__": | |
openai_tts().text_to_speech("Hello, I am an AI assistant. How can I help you?") | |
""" |