Spaces:
Runtime error
Runtime error
import gradio as gr | |
import base64 | |
import requests | |
import secrets | |
import os | |
import argparse | |
from io import BytesIO | |
from pydub import AudioSegment | |
LOCAL_API_ENDPOINT = "http://localhost:5000" | |
PUBLIC_API_ENDPOINT = "http://121.176.153.117:5000" | |
API_ENDPOINT = PUBLIC_API_ENDPOINT | |
session_id = "" | |
chat_history = [] | |
css = """ | |
#audio_input { | |
margin-top: -30px; !important; | |
margin-left: -15px; !important; | |
width: 100% !important; | |
} | |
#audio_input button { | |
height:50px !important; | |
font-size: 0px !important; | |
width: 110% !important; | |
} | |
#audio_input button:after { | |
content: '🎤' !important; | |
font-size: 16px !important; | |
} | |
audio { | |
min-width: 200px !important; | |
} | |
@media (max-width : 480px) { | |
#audio_input { | |
width: 120% !important; | |
} | |
#audio_input button:after { | |
content: '' !important; | |
} | |
#txt_input_container { | |
flex-grow: 70% !important; | |
} | |
#audio_input_container { | |
flex-grow: 30% !important; | |
} | |
} | |
""" | |
js_audio_auto_play = """ | |
() => { | |
// select last audio element | |
const audio = document.getElementsByTagName('audio'); | |
const last_audio = audio[audio.length - 1]; | |
// set autoplay attribute | |
last_audio.setAttribute('autoplay', true); | |
} | |
""" | |
def create_chat_session(): | |
r = requests.post(API_ENDPOINT + "/create") | |
if (r.status_code != 201): | |
raise Exception("Failed to create chat session") | |
# create temp audio folder | |
session_id = r.json()["id"] | |
os.makedirs(f"./temp_audio/{session_id}") | |
return session_id | |
def create_new_or_change_session(history, id): | |
global session_id | |
global chat_history | |
if id == "": | |
session_id = create_chat_session() | |
history = [] | |
else: | |
history, _ = change_session(history, id) | |
chat_history = history | |
return history, gr.update(value="", interactive=False) | |
def add_text(history, text): | |
history = history + [(text, None)] | |
return history, gr.update(value="", interactive=False) | |
def add_audio(history, audio): | |
audio_bytes = base64.b64decode(audio['data'].split(',')[-1].encode('utf-8')) | |
audio_file = BytesIO(audio_bytes) | |
AudioSegment.from_file(audio_file).export(audio_file, format="mp3") | |
# save audio file temporary to disk | |
audio_id = secrets.token_hex(8) | |
AudioSegment.from_file(audio_file).export(f"temp_audio/{session_id}/audio_input_{audio_id}.mp3", format="mp3") | |
history = history + [((f"temp_audio/{session_id}/audio_input_{audio_id}.mp3",), None)] | |
response = requests.post( | |
API_ENDPOINT + "/transcribe", | |
files={'audio': audio_file.getvalue()} | |
) | |
if (response.status_code != 200): | |
raise Exception(response.text) | |
text = response.json()['text'] | |
history = history + [(text, None)] | |
return history, gr.update(value="", interactive=False) | |
def reset_chat_session(history): | |
global session_id | |
global chat_history | |
response = requests.post( | |
API_ENDPOINT + f"/reset/{session_id}" | |
) | |
if (response.status_code != 200): | |
raise Exception(response.text) | |
history = [] | |
chat_history = [] | |
return history | |
def bot(history): | |
if type(history[-1][0]) == str: | |
message = history[-1][0] | |
else: | |
message = history[-2][0] | |
response = requests.post( | |
API_ENDPOINT + f"/send/text/{session_id}", | |
headers={'Content-type': 'application/json'}, | |
json={ | |
'message': message, | |
'role': 'user' | |
} | |
) | |
if (response.status_code != 200): | |
raise Exception(f"Failed to send message, {response.text}") | |
response = response.json() | |
text, audio = response['text'], response['audio'] | |
audio_bytes = base64.b64decode(audio.encode('utf-8')) | |
audio_file = BytesIO(audio_bytes) | |
audio_id = secrets.token_hex(8) | |
AudioSegment.from_file(audio_file).export(f"temp_audio/{session_id}/audio_input_{audio_id}.mp3", format="mp3") | |
history = history + [(None, (f"temp_audio/{session_id}/audio_input_{audio_id}.mp3",))] | |
history = history + [(None, text)] | |
global chat_history | |
chat_history = history.copy() | |
return history | |
def change_session(history, id): | |
global session_id | |
global chat_history | |
response = requests.get( | |
API_ENDPOINT + f"/{id}" | |
) | |
if (response.status_code != 200): | |
raise Exception(response.text) | |
response = response.json() | |
session_id = id | |
history = [] | |
try: | |
for chat in response: | |
if chat['role'] == 'user': | |
if chat['audio'] != "": | |
audio_bytes = base64.b64decode(chat['audio'].encode('utf-8')) | |
audio_file = BytesIO(audio_bytes) | |
audio_id = secrets.token_hex(8) | |
AudioSegment.from_file(audio_file).export(f"temp_audio/{id}/audio_input_{audio_id}.mp3", format="mp3") | |
history = history + [((f"temp_audio/{id}/audio_input_{audio_id}.mp3",), None)] | |
history = history + [(chat['message'], None)] | |
elif chat['role'] == 'assistant': | |
audio_bytes = base64.b64decode(chat['audio'].encode('utf-8')) | |
audio_file = BytesIO(audio_bytes) | |
audio_id = secrets.token_hex(8) | |
AudioSegment.from_file(audio_file).export(f"temp_audio/{id}/audio_input_{audio_id}.mp3", format="mp3") | |
history = history + [(None, (f"temp_audio/{id}/audio_input_{audio_id}.mp3",))] | |
history = history + [(None, chat['message'])] | |
else: | |
raise Exception("Invalid chat role") | |
except Exception as e: | |
raise Exception(f"Response: {response}") | |
chat_history = history.copy() | |
print(f"len(chat_history): {len(chat_history)}\nlen(history): {len(history)}\nlen(response): {len(response)}") | |
return history, gr.update(value="", interactive=False) | |
def load_chat_history(history): | |
global chat_history | |
if len(chat_history) > len(history): | |
history = chat_history | |
return history | |
def main(): | |
global session_id | |
global chat_history | |
session_id = create_chat_session() | |
chat_history = [] | |
with gr.Blocks(css=css) as demo: | |
with gr.Row(): | |
# change session id | |
change_session_txt = gr.Textbox( | |
show_label=False, | |
placeholder=session_id, | |
).style(container=False) | |
with gr.Row(): | |
# button to create new or change session id | |
change_session_button = gr.Button( | |
"Create new or change session", type='success', size="sm" | |
).style(container=False) | |
chatbot = gr.Chatbot([], elem_id="chatbot").style(height=750) | |
demo.load(load_chat_history, [chatbot], [chatbot], queue=False) | |
with gr.Row(): | |
with gr.Column(scale=0.85, min_width=0, elem_id="txt_input_container"): | |
txt = gr.Textbox( | |
show_label=False, | |
placeholder="Enter text and press enter, or record audio", | |
elem_id="txt_input" | |
).style(container=False) | |
with gr.Column(scale=0.15, min_width=0, elem_id="audio_input_container"): | |
audio = gr.Audio( | |
source="microphone", type="numpy", show_label=False, format="mp3", min_width=0, container=False, elem_id="audio_input" | |
) | |
with gr.Row(): | |
reset_button = gr.Button( | |
"Reset Chat Session", type='stop', size="sm" | |
).style(container=False) | |
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then( | |
bot, chatbot, chatbot | |
).then( | |
None, [], [], queue=False, _js=js_audio_auto_play | |
) | |
txt_msg.then(lambda: gr.update(interactive=True), None, [txt], queue=False) | |
audio_msg = audio.change(add_audio, [chatbot, audio], [chatbot, audio], queue=False, preprocess=False, postprocess=False).then( | |
bot, chatbot, chatbot | |
).then( | |
None, [], [], queue=False, _js=js_audio_auto_play | |
) | |
audio_msg.then(lambda: gr.update(interactive=True, value=None), None, [audio], queue=False) | |
reset_button.click(reset_chat_session, [chatbot], [chatbot], queue=False) | |
chgn_msg = change_session_txt.submit(change_session, [chatbot, change_session_txt], [chatbot, change_session_txt], queue=False) | |
chgn_msg.then(lambda: gr.update(interactive=True, placeholder=session_id), None, [change_session_txt], queue=False) | |
create_new_or_change_session_btn = change_session_button.click(create_new_or_change_session, [chatbot, change_session_txt], [chatbot, change_session_txt], queue=False) | |
create_new_or_change_session_btn.then(lambda: gr.update(interactive=True, placeholder=session_id), None, [change_session_txt], queue=False) | |
return demo | |
if __name__ == "__main__": | |
# arguments --local | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--local", action="store_true", help="Use local API endpoint") | |
args = parser.parse_args() | |
if args.local: | |
API_ENDPOINT = LOCAL_API_ENDPOINT | |
demo = main() | |
demo.launch(show_error=True, server_name="0.0.0.0") |