|
import gradio as gr |
|
import csv |
|
import uuid |
|
import threading |
|
from datasets import load_dataset |
|
from collections import OrderedDict |
|
import random |
|
import time |
|
from huggingface_hub import CommitScheduler, HfApi, snapshot_download |
|
import os |
|
import shutil |
|
from pathlib import Path |
|
import pandas as pd |
|
|
|
api = HfApi(token=os.environ["HF_TOKEN"]) |
|
|
|
|
|
DATASET_NAME = "taesiri/HumanHandsDataset" |
|
BACKUP_REPO = "taesiri/HumanHandsDatasetFingerCounts" |
|
|
|
|
|
os.makedirs("./data", exist_ok=True) |
|
|
|
|
|
def sync_with_hub(): |
|
""" |
|
Synchronize local data with the hub by downloading latest dataset |
|
""" |
|
print("Starting sync with hub...") |
|
data_dir = Path("./data") |
|
if data_dir.exists(): |
|
|
|
backup_dir = Path("./data_backup") |
|
if backup_dir.exists(): |
|
shutil.rmtree(backup_dir) |
|
shutil.copytree(data_dir, backup_dir) |
|
|
|
|
|
repo_path = snapshot_download( |
|
repo_id=BACKUP_REPO, repo_type="dataset", local_dir="hub_data" |
|
) |
|
|
|
|
|
hub_data_dir = Path(repo_path) / "data" |
|
if hub_data_dir.exists(): |
|
os.makedirs(data_dir, exist_ok=True) |
|
for item in hub_data_dir.glob("*"): |
|
if item.is_dir(): |
|
dest = data_dir / item.name |
|
if not dest.exists(): |
|
shutil.copytree(item, dest) |
|
elif item.name == "finger_count_results.csv": |
|
hub_csv = pd.read_csv(item) if item.exists() else pd.DataFrame() |
|
local_csv_path = data_dir / "finger_count_results.csv" |
|
local_csv = ( |
|
pd.read_csv(local_csv_path) |
|
if local_csv_path.exists() |
|
else pd.DataFrame() |
|
) |
|
merged_csv = pd.concat( |
|
[local_csv, hub_csv], ignore_index=True |
|
).drop_duplicates() |
|
merged_csv.to_csv(local_csv_path, index=False) |
|
|
|
|
|
if Path("hub_data").exists(): |
|
shutil.rmtree("hub_data") |
|
print("Finished syncing with hub!") |
|
|
|
|
|
|
|
scheduler = CommitScheduler( |
|
repo_id=BACKUP_REPO, |
|
repo_type="dataset", |
|
folder_path="./data", |
|
path_in_repo="data", |
|
every=1, |
|
) |
|
|
|
|
|
sync_with_hub() |
|
|
|
|
|
RESULT_CSV = "./data/finger_count_results.csv" |
|
|
|
|
|
ds = load_dataset(DATASET_NAME, split="train") |
|
|
|
uuid_df = load_dataset(DATASET_NAME, split="train", columns=["uuid"]) |
|
uuid_df = pd.DataFrame(uuid_df) |
|
|
|
|
|
write_lock = threading.Lock() |
|
|
|
|
|
annotated_samples = set() |
|
|
|
|
|
|
|
in_progress_samples = OrderedDict() |
|
IN_PROGRESS_TTL = 300 |
|
MAX_IN_PROGRESS = 1000 |
|
|
|
|
|
|
|
def load_annotated_samples(): |
|
try: |
|
with open(RESULT_CSV, "r", newline="", encoding="utf-8") as f: |
|
reader = csv.reader(f) |
|
next(reader) |
|
for row in reader: |
|
record_uuid = row[1] |
|
|
|
idx = uuid_df.index[uuid_df["uuid"] == record_uuid].tolist() |
|
if idx: |
|
annotated_samples.add(idx[0]) |
|
except FileExistsError: |
|
pass |
|
|
|
|
|
|
|
with write_lock: |
|
try: |
|
with open(RESULT_CSV, "x", newline="", encoding="utf-8") as f: |
|
writer = csv.writer(f) |
|
writer.writerow(["session_id", "uuid", "prompt", "choice"]) |
|
except FileExistsError: |
|
load_annotated_samples() |
|
|
|
|
|
def cleanup_in_progress(): |
|
"""Remove expired in-progress samples""" |
|
current_time = time.time() |
|
while ( |
|
in_progress_samples |
|
and list(in_progress_samples.items())[0][1][0] < current_time - IN_PROGRESS_TTL |
|
): |
|
in_progress_samples.popitem(last=False) |
|
|
|
|
|
def get_random_sample(session_id): |
|
"""Get a random sample that's neither annotated nor in progress""" |
|
cleanup_in_progress() |
|
|
|
|
|
all_indices = set(range(len(ds))) |
|
|
|
unavailable = annotated_samples.union(in_progress_samples.keys()) |
|
|
|
available = list(all_indices - unavailable) |
|
|
|
if not available: |
|
return None |
|
|
|
|
|
index = random.choice(available) |
|
|
|
|
|
if len(in_progress_samples) >= MAX_IN_PROGRESS: |
|
in_progress_samples.popitem(last=False) |
|
in_progress_samples[index] = (time.time(), session_id) |
|
|
|
return index |
|
|
|
|
|
def get_record(index): |
|
""" |
|
Given an index, return: |
|
- PIL image |
|
- prompt text |
|
- the UUID for the dataset row |
|
""" |
|
record = ds[index] |
|
return record["image"], record["prompt"], record["uuid"] |
|
|
|
|
|
def update_session(choice, session_id, index): |
|
""" |
|
This function is called whenever a user presses a button. |
|
- Writes the user's choice to the CSV file. |
|
- Increments the index to show the next image. |
|
- Returns the new image, prompt, updated index, and UUID to the UI. |
|
- If out of images, returns a "Done" placeholder. |
|
""" |
|
|
|
image, prompt, record_uuid = get_record(index) |
|
|
|
|
|
with write_lock: |
|
with open(RESULT_CSV, "a", newline="", encoding="utf-8") as f: |
|
writer = csv.writer(f) |
|
writer.writerow([session_id, record_uuid, prompt, choice]) |
|
|
|
|
|
annotated_samples.add(index) |
|
in_progress_samples.pop(index, None) |
|
|
|
|
|
new_index = get_random_sample(session_id) |
|
if new_index is None: |
|
return (None, "No more images to label. Thank you!", new_index, "") |
|
|
|
|
|
next_image, next_prompt, next_uuid = get_record(new_index) |
|
|
|
return (next_image, next_prompt, new_index, f"UUID: {next_uuid}") |
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## Human Hands Finger Counting App") |
|
|
|
|
|
session_id = gr.State(str(uuid.uuid4())) |
|
current_index = gr.State(0) |
|
|
|
image_display = gr.Image(type="pil", label="Image to Review") |
|
prompt_display = gr.Markdown() |
|
uuid_display = gr.Markdown() |
|
|
|
|
|
def start_app(session_id, index): |
|
if index == 0: |
|
index = get_random_sample(session_id) |
|
if index is None: |
|
return None, "No more images to label. Thank you!", "" |
|
img, prompt, uuid_str = get_record(index) |
|
return img, prompt, f"UUID: {uuid_str}" |
|
|
|
with gr.Row(): |
|
|
|
btn_three = gr.Button("Three") |
|
btn_four = gr.Button("Four") |
|
btn_five = gr.Button("Five") |
|
btn_six = gr.Button("Six") |
|
btn_seven = gr.Button("Seven") |
|
btn_eight = gr.Button("Eight") |
|
btn_nine = gr.Button("Nine") |
|
btn_ten = gr.Button("Ten") |
|
btn_more = gr.Button("More than 11") |
|
btn_cannot = gr.Button("Cannot identify", variant="stop") |
|
|
|
|
|
def choose_three(session_id, index): |
|
return update_session("three", session_id, index) |
|
|
|
def choose_four(session_id, index): |
|
return update_session("four", session_id, index) |
|
|
|
def choose_five(session_id, index): |
|
return update_session("five", session_id, index) |
|
|
|
def choose_six(session_id, index): |
|
return update_session("six", session_id, index) |
|
|
|
def choose_seven(session_id, index): |
|
return update_session("seven", session_id, index) |
|
|
|
def choose_eight(session_id, index): |
|
return update_session("eight", session_id, index) |
|
|
|
def choose_nine(session_id, index): |
|
return update_session("nine", session_id, index) |
|
|
|
def choose_ten(session_id, index): |
|
return update_session("ten", session_id, index) |
|
|
|
def choose_more(session_id, index): |
|
return update_session("more_than_11", session_id, index) |
|
|
|
def choose_cannot(session_id, index): |
|
return update_session("cannot_identify", session_id, index) |
|
|
|
|
|
btn_three.click( |
|
fn=choose_three, |
|
inputs=[session_id, current_index], |
|
outputs=[image_display, prompt_display, current_index, uuid_display], |
|
) |
|
|
|
btn_four.click( |
|
fn=choose_four, |
|
inputs=[session_id, current_index], |
|
outputs=[image_display, prompt_display, current_index, uuid_display], |
|
) |
|
|
|
btn_five.click( |
|
fn=choose_five, |
|
inputs=[session_id, current_index], |
|
outputs=[image_display, prompt_display, current_index, uuid_display], |
|
) |
|
|
|
btn_six.click( |
|
fn=choose_six, |
|
inputs=[session_id, current_index], |
|
outputs=[image_display, prompt_display, current_index, uuid_display], |
|
) |
|
|
|
btn_seven.click( |
|
fn=choose_seven, |
|
inputs=[session_id, current_index], |
|
outputs=[image_display, prompt_display, current_index, uuid_display], |
|
) |
|
|
|
btn_eight.click( |
|
fn=choose_eight, |
|
inputs=[session_id, current_index], |
|
outputs=[image_display, prompt_display, current_index, uuid_display], |
|
) |
|
|
|
btn_nine.click( |
|
fn=choose_nine, |
|
inputs=[session_id, current_index], |
|
outputs=[image_display, prompt_display, current_index, uuid_display], |
|
) |
|
|
|
btn_ten.click( |
|
fn=choose_ten, |
|
inputs=[session_id, current_index], |
|
outputs=[image_display, prompt_display, current_index, uuid_display], |
|
) |
|
|
|
btn_more.click( |
|
fn=choose_more, |
|
inputs=[session_id, current_index], |
|
outputs=[image_display, prompt_display, current_index, uuid_display], |
|
) |
|
|
|
btn_cannot.click( |
|
fn=choose_cannot, |
|
inputs=[session_id, current_index], |
|
outputs=[image_display, prompt_display, current_index, uuid_display], |
|
) |
|
|
|
|
|
demo.load( |
|
fn=start_app, |
|
inputs=[session_id, current_index], |
|
outputs=[image_display, prompt_display, uuid_display], |
|
) |
|
|
|
demo.launch() |
|
|