frankenliu commited on
Commit
f48eacc
·
verified ·
1 Parent(s): 526c360

clean up audio caches automatically (#18)

Browse files

- clean up audio caches automatically (d69879ca9279d55ddad0b7a06e499cd6e4f2e81a)

Files changed (1) hide show
  1. app.py +52 -16
app.py CHANGED
@@ -5,6 +5,7 @@ from datetime import datetime
5
  import time
6
  import json
7
  import logging
 
8
  from typing import List, Dict, Tuple
9
  import gradio as gr
10
 
@@ -15,6 +16,29 @@ import soundfile as sf
15
  from huggingface_hub import CommitScheduler
16
  from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def setup_logger(log_dir: str, log_file: str, level: int = logging.INFO):
19
  os.makedirs(log_dir, exist_ok=True)
20
 
@@ -32,8 +56,11 @@ def setup_logger(log_dir: str, log_file: str, level: int = logging.INFO):
32
  console_handler.setFormatter(formatter)
33
  logger.addHandler(console_handler)
34
 
35
-
36
- file_handler = logging.FileHandler(os.path.join(log_dir, log_file), encoding="utf-8")
 
 
 
37
  file_handler.setFormatter(formatter)
38
  logger.addHandler(file_handler)
39
 
@@ -64,7 +91,7 @@ def infer(text_input: str, audio_input: str, sr: int = 16000) -> str:
64
  logger.error("Invalid inputs! Please check your prompt or reupload the audio file.")
65
  raise gr.Error("Invalid inputs! Please check your prompt or reupload the audio file.")
66
 
67
- audio_path, request_id = handle_request(text_input, audio_input, sr=sr)
68
  message = construct_message(text_input, audio_path)
69
  now = datetime.now(time_zone)
70
  now_timestamp = now.timestamp()
@@ -74,7 +101,10 @@ def infer(text_input: str, audio_input: str, sr: int = 16000) -> str:
74
  # response = _infer_test(message)
75
  end_time = time.perf_counter()
76
  cost_time = round(end_time - start_time, 2)
77
- log_data_str = handle_logging(text_input, audio_path, response, request_id, now_timestamp, cost_time)
 
 
 
78
  logger.info(log_data_str)
79
  gr.Info(
80
  message=f"Inference's been done, took {cost_time}s.",
@@ -106,24 +136,23 @@ def is_valid_inputs(text_input: str, audio_input: str) -> bool:
106
 
107
  def handle_request(
108
  text_input: str, audio_input: str,
109
- usr_dir: str = "resources/usr", sr: int = 16000):
110
  request_id = os.urandom(16).hex()
111
- out_dir = os.path.join(usr_dir, request_id)
112
- os.makedirs(out_dir, exist_ok=True)
113
 
114
- audio_path = os.path.join(out_dir, f"{request_id}.wav")
115
  try:
116
  audio = preprocess_audio(audio_input, target_sr=sr)
117
  sf.write(audio_path, audio, sr, format="WAV")
118
 
119
  return audio_path, request_id
120
  except Exception as e:
121
- shutil.rmtree(out_dir, ignore_errors=True)
122
  raise gr.Error(f"{e}")
123
 
124
  def handle_logging(
125
  text_input: str, audio_path: str, response: str,
126
- request_id: str, timestamp: float, cost_time: float,
127
  log_dir: str = "resources/logs", log_file : str = "response.log"):
128
  log_file_path = os.path.join(log_dir, log_file)
129
  audio_log_dir = os.path.join(log_dir, "audio")
@@ -140,11 +169,9 @@ def handle_logging(
140
  log_data_str = json.dumps(log_data, ensure_ascii=False)
141
  if log_scheduler is not None:
142
  with log_scheduler.lock:
143
- shutil.copy2(audio_path, audio_log_path)
144
  with open(log_file_path, mode='a', encoding="utf-8") as writer:
145
  writer.write(log_data_str + "\n")
146
  else:
147
- shutil.copy2(audio_path, audio_log_path)
148
  with open(log_file_path, mode='a', encoding="utf-8") as writer:
149
  writer.write(log_data_str + "\n")
150
 
@@ -182,11 +209,13 @@ def construct_message(
182
  def create_log_scheduler(
183
  repo_id: str, repo_type: str,
184
  folder_path: str, path_in_repo: str,
185
- every: int = 3600):
186
- scheduler = CommitScheduler(
 
187
  repo_id=repo_id, repo_type=repo_type,
188
  folder_path=folder_path, path_in_repo=path_in_repo,
189
- every=every)
 
190
  return scheduler
191
 
192
  def update_audio_input(choice):
@@ -208,9 +237,15 @@ def enable_button():
208
 
209
  if __name__ == "__main__":
210
  time_zone = ZoneInfo("Asia/Shanghai")
 
 
 
 
211
  model_name = "mispeech/MiDashengLM-7B"
212
  json_path = "resources/examples.json"
213
  log_dir = "resources/logs"
 
 
214
  log_file = "gradio.log" # DEBUG logs
215
 
216
  model, processor, tokenizer = load_model(model_name)
@@ -220,7 +255,8 @@ if __name__ == "__main__":
220
  log_scheduler = create_log_scheduler(
221
  repo_id="mispeech/MiDashengLM-logs", repo_type="dataset",
222
  folder_path="resources/logs", path_in_repo="logs",
223
- every=15) # commit every 15 minute
 
224
 
225
  with gr.Blocks() as demo:
226
  gr.Markdown("#🪄 Select an example or upload your own audio")
 
5
  import time
6
  import json
7
  import logging
8
+ from logging.handlers import RotatingFileHandler
9
  from typing import List, Dict, Tuple
10
  import gradio as gr
11
 
 
16
  from huggingface_hub import CommitScheduler
17
  from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
18
 
19
+ class MyCommitScheduler(CommitScheduler):
20
+ def __init__(self, *args, cleanup_dir: str, **kwargs):
21
+ super().__init__(*args, **kwargs)
22
+ self.cleanup_dir = cleanup_dir
23
+
24
+ def push_to_hub(self):
25
+ super().push_to_hub()
26
+ self.clean_up()
27
+ logger.info(f"Clean up directory: {self.cleanup_dir}.")
28
+
29
+ def clean_up(self):
30
+ if os.path.isdir(self.cleanup_dir):
31
+ for filename in os.listdir(self.cleanup_dir):
32
+ file_path = os.path.join(self.cleanup_dir, filename)
33
+ try:
34
+ if os.path.isfile(file_path) or os.path.islink(file_path):
35
+ os.remove(file_path)
36
+ elif os.path.isdir(file_path):
37
+ shutil.rmtree(file_path)
38
+ except Exception as e:
39
+ logger.error(f"{e}")
40
+
41
+
42
  def setup_logger(log_dir: str, log_file: str, level: int = logging.INFO):
43
  os.makedirs(log_dir, exist_ok=True)
44
 
 
56
  console_handler.setFormatter(formatter)
57
  logger.addHandler(console_handler)
58
 
59
+ file_handler = RotatingFileHandler(
60
+ os.path.join(log_dir, log_file),
61
+ encoding="utf-8",
62
+ maxBytes= 5 * 1024 * 1024, # max size is 5 MB
63
+ backupCount=1)
64
  file_handler.setFormatter(formatter)
65
  logger.addHandler(file_handler)
66
 
 
91
  logger.error("Invalid inputs! Please check your prompt or reupload the audio file.")
92
  raise gr.Error("Invalid inputs! Please check your prompt or reupload the audio file.")
93
 
94
+ audio_path, request_id = handle_request(text_input, audio_input, audio_dir=cleanup_dir, sr=sr)
95
  message = construct_message(text_input, audio_path)
96
  now = datetime.now(time_zone)
97
  now_timestamp = now.timestamp()
 
101
  # response = _infer_test(message)
102
  end_time = time.perf_counter()
103
  cost_time = round(end_time - start_time, 2)
104
+ log_data_str = handle_logging(
105
+ text_input, audio_path, response,
106
+ request_id, now_str, cost_time,
107
+ log_dir=log_dir, log_file=response_log_file)
108
  logger.info(log_data_str)
109
  gr.Info(
110
  message=f"Inference's been done, took {cost_time}s.",
 
136
 
137
  def handle_request(
138
  text_input: str, audio_input: str,
139
+ audio_dir: str = "resources/logs/audio", sr: int = 16000):
140
  request_id = os.urandom(16).hex()
141
+ os.makedirs(audio_dir, exist_ok=True)
 
142
 
143
+ audio_path = os.path.join(audio_dir, f"{request_id}.wav")
144
  try:
145
  audio = preprocess_audio(audio_input, target_sr=sr)
146
  sf.write(audio_path, audio, sr, format="WAV")
147
 
148
  return audio_path, request_id
149
  except Exception as e:
150
+ shutil.rmtree(audio_dir, ignore_errors=True)
151
  raise gr.Error(f"{e}")
152
 
153
  def handle_logging(
154
  text_input: str, audio_path: str, response: str,
155
+ request_id: str, timestamp: str, cost_time: float,
156
  log_dir: str = "resources/logs", log_file : str = "response.log"):
157
  log_file_path = os.path.join(log_dir, log_file)
158
  audio_log_dir = os.path.join(log_dir, "audio")
 
169
  log_data_str = json.dumps(log_data, ensure_ascii=False)
170
  if log_scheduler is not None:
171
  with log_scheduler.lock:
 
172
  with open(log_file_path, mode='a', encoding="utf-8") as writer:
173
  writer.write(log_data_str + "\n")
174
  else:
 
175
  with open(log_file_path, mode='a', encoding="utf-8") as writer:
176
  writer.write(log_data_str + "\n")
177
 
 
209
  def create_log_scheduler(
210
  repo_id: str, repo_type: str,
211
  folder_path: str, path_in_repo: str,
212
+ every: int = 30,
213
+ cleanup_dir: str = "resources/logs/audio"):
214
+ scheduler = MyCommitScheduler(
215
  repo_id=repo_id, repo_type=repo_type,
216
  folder_path=folder_path, path_in_repo=path_in_repo,
217
+ every=every,
218
+ cleanup_dir=cleanup_dir)
219
  return scheduler
220
 
221
  def update_audio_input(choice):
 
237
 
238
  if __name__ == "__main__":
239
  time_zone = ZoneInfo("Asia/Shanghai")
240
+ appstart_time = datetime.now(time_zone)
241
+ appstart_timestamp = appstart_time.timestamp()
242
+ appstart_str = appstart_time.strftime("%Y-%m-%d_%H-%M-%S")
243
+
244
  model_name = "mispeech/MiDashengLM-7B"
245
  json_path = "resources/examples.json"
246
  log_dir = "resources/logs"
247
+ cleanup_dir = "resources/logs/audio"
248
+ response_log_file = f"response_{appstart_str}.log"
249
  log_file = "gradio.log" # DEBUG logs
250
 
251
  model, processor, tokenizer = load_model(model_name)
 
255
  log_scheduler = create_log_scheduler(
256
  repo_id="mispeech/MiDashengLM-logs", repo_type="dataset",
257
  folder_path="resources/logs", path_in_repo="logs",
258
+ every=15,
259
+ cleanup_dir=cleanup_dir) # commit every 15 minute
260
 
261
  with gr.Blocks() as demo:
262
  gr.Markdown("#🪄 Select an example or upload your own audio")