import gradio as gr import os import json from huggingface_hub import hf_hub_download, list_repo_files, upload_file, HfApi from datasets import load_dataset, Dataset import logging import tempfile import random # Add this at the top with other imports # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # Cricket annotation categories ANNOTATION_CATEGORIES = { "Bowler's Run Up": ["Fast", "Spin"], "Delivery Type": ["Yorker", "Bouncer", "Length Ball", "Slower ball", "Googly", "Arm Ball", "Half volley", "Full Toss", "Other"], "Ball's trajectory": ["In Swing", "Out Swing", "Off spin", "Leg spin", "Straight", "Other"], "Shot Played": ["Cover Drive", "Straight Drive", "On Drive", "Pull", "Square Cut", "Defensive Block", "Missed", "Slog", "Sweep", "Reverse Sweep", "Upper Cut", "Hook", "Other"], "Shot type": ["Grounded", "Airborne"], "Outcome of the shot": ["Four (4)", "Six (6)", "Wicket", "Single (1)", "Double (2)", "Triple (3)", "Dot (0)", "Wide", "No Ball", "Other"], "Shot direction": ["Long On", "Long Off", "Cover", "Point", "Midwicket", "Square Leg", "Third Man", "Fine Leg", "Straight", "Square", "Other"], "Batsman's Action": ["Defensive", "Aggressive", "Neutral"], "Fielder's Action": ["Catch taken", "Catch dropped", "Misfield", "Run-out attempt", "Fielder fields",'Cannot field', 'Other'], } HF_REPO_ID = "cricverse/CricBench" HF_REPO_TYPE = "dataset" class VideoAnnotator: def __init__(self): self.video_files = [] self.current_video_idx = 0 self.annotations = {} self.hf_token = os.environ.get("HF_TOKEN") self.dataset = None self.annotation_repo_id = "cricverse/CricBench_Annotations" self.annotation_repo_type = "dataset" self.annotation_hf_token = os.environ.get("upload_token") if not self.annotation_hf_token: raise ValueError("HF_ANNOTATION_TOKEN not found") self.api = HfApi(token=self.annotation_hf_token) def load_videos_from_hf(self): try: logger.info(f"Loading dataset from HuggingFace: {HF_REPO_ID}") self.dataset = load_dataset(HF_REPO_ID, token=self.hf_token) # Usually "train" split split = list(self.dataset.keys())[0] self.dataset_split = self.dataset[split] # Get all the video files from the CricBench dataset all_video_files = [item['video'] if 'video' in item else item['path'] for item in self.dataset_split] logger.info(f"Found {len(all_video_files)} potential video files in source dataset.") # Filter out already annotated videos logger.info(f"Checking for existing annotations in {self.annotation_repo_id}") try: # List files in the 'annotations' directory of the annotation repo annotated_files = self.api.list_repo_files( repo_id=self.annotation_repo_id, repo_type=self.annotation_repo_type, # path_in_repo="annotations" ) # Extract base video names from annotation filenames # e.g., "annotations/video1.mp4.jsonl" -> "video1.mp4" annotated_video_basenames = set( os.path.basename(f).replace('.jsonl', '') for f in annotated_files if f.startswith("annotations/") and f.endswith(".jsonl") ) logger.info(f"Found {len(annotated_video_basenames)} existing annotation files.") # Filter the video list: keep only videos whose basename is NOT in the annotated set self.video_files = [ vf for vf in all_video_files if os.path.basename(vf) not in annotated_video_basenames ] logger.info(f"Filtered list: {len(self.video_files)} videos remaining to be annotated.") except Exception as e: # Log error and fallback to using all videos if the check fails logger.error(f"Could not list or process annotation files: {e}. Proceeding with all videos, but conflicts may occur.") self.video_files = all_video_files if not self.video_files: logger.warning("No videos left to annotate!") return len(self.video_files) > 0 except Exception as e: logger.error(f"Error accessing HuggingFace dataset: {e}") return False def get_current_video(self): if not self.video_files: logger.warning("No video files available") return None video_path = self.video_files[self.current_video_idx] logger.info(f"Loading video: {video_path}") try: local_path = hf_hub_download( repo_id=HF_REPO_ID, filename=video_path, repo_type=HF_REPO_TYPE ) logger.info(f"Video downloaded to: {local_path}") return local_path except Exception as e: logger.error(f"Error downloading video: {e}") return None def save_annotation(self, *annotations): # Convert the list of annotations into a dictionary annotations_dict = { category: value for category, value in zip(ANNOTATION_CATEGORIES.keys(), annotations) if value is not None # Only include non-None values } if not annotations_dict: logger.warning("No annotations to save") return "No annotations to save" video_name = os.path.basename(self.video_files[self.current_video_idx]) logger.info(f"Saving annotations for {video_name}: {annotations_dict}") try: # Save annotations in JSONL format annotation_entry = { "video_id": video_name, "annotations": annotations_dict } jsonl_content = json.dumps(annotation_entry) + "\n" # Write to a temporary JSONL file with tempfile.NamedTemporaryFile(delete=False, suffix=".jsonl") as temp_file: temp_file.write(jsonl_content.encode('utf-8')) temp_file_path = temp_file.name # Upload the JSONL file to Hugging Face # if self.annotation_hf_token: logger.info(f"Uploading annotations to Hugging Face: {self.annotation_repo_id}") self.api.upload_file( path_or_fileobj=temp_file_path, path_in_repo=f"annotations/{video_name}.jsonl", repo_id=self.annotation_repo_id, repo_type=self.annotation_repo_type, # token=self.annotation_hf_token ) return f"Annotations saved and uploaded for {video_name}" # else: # logger.warning("HF_ANNOTATION_TOKEN not found. Annotations saved locally only.") # return f"Annotations saved locally for {video_name} (no HF upload)" except Exception as e: logger.error(f"Error saving annotations: {e}") return f"Error saving: {str(e)}" def load_existing_annotation(self): """Try to load existing annotation for the current video from the dataset""" if not self.dataset or not self.video_files: return None try: # Get the split name (e.g., 'train') split = list(self.dataset.keys())[0] # Check if the current item has annotations if 'annotations' in self.dataset[split][self.current_video_idx]: annotation_str = self.dataset[split][self.current_video_idx]['annotations'] if annotation_str: return json.loads(annotation_str) return None except Exception as e: logger.error(f"Error loading existing annotation: {e}") return None def next_video(self, *current_annotations): # Save current annotations before moving to next video save_status = "No annotations provided to save." # Default status if self.video_files: annotations_provided = any(ann is not None for ann in current_annotations) if annotations_provided: save_status = self.save_annotation(*current_annotations) logger.info(f"Save status before moving next: {save_status}") else: logger.info("No annotations selected, skipping save before moving next.") save_status = "Skipped saving - no annotations selected." # Select a random video index different from the current one available_indices = [i for i in range(len(self.video_files)) if i != self.current_video_idx] if available_indices: # If there are other videos to choose from self.current_video_idx = random.choice(available_indices) logger.info(f"Moving to random video (index: {self.current_video_idx})") return self.get_current_video(), *[None] * len(ANNOTATION_CATEGORIES), save_status else: logger.info("No other videos available") return self.get_current_video(), *[None] * len(ANNOTATION_CATEGORIES), "No other videos available. " + save_status return self.get_current_video(), *[None] * len(ANNOTATION_CATEGORIES), save_status def prev_video(self, *current_annotations): # Save current annotations before moving to previous video save_status = "No annotations provided to save." # Default status if self.video_files: annotations_provided = any(ann is not None for ann in current_annotations) if annotations_provided: save_status = self.save_annotation(*current_annotations) logger.info(f"Save status before moving previous: {save_status}") else: logger.info("No annotations selected, skipping save before moving previous.") save_status = "Skipped saving - no annotations selected." # Move to previous video if self.current_video_idx > 0: self.current_video_idx -= 1 logger.info(f"Moving to previous video (index: {self.current_video_idx})") # --- Start Edit: Return only the save status as the last element --- return self.get_current_video(), *[None] * len(ANNOTATION_CATEGORIES), save_status # --- End Edit --- else: logger.info("Already at the first video") # --- Start Edit: Return only the save status as the last element --- return self.get_current_video(), *[None] * len(ANNOTATION_CATEGORIES), "Already at the first video. " + save_status # --- End Edit --- def create_interface(): def get_session_annotator(): annotator = VideoAnnotator() success = annotator.load_videos_from_hf() if not success: logger.error("Failed to load videos. Interface might not function correctly.") return annotator with gr.Blocks() as demo: # Create a state variable to store the session-specific annotator annotator_state = gr.State(value=get_session_annotator()) # Initialize with the instance gr.Markdown("# Cricket Video Annotation Tool") total_categories = len(ANNOTATION_CATEGORIES) def update_progress(*annotation_values): filled_count = sum(1 for val in annotation_values if val is not None) return f"**Progress:** {filled_count} / {total_categories} categories selected" with gr.Row(): # Main row to hold video and controls side-by-side with gr.Column(scale=2): # Column for Video Player and Nav Buttons video_player = gr.Video( value=None, # Initialize with None, we'll set it after getting the annotator label="Current Video", height=350 ) status_display = gr.Textbox(label="Status", interactive=False) # For save/nav messages with gr.Row(): prev_btn = gr.Button("Previous Video") next_btn = gr.Button("Next Video") with gr.Column(scale=2): # Column for Annotations and Save Button annotation_components = [] gr.Markdown("## Annotations") # Header for the annotation section # Display annotation radio buttons vertically in this column for category, options in ANNOTATION_CATEGORIES.items(): radio = gr.Radio( choices=options, label=category, ) annotation_components.append(radio) progress_display = gr.Markdown(value=update_progress(*[None]*total_categories)) # Initial progress # Attach change listener to each radio button for radio in annotation_components: radio.change( fn=update_progress, inputs=annotation_components, outputs=progress_display ) save_btn = gr.Button("Save Annotations", variant="primary") # Modify the load event handler to properly initialize and return all outputs @demo.load( outputs=[video_player] + annotation_components + [status_display] ) def on_load(): annotator = get_session_annotator() # Create a fresh instance return ( annotator.get_current_video(), # video player *[None] * len(ANNOTATION_CATEGORIES), # annotation components "" # status display ) # Modify the event handlers to use the session-specific annotator save_btn.click( fn=lambda annotator, *args: annotator.save_annotation(*args), inputs=[annotator_state] + annotation_components, outputs=status_display ) next_btn.click( fn=lambda annotator, *args: annotator.next_video(*args), inputs=[annotator_state] + annotation_components, outputs=[video_player] + annotation_components + [status_display] ) prev_btn.click( fn=lambda annotator, *args: annotator.prev_video(*args), inputs=[annotator_state] + annotation_components, outputs=[video_player] + annotation_components + [status_display] ) return demo if __name__ == "__main__": demo = create_interface() demo.queue() demo.launch(allowed_paths=["/"])