quantumiracle-git commited on
Commit
1237651
·
1 Parent(s): 3e88bff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -2
app.py CHANGED
@@ -49,6 +49,7 @@ else: # local data
49
  VIDEO_INFO = os.path.join(VIDEO_PATH, 'video_info.json')
50
 
51
  def inference(video_path):
 
52
  with open(video_path, "rb") as f:
53
  data = f.read()
54
  b64 = base64.b64encode(data).decode()
@@ -137,20 +138,49 @@ def update(user_choice, left, right, choose_env, data_folder=VIDEO_PATH, flag_to
137
  def replay(left, right):
138
  return left, right
139
 
140
- def parse_envs(folder=VIDEO_PATH):
141
  """
142
  return a dict of env_name: video_paths
143
  """
144
  files = {}
 
 
 
145
  for env_name in os.listdir(folder):
146
  env_path = os.path.join(folder, env_name)
147
  if os.path.isdir(env_path):
148
  videos = os.listdir(env_path)
149
  video_files = []
150
- for video in videos:
151
  if video.endswith(f'.{FORMAT}'):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  video_path = os.path.join(folder, env_name, video)
153
  video_files.append(video_path)
 
 
154
  files[env_name] = video_files
155
 
156
  with open(VIDEO_INFO, 'w') as fp:
 
49
  VIDEO_INFO = os.path.join(VIDEO_PATH, 'video_info.json')
50
 
51
  def inference(video_path):
52
+ # for displaying mp4 with autoplay on Gradio
53
  with open(video_path, "rb") as f:
54
  data = f.read()
55
  b64 = base64.b64encode(data).decode()
 
138
  def replay(left, right):
139
  return left, right
140
 
141
+ def parse_envs(folder=VIDEO_PATH, filter=True, MAX_ITER=20000, DEFAULT_ITER=20000):
142
  """
143
  return a dict of env_name: video_paths
144
  """
145
  files = {}
146
+ if filter:
147
+ df = pd.read_csv('Bidexhands_Video.csv')
148
+ # print(df)
149
  for env_name in os.listdir(folder):
150
  env_path = os.path.join(folder, env_name)
151
  if os.path.isdir(env_path):
152
  videos = os.listdir(env_path)
153
  video_files = []
154
+ for video in videos: # video name rule: EnvName_Alg_Seed_Timestamp_Checkpoint_video-episode_EpisodeID
155
  if video.endswith(f'.{FORMAT}'):
156
+ if filter:
157
+ seed = video.split('_')[2]
158
+ checkpoint = video.split('_')[4]
159
+ try:
160
+ succeed_iteration = df.loc[(df['seed'] == int(seed)) & (df['env_name'] == str(env_name))]['succeed_iteration'].iloc[0]
161
+ except:
162
+ print(f'Env {env_name} with seed {seed} not found in Bidexhands_Video.csv')
163
+
164
+ if 'unsolved' in succeed_iteration:
165
+ continue
166
+ elif pd.isnull(succeed_iteration):
167
+ min_iter = DEFAULT_ITER
168
+ max_iter = MAX_ITER
169
+ elif '-' in succeed_iteration:
170
+ [min_iter, max_iter] = succeed_iteration.split('-')
171
+ else:
172
+ min_iter = succeed_iteration
173
+ max_iter = MAX_ITER
174
+
175
+ # check if the checkpoint is in the valid range
176
+ valid_checkpoints = np.arange(int(min_iter), int(max_iter)+1000, 1000)
177
+ if int(checkpoint) not in valid_checkpoints:
178
+ continue
179
+
180
  video_path = os.path.join(folder, env_name, video)
181
  video_files.append(video_path)
182
+ # print(video_path)
183
+
184
  files[env_name] = video_files
185
 
186
  with open(VIDEO_INFO, 'w') as fp: