Fingers / app.py
taesiri's picture
backup
45036b5
import gradio as gr
import csv
import uuid
import threading
from datasets import load_dataset
from collections import OrderedDict
import random
import time
from huggingface_hub import CommitScheduler, HfApi, snapshot_download
import os
import shutil
from pathlib import Path
import pandas as pd
api = HfApi(token=os.environ["HF_TOKEN"])
DATASET_NAME = "taesiri/HumanHandsDataset"
BACKUP_REPO = "taesiri/HumanHandsDatasetFingerCounts"
# Create data directory
os.makedirs("./data", exist_ok=True)
def sync_with_hub():
"""
Synchronize local data with the hub by downloading latest dataset
"""
print("Starting sync with hub...")
data_dir = Path("./data")
if data_dir.exists():
# Backup existing data
backup_dir = Path("./data_backup")
if backup_dir.exists():
shutil.rmtree(backup_dir)
shutil.copytree(data_dir, backup_dir)
# Download latest data from hub
repo_path = snapshot_download(
repo_id=BACKUP_REPO, repo_type="dataset", local_dir="hub_data"
)
# Merge hub data with local data
hub_data_dir = Path(repo_path) / "data"
if hub_data_dir.exists():
os.makedirs(data_dir, exist_ok=True)
for item in hub_data_dir.glob("*"):
if item.is_dir():
dest = data_dir / item.name
if not dest.exists():
shutil.copytree(item, dest)
elif item.name == "finger_count_results.csv":
hub_csv = pd.read_csv(item) if item.exists() else pd.DataFrame()
local_csv_path = data_dir / "finger_count_results.csv"
local_csv = (
pd.read_csv(local_csv_path)
if local_csv_path.exists()
else pd.DataFrame()
)
merged_csv = pd.concat(
[local_csv, hub_csv], ignore_index=True
).drop_duplicates()
merged_csv.to_csv(local_csv_path, index=False)
# Clean up downloaded repo
if Path("hub_data").exists():
shutil.rmtree("hub_data")
print("Finished syncing with hub!")
# Set up commit scheduler
scheduler = CommitScheduler(
repo_id=BACKUP_REPO,
repo_type="dataset",
folder_path="./data",
path_in_repo="data",
every=1,
)
# Sync with hub before starting
sync_with_hub()
# Update RESULT_CSV path to be in data directory
RESULT_CSV = "./data/finger_count_results.csv"
# Load the dataset
ds = load_dataset(DATASET_NAME, split="train")
# Get UUID lookup dataframe for efficient searching
uuid_df = load_dataset(DATASET_NAME, split="train", columns=["uuid"])
uuid_df = pd.DataFrame(uuid_df)
# A thread lock to avoid concurrent writes
write_lock = threading.Lock()
# Set to store annotated sample indices
annotated_samples = set()
# OrderedDict to act as a TTL cache for in-progress samples
# Format: {index: (timestamp, session_id)}
in_progress_samples = OrderedDict()
IN_PROGRESS_TTL = 300 # 5 minutes in seconds
MAX_IN_PROGRESS = 1000 # Maximum number of in-progress samples to track
# Load previously annotated samples from CSV
def load_annotated_samples():
try:
with open(RESULT_CSV, "r", newline="", encoding="utf-8") as f:
reader = csv.reader(f)
next(reader) # Skip header
for row in reader:
record_uuid = row[1]
# Find index for this UUID using efficient dataframe lookup
idx = uuid_df.index[uuid_df["uuid"] == record_uuid].tolist()
if idx:
annotated_samples.add(idx[0])
except FileExistsError:
pass
# Prepare the CSV file and load annotated samples
with write_lock:
try:
with open(RESULT_CSV, "x", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow(["session_id", "uuid", "prompt", "choice"])
except FileExistsError:
load_annotated_samples()
def cleanup_in_progress():
"""Remove expired in-progress samples"""
current_time = time.time()
while (
in_progress_samples
and list(in_progress_samples.items())[0][1][0] < current_time - IN_PROGRESS_TTL
):
in_progress_samples.popitem(last=False)
def get_random_sample(session_id):
"""Get a random sample that's neither annotated nor in progress"""
cleanup_in_progress()
# Get all possible indices
all_indices = set(range(len(ds)))
# Get unavailable indices (annotated + in-progress)
unavailable = annotated_samples.union(in_progress_samples.keys())
# Get available indices
available = list(all_indices - unavailable)
if not available:
return None
# Select random index from available ones
index = random.choice(available)
# Add to in-progress samples
if len(in_progress_samples) >= MAX_IN_PROGRESS:
in_progress_samples.popitem(last=False) # Remove oldest item
in_progress_samples[index] = (time.time(), session_id)
return index
def get_record(index):
"""
Given an index, return:
- PIL image
- prompt text
- the UUID for the dataset row
"""
record = ds[index]
return record["image"], record["prompt"], record["uuid"]
def update_session(choice, session_id, index):
"""
This function is called whenever a user presses a button.
- Writes the user's choice to the CSV file.
- Increments the index to show the next image.
- Returns the new image, prompt, updated index, and UUID to the UI.
- If out of images, returns a "Done" placeholder.
"""
# Get the current record
image, prompt, record_uuid = get_record(index)
# Write to CSV
with write_lock:
with open(RESULT_CSV, "a", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow([session_id, record_uuid, prompt, choice])
# Add to annotated samples and remove from in-progress
annotated_samples.add(index)
in_progress_samples.pop(index, None)
# Get next random sample
new_index = get_random_sample(session_id)
if new_index is None:
return (None, "No more images to label. Thank you!", new_index, "")
# Get the next record
next_image, next_prompt, next_uuid = get_record(new_index)
return (next_image, next_prompt, new_index, f"UUID: {next_uuid}")
# Create a Gradio interface
with gr.Blocks() as demo:
gr.Markdown("## Human Hands Finger Counting App")
# State: each user has a unique session ID and current index
session_id = gr.State(str(uuid.uuid4()))
current_index = gr.State(0)
image_display = gr.Image(type="pil", label="Image to Review")
prompt_display = gr.Markdown()
uuid_display = gr.Markdown() # Add UUID display
# Initialize with the first record
def start_app(session_id, index):
if index == 0: # Only get random sample for new sessions
index = get_random_sample(session_id)
if index is None:
return None, "No more images to label. Thank you!", ""
img, prompt, uuid_str = get_record(index)
return img, prompt, f"UUID: {uuid_str}"
with gr.Row():
# Buttons for finger count
btn_three = gr.Button("Three")
btn_four = gr.Button("Four")
btn_five = gr.Button("Five")
btn_six = gr.Button("Six")
btn_seven = gr.Button("Seven")
btn_eight = gr.Button("Eight")
btn_nine = gr.Button("Nine")
btn_ten = gr.Button("Ten")
btn_more = gr.Button("More than 11")
btn_cannot = gr.Button("Cannot identify", variant="stop") # Red background
# Define partial functions to specify each choice
def choose_three(session_id, index):
return update_session("three", session_id, index)
def choose_four(session_id, index):
return update_session("four", session_id, index)
def choose_five(session_id, index):
return update_session("five", session_id, index)
def choose_six(session_id, index):
return update_session("six", session_id, index)
def choose_seven(session_id, index):
return update_session("seven", session_id, index)
def choose_eight(session_id, index):
return update_session("eight", session_id, index)
def choose_nine(session_id, index):
return update_session("nine", session_id, index)
def choose_ten(session_id, index):
return update_session("ten", session_id, index)
def choose_more(session_id, index):
return update_session("more_than_11", session_id, index)
def choose_cannot(session_id, index):
return update_session("cannot_identify", session_id, index)
# Link button clicks to functions
btn_three.click(
fn=choose_three,
inputs=[session_id, current_index],
outputs=[image_display, prompt_display, current_index, uuid_display],
)
btn_four.click(
fn=choose_four,
inputs=[session_id, current_index],
outputs=[image_display, prompt_display, current_index, uuid_display],
)
btn_five.click(
fn=choose_five,
inputs=[session_id, current_index],
outputs=[image_display, prompt_display, current_index, uuid_display],
)
btn_six.click(
fn=choose_six,
inputs=[session_id, current_index],
outputs=[image_display, prompt_display, current_index, uuid_display],
)
btn_seven.click(
fn=choose_seven,
inputs=[session_id, current_index],
outputs=[image_display, prompt_display, current_index, uuid_display],
)
btn_eight.click(
fn=choose_eight,
inputs=[session_id, current_index],
outputs=[image_display, prompt_display, current_index, uuid_display],
)
btn_nine.click(
fn=choose_nine,
inputs=[session_id, current_index],
outputs=[image_display, prompt_display, current_index, uuid_display],
)
btn_ten.click(
fn=choose_ten,
inputs=[session_id, current_index],
outputs=[image_display, prompt_display, current_index, uuid_display],
)
btn_more.click(
fn=choose_more,
inputs=[session_id, current_index],
outputs=[image_display, prompt_display, current_index, uuid_display],
)
btn_cannot.click(
fn=choose_cannot,
inputs=[session_id, current_index],
outputs=[image_display, prompt_display, current_index, uuid_display],
)
# Load the first image/prompt on launch
demo.load(
fn=start_app,
inputs=[session_id, current_index],
outputs=[image_display, prompt_display, uuid_display],
)
demo.launch()