OpenBiDexHand / app.py
quantumiracle-git's picture
Update app.py
808ed62
raw
history blame
5.56 kB
import gradio as gr
import os
import random
import numpy as np
import gdown
from time import gmtime, strftime
from csv import writer
from datasets import load_dataset
from hfserver import HuggingFaceDatasetSaver, HuggingFaceDatasetJSONSaver
# download data from huggingface dataset
# dataset = load_dataset("quantumiracle-git/robotinder-data")
# download data from google drive
# url = 'https://drive.google.com/drive/folders/10UmNM2YpvNSkdLMgYiIAxk5IbS4dUezw?usp=sharing'
# output = './'
# id = url.split('/')[-1]
# os.system(f"gdown --id {id} -O {output} --folder --no-cookies")
def video_identity(video):
return video
def nan():
return None
FORMAT = ['mp4', 'gif'][1]
VIDEO_PATH = 'robotinder-data'
def get_huggingface_dataset():
try:
import huggingface_hub
except (ImportError, ModuleNotFoundError):
raise ImportError(
"Package `huggingface_hub` not found is needed "
"for HuggingFaceDatasetSaver. Try 'pip install huggingface_hub'."
)
HF_TOKEN = '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' # my HF token
DATASET_NAME = 'crowdsourced-robotinder-demo'
FLAGGING_DIR = 'flag/'
path_to_dataset_repo = huggingface_hub.create_repo(
repo_id=DATASET_NAME,
token=HF_TOKEN,
private=False,
repo_type="dataset",
exist_ok=True,
)
dataset_dir = os.path.join(DATASET_NAME, FLAGGING_DIR)
repo = huggingface_hub.Repository(
local_dir=dataset_dir,
clone_from=path_to_dataset_repo,
use_auth_token=HF_TOKEN,
)
repo.git_pull(lfs=True)
log_file = os.path.join(dataset_dir, "flag_data.csv")
return repo, log_file
def update(user_choice, data_folder=VIDEO_PATH):
envs = parse_envs()
env_name = envs[random.randint(0, len(envs)-1)]
# choose video
videos = os.listdir(os.path.join(data_folder, env_name))
video_files = []
for f in videos:
if f.endswith(f'.{FORMAT}'):
video_files.append(os.path.join(data_folder, env_name, f))
# choose two videos
selected_video_ids = np.random.choice(len(video_files), 2, replace=False)
left = video_files[selected_video_ids[0]]
right = video_files[selected_video_ids[1]]
# log
current_time = strftime("%Y-%m-%d-%H-%M-%S", gmtime())
info = [env_name, user_choice, left, right, current_time]
print(info)
repo, log_file = get_huggingface_dataset() # flag without using gradio flagging
with open(log_file, 'a') as file: # incremental change of the file
writer_object = writer(file)
writer_object.writerow(info)
file.close()
repo.push_to_hub(commit_message=f"Flagged sample at {current_time}")
return left, right
def replay(left, right):
return left, right
def parse_envs(folder='./videos'):
envs = []
for f in os.listdir(folder):
if os.path.isdir(os.path.join(folder, f)):
envs.append(f)
return envs
def build_interface(iter=3, data_folder='./videos'):
HF_TOKEN = os.getenv('HF_TOKEN')
print(HF_TOKEN)
HF_TOKEN = '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' # my HF token
# hf_writer = gr.HuggingFaceDatasetSaver(HF_TOKEN, "crowdsourced-robotinder-demo") # HuggingFace logger instead of local one: https://github.com/gradio-app/gradio/blob/master/gradio/flagging.py
hf_writer = HuggingFaceDatasetSaver(HF_TOKEN, "crowdsourced-robotinder-demo")
# callback = gr.CSVLogger()
callback = hf_writer
# build gradio interface
with gr.Blocks() as demo:
gr.Markdown("Here is RoboTinder!")
gr.Markdown("Select the best robot behaviour in your choice!")
with gr.Row():
# some initial videos
if FORMAT == 'mp4':
left_video_path = os.path.join(os.path.dirname(__file__),
"videos/rl-video-episode-0.mp4")
right_video_path = os.path.join(os.path.dirname(__file__),
"videos/rl-video-episode-1.mp4")
left = gr.PlayableVideo(left_video_path, label="left_video")
right = gr.PlayableVideo(right_video_path, label="right_video")
else:
left_video_path = os.path.join(os.path.dirname(__file__),
"videos/rl-video-episode-0.gif")
right_video_path = os.path.join(os.path.dirname(__file__),
"videos/rl-video-episode-1.gif")
left = gr.Image(left_video_path, shape=(1024, 768), label="left_video")
# right = gr.Image(right_video_path).style(height=768, width=1024)
right = gr.Image(right_video_path, label="right_video")
btn1 = gr.Button("Replay")
user_choice = gr.Radio(["Left", "Right", "Not Sure"], label="Which one is your favorite?")
btn2 = gr.Button("Next")
# This needs to be called at some point prior to the first call to callback.flag()
callback.setup([user_choice, left, right], "flagged_data_points")
btn1.click(fn=replay, inputs=[left, right], outputs=[left, right])
btn2.click(fn=update, inputs=[user_choice], outputs=[left, right])
# We can choose which components to flag -- in this case, we'll flag all of them
btn2.click(lambda *args: callback.flag(args), [user_choice, left, right], None, preprocess=False)
return demo
if __name__ == "__main__":
demo = build_interface()
# demo.launch(share=True)
demo.launch(share=False)