File size: 4,711 Bytes
9a089a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import gradio as gr
import os
import json
from PIL import Image
import tempfile
from datasets import load_dataset
DATASET_NAME = "Kyunnilee/visual-puzzles"
SPLIT_NAME = "train"
IMAGE_FOLDER = "images"
os.makedirs(IMAGE_FOLDER, exist_ok=True)
# Load dataset
hf_dataset = load_dataset(DATASET_NAME, split=SPLIT_NAME)
image_paths = []
for i, example in enumerate(hf_dataset):
img = example["image"]
path = os.path.join(IMAGE_FOLDER, f"image_{i}.png")
img.save(path)
image_paths.append(f"image_{i}.png")
ANNOTATION_FILE = "annotations.json"
annotations = []
current_index = 0
# Helper functions
def load_annotations(file_obj):
global annotations, current_index
file_path = file_obj.name if hasattr(file_obj, "name") else file_obj
with open(file_path, "r") as f:
annotations = json.load(f)
done_images = {a["image"] for a in annotations}
remaining = [img for img in image_paths if img not in done_images]
current_index = 0 if not remaining else image_paths.index(remaining[0])
return update_interface()
def save_annotations():
with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".json") as tmp:
json.dump(annotations, tmp, indent=2)
tmp_path = tmp.name
return tmp_path
def update_interface():
if current_index >= len(image_paths):
skipped_count = sum(1 for a in annotations if a["answer"] == "<SKIP>")
return None, "", f"All {len(image_paths)} images annotated. Skipped: {skipped_count}", 100
image_path = os.path.join(IMAGE_FOLDER, image_paths[current_index])
existing = next((a["answer"] for a in annotations if a["image"] == image_paths[current_index]), "")
skipped_count = sum(1 for a in annotations if a["answer"] == "<SKIP>")
return (
image_path,
existing if existing != "<SKIP>" else "",
f"{len(annotations)} / {len(image_paths)} completed. Skipped: {skipped_count}",
int((len(annotations) / len(image_paths)) * 100),
)
def submit_answer(answer):
global current_index
if current_index < len(image_paths):
img_name = image_paths[current_index]
existing = next((a for a in annotations if a["image"] == img_name), None)
if answer.strip():
if existing:
existing["answer"] = answer
else:
annotations.append({"image": img_name, "answer": answer})
else:
if not existing:
annotations.append({"image": img_name, "answer": "<SKIP>"})
# Find the next unannotated index
for i in range(current_index + 1, len(image_paths)):
next_img_name = image_paths[i]
if not any(a["image"] == next_img_name and a["answer"] not in ("", "<SKIP>") for a in annotations):
current_index = i
break
else:
# If no unannotated images are found, set current_index to the end
current_index = len(image_paths)
return update_interface()
def go_previous():
global current_index
if current_index > 0:
current_index -= 1
return update_interface()
def go_next():
global current_index
if current_index < len(image_paths) - 1:
current_index += 1
return update_interface()
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("## Rebus Puzzle Annotator")
with gr.Row():
image_display = gr.Image(label="Rebus Image", type="filepath")
textbox = gr.Textbox(label="Your Answer", placeholder="Type the rebus answer here")
with gr.Row():
submit_btn = gr.Button("Submit Answer")
prev_btn = gr.Button("Previous")
next_btn = gr.Button("Next")
with gr.Row():
status = gr.Textbox(label="Progress")
progress_bar = gr.Slider(label="Progress", minimum=0, maximum=100, step=1, interactive=False)
with gr.Row():
download_btn = gr.Button("Export Annotations")
upload_btn = gr.File(label="Import Annotations", file_types=[".json"])
output_file = gr.File(label="Download JSON")
# Hook up events
submit_btn.click(submit_answer, textbox, [image_display, textbox, status, progress_bar])
textbox.submit(submit_answer, textbox, [image_display, textbox, status, progress_bar])
prev_btn.click(go_previous, None, [image_display, textbox, status, progress_bar])
next_btn.click(go_next, None, [image_display, textbox, status, progress_bar])
download_btn.click(save_annotations, None, output_file)
upload_btn.change(load_annotations, upload_btn, [image_display, textbox, status, progress_bar])
demo.load(update_interface, None, [image_display, textbox, status, progress_bar])
if __name__ == "__main__":
demo.launch()
|