CricAnnotate / app.py
ashu1069's picture
updates
624e18c
raw
history blame
16 kB
import gradio as gr
import os
import json
from huggingface_hub import hf_hub_download, list_repo_files, upload_file, HfApi
from datasets import load_dataset, Dataset
import logging
import tempfile
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Cricket annotation categories
ANNOTATION_CATEGORIES = {
"Bowler's Run Up": ["Fast", "Spin"],
"Delivery Type": ["Yorker", "Bouncer", "Length Ball", "Slower ball", "Googly", "Arm Ball", "Other"],
"Ball's trajectory": ["In Swing", "Out Swing", "Off spin", "Leg spin", "Straight"],
"Shot Played": ["Cover Drive", "Straight Drive", "On Drive", "Pull", "Square Cut", "Defensive Block"],
"Outcome of the shot": ["Four (4)", "Six (6)", "Wicket", "Single (1)", "Double (2)", "Triple (3)", "Dot (0)"],
"Shot direction": ["Long On", "Long Off", "Cover", "Point", "Midwicket", "Square Leg", "Third Man", "Fine Leg"],
"Fielder's Action": ["Catch taken", "Catch dropped", "Misfield", "Run-out attempt", "Fielder fields"]
}
HF_REPO_ID = "cricverse/CricBench"
HF_REPO_TYPE = "dataset"
class VideoAnnotator:
def __init__(self):
self.video_files = []
self.current_video_idx = 0
self.annotations = {}
self.hf_token = os.environ.get("HF_TOKEN")
self.dataset = None
self.annotation_repo_id = "cricverse/CricBench_Annotations"
self.annotation_repo_type = "dataset"
self.annotation_hf_token = os.environ.get("upload_token")
if not self.annotation_hf_token:
raise ValueError("HF_ANNOTATION_TOKEN not found")
self.api = HfApi(token=self.annotation_hf_token)
def load_videos_from_hf(self):
try:
logger.info(f"Loading dataset from HuggingFace: {HF_REPO_ID}")
self.dataset = load_dataset(HF_REPO_ID, token=self.hf_token)
# Get the split (usually 'train')
split = list(self.dataset.keys())[0]
self.dataset_split = self.dataset[split]
# Get all potential video files from the dataset
all_video_files = [item['video'] if 'video' in item else item['path']
for item in self.dataset_split]
logger.info(f"Found {len(all_video_files)} potential video files in source dataset.")
# --- Start Edit: Filter out already annotated videos ---
logger.info(f"Checking for existing annotations in {self.annotation_repo_id}")
try:
annotated_files = self.api.list_repo_files(
repo_id=self.annotation_repo_id,
repo_type=self.annotation_repo_type,
path_in_repo="annotations"
)
# Extract base video names from annotation filenames (e.g., "annotations/video1.mp4.jsonl" -> "video1.mp4")
annotated_video_basenames = set(
os.path.basename(f).replace('.jsonl', '') for f in annotated_files if f.startswith("annotations/") and f.endswith(".jsonl")
)
logger.info(f"Found {len(annotated_video_basenames)} existing annotation files.")
# Filter the video list
self.video_files = [
vf for vf in all_video_files
if os.path.basename(vf) not in annotated_video_basenames
]
logger.info(f"Filtered list: {len(self.video_files)} videos remaining to be annotated.")
except Exception as e:
logger.error(f"Could not list or process annotation files: {e}. Proceeding with all videos, but conflicts may occur.")
self.video_files = all_video_files # Fallback: load all if check fails
# --- End Edit ---
# print(self.video_files) # Optional: keep if needed for debugging
# logger.info(f"Found {len(self.video_files)} video files") # Updated log message above
# print(f"Video files found: {self.video_files}") # Optional: keep if needed for debugging
if not self.video_files:
logger.warning("No videos left to annotate!")
# Optionally, display a message in the UI here if possible
return len(self.video_files) > 0
except Exception as e:
logger.error(f"Error accessing HuggingFace dataset: {e}")
return False
def get_current_video(self):
if not self.video_files:
logger.warning("No video files available")
return None
video_path = self.video_files[self.current_video_idx]
logger.info(f"Loading video: {video_path}")
try:
local_path = hf_hub_download(
repo_id=HF_REPO_ID,
filename=video_path,
repo_type=HF_REPO_TYPE
)
logger.info(f"Video downloaded to: {local_path}")
return local_path
except Exception as e:
logger.error(f"Error downloading video: {e}")
return None
def save_annotation(self, *annotations):
# Convert the list of annotations into a dictionary
annotations_dict = {
category: value
for category, value in zip(ANNOTATION_CATEGORIES.keys(), annotations)
if value is not None # Only include non-None values
}
if not annotations_dict:
logger.warning("No annotations to save")
return "No annotations to save"
video_name = os.path.basename(self.video_files[self.current_video_idx])
logger.info(f"Saving annotations for {video_name}: {annotations_dict}")
try:
# Save annotations in JSONL format
annotation_entry = {
"video_id": video_name,
"annotations": annotations_dict
}
jsonl_content = json.dumps(annotation_entry) + "\n"
# Write to a temporary JSONL file
with tempfile.NamedTemporaryFile(delete=False, suffix=".jsonl") as temp_file:
temp_file.write(jsonl_content.encode('utf-8'))
temp_file_path = temp_file.name
# Upload the JSONL file to Hugging Face
# if self.annotation_hf_token:
logger.info(f"Uploading annotations to Hugging Face: {self.annotation_repo_id}")
self.api.upload_file(
path_or_fileobj=temp_file_path,
path_in_repo=f"annotations/{video_name}.jsonl",
repo_id=self.annotation_repo_id,
repo_type=self.annotation_repo_type,
# token=self.annotation_hf_token
)
return f"Annotations saved and uploaded for {video_name}"
# else:
# logger.warning("HF_ANNOTATION_TOKEN not found. Annotations saved locally only.")
# return f"Annotations saved locally for {video_name} (no HF upload)"
except Exception as e:
logger.error(f"Error saving annotations: {e}")
return f"Error saving: {str(e)}"
def load_existing_annotation(self):
"""Try to load existing annotation for the current video from the dataset"""
if not self.dataset or not self.video_files:
return None
try:
# Get the split name (e.g., 'train')
split = list(self.dataset.keys())[0]
# Check if the current item has annotations
if 'annotations' in self.dataset[split][self.current_video_idx]:
annotation_str = self.dataset[split][self.current_video_idx]['annotations']
if annotation_str:
return json.loads(annotation_str)
return None
except Exception as e:
logger.error(f"Error loading existing annotation: {e}")
return None
def next_video(self, *current_annotations):
# Save current annotations before moving to next video
save_status = "No annotations provided to save." # Default status
if self.video_files:
annotations_provided = any(ann is not None for ann in current_annotations)
if annotations_provided:
save_status = self.save_annotation(*current_annotations)
logger.info(f"Save status before moving next: {save_status}")
else:
logger.info("No annotations selected, skipping save before moving next.")
save_status = "Skipped saving - no annotations selected."
# Move to next video
if self.current_video_idx < len(self.video_files) - 1:
self.current_video_idx += 1
logger.info(f"Moving to next video (index: {self.current_video_idx})")
# --- Start Edit: Return only the save status as the last element ---
return self.get_current_video(), *[None] * len(ANNOTATION_CATEGORIES), save_status
# --- End Edit ---
else:
logger.info("Already at the last video")
# --- Start Edit: Return only the save status as the last element ---
return self.get_current_video(), *[None] * len(ANNOTATION_CATEGORIES), "Already at the last video. " + save_status
# --- End Edit ---
def prev_video(self, *current_annotations):
# Save current annotations before moving to previous video
save_status = "No annotations provided to save." # Default status
if self.video_files:
annotations_provided = any(ann is not None for ann in current_annotations)
if annotations_provided:
save_status = self.save_annotation(*current_annotations)
logger.info(f"Save status before moving previous: {save_status}")
else:
logger.info("No annotations selected, skipping save before moving previous.")
save_status = "Skipped saving - no annotations selected."
# Move to previous video
if self.current_video_idx > 0:
self.current_video_idx -= 1
logger.info(f"Moving to previous video (index: {self.current_video_idx})")
# --- Start Edit: Return only the save status as the last element ---
return self.get_current_video(), *[None] * len(ANNOTATION_CATEGORIES), save_status
# --- End Edit ---
else:
logger.info("Already at the first video")
# --- Start Edit: Return only the save status as the last element ---
return self.get_current_video(), *[None] * len(ANNOTATION_CATEGORIES), "Already at the first video. " + save_status
# --- End Edit ---
def create_interface():
annotator = VideoAnnotator()
success = annotator.load_videos_from_hf()
if not success:
logger.error("Failed to load videos. Interface might not function correctly.")
# Handle the error appropriately, maybe show a message in the UI
with gr.Blocks() as demo:
gr.Markdown("# Cricket Video Annotation Tool")
# --- Start Edit: Define progress update function ---
total_categories = len(ANNOTATION_CATEGORIES)
def update_progress(*annotation_values):
filled_count = sum(1 for val in annotation_values if val is not None)
return f"**Progress:** {filled_count} / {total_categories} categories selected"
# --- End Edit ---
with gr.Row(): # Main row to hold video and controls side-by-side
with gr.Column(scale=2): # Column for Video Player and Nav Buttons
video_player = gr.Video(
value=annotator.get_current_video,
label="Current Video",
height=350
)
# --- Start Edit: Rename status_textbox to status_display ---
status_display = gr.Textbox(label="Status", interactive=False) # For save/nav messages
# --- End Edit ---
with gr.Row():
prev_btn = gr.Button("Previous Video")
next_btn = gr.Button("Next Video")
with gr.Column(scale=2): # Column for Annotations and Save Button
annotation_components = []
gr.Markdown("### Annotations") # Header for the annotation section
# Display annotation radio buttons vertically in this column
for category, options in ANNOTATION_CATEGORIES.items():
radio = gr.Radio(
choices=options,
label=category,
)
annotation_components.append(radio)
# --- Start Edit: Add Progress Display and attach change listeners ---
progress_display = gr.Markdown(value=update_progress(*[None]*total_categories)) # Initial progress
# Attach change listener to each radio button
for radio in annotation_components:
radio.change(
fn=update_progress,
inputs=annotation_components,
outputs=progress_display
)
# --- End Edit ---
save_btn = gr.Button("Save Annotations", variant="primary")
# Initialize with first video and potential annotations (logic remains the same)
current_video = annotator.get_current_video()
if current_video:
logger.info(f"Setting initial video player value: {current_video}")
# The video_player will call annotator.get_current_video on load,
# so explicitly setting value might be redundant unless get_current_video changes state.
# video_player.value = current_video # Let's rely on the function loader
existing_annotations = annotator.load_existing_annotation()
if existing_annotations:
logger.info(f"Loading existing annotations: {existing_annotations}")
# Need to return initial values for components if loading annotations
# This part is tricky with function loading. Let's adjust.
# We might need a separate function to load initial state
# Or adjust how initial values are set.
# For now, let's assume the user starts fresh or loads via next/prev.
# A more robust solution might involve a dedicated "load" button or
# returning initial component values from a setup function.
# Event handlers
save_btn.click(
fn=annotator.save_annotation,
inputs=annotation_components,
# --- Start Edit: Output to status_display ---
outputs=status_display
# --- End Edit ---
)
next_btn.click(
fn=annotator.next_video,
inputs=annotation_components,
# --- Start Edit: Output status to status_display. Progress updates via radio changes. ---
# Outputs: video, clear all radios, update status_display
outputs=[video_player] + annotation_components + [status_display]
# --- End Edit ---
)
prev_btn.click(
fn=annotator.prev_video,
inputs=annotation_components,
# --- Start Edit: Output status to status_display. Progress updates via radio changes. ---
# Outputs: video, clear all radios, update status_display
outputs=[video_player] + annotation_components + [status_display]
# --- End Edit ---
)
return demo
if __name__ == "__main__":
demo = create_interface()
# Consider adding share=True for easier testing if needed
demo.launch(allowed_paths=["/"])