Spaces:
Running
Running
apply new backend
Browse files- .gitignore +2 -0
- src/content/agent.py +13 -2
- src/content/common.py +97 -88
- src/exceptions.py +0 -4
- src/generation.py +0 -140
- src/retrieval.py +0 -75
- src/tunnel.py +0 -72
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
.venv/
|
2 |
+
__pycache__/
|
src/content/agent.py
CHANGED
@@ -1,7 +1,10 @@
|
|
|
|
|
|
|
|
1 |
import numpy as np
|
2 |
import streamlit as st
|
3 |
|
4 |
-
from src.retrieval import STANDARD_QUERIES
|
5 |
from src.content.common import (
|
6 |
MODEL_NAMES,
|
7 |
AUDIO_SAMPLES_W_INSTRUCT,
|
@@ -17,6 +20,9 @@ from src.content.common import (
|
|
17 |
)
|
18 |
|
19 |
|
|
|
|
|
|
|
20 |
LLM_NO_AUDIO_PROMPT_TEMPLATE = """{user_question}"""
|
21 |
|
22 |
|
@@ -96,7 +102,12 @@ def _prepare_final_prompt_with_ui(one_time_prompt):
|
|
96 |
return LLM_NO_AUDIO_PROMPT_TEMPLATE.format(user_question=one_time_prompt)
|
97 |
|
98 |
with st.spinner("Searching appropriate querys..."):
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
100 |
if len(st.session_state.ag_messages) <= 2:
|
101 |
relevant_query_indices.append(0)
|
102 |
|
|
|
1 |
+
import os
|
2 |
+
import requests
|
3 |
+
|
4 |
import numpy as np
|
5 |
import streamlit as st
|
6 |
|
7 |
+
from src.retrieval import STANDARD_QUERIES
|
8 |
from src.content.common import (
|
9 |
MODEL_NAMES,
|
10 |
AUDIO_SAMPLES_W_INSTRUCT,
|
|
|
20 |
)
|
21 |
|
22 |
|
23 |
+
API_BASE_URL = os.getenv('API_BASE_URL')
|
24 |
+
|
25 |
+
|
26 |
LLM_NO_AUDIO_PROMPT_TEMPLATE = """{user_question}"""
|
27 |
|
28 |
|
|
|
102 |
return LLM_NO_AUDIO_PROMPT_TEMPLATE.format(user_question=one_time_prompt)
|
103 |
|
104 |
with st.spinner("Searching appropriate querys..."):
|
105 |
+
response = requests.get(
|
106 |
+
f"{API_BASE_URL}retrieve_relevant_docs",
|
107 |
+
params={"user_question": one_time_prompt}
|
108 |
+
)
|
109 |
+
relevant_query_indices = response.json()
|
110 |
+
|
111 |
if len(st.session_state.ag_messages) <= 2:
|
112 |
relevant_query_indices.append(0)
|
113 |
|
src/content/common.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
import os
|
|
|
2 |
import copy
|
3 |
import base64
|
|
|
4 |
import itertools
|
5 |
from collections import OrderedDict
|
6 |
from typing import List, Optional
|
@@ -8,17 +10,11 @@ from typing import List, Optional
|
|
8 |
import numpy as np
|
9 |
import streamlit as st
|
10 |
|
11 |
-
from src.tunnel import start_server
|
12 |
-
from src.retrieval import load_retriever
|
13 |
from src.logger import load_logger
|
14 |
from src.utils import array_to_bytes, bytes_to_array, postprocess_voice_transcription
|
15 |
-
from src.generation import
|
16 |
-
FIXED_GENERATION_CONFIG,
|
17 |
-
MAX_AUDIO_LENGTH,
|
18 |
-
load_model,
|
19 |
-
retrive_response
|
20 |
-
)
|
21 |
|
|
|
22 |
|
23 |
PLAYGROUND_DIALOGUE_STATES = dict(
|
24 |
pg_audio_base64='',
|
@@ -65,46 +61,26 @@ DEFAULT_DIALOGUE_STATE_DICTS = [
|
|
65 |
]
|
66 |
|
67 |
|
68 |
-
MODEL_NAMES = OrderedDict({
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
|
71 |
AUDIO_SAMPLES_W_INSTRUCT = {
|
72 |
-
"female_pilot#1": {
|
73 |
-
"apperance": "Female Pilot Interview: Transcription",
|
74 |
-
"instructions": [
|
75 |
-
"Please transcribe the speech"
|
76 |
-
]
|
77 |
-
},
|
78 |
-
"female_pilot#2": {
|
79 |
-
"apperance": "Female Pilot Interview: Aircraft name",
|
80 |
-
"instructions": [
|
81 |
-
"What does 大力士 mean in the conversation"
|
82 |
-
]
|
83 |
-
},
|
84 |
-
"female_pilot#3": {
|
85 |
-
"apperance": "Female Pilot Interview: Air Force Personnel Count",
|
86 |
-
"instructions": [
|
87 |
-
"How many air force personnel are there?"
|
88 |
-
]
|
89 |
-
},
|
90 |
-
"female_pilot#4": {
|
91 |
-
"apperance": "Female Pilot Interview: Air Force Personnel Name",
|
92 |
-
"instructions": [
|
93 |
-
"Can you tell me the names of the two pilots?"
|
94 |
-
]
|
95 |
-
},
|
96 |
-
"female_pilot#5": {
|
97 |
-
"apperance": "Female Pilot Interview: Pilot Seat Restriction",
|
98 |
-
"instructions": [
|
99 |
-
"What is the concern of having a big butt for pilot?"
|
100 |
-
]
|
101 |
-
},
|
102 |
-
"female_pilot#6": {
|
103 |
-
"apperance": "Female Pilot Interview: Conversation Mood",
|
104 |
-
"instructions": [
|
105 |
-
"What is the mood of the conversation?"
|
106 |
-
]
|
107 |
-
},
|
108 |
"7_ASR_IMDA_PART3_30_ASR_v2_2269": {
|
109 |
"apperance": "7. Automatic Speech Recognition task: conversation in Singapore accent",
|
110 |
"instructions": [
|
@@ -358,13 +334,40 @@ AUDIO_SAMPLES_W_INSTRUCT = {
|
|
358 |
"instructions": [
|
359 |
"Please follow the instruction in the speech."
|
360 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
361 |
}
|
362 |
}
|
363 |
|
364 |
|
365 |
-
exec(os.getenv('APP_CONFIGS'))
|
366 |
-
|
367 |
-
|
368 |
def reset_states(*state_dicts):
|
369 |
for states in state_dicts:
|
370 |
st.session_state.update(copy.deepcopy(states))
|
@@ -403,14 +406,6 @@ def init_state_section():
|
|
403 |
st.session_state.logger = load_logger()
|
404 |
st.session_state.session_id = st.session_state.logger.register_session()
|
405 |
|
406 |
-
if "server" not in st.session_state:
|
407 |
-
st.session_state.server = start_server()
|
408 |
-
|
409 |
-
if "client_mapper" not in st.session_state:
|
410 |
-
st.session_state.client_mapper = load_model()
|
411 |
-
|
412 |
-
if "retriever" not in st.session_state:
|
413 |
-
st.session_state.retriever = load_retriever()
|
414 |
|
415 |
for key, value in FIXED_GENERATION_CONFIG.items():
|
416 |
if key not in st.session_state:
|
@@ -551,54 +546,68 @@ def retrive_response_with_ui(
|
|
551 |
if history is None:
|
552 |
history = []
|
553 |
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
base64_audio_input=base64_audio_input,
|
572 |
-
history=history,
|
573 |
-
**generation_params,
|
574 |
-
**kwargs
|
575 |
-
)
|
576 |
-
|
577 |
-
if error_msg:
|
578 |
-
st.error(error_msg)
|
579 |
|
580 |
-
|
581 |
-
|
582 |
-
st.warning(warning_msg)
|
583 |
|
|
|
|
|
584 |
response = ""
|
585 |
-
|
|
|
586 |
if stream:
|
587 |
-
|
|
|
|
|
|
|
|
|
588 |
response = st.write_stream(response_obj)
|
589 |
else:
|
590 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
591 |
if normalise_response:
|
592 |
response = postprocess_voice_transcription(response)
|
593 |
response = prefix + response
|
594 |
st.write(response)
|
595 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
596 |
st.session_state.logger.register_query(
|
597 |
session_id=st.session_state.session_id,
|
598 |
base64_audio=base64_audio_input,
|
599 |
text_input=text_input,
|
600 |
history=history,
|
601 |
-
params=
|
602 |
response=response,
|
603 |
warnings=warnings,
|
604 |
error_msg=error_msg
|
|
|
1 |
import os
|
2 |
+
import re
|
3 |
import copy
|
4 |
import base64
|
5 |
+
import requests
|
6 |
import itertools
|
7 |
from collections import OrderedDict
|
8 |
from typing import List, Optional
|
|
|
10 |
import numpy as np
|
11 |
import streamlit as st
|
12 |
|
|
|
|
|
13 |
from src.logger import load_logger
|
14 |
from src.utils import array_to_bytes, bytes_to_array, postprocess_voice_transcription
|
15 |
+
from src.generation import FIXED_GENERATION_CONFIG, MAX_AUDIO_LENGTH
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
+
API_BASE_URL = os.getenv('API_BASE_URL')
|
18 |
|
19 |
PLAYGROUND_DIALOGUE_STATES = dict(
|
20 |
pg_audio_base64='',
|
|
|
61 |
]
|
62 |
|
63 |
|
64 |
+
MODEL_NAMES = OrderedDict({
|
65 |
+
"llm": {
|
66 |
+
"vllm_name": "MERaLiON-Gemma",
|
67 |
+
"model_name": "MERaLiON-Gemma",
|
68 |
+
"ui_name": "MERaLiON-Gemma"
|
69 |
+
},
|
70 |
+
"audiollm": {
|
71 |
+
"vllm_name": "MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION",
|
72 |
+
"model_name": "MERaLiON-AudioLLM-Whisper-SEA-LION",
|
73 |
+
"ui_name": "MERaLiON-AudioLLM"
|
74 |
+
},
|
75 |
+
"audiollm-it": {
|
76 |
+
"vllm_name": "MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION-it",
|
77 |
+
"model_name": "MERaLiON-AudioLLM-Whisper-SEA-LION-it",
|
78 |
+
"ui_name": "MERaLiON-AudioLLM-Instruction-Tuning"
|
79 |
+
}
|
80 |
+
})
|
81 |
|
82 |
|
83 |
AUDIO_SAMPLES_W_INSTRUCT = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
"7_ASR_IMDA_PART3_30_ASR_v2_2269": {
|
85 |
"apperance": "7. Automatic Speech Recognition task: conversation in Singapore accent",
|
86 |
"instructions": [
|
|
|
334 |
"instructions": [
|
335 |
"Please follow the instruction in the speech."
|
336 |
]
|
337 |
+
},
|
338 |
+
"female_pilot#1": {
|
339 |
+
"apperance": "Female Pilot Interview: Transcription",
|
340 |
+
"instructions": [
|
341 |
+
"Please transcribe the speech"
|
342 |
+
]
|
343 |
+
},
|
344 |
+
"female_pilot#2": {
|
345 |
+
"apperance": "Female Pilot Interview: Aircraft name",
|
346 |
+
"instructions": [
|
347 |
+
"What does 大力士 mean in the conversation"
|
348 |
+
]
|
349 |
+
},
|
350 |
+
"female_pilot#3": {
|
351 |
+
"apperance": "Female Pilot Interview: Air Force Personnel Count",
|
352 |
+
"instructions": [
|
353 |
+
"How many air force personnel are there?"
|
354 |
+
]
|
355 |
+
},
|
356 |
+
"female_pilot#4": {
|
357 |
+
"apperance": "Female Pilot Interview: Air Force Personnel Name",
|
358 |
+
"instructions": [
|
359 |
+
"Can you tell me the names of the two pilots?"
|
360 |
+
]
|
361 |
+
},
|
362 |
+
"female_pilot#5": {
|
363 |
+
"apperance": "Female Pilot Interview: Conversation Mood",
|
364 |
+
"instructions": [
|
365 |
+
"What is the mood of the conversation?"
|
366 |
+
]
|
367 |
}
|
368 |
}
|
369 |
|
370 |
|
|
|
|
|
|
|
371 |
def reset_states(*state_dicts):
|
372 |
for states in state_dicts:
|
373 |
st.session_state.update(copy.deepcopy(states))
|
|
|
406 |
st.session_state.logger = load_logger()
|
407 |
st.session_state.session_id = st.session_state.logger.register_session()
|
408 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
409 |
|
410 |
for key, value in FIXED_GENERATION_CONFIG.items():
|
411 |
if key not in st.session_state:
|
|
|
546 |
if history is None:
|
547 |
history = []
|
548 |
|
549 |
+
# Prepare request data
|
550 |
+
request_data = {
|
551 |
+
"text_input": str(text_input),
|
552 |
+
"model_name": str(model_name),
|
553 |
+
"array_audio_input": array_audio_input.tolist(), # Convert numpy array to list
|
554 |
+
"base64_audio_input": str(base64_audio_input) if base64_audio_input else None,
|
555 |
+
"history": list(history) if history else None,
|
556 |
+
"stream": bool(stream),
|
557 |
+
"max_completion_tokens": int(st.session_state.max_completion_tokens),
|
558 |
+
"temperature": float(st.session_state.temperature),
|
559 |
+
"top_p": float(st.session_state.top_p),
|
560 |
+
"repetition_penalty": float(st.session_state.repetition_penalty),
|
561 |
+
"top_k": int(st.session_state.top_k),
|
562 |
+
"length_penalty": float(st.session_state.length_penalty),
|
563 |
+
"seed": int(st.session_state.seed),
|
564 |
+
"extra_params": {}
|
565 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
566 |
|
567 |
+
# print(request_data)
|
568 |
+
# print(model_name)
|
|
|
569 |
|
570 |
+
error_msg = ""
|
571 |
+
warnings = []
|
572 |
response = ""
|
573 |
+
|
574 |
+
try:
|
575 |
if stream:
|
576 |
+
# Streaming response
|
577 |
+
response_stream = requests.post(f"{API_BASE_URL}chat", json=request_data, stream=True)
|
578 |
+
response_stream.raise_for_status()
|
579 |
+
|
580 |
+
response_obj = itertools.chain([prefix], (chunk.decode() for chunk in response_stream))
|
581 |
response = st.write_stream(response_obj)
|
582 |
else:
|
583 |
+
# Non-streaming response
|
584 |
+
api_response = requests.post(f"{API_BASE_URL}chat", json=request_data)
|
585 |
+
api_response.raise_for_status()
|
586 |
+
result = api_response.json()
|
587 |
+
|
588 |
+
if "warnings" in result:
|
589 |
+
warnings = result["warnings"]
|
590 |
+
|
591 |
+
response = result.get("response", "")
|
592 |
if normalise_response:
|
593 |
response = postprocess_voice_transcription(response)
|
594 |
response = prefix + response
|
595 |
st.write(response)
|
596 |
|
597 |
+
except requests.exceptions.RequestException as e:
|
598 |
+
error_msg = f"API request failed: {str(e)}"
|
599 |
+
st.error(error_msg)
|
600 |
+
|
601 |
+
if show_warning:
|
602 |
+
for warning_msg in warnings:
|
603 |
+
st.warning(warning_msg)
|
604 |
+
|
605 |
st.session_state.logger.register_query(
|
606 |
session_id=st.session_state.session_id,
|
607 |
base64_audio=base64_audio_input,
|
608 |
text_input=text_input,
|
609 |
history=history,
|
610 |
+
params=request_data["extra_params"],
|
611 |
response=response,
|
612 |
warnings=warnings,
|
613 |
error_msg=error_msg
|
src/exceptions.py
CHANGED
@@ -1,6 +1,2 @@
|
|
1 |
class NoAudioException(Exception):
|
2 |
-
pass
|
3 |
-
|
4 |
-
|
5 |
-
class TunnelNotRunningException(Exception):
|
6 |
pass
|
|
|
1 |
class NoAudioException(Exception):
|
|
|
|
|
|
|
|
|
2 |
pass
|
src/generation.py
CHANGED
@@ -1,15 +1,3 @@
|
|
1 |
-
import os
|
2 |
-
import re
|
3 |
-
import time
|
4 |
-
from typing import List, Dict, Optional
|
5 |
-
|
6 |
-
import numpy as np
|
7 |
-
import streamlit as st
|
8 |
-
from openai import OpenAI, APIConnectionError
|
9 |
-
|
10 |
-
from src.exceptions import TunnelNotRunningException
|
11 |
-
|
12 |
-
|
13 |
FIXED_GENERATION_CONFIG = dict(
|
14 |
max_completion_tokens=1024,
|
15 |
top_k=50,
|
@@ -20,25 +8,6 @@ FIXED_GENERATION_CONFIG = dict(
|
|
20 |
MAX_AUDIO_LENGTH = 120
|
21 |
|
22 |
|
23 |
-
def load_model() -> Dict:
|
24 |
-
"""
|
25 |
-
Create an OpenAI client with connection to vllm server.
|
26 |
-
"""
|
27 |
-
openai_api_key = os.getenv('API_KEY')
|
28 |
-
local_ports = os.getenv('LOCAL_PORTS').split(" ")
|
29 |
-
|
30 |
-
name_to_client_mapper = {}
|
31 |
-
for port in local_ports:
|
32 |
-
client = OpenAI(
|
33 |
-
api_key=openai_api_key,
|
34 |
-
base_url=f"http://localhost:{port}/v1",
|
35 |
-
)
|
36 |
-
|
37 |
-
for model in client.models.list().data:
|
38 |
-
name_to_client_mapper[model.id] = client
|
39 |
-
|
40 |
-
return name_to_client_mapper
|
41 |
-
|
42 |
|
43 |
def prepare_multimodal_content(text_input, base64_audio_input):
|
44 |
return [
|
@@ -76,112 +45,3 @@ def change_multimodal_content(
|
|
76 |
}
|
77 |
|
78 |
return original_content
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
def _retrive_response(
|
83 |
-
model: str,
|
84 |
-
text_input: str,
|
85 |
-
base64_audio_input: str,
|
86 |
-
history: Optional[List] = None,
|
87 |
-
**kwargs):
|
88 |
-
"""
|
89 |
-
Send request through OpenAI client.
|
90 |
-
"""
|
91 |
-
if history is None:
|
92 |
-
history = []
|
93 |
-
|
94 |
-
if base64_audio_input:
|
95 |
-
content = [
|
96 |
-
{
|
97 |
-
"type": "text",
|
98 |
-
"text": f"Text instruction: {text_input}"
|
99 |
-
},
|
100 |
-
{
|
101 |
-
"type": "audio_url",
|
102 |
-
"audio_url": {
|
103 |
-
"url": f"data:audio/ogg;base64,{base64_audio_input}"
|
104 |
-
},
|
105 |
-
},
|
106 |
-
]
|
107 |
-
else:
|
108 |
-
content = text_input
|
109 |
-
|
110 |
-
current_client = st.session_state.client_mapper[model]
|
111 |
-
|
112 |
-
return current_client.chat.completions.create(
|
113 |
-
messages=history + [{"role": "user", "content": content}],
|
114 |
-
model=model,
|
115 |
-
**kwargs
|
116 |
-
)
|
117 |
-
|
118 |
-
|
119 |
-
def _retry_retrive_response_throws_exception(retry=3, **kwargs):
|
120 |
-
try:
|
121 |
-
response_object = _retrive_response(**kwargs)
|
122 |
-
except APIConnectionError as e:
|
123 |
-
if not st.session_state.server.is_running():
|
124 |
-
if retry == 0:
|
125 |
-
raise TunnelNotRunningException()
|
126 |
-
|
127 |
-
st.toast(f":warning: Internet connection is down. Trying to re-establish connection ({retry}).")
|
128 |
-
|
129 |
-
if st.session_state.server.is_down():
|
130 |
-
st.session_state.server.restart()
|
131 |
-
elif st.session_state.server.is_starting():
|
132 |
-
time.sleep(2)
|
133 |
-
|
134 |
-
return _retry_retrive_response_throws_exception(retry-1, **kwargs)
|
135 |
-
raise e
|
136 |
-
|
137 |
-
return response_object
|
138 |
-
|
139 |
-
|
140 |
-
def _validate_input(text_input, array_audio_input) -> List[str]:
|
141 |
-
"""
|
142 |
-
TODO: improve the input validation regex.
|
143 |
-
"""
|
144 |
-
warnings = []
|
145 |
-
if re.search("tool|code|python|java|math|calculate", text_input):
|
146 |
-
warnings.append("WARNING: MERaLiON-AudioLLM is not intended for use in tool calling, math, and coding tasks.")
|
147 |
-
|
148 |
-
if re.search(r'[\u4e00-\u9fff]+', text_input):
|
149 |
-
warnings.append("NOTE: Please try to prompt in English for the best performance.")
|
150 |
-
|
151 |
-
if array_audio_input.shape[0] == 0:
|
152 |
-
warnings.append("NOTE: Please specify audio from examples or local files.")
|
153 |
-
|
154 |
-
if array_audio_input.shape[0] / 16000 > 30.0:
|
155 |
-
warnings.append((
|
156 |
-
"WARNING: MERaLiON-AudioLLM is trained to process audio up to **30 seconds**."
|
157 |
-
f" Audio longer than **{MAX_AUDIO_LENGTH} seconds** will be truncated."
|
158 |
-
))
|
159 |
-
|
160 |
-
return warnings
|
161 |
-
|
162 |
-
|
163 |
-
def retrive_response(
|
164 |
-
text_input: str,
|
165 |
-
array_audio_input: np.ndarray,
|
166 |
-
**kwargs
|
167 |
-
):
|
168 |
-
warnings = _validate_input(text_input, array_audio_input)
|
169 |
-
|
170 |
-
response_object, error_msg = None, ""
|
171 |
-
try:
|
172 |
-
response_object = _retry_retrive_response_throws_exception(
|
173 |
-
text_input=text_input,
|
174 |
-
**kwargs
|
175 |
-
)
|
176 |
-
except TunnelNotRunningException:
|
177 |
-
error_msg = "Internet connection cannot be established. Please contact the administrator."
|
178 |
-
except Exception as e:
|
179 |
-
error_msg = f"Caught Exception: {repr(e)}. Please contact the administrator."
|
180 |
-
|
181 |
-
return error_msg, warnings, response_object
|
182 |
-
|
183 |
-
|
184 |
-
def postprocess_voice_transcription(text):
|
185 |
-
text = re.sub("<.*>:?|\(.*\)|\[.*\]", "", text)
|
186 |
-
text = re.sub("\s+", " ", text).strip()
|
187 |
-
return text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
FIXED_GENERATION_CONFIG = dict(
|
2 |
max_completion_tokens=1024,
|
3 |
top_k=50,
|
|
|
8 |
MAX_AUDIO_LENGTH = 120
|
9 |
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
def prepare_multimodal_content(text_input, base64_audio_input):
|
13 |
return [
|
|
|
45 |
}
|
46 |
|
47 |
return original_content
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/retrieval.py
CHANGED
@@ -1,10 +1,3 @@
|
|
1 |
-
from typing import List
|
2 |
-
|
3 |
-
import numpy as np
|
4 |
-
import streamlit as st
|
5 |
-
from FlagEmbedding import BGEM3FlagModel
|
6 |
-
|
7 |
-
|
8 |
STANDARD_QUERIES = [
|
9 |
{
|
10 |
"query_text": "Please transcribe this speech.",
|
@@ -43,71 +36,3 @@ STANDARD_QUERIES = [
|
|
43 |
"ui_text": "emotion recognition"
|
44 |
},
|
45 |
]
|
46 |
-
|
47 |
-
|
48 |
-
def _colbert_score(q_reps, p_reps):
|
49 |
-
"""Compute colbert scores of input queries and passages.
|
50 |
-
|
51 |
-
Args:
|
52 |
-
q_reps (np.ndarray): Multi-vector embeddings for queries.
|
53 |
-
p_reps (np.ndarray): Multi-vector embeddings for passages/corpus.
|
54 |
-
|
55 |
-
Returns:
|
56 |
-
torch.Tensor: Computed colbert scores.
|
57 |
-
"""
|
58 |
-
# q_reps, p_reps = torch.from_numpy(q_reps), torch.from_numpy(p_reps)
|
59 |
-
token_scores = np.einsum('in,jn->ij', q_reps, p_reps)
|
60 |
-
scores = token_scores.max(-1)
|
61 |
-
scores = np.sum(scores) / q_reps.shape[0]
|
62 |
-
return scores
|
63 |
-
|
64 |
-
class QueryRetriever:
|
65 |
-
def __init__(self, docs):
|
66 |
-
self.model = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True)
|
67 |
-
self.docs = docs
|
68 |
-
self.doc_vectors = self.model.encode(
|
69 |
-
[d["doc_text"] for d in self.docs],
|
70 |
-
return_sparse=True,
|
71 |
-
return_colbert_vecs=True
|
72 |
-
)
|
73 |
-
self.scorer_attrs = {
|
74 |
-
"lexical_weights": {
|
75 |
-
"method": self.model.compute_lexical_matching_score,
|
76 |
-
"weight": 0.2
|
77 |
-
},
|
78 |
-
"colbert_vecs": {
|
79 |
-
"method": _colbert_score,
|
80 |
-
"weight": 0.8
|
81 |
-
},
|
82 |
-
}
|
83 |
-
|
84 |
-
def get_relevant_doc_indices(self, prompt, normalize=False) -> np.ndarray:
|
85 |
-
scores = np.zeros(len(self.docs))
|
86 |
-
|
87 |
-
if not prompt:
|
88 |
-
return scores
|
89 |
-
|
90 |
-
prompt_vector = self.model.encode(
|
91 |
-
prompt,
|
92 |
-
return_sparse=True,
|
93 |
-
return_colbert_vecs=True
|
94 |
-
)
|
95 |
-
|
96 |
-
for scorer_name, scorer_attrs in self.scorer_attrs.items():
|
97 |
-
for i, doc_vec in enumerate(self.doc_vectors[scorer_name]):
|
98 |
-
scores[i] += scorer_attrs["method"](prompt_vector[scorer_name], doc_vec)
|
99 |
-
|
100 |
-
if normalize:
|
101 |
-
scores = scores / np.sum(scores)
|
102 |
-
return scores
|
103 |
-
|
104 |
-
|
105 |
-
@st.cache_resource()
|
106 |
-
def load_retriever():
|
107 |
-
return QueryRetriever(docs=STANDARD_QUERIES)
|
108 |
-
|
109 |
-
|
110 |
-
def retrieve_relevant_docs(user_question: str) -> List[int]:
|
111 |
-
scores = st.session_state.retriever.get_relevant_doc_indices(user_question, normalize=True)
|
112 |
-
selected_indices = np.where(scores > 0.2)[0]
|
113 |
-
return selected_indices.tolist()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
STANDARD_QUERIES = [
|
2 |
{
|
3 |
"query_text": "Please transcribe this speech.",
|
|
|
36 |
"ui_text": "emotion recognition"
|
37 |
},
|
38 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/tunnel.py
DELETED
@@ -1,72 +0,0 @@
|
|
1 |
-
import io
|
2 |
-
import os
|
3 |
-
|
4 |
-
import paramiko
|
5 |
-
import streamlit as st
|
6 |
-
from sshtunnel import SSHTunnelForwarder
|
7 |
-
|
8 |
-
|
9 |
-
DEFAULT_LOCAL_PORTS = "8000 8001"
|
10 |
-
DEFAULT_REMOTE_PORTS = "8000 8001"
|
11 |
-
|
12 |
-
|
13 |
-
@st.cache_resource()
|
14 |
-
def start_server():
|
15 |
-
server = SSHTunnelManager()
|
16 |
-
server.start()
|
17 |
-
return server
|
18 |
-
|
19 |
-
|
20 |
-
class SSHTunnelManager:
|
21 |
-
def __init__(self):
|
22 |
-
pkey = paramiko.RSAKey.from_private_key(io.StringIO(os.getenv('PRIVATE_KEY')))
|
23 |
-
|
24 |
-
self.server = SSHTunnelForwarder(
|
25 |
-
ssh_address_or_host=os.getenv('SERVER_DNS_NAME'),
|
26 |
-
ssh_username="ec2-user",
|
27 |
-
ssh_pkey=pkey,
|
28 |
-
local_bind_addresses=[
|
29 |
-
("127.0.0.1", int(port))
|
30 |
-
for port in os.getenv('LOCAL_PORTS', DEFAULT_LOCAL_PORTS).split(" ")
|
31 |
-
],
|
32 |
-
remote_bind_addresses=[
|
33 |
-
("127.0.0.1", int(port))
|
34 |
-
for port in os.getenv('REMOTE_PORTS', DEFAULT_REMOTE_PORTS).split(" ")
|
35 |
-
]
|
36 |
-
)
|
37 |
-
|
38 |
-
self._is_starting = False
|
39 |
-
self._is_running = False
|
40 |
-
|
41 |
-
def update_status(self):
|
42 |
-
if not self._is_starting:
|
43 |
-
self.server.check_tunnels()
|
44 |
-
self._is_running = all(
|
45 |
-
list(self.server.tunnel_is_up.values())
|
46 |
-
)
|
47 |
-
else:
|
48 |
-
self._is_running = False
|
49 |
-
|
50 |
-
def is_starting(self):
|
51 |
-
self.update_status()
|
52 |
-
return self._is_starting
|
53 |
-
|
54 |
-
def is_running(self):
|
55 |
-
self.update_status()
|
56 |
-
return self._is_running
|
57 |
-
|
58 |
-
def is_down(self):
|
59 |
-
self.update_status()
|
60 |
-
return (not self._is_running) and (not self._is_starting)
|
61 |
-
|
62 |
-
def start(self, *args, **kwargs):
|
63 |
-
if not self._is_starting:
|
64 |
-
self._is_starting = True
|
65 |
-
self.server.start(*args, **kwargs)
|
66 |
-
self._is_starting = False
|
67 |
-
|
68 |
-
def restart(self, *args, **kwargs):
|
69 |
-
if not self._is_starting:
|
70 |
-
self._is_starting = True
|
71 |
-
self.server.restart(*args, **kwargs)
|
72 |
-
self._is_starting = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|