Spaces:
Sleeping
Sleeping
quantumiracle-git
commited on
Commit
·
9d0eab4
1
Parent(s):
332e8c3
optimize code
Browse files
app.py
CHANGED
@@ -6,6 +6,7 @@ import gdown
|
|
6 |
import base64
|
7 |
from time import gmtime, strftime
|
8 |
from csv import writer
|
|
|
9 |
|
10 |
from datasets import load_dataset
|
11 |
from hfserver import HuggingFaceDatasetSaver, HuggingFaceDatasetJSONSaver
|
@@ -45,6 +46,7 @@ if LOAD_DATA_GOOGLE_DRIVE: # download data from google drive
|
|
45 |
else: # local data
|
46 |
VIDEO_PATH = 'robotinder-data'
|
47 |
|
|
|
48 |
|
49 |
def inference(video_path):
|
50 |
with open(video_path, "rb") as f:
|
@@ -123,15 +125,8 @@ def update(user_choice, left, right, choose_env, data_folder=VIDEO_PATH, flag_to
|
|
123 |
else:
|
124 |
env_name = choose_env
|
125 |
# choose video
|
126 |
-
|
127 |
-
|
128 |
-
for f in videos:
|
129 |
-
if f.endswith(f'.{FORMAT}'):
|
130 |
-
video_files.append(os.path.join(data_folder, env_name, f))
|
131 |
-
# randomly choose two videos
|
132 |
-
selected_video_ids = np.random.choice(len(video_files), 2, replace=False)
|
133 |
-
left = video_files[selected_video_ids[0]]
|
134 |
-
right = video_files[selected_video_ids[1]]
|
135 |
last_left_video_path = left
|
136 |
last_right_video_path = right
|
137 |
last_infer_left_video_path = inference(left)
|
@@ -143,11 +138,41 @@ def replay(left, right):
|
|
143 |
return left, right
|
144 |
|
145 |
def parse_envs(folder=VIDEO_PATH):
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
|
152 |
def build_interface(iter=3, data_folder=VIDEO_PATH):
|
153 |
import sys
|
@@ -162,26 +187,20 @@ def build_interface(iter=3, data_folder=VIDEO_PATH):
|
|
162 |
# callback = gr.CSVLogger()
|
163 |
callback = hf_writer
|
164 |
|
|
|
|
|
|
|
165 |
# build gradio interface
|
166 |
with gr.Blocks() as demo:
|
167 |
gr.Markdown("## Here is <span style=color:cyan>RoboTinder</span>!")
|
168 |
gr.Markdown("### Select the best robot behaviour in your choice!")
|
169 |
# some initial values
|
170 |
-
|
171 |
-
env_name = envs[random.randint(0, len(envs)-1)] # random pick an env
|
172 |
with gr.Row():
|
173 |
str_env_name = gr.Markdown(f"{env_name}")
|
174 |
|
175 |
# choose video
|
176 |
-
|
177 |
-
video_files = []
|
178 |
-
for f in videos:
|
179 |
-
if f.endswith(f'.{FORMAT}'):
|
180 |
-
video_files.append(os.path.join(data_folder, env_name, f))
|
181 |
-
# randomly choose two videos
|
182 |
-
selected_video_ids = np.random.choice(len(video_files), 2, replace=False)
|
183 |
-
left_video_path = video_files[selected_video_ids[0]]
|
184 |
-
right_video_path = video_files[selected_video_ids[1]]
|
185 |
|
186 |
with gr.Row():
|
187 |
if FORMAT == 'mp4':
|
@@ -219,7 +238,7 @@ def build_interface(iter=3, data_folder=VIDEO_PATH):
|
|
219 |
btn2.click(fn=update, inputs=[user_choice, left, right, choose_env], outputs=[left, right, str_env_name])
|
220 |
|
221 |
# We can choose which components to flag -- in this case, we'll flag all of them
|
222 |
-
btn2.click(lambda *args: callback.flag(args), [user_choice, left, right], None, preprocess=False)
|
223 |
|
224 |
return demo
|
225 |
|
|
|
6 |
import base64
|
7 |
from time import gmtime, strftime
|
8 |
from csv import writer
|
9 |
+
import json
|
10 |
|
11 |
from datasets import load_dataset
|
12 |
from hfserver import HuggingFaceDatasetSaver, HuggingFaceDatasetJSONSaver
|
|
|
46 |
else: # local data
|
47 |
VIDEO_PATH = 'robotinder-data'
|
48 |
|
49 |
+
VIDEO_INFO = os.path.join(VIDEO_PATH, 'video_info.json')
|
50 |
|
51 |
def inference(video_path):
|
52 |
with open(video_path, "rb") as f:
|
|
|
125 |
else:
|
126 |
env_name = choose_env
|
127 |
# choose video
|
128 |
+
left, right = randomly_select_videos(env_name)
|
129 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
last_left_video_path = left
|
131 |
last_right_video_path = right
|
132 |
last_infer_left_video_path = inference(left)
|
|
|
138 |
return left, right
|
139 |
|
140 |
def parse_envs(folder=VIDEO_PATH):
|
141 |
+
"""
|
142 |
+
return a dict of env_name: video_paths
|
143 |
+
"""
|
144 |
+
files = {}
|
145 |
+
for env_name in os.listdir(folder):
|
146 |
+
env_path = os.path.join(folder, env_name)
|
147 |
+
if os.path.isdir(env_path):
|
148 |
+
videos = os.listdir(env_path)
|
149 |
+
video_files = []
|
150 |
+
for video in videos:
|
151 |
+
if video.endswith(f'.{FORMAT}'):
|
152 |
+
video_path = os.path.join(folder, env_name, video)
|
153 |
+
video_files.append(video_path)
|
154 |
+
files[env_name] = video_files
|
155 |
+
|
156 |
+
with open(VIDEO_INFO, 'w') as fp:
|
157 |
+
json.dump(files, fp)
|
158 |
+
|
159 |
+
return files
|
160 |
+
|
161 |
+
def get_env_names():
|
162 |
+
with open(VIDEO_INFO, 'r') as fp:
|
163 |
+
files = json.load(fp)
|
164 |
+
return list(files.keys())
|
165 |
+
|
166 |
+
def randomly_select_videos(env_name):
|
167 |
+
# load the parsed video info
|
168 |
+
with open(VIDEO_INFO, 'r') as fp:
|
169 |
+
files = json.load(fp)
|
170 |
+
env_files = files[env_name]
|
171 |
+
# randomly choose two videos
|
172 |
+
selected_video_ids = np.random.choice(len(env_files), 2, replace=False)
|
173 |
+
left_video_path = env_files[selected_video_ids[0]]
|
174 |
+
right_video_path = env_files[selected_video_ids[1]]
|
175 |
+
return left_video_path, right_video_path
|
176 |
|
177 |
def build_interface(iter=3, data_folder=VIDEO_PATH):
|
178 |
import sys
|
|
|
187 |
# callback = gr.CSVLogger()
|
188 |
callback = hf_writer
|
189 |
|
190 |
+
# parse the video folder
|
191 |
+
files = parse_envs()
|
192 |
+
|
193 |
# build gradio interface
|
194 |
with gr.Blocks() as demo:
|
195 |
gr.Markdown("## Here is <span style=color:cyan>RoboTinder</span>!")
|
196 |
gr.Markdown("### Select the best robot behaviour in your choice!")
|
197 |
# some initial values
|
198 |
+
env_name = list(files.keys())[random.randint(0, len(files)-1)] # random pick an env
|
|
|
199 |
with gr.Row():
|
200 |
str_env_name = gr.Markdown(f"{env_name}")
|
201 |
|
202 |
# choose video
|
203 |
+
left_video_path, right_video_path = randomly_select_videos(env_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
|
205 |
with gr.Row():
|
206 |
if FORMAT == 'mp4':
|
|
|
238 |
btn2.click(fn=update, inputs=[user_choice, left, right, choose_env], outputs=[left, right, str_env_name])
|
239 |
|
240 |
# We can choose which components to flag -- in this case, we'll flag all of them
|
241 |
+
# btn2.click(lambda *args: callback.flag(args), [user_choice, left, right], None, preprocess=False) # not using the gradio flagging anymore
|
242 |
|
243 |
return demo
|
244 |
|