YingxuHe commited on
Commit
e9402b5
·
1 Parent(s): 89ed0ae

add mic button

Browse files
src/content/agent.py CHANGED
@@ -1,19 +1,18 @@
1
- import copy
2
- import base64
3
-
4
  import streamlit as st
5
 
6
- from src.generation import MAX_AUDIO_LENGTH
7
  from src.retrieval import STANDARD_QUERIES, retrieve_relevant_docs
8
- from src.utils import bytes_to_array, array_to_bytes
9
  from src.content.common import (
10
  MODEL_NAMES,
11
  AUDIO_SAMPLES_W_INSTRUCT,
12
  AGENT_DIALOGUE_STATES,
 
 
13
  init_state_section,
14
  header_section,
15
  sidebar_fragment,
16
- reset_states,
 
17
  retrive_response_with_ui
18
  )
19
 
@@ -42,103 +41,29 @@ However, the audio analysis may or may not contain relevant information to the u
42
  AUDIO_ANALYSIS_STATUS = "MERaLiON-AudioLLM Analysis"
43
 
44
 
45
- def _update_audio(audio_bytes):
46
- origin_audio_array = bytes_to_array(audio_bytes)
47
- truncated_audio_array = origin_audio_array[: MAX_AUDIO_LENGTH*16000]
48
- truncated_audio_bytes = array_to_bytes(truncated_audio_array)
49
-
50
- st.session_state.ag_audio_array = origin_audio_array
51
- st.session_state.ag_audio_base64 = base64.b64encode(truncated_audio_bytes).decode('utf-8')
52
-
53
-
54
- @st.fragment
55
- def successful_example_section():
56
- audio_sample_names = [name for name in AUDIO_SAMPLES_W_INSTRUCT.keys() if "Paral" in name]
57
-
58
- st.markdown(":fire: **Successful Tasks and Examples**")
59
-
60
- sample_name = st.selectbox(
61
- label="**Select Audio:**",
62
- label_visibility="collapsed",
63
- options=audio_sample_names,
64
- format_func=lambda o: AUDIO_SAMPLES_W_INSTRUCT[o]["apperance"],
65
- index=None,
66
- placeholder="Select an audio sample:",
67
- on_change=lambda: st.session_state.update(
68
- on_select=True,
69
- ag_messages=[],
70
- ag_model_messages=[],
71
- ag_visited_query_indices=[],
72
- disprompt=True
73
- ),
74
- key='select')
75
-
76
- if sample_name and st.session_state.on_select:
77
- audio_bytes = open(f"audio_samples/{sample_name}.wav", "rb").read()
78
- st.session_state.update(
79
- on_select=False,
80
- new_prompt=AUDIO_SAMPLES_W_INSTRUCT[sample_name]["instructions"][0]
81
- )
82
- _update_audio(audio_bytes)
83
- st.rerun(scope="app")
84
-
85
-
86
- @st.dialog("Specify Audio")
87
- def audio_attach_dialogue():
88
- st.markdown("**Upload**")
89
-
90
- uploaded_file = st.file_uploader(
91
- label="**Upload Audio:**",
92
- label_visibility="collapsed",
93
- type=['wav', 'mp3'],
94
- on_change=lambda: st.session_state.update(
95
- on_upload=True,
96
- ag_messages=[],
97
- ag_model_messages=[],
98
- ag_visited_query_indices=[]
99
- ),
100
- key='upload'
101
- )
102
-
103
- if uploaded_file and st.session_state.on_upload:
104
- audio_bytes = uploaded_file.read()
105
- _update_audio(audio_bytes)
106
- st.session_state.on_upload = False
107
- st.rerun()
108
-
109
- st.markdown("**Record**")
110
-
111
- uploaded_file = st.audio_input(
112
- label="**Record Audio:**",
113
- label_visibility="collapsed",
114
- on_change=lambda: st.session_state.update(
115
- on_record=True,
116
- ag_messages=[],
117
- ag_model_messages=[],
118
- ag_visited_query_indices=[]
119
- ),
120
- key='record'
121
- )
122
-
123
- if uploaded_file and st.session_state.on_record:
124
- audio_bytes = uploaded_file.read()
125
- _update_audio(audio_bytes)
126
- st.session_state.on_record = False
127
- st.rerun()
128
 
129
 
130
  def bottom_input_section():
131
- bottom_cols = st.columns([0.03, 0.03, 0.94])
132
  with bottom_cols[0]:
133
  st.button(
134
- 'Clear',
135
  disabled=st.session_state.disprompt,
136
  on_click=lambda: reset_states(AGENT_DIALOGUE_STATES)
137
  )
138
 
139
  with bottom_cols[1]:
140
- if st.button("\+ Audio", disabled=st.session_state.disprompt):
141
- audio_attach_dialogue()
 
 
 
 
142
 
143
  with bottom_cols[2]:
144
  if chat_input := st.chat_input(
@@ -148,6 +73,23 @@ def bottom_input_section():
148
  ):
149
  st.session_state.new_prompt = chat_input
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
  def _prepare_final_prompt_with_ui(one_time_prompt):
153
  if st.session_state.ag_audio_array.shape[0] == 0:
@@ -216,9 +158,7 @@ def conversation_section():
216
  st.audio(st.session_state.ag_audio_array, format="audio/wav", sample_rate=16000)
217
 
218
  for message in st.session_state.ag_messages:
219
- message_name = "assistant" if "assistant" in message["role"] else message["role"]
220
-
221
- with chat_message_container.chat_message(name=message_name):
222
  if message.get("error"):
223
  st.error(message["error"])
224
  for warning_msg in message.get("warnings", []):
@@ -238,38 +178,73 @@ def conversation_section():
238
  with st._bottom:
239
  bottom_input_section()
240
 
241
- if one_time_prompt := st.session_state.new_prompt:
242
- st.session_state.update(new_prompt="")
 
 
 
 
 
 
 
 
 
 
243
 
244
- with chat_message_container.chat_message("user"):
 
 
 
 
 
 
 
 
 
 
 
 
245
  st.write(one_time_prompt)
246
- st.session_state.ag_messages.append({"role": "user", "content": one_time_prompt})
247
-
248
- with chat_message_container.chat_message("assistant"):
249
- assistant_message = {"role": "assistant", "process": []}
250
- st.session_state.ag_messages.append(assistant_message)
251
 
252
- final_prompt = _prepare_final_prompt_with_ui(one_time_prompt)
253
-
254
- error_msg, warnings, response = retrive_response_with_ui(
255
- model_name=MODEL_NAMES["llm"]["vllm_name"],
256
- text_input=final_prompt,
257
- array_audio_input=st.session_state.ag_audio_array,
258
- base64_audio_input="",
259
- prefix=f"**{MODEL_NAMES['llm']['ui_name']}**: ",
260
- stream=True,
261
- history=st.session_state.ag_model_messages,
262
- show_warning=False
263
- )
264
-
265
- assistant_message.update({"error": error_msg, "warnings": warnings, "content": response})
266
- st.session_state.ag_model_messages.extend([
267
- {"role": "user", "content": final_prompt},
268
- {"role": "assistant", "content": response}
269
- ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
 
271
- st.session_state.disprompt=False
272
- st.rerun(scope="app")
273
 
274
 
275
  def agent_page():
@@ -286,5 +261,12 @@ def agent_page():
286
  with st.sidebar:
287
  sidebar_fragment()
288
 
289
- successful_example_section()
 
 
 
 
 
 
 
290
  conversation_section()
 
1
+ import numpy as np
 
 
2
  import streamlit as st
3
 
 
4
  from src.retrieval import STANDARD_QUERIES, retrieve_relevant_docs
 
5
  from src.content.common import (
6
  MODEL_NAMES,
7
  AUDIO_SAMPLES_W_INSTRUCT,
8
  AGENT_DIALOGUE_STATES,
9
+ reset_states,
10
+ update_voice_instruction_state,
11
  init_state_section,
12
  header_section,
13
  sidebar_fragment,
14
+ successful_example_section,
15
+ audio_attach_dialogue,
16
  retrive_response_with_ui
17
  )
18
 
 
41
  AUDIO_ANALYSIS_STATUS = "MERaLiON-AudioLLM Analysis"
42
 
43
 
44
+ AG_CONVERSATION_STATES = dict(
45
+ ag_messages=[],
46
+ ag_model_messages=[],
47
+ ag_visited_query_indices=[],
48
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
 
51
  def bottom_input_section():
52
+ bottom_cols = st.columns([0.03, 0.03, 0.91, 0.03])
53
  with bottom_cols[0]:
54
  st.button(
55
+ ':material/delete:',
56
  disabled=st.session_state.disprompt,
57
  on_click=lambda: reset_states(AGENT_DIALOGUE_STATES)
58
  )
59
 
60
  with bottom_cols[1]:
61
+ if st.button(":material/add:", disabled=st.session_state.disprompt):
62
+ audio_attach_dialogue(
63
+ audio_array_state="ag_audio_array",
64
+ audio_base64_state="ag_audio_base64",
65
+ restore_state=AG_CONVERSATION_STATES
66
+ )
67
 
68
  with bottom_cols[2]:
69
  if chat_input := st.chat_input(
 
73
  ):
74
  st.session_state.new_prompt = chat_input
75
 
76
+ with bottom_cols[3]:
77
+ uploaded_voice = st.audio_input(
78
+ label="voice_instruction",
79
+ label_visibility="collapsed",
80
+ disabled=st.session_state.disprompt,
81
+ on_change=lambda: st.session_state.update(
82
+ disprompt=True,
83
+ on_record_voice_instruction=True
84
+ ),
85
+ key='voice_instruction'
86
+ )
87
+
88
+ if uploaded_voice and st.session_state.on_record_voice_instruction:
89
+ voice_bytes = uploaded_voice.read()
90
+ update_voice_instruction_state(voice_bytes)
91
+ st.session_state.on_record_voice_instruction = False
92
+
93
 
94
  def _prepare_final_prompt_with_ui(one_time_prompt):
95
  if st.session_state.ag_audio_array.shape[0] == 0:
 
158
  st.audio(st.session_state.ag_audio_array, format="audio/wav", sample_rate=16000)
159
 
160
  for message in st.session_state.ag_messages:
161
+ with chat_message_container.chat_message(name=message["role"]):
 
 
162
  if message.get("error"):
163
  st.error(message["error"])
164
  for warning_msg in message.get("warnings", []):
 
178
  with st._bottom:
179
  bottom_input_section()
180
 
181
+ if (not st.session_state.new_prompt) and (not st.session_state.new_vi_base64):
182
+ return
183
+
184
+ one_time_prompt = st.session_state.new_prompt
185
+ one_time_vi_array = st.session_state.new_vi_array
186
+ one_time_vi_base64 = st.session_state.new_vi_base64
187
+
188
+ st.session_state.update(
189
+ new_prompt="",
190
+ new_vi_array=np.array([]),
191
+ new_vi_base64="",
192
+ )
193
 
194
+ with chat_message_container.chat_message("user"):
195
+ if one_time_vi_base64:
196
+ with st.spinner("Transcribing..."):
197
+ error_msg, warnings, one_time_prompt = retrive_response_with_ui(
198
+ model_name=MODEL_NAMES["audiollm"]["vllm_name"],
199
+ text_input="Write out the dialogue as text.",
200
+ array_audio_input=one_time_vi_array,
201
+ base64_audio_input=one_time_vi_base64,
202
+ stream=False,
203
+ normalise_response=True
204
+ )
205
+ else:
206
+ error_msg, warnings = "", []
207
  st.write(one_time_prompt)
 
 
 
 
 
208
 
209
+ st.session_state.ag_messages.append({
210
+ "role": "user",
211
+ "error": error_msg,
212
+ "warnings": warnings,
213
+ "content": one_time_prompt
214
+ })
215
+
216
+ with chat_message_container.chat_message("assistant"):
217
+ assistant_message = {"role": "assistant", "process": []}
218
+ st.session_state.ag_messages.append(assistant_message)
219
+
220
+ final_prompt = _prepare_final_prompt_with_ui(one_time_prompt)
221
+
222
+ llm_response_prefix = f"**{MODEL_NAMES['llm']['ui_name']}**: "
223
+ error_msg, warnings, response = retrive_response_with_ui(
224
+ model_name=MODEL_NAMES["llm"]["vllm_name"],
225
+ text_input=final_prompt,
226
+ array_audio_input=st.session_state.ag_audio_array,
227
+ base64_audio_input="",
228
+ prefix=llm_response_prefix,
229
+ stream=True,
230
+ history=st.session_state.ag_model_messages,
231
+ show_warning=False
232
+ )
233
+
234
+ assistant_message.update({
235
+ "error": error_msg,
236
+ "warnings": warnings,
237
+ "content": response
238
+ })
239
+
240
+ pure_response = response.replace(llm_response_prefix, "")
241
+ st.session_state.ag_model_messages.extend([
242
+ {"role": "user", "content": final_prompt},
243
+ {"role": "assistant", "content": pure_response}
244
+ ])
245
 
246
+ st.session_state.disprompt=False
247
+ st.rerun(scope="app")
248
 
249
 
250
  def agent_page():
 
261
  with st.sidebar:
262
  sidebar_fragment()
263
 
264
+ audio_sample_names = [name for name in AUDIO_SAMPLES_W_INSTRUCT.keys() if "Paral" in name]
265
+
266
+ successful_example_section(
267
+ audio_sample_names,
268
+ audio_array_state="ag_audio_array",
269
+ audio_base64_state="ag_audio_base64",
270
+ restore_state=AG_CONVERSATION_STATES
271
+ )
272
  conversation_section()
src/content/common.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import copy
 
3
  import itertools
4
  from collections import OrderedDict
5
  from typing import List, Optional
@@ -8,9 +9,15 @@ import numpy as np
8
  import streamlit as st
9
 
10
  from src.tunnel import start_server
11
- from src.generation import FIXED_GENERATION_CONFIG, load_model, retrive_response
12
  from src.retrieval import load_retriever
13
  from src.logger import load_logger
 
 
 
 
 
 
 
14
 
15
 
16
  PLAYGROUND_DIALOGUE_STATES = dict(
@@ -40,10 +47,13 @@ AGENT_DIALOGUE_STATES = dict(
40
  COMMON_DIALOGUE_STATES = dict(
41
  disprompt=False,
42
  new_prompt="",
 
 
43
  on_select=False,
44
  on_upload=False,
45
  on_record=False,
46
- on_select_quick_action=False
 
47
  )
48
 
49
 
@@ -319,6 +329,26 @@ AUDIO_SAMPLES_W_INSTRUCT = {
319
  exec(os.getenv('APP_CONFIGS'))
320
 
321
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
  def init_state_section():
323
  st.set_page_config(page_title='MERaLiON-AudioLLM', page_icon = "🔥", layout='wide')
324
 
@@ -397,10 +427,75 @@ def sidebar_fragment():
397
  st.slider(label="Repetition Penalty", min_value=1.0, max_value=1.2, value=1.1, key="repetition_penalty")
398
 
399
 
400
- def reset_states(*state_dicts):
401
- for states in state_dicts:
402
- st.session_state.update(copy.deepcopy(states))
403
- st.session_state.update(copy.deepcopy(COMMON_DIALOGUE_STATES))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
 
405
 
406
  def retrive_response_with_ui(
@@ -410,6 +505,7 @@ def retrive_response_with_ui(
410
  base64_audio_input: str,
411
  prefix: str = "",
412
  stream: bool = True,
 
413
  history: Optional[List] = None,
414
  show_warning: bool = True,
415
  **kwargs
@@ -455,7 +551,10 @@ def retrive_response_with_ui(
455
  response = st.write_stream(response_obj)
456
  else:
457
  response = response_obj.choices[0].message.content
458
- st.write(prefix+response)
 
 
 
459
 
460
  st.session_state.logger.register_query(
461
  session_id=st.session_state.session_id,
 
1
  import os
2
  import copy
3
+ import base64
4
  import itertools
5
  from collections import OrderedDict
6
  from typing import List, Optional
 
9
  import streamlit as st
10
 
11
  from src.tunnel import start_server
 
12
  from src.retrieval import load_retriever
13
  from src.logger import load_logger
14
+ from src.utils import array_to_bytes, bytes_to_array, postprocess_voice_transcription
15
+ from src.generation import (
16
+ FIXED_GENERATION_CONFIG,
17
+ MAX_AUDIO_LENGTH,
18
+ load_model,
19
+ retrive_response
20
+ )
21
 
22
 
23
  PLAYGROUND_DIALOGUE_STATES = dict(
 
47
  COMMON_DIALOGUE_STATES = dict(
48
  disprompt=False,
49
  new_prompt="",
50
+ new_vi_array=np.array([]),
51
+ new_vi_base64="",
52
  on_select=False,
53
  on_upload=False,
54
  on_record=False,
55
+ on_select_quick_action=False,
56
+ on_record_voice_instruction=False
57
  )
58
 
59
 
 
329
  exec(os.getenv('APP_CONFIGS'))
330
 
331
 
332
+ def reset_states(*state_dicts):
333
+ for states in state_dicts:
334
+ st.session_state.update(copy.deepcopy(states))
335
+ st.session_state.update(copy.deepcopy(COMMON_DIALOGUE_STATES))
336
+
337
+
338
+ def process_audio_bytes(audio_bytes):
339
+ origin_audio_array = bytes_to_array(audio_bytes)
340
+ truncated_audio_array = origin_audio_array[: MAX_AUDIO_LENGTH*16000]
341
+ truncated_audio_bytes = array_to_bytes(truncated_audio_array)
342
+ audio_base64 = base64.b64encode(truncated_audio_bytes).decode('utf-8')
343
+
344
+ return origin_audio_array, audio_base64
345
+
346
+
347
+ def update_voice_instruction_state(voice_bytes):
348
+ st.session_state.new_vi_array, st.session_state.new_vi_base64 = \
349
+ process_audio_bytes(voice_bytes)
350
+
351
+
352
  def init_state_section():
353
  st.set_page_config(page_title='MERaLiON-AudioLLM', page_icon = "🔥", layout='wide')
354
 
 
427
  st.slider(label="Repetition Penalty", min_value=1.0, max_value=1.2, value=1.1, key="repetition_penalty")
428
 
429
 
430
+ @st.fragment
431
+ def successful_example_section(audio_sample_names, audio_array_state, audio_base64_state, restore_state={}):
432
+ st.markdown(":fire: **Successful Tasks and Examples**")
433
+
434
+ sample_name = st.selectbox(
435
+ label="**Select Audio:**",
436
+ label_visibility="collapsed",
437
+ options=audio_sample_names,
438
+ format_func=lambda o: AUDIO_SAMPLES_W_INSTRUCT[o]["apperance"],
439
+ index=None,
440
+ placeholder="Select an audio sample:",
441
+ on_change=lambda: st.session_state.update(
442
+ on_select=True,
443
+ disprompt=True,
444
+ **copy.deepcopy(restore_state)
445
+ ),
446
+ key='select')
447
+
448
+ if sample_name and st.session_state.on_select:
449
+ audio_bytes = open(f"audio_samples/{sample_name}.wav", "rb").read()
450
+ st.session_state.update(
451
+ on_select=False,
452
+ new_prompt=AUDIO_SAMPLES_W_INSTRUCT[sample_name]["instructions"][0]
453
+ )
454
+ st.session_state[audio_array_state], st.session_state[audio_base64_state] = \
455
+ process_audio_bytes(audio_bytes)
456
+ st.rerun(scope="app")
457
+
458
+
459
+ @st.dialog("Specify audio context for analysis")
460
+ def audio_attach_dialogue(audio_array_state, audio_base64_state, restore_state={}):
461
+ st.markdown("**Upload**")
462
+
463
+ uploaded_file = st.file_uploader(
464
+ label="**Upload Audio:**",
465
+ label_visibility="collapsed",
466
+ type=['wav', 'mp3'],
467
+ on_change=lambda: st.session_state.update(
468
+ on_upload=True,
469
+ **copy.deepcopy(restore_state)
470
+ ),
471
+ key='upload'
472
+ )
473
+
474
+ if uploaded_file and st.session_state.on_upload:
475
+ audio_bytes = uploaded_file.read()
476
+ st.session_state[audio_array_state], st.session_state[audio_base64_state] = \
477
+ process_audio_bytes(audio_bytes)
478
+ st.session_state.on_upload = False
479
+ st.rerun()
480
+
481
+ st.markdown("**Record**")
482
+
483
+ uploaded_file = st.audio_input(
484
+ label="**Record Audio:**",
485
+ label_visibility="collapsed",
486
+ on_change=lambda: st.session_state.update(
487
+ on_record=True,
488
+ **copy.deepcopy(restore_state)
489
+ ),
490
+ key='record'
491
+ )
492
+
493
+ if uploaded_file and st.session_state.on_record:
494
+ audio_bytes = uploaded_file.read()
495
+ st.session_state[audio_array_state], st.session_state[audio_base64_state] = \
496
+ process_audio_bytes(audio_bytes)
497
+ st.session_state.on_record = False
498
+ st.rerun()
499
 
500
 
501
  def retrive_response_with_ui(
 
505
  base64_audio_input: str,
506
  prefix: str = "",
507
  stream: bool = True,
508
+ normalise_response: bool = False,
509
  history: Optional[List] = None,
510
  show_warning: bool = True,
511
  **kwargs
 
551
  response = st.write_stream(response_obj)
552
  else:
553
  response = response_obj.choices[0].message.content
554
+ if normalise_response:
555
+ response = postprocess_voice_transcription(response)
556
+ response = prefix + response
557
+ st.write(response)
558
 
559
  st.session_state.logger.register_query(
560
  session_id=st.session_state.session_id,
src/content/playground.py CHANGED
@@ -1,18 +1,17 @@
1
- import copy
2
- import base64
3
-
4
  import streamlit as st
5
 
6
- from src.generation import MAX_AUDIO_LENGTH
7
- from src.utils import bytes_to_array, array_to_bytes
8
  from src.content.common import (
9
  MODEL_NAMES,
10
  AUDIO_SAMPLES_W_INSTRUCT,
11
  PLAYGROUND_DIALOGUE_STATES,
 
 
12
  init_state_section,
13
  header_section,
14
  sidebar_fragment,
15
- reset_states,
 
16
  retrive_response_with_ui
17
  )
18
 
@@ -31,86 +30,22 @@ QUICK_ACTIONS = [
31
  ]
32
 
33
 
34
- def _update_audio(audio_bytes):
35
- origin_audio_array = bytes_to_array(audio_bytes)
36
- truncated_audio_array = origin_audio_array[: MAX_AUDIO_LENGTH*16000]
37
- truncated_audio_bytes = array_to_bytes(truncated_audio_array)
38
-
39
- st.session_state.pg_audio_array = origin_audio_array
40
- st.session_state.pg_audio_base64 = base64.b64encode(truncated_audio_bytes).decode('utf-8')
41
-
42
-
43
- @st.fragment
44
- def successful_example_section():
45
- audio_sample_names = [audio_sample_name for audio_sample_name in AUDIO_SAMPLES_W_INSTRUCT.keys()]
46
-
47
- st.markdown(":fire: **Successful Tasks and Examples**")
48
-
49
- sample_name = st.selectbox(
50
- label="**Select Audio:**",
51
- label_visibility="collapsed",
52
- options=audio_sample_names,
53
- format_func=lambda o: AUDIO_SAMPLES_W_INSTRUCT[o]["apperance"],
54
- index=None,
55
- placeholder="Select an audio sample:",
56
- on_change=lambda: st.session_state.update(
57
- on_select=True,
58
- pg_messages=[],
59
- disprompt=True
60
- ),
61
- key='select')
62
-
63
- if sample_name and st.session_state.on_select:
64
- audio_bytes = open(f"audio_samples/{sample_name}.wav", "rb").read()
65
- st.session_state.update(
66
- on_select=False,
67
- new_prompt=AUDIO_SAMPLES_W_INSTRUCT[sample_name]["instructions"][0]
68
- )
69
- _update_audio(audio_bytes)
70
- st.rerun(scope="app")
71
-
72
-
73
- @st.dialog("Specify Audio")
74
- def audio_attach_dialogue():
75
- st.markdown("**Upload**")
76
-
77
- uploaded_file = st.file_uploader(
78
- label="**Upload Audio:**",
79
- label_visibility="collapsed",
80
- type=['wav', 'mp3'],
81
- on_change=lambda: st.session_state.update(on_upload=True, pg_messages=[]),
82
- key='upload'
83
- )
84
-
85
- if uploaded_file and st.session_state.on_upload:
86
- audio_bytes = uploaded_file.read()
87
- _update_audio(audio_bytes)
88
- st.session_state.on_upload = False
89
- st.rerun()
90
-
91
- st.markdown("**Record**")
92
-
93
- uploaded_file = st.audio_input(
94
- label="**Record Audio:**",
95
- label_visibility="collapsed",
96
- on_change=lambda: st.session_state.update(on_record=True, pg_messages=[]),
97
- key='record'
98
- )
99
-
100
- if uploaded_file and st.session_state.on_record:
101
- audio_bytes = uploaded_file.read()
102
- _update_audio(audio_bytes)
103
- st.session_state.on_record = False
104
- st.rerun()
105
 
106
 
107
  @st.fragment
108
  def select_model_variants_fradment():
109
- display_mapper = {value["vllm_name"]: value["ui_name"] for value in MODEL_NAMES.values()}
 
 
 
 
110
 
111
  st.selectbox(
112
  label=":fire: Explore more MERaLiON-AudioLLM variants!",
113
- options=[value["vllm_name"] for value in MODEL_NAMES.values()],
114
  index=0,
115
  format_func=lambda o: display_mapper[o],
116
  key="pg_model_name",
@@ -122,27 +57,52 @@ def select_model_variants_fradment():
122
  def bottom_input_section():
123
  select_model_variants_fradment()
124
 
125
- bottom_cols = st.columns([0.03, 0.03, 0.94])
126
  with bottom_cols[0]:
127
  st.button(
128
- 'Clear',
129
  disabled=st.session_state.disprompt,
130
  on_click=lambda: reset_states(PLAYGROUND_DIALOGUE_STATES)
131
  )
132
 
133
  with bottom_cols[1]:
134
- if st.button("\+ Audio", disabled=st.session_state.disprompt):
135
- audio_attach_dialogue()
 
 
 
 
136
 
137
  with bottom_cols[2]:
138
  if chat_input := st.chat_input(
139
  placeholder="Instruction...",
140
  disabled=st.session_state.disprompt,
141
- on_submit=lambda: st.session_state.update(disprompt=True, pg_messages=[])
 
 
 
142
  ):
143
  st.session_state.new_prompt = chat_input
144
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
 
 
 
 
 
 
146
  @st.fragment
147
  def quick_actions_fragment():
148
  action_cols_spec = [_["width"] for _ in QUICK_ACTIONS]
@@ -184,32 +144,61 @@ def conversation_section():
184
  with st._bottom:
185
  bottom_input_section()
186
 
187
- if one_time_prompt := st.session_state.new_prompt:
188
- st.session_state.update(new_prompt="", pg_messages=[])
189
-
190
- with st.chat_message("user"):
191
- st.write(one_time_prompt)
192
- st.session_state.pg_messages.append({"role": "user", "content": one_time_prompt})
193
 
194
- with st.chat_message("assistant"):
195
- with st.spinner("Thinking..."):
196
- error_msg, warnings, response = retrive_response_with_ui(
197
- model_name=st.session_state.pg_model_name,
198
- text_input=one_time_prompt,
199
- array_audio_input=st.session_state.pg_audio_array,
200
- base64_audio_input=st.session_state.pg_audio_base64,
201
- stream=True
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  )
 
 
 
203
 
204
- st.session_state.pg_messages.append({
205
- "role": "assistant",
206
- "error": error_msg,
207
- "warnings": warnings,
208
- "content": response
209
- })
 
 
 
 
 
 
 
 
 
 
210
 
211
- st.session_state.disprompt=False
212
- st.rerun(scope="app")
 
 
 
 
 
 
 
213
 
214
 
215
  def playground_page():
@@ -223,14 +212,18 @@ def playground_page():
223
  <strong>Spoken Question Answering</strong>,
224
  <strong>Spoken Dialogue Summarization</strong>,
225
  <strong>Speech Instruction</strong>, and
226
- <strong>Paralinguistics</strong> tasks.
227
- This playground currently only support <strong>single-round</strong> conversation.
228
- """,
229
- concise_description=" This playground currently only support <strong>single-round</strong> conversation."
230
  )
231
 
232
  with st.sidebar:
233
  sidebar_fragment()
234
 
235
- successful_example_section()
 
 
 
 
 
 
236
  conversation_section()
 
1
+ import numpy as np
 
 
2
  import streamlit as st
3
 
 
 
4
  from src.content.common import (
5
  MODEL_NAMES,
6
  AUDIO_SAMPLES_W_INSTRUCT,
7
  PLAYGROUND_DIALOGUE_STATES,
8
+ reset_states,
9
+ update_voice_instruction_state,
10
  init_state_section,
11
  header_section,
12
  sidebar_fragment,
13
+ successful_example_section,
14
+ audio_attach_dialogue,
15
  retrive_response_with_ui
16
  )
17
 
 
30
  ]
31
 
32
 
33
+ PG_CONVERSATION_STATES = dict(
34
+ pg_messages=[],
35
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
 
38
  @st.fragment
39
  def select_model_variants_fradment():
40
+ display_mapper = {
41
+ value["vllm_name"]: value["ui_name"]
42
+ for key, value in MODEL_NAMES.items()
43
+ if "audiollm" in key
44
+ }
45
 
46
  st.selectbox(
47
  label=":fire: Explore more MERaLiON-AudioLLM variants!",
48
+ options=list(display_mapper.keys()),
49
  index=0,
50
  format_func=lambda o: display_mapper[o],
51
  key="pg_model_name",
 
57
  def bottom_input_section():
58
  select_model_variants_fradment()
59
 
60
+ bottom_cols = st.columns([0.03, 0.03, 0.91, 0.03])
61
  with bottom_cols[0]:
62
  st.button(
63
+ ':material/delete:',
64
  disabled=st.session_state.disprompt,
65
  on_click=lambda: reset_states(PLAYGROUND_DIALOGUE_STATES)
66
  )
67
 
68
  with bottom_cols[1]:
69
+ if st.button(":material/add:", disabled=st.session_state.disprompt):
70
+ audio_attach_dialogue(
71
+ audio_array_state="pg_audio_array",
72
+ audio_base64_state="pg_audio_base64",
73
+ restore_state=PG_CONVERSATION_STATES
74
+ )
75
 
76
  with bottom_cols[2]:
77
  if chat_input := st.chat_input(
78
  placeholder="Instruction...",
79
  disabled=st.session_state.disprompt,
80
+ on_submit=lambda: st.session_state.update(
81
+ disprompt=True,
82
+ **PG_CONVERSATION_STATES
83
+ )
84
  ):
85
  st.session_state.new_prompt = chat_input
86
 
87
+ with bottom_cols[3]:
88
+ uploaded_voice = st.audio_input(
89
+ label="voice_instruction",
90
+ label_visibility="collapsed",
91
+ disabled=st.session_state.disprompt,
92
+ on_change=lambda: st.session_state.update(
93
+ disprompt=True,
94
+ on_record_voice_instruction=True,
95
+ **PG_CONVERSATION_STATES
96
+ ),
97
+ key='voice_instruction'
98
+ )
99
 
100
+ if uploaded_voice and st.session_state.on_record_voice_instruction:
101
+ voice_bytes = uploaded_voice.read()
102
+ update_voice_instruction_state(voice_bytes)
103
+ st.session_state.on_record_voice_instruction = False
104
+
105
+
106
  @st.fragment
107
  def quick_actions_fragment():
108
  action_cols_spec = [_["width"] for _ in QUICK_ACTIONS]
 
144
  with st._bottom:
145
  bottom_input_section()
146
 
147
+ if (not st.session_state.new_prompt) and (not st.session_state.new_vi_base64):
148
+ return
 
 
 
 
149
 
150
+ one_time_prompt = st.session_state.new_prompt
151
+ one_time_vi_array = st.session_state.new_vi_array
152
+ one_time_vi_base64 = st.session_state.new_vi_base64
153
+
154
+ st.session_state.update(
155
+ new_prompt="",
156
+ new_vi_array=np.array([]),
157
+ new_vi_base64="",
158
+ pg_messages=[]
159
+ )
160
+
161
+ with st.chat_message("user"):
162
+ if one_time_vi_base64:
163
+ with st.spinner("Transcribing..."):
164
+ error_msg, warnings, one_time_prompt = retrive_response_with_ui(
165
+ model_name=MODEL_NAMES["audiollm"]["vllm_name"],
166
+ text_input="Write out the dialogue as text.",
167
+ array_audio_input=one_time_vi_array,
168
+ base64_audio_input=one_time_vi_base64,
169
+ stream=False,
170
+ normalise_response=True
171
  )
172
+ else:
173
+ error_msg, warnings = "", []
174
+ st.write(one_time_prompt)
175
 
176
+ st.session_state.pg_messages.append({
177
+ "role": "user",
178
+ "error": error_msg,
179
+ "warnings": warnings,
180
+ "content": one_time_prompt
181
+ })
182
+
183
+ with st.chat_message("assistant"):
184
+ with st.spinner("Thinking..."):
185
+ error_msg, warnings, response = retrive_response_with_ui(
186
+ model_name=st.session_state.pg_model_name,
187
+ text_input=one_time_prompt,
188
+ array_audio_input=st.session_state.pg_audio_array,
189
+ base64_audio_input=st.session_state.pg_audio_base64,
190
+ stream=True
191
+ )
192
 
193
+ st.session_state.pg_messages.append({
194
+ "role": "assistant",
195
+ "error": error_msg,
196
+ "warnings": warnings,
197
+ "content": response
198
+ })
199
+
200
+ st.session_state.disprompt=False
201
+ st.rerun(scope="app")
202
 
203
 
204
  def playground_page():
 
212
  <strong>Spoken Question Answering</strong>,
213
  <strong>Spoken Dialogue Summarization</strong>,
214
  <strong>Speech Instruction</strong>, and
215
+ <strong>Paralinguistics</strong> tasks.""",
216
+ concise_description=""
 
 
217
  )
218
 
219
  with st.sidebar:
220
  sidebar_fragment()
221
 
222
+ audio_sample_names = [name for name in AUDIO_SAMPLES_W_INSTRUCT.keys()]
223
+ successful_example_section(
224
+ audio_sample_names,
225
+ audio_array_state="pg_audio_array",
226
+ audio_base64_state="pg_audio_base64",
227
+ restore_state=PG_CONVERSATION_STATES
228
+ )
229
  conversation_section()
src/content/voice_chat.py CHANGED
@@ -1,24 +1,20 @@
1
- import copy
2
- import base64
3
-
4
  import numpy as np
5
  import streamlit as st
6
 
7
  from src.generation import (
8
- MAX_AUDIO_LENGTH,
9
  prepare_multimodal_content,
10
  change_multimodal_content
11
  )
12
  from src.content.common import (
13
  MODEL_NAMES,
14
  VOICE_CHAT_DIALOGUE_STATES,
 
 
15
  init_state_section,
16
  header_section,
17
  sidebar_fragment,
18
- reset_states,
19
  retrive_response_with_ui
20
  )
21
- from src.utils import bytes_to_array, array_to_bytes
22
 
23
 
24
  # TODO: change this.
@@ -26,20 +22,11 @@ DEFAULT_PROMPT = "Based on the information in this user’s voice, please reply
26
  MAX_VC_ROUNDS = 5
27
 
28
 
29
- def _update_audio(audio_bytes):
30
- origin_audio_array = bytes_to_array(audio_bytes)
31
- truncated_audio_array = origin_audio_array[: MAX_AUDIO_LENGTH*16000]
32
- truncated_audio_bytes = array_to_bytes(truncated_audio_array)
33
-
34
- st.session_state.vc_audio_array = origin_audio_array
35
- st.session_state.vc_audio_base64 = base64.b64encode(truncated_audio_bytes).decode('utf-8')
36
-
37
-
38
  def bottom_input_section():
39
  bottom_cols = st.columns([0.03, 0.97])
40
  with bottom_cols[0]:
41
  st.button(
42
- 'Clear',
43
  disabled=st.session_state.disprompt,
44
  on_click=lambda: reset_states(VOICE_CHAT_DIALOGUE_STATES)
45
  )
@@ -48,6 +35,7 @@ def bottom_input_section():
48
  uploaded_file = st.audio_input(
49
  label="record audio",
50
  label_visibility="collapsed",
 
51
  on_change=lambda: st.session_state.update(
52
  on_record=True,
53
  disprompt=True
@@ -57,7 +45,8 @@ def bottom_input_section():
57
 
58
  if uploaded_file and st.session_state.on_record:
59
  audio_bytes = uploaded_file.read()
60
- _update_audio(audio_bytes)
 
61
  st.session_state.update(
62
  on_record=False,
63
  )
@@ -69,6 +58,7 @@ def system_prompt_fragment():
69
  st.text_area(
70
  label="Insert system instructions or background knowledge here.",
71
  label_visibility="collapsed",
 
72
  max_chars=5000,
73
  key="system_prompt",
74
  value=DEFAULT_PROMPT,
@@ -151,9 +141,9 @@ def voice_chat_page():
151
  init_state_section()
152
  header_section(
153
  component_name="Voice Chat",
154
- description=""" It currently only support up to <strong>5 rounds</strong> of conversations.
155
  Feel free to talk about anything.""",
156
- concise_description=" It currently only support up to <strong>5 rounds</strong> of conversations.",
157
  icon="🗣️"
158
  )
159
 
 
 
 
 
1
  import numpy as np
2
  import streamlit as st
3
 
4
  from src.generation import (
 
5
  prepare_multimodal_content,
6
  change_multimodal_content
7
  )
8
  from src.content.common import (
9
  MODEL_NAMES,
10
  VOICE_CHAT_DIALOGUE_STATES,
11
+ reset_states,
12
+ process_audio_bytes,
13
  init_state_section,
14
  header_section,
15
  sidebar_fragment,
 
16
  retrive_response_with_ui
17
  )
 
18
 
19
 
20
  # TODO: change this.
 
22
  MAX_VC_ROUNDS = 5
23
 
24
 
 
 
 
 
 
 
 
 
 
25
  def bottom_input_section():
26
  bottom_cols = st.columns([0.03, 0.97])
27
  with bottom_cols[0]:
28
  st.button(
29
+ ':material/delete:',
30
  disabled=st.session_state.disprompt,
31
  on_click=lambda: reset_states(VOICE_CHAT_DIALOGUE_STATES)
32
  )
 
35
  uploaded_file = st.audio_input(
36
  label="record audio",
37
  label_visibility="collapsed",
38
+ disabled=st.session_state.disprompt,
39
  on_change=lambda: st.session_state.update(
40
  on_record=True,
41
  disprompt=True
 
45
 
46
  if uploaded_file and st.session_state.on_record:
47
  audio_bytes = uploaded_file.read()
48
+ st.session_state.vc_audio_array, st.session_state.vc_audio_base64 = \
49
+ process_audio_bytes(audio_bytes)
50
  st.session_state.update(
51
  on_record=False,
52
  )
 
58
  st.text_area(
59
  label="Insert system instructions or background knowledge here.",
60
  label_visibility="collapsed",
61
+ disabled=st.session_state.disprompt,
62
  max_chars=5000,
63
  key="system_prompt",
64
  value=DEFAULT_PROMPT,
 
141
  init_state_section()
142
  header_section(
143
  component_name="Voice Chat",
144
+ description=""" Currently support up to <strong>5 rounds</strong> of conversations.
145
  Feel free to talk about anything.""",
146
+ concise_description=" Currently support up to <strong>5 rounds</strong> of conversations.",
147
  icon="🗣️"
148
  )
149
 
src/utils.py CHANGED
@@ -1,4 +1,5 @@
1
  import io
 
2
  from datetime import datetime
3
  from scipy.io.wavfile import write
4
 
@@ -21,4 +22,10 @@ def array_to_bytes(audio_array):
21
  bytes_wav = bytes()
22
  byte_io = io.BytesIO(bytes_wav)
23
  write(byte_io, 16000, audio_array)
24
- return byte_io.read()
 
 
 
 
 
 
 
1
  import io
2
+ import re
3
  from datetime import datetime
4
  from scipy.io.wavfile import write
5
 
 
22
  bytes_wav = bytes()
23
  byte_io = io.BytesIO(bytes_wav)
24
  write(byte_io, 16000, audio_array)
25
+ return byte_io.read()
26
+
27
+
28
+ def postprocess_voice_transcription(text):
29
+ text = re.sub("<.*>:?|\(.*\)|\[.*\]", "", text)
30
+ text = re.sub("\s+", " ", text).strip()
31
+ return text
style/app_style.css CHANGED
@@ -88,15 +88,15 @@ div[data-testid="stBottomBlockContainer"] div[data-testid="stHorizontalBlock"]:h
88
  }
89
 
90
  div[data-testid="stBottomBlockContainer"] div[data-testid="stColumn"]:has( div[data-testid="stButton"]):first-of-type {
91
- width: 61px;
92
- min-width: 61px;
93
- flex: 0 0 61px;
94
  }
95
 
96
  div[data-testid="stBottomBlockContainer"] div[data-testid="stColumn"]:has( div[data-testid="stButton"]):nth-of-type(2) {
97
- width: 76px;
98
- min-width: 76px;
99
- flex: 0 0 76px;
100
  }
101
 
102
  div[data-testid="stBottomBlockContainer"] div[data-testid="stColumn"]:has( div[data-testid="stChatInput"]) {
@@ -113,4 +113,38 @@ div[data-testid="stBottomBlockContainer"] div[data-testid="stColumn"]:has( div[d
113
 
114
  div[data-testid="stBottomBlockContainer"] div[data-testid="stAudioInput"]>div {
115
  max-height: 40px;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  }
 
88
  }
89
 
90
  div[data-testid="stBottomBlockContainer"] div[data-testid="stColumn"]:has( div[data-testid="stButton"]):first-of-type {
91
+ width: 42px;
92
+ min-width: 42px;
93
+ flex: 0 0 42px;
94
  }
95
 
96
  div[data-testid="stBottomBlockContainer"] div[data-testid="stColumn"]:has( div[data-testid="stButton"]):nth-of-type(2) {
97
+ width: 42px;
98
+ min-width: 42px;
99
+ flex: 0 0 42px;
100
  }
101
 
102
  div[data-testid="stBottomBlockContainer"] div[data-testid="stColumn"]:has( div[data-testid="stChatInput"]) {
 
113
 
114
  div[data-testid="stBottomBlockContainer"] div[data-testid="stAudioInput"]>div {
115
  max-height: 40px;
116
+ }
117
+
118
+ /* Mic Button */
119
+
120
+ div[data-testid="stBottomBlockContainer"]:has( div[data-testid="stChatInput"]) div[data-testid="stAudioInput"]>div {
121
+ display: block;
122
+ padding: 0;
123
+ margin: auto;
124
+ }
125
+
126
+ div[data-testid="stBottomBlockContainer"]:has( div[data-testid="stChatInput"]) div[data-testid="stAudioInput"]>div>div:last-of-type {
127
+ display:none;
128
+ }
129
+
130
+ div[data-testid="stBottomBlockContainer"]:has( div[data-testid="stChatInput"]) div[data-testid="stAudioInput"]>div>div:nth-of-type(2) {
131
+ margin:auto;
132
+ }
133
+
134
+ div[data-testid="stBottomBlockContainer"]:has( div[data-testid="stChatInput"]) div[data-testid="stAudioInput"]>div>div:nth-of-type(2)>span:last-of-type {
135
+ display:none;
136
+ }
137
+
138
+ div[data-testid="stBottomBlockContainer"]:has( div[data-testid="stChatInput"]) div[data-testid="stAudioInput"]>div>div:nth-of-type(2)>span:only-of-type {
139
+ display:block;
140
+ }
141
+
142
+ div[data-testid="stBottomBlockContainer"]:has( div[data-testid="stChatInput"]) div[data-testid="stAudioInput"]>div>span {
143
+ display:none;
144
+ }
145
+
146
+ div[data-testid="stBottomBlockContainer"]:has( div[data-testid="stChatInput"]) div[data-testid="stColumn"]:has( div[data-testid="stAudioInput"]) {
147
+ width: 24px;
148
+ min-width: 24px;
149
+ flex: 0 0 24px;
150
  }