import gradio as gr import base64 import requests import secrets import os import argparse from io import BytesIO from pydub import AudioSegment LOCAL_API_ENDPOINT = "http://localhost:5000" PUBLIC_API_ENDPOINT = "http://121.176.153.117:5000" API_ENDPOINT = PUBLIC_API_ENDPOINT session_id = "" chat_history = [] css = """ #audio_input { margin-top: -30px; !important; margin-left: -15px; !important; width: 100% !important; } #audio_input button { height:50px !important; font-size: 0px !important; width: 110% !important; } #audio_input button:after { content: '🎤' !important; font-size: 16px !important; } audio { min-width: 200px !important; } @media (max-width : 480px) { #audio_input { width: 120% !important; } #audio_input button:after { content: '' !important; } #txt_input_container { flex-grow: 70% !important; } #audio_input_container { flex-grow: 30% !important; } } """ js_audio_auto_play = """ () => { // select last audio element const audio = document.getElementsByTagName('audio'); const last_audio = audio[audio.length - 1]; // set autoplay attribute last_audio.setAttribute('autoplay', true); } """ def create_chat_session(): r = requests.post(API_ENDPOINT + "/create") if (r.status_code != 201): raise Exception("Failed to create chat session") # create temp audio folder session_id = r.json()["id"] os.makedirs(f"./temp_audio/{session_id}") return session_id def create_new_or_change_session(history, id): global session_id global chat_history if id == "": session_id = create_chat_session() history = [] else: history, _ = change_session(history, id) chat_history = history return history, gr.update(value="", interactive=False) def add_text(history, text): history = history + [(text, None)] return history, gr.update(value="", interactive=False) def add_audio(history, audio): audio_bytes = base64.b64decode(audio['data'].split(',')[-1].encode('utf-8')) audio_file = BytesIO(audio_bytes) AudioSegment.from_file(audio_file).export(audio_file, format="mp3") # save audio file temporary to disk audio_id = secrets.token_hex(8) AudioSegment.from_file(audio_file).export(f"temp_audio/{session_id}/audio_input_{audio_id}.mp3", format="mp3") history = history + [((f"temp_audio/{session_id}/audio_input_{audio_id}.mp3",), None)] response = requests.post( API_ENDPOINT + "/transcribe", files={'audio': audio_file.getvalue()} ) if (response.status_code != 200): raise Exception(response.text) text = response.json()['text'] history = history + [(text, None)] return history, gr.update(value="", interactive=False) def reset_chat_session(history): global session_id global chat_history response = requests.post( API_ENDPOINT + f"/reset/{session_id}" ) if (response.status_code != 200): raise Exception(response.text) history = [] chat_history = [] return history def bot(history): if type(history[-1][0]) == str: message = history[-1][0] else: message = history[-2][0] response = requests.post( API_ENDPOINT + f"/send/text/{session_id}", headers={'Content-type': 'application/json'}, json={ 'message': message, 'role': 'user' } ) if (response.status_code != 200): raise Exception(f"Failed to send message, {response.text}") response = response.json() text, audio = response['text'], response['audio'] audio_bytes = base64.b64decode(audio.encode('utf-8')) audio_file = BytesIO(audio_bytes) audio_id = secrets.token_hex(8) AudioSegment.from_file(audio_file).export(f"temp_audio/{session_id}/audio_input_{audio_id}.mp3", format="mp3") history = history + [(None, (f"temp_audio/{session_id}/audio_input_{audio_id}.mp3",))] history = history + [(None, text)] global chat_history chat_history = history.copy() return history def change_session(history, id): global session_id global chat_history response = requests.get( API_ENDPOINT + f"/{id}" ) if (response.status_code != 200): raise Exception(response.text) response = response.json() session_id = id history = [] try: for chat in response: if chat['role'] == 'user': if chat['audio'] != "": audio_bytes = base64.b64decode(chat['audio'].encode('utf-8')) audio_file = BytesIO(audio_bytes) audio_id = secrets.token_hex(8) AudioSegment.from_file(audio_file).export(f"temp_audio/{id}/audio_input_{audio_id}.mp3", format="mp3") history = history + [((f"temp_audio/{id}/audio_input_{audio_id}.mp3",), None)] history = history + [(chat['message'], None)] elif chat['role'] == 'assistant': audio_bytes = base64.b64decode(chat['audio'].encode('utf-8')) audio_file = BytesIO(audio_bytes) audio_id = secrets.token_hex(8) AudioSegment.from_file(audio_file).export(f"temp_audio/{id}/audio_input_{audio_id}.mp3", format="mp3") history = history + [(None, (f"temp_audio/{id}/audio_input_{audio_id}.mp3",))] history = history + [(None, chat['message'])] else: raise Exception("Invalid chat role") except Exception as e: raise Exception(f"Response: {response}") chat_history = history.copy() print(f"len(chat_history): {len(chat_history)}\nlen(history): {len(history)}\nlen(response): {len(response)}") return history, gr.update(value="", interactive=False) def load_chat_history(history): global chat_history if len(chat_history) > len(history): history = chat_history return history def main(): global session_id global chat_history session_id = create_chat_session() chat_history = [] with gr.Blocks(css=css) as demo: with gr.Row(): # change session id change_session_txt = gr.Textbox( show_label=False, placeholder=session_id, ).style(container=False) with gr.Row(): # button to create new or change session id change_session_button = gr.Button( "Create new or change session", type='success', size="sm" ).style(container=False) chatbot = gr.Chatbot([], elem_id="chatbot").style(height=750) demo.load(load_chat_history, [chatbot], [chatbot], queue=False) with gr.Row(): with gr.Column(scale=0.85, min_width=0, elem_id="txt_input_container"): txt = gr.Textbox( show_label=False, placeholder="Enter text and press enter, or record audio", elem_id="txt_input" ).style(container=False) with gr.Column(scale=0.15, min_width=0, elem_id="audio_input_container"): audio = gr.Audio( source="microphone", type="numpy", show_label=False, format="mp3", min_width=0, container=False, elem_id="audio_input" ) with gr.Row(): reset_button = gr.Button( "Reset Chat Session", type='stop', size="sm" ).style(container=False) txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then( bot, chatbot, chatbot ).then( None, [], [], queue=False, _js=js_audio_auto_play ) txt_msg.then(lambda: gr.update(interactive=True), None, [txt], queue=False) audio_msg = audio.change(add_audio, [chatbot, audio], [chatbot, audio], queue=False, preprocess=False, postprocess=False).then( bot, chatbot, chatbot ).then( None, [], [], queue=False, _js=js_audio_auto_play ) audio_msg.then(lambda: gr.update(interactive=True, value=None), None, [audio], queue=False) reset_button.click(reset_chat_session, [chatbot], [chatbot], queue=False) chgn_msg = change_session_txt.submit(change_session, [chatbot, change_session_txt], [chatbot, change_session_txt], queue=False) chgn_msg.then(lambda: gr.update(interactive=True, placeholder=session_id), None, [change_session_txt], queue=False) create_new_or_change_session_btn = change_session_button.click(create_new_or_change_session, [chatbot, change_session_txt], [chatbot, change_session_txt], queue=False) create_new_or_change_session_btn.then(lambda: gr.update(interactive=True, placeholder=session_id), None, [change_session_txt], queue=False) return demo if __name__ == "__main__": # arguments --local parser = argparse.ArgumentParser() parser.add_argument("--local", action="store_true", help="Use local API endpoint") args = parser.parse_args() if args.local: API_ENDPOINT = LOCAL_API_ENDPOINT demo = main() demo.launch(show_error=True, server_name="0.0.0.0")