derekl35 HF Staff commited on
Commit
e9b7b43
·
verified ·
1 Parent(s): 4fa756e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -241
app.py CHANGED
@@ -35,24 +35,6 @@ def _save_agg_stats(stats: dict) -> None:
35
  with open(AGG_FILE, "w") as f:
36
  json.dump(stats, f, indent=2)
37
 
38
- USER_STATS_FILE = Path(__file__).parent / "user_stats.json"
39
- USER_STATS_LOCK_FILE = USER_STATS_FILE.with_suffix(".lock")
40
-
41
- def _load_user_stats() -> dict:
42
- if USER_STATS_FILE.exists():
43
- with open(USER_STATS_FILE, "r") as f:
44
- try:
45
- return json.load(f)
46
- except json.JSONDecodeError:
47
- print(f"Warning: {USER_STATS_FILE} is corrupted. Starting with empty user stats.")
48
- return {}
49
- return {}
50
-
51
- def _save_user_stats(stats: dict) -> None:
52
- with InterProcessLock(str(USER_STATS_LOCK_FILE)):
53
- with open(USER_STATS_FILE, "w") as f:
54
- json.dump(stats, f, indent=2)
55
-
56
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
57
  print(f"Using device: {DEVICE}")
58
 
@@ -62,7 +44,7 @@ DEFAULT_GUIDANCE_SCALE = 3.5
62
  DEFAULT_NUM_INFERENCE_STEPS = 15
63
  DEFAULT_MAX_SEQUENCE_LENGTH = 512
64
  HF_TOKEN = os.environ.get("HF_ACCESS_TOKEN")
65
- HF_DATASET_REPO_ID = "derekl35/flux-quant-challenge-submissions"
66
 
67
  CACHED_PIPES = {}
68
  def load_bf16_pipeline():
@@ -99,7 +81,6 @@ def load_bnb_8bit_pipeline():
99
  torch_dtype=torch.bfloat16
100
  )
101
  pipe.to(DEVICE)
102
- # pipe.enable_model_cpu_offload()
103
  end_time = time.time()
104
  mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0
105
  print(f"8-bit BNB pipeline loaded in {end_time - start_time:.2f}s. Memory reserved: {mem_reserved:.2f} GB")
@@ -121,7 +102,6 @@ def load_bnb_4bit_pipeline():
121
  torch_dtype=torch.bfloat16
122
  )
123
  pipe.to(DEVICE)
124
- # pipe.enable_model_cpu_offload()
125
  end_time = time.time()
126
  mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0
127
  print(f"4-bit BNB pipeline loaded in {end_time - start_time:.2f}s. Memory reserved: {mem_reserved:.2f} GB")
@@ -134,10 +114,10 @@ def load_bnb_4bit_pipeline():
134
  @spaces.GPU(duration=240)
135
  def generate_images(prompt, quantization_choice, progress=gr.Progress(track_tqdm=True)):
136
  if not prompt:
137
- return None, {}, gr.update(value="Please enter a prompt.", interactive=False), gr.update(choices=[], value=None), gr.update(interactive=True), gr.update(interactive=True)
138
 
139
  if not quantization_choice:
140
- return None, {}, gr.update(value="Please select a quantization method.", interactive=False), gr.update(choices=[], value=None), gr.update(interactive=True), gr.update(interactive=True)
141
 
142
  if quantization_choice == "8-bit bnb":
143
  quantized_load_func = load_bnb_8bit_pipeline
@@ -146,7 +126,7 @@ def generate_images(prompt, quantization_choice, progress=gr.Progress(track_tqdm
146
  quantized_load_func = load_bnb_4bit_pipeline
147
  quantized_label = "Quantized (4-bit bnb)"
148
  else:
149
- return None, {}, gr.update(value="Invalid quantization choice.", interactive=False), gr.update(choices=[], value=None), gr.update(interactive=True), gr.update(interactive=True)
150
 
151
  model_configs = [
152
  ("Original", load_bf16_pipeline),
@@ -188,11 +168,11 @@ def generate_images(prompt, quantization_choice, progress=gr.Progress(track_tqdm
188
 
189
  except Exception as e:
190
  print(f"Error during {label} model processing: {e}")
191
- return None, {}, gr.update(value=f"Error processing {label} model: {e}", interactive=False), gr.update(choices=[], value=None), gr.update(interactive=True), gr.update(interactive=True)
192
 
193
 
194
  if len(results) != len(model_configs):
195
- return None, {}, gr.update(value="Failed to generate images for all model types.", interactive=False), gr.update(choices=[], value=None), gr.update(interactive=True), gr.update(interactive=True)
196
 
197
  shuffled_results = results.copy()
198
  random.shuffle(shuffled_results)
@@ -263,13 +243,6 @@ def _accuracy_string(correct: int, attempts: int) -> tuple[str, float]:
263
  return f"{pct:.1f}%", pct
264
  return "N/A", -1.0
265
 
266
- def _add_medals(user_rows):
267
- MEDALS = {0: "🥇 ", 1: "🥈 ", 2: "🥉 "}
268
- return [
269
- [MEDALS.get(i, "") + row[0], *row[1:]]
270
- for i, row in enumerate(user_rows)
271
- ]
272
-
273
  def update_leaderboards_data():
274
  agg = _load_agg_stats()
275
  quant_rows = []
@@ -282,50 +255,12 @@ def update_leaderboards_data():
282
  acc_str
283
  ])
284
  quant_rows.sort(key=lambda r: r[1]/r[2] if r[2] != 0 else 1e9)
285
-
286
- user_stats_all = _load_user_stats()
287
-
288
- overall_user_rows = []
289
- for user, per_method_stats_dict in user_stats_all.items():
290
- user_total_correct = 0
291
- user_total_attempts = 0
292
- for method_stats in per_method_stats_dict.values():
293
- user_total_correct += method_stats.get("correct", 0)
294
- user_total_attempts += method_stats.get("attempts", 0)
295
-
296
- if user_total_attempts >= 1:
297
- acc_str, _ = _accuracy_string(user_total_correct, user_total_attempts)
298
- overall_user_rows.append([user, user_total_correct, user_total_attempts, acc_str])
299
-
300
- overall_user_rows.sort(key=lambda r: (-float(r[3].rstrip('%')) if r[3] != "N/A" else float('-inf'), -r[2]))
301
- overall_user_rows_medaled = _add_medals(overall_user_rows)
302
-
303
- user_leaderboards_per_method = {}
304
- quant_method_names = list(agg.keys())
305
-
306
- for method_name in quant_method_names:
307
- method_specific_user_rows = []
308
- for user, per_user_method_stats_dict in user_stats_all.items():
309
- if method_name in per_user_method_stats_dict:
310
- st = per_user_method_stats_dict[method_name]
311
- if st.get("attempts", 0) >= 1: # Only include users who have attempted this method
312
- acc_str, _ = _accuracy_string(st["correct"], st["attempts"])
313
- method_specific_user_rows.append([user, st["correct"], st["attempts"], acc_str])
314
-
315
- method_specific_user_rows.sort(key=lambda r: (-float(r[3].rstrip('%')) if r[3] != "N/A" else float('-inf'), -r[2]))
316
- method_specific_user_rows_medaled = _add_medals(method_specific_user_rows)
317
- user_leaderboards_per_method[method_name] = method_specific_user_rows_medaled
318
-
319
- return quant_rows, overall_user_rows_medaled, user_leaderboards_per_method
320
 
321
  quant_df = gr.DataFrame(
322
  headers=["Method", "Correct Guesses", "Total Attempts", "Detectability %"],
323
  interactive=False, col_count=(4, "fixed")
324
  )
325
- user_df = gr.DataFrame(
326
- headers=["User", "Correct Guesses", "Total Attempts", "Accuracy %"],
327
- interactive=False, col_count=(4, "fixed")
328
- )
329
 
330
  with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as demo:
331
  gr.Markdown("# FLUX Model Quantization Challenge")
@@ -372,26 +307,16 @@ with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as d
372
 
373
  with gr.Row():
374
  session_score_box = gr.Textbox(label="Your accuracy this session", interactive=False)
375
-
376
- with gr.Row(equal_height=False):
377
- username_input = gr.Textbox(
378
- label="Enter Your Name for Leaderboard",
379
- placeholder="YourName",
380
- visible=False,
381
- interactive=True,
382
- scale=2
383
- )
384
- add_score_button = gr.Button(
385
- "Add My Score to Leaderboard",
386
- visible=False,
387
- variant="secondary",
388
- scale=1
389
- )
390
- add_score_feedback = gr.Textbox(
391
- label="Leaderboard Update",
392
- visible=False,
393
- interactive=False,
394
- lines=1
395
  )
396
 
397
  correct_mapping_state = gr.State({})
@@ -400,29 +325,26 @@ with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as d
400
  "4-bit bnb": {"attempts": 0, "correct": 0}}
401
  )
402
  is_example_state = gr.State(False)
403
- has_added_score_state = gr.State(False)
404
  prompt_state = gr.State("")
405
  seed_state = gr.State(None)
406
  results_state = gr.State([])
407
 
408
  def _load_example_and_update_dfs(sel_summary):
409
- # Find the index of the selected example by its summary
410
  idx = next((i for i, ex in enumerate(EXAMPLES) if ex["summary"] == sel_summary), -1)
411
  if idx == -1:
412
- # Fallback or error handling if summary not found
413
  print(f"Error: Example with summary '{sel_summary}' not found.")
414
- return (gr.update(), gr.update(), gr.update(), False, gr.update(), gr.update(), "", None, [])
415
 
416
  ex = EXAMPLES[idx]
417
  gallery_items, mapping, prompt = load_example(idx)
418
- quant_data, overall_user_data, _ = update_leaderboards_data()
419
- return gallery_items, mapping, prompt, True, quant_data, overall_user_data, "", None, []
420
 
421
  ex_selector.change(
422
  fn=_load_example_and_update_dfs,
423
  inputs=ex_selector,
424
- outputs=[output_gallery, correct_mapping_state, prompt_input, is_example_state, quant_df, user_df,
425
- prompt_state, seed_state, results_state],
426
  ).then(
427
  lambda: (gr.update(interactive=True), gr.update(interactive=True)),
428
  outputs=[image1_btn, image2_btn],
@@ -432,50 +354,39 @@ with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as d
432
  fn=generate_images,
433
  inputs=[prompt_input, quantization_choice_radio],
434
  outputs=[output_gallery, correct_mapping_state, prompt_state, seed_state, results_state,
435
- feedback_box] #, quantization_choice_radio, generate_button, prompt_input]
436
  ).then(
437
- lambda: (False, # for is_example_state
438
- False, # for has_added_score_state
439
- gr.update(visible=False, value="", interactive=True), # username_input reset
440
- gr.update(visible=False), # add_score_button reset
441
- gr.update(visible=False, value="")), # add_score_feedback reset
442
- outputs=[is_example_state,
443
- has_added_score_state,
444
- username_input,
445
- add_score_button,
446
- add_score_feedback]
447
  ).then(
448
  lambda: (gr.update(interactive=True),
449
- gr.update(interactive=True),
450
- ""),
451
  outputs=[image1_btn, image2_btn, feedback_box],
452
  )
453
 
454
- def choose(choice_string, mapping, session_stats, is_example, has_added_score_curr,
455
- prompt, seed, results, username):
456
  feedback = check_guess(choice_string, mapping)
457
 
458
  if not mapping:
459
- return feedback, gr.update(), gr.update(), "", session_stats, [], [], gr.update(), gr.update(), gr.update()
460
 
461
  quant_label_from_mapping = next((label for label in mapping.values() if "Quantized" in label), None)
462
  if not quant_label_from_mapping:
463
  print("Error: Could not determine quantization label from mapping:", mapping)
464
  return ("Internal Error: Could not process results.", gr.update(interactive=False), gr.update(interactive=False),
465
- "", session_stats, [], [], gr.update(), gr.update(), gr.update())
466
 
467
  quant_key = "8-bit bnb" if "8-bit bnb" in quant_label_from_mapping else "4-bit bnb"
468
-
469
  got_it_right = "Correct!" in feedback
470
-
471
  sess = session_stats.copy()
472
- should_log_and_update_stats = not is_example and not has_added_score_curr
473
 
474
- if should_log_and_update_stats:
475
  sess[quant_key]["attempts"] += 1
476
  if got_it_right:
477
  sess[quant_key]["correct"] += 1
478
- session_stats = sess
479
 
480
  AGG_STATS = _load_agg_stats()
481
  AGG_STATS[quant_key]["attempts"] += 1
@@ -487,6 +398,8 @@ with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as d
487
  print("Warning: HF_TOKEN not set. Skipping dataset logging.")
488
  elif not results:
489
  print("Warning: Results state is empty. Skipping dataset logging.")
 
 
490
  else:
491
  print(f"Logging guess to HF Dataset: {HF_DATASET_REPO_ID}")
492
  original_image = None
@@ -525,32 +438,22 @@ with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as d
525
  "quantized_image_displayed_position": [f"Image {quantized_image_pos + 1}"],
526
  "user_guess_displayed_position": [choice_string],
527
  "correct_guess": [got_it_right],
528
- "username": [username.strip() if username else None],
529
  }
530
-
531
  try:
532
- # Attempt to load existing dataset
533
  existing_ds = load_dataset(
534
  HF_DATASET_REPO_ID,
535
  split="train",
536
  token=HF_TOKEN,
537
  features=expected_features,
538
- # verification_mode="no_checks" # Consider removing or using default
539
- # download_mode="force_redownload" # For debugging cache issues
540
  )
541
- # Create a new dataset from the new item, casting to the expected features
542
  new_row_ds = Dataset.from_dict(new_data_dict_of_lists, features=expected_features)
543
- # Concatenate
544
  combined_ds = concatenate_datasets([existing_ds, new_row_ds])
545
- # Push the combined dataset
546
  combined_ds.push_to_hub(HF_DATASET_REPO_ID, token=HF_TOKEN, split="train")
547
  print(f"Successfully appended guess to {HF_DATASET_REPO_ID} (train split)")
548
-
549
  except Exception as e:
550
  print(f"Could not load or append to existing dataset/split. Creating 'train' split with the new item. Error: {e}")
551
- # Create dataset from only the new item, with explicit features
552
  ds_new = Dataset.from_dict(new_data_dict_of_lists, features=expected_features)
553
- # Push this new dataset as the 'train' split
554
  ds_new.push_to_hub(HF_DATASET_REPO_ID, token=HF_TOKEN, split="train")
555
  print(f"Successfully created and logged new 'train' split to {HF_DATASET_REPO_ID}")
556
  else:
@@ -564,136 +467,45 @@ with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as d
564
  session_msg = ", ".join(
565
  f"{k}: {_fmt(v)}" for k, v in sess.items()
566
  )
567
- current_agg_stats = _load_agg_stats()
568
-
569
- username_input_update = gr.update(visible=False, interactive=True)
570
- add_score_button_update = gr.update(visible=False)
571
- current_feedback_text = add_score_feedback.value if hasattr(add_score_feedback, 'value') and add_score_feedback.value else ""
572
- add_score_feedback_update = gr.update(visible=has_added_score_curr, value=current_feedback_text)
573
-
574
- session_total_attempts = sum(stats["attempts"] for stats in sess.values())
575
-
576
- if not is_example and not has_added_score_curr:
577
- if session_total_attempts >= 1 :
578
- username_input_update = gr.update(visible=True, interactive=True)
579
- add_score_button_update = gr.update(visible=True, interactive=True)
580
- add_score_feedback_update = gr.update(visible=False, value="")
581
- else:
582
- username_input_update = gr.update(visible=False, value=username_input.value if hasattr(username_input, 'value') else "")
583
- add_score_button_update = gr.update(visible=False)
584
- add_score_feedback_update = gr.update(visible=False, value="")
585
- elif has_added_score_curr:
586
- username_input_update = gr.update(visible=True, interactive=False, value=username_input.value if hasattr(username_input, 'value') else "")
587
- add_score_button_update = gr.update(visible=True, interactive=False)
588
- add_score_feedback_update = gr.update(visible=True)
589
-
590
- quant_data, overall_user_data, _ = update_leaderboards_data()
591
  return (feedback,
592
  gr.update(interactive=False),
593
  gr.update(interactive=False),
594
  session_msg,
595
- session_stats,
596
- quant_data,
597
- overall_user_data,
598
- username_input_update,
599
- add_score_button_update,
600
- add_score_feedback_update)
601
 
602
  image1_btn.click(
603
- fn=lambda mapping, sess, is_ex, has_added, p, s, r, uname: choose("Image 1", mapping, sess, is_ex, has_added, p, s, r, uname),
604
- inputs=[correct_mapping_state, session_stats_state, is_example_state, has_added_score_state,
605
- prompt_state, seed_state, results_state, username_input],
606
  outputs=[feedback_box, image1_btn, image2_btn,
607
- session_score_box, session_stats_state,
608
- quant_df, user_df,
609
- username_input, add_score_button, add_score_feedback],
610
  )
611
  image2_btn.click(
612
- fn=lambda mapping, sess, is_ex, has_added, p, s, r, uname: choose("Image 2", mapping, sess, is_ex, has_added, p, s, r, uname),
613
- inputs=[correct_mapping_state, session_stats_state, is_example_state, has_added_score_state,
614
- prompt_state, seed_state, results_state, username_input],
615
  outputs=[feedback_box, image1_btn, image2_btn,
616
- session_score_box, session_stats_state,
617
- quant_df, user_df,
618
- username_input, add_score_button, add_score_feedback],
619
  )
620
 
621
- def handle_add_score_to_leaderboard(username_str, current_session_stats_dict):
622
- if not username_str or not username_str.strip():
623
- return ("Username is required.",
624
- gr.update(interactive=True),
625
- gr.update(interactive=True),
626
- False,
627
- None, None)
628
-
629
- user_stats = _load_user_stats()
630
- user_key = username_str.strip()
631
-
632
- session_total_session_attempts = sum(stats["attempts"] for stats in current_session_stats_dict.values())
633
- if session_total_session_attempts == 0:
634
- return ("No attempts made in this session to add to leaderboard.",
635
- gr.update(interactive=True),
636
- gr.update(interactive=True),
637
- False, None, None)
638
-
639
- if user_key not in user_stats:
640
- user_stats[user_key] = {}
641
-
642
- for method, stats in current_session_stats_dict.items():
643
- session_method_correct = stats["correct"]
644
- session_method_attempts = stats["attempts"]
645
-
646
- if session_method_attempts == 0:
647
- continue
648
-
649
- if method not in user_stats[user_key]:
650
- user_stats[user_key][method] = {"correct": 0, "attempts": 0}
651
-
652
- user_stats[user_key][method]["correct"] += session_method_correct
653
- user_stats[user_key][method]["attempts"] += session_method_attempts
654
-
655
- _save_user_stats(user_stats)
656
-
657
- new_quant_data, new_overall_user_data, _ = update_leaderboards_data()
658
- feedback_msg = f"Score for '{user_key}' submitted to leaderboard!"
659
- return (feedback_msg,
660
- gr.update(interactive=False),
661
- gr.update(interactive=False),
662
- True,
663
- new_quant_data,
664
- new_overall_user_data)
665
- add_score_button.click(
666
- fn=handle_add_score_to_leaderboard,
667
- inputs=[username_input, session_stats_state],
668
- outputs=[add_score_feedback, username_input, add_score_button, has_added_score_state, quant_df, user_df]
669
- )
670
  with gr.TabItem("Leaderboard"):
671
  gr.Markdown("## Quantization Method Leaderboard *(Lower % ⇒ harder to detect)*")
672
  leaderboard_tab_quant_df = gr.DataFrame(
673
  headers=["Method", "Correct Guesses", "Total Attempts", "Detectability %"],
674
  interactive=False, col_count=(4, "fixed"), label="Quantization Method Leaderboard"
675
  )
676
- gr.Markdown("---")
677
-
678
- leaderboard_tab_user_df_8bit = gr.DataFrame(
679
- headers=["User", "Correct Guesses", "Total Attempts", "Accuracy %"],
680
- interactive=False, col_count=(4, "fixed"), label="8-bit bnb User Leaderboard"
681
- )
682
- leaderboard_tab_user_df_4bit = gr.DataFrame(
683
- headers=["User", "Correct Guesses", "Total Attempts", "Accuracy %"],
684
- interactive=False, col_count=(4, "fixed"), label="4-bit bnb User Leaderboard"
685
- )
686
 
687
  def update_all_leaderboards_for_tab():
688
- q_rows, _, per_method_u_dict = update_leaderboards_data()
689
- user_rows_8bit = per_method_u_dict.get("8-bit bnb", [])
690
- user_rows_4bit = per_method_u_dict.get("4-bit bnb", [])
691
- return q_rows, user_rows_8bit, user_rows_4bit
692
 
693
  demo.load(update_all_leaderboards_for_tab, outputs=[
694
- leaderboard_tab_quant_df,
695
- leaderboard_tab_user_df_8bit,
696
- leaderboard_tab_user_df_4bit
697
  ])
698
 
699
  if __name__ == "__main__":
 
35
  with open(AGG_FILE, "w") as f:
36
  json.dump(stats, f, indent=2)
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
39
  print(f"Using device: {DEVICE}")
40
 
 
44
  DEFAULT_NUM_INFERENCE_STEPS = 15
45
  DEFAULT_MAX_SEQUENCE_LENGTH = 512
46
  HF_TOKEN = os.environ.get("HF_ACCESS_TOKEN")
47
+ HF_DATASET_REPO_ID = "diffusers/flux-quant-challenge-submissions"
48
 
49
  CACHED_PIPES = {}
50
  def load_bf16_pipeline():
 
81
  torch_dtype=torch.bfloat16
82
  )
83
  pipe.to(DEVICE)
 
84
  end_time = time.time()
85
  mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0
86
  print(f"8-bit BNB pipeline loaded in {end_time - start_time:.2f}s. Memory reserved: {mem_reserved:.2f} GB")
 
102
  torch_dtype=torch.bfloat16
103
  )
104
  pipe.to(DEVICE)
 
105
  end_time = time.time()
106
  mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0
107
  print(f"4-bit BNB pipeline loaded in {end_time - start_time:.2f}s. Memory reserved: {mem_reserved:.2f} GB")
 
114
  @spaces.GPU(duration=240)
115
  def generate_images(prompt, quantization_choice, progress=gr.Progress(track_tqdm=True)):
116
  if not prompt:
117
+ return None, {}, gr.update(value="Please enter a prompt.", interactive=False), None, [], gr.update(interactive=True), gr.update(interactive=True)
118
 
119
  if not quantization_choice:
120
+ return None, {}, gr.update(value="Please select a quantization method.", interactive=False), None, [], gr.update(interactive=True), gr.update(interactive=True)
121
 
122
  if quantization_choice == "8-bit bnb":
123
  quantized_load_func = load_bnb_8bit_pipeline
 
126
  quantized_load_func = load_bnb_4bit_pipeline
127
  quantized_label = "Quantized (4-bit bnb)"
128
  else:
129
+ return None, {}, gr.update(value="Invalid quantization choice.", interactive=False), None, [], gr.update(interactive=True), gr.update(interactive=True)
130
 
131
  model_configs = [
132
  ("Original", load_bf16_pipeline),
 
168
 
169
  except Exception as e:
170
  print(f"Error during {label} model processing: {e}")
171
+ return None, {}, gr.update(value=f"Error processing {label} model: {e}", interactive=False), None, [], gr.update(interactive=True), gr.update(interactive=True)
172
 
173
 
174
  if len(results) != len(model_configs):
175
+ return None, {}, gr.update(value="Failed to generate images for all model types.", interactive=False), None, [], gr.update(interactive=True), gr.update(interactive=True)
176
 
177
  shuffled_results = results.copy()
178
  random.shuffle(shuffled_results)
 
243
  return f"{pct:.1f}%", pct
244
  return "N/A", -1.0
245
 
 
 
 
 
 
 
 
246
  def update_leaderboards_data():
247
  agg = _load_agg_stats()
248
  quant_rows = []
 
255
  acc_str
256
  ])
257
  quant_rows.sort(key=lambda r: r[1]/r[2] if r[2] != 0 else 1e9)
258
+ return quant_rows
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
  quant_df = gr.DataFrame(
261
  headers=["Method", "Correct Guesses", "Total Attempts", "Detectability %"],
262
  interactive=False, col_count=(4, "fixed")
263
  )
 
 
 
 
264
 
265
  with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as demo:
266
  gr.Markdown("# FLUX Model Quantization Challenge")
 
307
 
308
  with gr.Row():
309
  session_score_box = gr.Textbox(label="Your accuracy this session", interactive=False)
310
+
311
+ gr.Markdown("""
312
+ ### Dataset Information
313
+ Unless you opt out below, your submissions will be recorded in a dataset. This dataset contains anonymized challenge results including prompts, images, quantization methods,
314
+ and whether guesses were correct.
315
+ """)
316
+
317
+ opt_out_checkbox = gr.Checkbox(
318
+ label="Opt out of data collection (don't record my submissions to the dataset)",
319
+ value=False
 
 
 
 
 
 
 
 
 
 
320
  )
321
 
322
  correct_mapping_state = gr.State({})
 
325
  "4-bit bnb": {"attempts": 0, "correct": 0}}
326
  )
327
  is_example_state = gr.State(False)
 
328
  prompt_state = gr.State("")
329
  seed_state = gr.State(None)
330
  results_state = gr.State([])
331
 
332
  def _load_example_and_update_dfs(sel_summary):
 
333
  idx = next((i for i, ex in enumerate(EXAMPLES) if ex["summary"] == sel_summary), -1)
334
  if idx == -1:
 
335
  print(f"Error: Example with summary '{sel_summary}' not found.")
336
+ return (gr.update(), gr.update(), gr.update(), False, gr.update(), "", None, [])
337
 
338
  ex = EXAMPLES[idx]
339
  gallery_items, mapping, prompt = load_example(idx)
340
+ quant_data = update_leaderboards_data()
341
+ return gallery_items, mapping, prompt, True, quant_data, "", None, []
342
 
343
  ex_selector.change(
344
  fn=_load_example_and_update_dfs,
345
  inputs=ex_selector,
346
+ outputs=[output_gallery, correct_mapping_state, prompt_input, is_example_state, quant_df,
347
+ prompt_state, seed_state, results_state],
348
  ).then(
349
  lambda: (gr.update(interactive=True), gr.update(interactive=True)),
350
  outputs=[image1_btn, image2_btn],
 
354
  fn=generate_images,
355
  inputs=[prompt_input, quantization_choice_radio],
356
  outputs=[output_gallery, correct_mapping_state, prompt_state, seed_state, results_state,
357
+ feedback_box]
358
  ).then(
359
+ lambda: False, # for is_example_state
360
+ outputs=[is_example_state]
 
 
 
 
 
 
 
 
361
  ).then(
362
  lambda: (gr.update(interactive=True),
363
+ gr.update(interactive=True),
364
+ ""),
365
  outputs=[image1_btn, image2_btn, feedback_box],
366
  )
367
 
368
+ def choose(choice_string, mapping, session_stats, is_example,
369
+ prompt, seed, results, opt_out):
370
  feedback = check_guess(choice_string, mapping)
371
 
372
  if not mapping:
373
+ return feedback, gr.update(), gr.update(), "", session_stats, gr.update()
374
 
375
  quant_label_from_mapping = next((label for label in mapping.values() if "Quantized" in label), None)
376
  if not quant_label_from_mapping:
377
  print("Error: Could not determine quantization label from mapping:", mapping)
378
  return ("Internal Error: Could not process results.", gr.update(interactive=False), gr.update(interactive=False),
379
+ "", session_stats, gr.update())
380
 
381
  quant_key = "8-bit bnb" if "8-bit bnb" in quant_label_from_mapping else "4-bit bnb"
 
382
  got_it_right = "Correct!" in feedback
 
383
  sess = session_stats.copy()
 
384
 
385
+ if not is_example: # Only log and update stats if it's not an example run
386
  sess[quant_key]["attempts"] += 1
387
  if got_it_right:
388
  sess[quant_key]["correct"] += 1
389
+ session_stats = sess # Update the state for the UI
390
 
391
  AGG_STATS = _load_agg_stats()
392
  AGG_STATS[quant_key]["attempts"] += 1
 
398
  print("Warning: HF_TOKEN not set. Skipping dataset logging.")
399
  elif not results:
400
  print("Warning: Results state is empty. Skipping dataset logging.")
401
+ elif opt_out:
402
+ print("User opted out of dataset logging. Skipping.")
403
  else:
404
  print(f"Logging guess to HF Dataset: {HF_DATASET_REPO_ID}")
405
  original_image = None
 
438
  "quantized_image_displayed_position": [f"Image {quantized_image_pos + 1}"],
439
  "user_guess_displayed_position": [choice_string],
440
  "correct_guess": [got_it_right],
441
+ "username": [None], # Log None for username
442
  }
 
443
  try:
 
444
  existing_ds = load_dataset(
445
  HF_DATASET_REPO_ID,
446
  split="train",
447
  token=HF_TOKEN,
448
  features=expected_features,
 
 
449
  )
 
450
  new_row_ds = Dataset.from_dict(new_data_dict_of_lists, features=expected_features)
 
451
  combined_ds = concatenate_datasets([existing_ds, new_row_ds])
 
452
  combined_ds.push_to_hub(HF_DATASET_REPO_ID, token=HF_TOKEN, split="train")
453
  print(f"Successfully appended guess to {HF_DATASET_REPO_ID} (train split)")
 
454
  except Exception as e:
455
  print(f"Could not load or append to existing dataset/split. Creating 'train' split with the new item. Error: {e}")
 
456
  ds_new = Dataset.from_dict(new_data_dict_of_lists, features=expected_features)
 
457
  ds_new.push_to_hub(HF_DATASET_REPO_ID, token=HF_TOKEN, split="train")
458
  print(f"Successfully created and logged new 'train' split to {HF_DATASET_REPO_ID}")
459
  else:
 
467
  session_msg = ", ".join(
468
  f"{k}: {_fmt(v)}" for k, v in sess.items()
469
  )
470
+
471
+ quant_data = update_leaderboards_data()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
  return (feedback,
473
  gr.update(interactive=False),
474
  gr.update(interactive=False),
475
  session_msg,
476
+ session_stats, # Return the potentially updated session_stats
477
+ quant_data)
 
 
 
 
478
 
479
  image1_btn.click(
480
+ fn=lambda mapping, sess, is_ex, p, s, r, opt_out: choose("Image 1", mapping, sess, is_ex, p, s, r, opt_out),
481
+ inputs=[correct_mapping_state, session_stats_state, is_example_state,
482
+ prompt_state, seed_state, results_state, opt_out_checkbox],
483
  outputs=[feedback_box, image1_btn, image2_btn,
484
+ session_score_box, session_stats_state,
485
+ quant_df],
 
486
  )
487
  image2_btn.click(
488
+ fn=lambda mapping, sess, is_ex, p, s, r, opt_out: choose("Image 2", mapping, sess, is_ex, p, s, r, opt_out),
489
+ inputs=[correct_mapping_state, session_stats_state, is_example_state,
490
+ prompt_state, seed_state, results_state, opt_out_checkbox],
491
  outputs=[feedback_box, image1_btn, image2_btn,
492
+ session_score_box, session_stats_state,
493
+ quant_df],
 
494
  )
495
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
496
  with gr.TabItem("Leaderboard"):
497
  gr.Markdown("## Quantization Method Leaderboard *(Lower % ⇒ harder to detect)*")
498
  leaderboard_tab_quant_df = gr.DataFrame(
499
  headers=["Method", "Correct Guesses", "Total Attempts", "Detectability %"],
500
  interactive=False, col_count=(4, "fixed"), label="Quantization Method Leaderboard"
501
  )
 
 
 
 
 
 
 
 
 
 
502
 
503
  def update_all_leaderboards_for_tab():
504
+ q_rows = update_leaderboards_data()
505
+ return q_rows # Only return quantization method data
 
 
506
 
507
  demo.load(update_all_leaderboards_for_tab, outputs=[
508
+ leaderboard_tab_quant_df,
 
 
509
  ])
510
 
511
  if __name__ == "__main__":