CricAnnotate / app.py
ashu1069's picture
updates
f764e8a
raw
history blame
10.6 kB
import gradio as gr
import os
import json
from huggingface_hub import hf_hub_download, list_repo_files, upload_file, HfApi
from datasets import load_dataset, Dataset
import logging
import tempfile
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Cricket annotation categories
ANNOTATION_CATEGORIES = {
"Bowler's Run Up": ["Fast", "Slow"],
"Delivery Type": ["Yorker", "Bouncer", "Length Ball", "Slower ball", "Googly", "Arm Ball", "Other"],
"Ball's trajectory": ["In Swing", "Out Swing", "Off spin", "Leg spin"],
"Shot Played": ["Cover Drive", "Straight Drive", "On Drive", "Pull", "Square Cut", "Defensive Block"],
"Outcome of the shot": ["Four (4)", "Six (6)", "Wicket", "Single (1)", "Double (2)", "Triple (3)", "Dot (0)"],
"Shot direction": ["Long On", "Long Off", "Cover", "Point", "Midwicket", "Square Leg", "Third Man", "Fine Leg"],
"Fielder's Action": ["Catch taken", "Catch dropped", "Misfield", "Run-out attempt", "Fielder fields"]
}
HF_REPO_ID = "srrthk/CricBench"
HF_REPO_TYPE = "dataset"
class VideoAnnotator:
def __init__(self):
self.video_files = []
self.current_video_idx = 0
self.annotations = {}
self.hf_token = os.environ.get("HF_TOKEN")
self.dataset = None
def load_videos_from_hf(self):
try:
logger.info(f"Loading dataset from HuggingFace: {HF_REPO_ID}")
self.dataset = load_dataset(HF_REPO_ID, token=self.hf_token)
# Get the split (usually 'train')
split = list(self.dataset.keys())[0]
self.dataset_split = self.dataset[split]
# Get all video files from the dataset
self.video_files = [item['video'] if 'video' in item else item['path']
for item in self.dataset_split]
logger.info(f"Found {len(self.video_files)} video files")
return len(self.video_files) > 0
except Exception as e:
logger.error(f"Error accessing HuggingFace dataset: {e}")
return False
def get_current_video(self):
if not self.video_files:
logger.warning("No video files available")
return None
video_path = self.video_files[self.current_video_idx]
logger.info(f"Loading video: {video_path}")
try:
local_path = hf_hub_download(
repo_id=HF_REPO_ID,
filename=video_path,
repo_type=HF_REPO_TYPE
)
logger.info(f"Video downloaded to: {local_path}")
return local_path
except Exception as e:
logger.error(f"Error downloading video: {e}")
return None
def save_annotation(self, annotations_dict):
if not annotations_dict:
logger.warning("No annotations to save")
return "No annotations to save"
video_name = os.path.basename(self.video_files[self.current_video_idx])
logger.info(f"Saving annotations for {video_name}")
try:
# Update the dataset with the new annotations
if self.dataset is not None:
# Get the split name (e.g., 'train')
split = list(self.dataset.keys())[0]
# Create a copy of the dataset to modify
updated_dataset = self.dataset[split].to_pandas()
# Convert annotations to JSON string
annotation_json = json.dumps(annotations_dict)
# Update the annotations column for the current video
updated_dataset.loc[self.current_video_idx, 'annotations'] = annotation_json
# Convert back to Hugging Face dataset
new_dataset = Dataset.from_pandas(updated_dataset)
# Push updated dataset to Hugging Face Hub
if self.hf_token:
logger.info(f"Uploading updated dataset to Hugging Face: {HF_REPO_ID}")
new_dataset.push_to_hub(
HF_REPO_ID,
split=split,
token=self.hf_token
)
# Update our local copy
self.dataset[split] = new_dataset
return f"Annotations saved for {video_name} and uploaded to Hugging Face dataset"
else:
logger.warning("HF_TOKEN not found. Dataset updated locally only.")
self.dataset[split] = new_dataset
return f"Annotations saved locally for {video_name} (no HF upload)"
else:
logger.error("Dataset not loaded, cannot save annotations")
return "Error: Dataset not loaded"
except Exception as e:
logger.error(f"Error saving annotations: {e}")
return f"Error saving: {str(e)}"
def load_existing_annotation(self):
"""Try to load existing annotation for the current video from the dataset"""
if not self.dataset or not self.video_files:
return None
try:
# Get the split name (e.g., 'train')
split = list(self.dataset.keys())[0]
# Check if the current item has annotations
if 'annotations' in self.dataset[split][self.current_video_idx]:
annotation_str = self.dataset[split][self.current_video_idx]['annotations']
if annotation_str:
return json.loads(annotation_str)
return None
except Exception as e:
logger.error(f"Error loading existing annotation: {e}")
return None
def next_video(self, *current_annotations):
# Save current annotations before moving to next video
if self.video_files:
annotations_dict = {}
for i, category in enumerate(ANNOTATION_CATEGORIES.keys()):
if current_annotations[i]:
annotations_dict[category] = current_annotations[i]
if annotations_dict:
self.save_annotation(annotations_dict)
# Move to next video
if self.current_video_idx < len(self.video_files) - 1:
self.current_video_idx += 1
logger.info(f"Moving to next video (index: {self.current_video_idx})")
return self.get_current_video(), *[None] * len(ANNOTATION_CATEGORIES)
else:
logger.info("Already at the last video")
return self.get_current_video(), *[None] * len(ANNOTATION_CATEGORIES)
def prev_video(self, *current_annotations):
# Save current annotations before moving to previous video
if self.video_files:
annotations_dict = {}
for i, category in enumerate(ANNOTATION_CATEGORIES.keys()):
if current_annotations[i]:
annotations_dict[category] = current_annotations[i]
if annotations_dict:
self.save_annotation(annotations_dict)
# Move to previous video
if self.current_video_idx > 0:
self.current_video_idx -= 1
logger.info(f"Moving to previous video (index: {self.current_video_idx})")
return self.get_current_video(), *[None] * len(ANNOTATION_CATEGORIES)
else:
logger.info("Already at the first video")
return self.get_current_video(), *[None] * len(ANNOTATION_CATEGORIES)
def create_interface():
annotator = VideoAnnotator()
success = annotator.load_videos_from_hf()
if not success:
logger.error("Failed to load videos. Using demo mode with sample video.")
# In real app, you might want to provide a sample video or show an error
with gr.Blocks() as demo:
gr.Markdown("# Cricket Video Annotation Tool")
with gr.Row():
video_player = gr.Video(label="Current Video")
annotation_components = []
with gr.Row():
with gr.Column():
for category, options in list(ANNOTATION_CATEGORIES.items())[:4]:
radio = gr.Radio(
choices=options,
label=category,
info=f"Select {category}"
)
annotation_components.append(radio)
with gr.Column():
for category, options in list(ANNOTATION_CATEGORIES.items())[4:]:
radio = gr.Radio(
choices=options,
label=category,
info=f"Select {category}"
)
annotation_components.append(radio)
with gr.Row():
prev_btn = gr.Button("Previous Video")
save_btn = gr.Button("Save Annotations", variant="primary")
next_btn = gr.Button("Next Video")
# Initialize with first video
current_video = annotator.get_current_video()
if current_video:
video_player.value = current_video
# Try to load existing annotations
existing_annotations = annotator.load_existing_annotation()
if existing_annotations:
for i, category in enumerate(ANNOTATION_CATEGORIES.keys()):
if category in existing_annotations:
annotation_components[i].value = existing_annotations[category]
# Event handlers
save_btn.click(
fn=annotator.save_annotation,
inputs=[gr.Group(annotation_components)],
outputs=gr.Textbox(label="Status")
)
next_btn.click(
fn=annotator.next_video,
inputs=annotation_components,
outputs=[video_player] + annotation_components
)
prev_btn.click(
fn=annotator.prev_video,
inputs=annotation_components,
outputs=[video_player] + annotation_components
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch()
# Add a local video for testing if no videos are loaded from Hugging Face