taesiri commited on
Commit
945e8ce
·
1 Parent(s): 2920ba0
Files changed (1) hide show
  1. app.py +97 -61
app.py CHANGED
@@ -16,11 +16,12 @@ api = HfApi(token=os.environ["HF_TOKEN"])
16
 
17
 
18
  DATASET_NAME = "taesiri/HumanHandsDataset"
19
- BACKUP_REPO= "taesiri/HumanHandsDatasetFingerCounts"
20
 
21
  # Create data directory
22
  os.makedirs("./data", exist_ok=True)
23
 
 
24
  def sync_with_hub():
25
  """
26
  Synchronize local data with the hub by downloading latest dataset
@@ -56,7 +57,9 @@ def sync_with_hub():
56
  if local_csv_path.exists()
57
  else pd.DataFrame()
58
  )
59
- merged_csv = pd.concat([local_csv, hub_csv], ignore_index=True).drop_duplicates()
 
 
60
  merged_csv.to_csv(local_csv_path, index=False)
61
 
62
  # Clean up downloaded repo
@@ -64,6 +67,7 @@ def sync_with_hub():
64
  shutil.rmtree("hub_data")
65
  print("Finished syncing with hub!")
66
 
 
67
  # Set up commit scheduler
68
  scheduler = CommitScheduler(
69
  repo_id=BACKUP_REPO,
@@ -81,6 +85,9 @@ RESULT_CSV = "./data/finger_count_results.csv"
81
 
82
  # Load the dataset
83
  ds = load_dataset(DATASET_NAME, split="train")
 
 
 
84
 
85
  # A thread lock to avoid concurrent writes
86
  write_lock = threading.Lock()
@@ -94,61 +101,68 @@ in_progress_samples = OrderedDict()
94
  IN_PROGRESS_TTL = 300 # 5 minutes in seconds
95
  MAX_IN_PROGRESS = 1000 # Maximum number of in-progress samples to track
96
 
 
97
  # Load previously annotated samples from CSV
98
  def load_annotated_samples():
99
  try:
100
- with open(RESULT_CSV, 'r', newline='', encoding='utf-8') as f:
101
  reader = csv.reader(f)
102
  next(reader) # Skip header
103
  for row in reader:
104
  record_uuid = row[1]
105
- # Find index for this UUID
106
- for idx, item in enumerate(ds):
107
- if item['uuid'] == record_uuid:
108
- annotated_samples.add(idx)
109
- break
110
  except FileExistsError:
111
  pass
112
 
 
113
  # Prepare the CSV file and load annotated samples
114
  with write_lock:
115
  try:
116
- with open(RESULT_CSV, 'x', newline='', encoding='utf-8') as f:
117
  writer = csv.writer(f)
118
  writer.writerow(["session_id", "uuid", "prompt", "choice"])
119
  except FileExistsError:
120
  load_annotated_samples()
121
 
 
122
  def cleanup_in_progress():
123
  """Remove expired in-progress samples"""
124
  current_time = time.time()
125
- while in_progress_samples and list(in_progress_samples.items())[0][1][0] < current_time - IN_PROGRESS_TTL:
 
 
 
126
  in_progress_samples.popitem(last=False)
127
 
 
128
  def get_random_sample(session_id):
129
  """Get a random sample that's neither annotated nor in progress"""
130
  cleanup_in_progress()
131
-
132
  # Get all possible indices
133
  all_indices = set(range(len(ds)))
134
  # Get unavailable indices (annotated + in-progress)
135
  unavailable = annotated_samples.union(in_progress_samples.keys())
136
  # Get available indices
137
  available = list(all_indices - unavailable)
138
-
139
  if not available:
140
  return None
141
-
142
  # Select random index from available ones
143
  index = random.choice(available)
144
-
145
  # Add to in-progress samples
146
  if len(in_progress_samples) >= MAX_IN_PROGRESS:
147
  in_progress_samples.popitem(last=False) # Remove oldest item
148
  in_progress_samples[index] = (time.time(), session_id)
149
-
150
  return index
151
 
 
152
  def get_record(index):
153
  """
154
  Given an index, return:
@@ -159,6 +173,7 @@ def get_record(index):
159
  record = ds[index]
160
  return record["image"], record["prompt"], record["uuid"]
161
 
 
162
  def update_session(choice, session_id, index):
163
  """
164
  This function is called whenever a user presses a button.
@@ -172,7 +187,7 @@ def update_session(choice, session_id, index):
172
 
173
  # Write to CSV
174
  with write_lock:
175
- with open(RESULT_CSV, 'a', newline='', encoding='utf-8') as f:
176
  writer = csv.writer(f)
177
  writer.writerow([session_id, record_uuid, prompt, choice])
178
 
@@ -190,18 +205,19 @@ def update_session(choice, session_id, index):
190
 
191
  return (next_image, next_prompt, new_index, f"UUID: {next_uuid}")
192
 
 
193
  # Create a Gradio interface
194
  with gr.Blocks() as demo:
195
  gr.Markdown("## Human Hands Finger Counting App")
196
-
197
  # State: each user has a unique session ID and current index
198
  session_id = gr.State(str(uuid.uuid4()))
199
  current_index = gr.State(0)
200
-
201
  image_display = gr.Image(type="pil", label="Image to Review")
202
  prompt_display = gr.Markdown()
203
  uuid_display = gr.Markdown() # Add UUID display
204
-
205
  # Initialize with the first record
206
  def start_app(session_id, index):
207
  if index == 0: # Only get random sample for new sessions
@@ -210,7 +226,7 @@ with gr.Blocks() as demo:
210
  return None, "No more images to label. Thank you!", ""
211
  img, prompt, uuid_str = get_record(index)
212
  return img, prompt, f"UUID: {uuid_str}"
213
-
214
  with gr.Row():
215
  # Buttons for finger count
216
  btn_three = gr.Button("Three")
@@ -223,7 +239,7 @@ with gr.Blocks() as demo:
223
  btn_ten = gr.Button("Ten")
224
  btn_more = gr.Button("More than 11")
225
  btn_cannot = gr.Button("Cannot identify", variant="stop") # Red background
226
-
227
  # Define partial functions to specify each choice
228
  def choose_three(session_id, index):
229
  return update_session("three", session_id, index)
@@ -256,51 +272,71 @@ with gr.Blocks() as demo:
256
  return update_session("cannot_identify", session_id, index)
257
 
258
  # Link button clicks to functions
259
- btn_three.click(fn=choose_three,
260
- inputs=[session_id, current_index],
261
- outputs=[image_display, prompt_display, current_index, uuid_display])
262
-
263
- btn_four.click(fn=choose_four,
264
- inputs=[session_id, current_index],
265
- outputs=[image_display, prompt_display, current_index, uuid_display])
266
-
267
- btn_five.click(fn=choose_five,
268
- inputs=[session_id, current_index],
269
- outputs=[image_display, prompt_display, current_index, uuid_display])
270
-
271
- btn_six.click(fn=choose_six,
272
- inputs=[session_id, current_index],
273
- outputs=[image_display, prompt_display, current_index, uuid_display])
274
-
275
- btn_seven.click(fn=choose_seven,
276
- inputs=[session_id, current_index],
277
- outputs=[image_display, prompt_display, current_index, uuid_display])
278
-
279
- btn_eight.click(fn=choose_eight,
280
- inputs=[session_id, current_index],
281
- outputs=[image_display, prompt_display, current_index, uuid_display])
282
-
283
- btn_nine.click(fn=choose_nine,
284
- inputs=[session_id, current_index],
285
- outputs=[image_display, prompt_display, current_index, uuid_display])
286
-
287
- btn_ten.click(fn=choose_ten,
288
- inputs=[session_id, current_index],
289
- outputs=[image_display, prompt_display, current_index, uuid_display])
290
-
291
- btn_more.click(fn=choose_more,
292
- inputs=[session_id, current_index],
293
- outputs=[image_display, prompt_display, current_index, uuid_display])
294
-
295
- btn_cannot.click(fn=choose_cannot,
296
- inputs=[session_id, current_index],
297
- outputs=[image_display, prompt_display, current_index, uuid_display])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
 
299
  # Load the first image/prompt on launch
300
  demo.load(
301
  fn=start_app,
302
  inputs=[session_id, current_index],
303
- outputs=[image_display, prompt_display, uuid_display]
304
  )
305
 
306
  demo.launch()
 
16
 
17
 
18
  DATASET_NAME = "taesiri/HumanHandsDataset"
19
+ BACKUP_REPO = "taesiri/HumanHandsDatasetFingerCounts"
20
 
21
  # Create data directory
22
  os.makedirs("./data", exist_ok=True)
23
 
24
+
25
  def sync_with_hub():
26
  """
27
  Synchronize local data with the hub by downloading latest dataset
 
57
  if local_csv_path.exists()
58
  else pd.DataFrame()
59
  )
60
+ merged_csv = pd.concat(
61
+ [local_csv, hub_csv], ignore_index=True
62
+ ).drop_duplicates()
63
  merged_csv.to_csv(local_csv_path, index=False)
64
 
65
  # Clean up downloaded repo
 
67
  shutil.rmtree("hub_data")
68
  print("Finished syncing with hub!")
69
 
70
+
71
  # Set up commit scheduler
72
  scheduler = CommitScheduler(
73
  repo_id=BACKUP_REPO,
 
85
 
86
  # Load the dataset
87
  ds = load_dataset(DATASET_NAME, split="train")
88
+ # Get UUID lookup dataframe for efficient searching
89
+ uuid_df = load_dataset(DATASET_NAME, split="train", columns=["uuid"])
90
+ uuid_df = pd.DataFrame(uuid_df)
91
 
92
  # A thread lock to avoid concurrent writes
93
  write_lock = threading.Lock()
 
101
  IN_PROGRESS_TTL = 300 # 5 minutes in seconds
102
  MAX_IN_PROGRESS = 1000 # Maximum number of in-progress samples to track
103
 
104
+
105
  # Load previously annotated samples from CSV
106
  def load_annotated_samples():
107
  try:
108
+ with open(RESULT_CSV, "r", newline="", encoding="utf-8") as f:
109
  reader = csv.reader(f)
110
  next(reader) # Skip header
111
  for row in reader:
112
  record_uuid = row[1]
113
+ # Find index for this UUID using efficient dataframe lookup
114
+ idx = uuid_df.index[uuid_df["uuid"] == record_uuid].tolist()
115
+ if idx:
116
+ annotated_samples.add(idx[0])
 
117
  except FileExistsError:
118
  pass
119
 
120
+
121
  # Prepare the CSV file and load annotated samples
122
  with write_lock:
123
  try:
124
+ with open(RESULT_CSV, "x", newline="", encoding="utf-8") as f:
125
  writer = csv.writer(f)
126
  writer.writerow(["session_id", "uuid", "prompt", "choice"])
127
  except FileExistsError:
128
  load_annotated_samples()
129
 
130
+
131
  def cleanup_in_progress():
132
  """Remove expired in-progress samples"""
133
  current_time = time.time()
134
+ while (
135
+ in_progress_samples
136
+ and list(in_progress_samples.items())[0][1][0] < current_time - IN_PROGRESS_TTL
137
+ ):
138
  in_progress_samples.popitem(last=False)
139
 
140
+
141
  def get_random_sample(session_id):
142
  """Get a random sample that's neither annotated nor in progress"""
143
  cleanup_in_progress()
144
+
145
  # Get all possible indices
146
  all_indices = set(range(len(ds)))
147
  # Get unavailable indices (annotated + in-progress)
148
  unavailable = annotated_samples.union(in_progress_samples.keys())
149
  # Get available indices
150
  available = list(all_indices - unavailable)
151
+
152
  if not available:
153
  return None
154
+
155
  # Select random index from available ones
156
  index = random.choice(available)
157
+
158
  # Add to in-progress samples
159
  if len(in_progress_samples) >= MAX_IN_PROGRESS:
160
  in_progress_samples.popitem(last=False) # Remove oldest item
161
  in_progress_samples[index] = (time.time(), session_id)
162
+
163
  return index
164
 
165
+
166
  def get_record(index):
167
  """
168
  Given an index, return:
 
173
  record = ds[index]
174
  return record["image"], record["prompt"], record["uuid"]
175
 
176
+
177
  def update_session(choice, session_id, index):
178
  """
179
  This function is called whenever a user presses a button.
 
187
 
188
  # Write to CSV
189
  with write_lock:
190
+ with open(RESULT_CSV, "a", newline="", encoding="utf-8") as f:
191
  writer = csv.writer(f)
192
  writer.writerow([session_id, record_uuid, prompt, choice])
193
 
 
205
 
206
  return (next_image, next_prompt, new_index, f"UUID: {next_uuid}")
207
 
208
+
209
  # Create a Gradio interface
210
  with gr.Blocks() as demo:
211
  gr.Markdown("## Human Hands Finger Counting App")
212
+
213
  # State: each user has a unique session ID and current index
214
  session_id = gr.State(str(uuid.uuid4()))
215
  current_index = gr.State(0)
216
+
217
  image_display = gr.Image(type="pil", label="Image to Review")
218
  prompt_display = gr.Markdown()
219
  uuid_display = gr.Markdown() # Add UUID display
220
+
221
  # Initialize with the first record
222
  def start_app(session_id, index):
223
  if index == 0: # Only get random sample for new sessions
 
226
  return None, "No more images to label. Thank you!", ""
227
  img, prompt, uuid_str = get_record(index)
228
  return img, prompt, f"UUID: {uuid_str}"
229
+
230
  with gr.Row():
231
  # Buttons for finger count
232
  btn_three = gr.Button("Three")
 
239
  btn_ten = gr.Button("Ten")
240
  btn_more = gr.Button("More than 11")
241
  btn_cannot = gr.Button("Cannot identify", variant="stop") # Red background
242
+
243
  # Define partial functions to specify each choice
244
  def choose_three(session_id, index):
245
  return update_session("three", session_id, index)
 
272
  return update_session("cannot_identify", session_id, index)
273
 
274
  # Link button clicks to functions
275
+ btn_three.click(
276
+ fn=choose_three,
277
+ inputs=[session_id, current_index],
278
+ outputs=[image_display, prompt_display, current_index, uuid_display],
279
+ )
280
+
281
+ btn_four.click(
282
+ fn=choose_four,
283
+ inputs=[session_id, current_index],
284
+ outputs=[image_display, prompt_display, current_index, uuid_display],
285
+ )
286
+
287
+ btn_five.click(
288
+ fn=choose_five,
289
+ inputs=[session_id, current_index],
290
+ outputs=[image_display, prompt_display, current_index, uuid_display],
291
+ )
292
+
293
+ btn_six.click(
294
+ fn=choose_six,
295
+ inputs=[session_id, current_index],
296
+ outputs=[image_display, prompt_display, current_index, uuid_display],
297
+ )
298
+
299
+ btn_seven.click(
300
+ fn=choose_seven,
301
+ inputs=[session_id, current_index],
302
+ outputs=[image_display, prompt_display, current_index, uuid_display],
303
+ )
304
+
305
+ btn_eight.click(
306
+ fn=choose_eight,
307
+ inputs=[session_id, current_index],
308
+ outputs=[image_display, prompt_display, current_index, uuid_display],
309
+ )
310
+
311
+ btn_nine.click(
312
+ fn=choose_nine,
313
+ inputs=[session_id, current_index],
314
+ outputs=[image_display, prompt_display, current_index, uuid_display],
315
+ )
316
+
317
+ btn_ten.click(
318
+ fn=choose_ten,
319
+ inputs=[session_id, current_index],
320
+ outputs=[image_display, prompt_display, current_index, uuid_display],
321
+ )
322
+
323
+ btn_more.click(
324
+ fn=choose_more,
325
+ inputs=[session_id, current_index],
326
+ outputs=[image_display, prompt_display, current_index, uuid_display],
327
+ )
328
+
329
+ btn_cannot.click(
330
+ fn=choose_cannot,
331
+ inputs=[session_id, current_index],
332
+ outputs=[image_display, prompt_display, current_index, uuid_display],
333
+ )
334
 
335
  # Load the first image/prompt on launch
336
  demo.load(
337
  fn=start_app,
338
  inputs=[session_id, current_index],
339
+ outputs=[image_display, prompt_display, uuid_display],
340
  )
341
 
342
  demo.launch()