Spaces:
Running
on
Zero
Running
on
Zero
Refactor app.py to update prefix/suffix naming conventions for metadata creation and enhance UI with new training hyperparameter inputs. Modify train_QIE.sh to utilize dynamic hyperparameter values for training execution, improving configurability and user experience.
Browse files- app.py +20 -8
- train_QIE.sh +12 -5
app.py
CHANGED
|
@@ -322,9 +322,9 @@ def _prepare_script(
|
|
| 322 |
# Inject prefix/suffix flags for metadata creation
|
| 323 |
extra_lines: List[str] = []
|
| 324 |
if (target_prefix or ""):
|
| 325 |
-
extra_lines.append(f" --
|
| 326 |
if (target_suffix or ""):
|
| 327 |
-
extra_lines.append(f" --
|
| 328 |
for i in range(8):
|
| 329 |
pre = control_prefixes[i] if (control_prefixes and i < len(control_prefixes)) else None
|
| 330 |
suf = control_suffixes[i] if (control_suffixes and i < len(control_suffixes)) else None
|
|
@@ -489,6 +489,9 @@ def run_training(
|
|
| 489 |
control7_uploads: Any,
|
| 490 |
ctrl7_prefix: str,
|
| 491 |
ctrl7_suffix: str,
|
|
|
|
|
|
|
|
|
|
| 492 |
max_epochs: int,
|
| 493 |
save_every: int,
|
| 494 |
) -> Iterable[tuple]:
|
|
@@ -593,6 +596,9 @@ def run_training(
|
|
| 593 |
target_suffix=(target_suffix or ""),
|
| 594 |
control_prefixes=[ctrl0_prefix, ctrl1_prefix, ctrl2_prefix, ctrl3_prefix, ctrl4_prefix, ctrl5_prefix, ctrl6_prefix, ctrl7_prefix],
|
| 595 |
control_suffixes=[ctrl0_suffix, ctrl1_suffix, ctrl2_suffix, ctrl3_suffix, ctrl4_suffix, ctrl5_suffix, ctrl6_suffix, ctrl7_suffix],
|
|
|
|
|
|
|
|
|
|
| 596 |
)
|
| 597 |
|
| 598 |
|
|
@@ -653,6 +659,14 @@ def build_ui() -> gr.Blocks:
|
|
| 653 |
output_name = gr.Textbox(label="OUTPUT NAME", placeholder="my_lora_output", lines=1)
|
| 654 |
caption = gr.Textbox(label="CAPTION", placeholder="A photo of ...", lines=2)
|
| 655 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 656 |
with gr.Row():
|
| 657 |
with gr.Column(scale=3):
|
| 658 |
images_input = gr.File(label="Upload target images", file_count="multiple", type="filepath")
|
|
@@ -728,12 +742,10 @@ def build_ui() -> gr.Blocks:
|
|
| 728 |
logs = gr.Textbox(label="Logs", lines=20)
|
| 729 |
ckpt_files = gr.Files(label="Checkpoints (live)", interactive=False)
|
| 730 |
|
| 731 |
-
|
| 732 |
-
max_epochs = gr.Number(label="Max epochs (this run)", value=10, precision=0)
|
| 733 |
-
save_every = gr.Number(label="Save every N epochs", value=5, precision=0)
|
| 734 |
|
| 735 |
# Wire previews
|
| 736 |
-
images_input.change(fn=_files_to_gallery, inputs=images_input, outputs=
|
| 737 |
ctrl0_files.change(fn=_files_to_gallery, inputs=ctrl0_files, outputs=ctrl0_gallery)
|
| 738 |
ctrl1_files.change(fn=_files_to_gallery, inputs=ctrl1_files, outputs=ctrl1_gallery)
|
| 739 |
ctrl2_files.change(fn=_files_to_gallery, inputs=ctrl2_files, outputs=ctrl2_gallery)
|
|
@@ -746,7 +758,7 @@ def build_ui() -> gr.Blocks:
|
|
| 746 |
run_btn.click(
|
| 747 |
fn=run_training,
|
| 748 |
inputs=[
|
| 749 |
-
output_name, caption, images_input,
|
| 750 |
ctrl0_files, ctrl0_prefix, ctrl0_suffix,
|
| 751 |
ctrl1_files, ctrl1_prefix, ctrl1_suffix,
|
| 752 |
ctrl2_files, ctrl2_prefix, ctrl2_suffix,
|
|
@@ -755,7 +767,7 @@ def build_ui() -> gr.Blocks:
|
|
| 755 |
ctrl5_files, ctrl5_prefix, ctrl5_suffix,
|
| 756 |
ctrl6_files, ctrl6_prefix, ctrl6_suffix,
|
| 757 |
ctrl7_files, ctrl7_prefix, ctrl7_suffix,
|
| 758 |
-
max_epochs, save_every,
|
| 759 |
],
|
| 760 |
outputs=[logs, ckpt_files],
|
| 761 |
)
|
|
|
|
| 322 |
# Inject prefix/suffix flags for metadata creation
|
| 323 |
extra_lines: List[str] = []
|
| 324 |
if (target_prefix or ""):
|
| 325 |
+
extra_lines.append(f" --main_prefix {_bash_quote(target_prefix)} \\")
|
| 326 |
if (target_suffix or ""):
|
| 327 |
+
extra_lines.append(f" --main_suffix {_bash_quote(target_suffix)} \\")
|
| 328 |
for i in range(8):
|
| 329 |
pre = control_prefixes[i] if (control_prefixes and i < len(control_prefixes)) else None
|
| 330 |
suf = control_suffixes[i] if (control_suffixes and i < len(control_suffixes)) else None
|
|
|
|
| 489 |
control7_uploads: Any,
|
| 490 |
ctrl7_prefix: str,
|
| 491 |
ctrl7_suffix: str,
|
| 492 |
+
learning_rate: str,
|
| 493 |
+
network_dim: int,
|
| 494 |
+
seed: int,
|
| 495 |
max_epochs: int,
|
| 496 |
save_every: int,
|
| 497 |
) -> Iterable[tuple]:
|
|
|
|
| 596 |
target_suffix=(target_suffix or ""),
|
| 597 |
control_prefixes=[ctrl0_prefix, ctrl1_prefix, ctrl2_prefix, ctrl3_prefix, ctrl4_prefix, ctrl5_prefix, ctrl6_prefix, ctrl7_prefix],
|
| 598 |
control_suffixes=[ctrl0_suffix, ctrl1_suffix, ctrl2_suffix, ctrl3_suffix, ctrl4_suffix, ctrl5_suffix, ctrl6_suffix, ctrl7_suffix],
|
| 599 |
+
override_learning_rate=(learning_rate or None),
|
| 600 |
+
override_network_dim=int(network_dim) if network_dim is not None else None,
|
| 601 |
+
override_seed=int(seed) if seed is not None else None,
|
| 602 |
)
|
| 603 |
|
| 604 |
|
|
|
|
| 659 |
output_name = gr.Textbox(label="OUTPUT NAME", placeholder="my_lora_output", lines=1)
|
| 660 |
caption = gr.Textbox(label="CAPTION", placeholder="A photo of ...", lines=2)
|
| 661 |
|
| 662 |
+
# Training options near OUTPUT NAME
|
| 663 |
+
with gr.Row():
|
| 664 |
+
lr_input = gr.Textbox(label="Learning rate", value="1e-3")
|
| 665 |
+
dim_input = gr.Number(label="Network dim", value=4, precision=0)
|
| 666 |
+
seed_input = gr.Number(label="Seed", value=42, precision=0)
|
| 667 |
+
max_epochs = gr.Number(label="Max epochs", value=100, precision=0)
|
| 668 |
+
save_every = gr.Number(label="Save every N epochs", value=10, precision=0)
|
| 669 |
+
|
| 670 |
with gr.Row():
|
| 671 |
with gr.Column(scale=3):
|
| 672 |
images_input = gr.File(label="Upload target images", file_count="multiple", type="filepath")
|
|
|
|
| 742 |
logs = gr.Textbox(label="Logs", lines=20)
|
| 743 |
ckpt_files = gr.Files(label="Checkpoints (live)", interactive=False)
|
| 744 |
|
| 745 |
+
# moved max_epochs/save_every above next to OUTPUT NAME
|
|
|
|
|
|
|
| 746 |
|
| 747 |
# Wire previews
|
| 748 |
+
images_input.change(fn=_files_to_gallery, inputs=images_input, outputs=main_gallery)
|
| 749 |
ctrl0_files.change(fn=_files_to_gallery, inputs=ctrl0_files, outputs=ctrl0_gallery)
|
| 750 |
ctrl1_files.change(fn=_files_to_gallery, inputs=ctrl1_files, outputs=ctrl1_gallery)
|
| 751 |
ctrl2_files.change(fn=_files_to_gallery, inputs=ctrl2_files, outputs=ctrl2_gallery)
|
|
|
|
| 758 |
run_btn.click(
|
| 759 |
fn=run_training,
|
| 760 |
inputs=[
|
| 761 |
+
output_name, caption, images_input, main_prefix, main_suffix,
|
| 762 |
ctrl0_files, ctrl0_prefix, ctrl0_suffix,
|
| 763 |
ctrl1_files, ctrl1_prefix, ctrl1_suffix,
|
| 764 |
ctrl2_files, ctrl2_prefix, ctrl2_suffix,
|
|
|
|
| 767 |
ctrl5_files, ctrl5_prefix, ctrl5_suffix,
|
| 768 |
ctrl6_files, ctrl6_prefix, ctrl6_suffix,
|
| 769 |
ctrl7_files, ctrl7_prefix, ctrl7_suffix,
|
| 770 |
+
lr_input, dim_input, seed_input, max_epochs, save_every,
|
| 771 |
],
|
| 772 |
outputs=[logs, ckpt_files],
|
| 773 |
)
|
train_QIE.sh
CHANGED
|
@@ -39,6 +39,13 @@ OUTPUT_DIR_BASE="/workspace/auto/train_LoRA"
|
|
| 39 |
DATASET_CONFIG="/workspace/auto/dataset_QIE.toml"
|
| 40 |
OUTPUT_JSON="${DATASET_DIR%/}/metadata.jsonl"
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
# Build control args from folder names with auto-detect fallback
|
| 43 |
CONTROL_ARGS=()
|
| 44 |
for i in {0..7}; do
|
|
@@ -121,15 +128,15 @@ accelerate launch src/musubi_tuner/qwen_image_train_network.py \
|
|
| 121 |
--weighting_scheme none \
|
| 122 |
--discrete_flow_shift 2.0 \
|
| 123 |
--optimizer_type adamw8bit \
|
| 124 |
-
--learning_rate
|
| 125 |
--gradient_checkpointing \
|
| 126 |
--max_data_loader_n_workers 2 \
|
| 127 |
--persistent_data_loader_workers \
|
| 128 |
--network_module networks.lora_qwen_image \
|
| 129 |
-
--network_dim
|
| 130 |
-
--max_train_epochs
|
| 131 |
-
--save_every_n_epochs
|
| 132 |
-
--seed
|
| 133 |
--output_dir "${OUTPUT_DIR_BASE}/${RUN_NAME}" \
|
| 134 |
--output_name "${RUN_NAME}" \
|
| 135 |
--ddp_gradient_as_bucket_view \
|
|
|
|
| 39 |
DATASET_CONFIG="/workspace/auto/dataset_QIE.toml"
|
| 40 |
OUTPUT_JSON="${DATASET_DIR%/}/metadata.jsonl"
|
| 41 |
|
| 42 |
+
# Training hyperparameters (can be overridden by app)
|
| 43 |
+
LEARNING_RATE="1e-3"
|
| 44 |
+
NETWORK_DIM=4
|
| 45 |
+
SEED=42
|
| 46 |
+
MAX_TRAIN_EPOCHS=100
|
| 47 |
+
SAVE_EVERY_N_EPOCHS=10
|
| 48 |
+
|
| 49 |
# Build control args from folder names with auto-detect fallback
|
| 50 |
CONTROL_ARGS=()
|
| 51 |
for i in {0..7}; do
|
|
|
|
| 128 |
--weighting_scheme none \
|
| 129 |
--discrete_flow_shift 2.0 \
|
| 130 |
--optimizer_type adamw8bit \
|
| 131 |
+
--learning_rate "$LEARNING_RATE" \
|
| 132 |
--gradient_checkpointing \
|
| 133 |
--max_data_loader_n_workers 2 \
|
| 134 |
--persistent_data_loader_workers \
|
| 135 |
--network_module networks.lora_qwen_image \
|
| 136 |
+
--network_dim "$NETWORK_DIM" \
|
| 137 |
+
--max_train_epochs "$MAX_TRAIN_EPOCHS" \
|
| 138 |
+
--save_every_n_epochs "$SAVE_EVERY_N_EPOCHS" \
|
| 139 |
+
--seed "$SEED" \
|
| 140 |
--output_dir "${OUTPUT_DIR_BASE}/${RUN_NAME}" \
|
| 141 |
--output_name "${RUN_NAME}" \
|
| 142 |
--ddp_gradient_as_bucket_view \
|