File size: 4,711 Bytes
9a089a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import gradio as gr
import os
import json
from PIL import Image
import tempfile
from datasets import load_dataset

DATASET_NAME = "Kyunnilee/visual-puzzles"
SPLIT_NAME = "train"

IMAGE_FOLDER = "images"
os.makedirs(IMAGE_FOLDER, exist_ok=True)

# Load dataset
hf_dataset = load_dataset(DATASET_NAME, split=SPLIT_NAME)

image_paths = []
for i, example in enumerate(hf_dataset):
    img = example["image"]
    path = os.path.join(IMAGE_FOLDER, f"image_{i}.png")
    img.save(path)
    image_paths.append(f"image_{i}.png")

ANNOTATION_FILE = "annotations.json"

annotations = []
current_index = 0


# Helper functions
def load_annotations(file_obj):
    global annotations, current_index
    file_path = file_obj.name if hasattr(file_obj, "name") else file_obj
    with open(file_path, "r") as f:
        annotations = json.load(f)
    done_images = {a["image"] for a in annotations}
    remaining = [img for img in image_paths if img not in done_images]
    current_index = 0 if not remaining else image_paths.index(remaining[0])
    return update_interface()


def save_annotations():
    with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".json") as tmp:
        json.dump(annotations, tmp, indent=2)
        tmp_path = tmp.name
    return tmp_path


def update_interface():
    if current_index >= len(image_paths):
        skipped_count = sum(1 for a in annotations if a["answer"] == "<SKIP>")
        return None, "", f"All {len(image_paths)} images annotated. Skipped: {skipped_count}", 100
    image_path = os.path.join(IMAGE_FOLDER, image_paths[current_index])
    existing = next((a["answer"] for a in annotations if a["image"] == image_paths[current_index]), "")
    skipped_count = sum(1 for a in annotations if a["answer"] == "<SKIP>")
    return (
        image_path,
        existing if existing != "<SKIP>" else "",
        f"{len(annotations)} / {len(image_paths)} completed. Skipped: {skipped_count}",
        int((len(annotations) / len(image_paths)) * 100),
    )


def submit_answer(answer):
    global current_index
    if current_index < len(image_paths):
        img_name = image_paths[current_index]
        existing = next((a for a in annotations if a["image"] == img_name), None)
        if answer.strip():
            if existing:
                existing["answer"] = answer
            else:
                annotations.append({"image": img_name, "answer": answer})
        else:
            if not existing:
                annotations.append({"image": img_name, "answer": "<SKIP>"})

    # Find the next unannotated index
    for i in range(current_index + 1, len(image_paths)):
        next_img_name = image_paths[i]
        if not any(a["image"] == next_img_name and a["answer"] not in ("", "<SKIP>") for a in annotations):
            current_index = i
            break
    else:
        # If no unannotated images are found, set current_index to the end
        current_index = len(image_paths)

    return update_interface()


def go_previous():
    global current_index
    if current_index > 0:
        current_index -= 1
    return update_interface()


def go_next():
    global current_index
    if current_index < len(image_paths) - 1:
        current_index += 1
    return update_interface()


# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("## Rebus Puzzle Annotator")

    with gr.Row():
        image_display = gr.Image(label="Rebus Image", type="filepath")
        textbox = gr.Textbox(label="Your Answer", placeholder="Type the rebus answer here")

    with gr.Row():
        submit_btn = gr.Button("Submit Answer")
        prev_btn = gr.Button("Previous")
        next_btn = gr.Button("Next")

    with gr.Row():
        status = gr.Textbox(label="Progress")
        progress_bar = gr.Slider(label="Progress", minimum=0, maximum=100, step=1, interactive=False)

    with gr.Row():
        download_btn = gr.Button("Export Annotations")
        upload_btn = gr.File(label="Import Annotations", file_types=[".json"])
        output_file = gr.File(label="Download JSON")

    # Hook up events
    submit_btn.click(submit_answer, textbox, [image_display, textbox, status, progress_bar])
    textbox.submit(submit_answer, textbox, [image_display, textbox, status, progress_bar])
    prev_btn.click(go_previous, None, [image_display, textbox, status, progress_bar])
    next_btn.click(go_next, None, [image_display, textbox, status, progress_bar])
    download_btn.click(save_annotations, None, output_file)
    upload_btn.change(load_annotations, upload_btn, [image_display, textbox, status, progress_bar])

    demo.load(update_interface, None, [image_display, textbox, status, progress_bar])


if __name__ == "__main__":
    demo.launch()