YingxuHe commited on
Commit
8b1f711
·
2 Parent(s): 1ae3a00 52b1f13

apply new backend

Browse files
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .venv/
2
+ __pycache__/
src/content/agent.py CHANGED
@@ -1,7 +1,10 @@
 
 
 
1
  import numpy as np
2
  import streamlit as st
3
 
4
- from src.retrieval import STANDARD_QUERIES, retrieve_relevant_docs
5
  from src.content.common import (
6
  MODEL_NAMES,
7
  AUDIO_SAMPLES_W_INSTRUCT,
@@ -17,6 +20,9 @@ from src.content.common import (
17
  )
18
 
19
 
 
 
 
20
  LLM_NO_AUDIO_PROMPT_TEMPLATE = """{user_question}"""
21
 
22
 
@@ -96,7 +102,12 @@ def _prepare_final_prompt_with_ui(one_time_prompt):
96
  return LLM_NO_AUDIO_PROMPT_TEMPLATE.format(user_question=one_time_prompt)
97
 
98
  with st.spinner("Searching appropriate querys..."):
99
- relevant_query_indices = retrieve_relevant_docs(one_time_prompt)
 
 
 
 
 
100
  if len(st.session_state.ag_messages) <= 2:
101
  relevant_query_indices.append(0)
102
 
 
1
+ import os
2
+ import requests
3
+
4
  import numpy as np
5
  import streamlit as st
6
 
7
+ from src.retrieval import STANDARD_QUERIES
8
  from src.content.common import (
9
  MODEL_NAMES,
10
  AUDIO_SAMPLES_W_INSTRUCT,
 
20
  )
21
 
22
 
23
+ API_BASE_URL = os.getenv('API_BASE_URL')
24
+
25
+
26
  LLM_NO_AUDIO_PROMPT_TEMPLATE = """{user_question}"""
27
 
28
 
 
102
  return LLM_NO_AUDIO_PROMPT_TEMPLATE.format(user_question=one_time_prompt)
103
 
104
  with st.spinner("Searching appropriate querys..."):
105
+ response = requests.get(
106
+ f"{API_BASE_URL}retrieve_relevant_docs",
107
+ params={"user_question": one_time_prompt}
108
+ )
109
+ relevant_query_indices = response.json()
110
+
111
  if len(st.session_state.ag_messages) <= 2:
112
  relevant_query_indices.append(0)
113
 
src/content/common.py CHANGED
@@ -1,6 +1,8 @@
1
  import os
 
2
  import copy
3
  import base64
 
4
  import itertools
5
  from collections import OrderedDict
6
  from typing import List, Optional
@@ -8,17 +10,11 @@ from typing import List, Optional
8
  import numpy as np
9
  import streamlit as st
10
 
11
- from src.tunnel import start_server
12
- from src.retrieval import load_retriever
13
  from src.logger import load_logger
14
  from src.utils import array_to_bytes, bytes_to_array, postprocess_voice_transcription
15
- from src.generation import (
16
- FIXED_GENERATION_CONFIG,
17
- MAX_AUDIO_LENGTH,
18
- load_model,
19
- retrive_response
20
- )
21
 
 
22
 
23
  PLAYGROUND_DIALOGUE_STATES = dict(
24
  pg_audio_base64='',
@@ -65,46 +61,26 @@ DEFAULT_DIALOGUE_STATE_DICTS = [
65
  ]
66
 
67
 
68
- MODEL_NAMES = OrderedDict({})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
 
71
  AUDIO_SAMPLES_W_INSTRUCT = {
72
- "female_pilot#1": {
73
- "apperance": "Female Pilot Interview: Transcription",
74
- "instructions": [
75
- "Please transcribe the speech"
76
- ]
77
- },
78
- "female_pilot#2": {
79
- "apperance": "Female Pilot Interview: Aircraft name",
80
- "instructions": [
81
- "What does 大力士 mean in the conversation"
82
- ]
83
- },
84
- "female_pilot#3": {
85
- "apperance": "Female Pilot Interview: Air Force Personnel Count",
86
- "instructions": [
87
- "How many air force personnel are there?"
88
- ]
89
- },
90
- "female_pilot#4": {
91
- "apperance": "Female Pilot Interview: Air Force Personnel Name",
92
- "instructions": [
93
- "Can you tell me the names of the two pilots?"
94
- ]
95
- },
96
- "female_pilot#5": {
97
- "apperance": "Female Pilot Interview: Pilot Seat Restriction",
98
- "instructions": [
99
- "What is the concern of having a big butt for pilot?"
100
- ]
101
- },
102
- "female_pilot#6": {
103
- "apperance": "Female Pilot Interview: Conversation Mood",
104
- "instructions": [
105
- "What is the mood of the conversation?"
106
- ]
107
- },
108
  "7_ASR_IMDA_PART3_30_ASR_v2_2269": {
109
  "apperance": "7. Automatic Speech Recognition task: conversation in Singapore accent",
110
  "instructions": [
@@ -358,13 +334,40 @@ AUDIO_SAMPLES_W_INSTRUCT = {
358
  "instructions": [
359
  "Please follow the instruction in the speech."
360
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
  }
362
  }
363
 
364
 
365
- exec(os.getenv('APP_CONFIGS'))
366
-
367
-
368
  def reset_states(*state_dicts):
369
  for states in state_dicts:
370
  st.session_state.update(copy.deepcopy(states))
@@ -403,14 +406,6 @@ def init_state_section():
403
  st.session_state.logger = load_logger()
404
  st.session_state.session_id = st.session_state.logger.register_session()
405
 
406
- if "server" not in st.session_state:
407
- st.session_state.server = start_server()
408
-
409
- if "client_mapper" not in st.session_state:
410
- st.session_state.client_mapper = load_model()
411
-
412
- if "retriever" not in st.session_state:
413
- st.session_state.retriever = load_retriever()
414
 
415
  for key, value in FIXED_GENERATION_CONFIG.items():
416
  if key not in st.session_state:
@@ -551,54 +546,68 @@ def retrive_response_with_ui(
551
  if history is None:
552
  history = []
553
 
554
- generation_params = dict(
555
- model=model_name,
556
- max_completion_tokens=st.session_state.max_completion_tokens,
557
- temperature=st.session_state.temperature,
558
- top_p=st.session_state.top_p,
559
- extra_body={
560
- "repetition_penalty": st.session_state.repetition_penalty,
561
- "top_k": st.session_state.top_k,
562
- "length_penalty": st.session_state.length_penalty
563
- },
564
- stream=stream,
565
- seed=st.session_state.seed
566
- )
567
-
568
- error_msg, warnings, response_obj = retrive_response(
569
- text_input,
570
- array_audio_input,
571
- base64_audio_input=base64_audio_input,
572
- history=history,
573
- **generation_params,
574
- **kwargs
575
- )
576
-
577
- if error_msg:
578
- st.error(error_msg)
579
 
580
- if show_warning:
581
- for warning_msg in warnings:
582
- st.warning(warning_msg)
583
 
 
 
584
  response = ""
585
- if response_obj is not None:
 
586
  if stream:
587
- response_obj = itertools.chain([prefix], response_obj)
 
 
 
 
588
  response = st.write_stream(response_obj)
589
  else:
590
- response = response_obj.choices[0].message.content
 
 
 
 
 
 
 
 
591
  if normalise_response:
592
  response = postprocess_voice_transcription(response)
593
  response = prefix + response
594
  st.write(response)
595
 
 
 
 
 
 
 
 
 
596
  st.session_state.logger.register_query(
597
  session_id=st.session_state.session_id,
598
  base64_audio=base64_audio_input,
599
  text_input=text_input,
600
  history=history,
601
- params=generation_params,
602
  response=response,
603
  warnings=warnings,
604
  error_msg=error_msg
 
1
  import os
2
+ import re
3
  import copy
4
  import base64
5
+ import requests
6
  import itertools
7
  from collections import OrderedDict
8
  from typing import List, Optional
 
10
  import numpy as np
11
  import streamlit as st
12
 
 
 
13
  from src.logger import load_logger
14
  from src.utils import array_to_bytes, bytes_to_array, postprocess_voice_transcription
15
+ from src.generation import FIXED_GENERATION_CONFIG, MAX_AUDIO_LENGTH
 
 
 
 
 
16
 
17
+ API_BASE_URL = os.getenv('API_BASE_URL')
18
 
19
  PLAYGROUND_DIALOGUE_STATES = dict(
20
  pg_audio_base64='',
 
61
  ]
62
 
63
 
64
+ MODEL_NAMES = OrderedDict({
65
+ "llm": {
66
+ "vllm_name": "MERaLiON-Gemma",
67
+ "model_name": "MERaLiON-Gemma",
68
+ "ui_name": "MERaLiON-Gemma"
69
+ },
70
+ "audiollm": {
71
+ "vllm_name": "MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION",
72
+ "model_name": "MERaLiON-AudioLLM-Whisper-SEA-LION",
73
+ "ui_name": "MERaLiON-AudioLLM"
74
+ },
75
+ "audiollm-it": {
76
+ "vllm_name": "MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION-it",
77
+ "model_name": "MERaLiON-AudioLLM-Whisper-SEA-LION-it",
78
+ "ui_name": "MERaLiON-AudioLLM-Instruction-Tuning"
79
+ }
80
+ })
81
 
82
 
83
  AUDIO_SAMPLES_W_INSTRUCT = {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  "7_ASR_IMDA_PART3_30_ASR_v2_2269": {
85
  "apperance": "7. Automatic Speech Recognition task: conversation in Singapore accent",
86
  "instructions": [
 
334
  "instructions": [
335
  "Please follow the instruction in the speech."
336
  ]
337
+ },
338
+ "female_pilot#1": {
339
+ "apperance": "Female Pilot Interview: Transcription",
340
+ "instructions": [
341
+ "Please transcribe the speech"
342
+ ]
343
+ },
344
+ "female_pilot#2": {
345
+ "apperance": "Female Pilot Interview: Aircraft name",
346
+ "instructions": [
347
+ "What does 大力士 mean in the conversation"
348
+ ]
349
+ },
350
+ "female_pilot#3": {
351
+ "apperance": "Female Pilot Interview: Air Force Personnel Count",
352
+ "instructions": [
353
+ "How many air force personnel are there?"
354
+ ]
355
+ },
356
+ "female_pilot#4": {
357
+ "apperance": "Female Pilot Interview: Air Force Personnel Name",
358
+ "instructions": [
359
+ "Can you tell me the names of the two pilots?"
360
+ ]
361
+ },
362
+ "female_pilot#5": {
363
+ "apperance": "Female Pilot Interview: Conversation Mood",
364
+ "instructions": [
365
+ "What is the mood of the conversation?"
366
+ ]
367
  }
368
  }
369
 
370
 
 
 
 
371
  def reset_states(*state_dicts):
372
  for states in state_dicts:
373
  st.session_state.update(copy.deepcopy(states))
 
406
  st.session_state.logger = load_logger()
407
  st.session_state.session_id = st.session_state.logger.register_session()
408
 
 
 
 
 
 
 
 
 
409
 
410
  for key, value in FIXED_GENERATION_CONFIG.items():
411
  if key not in st.session_state:
 
546
  if history is None:
547
  history = []
548
 
549
+ # Prepare request data
550
+ request_data = {
551
+ "text_input": str(text_input),
552
+ "model_name": str(model_name),
553
+ "array_audio_input": array_audio_input.tolist(), # Convert numpy array to list
554
+ "base64_audio_input": str(base64_audio_input) if base64_audio_input else None,
555
+ "history": list(history) if history else None,
556
+ "stream": bool(stream),
557
+ "max_completion_tokens": int(st.session_state.max_completion_tokens),
558
+ "temperature": float(st.session_state.temperature),
559
+ "top_p": float(st.session_state.top_p),
560
+ "repetition_penalty": float(st.session_state.repetition_penalty),
561
+ "top_k": int(st.session_state.top_k),
562
+ "length_penalty": float(st.session_state.length_penalty),
563
+ "seed": int(st.session_state.seed),
564
+ "extra_params": {}
565
+ }
 
 
 
 
 
 
 
 
566
 
567
+ # print(request_data)
568
+ # print(model_name)
 
569
 
570
+ error_msg = ""
571
+ warnings = []
572
  response = ""
573
+
574
+ try:
575
  if stream:
576
+ # Streaming response
577
+ response_stream = requests.post(f"{API_BASE_URL}chat", json=request_data, stream=True)
578
+ response_stream.raise_for_status()
579
+
580
+ response_obj = itertools.chain([prefix], (chunk.decode() for chunk in response_stream))
581
  response = st.write_stream(response_obj)
582
  else:
583
+ # Non-streaming response
584
+ api_response = requests.post(f"{API_BASE_URL}chat", json=request_data)
585
+ api_response.raise_for_status()
586
+ result = api_response.json()
587
+
588
+ if "warnings" in result:
589
+ warnings = result["warnings"]
590
+
591
+ response = result.get("response", "")
592
  if normalise_response:
593
  response = postprocess_voice_transcription(response)
594
  response = prefix + response
595
  st.write(response)
596
 
597
+ except requests.exceptions.RequestException as e:
598
+ error_msg = f"API request failed: {str(e)}"
599
+ st.error(error_msg)
600
+
601
+ if show_warning:
602
+ for warning_msg in warnings:
603
+ st.warning(warning_msg)
604
+
605
  st.session_state.logger.register_query(
606
  session_id=st.session_state.session_id,
607
  base64_audio=base64_audio_input,
608
  text_input=text_input,
609
  history=history,
610
+ params=request_data["extra_params"],
611
  response=response,
612
  warnings=warnings,
613
  error_msg=error_msg
src/exceptions.py CHANGED
@@ -1,6 +1,2 @@
1
  class NoAudioException(Exception):
2
- pass
3
-
4
-
5
- class TunnelNotRunningException(Exception):
6
  pass
 
1
  class NoAudioException(Exception):
 
 
 
 
2
  pass
src/generation.py CHANGED
@@ -1,15 +1,3 @@
1
- import os
2
- import re
3
- import time
4
- from typing import List, Dict, Optional
5
-
6
- import numpy as np
7
- import streamlit as st
8
- from openai import OpenAI, APIConnectionError
9
-
10
- from src.exceptions import TunnelNotRunningException
11
-
12
-
13
  FIXED_GENERATION_CONFIG = dict(
14
  max_completion_tokens=1024,
15
  top_k=50,
@@ -20,25 +8,6 @@ FIXED_GENERATION_CONFIG = dict(
20
  MAX_AUDIO_LENGTH = 120
21
 
22
 
23
- def load_model() -> Dict:
24
- """
25
- Create an OpenAI client with connection to vllm server.
26
- """
27
- openai_api_key = os.getenv('API_KEY')
28
- local_ports = os.getenv('LOCAL_PORTS').split(" ")
29
-
30
- name_to_client_mapper = {}
31
- for port in local_ports:
32
- client = OpenAI(
33
- api_key=openai_api_key,
34
- base_url=f"http://localhost:{port}/v1",
35
- )
36
-
37
- for model in client.models.list().data:
38
- name_to_client_mapper[model.id] = client
39
-
40
- return name_to_client_mapper
41
-
42
 
43
  def prepare_multimodal_content(text_input, base64_audio_input):
44
  return [
@@ -76,112 +45,3 @@ def change_multimodal_content(
76
  }
77
 
78
  return original_content
79
-
80
-
81
-
82
- def _retrive_response(
83
- model: str,
84
- text_input: str,
85
- base64_audio_input: str,
86
- history: Optional[List] = None,
87
- **kwargs):
88
- """
89
- Send request through OpenAI client.
90
- """
91
- if history is None:
92
- history = []
93
-
94
- if base64_audio_input:
95
- content = [
96
- {
97
- "type": "text",
98
- "text": f"Text instruction: {text_input}"
99
- },
100
- {
101
- "type": "audio_url",
102
- "audio_url": {
103
- "url": f"data:audio/ogg;base64,{base64_audio_input}"
104
- },
105
- },
106
- ]
107
- else:
108
- content = text_input
109
-
110
- current_client = st.session_state.client_mapper[model]
111
-
112
- return current_client.chat.completions.create(
113
- messages=history + [{"role": "user", "content": content}],
114
- model=model,
115
- **kwargs
116
- )
117
-
118
-
119
- def _retry_retrive_response_throws_exception(retry=3, **kwargs):
120
- try:
121
- response_object = _retrive_response(**kwargs)
122
- except APIConnectionError as e:
123
- if not st.session_state.server.is_running():
124
- if retry == 0:
125
- raise TunnelNotRunningException()
126
-
127
- st.toast(f":warning: Internet connection is down. Trying to re-establish connection ({retry}).")
128
-
129
- if st.session_state.server.is_down():
130
- st.session_state.server.restart()
131
- elif st.session_state.server.is_starting():
132
- time.sleep(2)
133
-
134
- return _retry_retrive_response_throws_exception(retry-1, **kwargs)
135
- raise e
136
-
137
- return response_object
138
-
139
-
140
- def _validate_input(text_input, array_audio_input) -> List[str]:
141
- """
142
- TODO: improve the input validation regex.
143
- """
144
- warnings = []
145
- if re.search("tool|code|python|java|math|calculate", text_input):
146
- warnings.append("WARNING: MERaLiON-AudioLLM is not intended for use in tool calling, math, and coding tasks.")
147
-
148
- if re.search(r'[\u4e00-\u9fff]+', text_input):
149
- warnings.append("NOTE: Please try to prompt in English for the best performance.")
150
-
151
- if array_audio_input.shape[0] == 0:
152
- warnings.append("NOTE: Please specify audio from examples or local files.")
153
-
154
- if array_audio_input.shape[0] / 16000 > 30.0:
155
- warnings.append((
156
- "WARNING: MERaLiON-AudioLLM is trained to process audio up to **30 seconds**."
157
- f" Audio longer than **{MAX_AUDIO_LENGTH} seconds** will be truncated."
158
- ))
159
-
160
- return warnings
161
-
162
-
163
- def retrive_response(
164
- text_input: str,
165
- array_audio_input: np.ndarray,
166
- **kwargs
167
- ):
168
- warnings = _validate_input(text_input, array_audio_input)
169
-
170
- response_object, error_msg = None, ""
171
- try:
172
- response_object = _retry_retrive_response_throws_exception(
173
- text_input=text_input,
174
- **kwargs
175
- )
176
- except TunnelNotRunningException:
177
- error_msg = "Internet connection cannot be established. Please contact the administrator."
178
- except Exception as e:
179
- error_msg = f"Caught Exception: {repr(e)}. Please contact the administrator."
180
-
181
- return error_msg, warnings, response_object
182
-
183
-
184
- def postprocess_voice_transcription(text):
185
- text = re.sub("<.*>:?|\(.*\)|\[.*\]", "", text)
186
- text = re.sub("\s+", " ", text).strip()
187
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  FIXED_GENERATION_CONFIG = dict(
2
  max_completion_tokens=1024,
3
  top_k=50,
 
8
  MAX_AUDIO_LENGTH = 120
9
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  def prepare_multimodal_content(text_input, base64_audio_input):
13
  return [
 
45
  }
46
 
47
  return original_content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/retrieval.py CHANGED
@@ -1,10 +1,3 @@
1
- from typing import List
2
-
3
- import numpy as np
4
- import streamlit as st
5
- from FlagEmbedding import BGEM3FlagModel
6
-
7
-
8
  STANDARD_QUERIES = [
9
  {
10
  "query_text": "Please transcribe this speech.",
@@ -43,71 +36,3 @@ STANDARD_QUERIES = [
43
  "ui_text": "emotion recognition"
44
  },
45
  ]
46
-
47
-
48
- def _colbert_score(q_reps, p_reps):
49
- """Compute colbert scores of input queries and passages.
50
-
51
- Args:
52
- q_reps (np.ndarray): Multi-vector embeddings for queries.
53
- p_reps (np.ndarray): Multi-vector embeddings for passages/corpus.
54
-
55
- Returns:
56
- torch.Tensor: Computed colbert scores.
57
- """
58
- # q_reps, p_reps = torch.from_numpy(q_reps), torch.from_numpy(p_reps)
59
- token_scores = np.einsum('in,jn->ij', q_reps, p_reps)
60
- scores = token_scores.max(-1)
61
- scores = np.sum(scores) / q_reps.shape[0]
62
- return scores
63
-
64
- class QueryRetriever:
65
- def __init__(self, docs):
66
- self.model = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True)
67
- self.docs = docs
68
- self.doc_vectors = self.model.encode(
69
- [d["doc_text"] for d in self.docs],
70
- return_sparse=True,
71
- return_colbert_vecs=True
72
- )
73
- self.scorer_attrs = {
74
- "lexical_weights": {
75
- "method": self.model.compute_lexical_matching_score,
76
- "weight": 0.2
77
- },
78
- "colbert_vecs": {
79
- "method": _colbert_score,
80
- "weight": 0.8
81
- },
82
- }
83
-
84
- def get_relevant_doc_indices(self, prompt, normalize=False) -> np.ndarray:
85
- scores = np.zeros(len(self.docs))
86
-
87
- if not prompt:
88
- return scores
89
-
90
- prompt_vector = self.model.encode(
91
- prompt,
92
- return_sparse=True,
93
- return_colbert_vecs=True
94
- )
95
-
96
- for scorer_name, scorer_attrs in self.scorer_attrs.items():
97
- for i, doc_vec in enumerate(self.doc_vectors[scorer_name]):
98
- scores[i] += scorer_attrs["method"](prompt_vector[scorer_name], doc_vec)
99
-
100
- if normalize:
101
- scores = scores / np.sum(scores)
102
- return scores
103
-
104
-
105
- @st.cache_resource()
106
- def load_retriever():
107
- return QueryRetriever(docs=STANDARD_QUERIES)
108
-
109
-
110
- def retrieve_relevant_docs(user_question: str) -> List[int]:
111
- scores = st.session_state.retriever.get_relevant_doc_indices(user_question, normalize=True)
112
- selected_indices = np.where(scores > 0.2)[0]
113
- return selected_indices.tolist()
 
 
 
 
 
 
 
 
1
  STANDARD_QUERIES = [
2
  {
3
  "query_text": "Please transcribe this speech.",
 
36
  "ui_text": "emotion recognition"
37
  },
38
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/tunnel.py DELETED
@@ -1,72 +0,0 @@
1
- import io
2
- import os
3
-
4
- import paramiko
5
- import streamlit as st
6
- from sshtunnel import SSHTunnelForwarder
7
-
8
-
9
- DEFAULT_LOCAL_PORTS = "8000 8001"
10
- DEFAULT_REMOTE_PORTS = "8000 8001"
11
-
12
-
13
- @st.cache_resource()
14
- def start_server():
15
- server = SSHTunnelManager()
16
- server.start()
17
- return server
18
-
19
-
20
- class SSHTunnelManager:
21
- def __init__(self):
22
- pkey = paramiko.RSAKey.from_private_key(io.StringIO(os.getenv('PRIVATE_KEY')))
23
-
24
- self.server = SSHTunnelForwarder(
25
- ssh_address_or_host=os.getenv('SERVER_DNS_NAME'),
26
- ssh_username="ec2-user",
27
- ssh_pkey=pkey,
28
- local_bind_addresses=[
29
- ("127.0.0.1", int(port))
30
- for port in os.getenv('LOCAL_PORTS', DEFAULT_LOCAL_PORTS).split(" ")
31
- ],
32
- remote_bind_addresses=[
33
- ("127.0.0.1", int(port))
34
- for port in os.getenv('REMOTE_PORTS', DEFAULT_REMOTE_PORTS).split(" ")
35
- ]
36
- )
37
-
38
- self._is_starting = False
39
- self._is_running = False
40
-
41
- def update_status(self):
42
- if not self._is_starting:
43
- self.server.check_tunnels()
44
- self._is_running = all(
45
- list(self.server.tunnel_is_up.values())
46
- )
47
- else:
48
- self._is_running = False
49
-
50
- def is_starting(self):
51
- self.update_status()
52
- return self._is_starting
53
-
54
- def is_running(self):
55
- self.update_status()
56
- return self._is_running
57
-
58
- def is_down(self):
59
- self.update_status()
60
- return (not self._is_running) and (not self._is_starting)
61
-
62
- def start(self, *args, **kwargs):
63
- if not self._is_starting:
64
- self._is_starting = True
65
- self.server.start(*args, **kwargs)
66
- self._is_starting = False
67
-
68
- def restart(self, *args, **kwargs):
69
- if not self._is_starting:
70
- self._is_starting = True
71
- self.server.restart(*args, **kwargs)
72
- self._is_starting = False