import gradio as gr import csv import uuid import threading from datasets import load_dataset from collections import OrderedDict import random import time from huggingface_hub import CommitScheduler, HfApi, snapshot_download import os import shutil from pathlib import Path import pandas as pd api = HfApi(token=os.environ["HF_TOKEN"]) DATASET_NAME = "taesiri/HumanHandsDataset" BACKUP_REPO = "taesiri/HumanHandsDatasetFingerCounts" # Create data directory os.makedirs("./data", exist_ok=True) def sync_with_hub(): """ Synchronize local data with the hub by downloading latest dataset """ print("Starting sync with hub...") data_dir = Path("./data") if data_dir.exists(): # Backup existing data backup_dir = Path("./data_backup") if backup_dir.exists(): shutil.rmtree(backup_dir) shutil.copytree(data_dir, backup_dir) # Download latest data from hub repo_path = snapshot_download( repo_id=BACKUP_REPO, repo_type="dataset", local_dir="hub_data" ) # Merge hub data with local data hub_data_dir = Path(repo_path) / "data" if hub_data_dir.exists(): os.makedirs(data_dir, exist_ok=True) for item in hub_data_dir.glob("*"): if item.is_dir(): dest = data_dir / item.name if not dest.exists(): shutil.copytree(item, dest) elif item.name == "finger_count_results.csv": hub_csv = pd.read_csv(item) if item.exists() else pd.DataFrame() local_csv_path = data_dir / "finger_count_results.csv" local_csv = ( pd.read_csv(local_csv_path) if local_csv_path.exists() else pd.DataFrame() ) merged_csv = pd.concat( [local_csv, hub_csv], ignore_index=True ).drop_duplicates() merged_csv.to_csv(local_csv_path, index=False) # Clean up downloaded repo if Path("hub_data").exists(): shutil.rmtree("hub_data") print("Finished syncing with hub!") # Set up commit scheduler scheduler = CommitScheduler( repo_id=BACKUP_REPO, repo_type="dataset", folder_path="./data", path_in_repo="data", every=1, ) # Sync with hub before starting sync_with_hub() # Update RESULT_CSV path to be in data directory RESULT_CSV = "./data/finger_count_results.csv" # Load the dataset ds = load_dataset(DATASET_NAME, split="train") # Get UUID lookup dataframe for efficient searching uuid_df = load_dataset(DATASET_NAME, split="train", columns=["uuid"]) uuid_df = pd.DataFrame(uuid_df) # A thread lock to avoid concurrent writes write_lock = threading.Lock() # Set to store annotated sample indices annotated_samples = set() # OrderedDict to act as a TTL cache for in-progress samples # Format: {index: (timestamp, session_id)} in_progress_samples = OrderedDict() IN_PROGRESS_TTL = 300 # 5 minutes in seconds MAX_IN_PROGRESS = 1000 # Maximum number of in-progress samples to track # Load previously annotated samples from CSV def load_annotated_samples(): try: with open(RESULT_CSV, "r", newline="", encoding="utf-8") as f: reader = csv.reader(f) next(reader) # Skip header for row in reader: record_uuid = row[1] # Find index for this UUID using efficient dataframe lookup idx = uuid_df.index[uuid_df["uuid"] == record_uuid].tolist() if idx: annotated_samples.add(idx[0]) except FileExistsError: pass # Prepare the CSV file and load annotated samples with write_lock: try: with open(RESULT_CSV, "x", newline="", encoding="utf-8") as f: writer = csv.writer(f) writer.writerow(["session_id", "uuid", "prompt", "choice"]) except FileExistsError: load_annotated_samples() def cleanup_in_progress(): """Remove expired in-progress samples""" current_time = time.time() while ( in_progress_samples and list(in_progress_samples.items())[0][1][0] < current_time - IN_PROGRESS_TTL ): in_progress_samples.popitem(last=False) def get_random_sample(session_id): """Get a random sample that's neither annotated nor in progress""" cleanup_in_progress() # Get all possible indices all_indices = set(range(len(ds))) # Get unavailable indices (annotated + in-progress) unavailable = annotated_samples.union(in_progress_samples.keys()) # Get available indices available = list(all_indices - unavailable) if not available: return None # Select random index from available ones index = random.choice(available) # Add to in-progress samples if len(in_progress_samples) >= MAX_IN_PROGRESS: in_progress_samples.popitem(last=False) # Remove oldest item in_progress_samples[index] = (time.time(), session_id) return index def get_record(index): """ Given an index, return: - PIL image - prompt text - the UUID for the dataset row """ record = ds[index] return record["image"], record["prompt"], record["uuid"] def update_session(choice, session_id, index): """ This function is called whenever a user presses a button. - Writes the user's choice to the CSV file. - Increments the index to show the next image. - Returns the new image, prompt, updated index, and UUID to the UI. - If out of images, returns a "Done" placeholder. """ # Get the current record image, prompt, record_uuid = get_record(index) # Write to CSV with write_lock: with open(RESULT_CSV, "a", newline="", encoding="utf-8") as f: writer = csv.writer(f) writer.writerow([session_id, record_uuid, prompt, choice]) # Add to annotated samples and remove from in-progress annotated_samples.add(index) in_progress_samples.pop(index, None) # Get next random sample new_index = get_random_sample(session_id) if new_index is None: return (None, "No more images to label. Thank you!", new_index, "") # Get the next record next_image, next_prompt, next_uuid = get_record(new_index) return (next_image, next_prompt, new_index, f"UUID: {next_uuid}") # Create a Gradio interface with gr.Blocks() as demo: gr.Markdown("## Human Hands Finger Counting App") # State: each user has a unique session ID and current index session_id = gr.State(str(uuid.uuid4())) current_index = gr.State(0) image_display = gr.Image(type="pil", label="Image to Review") prompt_display = gr.Markdown() uuid_display = gr.Markdown() # Add UUID display # Initialize with the first record def start_app(session_id, index): if index == 0: # Only get random sample for new sessions index = get_random_sample(session_id) if index is None: return None, "No more images to label. Thank you!", "" img, prompt, uuid_str = get_record(index) return img, prompt, f"UUID: {uuid_str}" with gr.Row(): # Buttons for finger count btn_three = gr.Button("Three") btn_four = gr.Button("Four") btn_five = gr.Button("Five") btn_six = gr.Button("Six") btn_seven = gr.Button("Seven") btn_eight = gr.Button("Eight") btn_nine = gr.Button("Nine") btn_ten = gr.Button("Ten") btn_more = gr.Button("More than 11") btn_cannot = gr.Button("Cannot identify", variant="stop") # Red background # Define partial functions to specify each choice def choose_three(session_id, index): return update_session("three", session_id, index) def choose_four(session_id, index): return update_session("four", session_id, index) def choose_five(session_id, index): return update_session("five", session_id, index) def choose_six(session_id, index): return update_session("six", session_id, index) def choose_seven(session_id, index): return update_session("seven", session_id, index) def choose_eight(session_id, index): return update_session("eight", session_id, index) def choose_nine(session_id, index): return update_session("nine", session_id, index) def choose_ten(session_id, index): return update_session("ten", session_id, index) def choose_more(session_id, index): return update_session("more_than_11", session_id, index) def choose_cannot(session_id, index): return update_session("cannot_identify", session_id, index) # Link button clicks to functions btn_three.click( fn=choose_three, inputs=[session_id, current_index], outputs=[image_display, prompt_display, current_index, uuid_display], ) btn_four.click( fn=choose_four, inputs=[session_id, current_index], outputs=[image_display, prompt_display, current_index, uuid_display], ) btn_five.click( fn=choose_five, inputs=[session_id, current_index], outputs=[image_display, prompt_display, current_index, uuid_display], ) btn_six.click( fn=choose_six, inputs=[session_id, current_index], outputs=[image_display, prompt_display, current_index, uuid_display], ) btn_seven.click( fn=choose_seven, inputs=[session_id, current_index], outputs=[image_display, prompt_display, current_index, uuid_display], ) btn_eight.click( fn=choose_eight, inputs=[session_id, current_index], outputs=[image_display, prompt_display, current_index, uuid_display], ) btn_nine.click( fn=choose_nine, inputs=[session_id, current_index], outputs=[image_display, prompt_display, current_index, uuid_display], ) btn_ten.click( fn=choose_ten, inputs=[session_id, current_index], outputs=[image_display, prompt_display, current_index, uuid_display], ) btn_more.click( fn=choose_more, inputs=[session_id, current_index], outputs=[image_display, prompt_display, current_index, uuid_display], ) btn_cannot.click( fn=choose_cannot, inputs=[session_id, current_index], outputs=[image_display, prompt_display, current_index, uuid_display], ) # Load the first image/prompt on launch demo.load( fn=start_app, inputs=[session_id, current_index], outputs=[image_display, prompt_display, uuid_display], ) demo.launch()