quantumiracle-git commited on
Commit
9d0eab4
·
1 Parent(s): 332e8c3

optimize code

Browse files
Files changed (1) hide show
  1. app.py +45 -26
app.py CHANGED
@@ -6,6 +6,7 @@ import gdown
6
  import base64
7
  from time import gmtime, strftime
8
  from csv import writer
 
9
 
10
  from datasets import load_dataset
11
  from hfserver import HuggingFaceDatasetSaver, HuggingFaceDatasetJSONSaver
@@ -45,6 +46,7 @@ if LOAD_DATA_GOOGLE_DRIVE: # download data from google drive
45
  else: # local data
46
  VIDEO_PATH = 'robotinder-data'
47
 
 
48
 
49
  def inference(video_path):
50
  with open(video_path, "rb") as f:
@@ -123,15 +125,8 @@ def update(user_choice, left, right, choose_env, data_folder=VIDEO_PATH, flag_to
123
  else:
124
  env_name = choose_env
125
  # choose video
126
- videos = os.listdir(os.path.join(data_folder, env_name))
127
- video_files = []
128
- for f in videos:
129
- if f.endswith(f'.{FORMAT}'):
130
- video_files.append(os.path.join(data_folder, env_name, f))
131
- # randomly choose two videos
132
- selected_video_ids = np.random.choice(len(video_files), 2, replace=False)
133
- left = video_files[selected_video_ids[0]]
134
- right = video_files[selected_video_ids[1]]
135
  last_left_video_path = left
136
  last_right_video_path = right
137
  last_infer_left_video_path = inference(left)
@@ -143,11 +138,41 @@ def replay(left, right):
143
  return left, right
144
 
145
  def parse_envs(folder=VIDEO_PATH):
146
- envs = []
147
- for f in os.listdir(folder):
148
- if os.path.isdir(os.path.join(folder, f)):
149
- envs.append(f)
150
- return envs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
  def build_interface(iter=3, data_folder=VIDEO_PATH):
153
  import sys
@@ -162,26 +187,20 @@ def build_interface(iter=3, data_folder=VIDEO_PATH):
162
  # callback = gr.CSVLogger()
163
  callback = hf_writer
164
 
 
 
 
165
  # build gradio interface
166
  with gr.Blocks() as demo:
167
  gr.Markdown("## Here is <span style=color:cyan>RoboTinder</span>!")
168
  gr.Markdown("### Select the best robot behaviour in your choice!")
169
  # some initial values
170
- envs = parse_envs()
171
- env_name = envs[random.randint(0, len(envs)-1)] # random pick an env
172
  with gr.Row():
173
  str_env_name = gr.Markdown(f"{env_name}")
174
 
175
  # choose video
176
- videos = os.listdir(os.path.join(data_folder, env_name))
177
- video_files = []
178
- for f in videos:
179
- if f.endswith(f'.{FORMAT}'):
180
- video_files.append(os.path.join(data_folder, env_name, f))
181
- # randomly choose two videos
182
- selected_video_ids = np.random.choice(len(video_files), 2, replace=False)
183
- left_video_path = video_files[selected_video_ids[0]]
184
- right_video_path = video_files[selected_video_ids[1]]
185
 
186
  with gr.Row():
187
  if FORMAT == 'mp4':
@@ -219,7 +238,7 @@ def build_interface(iter=3, data_folder=VIDEO_PATH):
219
  btn2.click(fn=update, inputs=[user_choice, left, right, choose_env], outputs=[left, right, str_env_name])
220
 
221
  # We can choose which components to flag -- in this case, we'll flag all of them
222
- btn2.click(lambda *args: callback.flag(args), [user_choice, left, right], None, preprocess=False)
223
 
224
  return demo
225
 
 
6
  import base64
7
  from time import gmtime, strftime
8
  from csv import writer
9
+ import json
10
 
11
  from datasets import load_dataset
12
  from hfserver import HuggingFaceDatasetSaver, HuggingFaceDatasetJSONSaver
 
46
  else: # local data
47
  VIDEO_PATH = 'robotinder-data'
48
 
49
+ VIDEO_INFO = os.path.join(VIDEO_PATH, 'video_info.json')
50
 
51
  def inference(video_path):
52
  with open(video_path, "rb") as f:
 
125
  else:
126
  env_name = choose_env
127
  # choose video
128
+ left, right = randomly_select_videos(env_name)
129
+
 
 
 
 
 
 
 
130
  last_left_video_path = left
131
  last_right_video_path = right
132
  last_infer_left_video_path = inference(left)
 
138
  return left, right
139
 
140
  def parse_envs(folder=VIDEO_PATH):
141
+ """
142
+ return a dict of env_name: video_paths
143
+ """
144
+ files = {}
145
+ for env_name in os.listdir(folder):
146
+ env_path = os.path.join(folder, env_name)
147
+ if os.path.isdir(env_path):
148
+ videos = os.listdir(env_path)
149
+ video_files = []
150
+ for video in videos:
151
+ if video.endswith(f'.{FORMAT}'):
152
+ video_path = os.path.join(folder, env_name, video)
153
+ video_files.append(video_path)
154
+ files[env_name] = video_files
155
+
156
+ with open(VIDEO_INFO, 'w') as fp:
157
+ json.dump(files, fp)
158
+
159
+ return files
160
+
161
+ def get_env_names():
162
+ with open(VIDEO_INFO, 'r') as fp:
163
+ files = json.load(fp)
164
+ return list(files.keys())
165
+
166
+ def randomly_select_videos(env_name):
167
+ # load the parsed video info
168
+ with open(VIDEO_INFO, 'r') as fp:
169
+ files = json.load(fp)
170
+ env_files = files[env_name]
171
+ # randomly choose two videos
172
+ selected_video_ids = np.random.choice(len(env_files), 2, replace=False)
173
+ left_video_path = env_files[selected_video_ids[0]]
174
+ right_video_path = env_files[selected_video_ids[1]]
175
+ return left_video_path, right_video_path
176
 
177
  def build_interface(iter=3, data_folder=VIDEO_PATH):
178
  import sys
 
187
  # callback = gr.CSVLogger()
188
  callback = hf_writer
189
 
190
+ # parse the video folder
191
+ files = parse_envs()
192
+
193
  # build gradio interface
194
  with gr.Blocks() as demo:
195
  gr.Markdown("## Here is <span style=color:cyan>RoboTinder</span>!")
196
  gr.Markdown("### Select the best robot behaviour in your choice!")
197
  # some initial values
198
+ env_name = list(files.keys())[random.randint(0, len(files)-1)] # random pick an env
 
199
  with gr.Row():
200
  str_env_name = gr.Markdown(f"{env_name}")
201
 
202
  # choose video
203
+ left_video_path, right_video_path = randomly_select_videos(env_name)
 
 
 
 
 
 
 
 
204
 
205
  with gr.Row():
206
  if FORMAT == 'mp4':
 
238
  btn2.click(fn=update, inputs=[user_choice, left, right, choose_env], outputs=[left, right, str_env_name])
239
 
240
  # We can choose which components to flag -- in this case, we'll flag all of them
241
+ # btn2.click(lambda *args: callback.flag(args), [user_choice, left, right], None, preprocess=False) # not using the gradio flagging anymore
242
 
243
  return demo
244