yeq6x commited on
Commit
6fabdaf
·
1 Parent(s): 7c9164c

Refactor app.py to update prefix/suffix naming conventions for metadata creation and enhance UI with new training hyperparameter inputs. Modify train_QIE.sh to utilize dynamic hyperparameter values for training execution, improving configurability and user experience.

Browse files
Files changed (2) hide show
  1. app.py +20 -8
  2. train_QIE.sh +12 -5
app.py CHANGED
@@ -322,9 +322,9 @@ def _prepare_script(
322
  # Inject prefix/suffix flags for metadata creation
323
  extra_lines: List[str] = []
324
  if (target_prefix or ""):
325
- extra_lines.append(f" --target_prefix {_bash_quote(target_prefix)} \\")
326
  if (target_suffix or ""):
327
- extra_lines.append(f" --target_suffix {_bash_quote(target_suffix)} \\")
328
  for i in range(8):
329
  pre = control_prefixes[i] if (control_prefixes and i < len(control_prefixes)) else None
330
  suf = control_suffixes[i] if (control_suffixes and i < len(control_suffixes)) else None
@@ -489,6 +489,9 @@ def run_training(
489
  control7_uploads: Any,
490
  ctrl7_prefix: str,
491
  ctrl7_suffix: str,
 
 
 
492
  max_epochs: int,
493
  save_every: int,
494
  ) -> Iterable[tuple]:
@@ -593,6 +596,9 @@ def run_training(
593
  target_suffix=(target_suffix or ""),
594
  control_prefixes=[ctrl0_prefix, ctrl1_prefix, ctrl2_prefix, ctrl3_prefix, ctrl4_prefix, ctrl5_prefix, ctrl6_prefix, ctrl7_prefix],
595
  control_suffixes=[ctrl0_suffix, ctrl1_suffix, ctrl2_suffix, ctrl3_suffix, ctrl4_suffix, ctrl5_suffix, ctrl6_suffix, ctrl7_suffix],
 
 
 
596
  )
597
 
598
 
@@ -653,6 +659,14 @@ def build_ui() -> gr.Blocks:
653
  output_name = gr.Textbox(label="OUTPUT NAME", placeholder="my_lora_output", lines=1)
654
  caption = gr.Textbox(label="CAPTION", placeholder="A photo of ...", lines=2)
655
 
 
 
 
 
 
 
 
 
656
  with gr.Row():
657
  with gr.Column(scale=3):
658
  images_input = gr.File(label="Upload target images", file_count="multiple", type="filepath")
@@ -728,12 +742,10 @@ def build_ui() -> gr.Blocks:
728
  logs = gr.Textbox(label="Logs", lines=20)
729
  ckpt_files = gr.Files(label="Checkpoints (live)", interactive=False)
730
 
731
- with gr.Row():
732
- max_epochs = gr.Number(label="Max epochs (this run)", value=10, precision=0)
733
- save_every = gr.Number(label="Save every N epochs", value=5, precision=0)
734
 
735
  # Wire previews
736
- images_input.change(fn=_files_to_gallery, inputs=images_input, outputs=target_gallery)
737
  ctrl0_files.change(fn=_files_to_gallery, inputs=ctrl0_files, outputs=ctrl0_gallery)
738
  ctrl1_files.change(fn=_files_to_gallery, inputs=ctrl1_files, outputs=ctrl1_gallery)
739
  ctrl2_files.change(fn=_files_to_gallery, inputs=ctrl2_files, outputs=ctrl2_gallery)
@@ -746,7 +758,7 @@ def build_ui() -> gr.Blocks:
746
  run_btn.click(
747
  fn=run_training,
748
  inputs=[
749
- output_name, caption, images_input, target_prefix, target_suffix,
750
  ctrl0_files, ctrl0_prefix, ctrl0_suffix,
751
  ctrl1_files, ctrl1_prefix, ctrl1_suffix,
752
  ctrl2_files, ctrl2_prefix, ctrl2_suffix,
@@ -755,7 +767,7 @@ def build_ui() -> gr.Blocks:
755
  ctrl5_files, ctrl5_prefix, ctrl5_suffix,
756
  ctrl6_files, ctrl6_prefix, ctrl6_suffix,
757
  ctrl7_files, ctrl7_prefix, ctrl7_suffix,
758
- max_epochs, save_every,
759
  ],
760
  outputs=[logs, ckpt_files],
761
  )
 
322
  # Inject prefix/suffix flags for metadata creation
323
  extra_lines: List[str] = []
324
  if (target_prefix or ""):
325
+ extra_lines.append(f" --main_prefix {_bash_quote(target_prefix)} \\")
326
  if (target_suffix or ""):
327
+ extra_lines.append(f" --main_suffix {_bash_quote(target_suffix)} \\")
328
  for i in range(8):
329
  pre = control_prefixes[i] if (control_prefixes and i < len(control_prefixes)) else None
330
  suf = control_suffixes[i] if (control_suffixes and i < len(control_suffixes)) else None
 
489
  control7_uploads: Any,
490
  ctrl7_prefix: str,
491
  ctrl7_suffix: str,
492
+ learning_rate: str,
493
+ network_dim: int,
494
+ seed: int,
495
  max_epochs: int,
496
  save_every: int,
497
  ) -> Iterable[tuple]:
 
596
  target_suffix=(target_suffix or ""),
597
  control_prefixes=[ctrl0_prefix, ctrl1_prefix, ctrl2_prefix, ctrl3_prefix, ctrl4_prefix, ctrl5_prefix, ctrl6_prefix, ctrl7_prefix],
598
  control_suffixes=[ctrl0_suffix, ctrl1_suffix, ctrl2_suffix, ctrl3_suffix, ctrl4_suffix, ctrl5_suffix, ctrl6_suffix, ctrl7_suffix],
599
+ override_learning_rate=(learning_rate or None),
600
+ override_network_dim=int(network_dim) if network_dim is not None else None,
601
+ override_seed=int(seed) if seed is not None else None,
602
  )
603
 
604
 
 
659
  output_name = gr.Textbox(label="OUTPUT NAME", placeholder="my_lora_output", lines=1)
660
  caption = gr.Textbox(label="CAPTION", placeholder="A photo of ...", lines=2)
661
 
662
+ # Training options near OUTPUT NAME
663
+ with gr.Row():
664
+ lr_input = gr.Textbox(label="Learning rate", value="1e-3")
665
+ dim_input = gr.Number(label="Network dim", value=4, precision=0)
666
+ seed_input = gr.Number(label="Seed", value=42, precision=0)
667
+ max_epochs = gr.Number(label="Max epochs", value=100, precision=0)
668
+ save_every = gr.Number(label="Save every N epochs", value=10, precision=0)
669
+
670
  with gr.Row():
671
  with gr.Column(scale=3):
672
  images_input = gr.File(label="Upload target images", file_count="multiple", type="filepath")
 
742
  logs = gr.Textbox(label="Logs", lines=20)
743
  ckpt_files = gr.Files(label="Checkpoints (live)", interactive=False)
744
 
745
+ # moved max_epochs/save_every above next to OUTPUT NAME
 
 
746
 
747
  # Wire previews
748
+ images_input.change(fn=_files_to_gallery, inputs=images_input, outputs=main_gallery)
749
  ctrl0_files.change(fn=_files_to_gallery, inputs=ctrl0_files, outputs=ctrl0_gallery)
750
  ctrl1_files.change(fn=_files_to_gallery, inputs=ctrl1_files, outputs=ctrl1_gallery)
751
  ctrl2_files.change(fn=_files_to_gallery, inputs=ctrl2_files, outputs=ctrl2_gallery)
 
758
  run_btn.click(
759
  fn=run_training,
760
  inputs=[
761
+ output_name, caption, images_input, main_prefix, main_suffix,
762
  ctrl0_files, ctrl0_prefix, ctrl0_suffix,
763
  ctrl1_files, ctrl1_prefix, ctrl1_suffix,
764
  ctrl2_files, ctrl2_prefix, ctrl2_suffix,
 
767
  ctrl5_files, ctrl5_prefix, ctrl5_suffix,
768
  ctrl6_files, ctrl6_prefix, ctrl6_suffix,
769
  ctrl7_files, ctrl7_prefix, ctrl7_suffix,
770
+ lr_input, dim_input, seed_input, max_epochs, save_every,
771
  ],
772
  outputs=[logs, ckpt_files],
773
  )
train_QIE.sh CHANGED
@@ -39,6 +39,13 @@ OUTPUT_DIR_BASE="/workspace/auto/train_LoRA"
39
  DATASET_CONFIG="/workspace/auto/dataset_QIE.toml"
40
  OUTPUT_JSON="${DATASET_DIR%/}/metadata.jsonl"
41
 
 
 
 
 
 
 
 
42
  # Build control args from folder names with auto-detect fallback
43
  CONTROL_ARGS=()
44
  for i in {0..7}; do
@@ -121,15 +128,15 @@ accelerate launch src/musubi_tuner/qwen_image_train_network.py \
121
  --weighting_scheme none \
122
  --discrete_flow_shift 2.0 \
123
  --optimizer_type adamw8bit \
124
- --learning_rate 1e-3 \
125
  --gradient_checkpointing \
126
  --max_data_loader_n_workers 2 \
127
  --persistent_data_loader_workers \
128
  --network_module networks.lora_qwen_image \
129
- --network_dim 4 \
130
- --max_train_epochs 100 \
131
- --save_every_n_epochs 10 \
132
- --seed 42 \
133
  --output_dir "${OUTPUT_DIR_BASE}/${RUN_NAME}" \
134
  --output_name "${RUN_NAME}" \
135
  --ddp_gradient_as_bucket_view \
 
39
  DATASET_CONFIG="/workspace/auto/dataset_QIE.toml"
40
  OUTPUT_JSON="${DATASET_DIR%/}/metadata.jsonl"
41
 
42
+ # Training hyperparameters (can be overridden by app)
43
+ LEARNING_RATE="1e-3"
44
+ NETWORK_DIM=4
45
+ SEED=42
46
+ MAX_TRAIN_EPOCHS=100
47
+ SAVE_EVERY_N_EPOCHS=10
48
+
49
  # Build control args from folder names with auto-detect fallback
50
  CONTROL_ARGS=()
51
  for i in {0..7}; do
 
128
  --weighting_scheme none \
129
  --discrete_flow_shift 2.0 \
130
  --optimizer_type adamw8bit \
131
+ --learning_rate "$LEARNING_RATE" \
132
  --gradient_checkpointing \
133
  --max_data_loader_n_workers 2 \
134
  --persistent_data_loader_workers \
135
  --network_module networks.lora_qwen_image \
136
+ --network_dim "$NETWORK_DIM" \
137
+ --max_train_epochs "$MAX_TRAIN_EPOCHS" \
138
+ --save_every_n_epochs "$SAVE_EVERY_N_EPOCHS" \
139
+ --seed "$SEED" \
140
  --output_dir "${OUTPUT_DIR_BASE}/${RUN_NAME}" \
141
  --output_name "${RUN_NAME}" \
142
  --ddp_gradient_as_bucket_view \