Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import base64 | |
| import requests | |
| import secrets | |
| import os | |
| import argparse | |
| from io import BytesIO | |
| from pydub import AudioSegment | |
| LOCAL_API_ENDPOINT = "http://localhost:5000" | |
| PUBLIC_API_ENDPOINT = "http://121.176.153.117:5000" | |
| API_ENDPOINT = PUBLIC_API_ENDPOINT | |
| session_id = "" | |
| chat_history = [] | |
| css = """ | |
| #audio_input { | |
| margin-top: -30px; !important; | |
| margin-left: -15px; !important; | |
| width: 100% !important; | |
| } | |
| #audio_input button { | |
| height:50px !important; | |
| font-size: 0px !important; | |
| width: 110% !important; | |
| } | |
| #audio_input button:after { | |
| content: '🎤' !important; | |
| font-size: 16px !important; | |
| } | |
| audio { | |
| min-width: 200px !important; | |
| } | |
| @media (max-width : 480px) { | |
| #audio_input { | |
| width: 120% !important; | |
| } | |
| #audio_input button:after { | |
| content: '' !important; | |
| } | |
| #txt_input_container { | |
| flex-grow: 70% !important; | |
| } | |
| #audio_input_container { | |
| flex-grow: 30% !important; | |
| } | |
| } | |
| """ | |
| js_audio_auto_play = """ | |
| () => { | |
| // select last audio element | |
| const audio = document.getElementsByTagName('audio'); | |
| const last_audio = audio[audio.length - 1]; | |
| // set autoplay attribute | |
| last_audio.setAttribute('autoplay', true); | |
| } | |
| """ | |
| def create_chat_session(): | |
| r = requests.post(API_ENDPOINT + "/create") | |
| if (r.status_code != 201): | |
| raise Exception("Failed to create chat session") | |
| # create temp audio folder | |
| session_id = r.json()["id"] | |
| os.makedirs(f"./temp_audio/{session_id}") | |
| return session_id | |
| def create_new_or_change_session(history, id): | |
| global session_id | |
| global chat_history | |
| if id == "": | |
| session_id = create_chat_session() | |
| history = [] | |
| else: | |
| history, _ = change_session(history, id) | |
| chat_history = history | |
| return history, gr.update(value="", interactive=False) | |
| def add_text(history, text): | |
| history = history + [(text, None)] | |
| return history, gr.update(value="", interactive=False) | |
| def add_audio(history, audio): | |
| audio_bytes = base64.b64decode(audio['data'].split(',')[-1].encode('utf-8')) | |
| audio_file = BytesIO(audio_bytes) | |
| AudioSegment.from_file(audio_file).export(audio_file, format="mp3") | |
| # save audio file temporary to disk | |
| audio_id = secrets.token_hex(8) | |
| AudioSegment.from_file(audio_file).export(f"temp_audio/{session_id}/audio_input_{audio_id}.mp3", format="mp3") | |
| history = history + [((f"temp_audio/{session_id}/audio_input_{audio_id}.mp3",), None)] | |
| response = requests.post( | |
| API_ENDPOINT + "/transcribe", | |
| files={'audio': audio_file.getvalue()} | |
| ) | |
| if (response.status_code != 200): | |
| raise Exception(response.text) | |
| text = response.json()['text'] | |
| history = history + [(text, None)] | |
| return history, gr.update(value="", interactive=False) | |
| def reset_chat_session(history): | |
| global session_id | |
| global chat_history | |
| response = requests.post( | |
| API_ENDPOINT + f"/reset/{session_id}" | |
| ) | |
| if (response.status_code != 200): | |
| raise Exception(response.text) | |
| history = [] | |
| chat_history = [] | |
| return history | |
| def bot(history): | |
| if type(history[-1][0]) == str: | |
| message = history[-1][0] | |
| else: | |
| message = history[-2][0] | |
| response = requests.post( | |
| API_ENDPOINT + f"/send/text/{session_id}", | |
| headers={'Content-type': 'application/json'}, | |
| json={ | |
| 'message': message, | |
| 'role': 'user' | |
| } | |
| ) | |
| if (response.status_code != 200): | |
| raise Exception(f"Failed to send message, {response.text}") | |
| response = response.json() | |
| text, audio = response['text'], response['audio'] | |
| audio_bytes = base64.b64decode(audio.encode('utf-8')) | |
| audio_file = BytesIO(audio_bytes) | |
| audio_id = secrets.token_hex(8) | |
| AudioSegment.from_file(audio_file).export(f"temp_audio/{session_id}/audio_input_{audio_id}.mp3", format="mp3") | |
| history = history + [(None, (f"temp_audio/{session_id}/audio_input_{audio_id}.mp3",))] | |
| history = history + [(None, text)] | |
| global chat_history | |
| chat_history = history.copy() | |
| return history | |
| def change_session(history, id): | |
| global session_id | |
| global chat_history | |
| response = requests.get( | |
| API_ENDPOINT + f"/{id}" | |
| ) | |
| if (response.status_code != 200): | |
| raise Exception(response.text) | |
| response = response.json() | |
| session_id = id | |
| history = [] | |
| try: | |
| for chat in response: | |
| if chat['role'] == 'user': | |
| if chat['audio'] != "": | |
| audio_bytes = base64.b64decode(chat['audio'].encode('utf-8')) | |
| audio_file = BytesIO(audio_bytes) | |
| audio_id = secrets.token_hex(8) | |
| AudioSegment.from_file(audio_file).export(f"temp_audio/{id}/audio_input_{audio_id}.mp3", format="mp3") | |
| history = history + [((f"temp_audio/{id}/audio_input_{audio_id}.mp3",), None)] | |
| history = history + [(chat['message'], None)] | |
| elif chat['role'] == 'assistant': | |
| audio_bytes = base64.b64decode(chat['audio'].encode('utf-8')) | |
| audio_file = BytesIO(audio_bytes) | |
| audio_id = secrets.token_hex(8) | |
| AudioSegment.from_file(audio_file).export(f"temp_audio/{id}/audio_input_{audio_id}.mp3", format="mp3") | |
| history = history + [(None, (f"temp_audio/{id}/audio_input_{audio_id}.mp3",))] | |
| history = history + [(None, chat['message'])] | |
| else: | |
| raise Exception("Invalid chat role") | |
| except Exception as e: | |
| raise Exception(f"Response: {response}") | |
| chat_history = history.copy() | |
| print(f"len(chat_history): {len(chat_history)}\nlen(history): {len(history)}\nlen(response): {len(response)}") | |
| return history, gr.update(value="", interactive=False) | |
| def load_chat_history(history): | |
| global chat_history | |
| if len(chat_history) > len(history): | |
| history = chat_history | |
| return history | |
| def main(): | |
| global session_id | |
| global chat_history | |
| session_id = create_chat_session() | |
| chat_history = [] | |
| with gr.Blocks(css=css) as demo: | |
| with gr.Row(): | |
| # change session id | |
| change_session_txt = gr.Textbox( | |
| show_label=False, | |
| placeholder=session_id, | |
| ).style(container=False) | |
| with gr.Row(): | |
| # button to create new or change session id | |
| change_session_button = gr.Button( | |
| "Create new or change session", type='success', size="sm" | |
| ).style(container=False) | |
| chatbot = gr.Chatbot([], elem_id="chatbot").style(height=750) | |
| demo.load(load_chat_history, [chatbot], [chatbot], queue=False) | |
| with gr.Row(): | |
| with gr.Column(scale=0.85, min_width=0, elem_id="txt_input_container"): | |
| txt = gr.Textbox( | |
| show_label=False, | |
| placeholder="Enter text and press enter, or record audio", | |
| elem_id="txt_input" | |
| ).style(container=False) | |
| with gr.Column(scale=0.15, min_width=0, elem_id="audio_input_container"): | |
| audio = gr.Audio( | |
| source="microphone", type="numpy", show_label=False, format="mp3", min_width=0, container=False, elem_id="audio_input" | |
| ) | |
| with gr.Row(): | |
| reset_button = gr.Button( | |
| "Reset Chat Session", type='stop', size="sm" | |
| ).style(container=False) | |
| txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then( | |
| bot, chatbot, chatbot | |
| ).then( | |
| None, [], [], queue=False, _js=js_audio_auto_play | |
| ) | |
| txt_msg.then(lambda: gr.update(interactive=True), None, [txt], queue=False) | |
| audio_msg = audio.change(add_audio, [chatbot, audio], [chatbot, audio], queue=False, preprocess=False, postprocess=False).then( | |
| bot, chatbot, chatbot | |
| ).then( | |
| None, [], [], queue=False, _js=js_audio_auto_play | |
| ) | |
| audio_msg.then(lambda: gr.update(interactive=True, value=None), None, [audio], queue=False) | |
| reset_button.click(reset_chat_session, [chatbot], [chatbot], queue=False) | |
| chgn_msg = change_session_txt.submit(change_session, [chatbot, change_session_txt], [chatbot, change_session_txt], queue=False) | |
| chgn_msg.then(lambda: gr.update(interactive=True, placeholder=session_id), None, [change_session_txt], queue=False) | |
| create_new_or_change_session_btn = change_session_button.click(create_new_or_change_session, [chatbot, change_session_txt], [chatbot, change_session_txt], queue=False) | |
| create_new_or_change_session_btn.then(lambda: gr.update(interactive=True, placeholder=session_id), None, [change_session_txt], queue=False) | |
| return demo | |
| if __name__ == "__main__": | |
| # arguments --local | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--local", action="store_true", help="Use local API endpoint") | |
| args = parser.parse_args() | |
| if args.local: | |
| API_ENDPOINT = LOCAL_API_ENDPOINT | |
| demo = main() | |
| demo.launch(show_error=True, server_name="0.0.0.0") |