Spaces:
Running
on
L40S
Running
on
L40S
| import argparse | |
| import base64 | |
| import wave | |
| import ormsgpack | |
| import pyaudio | |
| import requests | |
| from pydub import AudioSegment | |
| from pydub.playback import play | |
| from tools.file import audio_to_bytes, read_ref_text | |
| from tools.schema import ServeReferenceAudio, ServeTTSRequest | |
| def parse_args(): | |
| parser = argparse.ArgumentParser( | |
| description="Send a WAV file and text to a server and receive synthesized audio.", | |
| formatter_class=argparse.RawTextHelpFormatter, | |
| ) | |
| parser.add_argument( | |
| "--url", | |
| "-u", | |
| type=str, | |
| default="http://127.0.0.1:8080/v1/tts", | |
| help="URL of the server", | |
| ) | |
| parser.add_argument( | |
| "--text", "-t", type=str, required=True, help="Text to be synthesized" | |
| ) | |
| parser.add_argument( | |
| "--reference_id", | |
| "-id", | |
| type=str, | |
| default=None, | |
| help="ID of the reference model to be used for the speech\n(Local: name of folder containing audios and files)", | |
| ) | |
| parser.add_argument( | |
| "--reference_audio", | |
| "-ra", | |
| type=str, | |
| nargs="+", | |
| default=None, | |
| help="Path to the audio file", | |
| ) | |
| parser.add_argument( | |
| "--reference_text", | |
| "-rt", | |
| type=str, | |
| nargs="+", | |
| default=None, | |
| help="Reference text for voice synthesis", | |
| ) | |
| parser.add_argument( | |
| "--output", | |
| "-o", | |
| type=str, | |
| default="generated_audio", | |
| help="Output audio file name", | |
| ) | |
| parser.add_argument( | |
| "--play", | |
| type=bool, | |
| default=True, | |
| help="Whether to play audio after receiving data", | |
| ) | |
| parser.add_argument("--normalize", type=bool, default=True) | |
| parser.add_argument( | |
| "--format", type=str, choices=["wav", "mp3", "flac"], default="wav" | |
| ) | |
| parser.add_argument( | |
| "--mp3_bitrate", type=int, choices=[64, 128, 192], default=64, help="kHz" | |
| ) | |
| parser.add_argument("--opus_bitrate", type=int, default=-1000) | |
| parser.add_argument( | |
| "--latency", | |
| type=str, | |
| default="normal", | |
| choices=["normal", "balanced"], | |
| help="Used in api.fish.audio/v1/tts", | |
| ) | |
| parser.add_argument( | |
| "--max_new_tokens", | |
| type=int, | |
| default=0, | |
| help="Maximum new tokens to generate. \n0 means no limit.", | |
| ) | |
| parser.add_argument( | |
| "--chunk_length", type=int, default=200, help="Chunk length for synthesis" | |
| ) | |
| parser.add_argument( | |
| "--top_p", type=float, default=0.7, help="Top-p sampling for synthesis" | |
| ) | |
| parser.add_argument( | |
| "--repetition_penalty", | |
| type=float, | |
| default=1.2, | |
| help="Repetition penalty for synthesis", | |
| ) | |
| parser.add_argument( | |
| "--temperature", type=float, default=0.7, help="Temperature for sampling" | |
| ) | |
| parser.add_argument( | |
| "--streaming", type=bool, default=False, help="Enable streaming response" | |
| ) | |
| parser.add_argument( | |
| "--channels", type=int, default=1, help="Number of audio channels" | |
| ) | |
| parser.add_argument("--rate", type=int, default=44100, help="Sample rate for audio") | |
| parser.add_argument( | |
| "--use_memory_cache", | |
| type=str, | |
| default="never", | |
| choices=["on-demand", "never"], | |
| help="Cache encoded references codes in memory.\n" | |
| "If `on-demand`, the server will use cached encodings\n " | |
| "instead of encoding reference audio again.", | |
| ) | |
| parser.add_argument( | |
| "--seed", | |
| type=int, | |
| default=None, | |
| help="`None` means randomized inference, otherwise deterministic.\n" | |
| "It can't be used for fixing a timbre.", | |
| ) | |
| return parser.parse_args() | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| idstr: str | None = args.reference_id | |
| # priority: ref_id > [{text, audio},...] | |
| if idstr is None: | |
| ref_audios = args.reference_audio | |
| ref_texts = args.reference_text | |
| if ref_audios is None: | |
| byte_audios = [] | |
| else: | |
| byte_audios = [audio_to_bytes(ref_audio) for ref_audio in ref_audios] | |
| if ref_texts is None: | |
| ref_texts = [] | |
| else: | |
| ref_texts = [read_ref_text(ref_text) for ref_text in ref_texts] | |
| else: | |
| byte_audios = [] | |
| ref_texts = [] | |
| pass # in api.py | |
| data = { | |
| "text": args.text, | |
| "references": [ | |
| ServeReferenceAudio(audio=ref_audio, text=ref_text) | |
| for ref_text, ref_audio in zip(ref_texts, byte_audios) | |
| ], | |
| "reference_id": idstr, | |
| "normalize": args.normalize, | |
| "format": args.format, | |
| "mp3_bitrate": args.mp3_bitrate, | |
| "opus_bitrate": args.opus_bitrate, | |
| "max_new_tokens": args.max_new_tokens, | |
| "chunk_length": args.chunk_length, | |
| "top_p": args.top_p, | |
| "repetition_penalty": args.repetition_penalty, | |
| "temperature": args.temperature, | |
| "streaming": args.streaming, | |
| "use_memory_cache": args.use_memory_cache, | |
| "seed": args.seed, | |
| } | |
| pydantic_data = ServeTTSRequest(**data) | |
| response = requests.post( | |
| args.url, | |
| data=ormsgpack.packb(pydantic_data, option=ormsgpack.OPT_SERIALIZE_PYDANTIC), | |
| stream=args.streaming, | |
| headers={ | |
| "authorization": "Bearer YOUR_API_KEY", | |
| "content-type": "application/msgpack", | |
| }, | |
| ) | |
| if response.status_code == 200: | |
| if args.streaming: | |
| p = pyaudio.PyAudio() | |
| audio_format = pyaudio.paInt16 # Assuming 16-bit PCM format | |
| stream = p.open( | |
| format=audio_format, channels=args.channels, rate=args.rate, output=True | |
| ) | |
| wf = wave.open(f"{args.output}.wav", "wb") | |
| wf.setnchannels(args.channels) | |
| wf.setsampwidth(p.get_sample_size(audio_format)) | |
| wf.setframerate(args.rate) | |
| stream_stopped_flag = False | |
| try: | |
| for chunk in response.iter_content(chunk_size=1024): | |
| if chunk: | |
| stream.write(chunk) | |
| wf.writeframesraw(chunk) | |
| else: | |
| if not stream_stopped_flag: | |
| stream.stop_stream() | |
| stream_stopped_flag = True | |
| finally: | |
| stream.close() | |
| p.terminate() | |
| wf.close() | |
| else: | |
| audio_content = response.content | |
| audio_path = f"{args.output}.{args.format}" | |
| with open(audio_path, "wb") as audio_file: | |
| audio_file.write(audio_content) | |
| audio = AudioSegment.from_file(audio_path, format=args.format) | |
| if args.play: | |
| play(audio) | |
| print(f"Audio has been saved to '{audio_path}'.") | |
| else: | |
| print(f"Request failed with status code {response.status_code}") | |
| print(response.json()) | |