yeq6x commited on
Commit
26b93ae
·
1 Parent(s): 789adb0

Implement initial project structure and setup

Browse files
app.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import os
3
+ import re
4
+ import shutil
5
+ import subprocess
6
+ import sys
7
+ import tempfile
8
+ from pathlib import Path
9
+ from typing import Dict, Iterable, List, Optional
10
+
11
+ import gradio as gr
12
+
13
+ # Local modules
14
+ from download_qwen_image_models import download_all_models, DEFAULT_MODELS_DIR
15
+
16
+
17
+ # Defaults matching train_QIE.sh expectations
18
+ DEFAULT_DATA_ROOT = "/workspace/data"
19
+ DEFAULT_IMAGE_FOLDER = "image"
20
+ DEFAULT_OUTPUT_DIR_BASE = "/workspace/auto/train_LoRA"
21
+ DEFAULT_DATASET_CONFIG = "/workspace/auto/dataset_QIE.toml"
22
+ DEFAULT_MODELS_ROOT = DEFAULT_MODELS_DIR # "/workspace/Qwen-Image_models"
23
+ WORKSPACE_AUTO_DIR = "/workspace/auto"
24
+
25
+
26
+ TRAINING_DIR = Path(__file__).resolve().parent
27
+
28
+
29
+ def _bash_quote(s: str) -> str:
30
+ """Return a POSIX-safe single-quoted string literal representing s."""
31
+ if s is None:
32
+ return "''"
33
+ return "'" + str(s).replace("'", "'\"'\"'") + "'"
34
+
35
+
36
+ def _ensure_workspace_auto_files() -> None:
37
+ """Ensure /workspace/auto has required helper files from this repo.
38
+
39
+ Copies training/create_image_caption_json.py and training/dataset_QIE.toml
40
+ into /workspace/auto so that train_QIE.sh can run unmodified.
41
+ """
42
+ os.makedirs(WORKSPACE_AUTO_DIR, exist_ok=True)
43
+ src_py = TRAINING_DIR / "create_image_caption_json.py"
44
+ src_toml = TRAINING_DIR / "dataset_QIE.toml"
45
+ dst_py = Path(WORKSPACE_AUTO_DIR) / "create_image_caption_json.py"
46
+ dst_toml = Path(WORKSPACE_AUTO_DIR) / "dataset_QIE.toml"
47
+
48
+ try:
49
+ shutil.copy2(src_py, dst_py)
50
+ except Exception:
51
+ pass
52
+ try:
53
+ if src_toml.exists():
54
+ shutil.copy2(src_toml, dst_toml)
55
+ except Exception:
56
+ pass
57
+
58
+
59
+ def _prepare_script(
60
+ dataset_name: str,
61
+ caption: str,
62
+ data_root: str,
63
+ image_folder: str,
64
+ control_folders: List[Optional[str]],
65
+ models_root: str,
66
+ output_dir_base: Optional[str] = None,
67
+ dataset_config: Optional[str] = None,
68
+ ) -> Path:
69
+ """Create a temporary copy of train_QIE.sh with injected variables.
70
+
71
+ Only variables that must vary per-run are replaced. The rest of the script
72
+ remains as-is to preserve behavior.
73
+ """
74
+ src = TRAINING_DIR / "train_QIE.sh"
75
+ txt = src.read_text(encoding="utf-8")
76
+
77
+ # Replace core variables
78
+ replacements = {
79
+ r"^DATA_ROOT=\".*\"": f"DATA_ROOT={_bash_quote(data_root)}",
80
+ r"^DATASET_NAME=\".*\"": f"DATASET_NAME={_bash_quote(dataset_name)}",
81
+ r"^CAPTION=\".*\"": f"CAPTION={_bash_quote(caption)}",
82
+ r"^IMAGE_FOLDER=\".*\"": f"IMAGE_FOLDER={_bash_quote(image_folder)}",
83
+ }
84
+ if output_dir_base:
85
+ replacements[r"^OUTPUT_DIR_BASE=\".*\""] = (
86
+ f"OUTPUT_DIR_BASE={_bash_quote(output_dir_base)}"
87
+ )
88
+ if dataset_config:
89
+ replacements[r"^DATASET_CONFIG=\".*\""] = (
90
+ f"DATASET_CONFIG={_bash_quote(dataset_config)}"
91
+ )
92
+
93
+ for pat, val in replacements.items():
94
+ txt = re.sub(pat, val, txt, flags=re.MULTILINE)
95
+
96
+ # Inject CONTROL_FOLDER_i if provided (uncomment/override or append)
97
+ for i in range(8):
98
+ val = control_folders[i] if i < len(control_folders) else None
99
+ if not val:
100
+ continue
101
+ # Try to replace commented placeholder first
102
+ pattern = rf"^#\s*CONTROL_FOLDER_{i}=\".*\""
103
+ if re.search(pattern, txt, flags=re.MULTILINE):
104
+ txt = re.sub(
105
+ pattern,
106
+ f"CONTROL_FOLDER_{i}={_bash_quote(val)}",
107
+ txt,
108
+ flags=re.MULTILINE,
109
+ )
110
+ else:
111
+ # Append after IMAGE_FOLDER definition
112
+ txt = re.sub(
113
+ r"^(IMAGE_FOLDER=.*)$",
114
+ rf"\1\nCONTROL_FOLDER_{i}={_bash_quote(val)}",
115
+ txt,
116
+ count=1,
117
+ flags=re.MULTILINE,
118
+ )
119
+
120
+ # Point model paths to the selected models_root
121
+ def _replace_model_path(txt: str, key: str, rel: str) -> str:
122
+ return re.sub(
123
+ rf"--{key} \"[^\"]+\"",
124
+ f"--{key} \"{models_root.rstrip('/')}/{rel}\"",
125
+ txt,
126
+ )
127
+
128
+ txt = _replace_model_path(txt, "vae", "vae/diffusion_pytorch_model.safetensors")
129
+ txt = _replace_model_path(txt, "text_encoder", "text_encoder/qwen_2.5_vl_7b.safetensors")
130
+ txt = _replace_model_path(txt, "dit", "dit/qwen_image_edit_2509_bf16.safetensors")
131
+
132
+ # Write to a temp file alongside this repo for easier inspection
133
+ run_dir = TRAINING_DIR / ".gradio_runs"
134
+ run_dir.mkdir(parents=True, exist_ok=True)
135
+ tmp = run_dir / f"train_QIE_run_{os.getpid()}.sh"
136
+ tmp.write_text(txt, encoding="utf-8", newline="\n")
137
+ try:
138
+ os.chmod(tmp, 0o755)
139
+ except Exception:
140
+ pass
141
+ return tmp
142
+
143
+
144
+ def _pick_shell() -> str:
145
+ for sh in ("bash", "sh"):
146
+ if shutil.which(sh):
147
+ return sh
148
+ raise RuntimeError("No POSIX shell found. Please install bash or sh.")
149
+
150
+
151
+ def run_training(
152
+ dataset_name: str,
153
+ caption: str,
154
+ data_root: str,
155
+ image_folder: str,
156
+ control0: str,
157
+ control1: str,
158
+ control2: str,
159
+ control3: str,
160
+ control4: str,
161
+ control5: str,
162
+ control6: str,
163
+ control7: str,
164
+ models_root: str,
165
+ output_dir_base: str,
166
+ dataset_config: str,
167
+ ) -> Iterable[str]:
168
+ # Basic validation
169
+ if not dataset_name.strip():
170
+ yield "[ERROR] DATASET_NAME is required."
171
+ return
172
+ if not caption.strip():
173
+ yield "[ERROR] CAPTION is required."
174
+ return
175
+
176
+ # Ensure /workspace/auto holds helper files expected by the script
177
+ _ensure_workspace_auto_files()
178
+
179
+ # Prepare script with user parameters
180
+ control_folders = [
181
+ c if c.strip() else None
182
+ for c in [control0, control1, control2, control3, control4, control5, control6, control7]
183
+ ]
184
+ tmp_script = _prepare_script(
185
+ dataset_name=dataset_name.strip(),
186
+ caption=caption,
187
+ data_root=data_root.strip() or DEFAULT_DATA_ROOT,
188
+ image_folder=image_folder.strip() or DEFAULT_IMAGE_FOLDER,
189
+ control_folders=control_folders,
190
+ models_root=models_root.strip() or DEFAULT_MODELS_ROOT,
191
+ output_dir_base=(output_dir_base.strip() or None),
192
+ dataset_config=(dataset_config.strip() or None),
193
+ )
194
+
195
+ shell = _pick_shell()
196
+ yield f"[QIE] Using shell: {shell}"
197
+ yield f"[QIE] Running script: {tmp_script}"
198
+
199
+ # Run and stream output
200
+ proc = subprocess.Popen(
201
+ [shell, str(tmp_script)],
202
+ stdout=subprocess.PIPE,
203
+ stderr=subprocess.STDOUT,
204
+ text=True,
205
+ bufsize=1,
206
+ universal_newlines=True,
207
+ )
208
+ try:
209
+ assert proc.stdout is not None
210
+ for line in proc.stdout:
211
+ yield line.rstrip("\n")
212
+ finally:
213
+ code = proc.wait()
214
+ yield f"[QIE] Exit code: {code}"
215
+
216
+
217
+ def build_ui() -> gr.Blocks:
218
+ with gr.Blocks(title="Qwen-Image-Edit: Trainer") as demo:
219
+ gr.Markdown("""
220
+ # Qwen-Image-Edit Trainer
221
+ - Downloads required models on startup.
222
+ - Generates metadata and trains via `train_QIE.sh`.
223
+ - Paths are POSIX-style (e.g., /workspace/...).
224
+ """)
225
+
226
+ with gr.Row():
227
+ dataset_name = gr.Textbox(label="DATASET_NAME (folder under DATA_ROOT)", placeholder="my_dataset", lines=1)
228
+ caption = gr.Textbox(label="CAPTION", placeholder="A photo of ...", lines=2)
229
+
230
+ with gr.Row():
231
+ data_root = gr.Textbox(label="DATA_ROOT", value=DEFAULT_DATA_ROOT)
232
+ image_folder = gr.Textbox(label="IMAGE_FOLDER", value=DEFAULT_IMAGE_FOLDER)
233
+
234
+ with gr.Accordion("Control folders (optional)", open=False):
235
+ c0 = gr.Textbox(label="CONTROL_FOLDER_0")
236
+ c1 = gr.Textbox(label="CONTROL_FOLDER_1")
237
+ c2 = gr.Textbox(label="CONTROL_FOLDER_2")
238
+ c3 = gr.Textbox(label="CONTROL_FOLDER_3")
239
+ c4 = gr.Textbox(label="CONTROL_FOLDER_4")
240
+ c5 = gr.Textbox(label="CONTROL_FOLDER_5")
241
+ c6 = gr.Textbox(label="CONTROL_FOLDER_6")
242
+ c7 = gr.Textbox(label="CONTROL_FOLDER_7")
243
+
244
+ with gr.Row():
245
+ models_root = gr.Textbox(label="Models root", value=DEFAULT_MODELS_ROOT)
246
+ output_dir_base = gr.Textbox(label="OUTPUT_DIR_BASE", value=DEFAULT_OUTPUT_DIR_BASE)
247
+ dataset_config = gr.Textbox(label="DATASET_CONFIG", value=DEFAULT_DATASET_CONFIG)
248
+
249
+ run_btn = gr.Button("Start Training", variant="primary")
250
+ logs = gr.Textbox(label="Logs", lines=20)
251
+
252
+ run_btn.click(
253
+ fn=run_training,
254
+ inputs=[
255
+ dataset_name, caption, data_root, image_folder,
256
+ c0, c1, c2, c3, c4, c5, c6, c7,
257
+ models_root, output_dir_base, dataset_config,
258
+ ],
259
+ outputs=logs,
260
+ )
261
+
262
+ return demo
263
+
264
+
265
+ def _startup_download_models() -> None:
266
+ models_dir = DEFAULT_MODELS_ROOT
267
+ print(f"[QIE] Ensuring models in: {models_dir}")
268
+ try:
269
+ download_all_models(models_dir)
270
+ except Exception as e:
271
+ print(f"[QIE] Model download failed: {e}")
272
+
273
+
274
+ if __name__ == "__main__":
275
+ # 1) Download models at startup (blocking by design)
276
+ _startup_download_models()
277
+
278
+ # 2) Launch Gradio app
279
+ ui = build_ui()
280
+ ui.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
281
+
create_image_caption_json.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from typing import List, Dict
4
+
5
+
6
+ IMAGE_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.bmp', '.gif')
7
+
8
+
9
+ def _join_posix(base: str, name: str) -> str:
10
+ """Join path segments with forward slashes, avoiding duplicate separators."""
11
+ if not base:
12
+ return name
13
+ return base.rstrip('/') + '/' + name.lstrip('/')
14
+
15
+
16
+ def _list_images(input_folder: str) -> List[str]:
17
+ files = []
18
+ for entry in sorted(os.listdir(input_folder)):
19
+ if entry.lower().endswith(IMAGE_EXTENSIONS):
20
+ full = os.path.join(input_folder, entry)
21
+ if os.path.isfile(full):
22
+ files.append(entry)
23
+ return files
24
+
25
+
26
+ def _validate_controls_strict(filenames: List[str], control_dirs: List[str]) -> None:
27
+ """Strictly validate that for each filename, a corresponding file exists
28
+ in every specified control directory. Raises SystemExit on failure.
29
+ """
30
+ if not control_dirs:
31
+ return
32
+
33
+ missing: List[str] = []
34
+ for fname in filenames:
35
+ for idx, cdir in enumerate(control_dirs):
36
+ expected = os.path.join(cdir, fname)
37
+ if not os.path.exists(expected):
38
+ label = f"control_dir_{idx}"
39
+ missing.append(f"[{label}] {expected}")
40
+
41
+ if missing:
42
+ print("エラー: 以下のファイルが見つかりませんでした(strict):")
43
+ for m in missing:
44
+ print(" - " + m)
45
+ print(f"合計 {len(missing)} 件の不足が見つかりました。処理を中断します。")
46
+ raise SystemExit(1)
47
+
48
+
49
+ def _build_entry(
50
+ image_dir: str,
51
+ control_dirs: List[str],
52
+ caption: str,
53
+ image_filename: str,
54
+ ) -> Dict[str, str]:
55
+ entry: Dict[str, str] = {
56
+ "image_path": _join_posix(image_dir, image_filename),
57
+ "caption": caption,
58
+ }
59
+
60
+ if len(control_dirs) == 1:
61
+ entry["control_path"] = _join_posix(control_dirs[0], image_filename)
62
+ elif len(control_dirs) > 1:
63
+ for i, cdir in enumerate(control_dirs):
64
+ entry[f"control_path_{i}"] = _join_posix(cdir, image_filename)
65
+
66
+ return entry
67
+
68
+
69
+ def create_image_caption_json_unified(
70
+ input_folder: str,
71
+ image_dir: str,
72
+ control_dirs: List[str],
73
+ caption: str,
74
+ output_json: str,
75
+ ) -> None:
76
+ """
77
+ 指定されたフォルダ内の画像ファイルを処理し、キャプションと共にJSONLファイルを作成します。
78
+
79
+ - コントロール無し: {image_path, caption}
80
+ - 単一コントロール(--control_dir_0のみ): {image_path, control_path, caption}
81
+ - 複数コントロール(--control_dir_0..7): {image_path, control_path_0..N, caption}
82
+
83
+ 欠損ファイルは strict に扱い、不足があればエラー終了します。
84
+ """
85
+ filenames = _list_images(input_folder)
86
+
87
+ # strict: すべての control_dir に同名ファイルが存在することを確認
88
+ _validate_controls_strict(filenames, control_dirs)
89
+
90
+ out_dir = os.path.dirname(output_json)
91
+ if out_dir:
92
+ os.makedirs(out_dir, exist_ok=True)
93
+
94
+ count = 0
95
+ with open(output_json, 'w', encoding='utf-8') as f:
96
+ for fname in filenames:
97
+ entry = _build_entry(image_dir, control_dirs, caption, fname)
98
+ f.write(json.dumps(entry, ensure_ascii=False) + '\n')
99
+ count += 1
100
+
101
+ print(f"処理が完了しました。{count}件を書き出しました。結果は {output_json} に保存されました。")
102
+
103
+
104
+ if __name__ == "__main__":
105
+ import argparse
106
+
107
+ parser = argparse.ArgumentParser(description='画像とコントロールの対応JSONLを生成(厳格チェック)。')
108
+ parser.add_argument('-i', '--input-folder', required=True, help='入力ディレクトリ(画像ファイルを列挙)')
109
+ parser.add_argument('-c', '--caption', required=True, help='キャプション')
110
+ parser.add_argument('-o', '--output-json', default='metadata.jsonl', help='出力JSONLパス(既定: metadata.jsonl)')
111
+ parser.add_argument('--image-dir', default='/workspace/data/image', help='image_pathの親ディレクトリパス(JSON出力用)')
112
+
113
+ # 最大 control_dir_0..7 まで受け付け
114
+ for i in range(8):
115
+ parser.add_argument(
116
+ f'--control_dir_{i}',
117
+ dest=f'control_dir_{i}',
118
+ default=None,
119
+ help=f'control_path_{i}の親ディレクトリパス(JSON出力用)',
120
+ )
121
+
122
+ args = parser.parse_args()
123
+
124
+ # 収集: 指定された control_dir_* のみ(順序は0->7)
125
+ control_dirs: List[str] = []
126
+ for i in range(8):
127
+ val = getattr(args, f'control_dir_{i}')
128
+ if val is not None:
129
+ control_dirs.append(val)
130
+
131
+ create_image_caption_json_unified(
132
+ input_folder=args.input_folder,
133
+ image_dir=args.image_dir,
134
+ control_dirs=control_dirs,
135
+ caption=args.caption,
136
+ output_json=args.output_json,
137
+ )
138
+
dataset_QIE.toml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [general]
2
+ resolution = [1024, 1024]
3
+ caption_extension = ".txt"
4
+ batch_size = 1
5
+ enable_bucket = true
6
+ bucket_no_upscale = false
7
+
8
+ [[datasets]]
9
+ image_jsonl_file = "/workspace/data/zeke_/metadata.jsonl"
10
+ cache_directory = "/cache"
11
+ qwen_image_edit_control_resolution = [1024, 1024]
download_qwen_image_models.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import os
3
+ from typing import Dict
4
+ from huggingface_hub import hf_hub_download
5
+
6
+ # Enable hf_transfer for faster downloads
7
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
8
+
9
+ DEFAULT_MODELS_DIR = os.environ.get("QWEN_IMAGE_MODELS_DIR",
10
+ "/workspace/Qwen-Image_models")
11
+ # Temporary download root (requested: use /tmp instead of /workspace)
12
+ TMP_DOWNLOAD_ROOT = os.environ.get("QWEN_IMAGE_TMP_DIR", "/tmp/qie_downloads")
13
+
14
+
15
+ def _ensure_dirs(root: str) -> None:
16
+ os.makedirs(root, exist_ok=True)
17
+ os.makedirs(os.path.join(root, "dit"), exist_ok=True)
18
+ os.makedirs(os.path.join(root, "vae"), exist_ok=True)
19
+ os.makedirs(os.path.join(root, "text_encoder"), exist_ok=True)
20
+
21
+
22
+ def _download_then_place(*, repo_id: str, filename: str, subfolder: str,
23
+ component: str, models_dir: str) -> str:
24
+ """Download to /tmp and move to the final models_dir/component/filename.
25
+
26
+ Returns absolute path to the final placed file.
27
+ """
28
+ # Ensure temp and final directories
29
+ tmp_dir = os.path.join(TMP_DOWNLOAD_ROOT, component)
30
+ os.makedirs(tmp_dir, exist_ok=True)
31
+ final_dir = os.path.join(models_dir, component)
32
+ os.makedirs(final_dir, exist_ok=True)
33
+
34
+ print(f"[QIE] Downloading {component}: {repo_id}/{subfolder}/{filename}")
35
+ tmp_path = hf_hub_download(
36
+ repo_id=repo_id,
37
+ filename=filename,
38
+ subfolder=subfolder,
39
+ local_dir=tmp_dir,
40
+ )
41
+
42
+ final_path = os.path.join(final_dir, filename)
43
+ try:
44
+ # If target exists, replace it
45
+ if os.path.exists(final_path):
46
+ try:
47
+ os.remove(final_path)
48
+ except FileNotFoundError:
49
+ pass
50
+ # Move downloaded file into place
51
+ if os.path.abspath(tmp_path) != os.path.abspath(final_path):
52
+ os.replace(tmp_path, final_path)
53
+ except Exception:
54
+ # Fallback: copy if replace fails (e.g., cross-device)
55
+ import shutil
56
+ shutil.copy2(tmp_path, final_path)
57
+ return final_path
58
+
59
+
60
+ def download_all_models(models_dir: str = DEFAULT_MODELS_DIR) -> Dict[str, str]:
61
+ """Download required Qwen-Image-Edit models into models_dir.
62
+
63
+ Returns a dict of component -> local file path.
64
+ """
65
+ _ensure_dirs(models_dir)
66
+
67
+ print(f"[QIE] Models directory: {models_dir}")
68
+
69
+ print("[QIE] Download dir (tmp):", TMP_DOWNLOAD_ROOT)
70
+ print("[QIE] Final models dir:", models_dir)
71
+
72
+ # Download to /tmp then move to final path
73
+ dit_path = _download_then_place(
74
+ repo_id="Comfy-Org/Qwen-Image-Edit_ComfyUI",
75
+ filename="qwen_image_edit_2509_bf16.safetensors",
76
+ subfolder="split_files/diffusion_models",
77
+ component="dit",
78
+ models_dir=models_dir,
79
+ )
80
+
81
+ print("[QIE] Downloading VAE model(s)...")
82
+ vae_main = _download_then_place(
83
+ repo_id="Qwen/Qwen-Image-Edit",
84
+ filename="diffusion_pytorch_model.safetensors",
85
+ subfolder="vae",
86
+ component="vae",
87
+ models_dir=models_dir,
88
+ )
89
+ vae_alt = _download_then_place(
90
+ repo_id="Comfy-Org/Qwen-Image_ComfyUI",
91
+ filename="qwen_image_vae.safetensors",
92
+ subfolder="split_files/vae",
93
+ component="vae",
94
+ models_dir=models_dir,
95
+ )
96
+
97
+ print("[QIE] Downloading Text Encoder...")
98
+ te_path = _download_then_place(
99
+ repo_id="Comfy-Org/Qwen-Image_ComfyUI",
100
+ filename="qwen_2.5_vl_7b.safetensors",
101
+ subfolder="split_files/text_encoders",
102
+ component="text_encoder",
103
+ models_dir=models_dir,
104
+ )
105
+
106
+ print("[QIE] All models downloaded successfully!")
107
+ return {
108
+ "dit": dit_path,
109
+ "vae_main": vae_main,
110
+ "vae_alt": vae_alt,
111
+ "text_encoder": te_path,
112
+ "root": models_dir,
113
+ }
114
+
115
+
116
+ if __name__ == "__main__":
117
+ download_all_models()
train_QIE.sh ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # ==============================
4
+ # Generate metadata.jsonl before training
5
+ # Configure variables directly in this file.
6
+ # No environment variable overrides are used.
7
+ # ==============================
8
+
9
+ DATA_ROOT="/workspace/data"
10
+ DATASET_NAME=""
11
+
12
+ # Required inputs
13
+ CAPTION=""
14
+
15
+ IMAGE_FOLDER="image"
16
+
17
+ # CONTROL_FOLDER_0=""
18
+ # CONTROL_FOLDER_1=""
19
+ # CONTROL_FOLDER_2=""
20
+ # CONTROL_FOLDER_3=""
21
+ # CONTROL_FOLDER_4=""
22
+ # CONTROL_FOLDER_5=""
23
+ # CONTROL_FOLDER_6=""
24
+ # CONTROL_FOLDER_7=""
25
+
26
+ RUN_NAME="${DATASET_NAME%/}"
27
+ DATASET_DIR="${DATA_ROOT%/}/${DATASET_NAME}"
28
+
29
+ OUTPUT_DIR_BASE="/workspace/auto/train_LoRA"
30
+ DATASET_CONFIG="/workspace/auto/dataset_QIE.toml"
31
+ OUTPUT_JSON="${DATASET_DIR%/}/metadata.jsonl"
32
+
33
+ # Build control args from folder names with auto-detect fallback
34
+ CONTROL_ARGS=()
35
+ for i in {0..7}; do
36
+ var="CONTROL_FOLDER_${i}"
37
+ folder_name=${!var}
38
+ cpath=""
39
+ if [[ -n "$folder_name" ]]; then
40
+ cpath="${DATASET_DIR%/}/$folder_name"
41
+ elif [[ -d "${DATASET_DIR%/}/control_${i}" ]]; then
42
+ cpath="${DATASET_DIR%/}/control_${i}"
43
+ elif [[ $i -eq 0 && -d "${DATASET_DIR%/}/control" ]]; then
44
+ # Special fallback: allow single control folder named "control" for control_0
45
+ cpath="${DATASET_DIR%/}/control"
46
+ fi
47
+ [[ -n "$cpath" ]] && CONTROL_ARGS+=("--control_dir_${i}" "$cpath")
48
+ done
49
+
50
+ # Sync dataset config's image_jsonl_file with OUTPUT_JSON if present
51
+ if [[ -f "$DATASET_CONFIG" ]]; then
52
+ python - "$DATASET_CONFIG" "$OUTPUT_JSON" <<'PY'
53
+ import sys, re
54
+ path, out = sys.argv[1], sys.argv[2]
55
+ txt = open(path, 'r', encoding='utf-8').read()
56
+ new = re.sub(r"(?m)^\s*image_jsonl_file\s*=.*$", f'image_jsonl_file = "{out}"', txt)
57
+ if new == txt and 'image_jsonl_file' not in txt:
58
+ new = txt.rstrip('\n') + f"\nimage_jsonl_file = \"{out}\"\n"
59
+ open(path, 'w', encoding='utf-8').write(new)
60
+ print(f"[QIE] Updated {path}: image_jsonl_file -> {out}")
61
+ PY
62
+ else
63
+ echo "[QIE] WARN: Dataset config not found at $DATASET_CONFIG. Ensure it points to $OUTPUT_JSON"
64
+ fi
65
+
66
+ cd /workspace/auto
67
+
68
+ echo "[QIE] Generating metadata: $OUTPUT_JSON"
69
+ python create_image_caption_json.py \
70
+ -i "${DATASET_DIR%/}/${IMAGE_FOLDER}" \
71
+ -c "$CAPTION" \
72
+ -o "$OUTPUT_JSON" \
73
+ --image-dir "${DATASET_DIR%/}/${IMAGE_FOLDER}" \
74
+ "${CONTROL_ARGS[@]}"
75
+
76
+ cd /musubi-tuner
77
+
78
+ python qwen_image_cache_latents.py \
79
+ --dataset_config "$DATASET_CONFIG" \
80
+ --vae "/workspace/Qwen-Image_models/vae/diffusion_pytorch_model.safetensors" \
81
+ --edit_plus \
82
+ --vae_spatial_tile_sample_min_size 16384
83
+
84
+ python qwen_image_cache_text_encoder_outputs.py \
85
+ --dataset_config "$DATASET_CONFIG" \
86
+ --text_encoder "/workspace/Qwen-Image_models/text_encoder/qwen_2.5_vl_7b.safetensors" \
87
+ --edit_plus \
88
+ --batch_size 16
89
+
90
+ accelerate launch src/musubi_tuner/qwen_image_train_network.py \
91
+ --edit_plus \
92
+ --dit "/workspace/Qwen-Image_models/dit/qwen_image_edit_2509_bf16.safetensors" \
93
+ --vae "/workspace/Qwen-Image_models/vae/diffusion_pytorch_model.safetensors" \
94
+ --text_encoder "/workspace/Qwen-Image_models/text_encoder/qwen_2.5_vl_7b.safetensors" \
95
+ --dataset_config "$DATASET_CONFIG" \
96
+ --mixed_precision bf16 \
97
+ --sdpa \
98
+ --timestep_sampling shift \
99
+ --weighting_scheme none \
100
+ --discrete_flow_shift 2.0 \
101
+ --optimizer_type adamw8bit \
102
+ --learning_rate 1e-3 \
103
+ --gradient_checkpointing \
104
+ --max_data_loader_n_workers 2 \
105
+ --persistent_data_loader_workers \
106
+ --network_module networks.lora_qwen_image \
107
+ --network_dim 4 \
108
+ --max_train_epochs 100 \
109
+ --save_every_n_epochs 10 \
110
+ --seed 42 \
111
+ --output_dir "${OUTPUT_DIR_BASE}/${RUN_NAME}" \
112
+ --output_name "${RUN_NAME}" \
113
+ --ddp_gradient_as_bucket_view \
114
+ --ddp_static_graph