yiren98 commited on
Commit
36ed92b
·
verified ·
1 Parent(s): f406903

Upload 17 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ id_rsa
37
+ __pycache__
38
+ *.egg-info
39
+ .vscode
40
+ wandb
41
+ Merge
42
+ asy_results
43
+ recraft_results
44
+ drop
45
+ SplitAsy
46
+ example*
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Show Lab
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,14 +1,239 @@
1
- ---
2
- title: MakeAnything
3
- emoji: 🖼
4
- colorFrom: purple
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 5.0.1
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: Multi-Domain Procedural Sequence Generation
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MakeAnything
2
+
3
+ > **MakeAnything: Harnessing Diffusion Transformers for Multi-Domain Procedural Sequence Generation**
4
+ > <br>
5
+ > [Yiren Song](https://scholar.google.com.hk/citations?user=L2YS0jgAAAAJ),
6
+ > [Cheng Liu](https://scholar.google.com.hk/citations?hl=zh-CN&user=TvdVuAYAAAAJ),
7
+ > and
8
+ > [Mike Zheng Shou](https://sites.google.com/view/showlab)
9
+ > <br>
10
+ > [Show Lab](https://sites.google.com/view/showlab), National University of Singapore
11
+ > <br>
12
+
13
+ <a href="https://arxiv.org/abs/2502.01572"><img src="https://img.shields.io/badge/ariXv-2411.15098-A42C25.svg" alt="arXiv"></a>
14
+ <a href="https://huggingface.co/showlab/makeanything"><img src="https://img.shields.io/badge/🤗_HuggingFace-Model-ffbd45.svg" alt="HuggingFace"></a>
15
+ <a href="https://huggingface.co/datasets/showlab/makeanything/"><img src="https://img.shields.io/badge/🤗_HuggingFace-Dataset-ffbd45.svg" alt="HuggingFace"></a>
16
+
17
+ <br>
18
+
19
+ <img src='./images/teaser.png' width='100%' />
20
+
21
+
22
+ ## Configuration
23
+ ### 1. **Environment setup**
24
+ ```bash
25
+ git clone https://github.com/showlab/MakeAnything.git
26
+ cd MakeAnything
27
+
28
+ conda create -n makeanything python=3.11.10
29
+ conda activate makeanything
30
+ ```
31
+ ### 2. **Requirements installation**
32
+ ```bash
33
+ pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124
34
+ pip install --upgrade -r requirements.txt
35
+
36
+ accelerate config
37
+ ```
38
+
39
+ ## Asymmetric LoRA
40
+ ### 1. Weights
41
+ You can download the trained checkpoints of Asymmetric LoRA & LoRA for inference. Below are the details of available models:
42
+
43
+ | **Model** | **Description** | **Resolution** |
44
+ |:-:|:-:|:-:|
45
+ | [asylora_9f_general](https://huggingface.co/showlab/makeanything/blob/main/asymmetric_lora/asymmetric_lora_9f_general.safetensors) | The Asymmetric LoRA has been fine-tuned on all 9-frames datasets. *Index of lora_up*: `1:LEGO` `2:Cook` `3:Painting` `4:Icon` `5:Landscape illustration` `6:Portrait` `7:Transformer` `8:Sand art` `9:Illustration` `10:Sketch` | 1056,1056 |
46
+ | [asylora_4f_general](https://huggingface.co/showlab/makeanything/blob/main/asymmetric_lora/asymmetric_lora_4f_general.safetensors) | The Asymmetric LoRA has been fine-tuned on all 4-frames datasets. *Index of lora_up: (1~10 same as 9f)* `11:Clay toys` `12:Clay sculpture` `13:Zbrush Modeling` `14:Wood sculpture` `15:Ink painting` `16:Pencil sketch` `17:Fabric toys` `18:Oil painting` `19:Jade Carving` `20:Line draw` `21:Emoji` | 1024,1024 |
47
+
48
+ ### 2. Training
49
+ <span id="dataset_setting"></span>
50
+ #### 2.1 Settings for dataset
51
+ The training process relies on paired dataset consisting of text captions and images. Each dataset folder contains both `.caption` and `.png` files, where the filenames of the caption files correspond directly to the image filenames. Here is an example of the organized dataset.
52
+
53
+ ```
54
+ dataset/
55
+ ├── portrait_001.png
56
+ ├── portrait_001.caption
57
+ ├── portrait_002.png
58
+ ├── portrait_002.caption
59
+ ├── lego_001.png
60
+ ├── lego_001.caption
61
+ ```
62
+
63
+ The `.caption` files contain a **single line** of text that serves as a prompt for generating the corresponding image. The prompt **must specify the index of the lora_up** used for that particular training sample in the Asymmetric LoRA. The format for this is `--lora_up <index>`, where `<index>` is the current B matrices index in the Asymmetric LoRA, refers to the certain domain used in the training, and index should **start from 1**, not 0.
64
+
65
+ For example, a .caption file for a portrait painting sequence might look as follows:
66
+
67
+ ```caption
68
+ 3*3 of 9 sub-images, step-by-step portrait painting process, 1 girl --lora_up 6
69
+ ```
70
+
71
+ Then, you should organize your **dataset configuration file** written in `TOML`. Here is an example:
72
+
73
+ ```toml
74
+ [general]
75
+ enable_bucket = false
76
+
77
+ [[datasets]]
78
+ resolution = 1056
79
+ batch_size = 1
80
+
81
+ [[datasets.subsets]]
82
+ image_dir = '/path/to/dataset/'
83
+ caption_extension = '.caption'
84
+ num_repeats = 1
85
+ ```
86
+
87
+ It is recommended to set batch size to 1 and set resolution to 1024 (4-frames) or 1056 (9-frames).
88
+
89
+ #### 2.2 Start training
90
+ We have provided a template file for training Asymmetric LoRA in `scripts/asylora_train.sh`. Simply replace corresponding paths with yours to start the training. Note that `lora_ups_num` in the script is the total number of B matrices used in Asymmetric LoRA that you specified during training.
91
+
92
+ ```bash
93
+ chmod +x scripts/asylora_train.sh
94
+ scripts/asylora_train.sh
95
+ ```
96
+
97
+ Additionally, if you are directly **using our dataset for training**, the `.caption` files in our released dataset do not specify the `--lora_up <index>` field. You will need to organize and update the `.caption` files to include the appropriate `--lora_up <index>` values before starting the training.
98
+
99
+ ### 3. Inference
100
+ We have also provided a template file for inference Asymmetric LoRA in `scripts/asylora_inference.sh`. Once the training is done, replace file paths, fill in your prompt and run inference. Note that `lora_up_cur` in the script is the current number of B matrices index to be used for inference.
101
+
102
+ ```bash
103
+ chmod +x scripts/asylora_inference.sh
104
+ scripts/asylora_train.sh
105
+ ```
106
+
107
+
108
+ ## Recraft Model
109
+ ### 1. Weights
110
+ You can download the trained checkpoints of Recraft Model for inference. Below are the details of available models:
111
+ | **Model** | **Description** | **Resolution** |
112
+ |:-:|:-:|:-:|
113
+ | [recraft_9f_lego ](https://huggingface.co/showlab/makeanything/blob/main/recraft/recraft_9f_lego.safetensors) | The Recraft Model has been trained on `LEGO` dataset. Support `9-frames` generation. | 1056,1056 |
114
+ | [recraft_9f_portrait ](https://huggingface.co/showlab/makeanything/blob/main/recraft/recraft_9f_portrait.safetensors) | The Recraft Model has been trained on `Portrait` dataset. Support `9-frames` generation. | 1056,1056 |
115
+ | [recraft_9f_sketch ](https://huggingface.co/showlab/makeanything/blob/main/recraft/recraft_9f_sketch.safetensors) | The Recraft Model has been trained on `Sketch` dataset. Support `9-frames` generation. | 1056,1056 |
116
+ | [recraft_4f_wood_sculpture ](https://huggingface.co/showlab/makeanything/blob/main/recraft/recraft_4f_wood_sculpture.safetensors) | The Recraft Model has been trained on `Wood sculpture` dataset. Support `4-frames` generation. | 1024,1024 |
117
+
118
+ ### 2. Training
119
+ #### 2.1 Obtain standard LoRA
120
+ During the second phase of training the image-to-sequence generation with the Recraft model, we need to apply a **standard LoRA architecture** to be merged to flux.1 before performing the Recraft training. Therefore, the first step is to decompose the Asymmetric LoRA into the original LoRA format.
121
+
122
+ To achieve this, **train a standard LoRA directly** (optional method below) or we have provided a script template in `scripts/asylora_split.sh` for **splitting the Asymmetric LoRA**. The script allows you to extract the required B matrices from the Asymmetric LoRA model. Specifically, the `LORA_UP` in the script specifies the index of the B matrices you wish to extract for use as the original LoRA.
123
+
124
+ ```bash
125
+ chmod +x scripts/asylora_split.sh
126
+ scripts/asylora_split.sh
127
+ ```
128
+
129
+ #### (Optional) Train standard LoRA
130
+ You can also **directly train a standard LoRA** for Recraft process, eliminating the need to decompose the Asymmetric LoRA. In our project, we have included the standard LoRA training code from [kohya-ss/sd-scripts](https://github.com/sd-scripts) in the files `flux_train_network.py` for training and `flux_minimal_inference.py` for inference. You can refer to the related documentation for guidance on how to train.
131
+
132
+ Alternatively, using other training platforms like [kijai/ComfyUI-FluxTrainer](https://github.com/ComfyUI-FluxTrainer) is also a viable option. These platforms provide tools to facilitate the training and inference of LoRA models for the Recraft process.
133
+
134
+ #### 2.2 Merge LoRA to flux.1
135
+ Now you have obtained a standard LoRA, use our `scripts/lora_merge.sh` template script to merge the LoRA to flux.1 checkpoints for further recraft training. Note that the merged model may take up **around 50GB** of your memory space.
136
+
137
+ ```bash
138
+ chmod +x scripts/lora_merge.sh
139
+ scripts/lora_merge.sh
140
+ ```
141
+ #### 2.3 Settings for training
142
+
143
+ The dataset structure for Recraft training follows the same organization format as the dataset for Asymmetric LoRA, specifically described in [Asymmetric LoRA 2.1 Settings for dataset](#dataset_setting). A `TOML` configuration file is also required to organize and configure the dataset. Below is a template for the dataset configuration file:
144
+
145
+ ```toml
146
+ [general]
147
+ flip_aug = false
148
+ color_aug = false
149
+ keep_tokens_separator = "|||"
150
+ shuffle_caption = false
151
+ caption_tag_dropout_rate = 0
152
+ caption_extension = ".caption"
153
+
154
+ [[datasets]]
155
+ batch_size = 1
156
+ enable_bucket = true
157
+ resolution = [1024, 1024]
158
+
159
+ [[datasets.subsets]]
160
+ image_dir = "/path/to/dataset/"
161
+ num_repeats = 1
162
+ ```
163
+
164
+ Note that for training with 4-frame step sequences, the resolution must be set to `1024`. For training with 9-frame steps, the resolution should be `1056`.
165
+
166
+ For the sampling phase of the Recraft training process, we need to organize two text files: `sample_images.txt` and `sample_prompts.txt`. These files will store the sampled condition images and their corresponding prompts, respectively. Below are the templates for both files:
167
+
168
+ **sample_images.txt**
169
+ ```txt
170
+ /path/to/image_1.png
171
+ /path/to/image_2.png
172
+ ```
173
+
174
+ **sample_prompts.txt**
175
+ ```txt
176
+ image_1_prompt_content
177
+ image_2_prompt_content
178
+ ```
179
+ #### 2.4 Recraft training
180
+ We have provided a template file for training Recraft Model in `scripts/recraft_train.sh`. Simply replace corresponding paths with yours to start the training. Note that `frame_num` in the script must be `4` (for 1024 resolution) or `9` (for 1056 resolution).
181
+
182
+ ```bash
183
+ chmod +x scripts/asylora_train.sh
184
+ scripts/asylora_train.sh
185
+ ```
186
+
187
+ ### 3. Inference
188
+ We have also provided a template file for inference Recraft Model in `scripts/recraft_inference.sh`. Once the training is done, replace file paths, fill in your prompt and run inference.
189
+
190
+ ```bash
191
+ chmod +x scripts/asylora_inference.sh
192
+ scripts/asylora_train.sh
193
+ ```
194
+
195
+ ## Datasets
196
+
197
+ We have uploaded our datasets on [Hugging Face](https://huggingface.co/datasets/showlab/makeanything/). The datasets includes both 4-frame and 9-frame sequence images, covering a total of 21 domains of procedural sequences. For MakeAnything training, each domain consists of **50 sequences**, with resolutions of either **1024 (4-frame)** or **1056 (9-frame)**. Additionally, we provide an extensive collection of SVG datasets and Sketch datasets for further research and experimentation.
198
+
199
+ Note that the arrangement of **9-frame sequences follows an S-shape pattern**, whereas **4-frame sequences follow a ɔ-shape pattern**.
200
+
201
+ <details>
202
+ <summary>Click to preview the datasets</summary>
203
+ <br>
204
+
205
+ | Domain | Preview | Quantity | Domain | Preview | Quantity |
206
+ |:--------:|:---------:|:----------:|:--------:|:---------:|:----------:|
207
+ | LEGO | ![LEGO Preview](./images/datasets/lego.png) | 50 | Cook | ![Cook Preview](./images/datasets/cook.png) | 50 |
208
+ | Painting | ![Painting Preview](./images/datasets/painting.png) | 50 | Icon | ![Icon Preview](./images/datasets/icon.png) | 50+1.4k |
209
+ | Landscape Illustration | ![Landscape Illustration Preview](./images/datasets/landscape.png) | 50 | Portrait | ![Portrait Preview](./images/datasets/portrait.png) | 50+2k |
210
+ | Transformer | ![Transformer Preview](./images/datasets/transformer.png) | 50 | Sand Art | ![Sand Art Preview](./images/datasets/sandart.png) | 50 |
211
+ | Illustration | ![Illustration Preview](./images/datasets/illustration.png) | 50 | Sketch | ![Sketch Preview](./images/datasets/sketch.png) | 50+9k |
212
+ | Clay Toys | ![Clay Toys Preview](./images/datasets/claytoys.png) | 50 | Clay Sculpture | ![Clay Sculpture Preview](./images/datasets/claysculpture.png) | 50 |
213
+ | ZBrush Modeling | ![ZBrush Modeling Preview](./images/datasets/zbrush.png) | 50 | Wood Sculpture | ![Wood Sculpture Preview](./images/datasets/woodsculpture.png) | 50 |
214
+ | Ink Painting | ![Ink Painting Preview](./images/datasets/inkpainting.png) | 50 | Pencil Sketch | ![Pencil Sketch Preview](./images/datasets/pencilsketch.png) | 50 |
215
+ | Fabric Toys | ![Fabric Toys Preview](./images/datasets/fabrictoys.png) | 50 | Oil Painting | ![Oil Painting Preview](./images/datasets/oilpainting.png) | 50 |
216
+ | Jade Carving | ![Jade Carving Preview](./images/datasets/jadecarving.png) | 50 | Line Draw | ![Line Draw Preview](./images/datasets/linedraw.png) | 50 |
217
+ | Emoji | ![Emoji Preview](./images/datasets/emoji.png) | 50+12k | | | |
218
+
219
+ </details>
220
+
221
+ ## Results
222
+ ### Text-to-Sequence Generation (LoRA & Asymmetric LoRA)
223
+ <img src='./images/t2i.png' width='100%' />
224
+
225
+ ### Image-to-Sequence Generation (Recraft Model)
226
+ <img src='./images/i2i.png' width='100%' />
227
+
228
+ ### Generalization on Unseen Domains
229
+ <img src='./images/oneshot.png' width='100%' />
230
+
231
+ ## Citation
232
+ ```
233
+ @inproceedings{Song2025MakeAnythingHD,
234
+ title={MakeAnything: Harnessing Diffusion Transformers for Multi-Domain Procedural Sequence Generation},
235
+ author={Yiren Song and Cheng Liu and Mike Zheng Shou},
236
+ year={2025},
237
+ url={https://api.semanticscholar.org/CorpusID:276107845}
238
+ }
239
+ ```
flux_inference_recraft.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import copy
3
+ import math
4
+ import random
5
+ from typing import Any
6
+ import pdb
7
+ import os
8
+
9
+ import time
10
+ from PIL import Image, ImageOps
11
+
12
+ import torch
13
+ from accelerate import Accelerator
14
+ from library.device_utils import clean_memory_on_device
15
+ from safetensors.torch import load_file
16
+ from networks import lora_flux
17
+
18
+ from library import flux_models, flux_train_utils_recraft as flux_train_utils, flux_utils, sd3_train_utils, \
19
+ strategy_base, strategy_flux, train_util
20
+ from torchvision import transforms
21
+ import train_network
22
+ from library.utils import setup_logging
23
+ from diffusers.utils import load_image
24
+ import numpy as np
25
+
26
+ setup_logging()
27
+ import logging
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ def load_target_model(
33
+ fp8_base: bool,
34
+ pretrained_model_name_or_path: str,
35
+ disable_mmap_load_safetensors: bool,
36
+ clip_l_path: str,
37
+ fp8_base_unet: bool,
38
+ t5xxl_path: str,
39
+ ae_path: str,
40
+ weight_dtype: torch.dtype,
41
+ accelerator: Accelerator
42
+ ):
43
+ # Determine the loading data type
44
+ loading_dtype = None if fp8_base else weight_dtype
45
+
46
+ # Load the main model to the accelerator's device
47
+ _, model = flux_utils.load_flow_model(
48
+ pretrained_model_name_or_path,
49
+ # loading_dtype,
50
+ torch.float8_e4m3fn,
51
+ # accelerator.device, # Changed from "cpu" to accelerator.device
52
+ "cpu",
53
+ disable_mmap=disable_mmap_load_safetensors
54
+ )
55
+
56
+ if fp8_base:
57
+ # Check dtype of the model
58
+ if model.dtype in {torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz}:
59
+ raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
60
+ elif model.dtype == torch.float8_e4m3fn:
61
+ logger.info("Loaded fp8 FLUX model")
62
+
63
+ # Load the CLIP model to the accelerator's device
64
+ clip_l = flux_utils.load_clip_l(
65
+ clip_l_path,
66
+ weight_dtype,
67
+ # accelerator.device, # Changed from "cpu" to accelerator.device
68
+ "cpu",
69
+ disable_mmap=disable_mmap_load_safetensors
70
+ )
71
+ clip_l.eval()
72
+
73
+ # Determine the loading data type for T5XXL
74
+ if fp8_base and not fp8_base_unet:
75
+ loading_dtype_t5xxl = None # as is
76
+ else:
77
+ loading_dtype_t5xxl = weight_dtype
78
+
79
+ # Load the T5XXL model to the accelerator's device
80
+ t5xxl = flux_utils.load_t5xxl(
81
+ t5xxl_path,
82
+ loading_dtype_t5xxl,
83
+ # accelerator.device, # Changed from "cpu" to accelerator.device
84
+ "cpu",
85
+ disable_mmap=disable_mmap_load_safetensors
86
+ )
87
+ t5xxl.eval()
88
+
89
+ if fp8_base and not fp8_base_unet:
90
+ # Check dtype of the T5XXL model
91
+ if t5xxl.dtype in {torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz}:
92
+ raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}")
93
+ elif t5xxl.dtype == torch.float8_e4m3fn:
94
+ logger.info("Loaded fp8 T5XXL model")
95
+
96
+ # Load the AE model to the accelerator's device
97
+ ae = flux_utils.load_ae(
98
+ ae_path,
99
+ weight_dtype,
100
+ # accelerator.device, # Changed from "cpu" to accelerator.device
101
+ "cpu",
102
+ disable_mmap=disable_mmap_load_safetensors
103
+ )
104
+
105
+ # # Wrap models with Accelerator for potential distributed setups
106
+ # model, clip_l, t5xxl, ae = accelerator.prepare(model, clip_l, t5xxl, ae)
107
+
108
+ return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model
109
+
110
+
111
+ import torchvision.transforms as transforms
112
+
113
+
114
+ class ResizeWithPadding:
115
+ def __init__(self, size, fill=255):
116
+ self.size = size
117
+ self.fill = fill
118
+
119
+ def __call__(self, img):
120
+ if isinstance(img, np.ndarray):
121
+ img = Image.fromarray(img)
122
+ elif not isinstance(img, Image.Image):
123
+ raise TypeError("Input must be a PIL Image or a NumPy array")
124
+
125
+ width, height = img.size
126
+
127
+ if width == height:
128
+ img = img.resize((self.size, self.size), Image.LANCZOS)
129
+ else:
130
+ max_dim = max(width, height)
131
+
132
+ new_img = Image.new("RGB", (max_dim, max_dim), (self.fill, self.fill, self.fill))
133
+ new_img.paste(img, ((max_dim - width) // 2, (max_dim - height) // 2))
134
+
135
+ img = new_img.resize((self.size, self.size), Image.LANCZOS)
136
+
137
+ return img
138
+
139
+
140
+ def sample(args, accelerator, vae, text_encoder, flux, output_dir, sample_images, sample_prompts):
141
+ def encode_images_to_latents(vae, images):
142
+ # Get image dimensions
143
+ b, c, h, w = images.shape
144
+ num_split = 2 if args.frame_num == 4 else 3
145
+ # Split the image into three parts
146
+ img_parts = [images[:, :, :, i * w // num_split:(i + 1) * w // num_split] for i in range(num_split)]
147
+ # Encode each part
148
+ latents = [vae.encode(img) for img in img_parts]
149
+ # Concatenate latents in the latent space to reconstruct the full image
150
+ latents = torch.cat(latents, dim=-1)
151
+ return latents
152
+
153
+ def encode_images_to_latents2(vae, images):
154
+ latents = vae.encode(images)
155
+ return latents
156
+
157
+ # Directly use precomputed conditions
158
+ conditions = {}
159
+ with torch.no_grad():
160
+ for image_path, prompt_dict in zip(sample_images, sample_prompts):
161
+ prompt = prompt_dict.get("prompt", "")
162
+ if prompt not in conditions:
163
+ logger.info(f"Cache conditions for image: {image_path} with prompt: {prompt}")
164
+ resize_transform = ResizeWithPadding(size=512, fill=255) if args.frame_num == 4 else ResizeWithPadding(size=352, fill=255)
165
+ img_transforms = transforms.Compose([
166
+ resize_transform,
167
+ transforms.ToTensor(),
168
+ transforms.Normalize([0.5], [0.5]),
169
+ ])
170
+ # Load and preprocess image
171
+ image = img_transforms(np.array(load_image(image_path), dtype=np.uint8)).unsqueeze(0).to(
172
+ # accelerator.device, # Move image to CUDA
173
+ vae.device,
174
+ dtype=vae.dtype
175
+ )
176
+ latents = encode_images_to_latents2(vae, image)
177
+
178
+ # Log the shape of latents
179
+ logger.debug(f"Encoded latents shape for prompt '{prompt}': {latents.shape}")
180
+ # Store conditions on CUDA
181
+ # conditions[prompt] = latents[:,:,latents.shape[2]//2:latents.shape[2], :latents.shape[3]//2].to("cpu")
182
+ conditions[prompt] = latents.to("cpu")
183
+
184
+ sample_conditions = conditions
185
+
186
+ if sample_conditions is not None:
187
+ conditions = {k: v for k, v in sample_conditions.items()} # Already on CUDA
188
+
189
+ sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
190
+ text_encoder[0].to(accelerator.device)
191
+ text_encoder[1].to(accelerator.device)
192
+
193
+ tokenize_strategy = strategy_flux.FluxTokenizeStrategy(512)
194
+ text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(True)
195
+
196
+ with accelerator.autocast(), torch.no_grad():
197
+ for prompt_dict in sample_prompts:
198
+ for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
199
+ if p not in sample_prompts_te_outputs:
200
+ logger.info(f"Cache Text Encoder outputs for prompt: {p}")
201
+ tokens_and_masks = tokenize_strategy.tokenize(p)
202
+ sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
203
+ tokenize_strategy, text_encoder, tokens_and_masks, True
204
+ )
205
+
206
+ logger.info(f"Generating image")
207
+ save_dir = output_dir
208
+ os.makedirs(save_dir, exist_ok=True)
209
+
210
+ with torch.no_grad(), accelerator.autocast():
211
+ for prompt_dict in sample_prompts:
212
+ sample_image_inference(
213
+ args,
214
+ accelerator,
215
+ flux,
216
+ text_encoder,
217
+ vae,
218
+ save_dir,
219
+ prompt_dict,
220
+ sample_prompts_te_outputs,
221
+ None,
222
+ conditions
223
+ )
224
+
225
+ clean_memory_on_device(accelerator.device)
226
+
227
+
228
+ def sample_image_inference(
229
+ args,
230
+ accelerator: Accelerator,
231
+ flux: flux_models.Flux,
232
+ text_encoder,
233
+ ae: flux_models.AutoEncoder,
234
+ save_dir,
235
+ prompt_dict,
236
+ sample_prompts_te_outputs,
237
+ prompt_replacement,
238
+ sample_images_ae_outputs
239
+ ):
240
+ # Extract parameters from prompt_dict
241
+ sample_steps = prompt_dict.get("sample_steps", 20)
242
+ width = prompt_dict.get("width", 1024) if args.frame_num == 4 else prompt_dict.get("width", 1056)
243
+ height = prompt_dict.get("height", 1024) if args.frame_num == 4 else prompt_dict.get("height", 1056)
244
+ scale = prompt_dict.get("scale", 1.0)
245
+ seed = prompt_dict.get("seed")
246
+ prompt: str = prompt_dict.get("prompt", "")
247
+
248
+ if prompt_replacement is not None:
249
+ prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
250
+
251
+ if seed is not None:
252
+ torch.manual_seed(seed)
253
+ torch.cuda.manual_seed(seed)
254
+ else:
255
+ # True random sample image generation
256
+ torch.seed()
257
+ torch.cuda.seed()
258
+
259
+ # Ensure height and width are divisible by 16
260
+ height = max(64, height - height % 16)
261
+ width = max(64, width - width % 16)
262
+ logger.info(f"prompt: {prompt}")
263
+ logger.info(f"height: {height}")
264
+ logger.info(f"width: {width}")
265
+ logger.info(f"sample_steps: {sample_steps}")
266
+ logger.info(f"scale: {scale}")
267
+ if seed is not None:
268
+ logger.info(f"seed: {seed}")
269
+
270
+ # Encode prompts
271
+ # Assuming that TokenizeStrategy and TextEncodingStrategy are compatible with Accelerator
272
+ text_encoder_conds = []
273
+ if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs:
274
+ text_encoder_conds = sample_prompts_te_outputs[prompt]
275
+ logger.info(f"Using cached text encoder outputs for prompt: {prompt}")
276
+
277
+ if sample_images_ae_outputs and prompt in sample_images_ae_outputs:
278
+ ae_outputs = sample_images_ae_outputs[prompt]
279
+ else:
280
+ ae_outputs = None
281
+
282
+ # ae_outputs = torch.load('ae_outputs.pth', map_location='cuda:0')
283
+
284
+ # text_encoder_conds = torch.load('text_encoder_conds.pth', map_location='cuda:0')
285
+ l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
286
+
287
+ # 打印调试信息
288
+ logger.debug(
289
+ f"l_pooled shape: {l_pooled.shape}, t5_out shape: {t5_out.shape}, txt_ids shape: {txt_ids.shape}, t5_attn_mask shape: {t5_attn_mask.shape}")
290
+
291
+ # 采样图像
292
+ weight_dtype = ae.dtype # TODO: give dtype as argument
293
+ packed_latent_height = height // 16
294
+ packed_latent_width = width // 16
295
+
296
+ # 打印调试信息
297
+ logger.debug(f"packed_latent_height: {packed_latent_height}, packed_latent_width: {packed_latent_width}")
298
+
299
+ # 准备噪声张量在 CUDA 上
300
+ noise = torch.randn(
301
+ 1,
302
+ packed_latent_height * packed_latent_width,
303
+ 16 * 2 * 2,
304
+ device=accelerator.device,
305
+ dtype=weight_dtype,
306
+ generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None,
307
+ )
308
+
309
+ timesteps = flux_train_utils.get_schedule(sample_steps, noise.shape[1], shift=True) # FLUX.1 dev -> shift=True
310
+ img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(
311
+ accelerator.device, dtype=weight_dtype
312
+ )
313
+ t5_attn_mask = t5_attn_mask.to(accelerator.device)
314
+
315
+ clip_l, t5xxl = text_encoder
316
+ # ae.to("cpu")
317
+ clip_l.to("cpu")
318
+ t5xxl.to("cpu")
319
+
320
+ clean_memory_on_device(accelerator.device)
321
+ flux.to("cuda")
322
+
323
+ for param in flux.parameters():
324
+ param.requires_grad = False
325
+
326
+ # 执行去噪
327
+ with accelerator.autocast(), torch.no_grad():
328
+ x = flux_train_utils.denoise(args, flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps,
329
+ guidance=scale, t5_attn_mask=t5_attn_mask, ae_outputs=ae_outputs)
330
+
331
+ # 打印x的形状
332
+ logger.debug(f"x shape after denoise: {x.shape}")
333
+
334
+ x = x.float()
335
+ x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
336
+
337
+ # 将潜在向量转换为图像
338
+ # clean_memory_on_device(accelerator.device)
339
+ ae.to(accelerator.device)
340
+ with accelerator.autocast(), torch.no_grad():
341
+ x = ae.decode(x)
342
+ ae.to("cpu")
343
+ clean_memory_on_device(accelerator.device)
344
+
345
+ x = x.clamp(-1, 1)
346
+ x = x.permute(0, 2, 3, 1)
347
+ image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
348
+
349
+ # 生成唯一的文件名
350
+ ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
351
+ seed_suffix = "" if seed is None else f"_{seed}"
352
+ i: int = prompt_dict.get("enum", 0) # Ensure 'enum' exists
353
+ img_filename = f"{ts_str}{seed_suffix}_{i}.png" # Added 'i' to filename for uniqueness
354
+ image.save(os.path.join(save_dir, img_filename))
355
+
356
+
357
+ def setup_argparse():
358
+ parser = argparse.ArgumentParser(description="FLUX-Controlnet-Inpainting Inference Script")
359
+
360
+ # Paths
361
+ parser.add_argument('--base_flux_checkpoint', type=str, required=True,
362
+ help='Path to BASE_FLUX_CHECKPOINT')
363
+ parser.add_argument('--lora_weights_path', type=str, required=True,
364
+ help='Path to LORA_WEIGHTS_PATH')
365
+ parser.add_argument('--clip_l_path', type=str, required=True,
366
+ help='Path to CLIP_L_PATH')
367
+ parser.add_argument('--t5xxl_path', type=str, required=True,
368
+ help='Path to T5XXL_PATH')
369
+ parser.add_argument('--ae_path', type=str, required=True,
370
+ help='Path to AE_PATH')
371
+ parser.add_argument('--sample_images_file', type=str, required=True,
372
+ help='Path to SAMPLE_IMAGES_FILE')
373
+ parser.add_argument('--sample_prompts_file', type=str, required=True,
374
+ help='Path to SAMPLE_PROMPTS_FILE')
375
+ parser.add_argument('--output_dir', type=str, required=True,
376
+ help='Directory to save OUTPUT_DIR')
377
+ parser.add_argument('--frame_num', type=int, choices=[4, 9], required=True,
378
+ help="The number of steps in the generated step diagram (choose 4 or 9)")
379
+
380
+ return parser.parse_args()
381
+
382
+
383
+ def main(args):
384
+ accelerator = Accelerator(mixed_precision='bf16', device_placement=True)
385
+
386
+ BASE_FLUX_CHECKPOINT = args.base_flux_checkpoint
387
+ LORA_WEIGHTS_PATH = args.lora_weights_path
388
+ CLIP_L_PATH = args.clip_l_path
389
+ T5XXL_PATH = args.t5xxl_path
390
+ AE_PATH = args.ae_path
391
+
392
+ SAMPLE_IMAGES_FILE = args.sample_images_file
393
+ SAMPLE_PROMPTS_FILE = args.sample_prompts_file
394
+ OUTPUT_DIR = args.output_dir
395
+
396
+ with open(SAMPLE_IMAGES_FILE, "r", encoding="utf-8") as f:
397
+ image_lines = f.readlines()
398
+ sample_images = [line.strip() for line in image_lines if line.strip() and not line.strip().startswith("#")]
399
+
400
+ sample_prompts = train_util.load_prompts(SAMPLE_PROMPTS_FILE)
401
+
402
+ # Load models onto CUDA via Accelerator
403
+ _, [clip_l, t5xxl], ae, model = load_target_model(
404
+ fp8_base=True,
405
+ pretrained_model_name_or_path=BASE_FLUX_CHECKPOINT,
406
+ disable_mmap_load_safetensors=False,
407
+ clip_l_path=CLIP_L_PATH,
408
+ fp8_base_unet=False,
409
+ t5xxl_path=T5XXL_PATH,
410
+ ae_path=AE_PATH,
411
+ weight_dtype=torch.bfloat16,
412
+ accelerator=accelerator
413
+ )
414
+
415
+ model.eval()
416
+ clip_l.eval()
417
+ t5xxl.eval()
418
+ ae.eval()
419
+
420
+ # LoRA
421
+ multiplier = 1.0
422
+ weights_sd = load_file(LORA_WEIGHTS_PATH)
423
+ lora_model, _ = lora_flux.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd,
424
+ True)
425
+
426
+ lora_model.apply_to([clip_l, t5xxl], model)
427
+ info = lora_model.load_state_dict(weights_sd, strict=True)
428
+ logger.info(f"Loaded LoRA weights from {LORA_WEIGHTS_PATH}: {info}")
429
+ lora_model.eval()
430
+ lora_model.to("cuda")
431
+
432
+ # Set text encoders
433
+ text_encoder = [clip_l, t5xxl]
434
+
435
+ sample(args, accelerator, vae=ae, text_encoder=text_encoder, flux=model, output_dir=OUTPUT_DIR,
436
+ sample_images=sample_images, sample_prompts=sample_prompts)
437
+
438
+
439
+ if __name__ == "__main__":
440
+ args = setup_argparse()
441
+
442
+ main(args)
flux_minimal_inference.py ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Minimum Inference Code for FLUX
2
+
3
+ import argparse
4
+ import datetime
5
+ import math
6
+ import os
7
+ import random
8
+ from typing import Callable, List, Optional
9
+ import einops
10
+ import numpy as np
11
+
12
+ import torch
13
+ from tqdm import tqdm
14
+ from PIL import Image
15
+ import accelerate
16
+ from transformers import CLIPTextModel
17
+ from safetensors.torch import load_file
18
+
19
+ from library import device_utils
20
+ from library.device_utils import init_ipex, get_preferred_device
21
+ from networks import oft_flux
22
+
23
+ init_ipex()
24
+
25
+
26
+ from library.utils import setup_logging, str_to_dtype
27
+
28
+ setup_logging()
29
+ import logging
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+ import networks.lora_flux as lora_flux
34
+ from library import flux_models, flux_utils, sd3_utils, strategy_flux
35
+
36
+
37
+ def time_shift(mu: float, sigma: float, t: torch.Tensor):
38
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
39
+
40
+
41
+ def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
42
+ m = (y2 - y1) / (x2 - x1)
43
+ b = y1 - m * x1
44
+ return lambda x: m * x + b
45
+
46
+
47
+ def get_schedule(
48
+ num_steps: int,
49
+ image_seq_len: int,
50
+ base_shift: float = 0.5,
51
+ max_shift: float = 1.15,
52
+ shift: bool = True,
53
+ ) -> list[float]:
54
+ # extra step for zero
55
+ timesteps = torch.linspace(1, 0, num_steps + 1)
56
+
57
+ # shifting the schedule to favor high timesteps for higher signal images
58
+ if shift:
59
+ # eastimate mu based on linear estimation between two points
60
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
61
+ timesteps = time_shift(mu, 1.0, timesteps)
62
+
63
+ return timesteps.tolist()
64
+
65
+
66
+ def denoise(
67
+ model: flux_models.Flux,
68
+ img: torch.Tensor,
69
+ img_ids: torch.Tensor,
70
+ txt: torch.Tensor,
71
+ txt_ids: torch.Tensor,
72
+ vec: torch.Tensor,
73
+ timesteps: list[float],
74
+ guidance: float = 4.0,
75
+ t5_attn_mask: Optional[torch.Tensor] = None,
76
+ neg_txt: Optional[torch.Tensor] = None,
77
+ neg_vec: Optional[torch.Tensor] = None,
78
+ neg_t5_attn_mask: Optional[torch.Tensor] = None,
79
+ cfg_scale: Optional[float] = None,
80
+ ):
81
+ # this is ignored for schnell
82
+ logger.info(f"guidance: {guidance}, cfg_scale: {cfg_scale}")
83
+ guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
84
+
85
+ # prepare classifier free guidance
86
+ if neg_txt is not None and neg_vec is not None:
87
+ b_img_ids = torch.cat([img_ids, img_ids], dim=0)
88
+ b_txt_ids = torch.cat([txt_ids, txt_ids], dim=0)
89
+ b_txt = torch.cat([neg_txt, txt], dim=0)
90
+ b_vec = torch.cat([neg_vec, vec], dim=0)
91
+ if t5_attn_mask is not None and neg_t5_attn_mask is not None:
92
+ b_t5_attn_mask = torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0)
93
+ else:
94
+ b_t5_attn_mask = None
95
+ else:
96
+ b_img_ids = img_ids
97
+ b_txt_ids = txt_ids
98
+ b_txt = txt
99
+ b_vec = vec
100
+ b_t5_attn_mask = t5_attn_mask
101
+
102
+ for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
103
+ t_vec = torch.full((b_img_ids.shape[0],), t_curr, dtype=img.dtype, device=img.device)
104
+
105
+ # classifier free guidance
106
+ if neg_txt is not None and neg_vec is not None:
107
+ b_img = torch.cat([img, img], dim=0)
108
+ else:
109
+ b_img = img
110
+
111
+ pred = model(
112
+ img=b_img,
113
+ img_ids=b_img_ids,
114
+ txt=b_txt,
115
+ txt_ids=b_txt_ids,
116
+ y=b_vec,
117
+ timesteps=t_vec,
118
+ guidance=guidance_vec,
119
+ txt_attention_mask=b_t5_attn_mask,
120
+ )
121
+
122
+ # classifier free guidance
123
+ if neg_txt is not None and neg_vec is not None:
124
+ pred_uncond, pred = torch.chunk(pred, 2, dim=0)
125
+ pred = pred_uncond + cfg_scale * (pred - pred_uncond)
126
+
127
+ img = img + (t_prev - t_curr) * pred
128
+
129
+ return img
130
+
131
+
132
+ def do_sample(
133
+ accelerator: Optional[accelerate.Accelerator],
134
+ model: flux_models.Flux,
135
+ img: torch.Tensor,
136
+ img_ids: torch.Tensor,
137
+ l_pooled: torch.Tensor,
138
+ t5_out: torch.Tensor,
139
+ txt_ids: torch.Tensor,
140
+ num_steps: int,
141
+ guidance: float,
142
+ t5_attn_mask: Optional[torch.Tensor],
143
+ is_schnell: bool,
144
+ device: torch.device,
145
+ flux_dtype: torch.dtype,
146
+ neg_l_pooled: Optional[torch.Tensor] = None,
147
+ neg_t5_out: Optional[torch.Tensor] = None,
148
+ neg_t5_attn_mask: Optional[torch.Tensor] = None,
149
+ cfg_scale: Optional[float] = None,
150
+ ):
151
+ logger.info(f"num_steps: {num_steps}")
152
+ timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell)
153
+
154
+ # denoise initial noise
155
+ if accelerator:
156
+ with accelerator.autocast(), torch.no_grad():
157
+ x = denoise(
158
+ model,
159
+ img,
160
+ img_ids,
161
+ t5_out,
162
+ txt_ids,
163
+ l_pooled,
164
+ timesteps,
165
+ guidance,
166
+ t5_attn_mask,
167
+ neg_t5_out,
168
+ neg_l_pooled,
169
+ neg_t5_attn_mask,
170
+ cfg_scale,
171
+ )
172
+ else:
173
+ with torch.autocast(device_type=device.type, dtype=flux_dtype), torch.no_grad():
174
+ x = denoise(
175
+ model,
176
+ img,
177
+ img_ids,
178
+ t5_out,
179
+ txt_ids,
180
+ l_pooled,
181
+ timesteps,
182
+ guidance,
183
+ t5_attn_mask,
184
+ neg_t5_out,
185
+ neg_l_pooled,
186
+ neg_t5_attn_mask,
187
+ cfg_scale,
188
+ )
189
+
190
+ return x
191
+
192
+
193
+ def generate_image(
194
+ model,
195
+ clip_l: CLIPTextModel,
196
+ t5xxl,
197
+ ae,
198
+ prompt: str,
199
+ seed: Optional[int],
200
+ image_width: int,
201
+ image_height: int,
202
+ steps: Optional[int],
203
+ guidance: float,
204
+ negative_prompt: Optional[str],
205
+ cfg_scale: float,
206
+ ):
207
+ seed = seed if seed is not None else random.randint(0, 2**32 - 1)
208
+ logger.info(f"Seed: {seed}")
209
+
210
+ # make first noise with packed shape
211
+ # original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2
212
+ packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16)
213
+ noise_dtype = torch.float32 if is_fp8(dtype) else dtype
214
+ noise = torch.randn(
215
+ 1,
216
+ packed_latent_height * packed_latent_width,
217
+ 16 * 2 * 2,
218
+ device=device,
219
+ dtype=noise_dtype,
220
+ generator=torch.Generator(device=device).manual_seed(seed),
221
+ )
222
+
223
+ # prepare img and img ids
224
+
225
+ # this is needed only for img2img
226
+ # img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
227
+ # if img.shape[0] == 1 and bs > 1:
228
+ # img = repeat(img, "1 ... -> bs ...", bs=bs)
229
+
230
+ # txt2img only needs img_ids
231
+ img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width)
232
+
233
+ # prepare fp8 models
234
+ if is_fp8(clip_l_dtype) and (not hasattr(clip_l, "fp8_prepared") or not clip_l.fp8_prepared):
235
+ logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}")
236
+ clip_l.to(clip_l_dtype) # fp8
237
+ clip_l.text_model.embeddings.to(dtype=torch.bfloat16)
238
+ clip_l.fp8_prepared = True
239
+
240
+ if is_fp8(t5xxl_dtype) and (not hasattr(t5xxl, "fp8_prepared") or not t5xxl.fp8_prepared):
241
+ logger.info(f"prepare T5xxl for fp8: set to {t5xxl_dtype}")
242
+
243
+ def prepare_fp8(text_encoder, target_dtype):
244
+ def forward_hook(module):
245
+ def forward(hidden_states):
246
+ hidden_gelu = module.act(module.wi_0(hidden_states))
247
+ hidden_linear = module.wi_1(hidden_states)
248
+ hidden_states = hidden_gelu * hidden_linear
249
+ hidden_states = module.dropout(hidden_states)
250
+
251
+ hidden_states = module.wo(hidden_states)
252
+ return hidden_states
253
+
254
+ return forward
255
+
256
+ for module in text_encoder.modules():
257
+ if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
258
+ # print("set", module.__class__.__name__, "to", target_dtype)
259
+ module.to(target_dtype)
260
+ if module.__class__.__name__ in ["T5DenseGatedActDense"]:
261
+ # print("set", module.__class__.__name__, "hooks")
262
+ module.forward = forward_hook(module)
263
+
264
+ t5xxl.to(t5xxl_dtype)
265
+ prepare_fp8(t5xxl.encoder, torch.bfloat16)
266
+ t5xxl.fp8_prepared = True
267
+
268
+ # prepare embeddings
269
+ logger.info("Encoding prompts...")
270
+ clip_l = clip_l.to(device)
271
+ t5xxl = t5xxl.to(device)
272
+
273
+ def encode(prpt: str):
274
+ tokens_and_masks = tokenize_strategy.tokenize(prpt)
275
+ with torch.no_grad():
276
+ if is_fp8(clip_l_dtype):
277
+ with accelerator.autocast():
278
+ l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
279
+ else:
280
+ with torch.autocast(device_type=device.type, dtype=clip_l_dtype):
281
+ l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
282
+
283
+ if is_fp8(t5xxl_dtype):
284
+ with accelerator.autocast():
285
+ _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
286
+ tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
287
+ )
288
+ else:
289
+ with torch.autocast(device_type=device.type, dtype=t5xxl_dtype):
290
+ _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
291
+ tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
292
+ )
293
+ return l_pooled, t5_out, txt_ids, t5_attn_mask
294
+
295
+ l_pooled, t5_out, txt_ids, t5_attn_mask = encode(prompt)
296
+ if negative_prompt:
297
+ neg_l_pooled, neg_t5_out, _, neg_t5_attn_mask = encode(negative_prompt)
298
+ else:
299
+ neg_l_pooled, neg_t5_out, neg_t5_attn_mask = None, None, None
300
+
301
+ # NaN check
302
+ if torch.isnan(l_pooled).any():
303
+ raise ValueError("NaN in l_pooled")
304
+ if torch.isnan(t5_out).any():
305
+ raise ValueError("NaN in t5_out")
306
+
307
+ if args.offload:
308
+ clip_l = clip_l.cpu()
309
+ t5xxl = t5xxl.cpu()
310
+ # del clip_l, t5xxl
311
+ device_utils.clean_memory()
312
+
313
+ # generate image
314
+ logger.info("Generating image...")
315
+ model = model.to(device)
316
+ if steps is None:
317
+ steps = 4 if is_schnell else 50
318
+
319
+ img_ids = img_ids.to(device)
320
+ t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None
321
+
322
+ x = do_sample(
323
+ accelerator,
324
+ model,
325
+ noise,
326
+ img_ids,
327
+ l_pooled,
328
+ t5_out,
329
+ txt_ids,
330
+ steps,
331
+ guidance,
332
+ t5_attn_mask,
333
+ is_schnell,
334
+ device,
335
+ flux_dtype,
336
+ neg_l_pooled,
337
+ neg_t5_out,
338
+ neg_t5_attn_mask,
339
+ cfg_scale,
340
+ )
341
+ if args.offload:
342
+ model = model.cpu()
343
+ # del model
344
+ device_utils.clean_memory()
345
+
346
+ # unpack
347
+ x = x.float()
348
+ x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2)
349
+
350
+ # decode
351
+ logger.info("Decoding image...")
352
+ ae = ae.to(device)
353
+ with torch.no_grad():
354
+ if is_fp8(ae_dtype):
355
+ with accelerator.autocast():
356
+ x = ae.decode(x)
357
+ else:
358
+ with torch.autocast(device_type=device.type, dtype=ae_dtype):
359
+ x = ae.decode(x)
360
+ if args.offload:
361
+ ae = ae.cpu()
362
+
363
+ x = x.clamp(-1, 1)
364
+ x = x.permute(0, 2, 3, 1)
365
+ img = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
366
+
367
+ # save image
368
+ output_dir = args.output_dir
369
+ os.makedirs(output_dir, exist_ok=True)
370
+ output_path = os.path.join(output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png")
371
+ img.save(output_path)
372
+
373
+ logger.info(f"Saved image to {output_path}")
374
+
375
+
376
+ if __name__ == "__main__":
377
+ target_height = 768 # 1024
378
+ target_width = 1360 # 1024
379
+
380
+ # steps = 50 # 28 # 50
381
+ # guidance_scale = 5
382
+ # seed = 1 # None # 1
383
+
384
+ device = get_preferred_device()
385
+
386
+ parser = argparse.ArgumentParser()
387
+ parser.add_argument("--ckpt_path", type=str, required=True)
388
+ parser.add_argument("--clip_l", type=str, required=False)
389
+ parser.add_argument("--t5xxl", type=str, required=False)
390
+ parser.add_argument("--ae", type=str, required=False)
391
+ parser.add_argument("--apply_t5_attn_mask", action="store_true")
392
+ parser.add_argument("--prompt", type=str, default="A photo of a cat")
393
+ parser.add_argument("--output_dir", type=str, default=".")
394
+ parser.add_argument("--dtype", type=str, default="bfloat16", help="base dtype")
395
+ parser.add_argument("--clip_l_dtype", type=str, default=None, help="dtype for clip_l")
396
+ parser.add_argument("--ae_dtype", type=str, default=None, help="dtype for ae")
397
+ parser.add_argument("--t5xxl_dtype", type=str, default=None, help="dtype for t5xxl")
398
+ parser.add_argument("--flux_dtype", type=str, default=None, help="dtype for flux")
399
+ parser.add_argument("--seed", type=int, default=None)
400
+ parser.add_argument("--steps", type=int, default=None, help="Number of steps. Default is 4 for schnell, 50 for dev")
401
+ parser.add_argument("--guidance", type=float, default=3.5)
402
+ parser.add_argument("--negative_prompt", type=str, default=None)
403
+ parser.add_argument("--cfg_scale", type=float, default=1.0)
404
+ parser.add_argument("--offload", action="store_true", help="Offload to CPU")
405
+ parser.add_argument(
406
+ "--lora_weights",
407
+ type=str,
408
+ nargs="*",
409
+ default=[],
410
+ help="LoRA weights, only supports networks.lora_flux and lora_oft, each argument is a `path;multiplier` (semi-colon separated)",
411
+ )
412
+ parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model")
413
+ parser.add_argument("--width", type=int, default=target_width)
414
+ parser.add_argument("--height", type=int, default=target_height)
415
+ parser.add_argument("--interactive", action="store_true")
416
+ args = parser.parse_args()
417
+
418
+ seed = args.seed
419
+ steps = args.steps
420
+ guidance_scale = args.guidance
421
+
422
+ def is_fp8(dt):
423
+ return dt in [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz]
424
+
425
+ dtype = str_to_dtype(args.dtype)
426
+ clip_l_dtype = str_to_dtype(args.clip_l_dtype, dtype)
427
+ t5xxl_dtype = str_to_dtype(args.t5xxl_dtype, dtype)
428
+ ae_dtype = str_to_dtype(args.ae_dtype, dtype)
429
+ flux_dtype = str_to_dtype(args.flux_dtype, dtype)
430
+
431
+ logger.info(f"Dtypes for clip_l, t5xxl, ae, flux: {clip_l_dtype}, {t5xxl_dtype}, {ae_dtype}, {flux_dtype}")
432
+
433
+ loading_device = "cpu" if args.offload else device
434
+
435
+ use_fp8 = [is_fp8(d) for d in [dtype, clip_l_dtype, t5xxl_dtype, ae_dtype, flux_dtype]]
436
+ if any(use_fp8):
437
+ accelerator = accelerate.Accelerator(mixed_precision="bf16")
438
+ else:
439
+ accelerator = None
440
+
441
+ # load clip_l
442
+ logger.info(f"Loading clip_l from {args.clip_l}...")
443
+ clip_l = flux_utils.load_clip_l(args.clip_l, clip_l_dtype, loading_device)
444
+ clip_l.eval()
445
+
446
+ logger.info(f"Loading t5xxl from {args.t5xxl}...")
447
+ t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device)
448
+ t5xxl.eval()
449
+
450
+ # if is_fp8(clip_l_dtype):
451
+ # clip_l = accelerator.prepare(clip_l)
452
+ # if is_fp8(t5xxl_dtype):
453
+ # t5xxl = accelerator.prepare(t5xxl)
454
+
455
+ # DiT
456
+ is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device)
457
+ model.eval()
458
+ logger.info(f"Casting model to {flux_dtype}")
459
+ model.to(flux_dtype) # make sure model is dtype
460
+ # if is_fp8(flux_dtype):
461
+ # model = accelerator.prepare(model)
462
+ # if args.offload:
463
+ # model = model.to("cpu")
464
+
465
+ t5xxl_max_length = 256 if is_schnell else 512
466
+ tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length)
467
+ encoding_strategy = strategy_flux.FluxTextEncodingStrategy()
468
+
469
+ # AE
470
+ ae = flux_utils.load_ae(args.ae, ae_dtype, loading_device)
471
+ ae.eval()
472
+ # if is_fp8(ae_dtype):
473
+ # ae = accelerator.prepare(ae)
474
+
475
+ # LoRA
476
+ lora_models: List[lora_flux.LoRANetwork] = []
477
+ for weights_file in args.lora_weights:
478
+ if ";" in weights_file:
479
+ weights_file, multiplier = weights_file.split(";")
480
+ multiplier = float(multiplier)
481
+ else:
482
+ multiplier = 1.0
483
+
484
+ weights_sd = load_file(weights_file)
485
+ is_lora = is_oft = False
486
+ for key in weights_sd.keys():
487
+ if key.startswith("lora"):
488
+ is_lora = True
489
+ if key.startswith("oft"):
490
+ is_oft = True
491
+ if is_lora or is_oft:
492
+ break
493
+
494
+ module = lora_flux if is_lora else oft_flux
495
+ lora_model, _ = module.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True)
496
+
497
+ if args.merge_lora_weights:
498
+ lora_model.merge_to([clip_l, t5xxl], model, weights_sd)
499
+ else:
500
+ lora_model.apply_to([clip_l, t5xxl], model)
501
+ info = lora_model.load_state_dict(weights_sd, strict=True)
502
+ logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
503
+ lora_model.eval()
504
+ lora_model.to(device)
505
+
506
+ lora_models.append(lora_model)
507
+
508
+ if not args.interactive:
509
+ generate_image(
510
+ model,
511
+ clip_l,
512
+ t5xxl,
513
+ ae,
514
+ args.prompt,
515
+ args.seed,
516
+ args.width,
517
+ args.height,
518
+ args.steps,
519
+ args.guidance,
520
+ args.negative_prompt,
521
+ args.cfg_scale,
522
+ )
523
+ else:
524
+ # loop for interactive
525
+ width = target_width
526
+ height = target_height
527
+ steps = None
528
+ guidance = args.guidance
529
+ cfg_scale = args.cfg_scale
530
+
531
+ while True:
532
+ print(
533
+ "Enter prompt (empty to exit). Options: --w <width> --h <height> --s <steps> --d <seed> --g <guidance> --m <multipliers for LoRA>"
534
+ " --n <negative prompt>, `-` for empty negative prompt --c <cfg_scale>"
535
+ )
536
+ prompt = input()
537
+ if prompt == "":
538
+ break
539
+
540
+ # parse options
541
+ options = prompt.split("--")
542
+ prompt = options[0].strip()
543
+ seed = None
544
+ negative_prompt = None
545
+ for opt in options[1:]:
546
+ try:
547
+ opt = opt.strip()
548
+ if opt.startswith("w"):
549
+ width = int(opt[1:].strip())
550
+ elif opt.startswith("h"):
551
+ height = int(opt[1:].strip())
552
+ elif opt.startswith("s"):
553
+ steps = int(opt[1:].strip())
554
+ elif opt.startswith("d"):
555
+ seed = int(opt[1:].strip())
556
+ elif opt.startswith("g"):
557
+ guidance = float(opt[1:].strip())
558
+ elif opt.startswith("m"):
559
+ mutipliers = opt[1:].strip().split(",")
560
+ if len(mutipliers) != len(lora_models):
561
+ logger.error(f"Invalid number of multipliers, expected {len(lora_models)}")
562
+ continue
563
+ for i, lora_model in enumerate(lora_models):
564
+ lora_model.set_multiplier(float(mutipliers[i]))
565
+ elif opt.startswith("n"):
566
+ negative_prompt = opt[1:].strip()
567
+ if negative_prompt == "-":
568
+ negative_prompt = ""
569
+ elif opt.startswith("c"):
570
+ cfg_scale = float(opt[1:].strip())
571
+ except ValueError as e:
572
+ logger.error(f"Invalid option: {opt}, {e}")
573
+
574
+ generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance, negative_prompt, cfg_scale)
575
+
576
+ logger.info("Done!")
flux_minimal_inference_asylora.py ADDED
@@ -0,0 +1,583 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Minimum Inference Code for FLUX
3
+
4
+ import argparse
5
+ import datetime
6
+ import math
7
+ import os
8
+ import random
9
+ from typing import Callable, List, Optional
10
+ import einops
11
+ import numpy as np
12
+
13
+ import torch
14
+ from tqdm import tqdm
15
+ from PIL import Image
16
+ import accelerate
17
+ from transformers import CLIPTextModel
18
+ from safetensors.torch import load_file
19
+
20
+ from library import device_utils
21
+ from library.device_utils import init_ipex, get_preferred_device
22
+ from networks import oft_flux
23
+
24
+ init_ipex()
25
+
26
+
27
+ from library.utils import setup_logging, str_to_dtype
28
+
29
+ setup_logging()
30
+ import logging
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+ import networks.asylora_flux as lora_flux
35
+ from library import flux_models, flux_utils, sd3_utils, strategy_flux
36
+
37
+
38
+ def time_shift(mu: float, sigma: float, t: torch.Tensor):
39
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
40
+
41
+
42
+ def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
43
+ m = (y2 - y1) / (x2 - x1)
44
+ b = y1 - m * x1
45
+ return lambda x: m * x + b
46
+
47
+
48
+ def get_schedule(
49
+ num_steps: int,
50
+ image_seq_len: int,
51
+ base_shift: float = 0.5,
52
+ max_shift: float = 1.15,
53
+ shift: bool = True,
54
+ ) -> list[float]:
55
+ # extra step for zero
56
+ timesteps = torch.linspace(1, 0, num_steps + 1)
57
+
58
+ # shifting the schedule to favor high timesteps for higher signal images
59
+ if shift:
60
+ # eastimate mu based on linear estimation between two points
61
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
62
+ timesteps = time_shift(mu, 1.0, timesteps)
63
+
64
+ return timesteps.tolist()
65
+
66
+
67
+ def denoise(
68
+ model: flux_models.Flux,
69
+ img: torch.Tensor,
70
+ img_ids: torch.Tensor,
71
+ txt: torch.Tensor,
72
+ txt_ids: torch.Tensor,
73
+ vec: torch.Tensor,
74
+ timesteps: list[float],
75
+ guidance: float = 4.0,
76
+ t5_attn_mask: Optional[torch.Tensor] = None,
77
+ neg_txt: Optional[torch.Tensor] = None,
78
+ neg_vec: Optional[torch.Tensor] = None,
79
+ neg_t5_attn_mask: Optional[torch.Tensor] = None,
80
+ cfg_scale: Optional[float] = None,
81
+ ):
82
+ # this is ignored for schnell
83
+ logger.info(f"guidance: {guidance}, cfg_scale: {cfg_scale}")
84
+ guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
85
+
86
+ # prepare classifier free guidance
87
+ if neg_txt is not None and neg_vec is not None:
88
+ b_img_ids = torch.cat([img_ids, img_ids], dim=0)
89
+ b_txt_ids = torch.cat([txt_ids, txt_ids], dim=0)
90
+ b_txt = torch.cat([neg_txt, txt], dim=0)
91
+ b_vec = torch.cat([neg_vec, vec], dim=0)
92
+ if t5_attn_mask is not None and neg_t5_attn_mask is not None:
93
+ b_t5_attn_mask = torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0)
94
+ else:
95
+ b_t5_attn_mask = None
96
+ else:
97
+ b_img_ids = img_ids
98
+ b_txt_ids = txt_ids
99
+ b_txt = txt
100
+ b_vec = vec
101
+ b_t5_attn_mask = t5_attn_mask
102
+
103
+ for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
104
+ t_vec = torch.full((b_img_ids.shape[0],), t_curr, dtype=img.dtype, device=img.device)
105
+
106
+ # classifier free guidance
107
+ if neg_txt is not None and neg_vec is not None:
108
+ b_img = torch.cat([img, img], dim=0)
109
+ else:
110
+ b_img = img
111
+
112
+ pred = model(
113
+ img=b_img,
114
+ img_ids=b_img_ids,
115
+ txt=b_txt,
116
+ txt_ids=b_txt_ids,
117
+ y=b_vec,
118
+ timesteps=t_vec,
119
+ guidance=guidance_vec,
120
+ txt_attention_mask=b_t5_attn_mask,
121
+ )
122
+
123
+ # classifier free guidance
124
+ if neg_txt is not None and neg_vec is not None:
125
+ pred_uncond, pred = torch.chunk(pred, 2, dim=0)
126
+ pred = pred_uncond + cfg_scale * (pred - pred_uncond)
127
+
128
+ img = img + (t_prev - t_curr) * pred
129
+
130
+ return img
131
+
132
+
133
+ def do_sample(
134
+ accelerator: Optional[accelerate.Accelerator],
135
+ model: flux_models.Flux,
136
+ img: torch.Tensor,
137
+ img_ids: torch.Tensor,
138
+ l_pooled: torch.Tensor,
139
+ t5_out: torch.Tensor,
140
+ txt_ids: torch.Tensor,
141
+ num_steps: int,
142
+ guidance: float,
143
+ t5_attn_mask: Optional[torch.Tensor],
144
+ is_schnell: bool,
145
+ device: torch.device,
146
+ flux_dtype: torch.dtype,
147
+ neg_l_pooled: Optional[torch.Tensor] = None,
148
+ neg_t5_out: Optional[torch.Tensor] = None,
149
+ neg_t5_attn_mask: Optional[torch.Tensor] = None,
150
+ cfg_scale: Optional[float] = None,
151
+ ):
152
+ logger.info(f"num_steps: {num_steps}")
153
+ timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell)
154
+
155
+ # denoise initial noise
156
+ if accelerator:
157
+ with accelerator.autocast(), torch.no_grad():
158
+ x = denoise(
159
+ model,
160
+ img,
161
+ img_ids,
162
+ t5_out,
163
+ txt_ids,
164
+ l_pooled,
165
+ timesteps,
166
+ guidance,
167
+ t5_attn_mask,
168
+ neg_t5_out,
169
+ neg_l_pooled,
170
+ neg_t5_attn_mask,
171
+ cfg_scale,
172
+ )
173
+ else:
174
+ with torch.autocast(device_type=device.type, dtype=flux_dtype), torch.no_grad():
175
+ x = denoise(
176
+ model,
177
+ img,
178
+ img_ids,
179
+ t5_out,
180
+ txt_ids,
181
+ l_pooled,
182
+ timesteps,
183
+ guidance,
184
+ t5_attn_mask,
185
+ neg_t5_out,
186
+ neg_l_pooled,
187
+ neg_t5_attn_mask,
188
+ cfg_scale,
189
+ )
190
+
191
+ return x
192
+
193
+
194
+ def generate_image(
195
+ model,
196
+ clip_l: CLIPTextModel,
197
+ t5xxl,
198
+ ae,
199
+ prompt: str,
200
+ seed: Optional[int],
201
+ image_width: int,
202
+ image_height: int,
203
+ steps: Optional[int],
204
+ guidance: float,
205
+ negative_prompt: Optional[str],
206
+ cfg_scale: float,
207
+ ):
208
+ seed = seed if seed is not None else random.randint(0, 2**32 - 1)
209
+ logger.info(f"Seed: {seed}")
210
+
211
+ # make first noise with packed shape
212
+ # original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2
213
+ packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16)
214
+ noise_dtype = torch.float32 if is_fp8(dtype) else dtype
215
+ noise = torch.randn(
216
+ 1,
217
+ packed_latent_height * packed_latent_width,
218
+ 16 * 2 * 2,
219
+ device=device,
220
+ dtype=noise_dtype,
221
+ generator=torch.Generator(device=device).manual_seed(seed),
222
+ )
223
+
224
+ # prepare img and img ids
225
+
226
+ # this is needed only for img2img
227
+ # img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
228
+ # if img.shape[0] == 1 and bs > 1:
229
+ # img = repeat(img, "1 ... -> bs ...", bs=bs)
230
+
231
+ # txt2img only needs img_ids
232
+ img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width)
233
+
234
+ # prepare fp8 models
235
+ if is_fp8(clip_l_dtype) and (not hasattr(clip_l, "fp8_prepared") or not clip_l.fp8_prepared):
236
+ logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}")
237
+ clip_l.to(clip_l_dtype) # fp8
238
+ clip_l.text_model.embeddings.to(dtype=torch.bfloat16)
239
+ clip_l.fp8_prepared = True
240
+
241
+ if is_fp8(t5xxl_dtype) and (not hasattr(t5xxl, "fp8_prepared") or not t5xxl.fp8_prepared):
242
+ logger.info(f"prepare T5xxl for fp8: set to {t5xxl_dtype}")
243
+
244
+ def prepare_fp8(text_encoder, target_dtype):
245
+ def forward_hook(module):
246
+ def forward(hidden_states):
247
+ hidden_gelu = module.act(module.wi_0(hidden_states))
248
+ hidden_linear = module.wi_1(hidden_states)
249
+ hidden_states = hidden_gelu * hidden_linear
250
+ hidden_states = module.dropout(hidden_states)
251
+
252
+ hidden_states = module.wo(hidden_states)
253
+ return hidden_states
254
+
255
+ return forward
256
+
257
+ for module in text_encoder.modules():
258
+ if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
259
+ # print("set", module.__class__.__name__, "to", target_dtype)
260
+ module.to(target_dtype)
261
+ if module.__class__.__name__ in ["T5DenseGatedActDense"]:
262
+ # print("set", module.__class__.__name__, "hooks")
263
+ module.forward = forward_hook(module)
264
+
265
+ t5xxl.to(t5xxl_dtype)
266
+ prepare_fp8(t5xxl.encoder, torch.bfloat16)
267
+ t5xxl.fp8_prepared = True
268
+
269
+ # prepare embeddings
270
+ logger.info("Encoding prompts...")
271
+ clip_l = clip_l.to(device)
272
+ t5xxl = t5xxl.to(device)
273
+
274
+ def encode(prpt: str):
275
+ tokens_and_masks = tokenize_strategy.tokenize(prpt)
276
+ with torch.no_grad():
277
+ if is_fp8(clip_l_dtype):
278
+ with accelerator.autocast():
279
+ l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
280
+ else:
281
+ with torch.autocast(device_type=device.type, dtype=clip_l_dtype):
282
+ l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
283
+
284
+ if is_fp8(t5xxl_dtype):
285
+ with accelerator.autocast():
286
+ _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
287
+ tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
288
+ )
289
+ else:
290
+ with torch.autocast(device_type=device.type, dtype=t5xxl_dtype):
291
+ _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
292
+ tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
293
+ )
294
+ return l_pooled, t5_out, txt_ids, t5_attn_mask
295
+
296
+ l_pooled, t5_out, txt_ids, t5_attn_mask = encode(prompt)
297
+ if negative_prompt:
298
+ neg_l_pooled, neg_t5_out, _, neg_t5_attn_mask = encode(negative_prompt)
299
+ else:
300
+ neg_l_pooled, neg_t5_out, neg_t5_attn_mask = None, None, None
301
+
302
+ # NaN check
303
+ if torch.isnan(l_pooled).any():
304
+ raise ValueError("NaN in l_pooled")
305
+ if torch.isnan(t5_out).any():
306
+ raise ValueError("NaN in t5_out")
307
+
308
+ if args.offload:
309
+ clip_l = clip_l.cpu()
310
+ t5xxl = t5xxl.cpu()
311
+ # del clip_l, t5xxl
312
+ device_utils.clean_memory()
313
+
314
+ # generate image
315
+ logger.info("Generating image...")
316
+ model = model.to(device)
317
+ if steps is None:
318
+ steps = 4 if is_schnell else 50
319
+
320
+ img_ids = img_ids.to(device)
321
+ t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None
322
+
323
+ x = do_sample(
324
+ accelerator,
325
+ model,
326
+ noise,
327
+ img_ids,
328
+ l_pooled,
329
+ t5_out,
330
+ txt_ids,
331
+ steps,
332
+ guidance,
333
+ t5_attn_mask,
334
+ is_schnell,
335
+ device,
336
+ flux_dtype,
337
+ neg_l_pooled,
338
+ neg_t5_out,
339
+ neg_t5_attn_mask,
340
+ cfg_scale,
341
+ )
342
+ if args.offload:
343
+ model = model.cpu()
344
+ # del model
345
+ device_utils.clean_memory()
346
+
347
+ # unpack
348
+ x = x.float()
349
+ x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2)
350
+
351
+ # decode
352
+ logger.info("Decoding image...")
353
+ ae = ae.to(device)
354
+ with torch.no_grad():
355
+ if is_fp8(ae_dtype):
356
+ with accelerator.autocast():
357
+ x = ae.decode(x)
358
+ else:
359
+ with torch.autocast(device_type=device.type, dtype=ae_dtype):
360
+ x = ae.decode(x)
361
+ if args.offload:
362
+ ae = ae.cpu()
363
+
364
+ x = x.clamp(-1, 1)
365
+ x = x.permute(0, 2, 3, 1)
366
+ img = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
367
+
368
+ # save image
369
+ output_dir = args.output_dir
370
+ os.makedirs(output_dir, exist_ok=True)
371
+ output_path = os.path.join(output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png")
372
+ img.save(output_path)
373
+
374
+ logger.info(f"Saved image to {output_path}")
375
+
376
+
377
+ if __name__ == "__main__":
378
+ target_height = 768 # 1024
379
+ target_width = 1360 # 1024
380
+
381
+ # steps = 50 # 28 # 50
382
+ # guidance_scale = 5
383
+ # seed = 1 # None # 1
384
+
385
+ device = get_preferred_device()
386
+
387
+ parser = argparse.ArgumentParser()
388
+ parser.add_argument("--lora_ups_num", type=int, required=True)
389
+ parser.add_argument("--lora_up_cur", type=int, required=True)
390
+ parser.add_argument("--ckpt_path", type=str, required=True)
391
+ parser.add_argument("--clip_l", type=str, required=False)
392
+ parser.add_argument("--t5xxl", type=str, required=False)
393
+ parser.add_argument("--ae", type=str, required=False)
394
+ parser.add_argument("--apply_t5_attn_mask", action="store_true")
395
+ parser.add_argument("--prompt", type=str, default="A photo of a cat")
396
+ parser.add_argument("--output_dir", type=str, default=".")
397
+ parser.add_argument("--dtype", type=str, default="bfloat16", help="base dtype")
398
+ parser.add_argument("--clip_l_dtype", type=str, default=None, help="dtype for clip_l")
399
+ parser.add_argument("--ae_dtype", type=str, default=None, help="dtype for ae")
400
+ parser.add_argument("--t5xxl_dtype", type=str, default=None, help="dtype for t5xxl")
401
+ parser.add_argument("--flux_dtype", type=str, default=None, help="dtype for flux")
402
+ parser.add_argument("--seed", type=int, default=None)
403
+ parser.add_argument("--steps", type=int, default=None, help="Number of steps. Default is 4 for schnell, 50 for dev")
404
+ parser.add_argument("--guidance", type=float, default=3.5)
405
+ parser.add_argument("--negative_prompt", type=str, default=None)
406
+ parser.add_argument("--cfg_scale", type=float, default=1.0)
407
+ parser.add_argument("--offload", action="store_true", help="Offload to CPU")
408
+ parser.add_argument(
409
+ "--lora_weights",
410
+ type=str,
411
+ nargs="*",
412
+ default=[],
413
+ help="LoRA weights, only supports networks.lora_flux and lora_oft, each argument is a `path;multiplier` (semi-colon separated)",
414
+ )
415
+ parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model")
416
+ parser.add_argument("--width", type=int, default=target_width)
417
+ parser.add_argument("--height", type=int, default=target_height)
418
+ parser.add_argument("--interactive", action="store_true")
419
+ args = parser.parse_args()
420
+
421
+ seed = args.seed
422
+ steps = args.steps
423
+ guidance_scale = args.guidance
424
+ lora_ups_num = args.lora_ups_num
425
+ lora_up_cur = args.lora_up_cur
426
+
427
+ def is_fp8(dt):
428
+ return dt in [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz]
429
+
430
+ dtype = str_to_dtype(args.dtype)
431
+ clip_l_dtype = str_to_dtype(args.clip_l_dtype, dtype)
432
+ t5xxl_dtype = str_to_dtype(args.t5xxl_dtype, dtype)
433
+ ae_dtype = str_to_dtype(args.ae_dtype, dtype)
434
+ flux_dtype = str_to_dtype(args.flux_dtype, dtype)
435
+
436
+ logger.info(f"Dtypes for clip_l, t5xxl, ae, flux: {clip_l_dtype}, {t5xxl_dtype}, {ae_dtype}, {flux_dtype}")
437
+
438
+ loading_device = "cpu" if args.offload else device
439
+
440
+ use_fp8 = [is_fp8(d) for d in [dtype, clip_l_dtype, t5xxl_dtype, ae_dtype, flux_dtype]]
441
+ if any(use_fp8):
442
+ accelerator = accelerate.Accelerator(mixed_precision="bf16")
443
+ else:
444
+ accelerator = None
445
+
446
+ # load clip_l
447
+ logger.info(f"Loading clip_l from {args.clip_l}...")
448
+ clip_l = flux_utils.load_clip_l(args.clip_l, clip_l_dtype, loading_device)
449
+ clip_l.eval()
450
+
451
+ logger.info(f"Loading t5xxl from {args.t5xxl}...")
452
+ t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device)
453
+ t5xxl.eval()
454
+
455
+ # if is_fp8(clip_l_dtype):
456
+ # clip_l = accelerator.prepare(clip_l)
457
+ # if is_fp8(t5xxl_dtype):
458
+ # t5xxl = accelerator.prepare(t5xxl)
459
+
460
+ # DiT
461
+ is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device)
462
+ model.eval()
463
+ logger.info(f"Casting model to {flux_dtype}")
464
+ model.to(flux_dtype) # make sure model is dtype
465
+ # if is_fp8(flux_dtype):
466
+ # model = accelerator.prepare(model)
467
+ # if args.offload:
468
+ # model = model.to("cpu")
469
+
470
+ t5xxl_max_length = 256 if is_schnell else 512
471
+ tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length)
472
+ encoding_strategy = strategy_flux.FluxTextEncodingStrategy()
473
+
474
+ # AE
475
+ ae = flux_utils.load_ae(args.ae, ae_dtype, loading_device)
476
+ ae.eval()
477
+ # if is_fp8(ae_dtype):
478
+ # ae = accelerator.prepare(ae)
479
+
480
+ # LoRA
481
+ lora_models: List[lora_flux.LoRANetwork] = []
482
+ for weights_file in args.lora_weights:
483
+ if ";" in weights_file:
484
+ weights_file, multiplier = weights_file.split(";")
485
+ multiplier = float(multiplier)
486
+ else:
487
+ multiplier = 1.0
488
+
489
+ weights_sd = load_file(weights_file)
490
+ is_lora = is_oft = False
491
+ for key in weights_sd.keys():
492
+ if key.startswith("lora"):
493
+ is_lora = True
494
+ if key.startswith("oft"):
495
+ is_oft = True
496
+ if is_lora or is_oft:
497
+ break
498
+
499
+ module = lora_flux if is_lora else oft_flux
500
+ lora_model, _ = module.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True, lora_ups_num)
501
+ for sub_lora in lora_model.unet_loras:
502
+ sub_lora.set_lora_up_cur(lora_up_cur-1)
503
+
504
+ if args.merge_lora_weights:
505
+ lora_model.merge_to([clip_l, t5xxl], model, weights_sd)
506
+ else:
507
+ lora_model.apply_to([clip_l, t5xxl], model)
508
+ info = lora_model.load_state_dict(weights_sd, strict=True)
509
+ logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
510
+ lora_model.eval()
511
+ lora_model.to(device)
512
+
513
+ lora_models.append(lora_model)
514
+
515
+ if not args.interactive:
516
+ generate_image(
517
+ model,
518
+ clip_l,
519
+ t5xxl,
520
+ ae,
521
+ args.prompt,
522
+ args.seed,
523
+ args.width,
524
+ args.height,
525
+ args.steps,
526
+ args.guidance,
527
+ args.negative_prompt,
528
+ args.cfg_scale,
529
+ )
530
+ else:
531
+ # loop for interactive
532
+ width = target_width
533
+ height = target_height
534
+ steps = None
535
+ guidance = args.guidance
536
+ cfg_scale = args.cfg_scale
537
+
538
+ while True:
539
+ print(
540
+ "Enter prompt (empty to exit). Options: --w <width> --h <height> --s <steps> --d <seed> --g <guidance> --m <multipliers for LoRA>"
541
+ " --n <negative prompt>, `-` for empty negative prompt --c <cfg_scale>"
542
+ )
543
+ prompt = input()
544
+ if prompt == "":
545
+ break
546
+
547
+ # parse options
548
+ options = prompt.split("--")
549
+ prompt = options[0].strip()
550
+ seed = None
551
+ negative_prompt = None
552
+ for opt in options[1:]:
553
+ try:
554
+ opt = opt.strip()
555
+ if opt.startswith("w"):
556
+ width = int(opt[1:].strip())
557
+ elif opt.startswith("h"):
558
+ height = int(opt[1:].strip())
559
+ elif opt.startswith("s"):
560
+ steps = int(opt[1:].strip())
561
+ elif opt.startswith("d"):
562
+ seed = int(opt[1:].strip())
563
+ elif opt.startswith("g"):
564
+ guidance = float(opt[1:].strip())
565
+ elif opt.startswith("m"):
566
+ mutipliers = opt[1:].strip().split(",")
567
+ if len(mutipliers) != len(lora_models):
568
+ logger.error(f"Invalid number of multipliers, expected {len(lora_models)}")
569
+ continue
570
+ for i, lora_model in enumerate(lora_models):
571
+ lora_model.set_multiplier(float(mutipliers[i]))
572
+ elif opt.startswith("n"):
573
+ negative_prompt = opt[1:].strip()
574
+ if negative_prompt == "-":
575
+ negative_prompt = ""
576
+ elif opt.startswith("c"):
577
+ cfg_scale = float(opt[1:].strip())
578
+ except ValueError as e:
579
+ logger.error(f"Invalid option: {opt}, {e}")
580
+
581
+ generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance, negative_prompt, cfg_scale)
582
+
583
+ logger.info("Done!")
flux_train_network.py ADDED
@@ -0,0 +1,588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import copy
3
+ import math
4
+ import random
5
+ from typing import Any, Optional, Union
6
+
7
+ import torch
8
+ from accelerate import Accelerator
9
+
10
+ from library.device_utils import clean_memory_on_device, init_ipex
11
+
12
+ init_ipex()
13
+
14
+ import train_network
15
+ from library import (
16
+ flux_models,
17
+ flux_train_utils,
18
+ flux_utils,
19
+ sd3_train_utils,
20
+ strategy_base,
21
+ strategy_flux,
22
+ train_util,
23
+ )
24
+ from library.utils import setup_logging
25
+
26
+ setup_logging()
27
+ import logging
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ class FluxNetworkTrainer(train_network.NetworkTrainer):
33
+ def __init__(self):
34
+ super().__init__()
35
+ self.sample_prompts_te_outputs = None
36
+ self.is_schnell: Optional[bool] = None
37
+ self.is_swapping_blocks: bool = False
38
+
39
+ def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
40
+ super().assert_extra_args(args, train_dataset_group, val_dataset_group)
41
+ # sdxl_train_util.verify_sdxl_training_args(args)
42
+
43
+ if args.fp8_base_unet:
44
+ args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1
45
+
46
+ if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
47
+ logger.warning(
48
+ "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります"
49
+ )
50
+ args.cache_text_encoder_outputs = True
51
+
52
+ if args.cache_text_encoder_outputs:
53
+ assert (
54
+ train_dataset_group.is_text_encoder_output_cacheable()
55
+ ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
56
+
57
+ # prepare CLIP-L/T5XXL training flags
58
+ self.train_clip_l = not args.network_train_unet_only
59
+ self.train_t5xxl = False # default is False even if args.network_train_unet_only is False
60
+
61
+ if args.max_token_length is not None:
62
+ logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")
63
+
64
+ assert (
65
+ args.blocks_to_swap is None or args.blocks_to_swap == 0
66
+ ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
67
+
68
+ # deprecated split_mode option
69
+ if args.split_mode:
70
+ if args.blocks_to_swap is not None:
71
+ logger.warning(
72
+ "split_mode is deprecated. Because `--blocks_to_swap` is set, `--split_mode` is ignored."
73
+ " / split_modeは非推奨です。`--blocks_to_swap`が設定されているため、`--split_mode`は無視されます。"
74
+ )
75
+ else:
76
+ logger.warning(
77
+ "split_mode is deprecated. Please use `--blocks_to_swap` instead. `--blocks_to_swap 18` is automatically set."
78
+ " / split_modeは非推奨です。代わりに`--blocks_to_swap`を使用してください。`--blocks_to_swap 18`が自動的に設定されました。"
79
+ )
80
+ args.blocks_to_swap = 18 # 18 is safe for most cases
81
+
82
+ train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
83
+ if val_dataset_group is not None:
84
+ val_dataset_group.verify_bucket_reso_steps(32) # TODO check this
85
+
86
+ def load_target_model(self, args, weight_dtype, accelerator):
87
+ # currently offload to cpu for some models
88
+
89
+ # if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
90
+ loading_dtype = None if args.fp8_base else weight_dtype
91
+
92
+ # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
93
+ self.is_schnell, model = flux_utils.load_flow_model(
94
+ args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors
95
+ )
96
+ if args.fp8_base:
97
+ # check dtype of model
98
+ if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz:
99
+ raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
100
+ elif model.dtype == torch.float8_e4m3fn:
101
+ logger.info("Loaded fp8 FLUX model")
102
+ else:
103
+ logger.info(
104
+ "Cast FLUX model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint."
105
+ " / FLUXモデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。"
106
+ )
107
+ model.to(torch.float8_e4m3fn)
108
+
109
+ # if args.split_mode:
110
+ # model = self.prepare_split_model(model, weight_dtype, accelerator)
111
+
112
+ self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
113
+ if self.is_swapping_blocks:
114
+ # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
115
+ logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
116
+ model.enable_block_swap(args.blocks_to_swap, accelerator.device)
117
+
118
+ clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
119
+ clip_l.eval()
120
+
121
+ # if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8)
122
+ if args.fp8_base and not args.fp8_base_unet:
123
+ loading_dtype = None # as is
124
+ else:
125
+ loading_dtype = weight_dtype
126
+
127
+ # loading t5xxl to cpu takes a long time, so we should load to gpu in future
128
+ t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
129
+ t5xxl.eval()
130
+ if args.fp8_base and not args.fp8_base_unet:
131
+ # check dtype of model
132
+ if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz:
133
+ raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}")
134
+ elif t5xxl.dtype == torch.float8_e4m3fn:
135
+ logger.info("Loaded fp8 T5XXL model")
136
+
137
+ ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
138
+
139
+ return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model
140
+
141
+ def get_tokenize_strategy(self, args):
142
+ _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
143
+
144
+ if args.t5xxl_max_token_length is None:
145
+ if is_schnell:
146
+ t5xxl_max_token_length = 256
147
+ else:
148
+ t5xxl_max_token_length = 512
149
+ else:
150
+ t5xxl_max_token_length = args.t5xxl_max_token_length
151
+
152
+ logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}")
153
+ return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir)
154
+
155
+ def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy):
156
+ return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl]
157
+
158
+ def get_latents_caching_strategy(self, args):
159
+ latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False)
160
+ return latents_caching_strategy
161
+
162
+ def get_text_encoding_strategy(self, args):
163
+ return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask)
164
+
165
+ def post_process_network(self, args, accelerator, network, text_encoders, unet):
166
+ # check t5xxl is trained or not
167
+ self.train_t5xxl = network.train_t5xxl
168
+
169
+ if self.train_t5xxl and args.cache_text_encoder_outputs:
170
+ raise ValueError(
171
+ "T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません"
172
+ )
173
+
174
+ def get_models_for_text_encoding(self, args, accelerator, text_encoders):
175
+ if args.cache_text_encoder_outputs:
176
+ if self.train_clip_l and not self.train_t5xxl:
177
+ return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached
178
+ else:
179
+ return None # no text encoders are needed for encoding because both are cached
180
+ else:
181
+ return text_encoders # both CLIP-L and T5XXL are needed for encoding
182
+
183
+ def get_text_encoders_train_flags(self, args, text_encoders):
184
+ return [self.train_clip_l, self.train_t5xxl]
185
+
186
+ def get_text_encoder_outputs_caching_strategy(self, args):
187
+ if args.cache_text_encoder_outputs:
188
+ # if the text encoders is trained, we need tokenization, so is_partial is True
189
+ return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
190
+ args.cache_text_encoder_outputs_to_disk,
191
+ args.text_encoder_batch_size,
192
+ args.skip_cache_check,
193
+ is_partial=self.train_clip_l or self.train_t5xxl,
194
+ apply_t5_attn_mask=args.apply_t5_attn_mask,
195
+ )
196
+ else:
197
+ return None
198
+
199
+ def cache_text_encoder_outputs_if_needed(
200
+ self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
201
+ ):
202
+ if args.cache_text_encoder_outputs:
203
+ if not args.lowram:
204
+ # メモリ消費を減らす
205
+ logger.info("move vae and unet to cpu to save memory")
206
+ org_vae_device = vae.device
207
+ org_unet_device = unet.device
208
+ vae.to("cpu")
209
+ unet.to("cpu")
210
+ clean_memory_on_device(accelerator.device)
211
+
212
+ # When TE is not be trained, it will not be prepared so we need to use explicit autocast
213
+ logger.info("move text encoders to gpu")
214
+ text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
215
+ text_encoders[1].to(accelerator.device)
216
+
217
+ if text_encoders[1].dtype == torch.float8_e4m3fn:
218
+ # if we load fp8 weights, the model is already fp8, so we use it as is
219
+ self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
220
+ else:
221
+ # otherwise, we need to convert it to target dtype
222
+ text_encoders[1].to(weight_dtype)
223
+
224
+ with accelerator.autocast():
225
+ dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
226
+
227
+ # cache sample prompts
228
+ if args.sample_prompts is not None:
229
+ logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")
230
+
231
+ tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
232
+ text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
233
+
234
+ prompts = train_util.load_prompts(args.sample_prompts)
235
+ sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
236
+ with accelerator.autocast(), torch.no_grad():
237
+ for prompt_dict in prompts:
238
+ for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
239
+ if p not in sample_prompts_te_outputs:
240
+ logger.info(f"cache Text Encoder outputs for prompt: {p}")
241
+ tokens_and_masks = tokenize_strategy.tokenize(p)
242
+ sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
243
+ tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask
244
+ )
245
+ self.sample_prompts_te_outputs = sample_prompts_te_outputs
246
+
247
+ accelerator.wait_for_everyone()
248
+
249
+ # move back to cpu
250
+ if not self.is_train_text_encoder(args):
251
+ logger.info("move CLIP-L back to cpu")
252
+ text_encoders[0].to("cpu")
253
+ logger.info("move t5XXL back to cpu")
254
+ text_encoders[1].to("cpu")
255
+ clean_memory_on_device(accelerator.device)
256
+
257
+ if not args.lowram:
258
+ logger.info("move vae and unet back to original device")
259
+ vae.to(org_vae_device)
260
+ unet.to(org_unet_device)
261
+ else:
262
+ # Text Encoderから毎回出力を取得するので、GPUに乗せておく
263
+ text_encoders[0].to(accelerator.device, dtype=weight_dtype)
264
+ text_encoders[1].to(accelerator.device)
265
+
266
+ # def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
267
+ # noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
268
+
269
+ # # get size embeddings
270
+ # orig_size = batch["original_sizes_hw"]
271
+ # crop_size = batch["crop_top_lefts"]
272
+ # target_size = batch["target_sizes_hw"]
273
+ # embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
274
+
275
+ # # concat embeddings
276
+ # encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
277
+ # vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
278
+ # text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
279
+
280
+ # noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
281
+ # return noise_pred
282
+
283
+ def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux):
284
+ text_encoders = text_encoder # for compatibility
285
+ text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)
286
+
287
+ flux_train_utils.sample_images(
288
+ accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs
289
+ )
290
+ # return
291
+
292
+ """
293
+ class FluxUpperLowerWrapper(torch.nn.Module):
294
+ def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.FluxLower, device: torch.device):
295
+ super().__init__()
296
+ self.flux_upper = flux_upper
297
+ self.flux_lower = flux_lower
298
+ self.target_device = device
299
+
300
+ def prepare_block_swap_before_forward(self):
301
+ pass
302
+
303
+ def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None):
304
+ self.flux_lower.to("cpu")
305
+ clean_memory_on_device(self.target_device)
306
+ self.flux_upper.to(self.target_device)
307
+ img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance, txt_attention_mask)
308
+ self.flux_upper.to("cpu")
309
+ clean_memory_on_device(self.target_device)
310
+ self.flux_lower.to(self.target_device)
311
+ return self.flux_lower(img, txt, vec, pe, txt_attention_mask)
312
+
313
+ wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device)
314
+ clean_memory_on_device(accelerator.device)
315
+ flux_train_utils.sample_images(
316
+ accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs
317
+ )
318
+ clean_memory_on_device(accelerator.device)
319
+ """
320
+
321
+ def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
322
+ noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
323
+ self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
324
+ return noise_scheduler
325
+
326
+ def encode_images_to_latents(self, args, accelerator, vae, images):
327
+ return vae.encode(images)
328
+
329
+ def shift_scale_latents(self, args, latents):
330
+ return latents
331
+
332
+ def get_noise_pred_and_target(
333
+ self,
334
+ args,
335
+ accelerator,
336
+ noise_scheduler,
337
+ latents,
338
+ batch,
339
+ text_encoder_conds,
340
+ unet: flux_models.Flux,
341
+ network,
342
+ weight_dtype,
343
+ train_unet,
344
+ is_train=True
345
+ ):
346
+ # Sample noise that we'll add to the latents
347
+ noise = torch.randn_like(latents)
348
+ bsz = latents.shape[0]
349
+
350
+ # get noisy model input and timesteps
351
+ noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
352
+ args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
353
+ )
354
+
355
+ # pack latents and get img_ids
356
+ packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
357
+ packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
358
+ img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
359
+
360
+ # get guidance
361
+ # ensure guidance_scale in args is float
362
+ guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
363
+
364
+ # ensure the hidden state will require grad
365
+ if args.gradient_checkpointing:
366
+ noisy_model_input.requires_grad_(True)
367
+ for t in text_encoder_conds:
368
+ if t is not None and t.dtype.is_floating_point:
369
+ t.requires_grad_(True)
370
+ img_ids.requires_grad_(True)
371
+ guidance_vec.requires_grad_(True)
372
+
373
+ # Predict the noise residual
374
+ l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
375
+ if not args.apply_t5_attn_mask:
376
+ t5_attn_mask = None
377
+
378
+ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
379
+ # if not args.split_mode:
380
+ # normal forward
381
+ with torch.set_grad_enabled(is_train), accelerator.autocast():
382
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
383
+ model_pred = unet(
384
+ img=img,
385
+ img_ids=img_ids,
386
+ txt=t5_out,
387
+ txt_ids=txt_ids,
388
+ y=l_pooled,
389
+ timesteps=timesteps / 1000,
390
+ guidance=guidance_vec,
391
+ txt_attention_mask=t5_attn_mask,
392
+ )
393
+ """
394
+ else:
395
+ # split forward to reduce memory usage
396
+ assert network.train_blocks == "single", "train_blocks must be single for split mode"
397
+ with accelerator.autocast():
398
+ # move flux lower to cpu, and then move flux upper to gpu
399
+ unet.to("cpu")
400
+ clean_memory_on_device(accelerator.device)
401
+ self.flux_upper.to(accelerator.device)
402
+
403
+ # upper model does not require grad
404
+ with torch.no_grad():
405
+ intermediate_img, intermediate_txt, vec, pe = self.flux_upper(
406
+ img=packed_noisy_model_input,
407
+ img_ids=img_ids,
408
+ txt=t5_out,
409
+ txt_ids=txt_ids,
410
+ y=l_pooled,
411
+ timesteps=timesteps / 1000,
412
+ guidance=guidance_vec,
413
+ txt_attention_mask=t5_attn_mask,
414
+ )
415
+
416
+ # move flux upper back to cpu, and then move flux lower to gpu
417
+ self.flux_upper.to("cpu")
418
+ clean_memory_on_device(accelerator.device)
419
+ unet.to(accelerator.device)
420
+
421
+ # lower model requires grad
422
+ intermediate_img.requires_grad_(True)
423
+ intermediate_txt.requires_grad_(True)
424
+ vec.requires_grad_(True)
425
+ pe.requires_grad_(True)
426
+
427
+ with torch.set_grad_enabled(is_train and train_unet):
428
+ model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
429
+ """
430
+
431
+ return model_pred
432
+
433
+ model_pred = call_dit(
434
+ img=packed_noisy_model_input,
435
+ img_ids=img_ids,
436
+ t5_out=t5_out,
437
+ txt_ids=txt_ids,
438
+ l_pooled=l_pooled,
439
+ timesteps=timesteps,
440
+ guidance_vec=guidance_vec,
441
+ t5_attn_mask=t5_attn_mask,
442
+ )
443
+
444
+ # unpack latents
445
+ model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
446
+
447
+ # apply model prediction type
448
+ model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
449
+
450
+ # flow matching loss: this is different from SD3
451
+ target = noise - latents
452
+
453
+ # differential output preservation
454
+ if "custom_attributes" in batch:
455
+ diff_output_pr_indices = []
456
+ for i, custom_attributes in enumerate(batch["custom_attributes"]):
457
+ if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
458
+ diff_output_pr_indices.append(i)
459
+
460
+ if len(diff_output_pr_indices) > 0:
461
+ network.set_multiplier(0.0)
462
+ unet.prepare_block_swap_before_forward()
463
+ with torch.no_grad():
464
+ model_pred_prior = call_dit(
465
+ img=packed_noisy_model_input[diff_output_pr_indices],
466
+ img_ids=img_ids[diff_output_pr_indices],
467
+ t5_out=t5_out[diff_output_pr_indices],
468
+ txt_ids=txt_ids[diff_output_pr_indices],
469
+ l_pooled=l_pooled[diff_output_pr_indices],
470
+ timesteps=timesteps[diff_output_pr_indices],
471
+ guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None,
472
+ t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None,
473
+ )
474
+ network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
475
+
476
+ model_pred_prior = flux_utils.unpack_latents(model_pred_prior, packed_latent_height, packed_latent_width)
477
+ model_pred_prior, _ = flux_train_utils.apply_model_prediction_type(
478
+ args,
479
+ model_pred_prior,
480
+ noisy_model_input[diff_output_pr_indices],
481
+ sigmas[diff_output_pr_indices] if sigmas is not None else None,
482
+ )
483
+ target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
484
+
485
+ return model_pred, target, timesteps, weighting
486
+
487
+ def post_process_loss(self, loss, args, timesteps, noise_scheduler):
488
+ return loss
489
+
490
+ def get_sai_model_spec(self, args):
491
+ return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev")
492
+
493
+ def update_metadata(self, metadata, args):
494
+ metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask
495
+ metadata["ss_weighting_scheme"] = args.weighting_scheme
496
+ metadata["ss_logit_mean"] = args.logit_mean
497
+ metadata["ss_logit_std"] = args.logit_std
498
+ metadata["ss_mode_scale"] = args.mode_scale
499
+ metadata["ss_guidance_scale"] = args.guidance_scale
500
+ metadata["ss_timestep_sampling"] = args.timestep_sampling
501
+ metadata["ss_sigmoid_scale"] = args.sigmoid_scale
502
+ metadata["ss_model_prediction_type"] = args.model_prediction_type
503
+ metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
504
+
505
+ def is_text_encoder_not_needed_for_training(self, args):
506
+ return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
507
+
508
+ def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
509
+ if index == 0: # CLIP-L
510
+ return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder)
511
+ else: # T5XXL
512
+ text_encoder.encoder.embed_tokens.requires_grad_(True)
513
+
514
+ def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
515
+ if index == 0: # CLIP-L
516
+ logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}")
517
+ text_encoder.to(te_weight_dtype) # fp8
518
+ text_encoder.text_model.embeddings.to(dtype=weight_dtype)
519
+ else: # T5XXL
520
+
521
+ def prepare_fp8(text_encoder, target_dtype):
522
+ def forward_hook(module):
523
+ def forward(hidden_states):
524
+ hidden_gelu = module.act(module.wi_0(hidden_states))
525
+ hidden_linear = module.wi_1(hidden_states)
526
+ hidden_states = hidden_gelu * hidden_linear
527
+ hidden_states = module.dropout(hidden_states)
528
+
529
+ hidden_states = module.wo(hidden_states)
530
+ return hidden_states
531
+
532
+ return forward
533
+
534
+ for module in text_encoder.modules():
535
+ if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
536
+ # print("set", module.__class__.__name__, "to", target_dtype)
537
+ module.to(target_dtype)
538
+ if module.__class__.__name__ in ["T5DenseGatedActDense"]:
539
+ # print("set", module.__class__.__name__, "hooks")
540
+ module.forward = forward_hook(module)
541
+
542
+ if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype:
543
+ logger.info(f"T5XXL already prepared for fp8")
544
+ else:
545
+ logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks")
546
+ text_encoder.to(te_weight_dtype) # fp8
547
+ prepare_fp8(text_encoder, weight_dtype)
548
+
549
+ def prepare_unet_with_accelerator(
550
+ self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
551
+ ) -> torch.nn.Module:
552
+ if not self.is_swapping_blocks:
553
+ return super().prepare_unet_with_accelerator(args, accelerator, unet)
554
+
555
+ # if we doesn't swap blocks, we can move the model to device
556
+ flux: flux_models.Flux = unet
557
+ flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks])
558
+ accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
559
+ accelerator.unwrap_model(flux).prepare_block_swap_before_forward()
560
+
561
+ return flux
562
+
563
+
564
+ def setup_parser() -> argparse.ArgumentParser:
565
+ parser = train_network.setup_parser()
566
+ train_util.add_dit_training_arguments(parser)
567
+ flux_train_utils.add_flux_train_arguments(parser)
568
+
569
+ parser.add_argument(
570
+ "--split_mode",
571
+ action="store_true",
572
+ # help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required"
573
+ # + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要",
574
+ help="[Deprecated] This option is deprecated. Please use `--blocks_to_swap` instead."
575
+ " / このオプションは非推奨です。代わりに`--blocks_to_swap`を使用してください。",
576
+ )
577
+ return parser
578
+
579
+
580
+ if __name__ == "__main__":
581
+ parser = setup_parser()
582
+
583
+ args = parser.parse_args()
584
+ train_util.verify_command_line_training_args(args)
585
+ args = train_util.read_config_from_file(args, parser)
586
+
587
+ trainer = FluxNetworkTrainer()
588
+ trainer.train(args)
flux_train_network_asylora.py ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+
3
+
4
+ import argparse
5
+ import copy
6
+ import math
7
+ import random
8
+ from typing import Any, Optional
9
+
10
+ import torch
11
+ from accelerate import Accelerator
12
+ from library.device_utils import init_ipex, clean_memory_on_device
13
+
14
+ init_ipex()
15
+
16
+ from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util
17
+ import train_network_asylora
18
+ from library.utils import setup_logging
19
+
20
+ setup_logging()
21
+ import logging
22
+ import re
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class FluxNetworkTrainer(train_network_asylora.NetworkTrainer):
28
+ def __init__(self):
29
+ super().__init__()
30
+ self.sample_prompts_te_outputs = None
31
+ self.is_schnell: Optional[bool] = None
32
+ self.is_swapping_blocks: bool = False
33
+
34
+ def assert_extra_args(self, args, train_dataset_group):
35
+ super().assert_extra_args(args, train_dataset_group)
36
+ # sdxl_train_util.verify_sdxl_training_args(args)
37
+
38
+ if args.fp8_base_unet:
39
+ args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1
40
+
41
+ if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
42
+ logger.warning(
43
+ "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled"
44
+ )
45
+ args.cache_text_encoder_outputs = True
46
+
47
+ if args.cache_text_encoder_outputs:
48
+ assert (
49
+ train_dataset_group.is_text_encoder_output_cacheable()
50
+ ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
51
+
52
+ # prepare CLIP-L/T5XXL training flags
53
+ self.train_clip_l = not args.network_train_unet_only
54
+ self.train_t5xxl = False # default is False even if args.network_train_unet_only is False
55
+
56
+ if args.max_token_length is not None:
57
+ logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")
58
+
59
+ assert (
60
+ args.blocks_to_swap is None or args.blocks_to_swap == 0
61
+ ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
62
+
63
+ # deprecated split_mode option
64
+ if args.split_mode:
65
+ if args.blocks_to_swap is not None:
66
+ logger.warning(
67
+ "split_mode is deprecated. Because `--blocks_to_swap` is set, `--split_mode` is ignored."
68
+ " / split_modeは非推奨です。`--blocks_to_swap`が設定されているため、`--split_mode`は無視されます。"
69
+ )
70
+ else:
71
+ logger.warning(
72
+ "split_mode is deprecated. Please use `--blocks_to_swap` instead. `--blocks_to_swap 18` is automatically set."
73
+ " / split_modeは非推奨です。代わりに`--blocks_to_swap`を使用してください。`--blocks_to_swap 18`が自動的に設定されました。"
74
+ )
75
+ args.blocks_to_swap = 18 # 18 is safe for most cases
76
+
77
+ train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
78
+
79
+ def load_target_model(self, args, weight_dtype, accelerator):
80
+ # currently offload to cpu for some models
81
+
82
+ # if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
83
+ loading_dtype = None if args.fp8_base else weight_dtype
84
+
85
+ # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
86
+ self.is_schnell, model = flux_utils.load_flow_model(
87
+ args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors
88
+ )
89
+ if args.fp8_base:
90
+ # check dtype of model
91
+ if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz:
92
+ raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
93
+ elif model.dtype == torch.float8_e4m3fn:
94
+ logger.info("Loaded fp8 FLUX model")
95
+ else:
96
+ logger.info(
97
+ "Cast FLUX model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint."
98
+ " / FLUXモデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。"
99
+ )
100
+ model.to(torch.float8_e4m3fn)
101
+
102
+ # if args.split_mode:
103
+ # model = self.prepare_split_model(model, weight_dtype, accelerator)
104
+
105
+ self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
106
+ if self.is_swapping_blocks:
107
+ # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
108
+ logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
109
+ model.enable_block_swap(args.blocks_to_swap, accelerator.device)
110
+
111
+ clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
112
+ clip_l.eval()
113
+
114
+ # if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8)
115
+ if args.fp8_base and not args.fp8_base_unet:
116
+ loading_dtype = None # as is
117
+ else:
118
+ loading_dtype = weight_dtype
119
+
120
+ # loading t5xxl to cpu takes a long time, so we should load to gpu in future
121
+ t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
122
+ t5xxl.eval()
123
+ if args.fp8_base and not args.fp8_base_unet:
124
+ # check dtype of model
125
+ if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz:
126
+ raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}")
127
+ elif t5xxl.dtype == torch.float8_e4m3fn:
128
+ logger.info("Loaded fp8 T5XXL model")
129
+
130
+ ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
131
+
132
+ return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model
133
+
134
+ def get_tokenize_strategy(self, args):
135
+ _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
136
+
137
+ if args.t5xxl_max_token_length is None:
138
+ if is_schnell:
139
+ t5xxl_max_token_length = 256
140
+ else:
141
+ t5xxl_max_token_length = 512
142
+ else:
143
+ t5xxl_max_token_length = args.t5xxl_max_token_length
144
+
145
+ logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}")
146
+ return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir)
147
+
148
+ def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy):
149
+ return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl]
150
+
151
+ def get_latents_caching_strategy(self, args):
152
+ latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False)
153
+ return latents_caching_strategy
154
+
155
+ def get_text_encoding_strategy(self, args):
156
+ return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask)
157
+
158
+ def post_process_network(self, args, accelerator, network, text_encoders, unet):
159
+ # check t5xxl is trained or not
160
+ self.train_t5xxl = network.train_t5xxl
161
+
162
+ if self.train_t5xxl and args.cache_text_encoder_outputs:
163
+ raise ValueError(
164
+ "T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません"
165
+ )
166
+
167
+ def get_models_for_text_encoding(self, args, accelerator, text_encoders):
168
+ if args.cache_text_encoder_outputs:
169
+ if self.train_clip_l and not self.train_t5xxl:
170
+ return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached
171
+ else:
172
+ return None # no text encoders are needed for encoding because both are cached
173
+ else:
174
+ return text_encoders # both CLIP-L and T5XXL are needed for encoding
175
+
176
+ def get_text_encoders_train_flags(self, args, text_encoders):
177
+ return [self.train_clip_l, self.train_t5xxl]
178
+
179
+ def get_text_encoder_outputs_caching_strategy(self, args):
180
+ if args.cache_text_encoder_outputs:
181
+ # if the text encoders is trained, we need tokenization, so is_partial is True
182
+ return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
183
+ args.cache_text_encoder_outputs_to_disk,
184
+ args.text_encoder_batch_size,
185
+ args.skip_cache_check,
186
+ is_partial=self.train_clip_l or self.train_t5xxl,
187
+ apply_t5_attn_mask=args.apply_t5_attn_mask,
188
+ )
189
+ else:
190
+ return None
191
+
192
+ def cache_text_encoder_outputs_if_needed(
193
+ self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
194
+ ):
195
+ if args.cache_text_encoder_outputs:
196
+ if not args.lowram:
197
+ # メモリ消費を減らす
198
+ logger.info("move vae and unet to cpu to save memory")
199
+ org_vae_device = vae.device
200
+ org_unet_device = unet.device
201
+ vae.to("cpu")
202
+ unet.to("cpu")
203
+ clean_memory_on_device(accelerator.device)
204
+
205
+ # When TE is not be trained, it will not be prepared so we need to use explicit autocast
206
+ logger.info("move text encoders to gpu")
207
+ text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
208
+ text_encoders[1].to(accelerator.device)
209
+
210
+ if text_encoders[1].dtype == torch.float8_e4m3fn:
211
+ # if we load fp8 weights, the model is already fp8, so we use it as is
212
+ self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
213
+ else:
214
+ # otherwise, we need to convert it to target dtype
215
+ text_encoders[1].to(weight_dtype)
216
+
217
+ with accelerator.autocast():
218
+ dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
219
+
220
+ # cache sample prompts
221
+ if args.sample_prompts is not None:
222
+ logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")
223
+
224
+ tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
225
+ text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
226
+
227
+ prompts = train_util.load_prompts(args.sample_prompts)
228
+ sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
229
+ with accelerator.autocast(), torch.no_grad():
230
+ for prompt_dict in prompts:
231
+ for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
232
+ if p not in sample_prompts_te_outputs:
233
+ logger.info(f"cache Text Encoder outputs for prompt: {p}")
234
+ tokens_and_masks = tokenize_strategy.tokenize(p)
235
+ sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
236
+ tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask
237
+ )
238
+ self.sample_prompts_te_outputs = sample_prompts_te_outputs
239
+
240
+ accelerator.wait_for_everyone()
241
+
242
+ # move back to cpu
243
+ if not self.is_train_text_encoder(args):
244
+ logger.info("move CLIP-L back to cpu")
245
+ text_encoders[0].to("cpu")
246
+ logger.info("move t5XXL back to cpu")
247
+ text_encoders[1].to("cpu")
248
+ clean_memory_on_device(accelerator.device)
249
+
250
+ if not args.lowram:
251
+ logger.info("move vae and unet back to original device")
252
+ vae.to(org_vae_device)
253
+ unet.to(org_unet_device)
254
+ else:
255
+ # Text Encoderから毎回出力を取得するので、GPUに乗せておく
256
+ text_encoders[0].to(accelerator.device, dtype=weight_dtype)
257
+ text_encoders[1].to(accelerator.device)
258
+
259
+ # def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
260
+ # noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
261
+
262
+ # # get size embeddings
263
+ # orig_size = batch["original_sizes_hw"]
264
+ # crop_size = batch["crop_top_lefts"]
265
+ # target_size = batch["target_sizes_hw"]
266
+ # embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
267
+
268
+ # # concat embeddings
269
+ # encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
270
+ # vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
271
+ # text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
272
+
273
+ # noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
274
+ # return noise_pred
275
+
276
+ def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux):
277
+ text_encoders = text_encoder # for compatibility
278
+ text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)
279
+
280
+ flux_train_utils.sample_images(
281
+ accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs
282
+ )
283
+ # return
284
+
285
+ """
286
+ class FluxUpperLowerWrapper(torch.nn.Module):
287
+ def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.FluxLower, device: torch.device):
288
+ super().__init__()
289
+ self.flux_upper = flux_upper
290
+ self.flux_lower = flux_lower
291
+ self.target_device = device
292
+
293
+ def prepare_block_swap_before_forward(self):
294
+ pass
295
+
296
+ def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None):
297
+ self.flux_lower.to("cpu")
298
+ clean_memory_on_device(self.target_device)
299
+ self.flux_upper.to(self.target_device)
300
+ img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance, txt_attention_mask)
301
+ self.flux_upper.to("cpu")
302
+ clean_memory_on_device(self.target_device)
303
+ self.flux_lower.to(self.target_device)
304
+ return self.flux_lower(img, txt, vec, pe, txt_attention_mask)
305
+
306
+ wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device)
307
+ clean_memory_on_device(accelerator.device)
308
+ flux_train_utils.sample_images(
309
+ accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs
310
+ )
311
+ clean_memory_on_device(accelerator.device)
312
+ """
313
+
314
+ def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
315
+ noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
316
+ self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
317
+ return noise_scheduler
318
+
319
+ def encode_images_to_latents(self, args, accelerator, vae, images):
320
+ return vae.encode(images)
321
+
322
+ def shift_scale_latents(self, args, latents):
323
+ return latents
324
+
325
+ def get_noise_pred_and_target(
326
+ self,
327
+ args,
328
+ accelerator,
329
+ noise_scheduler,
330
+ latents,
331
+ batch,
332
+ text_encoder_conds,
333
+ unet: flux_models.Flux,
334
+ network,
335
+ weight_dtype,
336
+ train_unet,
337
+ ):
338
+ # Sample noise that we'll add to the latents
339
+ noise = torch.randn_like(latents)
340
+ bsz = latents.shape[0]
341
+
342
+ # get noisy model input and timesteps
343
+ noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
344
+ args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
345
+ )
346
+
347
+ # pack latents and get img_ids
348
+ packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
349
+ packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
350
+ img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
351
+
352
+ # get guidance
353
+ # ensure guidance_scale in args is float
354
+ guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
355
+
356
+ # ensure the hidden state will require grad
357
+ if args.gradient_checkpointing:
358
+ noisy_model_input.requires_grad_(True)
359
+ for t in text_encoder_conds:
360
+ if t is not None and t.dtype.is_floating_point:
361
+ t.requires_grad_(True)
362
+ img_ids.requires_grad_(True)
363
+ guidance_vec.requires_grad_(True)
364
+
365
+ # Predict the noise residual
366
+ l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
367
+ if not args.apply_t5_attn_mask:
368
+ t5_attn_mask = None
369
+
370
+ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
371
+ # if not args.split_mode:
372
+ # normal forward
373
+ with accelerator.autocast():
374
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
375
+ model_pred = unet(
376
+ img=img,
377
+ img_ids=img_ids,
378
+ txt=t5_out,
379
+ txt_ids=txt_ids,
380
+ y=l_pooled,
381
+ timesteps=timesteps / 1000,
382
+ guidance=guidance_vec,
383
+ txt_attention_mask=t5_attn_mask
384
+ )
385
+ """
386
+ else:
387
+ # split forward to reduce memory usage
388
+ assert network.train_blocks == "single", "train_blocks must be single for split mode"
389
+ with accelerator.autocast():
390
+ # move flux lower to cpu, and then move flux upper to gpu
391
+ unet.to("cpu")
392
+ clean_memory_on_device(accelerator.device)
393
+ self.flux_upper.to(accelerator.device)
394
+
395
+ # upper model does not require grad
396
+ with torch.no_grad():
397
+ intermediate_img, intermediate_txt, vec, pe = self.flux_upper(
398
+ img=packed_noisy_model_input,
399
+ img_ids=img_ids,
400
+ txt=t5_out,
401
+ txt_ids=txt_ids,
402
+ y=l_pooled,
403
+ timesteps=timesteps / 1000,
404
+ guidance=guidance_vec,
405
+ txt_attention_mask=t5_attn_mask,
406
+ )
407
+
408
+ # move flux upper back to cpu, and then move flux lower to gpu
409
+ self.flux_upper.to("cpu")
410
+ clean_memory_on_device(accelerator.device)
411
+ unet.to(accelerator.device)
412
+
413
+ # lower model requires grad
414
+ intermediate_img.requires_grad_(True)
415
+ intermediate_txt.requires_grad_(True)
416
+ vec.requires_grad_(True)
417
+ pe.requires_grad_(True)
418
+ model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
419
+ """
420
+
421
+ return model_pred
422
+
423
+ # 获取数据集分类编号 文本
424
+ # lora_category = batch["captions"][0].split(",")[0][3:]
425
+ # assert lora_category.isdigit(), f"lora_category 不是整数,值为: {lora_category}, {batch['captions'][0]}"
426
+ # lora_category = int(lora_category)
427
+
428
+ prompt_cur = batch["captions"][0]
429
+ match = re.search(r'--lora_up_cur (\d+)', prompt_cur)
430
+ assert match, "Pattern '--lora_up_cur' not found"
431
+ lora_category = int(match.group(1))
432
+
433
+ for lora in network.unet_loras:
434
+ lora.set_lora_up_cur(lora_category-1)
435
+
436
+ model_pred = call_dit(
437
+ img=packed_noisy_model_input,
438
+ img_ids=img_ids,
439
+ t5_out=t5_out,
440
+ txt_ids=txt_ids,
441
+ l_pooled=l_pooled,
442
+ timesteps=timesteps,
443
+ guidance_vec=guidance_vec,
444
+ t5_attn_mask=t5_attn_mask
445
+ )
446
+
447
+ # unpack latents
448
+ model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
449
+
450
+ # apply model prediction type
451
+ model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
452
+
453
+ # flow matching loss: this is different from SD3
454
+ target = noise - latents
455
+
456
+ # differential output preservation
457
+ if "custom_attributes" in batch:
458
+ diff_output_pr_indices = []
459
+ for i, custom_attributes in enumerate(batch["custom_attributes"]):
460
+ if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
461
+ diff_output_pr_indices.append(i)
462
+
463
+ if len(diff_output_pr_indices) > 0:
464
+ network.set_multiplier(0.0)
465
+ unet.prepare_block_swap_before_forward()
466
+ with torch.no_grad():
467
+ model_pred_prior = call_dit(
468
+ img=packed_noisy_model_input[diff_output_pr_indices],
469
+ img_ids=img_ids[diff_output_pr_indices],
470
+ t5_out=t5_out[diff_output_pr_indices],
471
+ txt_ids=txt_ids[diff_output_pr_indices],
472
+ l_pooled=l_pooled[diff_output_pr_indices],
473
+ timesteps=timesteps[diff_output_pr_indices],
474
+ guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None,
475
+ t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None,
476
+ )
477
+ network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
478
+
479
+ model_pred_prior = flux_utils.unpack_latents(model_pred_prior, packed_latent_height, packed_latent_width)
480
+ model_pred_prior, _ = flux_train_utils.apply_model_prediction_type(
481
+ args,
482
+ model_pred_prior,
483
+ noisy_model_input[diff_output_pr_indices],
484
+ sigmas[diff_output_pr_indices] if sigmas is not None else None,
485
+ )
486
+ target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
487
+
488
+ return model_pred, target, timesteps, None, weighting
489
+
490
+ def post_process_loss(self, loss, args, timesteps, noise_scheduler):
491
+ return loss
492
+
493
+ def get_sai_model_spec(self, args):
494
+ return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev")
495
+
496
+ def update_metadata(self, metadata, args):
497
+ metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask
498
+ metadata["ss_weighting_scheme"] = args.weighting_scheme
499
+ metadata["ss_logit_mean"] = args.logit_mean
500
+ metadata["ss_logit_std"] = args.logit_std
501
+ metadata["ss_mode_scale"] = args.mode_scale
502
+ metadata["ss_guidance_scale"] = args.guidance_scale
503
+ metadata["ss_timestep_sampling"] = args.timestep_sampling
504
+ metadata["ss_sigmoid_scale"] = args.sigmoid_scale
505
+ metadata["ss_model_prediction_type"] = args.model_prediction_type
506
+ metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
507
+
508
+ def is_text_encoder_not_needed_for_training(self, args):
509
+ return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
510
+
511
+ def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
512
+ if index == 0: # CLIP-L
513
+ return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder)
514
+ else: # T5XXL
515
+ text_encoder.encoder.embed_tokens.requires_grad_(True)
516
+
517
+ def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
518
+ if index == 0: # CLIP-L
519
+ logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}")
520
+ text_encoder.to(te_weight_dtype) # fp8
521
+ text_encoder.text_model.embeddings.to(dtype=weight_dtype)
522
+ else: # T5XXL
523
+
524
+ def prepare_fp8(text_encoder, target_dtype):
525
+ def forward_hook(module):
526
+ def forward(hidden_states):
527
+ hidden_gelu = module.act(module.wi_0(hidden_states))
528
+ hidden_linear = module.wi_1(hidden_states)
529
+ hidden_states = hidden_gelu * hidden_linear
530
+ hidden_states = module.dropout(hidden_states)
531
+
532
+ hidden_states = module.wo(hidden_states)
533
+ return hidden_states
534
+
535
+ return forward
536
+
537
+ for module in text_encoder.modules():
538
+ if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
539
+ # print("set", module.__class__.__name__, "to", target_dtype)
540
+ module.to(target_dtype)
541
+ if module.__class__.__name__ in ["T5DenseGatedActDense"]:
542
+ # print("set", module.__class__.__name__, "hooks")
543
+ module.forward = forward_hook(module)
544
+
545
+ if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype:
546
+ logger.info(f"T5XXL already prepared for fp8")
547
+ else:
548
+ logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks")
549
+ text_encoder.to(te_weight_dtype) # fp8
550
+ prepare_fp8(text_encoder, weight_dtype)
551
+
552
+ def prepare_unet_with_accelerator(
553
+ self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
554
+ ) -> torch.nn.Module:
555
+ if not self.is_swapping_blocks:
556
+ return super().prepare_unet_with_accelerator(args, accelerator, unet)
557
+
558
+ # if we doesn't swap blocks, we can move the model to device
559
+ flux: flux_models.Flux = unet
560
+ flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks])
561
+ accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
562
+ accelerator.unwrap_model(flux).prepare_block_swap_before_forward()
563
+
564
+ return flux
565
+
566
+
567
+ def setup_parser() -> argparse.ArgumentParser:
568
+ parser = train_network_asylora.setup_parser()
569
+ train_util.add_dit_training_arguments(parser)
570
+ flux_train_utils.add_flux_train_arguments(parser)
571
+
572
+ parser.add_argument(
573
+ "--split_mode",
574
+ action="store_true",
575
+ # help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required"
576
+ # + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要",
577
+ help="[Deprecated] This option is deprecated. Please use `--blocks_to_swap` instead."
578
+ " / このオプションは非推奨です。代わりに`--blocks_to_swap`を使用してください。",
579
+ )
580
+ return parser
581
+
582
+
583
+ if __name__ == "__main__":
584
+ parser = setup_parser()
585
+
586
+ args = parser.parse_args()
587
+ train_util.verify_command_line_training_args(args)
588
+ args = train_util.read_config_from_file(args, parser)
589
+
590
+ trainer = FluxNetworkTrainer()
591
+ trainer.train(args)
flux_train_recraft.py ADDED
@@ -0,0 +1,713 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import copy
3
+ import math
4
+ import random
5
+ from typing import Any
6
+ import pdb
7
+
8
+ import torch
9
+ from accelerate import Accelerator
10
+ from library.device_utils import init_ipex, clean_memory_on_device
11
+
12
+ init_ipex()
13
+
14
+ from library import flux_models, flux_train_utils_recraft as flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util
15
+ from torchvision import transforms
16
+ import train_network
17
+ from library.utils import setup_logging
18
+ from diffusers.utils import load_image
19
+ import numpy as np
20
+ from PIL import Image, ImageOps
21
+
22
+ setup_logging()
23
+ import logging
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+ # NUM_SPLIT = 2
28
+
29
+ class ResizeWithPadding:
30
+ def __init__(self, size, fill=255):
31
+ self.size = size
32
+ self.fill = fill
33
+
34
+ def __call__(self, img):
35
+ if isinstance(img, np.ndarray):
36
+ img = Image.fromarray(img)
37
+ elif not isinstance(img, Image.Image):
38
+ raise TypeError("Input must be a PIL Image or a NumPy array")
39
+
40
+ width, height = img.size
41
+
42
+ if width == height:
43
+ img = img.resize((self.size, self.size), Image.LANCZOS)
44
+ else:
45
+ max_dim = max(width, height)
46
+
47
+ new_img = Image.new("RGB", (max_dim, max_dim), (self.fill, self.fill, self.fill))
48
+ new_img.paste(img, ((max_dim - width) // 2, (max_dim - height) // 2))
49
+
50
+ img = new_img.resize((self.size, self.size), Image.LANCZOS)
51
+
52
+ return img
53
+
54
+ class FluxNetworkTrainer(train_network.NetworkTrainer):
55
+ def __init__(self):
56
+ super().__init__()
57
+ self.sample_prompts_te_outputs = None
58
+ self.sample_conditions = None
59
+ self.is_schnell: Optional[bool] = None
60
+
61
+ def assert_extra_args(self, args, train_dataset_group):
62
+ super().assert_extra_args(args, train_dataset_group)
63
+ # sdxl_train_util.verify_sdxl_training_args(args)
64
+
65
+ if args.fp8_base_unet:
66
+ args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1
67
+
68
+ if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
69
+ logger.warning(
70
+ "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります"
71
+ )
72
+ args.cache_text_encoder_outputs = True
73
+
74
+ if args.cache_text_encoder_outputs:
75
+ assert (
76
+ train_dataset_group.is_text_encoder_output_cacheable()
77
+ ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
78
+
79
+ # prepare CLIP-L/T5XXL training flags
80
+ self.train_clip_l = not args.network_train_unet_only
81
+ self.train_t5xxl = False # default is False even if args.network_train_unet_only is False
82
+
83
+ if args.max_token_length is not None:
84
+ logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")
85
+
86
+ assert not args.split_mode or not args.cpu_offload_checkpointing, (
87
+ "split_mode and cpu_offload_checkpointing cannot be used together"
88
+ " / split_modeとcpu_offload_checkpointingは同時に使用できません"
89
+ )
90
+
91
+ train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
92
+
93
+ def load_target_model(self, args, weight_dtype, accelerator):
94
+ # currently offload to cpu for some models
95
+
96
+ # if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
97
+ loading_dtype = None if args.fp8_base else weight_dtype
98
+
99
+ # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
100
+ self.is_schnell, model = flux_utils.load_flow_model(
101
+ args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors
102
+ )
103
+ if args.fp8_base:
104
+ # check dtype of model
105
+ if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz:
106
+ raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
107
+ elif model.dtype == torch.float8_e4m3fn:
108
+ logger.info("Loaded fp8 FLUX model")
109
+
110
+ if args.split_mode:
111
+ model = self.prepare_split_model(model, weight_dtype, accelerator)
112
+
113
+ clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
114
+ clip_l.eval()
115
+
116
+ # if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8)
117
+ if args.fp8_base and not args.fp8_base_unet:
118
+ loading_dtype = None # as is
119
+ else:
120
+ loading_dtype = weight_dtype
121
+
122
+ # loading t5xxl to cpu takes a long time, so we should load to gpu in future
123
+ t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
124
+ t5xxl.eval()
125
+ if args.fp8_base and not args.fp8_base_unet:
126
+ # check dtype of model
127
+ if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz:
128
+ raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}")
129
+ elif t5xxl.dtype == torch.float8_e4m3fn:
130
+ logger.info("Loaded fp8 T5XXL model")
131
+
132
+ ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
133
+
134
+ return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model
135
+
136
+ def prepare_split_model(self, model, weight_dtype, accelerator):
137
+ from accelerate import init_empty_weights
138
+
139
+ logger.info("prepare split model")
140
+ with init_empty_weights():
141
+ flux_upper = flux_models.FluxUpper(model.params)
142
+ flux_lower = flux_models.FluxLower(model.params)
143
+ sd = model.state_dict()
144
+
145
+ # lower (trainable)
146
+ logger.info("load state dict for lower")
147
+ flux_lower.load_state_dict(sd, strict=False, assign=True)
148
+ flux_lower.to(dtype=weight_dtype)
149
+
150
+ # upper (frozen)
151
+ logger.info("load state dict for upper")
152
+ flux_upper.load_state_dict(sd, strict=False, assign=True)
153
+
154
+ logger.info("prepare upper model")
155
+ target_dtype = torch.float8_e4m3fn if args.fp8_base else weight_dtype
156
+ flux_upper.to(accelerator.device, dtype=target_dtype)
157
+ flux_upper.eval()
158
+
159
+ if args.fp8_base:
160
+ # this is required to run on fp8
161
+ flux_upper = accelerator.prepare(flux_upper)
162
+
163
+ flux_upper.to("cpu")
164
+
165
+ self.flux_upper = flux_upper
166
+ del model # we don't need model anymore
167
+ clean_memory_on_device(accelerator.device)
168
+
169
+ logger.info("split model prepared")
170
+
171
+ return flux_lower
172
+
173
+ def get_tokenize_strategy(self, args):
174
+ _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
175
+
176
+ if args.t5xxl_max_token_length is None:
177
+ if is_schnell:
178
+ t5xxl_max_token_length = 256
179
+ else:
180
+ t5xxl_max_token_length = 512
181
+ else:
182
+ t5xxl_max_token_length = args.t5xxl_max_token_length
183
+
184
+ logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}")
185
+ return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir)
186
+
187
+ def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy):
188
+ return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl]
189
+
190
+ def get_latents_caching_strategy(self, args):
191
+ latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False)
192
+ return latents_caching_strategy
193
+
194
+ def get_text_encoding_strategy(self, args):
195
+ return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask)
196
+
197
+ def post_process_network(self, args, accelerator, network, text_encoders, unet):
198
+ # check t5xxl is trained or not
199
+ self.train_t5xxl = network.train_t5xxl
200
+
201
+ if self.train_t5xxl and args.cache_text_encoder_outputs:
202
+ raise ValueError(
203
+ "T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません"
204
+ )
205
+
206
+ def get_models_for_text_encoding(self, args, accelerator, text_encoders):
207
+ if args.cache_text_encoder_outputs:
208
+ if self.train_clip_l and not self.train_t5xxl:
209
+ return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached
210
+ else:
211
+ return None # no text encoders are needed for encoding because both are cached
212
+ else:
213
+ return text_encoders # both CLIP-L and T5XXL are needed for encoding
214
+
215
+ def get_text_encoders_train_flags(self, args, text_encoders):
216
+ return [self.train_clip_l, self.train_t5xxl]
217
+
218
+ def get_text_encoder_outputs_caching_strategy(self, args):
219
+ if args.cache_text_encoder_outputs:
220
+ # if the text encoders is trained, we need tokenization, so is_partial is True
221
+ return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
222
+ args.cache_text_encoder_outputs_to_disk,
223
+ args.text_encoder_batch_size,
224
+ args.skip_cache_check,
225
+ is_partial=self.train_clip_l or self.train_t5xxl,
226
+ apply_t5_attn_mask=args.apply_t5_attn_mask,
227
+ )
228
+ else:
229
+ return None
230
+
231
+ def cache_text_encoder_outputs_if_needed(
232
+ self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
233
+ ):
234
+ if args.cache_text_encoder_outputs:
235
+ if not args.lowram:
236
+ # メモリ消費を減らす
237
+ logger.info("move vae and unet to cpu to save memory")
238
+ org_vae_device = vae.device
239
+ org_unet_device = unet.device
240
+ vae.to("cpu")
241
+ unet.to("cpu")
242
+ clean_memory_on_device(accelerator.device)
243
+
244
+ # When TE is not be trained, it will not be prepared so we need to use explicit autocast
245
+ logger.info("move text encoders to gpu")
246
+ text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
247
+ text_encoders[1].to(accelerator.device)
248
+
249
+ if text_encoders[1].dtype == torch.float8_e4m3fn:
250
+ # if we load fp8 weights, the model is already fp8, so we use it as is
251
+ self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
252
+ else:
253
+ # otherwise, we need to convert it to target dtype
254
+ text_encoders[1].to(weight_dtype)
255
+
256
+ with accelerator.autocast():
257
+ dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
258
+
259
+ # cache sample prompts
260
+ if args.sample_prompts is not None:
261
+ logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")
262
+
263
+ tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
264
+ text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
265
+
266
+ prompts = train_util.load_prompts(args.sample_prompts)
267
+ sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
268
+ with accelerator.autocast(), torch.no_grad():
269
+ for prompt_dict in prompts:
270
+ for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
271
+ if p not in sample_prompts_te_outputs:
272
+ logger.info(f"cache Text Encoder outputs for prompt: {p}")
273
+ tokens_and_masks = tokenize_strategy.tokenize(p)
274
+ sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
275
+ tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask
276
+ )
277
+ self.sample_prompts_te_outputs = sample_prompts_te_outputs
278
+
279
+ # 添加conditions缓存逻辑
280
+ if args.sample_images is not None:
281
+ logger.info(f"cache conditions for sample images: {args.sample_images}")
282
+
283
+ # lc03lc
284
+ resize_transform = ResizeWithPadding(size=512, fill=255) if args.frame_num == 4 else ResizeWithPadding(size=352, fill=255)
285
+ img_transforms = transforms.Compose([
286
+ resize_transform,
287
+ transforms.ToTensor(),
288
+ transforms.Normalize([0.5], [0.5]),
289
+ ])
290
+
291
+ if args.sample_images.endswith(".txt"):
292
+ with open(args.sample_images, "r", encoding="utf-8") as f:
293
+ lines = f.readlines()
294
+ sample_images = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
295
+ else:
296
+ raise NotImplementedError(f"sample_images file format not supported: {args.sample_images}")
297
+
298
+ prompts = train_util.load_prompts(args.sample_prompts)
299
+ conditions = {} # key: prompt, value: latents
300
+
301
+ with torch.no_grad():
302
+ for image, prompt_dict in zip(sample_images, prompts):
303
+ prompt = prompt_dict.get("prompt", "")
304
+ if prompt not in conditions:
305
+ logger.info(f"cache conditions for image: {image} with prompt: {prompt}")
306
+ image = img_transforms(np.array(load_image(image), dtype=np.uint8)).unsqueeze(0).to(vae.device, dtype=vae.dtype)
307
+ latents = self.encode_images_to_latents2(args, accelerator, vae, image)
308
+ # lc03lc
309
+ conditions[prompt] = latents
310
+ # if args.frame_num == 4:
311
+ # conditions[prompt] = latents[:,:,2*latents.shape[2]//3:latents.shape[2], 2*latents.shape[3]//3:latents.shape[3]].to("cpu")
312
+ # else:
313
+ # conditions[prompt] = latents[:,:,latents.shape[2]//2:latents.shape[2], :latents.shape[3]//2].to("cpu")
314
+
315
+ self.sample_conditions = conditions
316
+
317
+ accelerator.wait_for_everyone()
318
+
319
+ # move back to cpu
320
+ if not self.is_train_text_encoder(args):
321
+ logger.info("move CLIP-L back to cpu")
322
+ text_encoders[0].to("cpu")
323
+ logger.info("move t5XXL back to cpu")
324
+ text_encoders[1].to("cpu")
325
+ clean_memory_on_device(accelerator.device)
326
+
327
+ if not args.lowram:
328
+ logger.info("move vae and unet back to original device")
329
+ vae.to(org_vae_device)
330
+ unet.to(org_unet_device)
331
+ else:
332
+ # Text Encoderから毎回出力を取得するので、GPUに乗せておく
333
+ text_encoders[0].to(accelerator.device, dtype=weight_dtype)
334
+ text_encoders[1].to(accelerator.device)
335
+
336
+ # def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
337
+ # noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
338
+
339
+ # # get size embeddings
340
+ # orig_size = batch["original_sizes_hw"]
341
+ # crop_size = batch["crop_top_lefts"]
342
+ # target_size = batch["target_sizes_hw"]
343
+ # embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
344
+
345
+ # # concat embeddings
346
+ # encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
347
+ # vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
348
+ # text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
349
+
350
+ # noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
351
+ # return noise_pred
352
+
353
+ def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux):
354
+ text_encoders = text_encoder # for compatibility
355
+ text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)
356
+ # 直接使用预先计算的conditions
357
+ conditions = None
358
+ if self.sample_conditions is not None:
359
+ conditions = {k: v.to(accelerator.device) for k, v in self.sample_conditions.items()}
360
+
361
+ if not args.split_mode:
362
+ flux_train_utils.sample_images(
363
+ accelerator, args, epoch, global_step, flux, ae, text_encoder, self.sample_prompts_te_outputs, None, conditions
364
+ )
365
+ return
366
+
367
+ class FluxUpperLowerWrapper(torch.nn.Module):
368
+ def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.FluxLower, device: torch.device):
369
+ super().__init__()
370
+ self.flux_upper = flux_upper
371
+ self.flux_lower = flux_lower
372
+ self.target_device = device
373
+
374
+ def prepare_block_swap_before_forward(self):
375
+ pass
376
+
377
+ def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None):
378
+ self.flux_lower.to("cpu")
379
+ clean_memory_on_device(self.target_device)
380
+ self.flux_upper.to(self.target_device)
381
+ img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance, txt_attention_mask)
382
+ self.flux_upper.to("cpu")
383
+ clean_memory_on_device(self.target_device)
384
+ self.flux_lower.to(self.target_device)
385
+ return self.flux_lower(img, txt, vec, pe, txt_attention_mask)
386
+
387
+ wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device)
388
+ clean_memory_on_device(accelerator.device)
389
+ flux_train_utils.sample_images(
390
+ accelerator, args, epoch, global_step, wrapper, ae, text_encoder, self.sample_prompts_te_outputs, conditions
391
+ )
392
+ clean_memory_on_device(accelerator.device)
393
+
394
+ def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
395
+ noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
396
+ self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
397
+ return noise_scheduler
398
+
399
+ def encode_images_to_latents(self, args, accelerator, vae, images):
400
+ # 获取图像尺寸
401
+ b, c, h, w = images.shape
402
+
403
+ # num_split = NUM_SPLIT
404
+ num_split = 2 if args.frame_num == 4 else 3
405
+ # 将图像分成三个部分
406
+ img_parts = [images[:,:,:,i*w//num_split:(i+1)*w//num_split] for i in range(num_split)]
407
+ # 分别编码
408
+ latents = [vae.encode(img) for img in img_parts]
409
+ # 在latent空间拼接回完整图像
410
+ latents = torch.cat(latents, dim=-1)
411
+
412
+ return latents
413
+
414
+ def encode_images_to_latents2(self, args, accelerator, vae, images):
415
+ # 获取图像尺寸
416
+ b, c, h, w = images.shape
417
+ # num_split = NUM_SPLIT
418
+ num_split = 2 if args.frame_num == 4 else 3
419
+ latents = vae.encode(images)
420
+ return latents
421
+
422
+ def encode_images_to_latents3(self, args, accelerator, vae, images):
423
+ b, c, h, w = images.shape
424
+ # Number of splits along each dimension
425
+ num_split = 3
426
+ # Check if the image can be evenly divided into 3x3 grid
427
+ assert h % num_split == 0 and w % num_split == 0, "Image dimensions must be divisible by 3."
428
+
429
+ # Height and width of each split
430
+ split_h, split_w = h // num_split, w // num_split
431
+
432
+ # Store latents for each split
433
+ latents = []
434
+
435
+ for i in range(num_split):
436
+ for j in range(num_split):
437
+ # Extract the (i, j) sub-image
438
+ img_part = images[:, :, i * split_h:(i + 1) * split_h, j * split_w:(j + 1) * split_w]
439
+ # Encode the sub-image using VAE
440
+ latent = vae.encode(img_part)
441
+ # Append the latent
442
+ latents.append(latent)
443
+
444
+ # Combine latents into a 3x3 grid in the latent space
445
+ # Latents list -> Tensor [num_split^2, b, latent_dim, h', w']
446
+ latents = torch.stack(latents, dim=0)
447
+
448
+ # Reshape into a 3x3 grid
449
+ # Shape: [num_split, num_split, b, latent_dim, h', w']
450
+ latents = latents.view(num_split, num_split, b, *latents.shape[2:])
451
+
452
+ # Combine the 3x3 grid along height and width in latent space
453
+ # Concatenate along width for each row, then concatenate rows along height
454
+ latents = torch.cat([torch.cat(latents[i], dim=-1) for i in range(num_split)], dim=-2)
455
+
456
+ # Final shape: [b, latent_dim, h', w']
457
+ return latents
458
+
459
+ def shift_scale_latents(self, args, latents):
460
+ return latents
461
+
462
+ def get_noise_pred_and_target(
463
+ self,
464
+ args,
465
+ accelerator,
466
+ noise_scheduler,
467
+ latents,
468
+ batch,
469
+ text_encoder_conds,
470
+ unet: flux_models.Flux,
471
+ network,
472
+ weight_dtype,
473
+ train_unet,
474
+ ):
475
+ # Sample noise that we'll add to the latents
476
+ noise = torch.randn_like(latents)
477
+ bsz = latents.shape[0]
478
+
479
+ # get noisy model input and timesteps
480
+ noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
481
+ args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
482
+ )
483
+
484
+ # pack latents and get img_ids
485
+ # yiren ? need modify?
486
+ packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
487
+ packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
488
+ img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
489
+
490
+ # get guidance
491
+ # ensure guidance_scale in args is float
492
+ guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
493
+
494
+ # ensure the hidden state will require grad
495
+ if args.gradient_checkpointing:
496
+ noisy_model_input.requires_grad_(True)
497
+ for t in text_encoder_conds:
498
+ if t is not None and t.dtype.is_floating_point:
499
+ t.requires_grad_(True)
500
+ img_ids.requires_grad_(True)
501
+ guidance_vec.requires_grad_(True)
502
+
503
+ # Predict the noise residual
504
+ l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
505
+ if not args.apply_t5_attn_mask:
506
+ t5_attn_mask = None
507
+
508
+ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
509
+ if not args.split_mode:
510
+ # normal forward
511
+ with accelerator.autocast():
512
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
513
+ model_pred = unet(
514
+ img=img,
515
+ img_ids=img_ids,
516
+ txt=t5_out,
517
+ txt_ids=txt_ids,
518
+ y=l_pooled,
519
+ timesteps=timesteps / 1000,
520
+ guidance=guidance_vec,
521
+ txt_attention_mask=t5_attn_mask,
522
+ )
523
+ else:
524
+ # split forward to reduce memory usage
525
+ assert network.train_blocks == "single", "train_blocks must be single for split mode"
526
+ with accelerator.autocast():
527
+ # move flux lower to cpu, and then move flux upper to gpu
528
+ unet.to("cpu")
529
+ clean_memory_on_device(accelerator.device)
530
+ self.flux_upper.to(accelerator.device)
531
+
532
+ # upper model does not require grad
533
+ with torch.no_grad():
534
+ intermediate_img, intermediate_txt, vec, pe = self.flux_upper(
535
+ img=packed_noisy_model_input,
536
+ img_ids=img_ids,
537
+ txt=t5_out,
538
+ txt_ids=txt_ids,
539
+ y=l_pooled,
540
+ timesteps=timesteps / 1000,
541
+ guidance=guidance_vec,
542
+ txt_attention_mask=t5_attn_mask,
543
+ )
544
+
545
+ # move flux upper back to cpu, and then move flux lower to gpu
546
+ self.flux_upper.to("cpu")
547
+ clean_memory_on_device(accelerator.device)
548
+ unet.to(accelerator.device)
549
+
550
+ # lower model requires grad
551
+ intermediate_img.requires_grad_(True)
552
+ intermediate_txt.requires_grad_(True)
553
+ vec.requires_grad_(True)
554
+ pe.requires_grad_(True)
555
+ model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
556
+
557
+ return model_pred
558
+
559
+ model_pred = call_dit(
560
+ img=packed_noisy_model_input,
561
+ img_ids=img_ids,
562
+ t5_out=t5_out,
563
+ txt_ids=txt_ids,
564
+ l_pooled=l_pooled,
565
+ timesteps=timesteps,
566
+ guidance_vec=guidance_vec,
567
+ t5_attn_mask=t5_attn_mask,
568
+ )
569
+
570
+ # unpack latents
571
+ model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
572
+
573
+ # apply model prediction type
574
+ model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
575
+
576
+ # flow matching loss: this is different from SD3
577
+ target = noise - latents
578
+
579
+ # differential output preservation
580
+ if "custom_attributes" in batch:
581
+ diff_output_pr_indices = []
582
+ for i, custom_attributes in enumerate(batch["custom_attributes"]):
583
+ if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
584
+ diff_output_pr_indices.append(i)
585
+
586
+ if len(diff_output_pr_indices) > 0:
587
+ network.set_multiplier(0.0)
588
+ with torch.no_grad():
589
+ model_pred_prior = call_dit(
590
+ img=packed_noisy_model_input[diff_output_pr_indices],
591
+ img_ids=img_ids[diff_output_pr_indices],
592
+ t5_out=t5_out[diff_output_pr_indices],
593
+ txt_ids=txt_ids[diff_output_pr_indices],
594
+ l_pooled=l_pooled[diff_output_pr_indices],
595
+ timesteps=timesteps[diff_output_pr_indices],
596
+ guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None,
597
+ t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None,
598
+ )
599
+ network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
600
+
601
+ model_pred_prior = flux_utils.unpack_latents(model_pred_prior, packed_latent_height, packed_latent_width)
602
+ model_pred_prior, _ = flux_train_utils.apply_model_prediction_type(
603
+ args,
604
+ model_pred_prior,
605
+ noisy_model_input[diff_output_pr_indices],
606
+ sigmas[diff_output_pr_indices] if sigmas is not None else None,
607
+ )
608
+ target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
609
+
610
+ # elimilate the loss in the left top quarter of the image
611
+ h, w = target.shape[2], target.shape[3]
612
+ # num_split = NUM_SPLIT
613
+ num_split = 2 if args.frame_num == 4 else 3
614
+ # target[:, :, :, :w//num_split] = model_pred[:, :, :, :w//num_split]
615
+ # target[:, :, :, :w//num_split] = model_pred[:, :, :, :w//num_split]
616
+ target[:, :, 2*h//num_split:h, 2*w//num_split:w] = model_pred[:, :, 2*h//num_split:h, 2*w//num_split:w]
617
+
618
+
619
+ return model_pred, target, timesteps, None, weighting
620
+
621
+ def post_process_loss(self, loss, args, timesteps, noise_scheduler):
622
+ return loss
623
+
624
+ def get_sai_model_spec(self, args):
625
+ return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev")
626
+
627
+ def update_metadata(self, metadata, args):
628
+ metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask
629
+ metadata["ss_weighting_scheme"] = args.weighting_scheme
630
+ metadata["ss_logit_mean"] = args.logit_mean
631
+ metadata["ss_logit_std"] = args.logit_std
632
+ metadata["ss_mode_scale"] = args.mode_scale
633
+ metadata["ss_guidance_scale"] = args.guidance_scale
634
+ metadata["ss_timestep_sampling"] = args.timestep_sampling
635
+ metadata["ss_sigmoid_scale"] = args.sigmoid_scale
636
+ metadata["ss_model_prediction_type"] = args.model_prediction_type
637
+ metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
638
+
639
+ def is_text_encoder_not_needed_for_training(self, args):
640
+ return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
641
+
642
+ def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
643
+ if index == 0: # CLIP-L
644
+ return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder)
645
+ else: # T5XXL
646
+ text_encoder.encoder.embed_tokens.requires_grad_(True)
647
+
648
+ def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
649
+ if index == 0: # CLIP-L
650
+ logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}")
651
+ text_encoder.to(te_weight_dtype) # fp8
652
+ text_encoder.text_model.embeddings.to(dtype=weight_dtype)
653
+ else: # T5XXL
654
+
655
+ def prepare_fp8(text_encoder, target_dtype):
656
+ def forward_hook(module):
657
+ def forward(hidden_states):
658
+ hidden_gelu = module.act(module.wi_0(hidden_states))
659
+ hidden_linear = module.wi_1(hidden_states)
660
+ hidden_states = hidden_gelu * hidden_linear
661
+ hidden_states = module.dropout(hidden_states)
662
+
663
+ hidden_states = module.wo(hidden_states)
664
+ return hidden_states
665
+
666
+ return forward
667
+
668
+ for module in text_encoder.modules():
669
+ if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
670
+ # print("set", module.__class__.__name__, "to", target_dtype)
671
+ module.to(target_dtype)
672
+ if module.__class__.__name__ in ["T5DenseGatedActDense"]:
673
+ # print("set", module.__class__.__name__, "hooks")
674
+ module.forward = forward_hook(module)
675
+
676
+ if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype:
677
+ logger.info(f"T5XXL already prepared for fp8")
678
+ else:
679
+ logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks")
680
+ text_encoder.to(te_weight_dtype) # fp8
681
+ prepare_fp8(text_encoder, weight_dtype)
682
+
683
+
684
+ def setup_parser() -> argparse.ArgumentParser:
685
+ parser = train_network.setup_parser()
686
+ flux_train_utils.add_flux_train_arguments(parser)
687
+
688
+ parser.add_argument(
689
+ "--split_mode",
690
+ action="store_true",
691
+ help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required"
692
+ + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要",
693
+ )
694
+
695
+ parser.add_argument(
696
+ '--frame_num',
697
+ type=int,
698
+ choices=[4, 9],
699
+ required=True,
700
+ help="The number of steps in the generated step diagram (choose 4 or 9)"
701
+ )
702
+ return parser
703
+
704
+
705
+ if __name__ == "__main__":
706
+ parser = setup_parser()
707
+
708
+ args = parser.parse_args()
709
+ train_util.verify_command_line_training_args(args)
710
+ args = train_util.read_config_from_file(args, parser)
711
+
712
+ trainer = FluxNetworkTrainer()
713
+ trainer.train(args)
gradio_app.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import torch
4
+ import numpy as np
5
+ import random
6
+ from PIL import Image
7
+ from accelerate import Accelerator
8
+ import os
9
+ import time
10
+ from torchvision import transforms
11
+ from safetensors.torch import load_file
12
+ from networks import lora_flux
13
+ from library import flux_utils, flux_train_utils_recraft as flux_train_utils, strategy_flux
14
+ import logging
15
+
16
+ # Set up logger
17
+ logger = logging.getLogger(__name__)
18
+ logging.basicConfig(level=logging.DEBUG)
19
+
20
+ # Ensure necessary devices are available
21
+ device = "cuda" if torch.cuda.is_available() else "cpu"
22
+ accelerator = Accelerator(mixed_precision='bf16', device_placement=True)
23
+
24
+ # Model paths (replace these with your actual model paths)
25
+ BASE_FLUX_CHECKPOINT="/tiamat-NAS/songyiren/FYP/liucheng/sd-scripts/MergeModel/6_Portrait/6_Portrait.safetensors"
26
+ LORA_WEIGHTS_PATH="/tiamat-NAS/songyiren/FYP/liucheng/sd-scripts/RecraftModel/6_Portrait/6_Portrait-step00025000.safetensors"
27
+ CLIP_L_PATH="/tiamat-NAS/hailong/storage_backup/models/stabilityai/stable-diffusion-3-medium/text_encoders/clip_l.safetensors"
28
+ T5XXL_PATH="/tiamat-NAS/hailong/storage_backup/models/stabilityai/stable-diffusion-3-medium/text_encoders/t5xxl_fp16.safetensors"
29
+ AE_PATH="/tiamat-vePFS/share_data/storage/huggingface/models/black-forest-labs/FLUX.1-dev/ae.safetensors"
30
+
31
+ # Load model function
32
+ def load_target_model():
33
+ logger.info("Loading models...")
34
+ try:
35
+ _, model = flux_utils.load_flow_model(
36
+ BASE_FLUX_CHECKPOINT, torch.float8_e4m3fn, "cpu", disable_mmap=False
37
+ )
38
+ clip_l = flux_utils.load_clip_l(CLIP_L_PATH, torch.bfloat16, "cpu", disable_mmap=False)
39
+ clip_l.eval()
40
+ t5xxl = flux_utils.load_t5xxl(T5XXL_PATH, torch.bfloat16, "cpu", disable_mmap=False)
41
+ t5xxl.eval()
42
+ ae = flux_utils.load_ae(AE_PATH, torch.bfloat16, "cpu", disable_mmap=False)
43
+ logger.info("Models loaded successfully.")
44
+ return model, [clip_l, t5xxl], ae
45
+ except Exception as e:
46
+ logger.error(f"Error loading models: {e}")
47
+ raise
48
+
49
+ # Image pre-processing (resize and padding)
50
+ class ResizeWithPadding:
51
+ def __init__(self, size, fill=255):
52
+ self.size = size
53
+ self.fill = fill
54
+
55
+ def __call__(self, img):
56
+ if isinstance(img, np.ndarray):
57
+ img = Image.fromarray(img)
58
+ elif not isinstance(img, Image.Image):
59
+ raise TypeError("Input must be a PIL Image or a NumPy array")
60
+
61
+ width, height = img.size
62
+
63
+ if width == height:
64
+ img = img.resize((self.size, self.size), Image.LANCZOS)
65
+ else:
66
+ max_dim = max(width, height)
67
+ new_img = Image.new("RGB", (max_dim, max_dim), (self.fill, self.fill, self.fill))
68
+ new_img.paste(img, ((max_dim - width) // 2, (max_dim - height) // 2))
69
+ img = new_img.resize((self.size, self.size), Image.LANCZOS)
70
+ return img
71
+
72
+ # The function to generate image from a prompt and conditional image
73
+ def infer(prompt, sample_image, frame_num, seed=0, randomize_seed=False):
74
+ logger.info(f"Started generating image with prompt: {prompt}")
75
+
76
+ # Load models
77
+ model, [clip_l, t5xxl], ae = load_target_model()
78
+
79
+ model.eval()
80
+ clip_l.eval()
81
+ t5xxl.eval()
82
+ ae.eval()
83
+
84
+ # LoRA
85
+ multiplier = 1.0
86
+ weights_sd = load_file(LORA_WEIGHTS_PATH)
87
+ lora_model, _ = lora_flux.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd,
88
+ True)
89
+
90
+ lora_model.apply_to([clip_l, t5xxl], model)
91
+ info = lora_model.load_state_dict(weights_sd, strict=True)
92
+ logger.info(f"Loaded LoRA weights from {LORA_WEIGHTS_PATH}: {info}")
93
+ lora_model.eval()
94
+ lora_model.to("cuda")
95
+
96
+ # Process the seed
97
+ if randomize_seed:
98
+ seed = random.randint(0, np.iinfo(np.int32).max)
99
+ logger.debug(f"Using seed: {seed}")
100
+
101
+ # Preprocess the conditional image
102
+ resize_transform = ResizeWithPadding(size=512) if frame_num == 4 else ResizeWithPadding(size=352)
103
+ img_transforms = transforms.Compose([
104
+ resize_transform,
105
+ transforms.ToTensor(),
106
+ transforms.Normalize([0.5], [0.5]),
107
+ ])
108
+ image = img_transforms(np.array(sample_image, dtype=np.uint8)).unsqueeze(0).to(
109
+ device=device,
110
+ dtype=torch.bfloat16
111
+ )
112
+ logger.debug("Conditional image preprocessed.")
113
+
114
+ # Encode the image to latents
115
+ ae.to("cuda")
116
+ latents = ae.encode(image)
117
+ logger.debug("Image encoded to latents.")
118
+
119
+ conditions = {}
120
+ conditions[prompt] = latents.to("cpu")
121
+
122
+ ae.to("cpu")
123
+ clip_l.to("cuda")
124
+ t5xxl.to("cuda")
125
+
126
+ # Encode the prompt
127
+ tokenize_strategy = strategy_flux.FluxTokenizeStrategy(512)
128
+ text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(True)
129
+ tokens_and_masks = tokenize_strategy.tokenize(prompt)
130
+ l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, True)
131
+
132
+ logger.debug("Prompt encoded.")
133
+
134
+ # Prepare the noise and other parameters
135
+ width = 1024 if frame_num == 4 else 1056
136
+ height = 1024 if frame_num == 4 else 1056
137
+
138
+ height = max(64, height - height % 16)
139
+ width = max(64, width - width % 16)
140
+
141
+ packed_latent_height = height // 16
142
+ packed_latent_width = width // 16
143
+
144
+ noise = torch.randn(1, packed_latent_height * packed_latent_width, 16 * 2 * 2, device=device, dtype=torch.float16)
145
+ logger.debug("Noise prepared.")
146
+
147
+ # Generate the image
148
+ timesteps = flux_train_utils.get_schedule(20, noise.shape[1], shift=True) # Sample steps = 20
149
+ img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(device)
150
+
151
+ t5_attn_mask = t5_attn_mask.to(device)
152
+ ae_outputs = conditions[prompt]
153
+
154
+ logger.debug("Image generation parameters set.")
155
+
156
+ args = lambda: None
157
+ args.frame_num = frame_num
158
+
159
+ clip_l.to("cpu")
160
+ t5xxl.to("cpu")
161
+
162
+ torch.cuda.empty_cache()
163
+ model.to("cuda")
164
+
165
+ # import pdb
166
+ # pdb.set_trace()
167
+
168
+ # Run the denoising process
169
+ with accelerator.autocast(), torch.no_grad():
170
+ x = flux_train_utils.denoise(
171
+ args, model, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=1.0, t5_attn_mask=t5_attn_mask, ae_outputs=ae_outputs
172
+ )
173
+ logger.debug("Denoising process completed.")
174
+
175
+ # Decode the final image
176
+ x = x.float()
177
+ x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
178
+ model.to("cpu")
179
+ ae.to("cuda")
180
+ with accelerator.autocast(), torch.no_grad():
181
+ x = ae.decode(x)
182
+ logger.debug("Latents decoded into image.")
183
+ ae.to("cpu")
184
+
185
+ # Convert the tensor to an image
186
+ x = x.clamp(-1, 1)
187
+ x = x.permute(0, 2, 3, 1)
188
+ generated_image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
189
+
190
+ logger.info("Image generation completed.")
191
+ return generated_image
192
+
193
+ # Gradio interface
194
+ with gr.Blocks() as demo:
195
+ gr.Markdown("## FLUX Image Generation")
196
+
197
+ with gr.Row():
198
+ # Input for the prompt
199
+ prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here", lines=1)
200
+
201
+ # File upload for image
202
+ sample_image = gr.Image(label="Upload a Conditional Image", type="pil")
203
+
204
+ # Frame number selection
205
+ frame_num = gr.Radio([4, 9], label="Select Frame Number", value=4)
206
+
207
+ # Seed and randomize seed options
208
+ seed = gr.Slider(0, np.iinfo(np.int32).max, step=1, label="Seed", value=0)
209
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
210
+
211
+ # Run Button
212
+ run_button = gr.Button("Generate Image")
213
+
214
+ # Output result
215
+ result_image = gr.Image(label="Generated Image")
216
+
217
+ run_button.click(
218
+ fn=infer,
219
+ inputs=[prompt, sample_image, frame_num, seed, randomize_seed],
220
+ outputs=[result_image]
221
+ )
222
+
223
+ # Launch the Gradio app
224
+ demo.launch(server_port=8289, server_name="0.0.0.0", share=True)
225
+
226
+
227
+ # prompt = "1girl"
228
+ # sample_image = Image.open("/tiamat-NAS/songyiren/FYP/liucheng/sd-scripts/MergeModel/test/1.png") # 使用一个测试图像
229
+ # frame_num = 9
230
+ # seed = 42
231
+ # randomize_seed = False
232
+ # result = infer(prompt, sample_image, frame_num, seed, randomize_seed)
233
+ # result.save('asy_results/generated_image.png')
id_rsa ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN OPENSSH PRIVATE KEY-----
2
+ b3BlbnNzaC1rZXktdjEAAAAACmFlczI1Ni1jdHIAAAAGYmNyeXB0AAAAGAAAABAcjqc3DU
3
+ L8b9ja3ALHmkowAAAAGAAAAAEAAAIXAAAAB3NzaC1yc2EAAAADAQABAAACAQDE/44P8+xW
4
+ HQVsebSAQhmu0nHf82prYlt2OMWH/xCCzvs8b9Z8QppPRQd0mexgyvk5jxttXXT22nQdSP
5
+ ILQsvaOSFJJGBR/RWUI0dNFdR3CdJqWWH331YYEsYe+SWwpUhVSyg0Ys9KhjkWYSVpwv+V
6
+ VKf65QVydeH4OkCe/zuJClIWOl7U/XEwksLjSj/FzIinWADamdnREBbGHK62l5Y4gHTtGI
7
+ DoHqvqbGnOWbDic3YSEyOwXdA4bkjqabuQKVAvsr/p9I6+OqAv2yHAHAG6nSclHz7a8pWi
8
+ FhTeQXVDMOB9W34mfyqN8/lcFlNLgYPab757p4fjk6Ox6TjCSNI6JfJeX/yQeqBwi4pR8h
9
+ YbOvui+NnHLy2bgWBxPinDb752qDpXQuy65aDGYWWUhGLd8LWaTQkx3LD/+iGR5bOwV18X
10
+ HuexVVh4pSGVidJ3m1fqpb7AGGcjL0inTvdnz4r3oGfhFREvuhjQo/O1mzgxElcRK2P+xG
11
+ nJ2eZzC8iwx4UjCotkTdQ3S/2uVI7CFsOnV3oqQ403xcp5OzanH/e/0sTukkM4UWUmUXIG
12
+ QZxicctZJCndVxffNHFutCL2uIwr2evpwwBjE5GasL4VcxaQlcP1KMwPjYyzZUR62qWRDg
13
+ sQNlTx+QQWRl4VWEU/jzvehXMza1vl9bbCo+Y9aSVmAQAAB1DCpgt/NwB5QcPJ3XXuqDJ+
14
+ NCS9iuXhqKGy0L/S0jwuxYNpHjoJ2fptVJo5fgIjGJI3TcMqOHlRNVIZG8e2Cw9YbVivat
15
+ 4WIZq57ZH645uE1XcngdbjxEZfCYozUATaD4piXWYAcsmszlG/r540lBSpGzxB6YESJJgk
16
+ KNh9+gXgMjZCk3NCTplvb3qdvoYYOvagGE68CqtZn10OavMgTwymG1fzrOQ9Biuq27VOhz
17
+ zS5ZYqdynWu/F7Vl3QxdS0Mbap7ratRdQlKgDx6WslsXVQwTKoDjyI0tykOVQxEjlwVhTu
18
+ loQRTJd43Po2BByMclbCM84rlZEViK/jaBc+evsjV0H6CqhJ0Sh28rfSuM/Rb1tYTzktEh
19
+ x7vd5fcTHRFFupxdAcusOaI4QfRNTQPyX+ShwS3q1CqbgcxWZutysPMpq4f2w7sa2L0Tu1
20
+ qUmGCxEkfKSUHO2sk0lIXlNG9p6cxR+fv5D3AcSK1LceyZuYLk45gGHpMlBPDs8R2grX1J
21
+ mTkO7D+EpHcCd0mbEZ37YBLOtERSBNohgVgOEc3SCTRQu5LrCj+or/47b7T2dU/XL/BE/6
22
+ oD1bYk+cSgT16xYOmqKlE7f+StbOVASkQ32Pfnv5b+JsbQRiEZXe9YKoHRRw4KRDiHNjMG
23
+ Zb+N8b4hpf2tOn6aWUVH471ciDlJBcbdyLQaupFM5QRwGobSdoAaQrmFIPzpIZ1rOVE8Jp
24
+ TeZ6CmpxiRNWeFXGojOp7x4GJc7oghdA9loJ5LRVD0Mw2K5VrD32+PMjJ1OZNJrXTZ+FW9
25
+ ujmjocpbL456Cuqr1yEFWzfy5liC9CItiEEp2sDRXjVnDwDOdO1DJe+1E+0ciTqkizxFnw
26
+ M/vTgF8BSVC+Vfg559A/0JRPsxyv4FdEpuOWOART2vnam+XgxjgJ4AoBDC44YNhlzkHNZ4
27
+ I08neqz45hs7B8pbMlE6DxUcJ2dKzzwIcImX0G4QGUdIdGyMC8y10hqI29XBMTa18LDy2i
28
+ n+tH3Q5E3hoRZVPpIeU+B9bLIznclwaOsDyFF+6E2ET4/Wwl9uiQA9XY9HCL2H1pUMr0Mi
29
+ QD04/qF6P2GCQqeI0w4NHtZRlQPiVdExtpHpgIPyyfCFM6G5p+3GH9ZL4/IMcj9R1ziiW8
30
+ 6cZUR1kFHoueAMv63Q4hOzWiaGxdbrYbZbmmp0geO6tXDbeBbBi3YmBHwu0MQheZDbjMxW
31
+ EWYtawjK4wcOnoFkoMh+hU5AheuZJsHKUKEVxq2XauFjpc6dJQaCAVZqoVwMaalS12T8Y1
32
+ m6+Qa/V5vOZvsGdRB80iJOzA5aUbHb203jG7AmIuOhCaW8DQdY1ai69s6wXV7XXS/A0nh9
33
+ vjZiFXxb/nCNNJhk1LrzrcvqXl3LRx/5NTRMt+Bp12bsA9fP0U8tlMPHtGq+ctQosS1nWh
34
+ qrJ2lw1P2fsxWjRVu/4t+n1Hv19tz2stTwcBvUrzYaOvC5SsD05xypRLOfca/QsTlqcqSl
35
+ P4rnP43y5ACsBSXAsnQ2dIfumZFtE+yUy41CXotLbj2MwrPmXFDwh7mNcWYwOz49YnlL5m
36
+ j06liu12UDZd/HVxehHj7HSkbS1AXlvbwKK7kUpmY/ZPhi7ElqtdWHOBg+ixgTdLZVGQdV
37
+ lSAArBMkiuc3CsFNPRJqNtcCLrRm2BYkX3ZCaR9jQ19a2jgBPex4FtOxSgZCcH101RXnKv
38
+ bqnn1hMbJnOUeaMU2w3SqUs2s3sR5RONnn97tfnKGwZRzzg7UGtlxs0yy/MqIl6FTe+HrD
39
+ tz/6mK7afM8fp+AoKqrE0S6atjcuIiu71BJ7T7NlOi/AobNZAmX1FeE8qSxg+IV8Indkn/
40
+ wkVpQ+WnbxPNmsBxJbSLRUNjOQahyzi7+e/sB1eBI1zlEjLW/rnk3SMX2xD3umQsYK6Zce
41
+ MsamiTJh+qrV9NF/+6y80iYtcsKkTU+MESETFJuWAcD0lfP+jjJPGN0XiCumaACMKPtYFz
42
+ qMyiPzOnxirPikNTZQ7FAM/YnVauwBjLG/NCmTsHKtCsJYSlNFn+OFCcyAeZjKJtwev6Aq
43
+ nf+l8YK0ieOdjhPTGzyBYbo6r/XndiVN1bKxS9xT88nzn3ZPKggT2vd5bUm133F2Z9qeRz
44
+ o9XyEW/CuoxjfBLbxIqCJh7Ow5oBDA4YYCxBjyP8Q2TlrjVpntTYh+fkAzeo9H8uNQ4Jop
45
+ Mmc/56IObSiLVzTcyRaNHcnx8hDi0QjgOPKmIcVzqdStgmyZX0nBxiboF9nPIUICkUUgXN
46
+ SqidyUZRjTGWcmuVESrpfS5eDYD7pgpwpogsC9Cr7wyXbZSWaCnYvmZz3B+UAoo0BNH4az
47
+ dHl8ZC7dhmPFUJj+gCxkOrYA1lkYczR23g2EpwVAELQnjHPzShtvLYniWIKtqBPk7tLrvK
48
+ 4PRb3IvyEGFIRNjvyGewsB4FJce470EK1aGJDRDf7r0UCxM5KU5dMUZQs529EWVkO1gt6l
49
+ y1uJAmV7ycLKV43ExoZhP5wwE=
50
+ -----END OPENSSH PRIVATE KEY-----
requirements.txt CHANGED
@@ -1,6 +1,47 @@
1
- accelerate
2
- diffusers
3
- invisible_watermark
4
- torch
5
- transformers
6
- xformers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.33.0
2
+ transformers==4.44.0
3
+ diffusers[torch]==0.25.0
4
+ ftfy==6.1.1
5
+ # albumentations==1.3.0
6
+ opencv-python==4.8.1.78
7
+ einops==0.7.0
8
+ pytorch-lightning==1.9.0
9
+ bitsandbytes==0.44.0
10
+ prodigyopt==1.0
11
+ lion-pytorch==0.0.6
12
+ came_pytorch==0.1.3
13
+ schedulefree==1.4
14
+ tensorboard
15
+ safetensors==0.4.4
16
+ # gradio==3.16.2
17
+ altair==4.2.2
18
+ easygui==0.98.3
19
+ toml==0.10.2
20
+ voluptuous==0.13.1
21
+ huggingface-hub==0.24.5
22
+ # for Image utils
23
+ imagesize==1.4.1
24
+ numpy<=2.0
25
+ # for BLIP captioning
26
+ # requests==2.28.2
27
+ # timm==0.6.12
28
+ # fairscale==0.4.13
29
+ # for WD14 captioning (tensorflow)
30
+ # tensorflow==2.10.1
31
+ # for WD14 captioning (onnx)
32
+ # onnx==1.15.0
33
+ # onnxruntime-gpu==1.17.1
34
+ # onnxruntime==1.17.1
35
+ # for cuda 12.1(default 11.8)
36
+ # onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/
37
+
38
+ # this is for onnx:
39
+ # protobuf==3.20.3
40
+ # open clip for SDXL
41
+ # open-clip-torch==2.20.0
42
+ # For logging
43
+ rich==13.7.0
44
+ # for T5XXL tokenizer (SD3/FLUX)
45
+ sentencepiece==0.2.0
46
+ # for kohya_ss library
47
+ -e .
setup.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(name = "library", packages = find_packages())
split_asylora.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from safetensors import safe_open
4
+ from safetensors.torch import save_file
5
+
6
+ def main():
7
+ parser = argparse.ArgumentParser()
8
+ parser.add_argument('--asylora_path', type=str, required=True, help="Path to the input asylora file.")
9
+ parser.add_argument('--output_path', type=str, required=True, help="Path to save the modified safetensors file.")
10
+ parser.add_argument('--lora_up', type=int, required=True, help="The target lora_up value.")
11
+
12
+ args = parser.parse_args()
13
+
14
+ output_dir = os.path.dirname(args.output_path)
15
+ if not os.path.exists(output_dir):
16
+ os.makedirs(output_dir)
17
+
18
+ with safe_open(args.asylora_path, framework="pt") as f:
19
+ tensor_dict = {key: f.get_tensor(key) for key in f.keys()}
20
+
21
+ modified_dict = {}
22
+
23
+ for key, tensor in tensor_dict.items():
24
+ if 'lora_ups' in key:
25
+ lora_up_index = int(key.split('.')[2])
26
+ if lora_up_index != args.lora_up - 1:
27
+ continue
28
+ else:
29
+ new_key = key.replace(f'lora_ups.{lora_up_index}.', 'lora_up.')
30
+ modified_dict[new_key] = tensor
31
+ else:
32
+ modified_dict[key] = tensor
33
+
34
+ save_file(modified_dict, args.output_path)
35
+
36
+ if __name__ == "__main__":
37
+ main()
train_network.py ADDED
@@ -0,0 +1,1479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import argparse
3
+ import math
4
+ import os
5
+ import sys
6
+ import random
7
+ import time
8
+ import json
9
+ from multiprocessing import Value
10
+ from typing import Any, List
11
+ import toml
12
+
13
+ from tqdm import tqdm
14
+
15
+ import torch
16
+ from library.device_utils import init_ipex, clean_memory_on_device
17
+
18
+ init_ipex()
19
+
20
+ from accelerate.utils import set_seed
21
+ from diffusers import DDPMScheduler
22
+ from library import deepspeed_utils, model_util, strategy_base, strategy_sd
23
+
24
+ import library.train_util as train_util
25
+ from library.train_util import DreamBoothDataset
26
+ import library.config_util as config_util
27
+ from library.config_util import (
28
+ ConfigSanitizer,
29
+ BlueprintGenerator,
30
+ )
31
+ import library.huggingface_util as huggingface_util
32
+ import library.custom_train_functions as custom_train_functions
33
+ from library.custom_train_functions import (
34
+ apply_snr_weight,
35
+ get_weighted_text_embeddings,
36
+ prepare_scheduler_for_custom_training,
37
+ scale_v_prediction_loss_like_noise_prediction,
38
+ add_v_prediction_like_loss,
39
+ apply_debiased_estimation,
40
+ apply_masked_loss,
41
+ )
42
+ from library.utils import setup_logging, add_logging_arguments
43
+
44
+ setup_logging()
45
+ import logging
46
+
47
+ logger = logging.getLogger(__name__)
48
+
49
+
50
+ class NetworkTrainer:
51
+ def __init__(self):
52
+ self.vae_scale_factor = 0.18215
53
+ self.is_sdxl = False
54
+
55
+ # TODO 他のスクリプトと共通化する
56
+ def generate_step_logs(
57
+ self,
58
+ args: argparse.Namespace,
59
+ current_loss,
60
+ avr_loss,
61
+ lr_scheduler,
62
+ lr_descriptions,
63
+ keys_scaled=None,
64
+ mean_norm=None,
65
+ maximum_norm=None,
66
+ ):
67
+ logs = {"loss/current": current_loss, "loss/average": avr_loss}
68
+
69
+ if keys_scaled is not None:
70
+ logs["max_norm/keys_scaled"] = keys_scaled
71
+ logs["max_norm/average_key_norm"] = mean_norm
72
+ logs["max_norm/max_key_norm"] = maximum_norm
73
+
74
+ lrs = lr_scheduler.get_last_lr()
75
+ for i, lr in enumerate(lrs):
76
+ if lr_descriptions is not None:
77
+ lr_desc = lr_descriptions[i]
78
+ else:
79
+ idx = i - (0 if args.network_train_unet_only else -1)
80
+ if idx == -1:
81
+ lr_desc = "textencoder"
82
+ else:
83
+ if len(lrs) > 2:
84
+ lr_desc = f"group{idx}"
85
+ else:
86
+ lr_desc = "unet"
87
+
88
+ logs[f"lr/{lr_desc}"] = lr
89
+
90
+ if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
91
+ # tracking d*lr value
92
+ logs[f"lr/d*lr/{lr_desc}"] = (
93
+ lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
94
+ )
95
+
96
+ return logs
97
+
98
+ def assert_extra_args(self, args, train_dataset_group):
99
+ train_dataset_group.verify_bucket_reso_steps(64)
100
+
101
+ def load_target_model(self, args, weight_dtype, accelerator):
102
+ text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
103
+
104
+ # モデルに xformers とか memory efficient attention を組み込む
105
+ train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
106
+ if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
107
+ vae.set_use_memory_efficient_attention_xformers(args.xformers)
108
+
109
+ return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet
110
+
111
+ def get_tokenize_strategy(self, args):
112
+ return strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
113
+
114
+ def get_tokenizers(self, tokenize_strategy: strategy_sd.SdTokenizeStrategy) -> List[Any]:
115
+ return [tokenize_strategy.tokenizer]
116
+
117
+ def get_latents_caching_strategy(self, args):
118
+ latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
119
+ True, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
120
+ )
121
+ return latents_caching_strategy
122
+
123
+ def get_text_encoding_strategy(self, args):
124
+ return strategy_sd.SdTextEncodingStrategy(args.clip_skip)
125
+
126
+ def get_text_encoder_outputs_caching_strategy(self, args):
127
+ return None
128
+
129
+ def get_models_for_text_encoding(self, args, accelerator, text_encoders):
130
+ """
131
+ Returns a list of models that will be used for text encoding. SDXL uses wrapped and unwrapped models.
132
+ FLUX.1 and SD3 may cache some outputs of the text encoder, so return the models that will be used for encoding (not cached).
133
+ """
134
+ return text_encoders
135
+
136
+ # returns a list of bool values indicating whether each text encoder should be trained
137
+ def get_text_encoders_train_flags(self, args, text_encoders):
138
+ return [True] * len(text_encoders) if self.is_train_text_encoder(args) else [False] * len(text_encoders)
139
+
140
+ def is_train_text_encoder(self, args):
141
+ return not args.network_train_unet_only
142
+
143
+ def cache_text_encoder_outputs_if_needed(self, args, accelerator, unet, vae, text_encoders, dataset, weight_dtype):
144
+ for t_enc in text_encoders:
145
+ t_enc.to(accelerator.device, dtype=weight_dtype)
146
+
147
+ def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype, **kwargs):
148
+ noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample
149
+ return noise_pred
150
+
151
+ def all_reduce_network(self, accelerator, network):
152
+ for param in network.parameters():
153
+ if param.grad is not None:
154
+ param.grad = accelerator.reduce(param.grad, reduction="mean")
155
+
156
+ def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoder, unet):
157
+ train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizers[0], text_encoder, unet)
158
+
159
+ # region SD/SDXL
160
+
161
+ def post_process_network(self, args, accelerator, network, text_encoders, unet):
162
+ pass
163
+
164
+ def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
165
+ noise_scheduler = DDPMScheduler(
166
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
167
+ )
168
+ prepare_scheduler_for_custom_training(noise_scheduler, device)
169
+ if args.zero_terminal_snr:
170
+ custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
171
+ return noise_scheduler
172
+
173
+ def encode_images_to_latents(self, args, accelerator, vae, images):
174
+ return vae.encode(images).latent_dist.sample()
175
+
176
+ def shift_scale_latents(self, args, latents):
177
+ return latents * self.vae_scale_factor
178
+
179
+ def get_noise_pred_and_target(
180
+ self,
181
+ args,
182
+ accelerator,
183
+ noise_scheduler,
184
+ latents,
185
+ batch,
186
+ text_encoder_conds,
187
+ unet,
188
+ network,
189
+ weight_dtype,
190
+ train_unet,
191
+ ):
192
+ # Sample noise, sample a random timestep for each image, and add noise to the latents,
193
+ # with noise offset and/or multires noise if specified
194
+ noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
195
+
196
+ # ensure the hidden state will require grad
197
+ if args.gradient_checkpointing:
198
+ for x in noisy_latents:
199
+ x.requires_grad_(True)
200
+ for t in text_encoder_conds:
201
+ t.requires_grad_(True)
202
+
203
+ # Predict the noise residual
204
+ with accelerator.autocast():
205
+ noise_pred = self.call_unet(
206
+ args,
207
+ accelerator,
208
+ unet,
209
+ noisy_latents.requires_grad_(train_unet),
210
+ timesteps,
211
+ text_encoder_conds,
212
+ batch,
213
+ weight_dtype,
214
+ )
215
+
216
+ if args.v_parameterization:
217
+ # v-parameterization training
218
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
219
+ else:
220
+ target = noise
221
+
222
+ # differential output preservation
223
+ if "custom_attributes" in batch:
224
+ diff_output_pr_indices = []
225
+ for i, custom_attributes in enumerate(batch["custom_attributes"]):
226
+ if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
227
+ diff_output_pr_indices.append(i)
228
+
229
+ if len(diff_output_pr_indices) > 0:
230
+ network.set_multiplier(0.0)
231
+ with torch.no_grad(), accelerator.autocast():
232
+ noise_pred_prior = self.call_unet(
233
+ args,
234
+ accelerator,
235
+ unet,
236
+ noisy_latents,
237
+ timesteps,
238
+ text_encoder_conds,
239
+ batch,
240
+ weight_dtype,
241
+ indices=diff_output_pr_indices,
242
+ )
243
+ network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
244
+ target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)
245
+
246
+ return noise_pred, target, timesteps, huber_c, None
247
+
248
+ def post_process_loss(self, loss, args, timesteps, noise_scheduler):
249
+ if args.min_snr_gamma:
250
+ loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
251
+ if args.scale_v_pred_loss_like_noise_pred:
252
+ loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
253
+ if args.v_pred_like_loss:
254
+ loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
255
+ if args.debiased_estimation_loss:
256
+ loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
257
+ return loss
258
+
259
+ def get_sai_model_spec(self, args):
260
+ return train_util.get_sai_model_spec(None, args, self.is_sdxl, True, False)
261
+
262
+ def update_metadata(self, metadata, args):
263
+ pass
264
+
265
+ def is_text_encoder_not_needed_for_training(self, args):
266
+ return False # use for sample images
267
+
268
+ def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
269
+ # set top parameter requires_grad = True for gradient checkpointing works
270
+ text_encoder.text_model.embeddings.requires_grad_(True)
271
+
272
+ def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
273
+ text_encoder.text_model.embeddings.to(dtype=weight_dtype)
274
+
275
+ def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
276
+ pass
277
+
278
+ # endregion
279
+
280
+ def train(self, args):
281
+ session_id = random.randint(0, 2**32)
282
+ training_started_at = time.time()
283
+ train_util.verify_training_args(args)
284
+ train_util.prepare_dataset_args(args, True)
285
+ deepspeed_utils.prepare_deepspeed_args(args)
286
+ setup_logging(args, reset=True)
287
+
288
+ cache_latents = args.cache_latents
289
+ use_dreambooth_method = args.in_json is None
290
+ use_user_config = args.dataset_config is not None
291
+
292
+ if args.seed is None:
293
+ args.seed = random.randint(0, 2**32)
294
+ set_seed(args.seed)
295
+
296
+ tokenize_strategy = self.get_tokenize_strategy(args)
297
+ strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
298
+ tokenizers = self.get_tokenizers(tokenize_strategy) # will be removed after sample_image is refactored
299
+
300
+ # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
301
+ latents_caching_strategy = self.get_latents_caching_strategy(args)
302
+ strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
303
+
304
+ # データセットを準備する
305
+ if args.dataset_class is None:
306
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
307
+ if use_user_config:
308
+ logger.info(f"Loading dataset config from {args.dataset_config}")
309
+ user_config = config_util.load_user_config(args.dataset_config)
310
+ ignored = ["train_data_dir", "reg_data_dir", "in_json"]
311
+ if any(getattr(args, attr) is not None for attr in ignored):
312
+ logger.warning(
313
+ "ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
314
+ ", ".join(ignored)
315
+ )
316
+ )
317
+ else:
318
+ if use_dreambooth_method:
319
+ logger.info("Using DreamBooth method.")
320
+ user_config = {
321
+ "datasets": [
322
+ {
323
+ "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
324
+ args.train_data_dir, args.reg_data_dir
325
+ )
326
+ }
327
+ ]
328
+ }
329
+ else:
330
+ logger.info("Training with captions.")
331
+ user_config = {
332
+ "datasets": [
333
+ {
334
+ "subsets": [
335
+ {
336
+ "image_dir": args.train_data_dir,
337
+ "metadata_file": args.in_json,
338
+ }
339
+ ]
340
+ }
341
+ ]
342
+ }
343
+
344
+ blueprint = blueprint_generator.generate(user_config, args)
345
+ train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
346
+ else:
347
+ # use arbitrary dataset class
348
+ train_dataset_group = train_util.load_arbitrary_dataset(args)
349
+
350
+ current_epoch = Value("i", 0)
351
+ current_step = Value("i", 0)
352
+ ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
353
+ collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
354
+
355
+ if args.debug_dataset:
356
+ train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly
357
+ train_util.debug_dataset(train_dataset_group)
358
+ return
359
+ if len(train_dataset_group) == 0:
360
+ logger.error(
361
+ "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではな��、画像があるフォルダの親フォルダを指定する必要があります)"
362
+ )
363
+ return
364
+
365
+ if cache_latents:
366
+ assert (
367
+ train_dataset_group.is_latent_cacheable()
368
+ ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
369
+
370
+ self.assert_extra_args(args, train_dataset_group) # may change some args
371
+
372
+ # acceleratorを準備する
373
+ logger.info("preparing accelerator")
374
+ accelerator = train_util.prepare_accelerator(args)
375
+ is_main_process = accelerator.is_main_process
376
+
377
+ # mixed precisionに対応した型を用意しておき適宜castする
378
+ weight_dtype, save_dtype = train_util.prepare_dtype(args)
379
+ vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
380
+
381
+ # モデルを読み込む
382
+ model_version, text_encoder, vae, unet = self.load_target_model(args, weight_dtype, accelerator)
383
+
384
+ # text_encoder is List[CLIPTextModel] or CLIPTextModel
385
+ text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
386
+
387
+ # 差分追加学習のためにモデルを読み込む
388
+ sys.path.append(os.path.dirname(__file__))
389
+ accelerator.print("import network module:", args.network_module)
390
+ network_module = importlib.import_module(args.network_module)
391
+
392
+ if args.base_weights is not None:
393
+ # base_weights が指定されている場合は、指定された重みを読み込みマージする
394
+ for i, weight_path in enumerate(args.base_weights):
395
+ if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i:
396
+ multiplier = 1.0
397
+ else:
398
+ multiplier = args.base_weights_multiplier[i]
399
+
400
+ accelerator.print(f"merging module: {weight_path} with multiplier {multiplier}")
401
+
402
+ module, weights_sd = network_module.create_network_from_weights(
403
+ multiplier, weight_path, vae, text_encoder, unet, for_inference=True
404
+ )
405
+ module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu")
406
+
407
+ accelerator.print(f"all weights merged: {', '.join(args.base_weights)}")
408
+
409
+ # 学習を準備する
410
+ if cache_latents:
411
+ vae.to(accelerator.device, dtype=vae_dtype)
412
+ vae.requires_grad_(False)
413
+ vae.eval()
414
+
415
+ train_dataset_group.new_cache_latents(vae, accelerator)
416
+
417
+ vae.to("cpu")
418
+ clean_memory_on_device(accelerator.device)
419
+
420
+ accelerator.wait_for_everyone()
421
+
422
+ # 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
423
+ # cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu
424
+ text_encoding_strategy = self.get_text_encoding_strategy(args)
425
+ strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
426
+
427
+ text_encoder_outputs_caching_strategy = self.get_text_encoder_outputs_caching_strategy(args)
428
+ if text_encoder_outputs_caching_strategy is not None:
429
+ strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_outputs_caching_strategy)
430
+ self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, train_dataset_group, weight_dtype)
431
+
432
+ # prepare network
433
+ net_kwargs = {}
434
+ if args.network_args is not None:
435
+ for net_arg in args.network_args:
436
+ key, value = net_arg.split("=")
437
+ net_kwargs[key] = value
438
+
439
+ # if a new network is added in future, add if ~ then blocks for each network (;'∀')
440
+ if args.dim_from_weights:
441
+ network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs)
442
+ else:
443
+ if "dropout" not in net_kwargs:
444
+ # workaround for LyCORIS (;^ω^)
445
+ net_kwargs["dropout"] = args.network_dropout
446
+
447
+ network = network_module.create_network(
448
+ 1.0,
449
+ args.network_dim,
450
+ args.network_alpha,
451
+ vae,
452
+ text_encoder,
453
+ unet,
454
+ neuron_dropout=args.network_dropout,
455
+ **net_kwargs,
456
+ )
457
+ if network is None:
458
+ return
459
+ network_has_multiplier = hasattr(network, "set_multiplier")
460
+
461
+ if hasattr(network, "prepare_network"):
462
+ network.prepare_network(args)
463
+ if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"):
464
+ logger.warning(
465
+ "warning: scale_weight_norms is specified but the network does not support it / scale_weight_normsが指定されていますが、ネットワークが対応���ていません"
466
+ )
467
+ args.scale_weight_norms = False
468
+
469
+ self.post_process_network(args, accelerator, network, text_encoders, unet)
470
+
471
+ # apply network to unet and text_encoder
472
+ train_unet = not args.network_train_text_encoder_only
473
+ train_text_encoder = self.is_train_text_encoder(args)
474
+ network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
475
+
476
+ if args.network_weights is not None:
477
+ # FIXME consider alpha of weights: this assumes that the alpha is not changed
478
+ info = network.load_weights(args.network_weights)
479
+ accelerator.print(f"load network weights from {args.network_weights}: {info}")
480
+
481
+ if args.gradient_checkpointing:
482
+ if args.cpu_offload_checkpointing:
483
+ unet.enable_gradient_checkpointing(cpu_offload=True)
484
+ else:
485
+ unet.enable_gradient_checkpointing()
486
+
487
+ for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)):
488
+ if flag:
489
+ if t_enc.supports_gradient_checkpointing:
490
+ t_enc.gradient_checkpointing_enable()
491
+ del t_enc
492
+ network.enable_gradient_checkpointing() # may have no effect
493
+
494
+ # 学習に必要なクラスを準備する
495
+ accelerator.print("prepare optimizer, data loader etc.")
496
+
497
+ # make backward compatibility for text_encoder_lr
498
+ support_multiple_lrs = hasattr(network, "prepare_optimizer_params_with_multiple_te_lrs")
499
+ if support_multiple_lrs:
500
+ text_encoder_lr = args.text_encoder_lr
501
+ else:
502
+ # toml backward compatibility
503
+ if args.text_encoder_lr is None or isinstance(args.text_encoder_lr, float) or isinstance(args.text_encoder_lr, int):
504
+ text_encoder_lr = args.text_encoder_lr
505
+ else:
506
+ text_encoder_lr = None if len(args.text_encoder_lr) == 0 else args.text_encoder_lr[0]
507
+ try:
508
+ if support_multiple_lrs:
509
+ results = network.prepare_optimizer_params_with_multiple_te_lrs(text_encoder_lr, args.unet_lr, args.learning_rate)
510
+ else:
511
+ results = network.prepare_optimizer_params(text_encoder_lr, args.unet_lr, args.learning_rate)
512
+ if type(results) is tuple:
513
+ trainable_params = results[0]
514
+ lr_descriptions = results[1]
515
+ else:
516
+ trainable_params = results
517
+ lr_descriptions = None
518
+ except TypeError as e:
519
+ trainable_params = network.prepare_optimizer_params(text_encoder_lr, args.unet_lr)
520
+ lr_descriptions = None
521
+
522
+ # if len(trainable_params) == 0:
523
+ # accelerator.print("no trainable parameters found / 学習可能なパラメータが見つかりませんでした")
524
+ # for params in trainable_params:
525
+ # for k, v in params.items():
526
+ # if type(v) == float:
527
+ # pass
528
+ # else:
529
+ # v = len(v)
530
+ # accelerator.print(f"trainable_params: {k} = {v}")
531
+
532
+ optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
533
+ optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
534
+
535
+ # prepare dataloader
536
+ # strategies are set here because they cannot be referenced in another process. Copy them with the dataset
537
+ # some strategies can be None
538
+ train_dataset_group.set_current_strategies()
539
+
540
+ # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
541
+ n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
542
+
543
+ train_dataloader = torch.utils.data.DataLoader(
544
+ train_dataset_group,
545
+ batch_size=1,
546
+ shuffle=True,
547
+ collate_fn=collator,
548
+ num_workers=n_workers,
549
+ persistent_workers=args.persistent_data_loader_workers,
550
+ )
551
+
552
+ # 学習ステップ数を計算する
553
+ if args.max_train_epochs is not None:
554
+ args.max_train_steps = args.max_train_epochs * math.ceil(
555
+ len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
556
+ )
557
+ accelerator.print(
558
+ f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
559
+ )
560
+
561
+ # データセット側にも学習ステップを送信
562
+ train_dataset_group.set_max_train_steps(args.max_train_steps)
563
+
564
+ # lr schedulerを用意する
565
+ lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
566
+
567
+ # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
568
+ if args.full_fp16:
569
+ assert (
570
+ args.mixed_precision == "fp16"
571
+ ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
572
+ accelerator.print("enable full fp16 training.")
573
+ network.to(weight_dtype)
574
+ elif args.full_bf16:
575
+ assert (
576
+ args.mixed_precision == "bf16"
577
+ ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
578
+ accelerator.print("enable full bf16 training.")
579
+ network.to(weight_dtype)
580
+
581
+ unet_weight_dtype = te_weight_dtype = weight_dtype
582
+ # Experimental Feature: Put base model into fp8 to save vram
583
+ if args.fp8_base or args.fp8_base_unet:
584
+ assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。"
585
+ assert (
586
+ args.mixed_precision != "no"
587
+ ), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。"
588
+ accelerator.print("enable fp8 training for U-Net.")
589
+ unet_weight_dtype = torch.float8_e4m3fn
590
+
591
+ if not args.fp8_base_unet:
592
+ accelerator.print("enable fp8 training for Text Encoder.")
593
+ te_weight_dtype = weight_dtype if args.fp8_base_unet else torch.float8_e4m3fn
594
+
595
+ # unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM
596
+ # unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory
597
+
598
+ logger.info(f"set U-Net weight dtype to {unet_weight_dtype}, device to {accelerator.device}")
599
+ unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above
600
+
601
+ unet.requires_grad_(False)
602
+ unet.to(dtype=unet_weight_dtype)
603
+ for i, t_enc in enumerate(text_encoders):
604
+ t_enc.requires_grad_(False)
605
+
606
+ # in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16
607
+ if t_enc.device.type != "cpu":
608
+ t_enc.to(dtype=te_weight_dtype)
609
+
610
+ # nn.Embedding not support FP8
611
+ if te_weight_dtype != weight_dtype:
612
+ self.prepare_text_encoder_fp8(i, t_enc, te_weight_dtype, weight_dtype)
613
+
614
+ # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
615
+ if args.deepspeed:
616
+ flags = self.get_text_encoders_train_flags(args, text_encoders)
617
+ ds_model = deepspeed_utils.prepare_deepspeed_model(
618
+ args,
619
+ unet=unet if train_unet else None,
620
+ text_encoder1=text_encoders[0] if flags[0] else None,
621
+ text_encoder2=(text_encoders[1] if flags[1] else None) if len(text_encoders) > 1 else None,
622
+ network=network,
623
+ )
624
+ ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
625
+ ds_model, optimizer, train_dataloader, lr_scheduler
626
+ )
627
+ training_model = ds_model
628
+ else:
629
+ if train_unet:
630
+ unet = accelerator.prepare(unet)
631
+ else:
632
+ unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator
633
+ if train_text_encoder:
634
+ text_encoders = [
635
+ (accelerator.prepare(t_enc) if flag else t_enc)
636
+ for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders))
637
+ ]
638
+ if len(text_encoders) > 1:
639
+ text_encoder = text_encoders
640
+ else:
641
+ text_encoder = text_encoders[0]
642
+ else:
643
+ pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set
644
+
645
+ network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
646
+ network, optimizer, train_dataloader, lr_scheduler
647
+ )
648
+ training_model = network
649
+
650
+ if args.gradient_checkpointing:
651
+ # according to TI example in Diffusers, train is required
652
+ unet.train()
653
+ for i, (t_enc, frag) in enumerate(zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders))):
654
+ t_enc.train()
655
+
656
+ # set top parameter requires_grad = True for gradient checkpointing works
657
+ if frag:
658
+ self.prepare_text_encoder_grad_ckpt_workaround(i, t_enc)
659
+
660
+ else:
661
+ unet.eval()
662
+ for t_enc in text_encoders:
663
+ t_enc.eval()
664
+
665
+ del t_enc
666
+
667
+ accelerator.unwrap_model(network).prepare_grad_etc(text_encoder, unet)
668
+
669
+ if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する
670
+ vae.requires_grad_(False)
671
+ vae.eval()
672
+ vae.to(accelerator.device, dtype=vae_dtype)
673
+
674
+ # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
675
+ if args.full_fp16:
676
+ train_util.patch_accelerator_for_fp16_training(accelerator)
677
+
678
+ # before resuming make hook for saving/loading to save/load the network weights only
679
+ def save_model_hook(models, weights, output_dir):
680
+ # pop weights of other models than network to save only network weights
681
+ # only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606
682
+ if accelerator.is_main_process or args.deepspeed:
683
+ remove_indices = []
684
+ for i, model in enumerate(models):
685
+ if not isinstance(model, type(accelerator.unwrap_model(network))):
686
+ remove_indices.append(i)
687
+ for i in reversed(remove_indices):
688
+ if len(weights) > i:
689
+ weights.pop(i)
690
+ # print(f"save model hook: {len(weights)} weights will be saved")
691
+
692
+ # save current ecpoch and step
693
+ train_state_file = os.path.join(output_dir, "train_state.json")
694
+ # +1 is needed because the state is saved before current_step is set from global_step
695
+ logger.info(f"save train state to {train_state_file} at epoch {current_epoch.value} step {current_step.value+1}")
696
+ with open(train_state_file, "w", encoding="utf-8") as f:
697
+ json.dump({"current_epoch": current_epoch.value, "current_step": current_step.value + 1}, f)
698
+
699
+ steps_from_state = None
700
+
701
+ def load_model_hook(models, input_dir):
702
+ # remove models except network
703
+ remove_indices = []
704
+ for i, model in enumerate(models):
705
+ if not isinstance(model, type(accelerator.unwrap_model(network))):
706
+ remove_indices.append(i)
707
+ for i in reversed(remove_indices):
708
+ models.pop(i)
709
+ # print(f"load model hook: {len(models)} models will be loaded")
710
+
711
+ # load current epoch and step to
712
+ nonlocal steps_from_state
713
+ train_state_file = os.path.join(input_dir, "train_state.json")
714
+ if os.path.exists(train_state_file):
715
+ with open(train_state_file, "r", encoding="utf-8") as f:
716
+ data = json.load(f)
717
+ steps_from_state = data["current_step"]
718
+ logger.info(f"load train state from {train_state_file}: {data}")
719
+
720
+ accelerator.register_save_state_pre_hook(save_model_hook)
721
+ accelerator.register_load_state_pre_hook(load_model_hook)
722
+
723
+ # resumeする
724
+ train_util.resume_from_local_or_hf_if_specified(accelerator, args)
725
+
726
+ # epoch数を計算する
727
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
728
+ num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
729
+ if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
730
+ args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
731
+
732
+ # 学習する
733
+ # TODO: find a way to handle total batch size when there are multiple datasets
734
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
735
+
736
+ accelerator.print("running training / 学習開始")
737
+ accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
738
+ accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
739
+ accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
740
+ accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
741
+ accelerator.print(
742
+ f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
743
+ )
744
+ # accelerator.print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
745
+ accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
746
+ accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
747
+
748
+ # TODO refactor metadata creation and move to util
749
+ metadata = {
750
+ "ss_session_id": session_id, # random integer indicating which group of epochs the model came from
751
+ "ss_training_started_at": training_started_at, # unix timestamp
752
+ "ss_output_name": args.output_name,
753
+ "ss_learning_rate": args.learning_rate,
754
+ "ss_text_encoder_lr": text_encoder_lr,
755
+ "ss_unet_lr": args.unet_lr,
756
+ "ss_num_train_images": train_dataset_group.num_train_images,
757
+ "ss_num_reg_images": train_dataset_group.num_reg_images,
758
+ "ss_num_batches_per_epoch": len(train_dataloader),
759
+ "ss_num_epochs": num_train_epochs,
760
+ "ss_gradient_checkpointing": args.gradient_checkpointing,
761
+ "ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
762
+ "ss_max_train_steps": args.max_train_steps,
763
+ "ss_lr_warmup_steps": args.lr_warmup_steps,
764
+ "ss_lr_scheduler": args.lr_scheduler,
765
+ "ss_network_module": args.network_module,
766
+ "ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim
767
+ "ss_network_alpha": args.network_alpha, # some networks may not have alpha
768
+ "ss_network_dropout": args.network_dropout, # some networks may not have dropout
769
+ "ss_mixed_precision": args.mixed_precision,
770
+ "ss_full_fp16": bool(args.full_fp16),
771
+ "ss_v2": bool(args.v2),
772
+ "ss_base_model_version": model_version,
773
+ "ss_clip_skip": args.clip_skip,
774
+ "ss_max_token_length": args.max_token_length,
775
+ "ss_cache_latents": bool(args.cache_latents),
776
+ "ss_seed": args.seed,
777
+ "ss_lowram": args.lowram,
778
+ "ss_noise_offset": args.noise_offset,
779
+ "ss_multires_noise_iterations": args.multires_noise_iterations,
780
+ "ss_multires_noise_discount": args.multires_noise_discount,
781
+ "ss_adaptive_noise_scale": args.adaptive_noise_scale,
782
+ "ss_zero_terminal_snr": args.zero_terminal_snr,
783
+ "ss_training_comment": args.training_comment, # will not be updated after training
784
+ "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
785
+ "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""),
786
+ "ss_max_grad_norm": args.max_grad_norm,
787
+ "ss_caption_dropout_rate": args.caption_dropout_rate,
788
+ "ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs,
789
+ "ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate,
790
+ "ss_face_crop_aug_range": args.face_crop_aug_range,
791
+ "ss_prior_loss_weight": args.prior_loss_weight,
792
+ "ss_min_snr_gamma": args.min_snr_gamma,
793
+ "ss_scale_weight_norms": args.scale_weight_norms,
794
+ "ss_ip_noise_gamma": args.ip_noise_gamma,
795
+ "ss_debiased_estimation": bool(args.debiased_estimation_loss),
796
+ "ss_noise_offset_random_strength": args.noise_offset_random_strength,
797
+ "ss_ip_noise_gamma_random_strength": args.ip_noise_gamma_random_strength,
798
+ "ss_loss_type": args.loss_type,
799
+ "ss_huber_schedule": args.huber_schedule,
800
+ "ss_huber_c": args.huber_c,
801
+ "ss_fp8_base": bool(args.fp8_base),
802
+ "ss_fp8_base_unet": bool(args.fp8_base_unet),
803
+ }
804
+
805
+ self.update_metadata(metadata, args) # architecture specific metadata
806
+
807
+ if use_user_config:
808
+ # save metadata of multiple datasets
809
+ # NOTE: pack "ss_datasets" value as json one time
810
+ # or should also pack nested collections as json?
811
+ datasets_metadata = []
812
+ tag_frequency = {} # merge tag frequency for metadata editor
813
+ dataset_dirs_info = {} # merge subset dirs for metadata editor
814
+
815
+ for dataset in train_dataset_group.datasets:
816
+ is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset)
817
+ dataset_metadata = {
818
+ "is_dreambooth": is_dreambooth_dataset,
819
+ "batch_size_per_device": dataset.batch_size,
820
+ "num_train_images": dataset.num_train_images, # includes repeating
821
+ "num_reg_images": dataset.num_reg_images,
822
+ "resolution": (dataset.width, dataset.height),
823
+ "enable_bucket": bool(dataset.enable_bucket),
824
+ "min_bucket_reso": dataset.min_bucket_reso,
825
+ "max_bucket_reso": dataset.max_bucket_reso,
826
+ "tag_frequency": dataset.tag_frequency,
827
+ "bucket_info": dataset.bucket_info,
828
+ }
829
+
830
+ subsets_metadata = []
831
+ for subset in dataset.subsets:
832
+ subset_metadata = {
833
+ "img_count": subset.img_count,
834
+ "num_repeats": subset.num_repeats,
835
+ "color_aug": bool(subset.color_aug),
836
+ "flip_aug": bool(subset.flip_aug),
837
+ "random_crop": bool(subset.random_crop),
838
+ "shuffle_caption": bool(subset.shuffle_caption),
839
+ "keep_tokens": subset.keep_tokens,
840
+ "keep_tokens_separator": subset.keep_tokens_separator,
841
+ "secondary_separator": subset.secondary_separator,
842
+ "enable_wildcard": bool(subset.enable_wildcard),
843
+ "caption_prefix": subset.caption_prefix,
844
+ "caption_suffix": subset.caption_suffix,
845
+ }
846
+
847
+ image_dir_or_metadata_file = None
848
+ if subset.image_dir:
849
+ image_dir = os.path.basename(subset.image_dir)
850
+ subset_metadata["image_dir"] = image_dir
851
+ image_dir_or_metadata_file = image_dir
852
+
853
+ if is_dreambooth_dataset:
854
+ subset_metadata["class_tokens"] = subset.class_tokens
855
+ subset_metadata["is_reg"] = subset.is_reg
856
+ if subset.is_reg:
857
+ image_dir_or_metadata_file = None # not merging reg dataset
858
+ else:
859
+ metadata_file = os.path.basename(subset.metadata_file)
860
+ subset_metadata["metadata_file"] = metadata_file
861
+ image_dir_or_metadata_file = metadata_file # may overwrite
862
+
863
+ subsets_metadata.append(subset_metadata)
864
+
865
+ # merge dataset dir: not reg subset only
866
+ # TODO update additional-network extension to show detailed dataset config from metadata
867
+ if image_dir_or_metadata_file is not None:
868
+ # datasets may have a certain dir multiple times
869
+ v = image_dir_or_metadata_file
870
+ i = 2
871
+ while v in dataset_dirs_info:
872
+ v = image_dir_or_metadata_file + f" ({i})"
873
+ i += 1
874
+ image_dir_or_metadata_file = v
875
+
876
+ dataset_dirs_info[image_dir_or_metadata_file] = {
877
+ "n_repeats": subset.num_repeats,
878
+ "img_count": subset.img_count,
879
+ }
880
+
881
+ dataset_metadata["subsets"] = subsets_metadata
882
+ datasets_metadata.append(dataset_metadata)
883
+
884
+ # merge tag frequency:
885
+ for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items():
886
+ # あるディレクトリが複数のdatasetで使用されている場合、一度だけ数える
887
+ # もともと繰り返し回数を指定しているので、キャプション内でのタグの出現回数と、それが学習で何度使われるかは一致しない
888
+ # なので、ここで複数datasetの回数を合算してもあまり意味はない
889
+ if ds_dir_name in tag_frequency:
890
+ continue
891
+ tag_frequency[ds_dir_name] = ds_freq_for_dir
892
+
893
+ metadata["ss_datasets"] = json.dumps(datasets_metadata)
894
+ metadata["ss_tag_frequency"] = json.dumps(tag_frequency)
895
+ metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info)
896
+ else:
897
+ # conserving backward compatibility when using train_dataset_dir and reg_dataset_dir
898
+ assert (
899
+ len(train_dataset_group.datasets) == 1
900
+ ), f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。"
901
+
902
+ dataset = train_dataset_group.datasets[0]
903
+
904
+ dataset_dirs_info = {}
905
+ reg_dataset_dirs_info = {}
906
+ if use_dreambooth_method:
907
+ for subset in dataset.subsets:
908
+ info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info
909
+ info[os.path.basename(subset.image_dir)] = {"n_repeats": subset.num_repeats, "img_count": subset.img_count}
910
+ else:
911
+ for subset in dataset.subsets:
912
+ dataset_dirs_info[os.path.basename(subset.metadata_file)] = {
913
+ "n_repeats": subset.num_repeats,
914
+ "img_count": subset.img_count,
915
+ }
916
+
917
+ metadata.update(
918
+ {
919
+ "ss_batch_size_per_device": args.train_batch_size,
920
+ "ss_total_batch_size": total_batch_size,
921
+ "ss_resolution": args.resolution,
922
+ "ss_color_aug": bool(args.color_aug),
923
+ "ss_flip_aug": bool(args.flip_aug),
924
+ "ss_random_crop": bool(args.random_crop),
925
+ "ss_shuffle_caption": bool(args.shuffle_caption),
926
+ "ss_enable_bucket": bool(dataset.enable_bucket),
927
+ "ss_bucket_no_upscale": bool(dataset.bucket_no_upscale),
928
+ "ss_min_bucket_reso": dataset.min_bucket_reso,
929
+ "ss_max_bucket_reso": dataset.max_bucket_reso,
930
+ "ss_keep_tokens": args.keep_tokens,
931
+ "ss_dataset_dirs": json.dumps(dataset_dirs_info),
932
+ "ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info),
933
+ "ss_tag_frequency": json.dumps(dataset.tag_frequency),
934
+ "ss_bucket_info": json.dumps(dataset.bucket_info),
935
+ }
936
+ )
937
+
938
+ # add extra args
939
+ if args.network_args:
940
+ metadata["ss_network_args"] = json.dumps(net_kwargs)
941
+
942
+ # model name and hash
943
+ if args.pretrained_model_name_or_path is not None:
944
+ sd_model_name = args.pretrained_model_name_or_path
945
+ if os.path.exists(sd_model_name):
946
+ metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name)
947
+ metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(sd_model_name)
948
+ sd_model_name = os.path.basename(sd_model_name)
949
+ metadata["ss_sd_model_name"] = sd_model_name
950
+
951
+ if args.vae is not None:
952
+ vae_name = args.vae
953
+ if os.path.exists(vae_name):
954
+ metadata["ss_vae_hash"] = train_util.model_hash(vae_name)
955
+ metadata["ss_new_vae_hash"] = train_util.calculate_sha256(vae_name)
956
+ vae_name = os.path.basename(vae_name)
957
+ metadata["ss_vae_name"] = vae_name
958
+
959
+ metadata = {k: str(v) for k, v in metadata.items()}
960
+
961
+ # make minimum metadata for filtering
962
+ minimum_metadata = {}
963
+ for key in train_util.SS_METADATA_MINIMUM_KEYS:
964
+ if key in metadata:
965
+ minimum_metadata[key] = metadata[key]
966
+
967
+ # calculate steps to skip when resuming or starting from a specific step
968
+ initial_step = 0
969
+ if args.initial_epoch is not None or args.initial_step is not None:
970
+ # if initial_epoch or initial_step is specified, steps_from_state is ignored even when resuming
971
+ if steps_from_state is not None:
972
+ logger.warning(
973
+ "steps from the state is ignored because initial_step is specified / initial_stepが指定されているため、stateからのステップ数は無視されます"
974
+ )
975
+ if args.initial_step is not None:
976
+ initial_step = args.initial_step
977
+ else:
978
+ # num steps per epoch is calculated by num_processes and gradient_accumulation_steps
979
+ initial_step = (args.initial_epoch - 1) * math.ceil(
980
+ len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
981
+ )
982
+ else:
983
+ # if initial_epoch and initial_step are not specified, steps_from_state is used when resuming
984
+ if steps_from_state is not None:
985
+ initial_step = steps_from_state
986
+ steps_from_state = None
987
+
988
+ if initial_step > 0:
989
+ assert (
990
+ args.max_train_steps > initial_step
991
+ ), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: {args.max_train_steps} vs {initial_step}"
992
+
993
+ progress_bar = tqdm(
994
+ range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps"
995
+ )
996
+
997
+ epoch_to_start = 0
998
+ if initial_step > 0:
999
+ if args.skip_until_initial_step:
1000
+ # if skip_until_initial_step is specified, load data and discard it to ensure the same data is used
1001
+ if not args.resume:
1002
+ logger.info(
1003
+ f"initial_step is specified but not resuming. lr scheduler will be started from the beginning / initial_stepが指定されていますがresumeしていないため、lr schedulerは最初から始まります"
1004
+ )
1005
+ logger.info(f"skipping {initial_step} steps / {initial_step}ステップをスキップします")
1006
+ initial_step *= args.gradient_accumulation_steps
1007
+
1008
+ # set epoch to start to make initial_step less than len(train_dataloader)
1009
+ epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1010
+ else:
1011
+ # if not, only epoch no is skipped for informative purpose
1012
+ epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1013
+ initial_step = 0 # do not skip
1014
+
1015
+ global_step = 0
1016
+
1017
+ noise_scheduler = self.get_noise_scheduler(args, accelerator.device)
1018
+
1019
+ if accelerator.is_main_process:
1020
+ init_kwargs = {}
1021
+ if args.wandb_run_name:
1022
+ init_kwargs["wandb"] = {"name": args.wandb_run_name}
1023
+ if args.log_tracker_config is not None:
1024
+ init_kwargs = toml.load(args.log_tracker_config)
1025
+ accelerator.init_trackers(
1026
+ "network_train" if args.log_tracker_name is None else args.log_tracker_name,
1027
+ config=train_util.get_sanitized_config_or_none(args),
1028
+ init_kwargs=init_kwargs,
1029
+ )
1030
+
1031
+ loss_recorder = train_util.LossRecorder()
1032
+ del train_dataset_group
1033
+
1034
+ # callback for step start
1035
+ if hasattr(accelerator.unwrap_model(network), "on_step_start"):
1036
+ on_step_start_for_network = accelerator.unwrap_model(network).on_step_start
1037
+ else:
1038
+ on_step_start_for_network = lambda *args, **kwargs: None
1039
+
1040
+ # function for saving/removing
1041
+ def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False):
1042
+ os.makedirs(args.output_dir, exist_ok=True)
1043
+ ckpt_file = os.path.join(args.output_dir, ckpt_name)
1044
+
1045
+ accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
1046
+ metadata["ss_training_finished_at"] = str(time.time())
1047
+ metadata["ss_steps"] = str(steps)
1048
+ metadata["ss_epoch"] = str(epoch_no)
1049
+
1050
+ metadata_to_save = minimum_metadata if args.no_metadata else metadata
1051
+ sai_metadata = self.get_sai_model_spec(args)
1052
+ metadata_to_save.update(sai_metadata)
1053
+
1054
+ unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save)
1055
+ if args.huggingface_repo_id is not None:
1056
+ huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
1057
+
1058
+ def remove_model(old_ckpt_name):
1059
+ old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
1060
+ if os.path.exists(old_ckpt_file):
1061
+ accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
1062
+ os.remove(old_ckpt_file)
1063
+
1064
+ # if text_encoder is not needed for training, delete it to save memory.
1065
+ # TODO this can be automated after SDXL sample prompt cache is implemented
1066
+ if self.is_text_encoder_not_needed_for_training(args):
1067
+ logger.info("text_encoder is not needed for training. deleting to save memory.")
1068
+ for t_enc in text_encoders:
1069
+ del t_enc
1070
+ text_encoders = []
1071
+ text_encoder = None
1072
+
1073
+ # For --sample_at_first
1074
+ optimizer_eval_fn()
1075
+ self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
1076
+ optimizer_train_fn()
1077
+ if len(accelerator.trackers) > 0:
1078
+ # log empty object to commit the sample images to wandb
1079
+ accelerator.log({}, step=0)
1080
+
1081
+ # training loop
1082
+ if initial_step > 0: # only if skip_until_initial_step is specified
1083
+ for skip_epoch in range(epoch_to_start): # skip epochs
1084
+ logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}")
1085
+ initial_step -= len(train_dataloader)
1086
+ global_step = initial_step
1087
+
1088
+ # log device and dtype for each model
1089
+ logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}")
1090
+ for i, t_enc in enumerate(text_encoders):
1091
+ params_itr = t_enc.parameters()
1092
+ params_itr.__next__() # skip the first parameter
1093
+ params_itr.__next__() # skip the second parameter. because CLIP first two parameters are embeddings
1094
+ param_3rd = params_itr.__next__()
1095
+ logger.info(f"text_encoder [{i}] dtype: {param_3rd.dtype}, device: {t_enc.device}")
1096
+
1097
+ clean_memory_on_device(accelerator.device)
1098
+
1099
+ for epoch in range(epoch_to_start, num_train_epochs):
1100
+ accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
1101
+ current_epoch.value = epoch + 1
1102
+
1103
+ metadata["ss_epoch"] = str(epoch + 1)
1104
+
1105
+ accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)
1106
+
1107
+ skipped_dataloader = None
1108
+ if initial_step > 0:
1109
+ skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step - 1)
1110
+ initial_step = 1
1111
+
1112
+ for step, batch in enumerate(skipped_dataloader or train_dataloader):
1113
+ current_step.value = global_step
1114
+ if initial_step > 0:
1115
+ initial_step -= 1
1116
+ continue
1117
+
1118
+ with accelerator.accumulate(training_model):
1119
+ on_step_start_for_network(text_encoder, unet)
1120
+
1121
+ # temporary, for batch processing
1122
+ self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
1123
+
1124
+ if "latents" in batch and batch["latents"] is not None:
1125
+ latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
1126
+ else:
1127
+ with torch.no_grad():
1128
+ # latentに変換
1129
+ latents = self.encode_images_to_latents(args, accelerator, vae, batch["images"].to(vae_dtype))
1130
+ latents = latents.to(dtype=weight_dtype)
1131
+
1132
+ # NaNが含まれていれば警告を表示し0に置き換える
1133
+ if torch.any(torch.isnan(latents)):
1134
+ accelerator.print("NaN found in latents, replacing with zeros")
1135
+ latents = torch.nan_to_num(latents, 0, out=latents)
1136
+
1137
+ latents = self.shift_scale_latents(args, latents)
1138
+
1139
+ # get multiplier for each sample
1140
+ if network_has_multiplier:
1141
+ multipliers = batch["network_multipliers"]
1142
+ # if all multipliers are same, use single multiplier
1143
+ if torch.all(multipliers == multipliers[0]):
1144
+ multipliers = multipliers[0].item()
1145
+ else:
1146
+ raise NotImplementedError("multipliers for each sample is not supported yet")
1147
+ # print(f"set multiplier: {multipliers}")
1148
+ accelerator.unwrap_model(network).set_multiplier(multipliers)
1149
+
1150
+ text_encoder_conds = []
1151
+ text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
1152
+ if text_encoder_outputs_list is not None:
1153
+ text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
1154
+
1155
+ if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
1156
+ # TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached'
1157
+ with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
1158
+ # Get the text embedding for conditioning
1159
+ if args.weighted_captions:
1160
+ input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
1161
+ encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights(
1162
+ tokenize_strategy,
1163
+ self.get_models_for_text_encoding(args, accelerator, text_encoders),
1164
+ input_ids_list,
1165
+ weights_list,
1166
+ )
1167
+ else:
1168
+ input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
1169
+ encoded_text_encoder_conds = text_encoding_strategy.encode_tokens(
1170
+ tokenize_strategy,
1171
+ self.get_models_for_text_encoding(args, accelerator, text_encoders),
1172
+ input_ids,
1173
+ )
1174
+ if args.full_fp16:
1175
+ encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds]
1176
+
1177
+ # if text_encoder_conds is not cached, use encoded_text_encoder_conds
1178
+ if len(text_encoder_conds) == 0:
1179
+ text_encoder_conds = encoded_text_encoder_conds
1180
+ else:
1181
+ # if encoded_text_encoder_conds is not None, update cached text_encoder_conds
1182
+ for i in range(len(encoded_text_encoder_conds)):
1183
+ if encoded_text_encoder_conds[i] is not None:
1184
+ text_encoder_conds[i] = encoded_text_encoder_conds[i]
1185
+
1186
+ # sample noise, call unet, get target
1187
+ noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target(
1188
+ args,
1189
+ accelerator,
1190
+ noise_scheduler,
1191
+ latents,
1192
+ batch,
1193
+ text_encoder_conds,
1194
+ unet,
1195
+ network,
1196
+ weight_dtype,
1197
+ train_unet,
1198
+ )
1199
+
1200
+ loss = train_util.conditional_loss(
1201
+ noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
1202
+ )
1203
+ if weighting is not None:
1204
+ loss = loss * weighting
1205
+ if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
1206
+ loss = apply_masked_loss(loss, batch)
1207
+ loss = loss.mean([1, 2, 3])
1208
+
1209
+ loss_weights = batch["loss_weights"] # 各sampleごとのweight
1210
+ loss = loss * loss_weights
1211
+
1212
+ # min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc.
1213
+ loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
1214
+
1215
+ loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
1216
+
1217
+ accelerator.backward(loss)
1218
+ if accelerator.sync_gradients:
1219
+ self.all_reduce_network(accelerator, network) # sync DDP grad manually
1220
+ if args.max_grad_norm != 0.0:
1221
+ params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
1222
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1223
+
1224
+ optimizer.step()
1225
+ lr_scheduler.step()
1226
+ optimizer.zero_grad(set_to_none=True)
1227
+
1228
+ if args.scale_weight_norms:
1229
+ keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
1230
+ args.scale_weight_norms, accelerator.device
1231
+ )
1232
+ max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
1233
+ else:
1234
+ keys_scaled, mean_norm, maximum_norm = None, None, None
1235
+
1236
+ # Checks if the accelerator has performed an optimization step behind the scenes
1237
+ if accelerator.sync_gradients:
1238
+ progress_bar.update(1)
1239
+ global_step += 1
1240
+
1241
+ optimizer_eval_fn()
1242
+ self.sample_images(
1243
+ accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet
1244
+ )
1245
+
1246
+ # 指定ステップごとにモデルを保存
1247
+ if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
1248
+ accelerator.wait_for_everyone()
1249
+ if accelerator.is_main_process:
1250
+ ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
1251
+ save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch)
1252
+
1253
+ if args.save_state:
1254
+ train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
1255
+
1256
+ remove_step_no = train_util.get_remove_step_no(args, global_step)
1257
+ if remove_step_no is not None:
1258
+ remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
1259
+ remove_model(remove_ckpt_name)
1260
+ optimizer_train_fn()
1261
+
1262
+ current_loss = loss.detach().item()
1263
+ loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
1264
+ avr_loss: float = loss_recorder.moving_average
1265
+ logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
1266
+ progress_bar.set_postfix(**logs)
1267
+
1268
+ if args.scale_weight_norms:
1269
+ progress_bar.set_postfix(**{**max_mean_logs, **logs})
1270
+
1271
+ if len(accelerator.trackers) > 0:
1272
+ logs = self.generate_step_logs(
1273
+ args, current_loss, avr_loss, lr_scheduler, lr_descriptions, keys_scaled, mean_norm, maximum_norm
1274
+ )
1275
+ accelerator.log(logs, step=global_step)
1276
+
1277
+ if global_step >= args.max_train_steps:
1278
+ break
1279
+
1280
+ if len(accelerator.trackers) > 0:
1281
+ logs = {"loss/epoch": loss_recorder.moving_average}
1282
+ accelerator.log(logs, step=epoch + 1)
1283
+
1284
+ accelerator.wait_for_everyone()
1285
+
1286
+ # 指定エポックごとにモデルを保存
1287
+ optimizer_eval_fn()
1288
+ if args.save_every_n_epochs is not None:
1289
+ saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
1290
+ if is_main_process and saving:
1291
+ ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
1292
+ save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1)
1293
+
1294
+ remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
1295
+ if remove_epoch_no is not None:
1296
+ remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no)
1297
+ remove_model(remove_ckpt_name)
1298
+
1299
+ if args.save_state:
1300
+ train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
1301
+
1302
+ self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
1303
+ optimizer_train_fn()
1304
+
1305
+ # end of epoch
1306
+
1307
+ # metadata["ss_epoch"] = str(num_train_epochs)
1308
+ metadata["ss_training_finished_at"] = str(time.time())
1309
+
1310
+ if is_main_process:
1311
+ network = accelerator.unwrap_model(network)
1312
+
1313
+ accelerator.end_training()
1314
+ optimizer_eval_fn()
1315
+
1316
+ if is_main_process and (args.save_state or args.save_state_on_train_end):
1317
+ train_util.save_state_on_train_end(args, accelerator)
1318
+
1319
+ if is_main_process:
1320
+ ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
1321
+ save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True)
1322
+
1323
+ logger.info("model saved.")
1324
+
1325
+
1326
+ def setup_parser() -> argparse.ArgumentParser:
1327
+ parser = argparse.ArgumentParser()
1328
+
1329
+ add_logging_arguments(parser)
1330
+ train_util.add_sd_models_arguments(parser)
1331
+ train_util.add_dataset_arguments(parser, True, True, True)
1332
+ train_util.add_training_arguments(parser, True)
1333
+ train_util.add_masked_loss_arguments(parser)
1334
+ deepspeed_utils.add_deepspeed_arguments(parser)
1335
+ train_util.add_optimizer_arguments(parser)
1336
+ config_util.add_config_arguments(parser)
1337
+ custom_train_functions.add_custom_train_arguments(parser)
1338
+
1339
+ parser.add_argument(
1340
+ "--cpu_offload_checkpointing",
1341
+ action="store_true",
1342
+ help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing for U-Net or DiT, if supported"
1343
+ " / 勾配チェックポイント時にテンソルをCPUにオフロードする(U-NetまたはDiTのみ、サポートされている場合)",
1344
+ )
1345
+ parser.add_argument(
1346
+ "--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない"
1347
+ )
1348
+ parser.add_argument(
1349
+ "--save_model_as",
1350
+ type=str,
1351
+ default="safetensors",
1352
+ choices=[None, "ckpt", "pt", "safetensors"],
1353
+ help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)",
1354
+ )
1355
+
1356
+ parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
1357
+ parser.add_argument(
1358
+ "--text_encoder_lr",
1359
+ type=float,
1360
+ default=None,
1361
+ nargs="*",
1362
+ help="learning rate for Text Encoder, can be multiple / Text Encoderの学習率、複数指定可能",
1363
+ )
1364
+ parser.add_argument(
1365
+ "--fp8_base_unet",
1366
+ action="store_true",
1367
+ help="use fp8 for U-Net (or DiT), Text Encoder is fp16 or bf16"
1368
+ " / U-Net(またはDiT)にfp8を使用する。Text Encoderはfp16またはbf16",
1369
+ )
1370
+
1371
+ parser.add_argument(
1372
+ "--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み"
1373
+ )
1374
+ parser.add_argument(
1375
+ "--network_module", type=str, default=None, help="network module to train / 学習対象のネットワークのモジュール"
1376
+ )
1377
+ parser.add_argument(
1378
+ "--network_dim",
1379
+ type=int,
1380
+ default=None,
1381
+ help="network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)",
1382
+ )
1383
+ parser.add_argument(
1384
+ "--network_alpha",
1385
+ type=float,
1386
+ default=1,
1387
+ help="alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1(旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定)",
1388
+ )
1389
+ parser.add_argument(
1390
+ "--network_dropout",
1391
+ type=float,
1392
+ default=None,
1393
+ help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする(0またはNoneはdropoutなし、1は全ニューロンをdropout)",
1394
+ )
1395
+ parser.add_argument(
1396
+ "--network_args",
1397
+ type=str,
1398
+ default=None,
1399
+ nargs="*",
1400
+ help="additional arguments for network (key=value) / ネットワークへの追加の引数",
1401
+ )
1402
+ parser.add_argument(
1403
+ "--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する"
1404
+ )
1405
+ parser.add_argument(
1406
+ "--network_train_text_encoder_only",
1407
+ action="store_true",
1408
+ help="only training Text Encoder part / Text Encoder関連部分のみ学習する",
1409
+ )
1410
+ parser.add_argument(
1411
+ "--training_comment",
1412
+ type=str,
1413
+ default=None,
1414
+ help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列",
1415
+ )
1416
+ parser.add_argument(
1417
+ "--dim_from_weights",
1418
+ action="store_true",
1419
+ help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する",
1420
+ )
1421
+ parser.add_argument(
1422
+ "--scale_weight_norms",
1423
+ type=float,
1424
+ default=None,
1425
+ help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point) / 重みの値をスケ��リングして勾配爆発を防ぐ(1が初期値としては適当)",
1426
+ )
1427
+ parser.add_argument(
1428
+ "--base_weights",
1429
+ type=str,
1430
+ default=None,
1431
+ nargs="*",
1432
+ help="network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みファイル",
1433
+ )
1434
+ parser.add_argument(
1435
+ "--base_weights_multiplier",
1436
+ type=float,
1437
+ default=None,
1438
+ nargs="*",
1439
+ help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率",
1440
+ )
1441
+ parser.add_argument(
1442
+ "--no_half_vae",
1443
+ action="store_true",
1444
+ help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
1445
+ )
1446
+ parser.add_argument(
1447
+ "--skip_until_initial_step",
1448
+ action="store_true",
1449
+ help="skip training until initial_step is reached / initial_stepに到達するまで学習をスキップする",
1450
+ )
1451
+ parser.add_argument(
1452
+ "--initial_epoch",
1453
+ type=int,
1454
+ default=None,
1455
+ help="initial epoch number, 1 means first epoch (same as not specifying). NOTE: initial_epoch/step doesn't affect to lr scheduler. Which means lr scheduler will start from 0 without `--resume`."
1456
+ + " / 初期エポック数、1で最初のエポック(未指定時と同じ)。注意:initial_epoch/stepはlr schedulerに影響しないため、`--resume`しない場合はlr schedulerは0から始まる",
1457
+ )
1458
+ parser.add_argument(
1459
+ "--initial_step",
1460
+ type=int,
1461
+ default=None,
1462
+ help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch."
1463
+ + " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする",
1464
+ )
1465
+ # parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio")
1466
+ # parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio")
1467
+ # parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")
1468
+ return parser
1469
+
1470
+
1471
+ if __name__ == "__main__":
1472
+ parser = setup_parser()
1473
+
1474
+ args = parser.parse_args()
1475
+ train_util.verify_command_line_training_args(args)
1476
+ args = train_util.read_config_from_file(args, parser)
1477
+
1478
+ trainer = NetworkTrainer()
1479
+ trainer.train(args)
train_network_asylora.py ADDED
@@ -0,0 +1,1492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import argparse
3
+ import math
4
+ import os
5
+ import sys
6
+ import random
7
+ import time
8
+ import json
9
+ from multiprocessing import Value
10
+ from typing import Any, List
11
+ import toml
12
+
13
+ from tqdm import tqdm
14
+
15
+ import torch
16
+ from library.device_utils import init_ipex, clean_memory_on_device
17
+
18
+ init_ipex()
19
+
20
+ from accelerate.utils import set_seed
21
+ from accelerate import Accelerator
22
+ from diffusers import DDPMScheduler
23
+ from library import deepspeed_utils, model_util, strategy_base, strategy_sd
24
+
25
+ import library.train_util as train_util
26
+ from library.train_util import DreamBoothDataset
27
+ import library.config_util as config_util
28
+ from library.config_util import (
29
+ ConfigSanitizer,
30
+ BlueprintGenerator,
31
+ )
32
+ import library.huggingface_util as huggingface_util
33
+ import library.custom_train_functions as custom_train_functions
34
+ from library.custom_train_functions import (
35
+ apply_snr_weight,
36
+ get_weighted_text_embeddings,
37
+ prepare_scheduler_for_custom_training,
38
+ scale_v_prediction_loss_like_noise_prediction,
39
+ add_v_prediction_like_loss,
40
+ apply_debiased_estimation,
41
+ apply_masked_loss,
42
+ )
43
+ from library.utils import setup_logging, add_logging_arguments
44
+
45
+ setup_logging()
46
+ import logging
47
+
48
+ logger = logging.getLogger(__name__)
49
+
50
+
51
+ class NetworkTrainer:
52
+ def __init__(self):
53
+ self.vae_scale_factor = 0.18215
54
+ self.is_sdxl = False
55
+
56
+ # TODO 他のスクリプトと共通化する
57
+ def generate_step_logs(
58
+ self,
59
+ args: argparse.Namespace,
60
+ current_loss,
61
+ avr_loss,
62
+ lr_scheduler,
63
+ lr_descriptions,
64
+ keys_scaled=None,
65
+ mean_norm=None,
66
+ maximum_norm=None,
67
+ ):
68
+ logs = {"loss/current": current_loss, "loss/average": avr_loss}
69
+
70
+ if keys_scaled is not None:
71
+ logs["max_norm/keys_scaled"] = keys_scaled
72
+ logs["max_norm/average_key_norm"] = mean_norm
73
+ logs["max_norm/max_key_norm"] = maximum_norm
74
+
75
+ lrs = lr_scheduler.get_last_lr()
76
+ for i, lr in enumerate(lrs):
77
+ if lr_descriptions is not None:
78
+ lr_desc = lr_descriptions[i]
79
+ else:
80
+ idx = i - (0 if args.network_train_unet_only else -1)
81
+ if idx == -1:
82
+ lr_desc = "textencoder"
83
+ else:
84
+ if len(lrs) > 2:
85
+ lr_desc = f"group{idx}"
86
+ else:
87
+ lr_desc = "unet"
88
+
89
+ logs[f"lr/{lr_desc}"] = lr
90
+
91
+ if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
92
+ # tracking d*lr value
93
+ logs[f"lr/d*lr/{lr_desc}"] = (
94
+ lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
95
+ )
96
+
97
+ return logs
98
+
99
+ def assert_extra_args(self, args, train_dataset_group):
100
+ train_dataset_group.verify_bucket_reso_steps(64)
101
+
102
+ def load_target_model(self, args, weight_dtype, accelerator):
103
+ text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
104
+
105
+ # モデルに xformers とか memory efficient attention を組み込む
106
+ train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
107
+ if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
108
+ vae.set_use_memory_efficient_attention_xformers(args.xformers)
109
+
110
+ return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet
111
+
112
+ def get_tokenize_strategy(self, args):
113
+ return strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
114
+
115
+ def get_tokenizers(self, tokenize_strategy: strategy_sd.SdTokenizeStrategy) -> List[Any]:
116
+ return [tokenize_strategy.tokenizer]
117
+
118
+ def get_latents_caching_strategy(self, args):
119
+ latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
120
+ True, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
121
+ )
122
+ return latents_caching_strategy
123
+
124
+ def get_text_encoding_strategy(self, args):
125
+ return strategy_sd.SdTextEncodingStrategy(args.clip_skip)
126
+
127
+ def get_text_encoder_outputs_caching_strategy(self, args):
128
+ return None
129
+
130
+ def get_models_for_text_encoding(self, args, accelerator, text_encoders):
131
+ """
132
+ Returns a list of models that will be used for text encoding. SDXL uses wrapped and unwrapped models.
133
+ FLUX.1 and SD3 may cache some outputs of the text encoder, so return the models that will be used for encoding (not cached).
134
+ """
135
+ return text_encoders
136
+
137
+ # returns a list of bool values indicating whether each text encoder should be trained
138
+ def get_text_encoders_train_flags(self, args, text_encoders):
139
+ return [True] * len(text_encoders) if self.is_train_text_encoder(args) else [False] * len(text_encoders)
140
+
141
+ def is_train_text_encoder(self, args):
142
+ return not args.network_train_unet_only
143
+
144
+ def cache_text_encoder_outputs_if_needed(self, args, accelerator, unet, vae, text_encoders, dataset, weight_dtype):
145
+ for t_enc in text_encoders:
146
+ t_enc.to(accelerator.device, dtype=weight_dtype)
147
+
148
+ def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype, **kwargs):
149
+ noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample
150
+ return noise_pred
151
+
152
+ def all_reduce_network(self, accelerator, network):
153
+ for param in network.parameters():
154
+ if param.grad is not None:
155
+ param.grad = accelerator.reduce(param.grad, reduction="mean")
156
+
157
+ def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoder, unet):
158
+ train_util.sample_iDDmages(accelerator, args, epoch, global_step, device, vae, tokenizers[0], text_encoder, unet)
159
+
160
+ # region SD/SDXL
161
+
162
+ def post_process_network(self, args, accelerator, network, text_encoders, unet):
163
+ pass
164
+
165
+ def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
166
+ noise_scheduler = DDPMScheduler(
167
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
168
+ )
169
+ prepare_scheduler_for_custom_training(noise_scheduler, device)
170
+ if args.zero_terminal_snr:
171
+ custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
172
+ return noise_scheduler
173
+
174
+ def encode_images_to_latents(self, args, accelerator, vae, images):
175
+ return vae.encode(images).latent_dist.sample()
176
+
177
+ def shift_scale_latents(self, args, latents):
178
+ return latents * self.vae_scale_factor
179
+
180
+ def get_noise_pred_and_target(
181
+ self,
182
+ args,
183
+ accelerator,
184
+ noise_scheduler,
185
+ latents,
186
+ batch,
187
+ text_encoder_conds,
188
+ unet,
189
+ network,
190
+ weight_dtype,
191
+ train_unet,
192
+ ):
193
+ # Sample noise, sample a random timestep for each image, and add noise to the latents,
194
+ # with noise offset and/or multires noise if specified
195
+ noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
196
+
197
+ # ensure the hidden state will require grad
198
+ if args.gradient_checkpointing:
199
+ for x in noisy_latents:
200
+ x.requires_grad_(True)
201
+ for t in text_encoder_conds:
202
+ t.requires_grad_(True)
203
+
204
+ # Predict the noise residual
205
+ with accelerator.autocast():
206
+ noise_pred = self.call_unet(
207
+ args,
208
+ accelerator,
209
+ unet,
210
+ noisy_latents.requires_grad_(train_unet),
211
+ timesteps,
212
+ text_encoder_conds,
213
+ batch,
214
+ weight_dtype,
215
+ )
216
+
217
+ if args.v_parameterization:
218
+ # v-parameterization training
219
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
220
+ else:
221
+ target = noise
222
+
223
+ # differential output preservation
224
+ if "custom_attributes" in batch:
225
+ diff_output_pr_indices = []
226
+ for i, custom_attributes in enumerate(batch["custom_attributes"]):
227
+ if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
228
+ diff_output_pr_indices.append(i)
229
+
230
+ if len(diff_output_pr_indices) > 0:
231
+ network.set_multiplier(0.0)
232
+ with torch.no_grad(), accelerator.autocast():
233
+ noise_pred_prior = self.call_unet(
234
+ args,
235
+ accelerator,
236
+ unet,
237
+ noisy_latents,
238
+ timesteps,
239
+ text_encoder_conds,
240
+ batch,
241
+ weight_dtype,
242
+ indices=diff_output_pr_indices,
243
+ )
244
+ network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
245
+ target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)
246
+
247
+ return noise_pred, target, timesteps, huber_c, None
248
+
249
+ def post_process_loss(self, loss, args, timesteps, noise_scheduler):
250
+ if args.min_snr_gamma:
251
+ loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
252
+ if args.scale_v_pred_loss_like_noise_pred:
253
+ loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
254
+ if args.v_pred_like_loss:
255
+ loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
256
+ if args.debiased_estimation_loss:
257
+ loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
258
+ return loss
259
+
260
+ def get_sai_model_spec(self, args):
261
+ return train_util.get_sai_model_spec(None, args, self.is_sdxl, True, False)
262
+
263
+ def update_metadata(self, metadata, args):
264
+ pass
265
+
266
+ def is_text_encoder_not_needed_for_training(self, args):
267
+ return False # use for sample images
268
+
269
+ def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
270
+ # set top parameter requires_grad = True for gradient checkpointing works
271
+ text_encoder.text_model.embeddings.requires_grad_(True)
272
+
273
+ def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
274
+ text_encoder.text_model.embeddings.to(dtype=weight_dtype)
275
+
276
+ def prepare_unet_with_accelerator(
277
+ self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
278
+ ) -> torch.nn.Module:
279
+ return accelerator.prepare(unet)
280
+
281
+ def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
282
+ pass
283
+
284
+ # endregion
285
+
286
+ def train(self, args):
287
+ session_id = random.randint(0, 2**32)
288
+ training_started_at = time.time()
289
+ train_util.verify_training_args(args)
290
+ train_util.prepare_dataset_args(args, True)
291
+ deepspeed_utils.prepare_deepspeed_args(args)
292
+ setup_logging(args, reset=True)
293
+
294
+ cache_latents = args.cache_latents
295
+ use_dreambooth_method = args.in_json is None
296
+ use_user_config = args.dataset_config is not None
297
+
298
+ if args.seed is None:
299
+ args.seed = random.randint(0, 2**32)
300
+ set_seed(args.seed)
301
+
302
+ tokenize_strategy = self.get_tokenize_strategy(args)
303
+ strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
304
+ tokenizers = self.get_tokenizers(tokenize_strategy) # will be removed after sample_image is refactored
305
+
306
+ # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
307
+ latents_caching_strategy = self.get_latents_caching_strategy(args)
308
+ strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
309
+
310
+ # 准备数据集
311
+ if args.dataset_class is None:
312
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
313
+ if use_user_config:
314
+ logger.info(f"Loading dataset config from {args.dataset_config}")
315
+ user_config = config_util.load_user_config(args.dataset_config)
316
+ ignored = ["train_data_dir", "reg_data_dir", "in_json"]
317
+ # if any(getattr(args, attr) is not None for attr in ignored):
318
+ # logger.warning(
319
+ # "ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
320
+ # ", ".join(ignored)
321
+ # )
322
+ # )
323
+ # else:
324
+ # if use_dreambooth_method:
325
+ # logger.info("Using DreamBooth method.")
326
+ # user_config = {
327
+ # "datasets": [
328
+ # {
329
+ # "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
330
+ # args.train_data_dir, args.reg_data_dir
331
+ # )
332
+ # }
333
+ # ]
334
+ # }
335
+ # else:
336
+ # logger.info("Training with captions.")
337
+ # user_config = {
338
+ # "datasets": [
339
+ # {
340
+ # "subsets": [
341
+ # {
342
+ # "image_dir": args.train_data_dir,
343
+ # "metadata_file": args.in_json,
344
+ # }
345
+ # ]
346
+ # }
347
+ # ]
348
+ # }
349
+
350
+ blueprint = blueprint_generator.generate(user_config, args) # user_config: LoraConfig.toml
351
+ train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
352
+ else:
353
+ # use arbitrary dataset class
354
+ train_dataset_group = train_util.load_arbitrary_dataset(args)
355
+
356
+ current_epoch = Value("i", 0)
357
+ current_step = Value("i", 0)
358
+ ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
359
+ collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
360
+
361
+ if args.debug_dataset:
362
+ train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly
363
+ train_util.debug_dataset(train_dataset_group)
364
+ return
365
+ if len(train_dataset_group) == 0:
366
+ logger.error(
367
+ "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)"
368
+ )
369
+ return
370
+
371
+ if cache_latents:
372
+ assert (
373
+ train_dataset_group.is_latent_cacheable()
374
+ ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
375
+
376
+ self.assert_extra_args(args, train_dataset_group) # may change some args
377
+
378
+ # acceleratorを準備する
379
+ logger.info("preparing accelerator")
380
+ accelerator = train_util.prepare_accelerator(args)
381
+ is_main_process = accelerator.is_main_process
382
+
383
+ # mixed precisionに対応した型を用意しておき適宜castする
384
+ weight_dtype, save_dtype = train_util.prepare_dtype(args)
385
+ vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
386
+
387
+ # 加载模型 model_version: flux text_encoder: t5, clip
388
+ model_version, text_encoder, vae, unet = self.load_target_model(args, weight_dtype, accelerator)
389
+
390
+ # text_encoder is List[CLIPTextModel] or CLIPTextModel
391
+ text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
392
+
393
+ # 加载模型以进行额外的差异学习
394
+ sys.path.append(os.path.dirname(__file__))
395
+ accelerator.print("import network module:", args.network_module)
396
+ network_module = importlib.import_module(args.network_module)
397
+
398
+ # if args.base_weights is not None:
399
+ # # base_weights が指定されている場合は、指定された重みを読み込みマージする
400
+ # for i, weight_path in enumerate(args.base_weights):
401
+ # if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i:
402
+ # multiplier = 1.0
403
+ # else:
404
+ # multiplier = args.base_weights_multiplier[i]
405
+ #
406
+ # accelerator.print(f"merging module: {weight_path} with multiplier {multiplier}")
407
+ #
408
+ # module, weights_sd = network_module.create_network_from_weights(
409
+ # multiplier, weight_path, vae, text_encoder, unet, for_inference=True
410
+ # )
411
+ # module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu")
412
+ #
413
+ # accelerator.print(f"all weights merged: {', '.join(args.base_weights)}")
414
+
415
+ # 准备学习
416
+ if cache_latents:
417
+ vae.to(accelerator.device, dtype=vae_dtype)
418
+ vae.requires_grad_(False)
419
+ vae.eval()
420
+
421
+ train_dataset_group.new_cache_latents(vae, accelerator)
422
+
423
+ vae.to("cpu")
424
+ clean_memory_on_device(accelerator.device)
425
+
426
+ accelerator.wait_for_everyone()
427
+
428
+ # 如有必要,缓存文本编码器输出:将Text Encoder移至cpu或gpu
429
+ # cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu
430
+ text_encoding_strategy = self.get_text_encoding_strategy(args)
431
+ strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
432
+
433
+ text_encoder_outputs_caching_strategy = self.get_text_encoder_outputs_caching_strategy(args)
434
+ if text_encoder_outputs_caching_strategy is not None:
435
+ strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_outputs_caching_strategy)
436
+ self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, train_dataset_group, weight_dtype)
437
+
438
+ # prepare network
439
+ net_kwargs = {}
440
+ if args.network_args is not None:
441
+ for net_arg in args.network_args:
442
+ key, value = net_arg.split("=")
443
+ net_kwargs[key] = value
444
+
445
+ # if a new network is added in future, add if ~ then blocks for each network (;'?')
446
+ if args.dim_from_weights:
447
+ network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs)
448
+ else:
449
+ if "dropout" not in net_kwargs:
450
+ # workaround for LyCORIS (;^ω^)
451
+ net_kwargs["dropout"] = args.network_dropout
452
+
453
+ network = network_module.create_network(
454
+ 1.0,
455
+ args.network_dim,
456
+ args.network_alpha,
457
+ vae,
458
+ text_encoder,
459
+ unet,
460
+ lora_ups_num=args.lora_ups_num,
461
+ neuron_dropout=args.network_dropout,
462
+ **net_kwargs,
463
+ )
464
+ if network is None:
465
+ return
466
+ network_has_multiplier = hasattr(network, "set_multiplier") # network_has_multiplier: True
467
+
468
+ if hasattr(network, "prepare_network"):
469
+ network.prepare_network(args)
470
+ if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"):
471
+ logger.warning(
472
+ "warning: scale_weight_norms is specified but the network does not support it / scale_weight_normsが指定されていますが、ネットワークが対応していません"
473
+ )
474
+ args.scale_weight_norms = False
475
+
476
+ self.post_process_network(args, accelerator, network, text_encoders, unet)
477
+
478
+ # apply network to unet and text_encoder
479
+ train_unet = not args.network_train_text_encoder_only # train_unet: True
480
+ train_text_encoder = self.is_train_text_encoder(args) # train_text_encoder: False
481
+ network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
482
+
483
+ if args.network_weights is not None:
484
+ # FIXME consider alpha of weights: this assumes that the alpha is not changed
485
+ info = network.load_weights(args.network_weights)
486
+ accelerator.print(f"load network weights from {args.network_weights}: {info}")
487
+
488
+ if args.gradient_checkpointing:
489
+ if args.cpu_offload_checkpointing:
490
+ unet.enable_gradient_checkpointing(cpu_offload=True)
491
+ else:
492
+ unet.enable_gradient_checkpointing()
493
+
494
+ for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)):
495
+ if flag:
496
+ if t_enc.supports_gradient_checkpointing:
497
+ t_enc.gradient_checkpointing_enable()
498
+ del t_enc
499
+ network.enable_gradient_checkpointing() # may have no effect
500
+
501
+ # 准备优化器等超参数
502
+ accelerator.print("prepare optimizer, data loader etc.")
503
+
504
+ # make backward compatibility for text_encoder_lr
505
+ support_multiple_lrs = hasattr(network, "prepare_optimizer_params_with_multiple_te_lrs")
506
+ if support_multiple_lrs:
507
+ text_encoder_lr = args.text_encoder_lr
508
+ else:
509
+ # toml backward compatibility
510
+ if args.text_encoder_lr is None or isinstance(args.text_encoder_lr, float) or isinstance(args.text_encoder_lr, int):
511
+ text_encoder_lr = args.text_encoder_lr
512
+ else:
513
+ text_encoder_lr = None if len(args.text_encoder_lr) == 0 else args.text_encoder_lr[0]
514
+ try:
515
+ if support_multiple_lrs:
516
+ results = network.prepare_optimizer_params_with_multiple_te_lrs(text_encoder_lr, args.unet_lr, args.learning_rate)
517
+ else:
518
+ results = network.prepare_optimizer_params(text_encoder_lr, args.unet_lr, args.learning_rate)
519
+ if type(results) is tuple:
520
+ trainable_params = results[0]
521
+ lr_descriptions = results[1]
522
+ else:
523
+ trainable_params = results
524
+ lr_descriptions = None
525
+ except TypeError as e:
526
+ trainable_params = network.prepare_optimizer_params(text_encoder_lr, args.unet_lr)
527
+ lr_descriptions = None
528
+
529
+ # if len(trainable_params) == 0:
530
+ # accelerator.print("no trainable parameters found / 学習可能なパラメータが見つかりませんでした")
531
+ # for params in trainable_params:
532
+ # for k, v in params.items():
533
+ # if type(v) == float:
534
+ # pass
535
+ # else:
536
+ # v = len(v)
537
+ # accelerator.print(f"trainable_params: {k} = {v}")
538
+
539
+ optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
540
+ optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
541
+
542
+ # prepare dataloader
543
+ # strategies are set here because they cannot be referenced in another process. Copy them with the dataset
544
+ # some strategies can be None
545
+ train_dataset_group.set_current_strategies()
546
+
547
+ n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
548
+
549
+ train_dataloader = torch.utils.data.DataLoader(
550
+ train_dataset_group,
551
+ batch_size=1,
552
+ shuffle=True,
553
+ collate_fn=collator,
554
+ num_workers=n_workers,
555
+ persistent_workers=args.persistent_data_loader_workers,
556
+ )
557
+
558
+ # 计算学习步数
559
+ if args.max_train_epochs is not None:
560
+ args.max_train_steps = args.max_train_epochs * math.ceil(
561
+ len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
562
+ )
563
+ accelerator.print(
564
+ f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
565
+ )
566
+
567
+ # 设置最大训练步数
568
+ train_dataset_group.set_max_train_steps(args.max_train_steps)
569
+
570
+ # 设置调度器
571
+ lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
572
+
573
+ # 实验功能:进行fp16/bf16学习包括梯度,将整个模型设置为fp16/bf16
574
+ if args.full_fp16: # 不走
575
+ assert (
576
+ args.mixed_precision == "fp16"
577
+ ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
578
+ accelerator.print("enable full fp16 training.")
579
+ network.to(weight_dtype)
580
+ elif args.full_bf16: # 不走
581
+ assert (
582
+ args.mixed_precision == "bf16"
583
+ ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
584
+ accelerator.print("enable full bf16 training.")
585
+ network.to(weight_dtype)
586
+
587
+ unet_weight_dtype = te_weight_dtype = weight_dtype
588
+ # Experimental Feature: Put base model into fp8 to save vram
589
+ if args.fp8_base or args.fp8_base_unet:
590
+ assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。"
591
+ assert (
592
+ args.mixed_precision != "no"
593
+ ), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。"
594
+ accelerator.print("enable fp8 training for U-Net.")
595
+ unet_weight_dtype = torch.float8_e4m3fn # torch.float8_e4m3fn
596
+
597
+ if not args.fp8_base_unet:
598
+ accelerator.print("enable fp8 training for Text Encoder.")
599
+ te_weight_dtype = weight_dtype if args.fp8_base_unet else torch.float8_e4m3fn
600
+
601
+ # unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM
602
+ # unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory
603
+
604
+ # logger.info(f"set U-Net weight dtype to {unet_weight_dtype}, device to {accelerator.device}")
605
+ # unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above
606
+ logger.info(f"set U-Net weight dtype to {unet_weight_dtype}")
607
+ unet.to(dtype=unet_weight_dtype) # do not move to device because unet is not prepared by accelerator
608
+
609
+ unet.requires_grad_(False)
610
+ unet.to(dtype=unet_weight_dtype)
611
+ for i, t_enc in enumerate(text_encoders):
612
+ t_enc.requires_grad_(False)
613
+
614
+ # in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16
615
+ if t_enc.device.type != "cpu":
616
+ t_enc.to(dtype=te_weight_dtype)
617
+
618
+ # nn.Embedding not support FP8
619
+ if te_weight_dtype != weight_dtype:
620
+ self.prepare_text_encoder_fp8(i, t_enc, te_weight_dtype, weight_dtype)
621
+
622
+ # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
623
+ if args.deepspeed: # 不走
624
+ flags = self.get_text_encoders_train_flags(args, text_encoders)
625
+ ds_model = deepspeed_utils.prepare_deepspeed_model(
626
+ args,
627
+ unet=unet if train_unet else None,
628
+ text_encoder1=text_encoders[0] if flags[0] else None,
629
+ text_encoder2=(text_encoders[1] if flags[1] else None) if len(text_encoders) > 1 else None,
630
+ network=network,
631
+ )
632
+ ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
633
+ ds_model, optimizer, train_dataloader, lr_scheduler
634
+ )
635
+ training_model = ds_model
636
+ else:
637
+ if train_unet:
638
+ # default implementation is: unet = accelerator.prepare(unet)
639
+ unet = self.prepare_unet_with_accelerator(args, accelerator, unet) # accelerator does some magic here
640
+ else:
641
+ unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator
642
+ if train_text_encoder:
643
+ text_encoders = [
644
+ (accelerator.prepare(t_enc) if flag else t_enc)
645
+ for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders))
646
+ ]
647
+ if len(text_encoders) > 1:
648
+ text_encoder = text_encoders
649
+ else:
650
+ text_encoder = text_encoders[0]
651
+ else:
652
+ pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set
653
+
654
+ network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
655
+ network, optimizer, train_dataloader, lr_scheduler
656
+ )
657
+ training_model = network
658
+
659
+ if args.gradient_checkpointing:
660
+ # according to TI example in Diffusers, train is required
661
+ unet.train()
662
+ for i, (t_enc, frag) in enumerate(zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders))):
663
+ t_enc.train()
664
+
665
+ # set top parameter requires_grad = True for gradient checkpointing works
666
+ if frag:
667
+ self.prepare_text_encoder_grad_ckpt_workaround(i, t_enc)
668
+
669
+ else:
670
+ unet.eval()
671
+ for t_enc in text_encoders:
672
+ t_enc.eval()
673
+
674
+ del t_enc
675
+
676
+ accelerator.unwrap_model(network).prepare_grad_etc(text_encoder, unet)
677
+
678
+ if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する
679
+ vae.requires_grad_(False)
680
+ vae.eval()
681
+ vae.to(accelerator.device, dtype=vae_dtype)
682
+
683
+ # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
684
+ if args.full_fp16:
685
+ train_util.patch_accelerator_for_fp16_training(accelerator)
686
+
687
+ # before resuming make hook for saving/loading to save/load the network weights only
688
+ def save_model_hook(models, weights, output_dir):
689
+ # pop weights of other models than network to save only network weights
690
+ # only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606
691
+ if accelerator.is_main_process or args.deepspeed:
692
+ remove_indices = []
693
+ for i, model in enumerate(models):
694
+ if not isinstance(model, type(accelerator.unwrap_model(network))):
695
+ remove_indices.append(i)
696
+ for i in reversed(remove_indices):
697
+ if len(weights) > i:
698
+ weights.pop(i)
699
+ # print(f"save model hook: {len(weights)} weights will be saved")
700
+
701
+ # save current ecpoch and step
702
+ train_state_file = os.path.join(output_dir, "train_state.json")
703
+ # +1 is needed because the state is saved before current_step is set from global_step
704
+ logger.info(f"save train state to {train_state_file} at epoch {current_epoch.value} step {current_step.value+1}")
705
+ with open(train_state_file, "w", encoding="utf-8") as f:
706
+ json.dump({"current_epoch": current_epoch.value, "current_step": current_step.value + 1}, f)
707
+
708
+ steps_from_state = None
709
+
710
+ def load_model_hook(models, input_dir):
711
+ # remove models except network
712
+ remove_indices = []
713
+ for i, model in enumerate(models):
714
+ if not isinstance(model, type(accelerator.unwrap_model(network))):
715
+ remove_indices.append(i)
716
+ for i in reversed(remove_indices):
717
+ models.pop(i)
718
+ # print(f"load model hook: {len(models)} models will be loaded")
719
+
720
+ # load current epoch and step to
721
+ nonlocal steps_from_state
722
+ train_state_file = os.path.join(input_dir, "train_state.json")
723
+ if os.path.exists(train_state_file):
724
+ with open(train_state_file, "r", encoding="utf-8") as f:
725
+ data = json.load(f)
726
+ steps_from_state = data["current_step"]
727
+ logger.info(f"load train state from {train_state_file}: {data}")
728
+
729
+ accelerator.register_save_state_pre_hook(save_model_hook)
730
+ accelerator.register_load_state_pre_hook(load_model_hook)
731
+
732
+ # resumeする
733
+ train_util.resume_from_local_or_hf_if_specified(accelerator, args)
734
+
735
+ # epoch数を計算する
736
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
737
+ num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
738
+ if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
739
+ args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
740
+
741
+ # 学習する
742
+ # TODO: find a way to handle total batch size when there are multiple datasets
743
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
744
+
745
+ accelerator.print("开始训练 / Training started")
746
+ accelerator.print(
747
+ f" 训练图片数 * 重复次数 / Number of training images * repeats: {train_dataset_group.num_train_images}")
748
+ accelerator.print(f" 正则化图片数 / Number of regularization images: {train_dataset_group.num_reg_images}")
749
+ accelerator.print(f" 每个 epoch 的批次数 / Number of batches per epoch: {len(train_dataloader)}")
750
+ accelerator.print(f" 训练的 epoch 数 / Number of epochs: {num_train_epochs}")
751
+ accelerator.print(
752
+ f" 每个设备的批次大小 / Batch size per device: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
753
+ )
754
+ # accelerator.print(f" 总批次大小(包括并行和分布式训练及梯度累积)/ Total batch size (with parallel & distributed & accumulation): {total_batch_size}")
755
+ accelerator.print(f" 梯度累积步数 / Gradient accumulation steps: {args.gradient_accumulation_steps}")
756
+ accelerator.print(f" 总优化步骤数 / Total optimization steps: {args.max_train_steps}")
757
+
758
+ # TODO refactor metadata creation and move to util
759
+ metadata = {
760
+ "ss_session_id": session_id, # random integer indicating which group of epochs the model came from
761
+ "ss_training_started_at": training_started_at, # unix timestamp
762
+ "ss_output_name": args.output_name,
763
+ "ss_learning_rate": args.learning_rate,
764
+ "ss_text_encoder_lr": text_encoder_lr,
765
+ "ss_unet_lr": args.unet_lr,
766
+ "ss_num_train_images": train_dataset_group.num_train_images,
767
+ "ss_num_reg_images": train_dataset_group.num_reg_images,
768
+ "ss_num_batches_per_epoch": len(train_dataloader),
769
+ "ss_num_epochs": num_train_epochs,
770
+ "ss_gradient_checkpointing": args.gradient_checkpointing,
771
+ "ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
772
+ "ss_max_train_steps": args.max_train_steps,
773
+ "ss_lr_warmup_steps": args.lr_warmup_steps,
774
+ "ss_lr_scheduler": args.lr_scheduler,
775
+ "ss_network_module": args.network_module,
776
+ "ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim
777
+ "ss_network_alpha": args.network_alpha, # some networks may not have alpha
778
+ "ss_network_dropout": args.network_dropout, # some networks may not have dropout
779
+ "ss_mixed_precision": args.mixed_precision,
780
+ "ss_full_fp16": bool(args.full_fp16),
781
+ "ss_v2": bool(args.v2),
782
+ "ss_base_model_version": model_version,
783
+ "ss_clip_skip": args.clip_skip,
784
+ "ss_max_token_length": args.max_token_length,
785
+ "ss_cache_latents": bool(args.cache_latents),
786
+ "ss_seed": args.seed,
787
+ "ss_lowram": args.lowram,
788
+ "ss_noise_offset": args.noise_offset,
789
+ "ss_multires_noise_iterations": args.multires_noise_iterations,
790
+ "ss_multires_noise_discount": args.multires_noise_discount,
791
+ "ss_adaptive_noise_scale": args.adaptive_noise_scale,
792
+ "ss_zero_terminal_snr": args.zero_terminal_snr,
793
+ "ss_training_comment": args.training_comment, # will not be updated after training
794
+ "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
795
+ "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""),
796
+ "ss_max_grad_norm": args.max_grad_norm,
797
+ "ss_caption_dropout_rate": args.caption_dropout_rate,
798
+ "ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs,
799
+ "ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate,
800
+ "ss_face_crop_aug_range": args.face_crop_aug_range,
801
+ "ss_prior_loss_weight": args.prior_loss_weight,
802
+ "ss_min_snr_gamma": args.min_snr_gamma,
803
+ "ss_scale_weight_norms": args.scale_weight_norms,
804
+ "ss_ip_noise_gamma": args.ip_noise_gamma,
805
+ "ss_debiased_estimation": bool(args.debiased_estimation_loss),
806
+ "ss_noise_offset_random_strength": args.noise_offset_random_strength,
807
+ "ss_ip_noise_gamma_random_strength": args.ip_noise_gamma_random_strength,
808
+ "ss_loss_type": args.loss_type,
809
+ "ss_huber_schedule": args.huber_schedule,
810
+ "ss_huber_c": args.huber_c,
811
+ "ss_fp8_base": bool(args.fp8_base),
812
+ "ss_fp8_base_unet": bool(args.fp8_base_unet),
813
+ }
814
+
815
+ self.update_metadata(metadata, args) # architecture specific metadata
816
+
817
+ if use_user_config:
818
+ # save metadata of multiple datasets
819
+ # NOTE: pack "ss_datasets" value as json one time
820
+ # or should also pack nested collections as json?
821
+ datasets_metadata = []
822
+ tag_frequency = {} # merge tag frequency for metadata editor
823
+ dataset_dirs_info = {} # merge subset dirs for metadata editor
824
+
825
+ for dataset in train_dataset_group.datasets:
826
+ is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset)
827
+ dataset_metadata = {
828
+ "is_dreambooth": is_dreambooth_dataset,
829
+ "batch_size_per_device": dataset.batch_size,
830
+ "num_train_images": dataset.num_train_images, # includes repeating
831
+ "num_reg_images": dataset.num_reg_images,
832
+ "resolution": (dataset.width, dataset.height),
833
+ "enable_bucket": bool(dataset.enable_bucket),
834
+ "min_bucket_reso": dataset.min_bucket_reso,
835
+ "max_bucket_reso": dataset.max_bucket_reso,
836
+ "tag_frequency": dataset.tag_frequency,
837
+ "bucket_info": dataset.bucket_info,
838
+ }
839
+
840
+ subsets_metadata = []
841
+ for subset in dataset.subsets:
842
+ subset_metadata = {
843
+ "img_count": subset.img_count,
844
+ "num_repeats": subset.num_repeats,
845
+ "color_aug": bool(subset.color_aug),
846
+ "flip_aug": bool(subset.flip_aug),
847
+ "random_crop": bool(subset.random_crop),
848
+ "shuffle_caption": bool(subset.shuffle_caption),
849
+ "keep_tokens": subset.keep_tokens,
850
+ "keep_tokens_separator": subset.keep_tokens_separator,
851
+ "secondary_separator": subset.secondary_separator,
852
+ "enable_wildcard": bool(subset.enable_wildcard),
853
+ "caption_prefix": subset.caption_prefix,
854
+ "caption_suffix": subset.caption_suffix,
855
+ }
856
+
857
+ image_dir_or_metadata_file = None
858
+ if subset.image_dir:
859
+ image_dir = os.path.basename(subset.image_dir)
860
+ subset_metadata["image_dir"] = image_dir
861
+ image_dir_or_metadata_file = image_dir
862
+
863
+ if is_dreambooth_dataset:
864
+ subset_metadata["class_tokens"] = subset.class_tokens
865
+ subset_metadata["is_reg"] = subset.is_reg
866
+ if subset.is_reg:
867
+ image_dir_or_metadata_file = None # not merging reg dataset
868
+ else:
869
+ metadata_file = os.path.basename(subset.metadata_file)
870
+ subset_metadata["metadata_file"] = metadata_file
871
+ image_dir_or_metadata_file = metadata_file # may overwrite
872
+
873
+ subsets_metadata.append(subset_metadata)
874
+
875
+ # merge dataset dir: not reg subset only
876
+ # TODO update additional-network extension to show detailed dataset config from metadata
877
+ if image_dir_or_metadata_file is not None:
878
+ # datasets may have a certain dir multiple times
879
+ v = image_dir_or_metadata_file
880
+ i = 2
881
+ while v in dataset_dirs_info:
882
+ v = image_dir_or_metadata_file + f" ({i})"
883
+ i += 1
884
+ image_dir_or_metadata_file = v
885
+
886
+ dataset_dirs_info[image_dir_or_metadata_file] = {
887
+ "n_repeats": subset.num_repeats,
888
+ "img_count": subset.img_count,
889
+ }
890
+
891
+ dataset_metadata["subsets"] = subsets_metadata
892
+ datasets_metadata.append(dataset_metadata)
893
+
894
+ # merge tag frequency:
895
+ for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items():
896
+ # 如果一个目录被多个dataset使用,则只计数一次
897
+ # 因为我们最初指定了重复次数,所以标签在标题中出现的次数和它在学习中使用的次数并不匹配。
898
+ # 所以在这里把多个dataset的次数加在一起也没什么意义
899
+ if ds_dir_name in tag_frequency:
900
+ continue
901
+ tag_frequency[ds_dir_name] = ds_freq_for_dir
902
+
903
+ metadata["ss_datasets"] = json.dumps(datasets_metadata)
904
+ metadata["ss_tag_frequency"] = json.dumps(tag_frequency)
905
+ metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info)
906
+ else:
907
+ # conserving backward compatibility when using train_dataset_dir and reg_dataset_dir
908
+ assert (
909
+ len(train_dataset_group.datasets) == 1
910
+ ), f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。"
911
+
912
+ dataset = train_dataset_group.datasets[0]
913
+
914
+ dataset_dirs_info = {}
915
+ reg_dataset_dirs_info = {}
916
+ if use_dreambooth_method:
917
+ for subset in dataset.subsets:
918
+ info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info
919
+ info[os.path.basename(subset.image_dir)] = {"n_repeats": subset.num_repeats, "img_count": subset.img_count}
920
+ else:
921
+ for subset in dataset.subsets:
922
+ dataset_dirs_info[os.path.basename(subset.metadata_file)] = {
923
+ "n_repeats": subset.num_repeats,
924
+ "img_count": subset.img_count,
925
+ }
926
+
927
+ metadata.update(
928
+ {
929
+ "ss_batch_size_per_device": args.train_batch_size,
930
+ "ss_total_batch_size": total_batch_size,
931
+ "ss_resolution": args.resolution,
932
+ "ss_color_aug": bool(args.color_aug),
933
+ "ss_flip_aug": bool(args.flip_aug),
934
+ "ss_random_crop": bool(args.random_crop),
935
+ "ss_shuffle_caption": bool(args.shuffle_caption),
936
+ "ss_enable_bucket": bool(dataset.enable_bucket),
937
+ "ss_bucket_no_upscale": bool(dataset.bucket_no_upscale),
938
+ "ss_min_bucket_reso": dataset.min_bucket_reso,
939
+ "ss_max_bucket_reso": dataset.max_bucket_reso,
940
+ "ss_keep_tokens": args.keep_tokens,
941
+ "ss_dataset_dirs": json.dumps(dataset_dirs_info),
942
+ "ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info),
943
+ "ss_tag_frequency": json.dumps(dataset.tag_frequency),
944
+ "ss_bucket_info": json.dumps(dataset.bucket_info),
945
+ }
946
+ )
947
+
948
+ # add extra args
949
+ if args.network_args:
950
+ metadata["ss_network_args"] = json.dumps(net_kwargs)
951
+
952
+ # model name and hash
953
+ if args.pretrained_model_name_or_path is not None:
954
+ sd_model_name = args.pretrained_model_name_or_path
955
+ if os.path.exists(sd_model_name):
956
+ metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name)
957
+ metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(sd_model_name)
958
+ sd_model_name = os.path.basename(sd_model_name)
959
+ metadata["ss_sd_model_name"] = sd_model_name
960
+
961
+ if args.vae is not None:
962
+ vae_name = args.vae
963
+ if os.path.exists(vae_name):
964
+ metadata["ss_vae_hash"] = train_util.model_hash(vae_name)
965
+ metadata["ss_new_vae_hash"] = train_util.calculate_sha256(vae_name)
966
+ vae_name = os.path.basename(vae_name)
967
+ metadata["ss_vae_name"] = vae_name
968
+
969
+ metadata = {k: str(v) for k, v in metadata.items()}
970
+
971
+ # make minimum metadata for filtering
972
+ minimum_metadata = {}
973
+ for key in train_util.SS_METADATA_MINIMUM_KEYS:
974
+ if key in metadata:
975
+ minimum_metadata[key] = metadata[key]
976
+
977
+ # calculate steps to skip when resuming or starting from a specific step
978
+ initial_step = 0
979
+ if args.initial_epoch is not None or args.initial_step is not None:
980
+ # if initial_epoch or initial_step is specified, steps_from_state is ignored even when resuming
981
+ if steps_from_state is not None:
982
+ logger.warning(
983
+ "steps from the state is ignored because initial_step is specified / initial_stepが指定されているため、stateからのステップ数は無視されます"
984
+ )
985
+ if args.initial_step is not None:
986
+ initial_step = args.initial_step
987
+ else:
988
+ # num steps per epoch is calculated by num_processes and gradient_accumulation_steps
989
+ initial_step = (args.initial_epoch - 1) * math.ceil(
990
+ len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
991
+ )
992
+ else:
993
+ # if initial_epoch and initial_step are not specified, steps_from_state is used when resuming
994
+ if steps_from_state is not None:
995
+ initial_step = steps_from_state
996
+ steps_from_state = None
997
+
998
+ if initial_step > 0:
999
+ assert (
1000
+ args.max_train_steps > initial_step
1001
+ ), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: {args.max_train_steps} vs {initial_step}"
1002
+
1003
+ progress_bar = tqdm(
1004
+ range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps"
1005
+ )
1006
+
1007
+ epoch_to_start = 0
1008
+ if initial_step > 0:
1009
+ if args.skip_until_initial_step:
1010
+ # if skip_until_initial_step is specified, load data and discard it to ensure the same data is used
1011
+ if not args.resume:
1012
+ logger.info(
1013
+ f"initial_step is specified but not resuming. lr scheduler will be started from the beginning / initial_stepが指定されていますがresumeしていないため、lr schedulerは最初から始まります"
1014
+ )
1015
+ logger.info(f"skipping {initial_step} steps / {initial_step}ステップをスキップします")
1016
+ initial_step *= args.gradient_accumulation_steps
1017
+
1018
+ # set epoch to start to make initial_step less than len(train_dataloader)
1019
+ epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1020
+ else:
1021
+ # if not, only epoch no is skipped for informative purpose
1022
+ epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1023
+ initial_step = 0 # do not skip
1024
+
1025
+ global_step = 0
1026
+
1027
+ noise_scheduler = self.get_noise_scheduler(args, accelerator.device)
1028
+
1029
+ if accelerator.is_main_process:
1030
+ init_kwargs = {}
1031
+ if args.wandb_run_name:
1032
+ init_kwargs["wandb"] = {"name": args.wandb_run_name}
1033
+ if args.log_tracker_config is not None:
1034
+ init_kwargs = toml.load(args.log_tracker_config)
1035
+ accelerator.init_trackers(
1036
+ "network_train" if args.log_tracker_name is None else args.log_tracker_name,
1037
+ config=train_util.get_sanitized_config_or_none(args),
1038
+ init_kwargs=init_kwargs,
1039
+ )
1040
+
1041
+ loss_recorder = train_util.LossRecorder()
1042
+ del train_dataset_group
1043
+
1044
+ # callback for step start
1045
+ if hasattr(accelerator.unwrap_model(network), "on_step_start"):
1046
+ on_step_start_for_network = accelerator.unwrap_model(network).on_step_start
1047
+ else:
1048
+ on_step_start_for_network = lambda *args, **kwargs: None
1049
+
1050
+ # function for saving/removing
1051
+ def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False):
1052
+ os.makedirs(args.output_dir, exist_ok=True)
1053
+ ckpt_file = os.path.join(args.output_dir, ckpt_name)
1054
+
1055
+ accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
1056
+ metadata["ss_training_finished_at"] = str(time.time())
1057
+ metadata["ss_steps"] = str(steps)
1058
+ metadata["ss_epoch"] = str(epoch_no)
1059
+
1060
+ metadata_to_save = minimum_metadata if args.no_metadata else metadata
1061
+ sai_metadata = self.get_sai_model_spec(args)
1062
+ metadata_to_save.update(sai_metadata)
1063
+
1064
+ unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save)
1065
+ if args.huggingface_repo_id is not None:
1066
+ huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
1067
+
1068
+ def remove_model(old_ckpt_name):
1069
+ old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
1070
+ if os.path.exists(old_ckpt_file):
1071
+ accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
1072
+ os.remove(old_ckpt_file)
1073
+
1074
+ # if text_encoder is not needed for training, delete it to save memory.
1075
+ # TODO this can be automated after SDXL sample prompt cache is implemented
1076
+ if self.is_text_encoder_not_needed_for_training(args):
1077
+ logger.info("text_encoder is not needed for training. deleting to save memory.")
1078
+ for t_enc in text_encoders:
1079
+ del t_enc
1080
+ text_encoders = []
1081
+ text_encoder = None
1082
+
1083
+ # For --sample_at_first
1084
+ optimizer_eval_fn()
1085
+ self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
1086
+ optimizer_train_fn()
1087
+ if len(accelerator.trackers) > 0:
1088
+ # log empty object to commit the sample images to wandb
1089
+ accelerator.log({}, step=0)
1090
+
1091
+ # training loop
1092
+ if initial_step > 0: # only if skip_until_initial_step is specified
1093
+ for skip_epoch in range(epoch_to_start): # skip epochs
1094
+ logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}")
1095
+ initial_step -= len(train_dataloader)
1096
+ global_step = initial_step
1097
+
1098
+ # log device and dtype for each model
1099
+ logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}")
1100
+ for i, t_enc in enumerate(text_encoders):
1101
+ params_itr = t_enc.parameters()
1102
+ params_itr.__next__() # skip the first parameter
1103
+ params_itr.__next__() # skip the second parameter. because CLIP first two parameters are embeddings
1104
+ param_3rd = params_itr.__next__()
1105
+ logger.info(f"text_encoder [{i}] dtype: {param_3rd.dtype}, device: {t_enc.device}")
1106
+
1107
+ clean_memory_on_device(accelerator.device)
1108
+
1109
+ for epoch in range(epoch_to_start, num_train_epochs):
1110
+ accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
1111
+ current_epoch.value = epoch + 1
1112
+
1113
+ metadata["ss_epoch"] = str(epoch + 1)
1114
+
1115
+ accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)
1116
+
1117
+ skipped_dataloader = None
1118
+ if initial_step > 0:
1119
+ skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step - 1)
1120
+ initial_step = 1
1121
+
1122
+ for step, batch in enumerate(skipped_dataloader or train_dataloader):
1123
+ current_step.value = global_step
1124
+ if initial_step > 0:
1125
+ initial_step -= 1
1126
+ continue
1127
+
1128
+ with accelerator.accumulate(training_model):
1129
+ on_step_start_for_network(text_encoder, unet)
1130
+
1131
+ # temporary, for batch processing
1132
+ self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
1133
+
1134
+ if "latents" in batch and batch["latents"] is not None:
1135
+ latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
1136
+ else:
1137
+ with torch.no_grad():
1138
+ # latentに変換
1139
+ latents = self.encode_images_to_latents(args, accelerator, vae, batch["images"].to(vae_dtype))
1140
+ latents = latents.to(dtype=weight_dtype)
1141
+
1142
+ # NaNが含まれていれば警告を表示し0に置き換える
1143
+ if torch.any(torch.isnan(latents)):
1144
+ accelerator.print("NaN found in latents, replacing with zeros")
1145
+ latents = torch.nan_to_num(latents, 0, out=latents)
1146
+
1147
+ latents = self.shift_scale_latents(args, latents)
1148
+
1149
+ # get multiplier for each sample
1150
+ if network_has_multiplier:
1151
+ multipliers = batch["network_multipliers"]
1152
+ # if all multipliers are same, use single multiplier
1153
+ if torch.all(multipliers == multipliers[0]):
1154
+ multipliers = multipliers[0].item()
1155
+ else:
1156
+ raise NotImplementedError("multipliers for each sample is not supported yet")
1157
+ # print(f"set multiplier: {multipliers}")
1158
+ accelerator.unwrap_model(network).set_multiplier(multipliers)
1159
+
1160
+ text_encoder_conds = []
1161
+ text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
1162
+ if text_encoder_outputs_list is not None:
1163
+ text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
1164
+
1165
+ if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
1166
+ # TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached'
1167
+ with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
1168
+ # Get the text embedding for conditioning
1169
+ if args.weighted_captions:
1170
+ input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
1171
+ encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights(
1172
+ tokenize_strategy,
1173
+ self.get_models_for_text_encoding(args, accelerator, text_encoders),
1174
+ input_ids_list,
1175
+ weights_list,
1176
+ )
1177
+ else:
1178
+ input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
1179
+ encoded_text_encoder_conds = text_encoding_strategy.encode_tokens(
1180
+ tokenize_strategy,
1181
+ self.get_models_for_text_encoding(args, accelerator, text_encoders),
1182
+ input_ids,
1183
+ )
1184
+ if args.full_fp16:
1185
+ encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds]
1186
+
1187
+ # if text_encoder_conds is not cached, use encoded_text_encoder_conds
1188
+ if len(text_encoder_conds) == 0:
1189
+ text_encoder_conds = encoded_text_encoder_conds
1190
+ else:
1191
+ # if encoded_text_encoder_conds is not None, update cached text_encoder_conds
1192
+ for i in range(len(encoded_text_encoder_conds)):
1193
+ if encoded_text_encoder_conds[i] is not None:
1194
+ text_encoder_conds[i] = encoded_text_encoder_conds[i]
1195
+
1196
+ # sample noise, call unet, get target
1197
+ noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target(
1198
+ args,
1199
+ accelerator,
1200
+ noise_scheduler,
1201
+ latents,
1202
+ batch, # 这里面有文本信息
1203
+ text_encoder_conds,
1204
+ unet,
1205
+ network,
1206
+ weight_dtype,
1207
+ train_unet,
1208
+ )
1209
+
1210
+ loss = train_util.conditional_loss(
1211
+ noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
1212
+ )
1213
+ if weighting is not None:
1214
+ loss = loss * weighting
1215
+ if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
1216
+ loss = apply_masked_loss(loss, batch)
1217
+ loss = loss.mean([1, 2, 3])
1218
+
1219
+ loss_weights = batch["loss_weights"] # 各sampleごとのweight
1220
+ loss = loss * loss_weights
1221
+
1222
+ # min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc.
1223
+ loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
1224
+
1225
+ loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
1226
+
1227
+ accelerator.backward(loss)
1228
+ if accelerator.sync_gradients:
1229
+ self.all_reduce_network(accelerator, network) # sync DDP grad manually
1230
+ if args.max_grad_norm != 0.0:
1231
+ params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
1232
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1233
+
1234
+ optimizer.step()
1235
+ lr_scheduler.step()
1236
+ optimizer.zero_grad(set_to_none=True)
1237
+
1238
+ if args.scale_weight_norms:
1239
+ keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
1240
+ args.scale_weight_norms, accelerator.device
1241
+ )
1242
+ max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
1243
+ else:
1244
+ keys_scaled, mean_norm, maximum_norm = None, None, None
1245
+
1246
+ # Checks if the accelerator has performed an optimization step behind the scenes
1247
+ if accelerator.sync_gradients:
1248
+ progress_bar.update(1)
1249
+ global_step += 1
1250
+
1251
+ optimizer_eval_fn()
1252
+ self.sample_images(
1253
+ accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet
1254
+ )
1255
+
1256
+ # 指定ステップごとにモデルを保存
1257
+ if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
1258
+ accelerator.wait_for_everyone()
1259
+ if accelerator.is_main_process:
1260
+ ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
1261
+ save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch)
1262
+
1263
+ if args.save_state:
1264
+ train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
1265
+
1266
+ remove_step_no = train_util.get_remove_step_no(args, global_step)
1267
+ if remove_step_no is not None:
1268
+ remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
1269
+ remove_model(remove_ckpt_name)
1270
+ optimizer_train_fn()
1271
+
1272
+ current_loss = loss.detach().item()
1273
+ loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
1274
+ avr_loss: float = loss_recorder.moving_average
1275
+ logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
1276
+ progress_bar.set_postfix(**logs)
1277
+
1278
+ if args.scale_weight_norms:
1279
+ progress_bar.set_postfix(**{**max_mean_logs, **logs})
1280
+
1281
+ if len(accelerator.trackers) > 0:
1282
+ logs = self.generate_step_logs(
1283
+ args, current_loss, avr_loss, lr_scheduler, lr_descriptions, keys_scaled, mean_norm, maximum_norm
1284
+ )
1285
+ accelerator.log(logs, step=global_step)
1286
+
1287
+ if global_step >= args.max_train_steps:
1288
+ break
1289
+
1290
+ if len(accelerator.trackers) > 0:
1291
+ logs = {"loss/epoch": loss_recorder.moving_average}
1292
+ accelerator.log(logs, step=epoch + 1)
1293
+
1294
+ accelerator.wait_for_everyone()
1295
+
1296
+ # 指定エポックごとにモデルを保存
1297
+ optimizer_eval_fn()
1298
+ if args.save_every_n_epochs is not None:
1299
+ saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
1300
+ if is_main_process and saving:
1301
+ ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
1302
+ save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1)
1303
+
1304
+ remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
1305
+ if remove_epoch_no is not None:
1306
+ remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no)
1307
+ remove_model(remove_ckpt_name)
1308
+
1309
+ if args.save_state:
1310
+ train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
1311
+
1312
+ self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
1313
+ optimizer_train_fn()
1314
+
1315
+ # end of epoch
1316
+
1317
+ # metadata["ss_epoch"] = str(num_train_epochs)
1318
+ metadata["ss_training_finished_at"] = str(time.time())
1319
+
1320
+ if is_main_process:
1321
+ network = accelerator.unwrap_model(network)
1322
+
1323
+ accelerator.end_training()
1324
+ optimizer_eval_fn()
1325
+
1326
+ if is_main_process and (args.save_state or args.save_state_on_train_end):
1327
+ train_util.save_state_on_train_end(args, accelerator)
1328
+
1329
+ if is_main_process:
1330
+ ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
1331
+ save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True)
1332
+
1333
+ logger.info("model saved.")
1334
+
1335
+
1336
+ def setup_parser() -> argparse.ArgumentParser:
1337
+ parser = argparse.ArgumentParser()
1338
+
1339
+ add_logging_arguments(parser)
1340
+ train_util.add_sd_models_arguments(parser)
1341
+ train_util.add_dataset_arguments(parser, True, True, True)
1342
+ train_util.add_training_arguments(parser, True)
1343
+ train_util.add_masked_loss_arguments(parser)
1344
+ deepspeed_utils.add_deepspeed_arguments(parser)
1345
+ train_util.add_optimizer_arguments(parser)
1346
+ config_util.add_config_arguments(parser)
1347
+ custom_train_functions.add_custom_train_arguments(parser)
1348
+
1349
+ parser.add_argument(
1350
+ "--cpu_offload_checkpointing",
1351
+ action="store_true",
1352
+ help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing for U-Net or DiT, if supported"
1353
+ " / 勾配チェックポイント時にテンソルをCPUにオフロードする(U-NetまたはDiTのみ、サポートされている場合)",
1354
+ )
1355
+ parser.add_argument(
1356
+ "--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない"
1357
+ )
1358
+ parser.add_argument(
1359
+ "--save_model_as",
1360
+ type=str,
1361
+ default="safetensors",
1362
+ choices=[None, "ckpt", "pt", "safetensors"],
1363
+ help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)",
1364
+ )
1365
+
1366
+ parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
1367
+ parser.add_argument(
1368
+ "--text_encoder_lr",
1369
+ type=float,
1370
+ default=None,
1371
+ nargs="*",
1372
+ help="learning rate for Text Encoder, can be multiple / Text Encoderの学習率、複数指定可能",
1373
+ )
1374
+ parser.add_argument(
1375
+ "--fp8_base_unet",
1376
+ action="store_true",
1377
+ help="use fp8 for U-Net (or DiT), Text Encoder is fp16 or bf16"
1378
+ " / U-Net(またはDiT)にfp8を使用する。Text Encoderはfp16またはbf16",
1379
+ )
1380
+
1381
+ parser.add_argument(
1382
+ "--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み"
1383
+ )
1384
+ parser.add_argument(
1385
+ "--network_module", type=str, default=None, help="network module to train / 学習対象のネットワークのモジュール"
1386
+ )
1387
+ parser.add_argument(
1388
+ "--network_dim",
1389
+ type=int,
1390
+ default=None,
1391
+ help="network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)",
1392
+ )
1393
+ parser.add_argument(
1394
+ "--network_alpha",
1395
+ type=float,
1396
+ default=1,
1397
+ help="alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1(旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定)",
1398
+ )
1399
+ parser.add_argument(
1400
+ "--network_dropout",
1401
+ type=float,
1402
+ default=None,
1403
+ help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする(0またはNoneはdropoutなし、1は全ニューロンをdropout)",
1404
+ )
1405
+ parser.add_argument(
1406
+ "--network_args",
1407
+ type=str,
1408
+ default=None,
1409
+ nargs="*",
1410
+ help="additional arguments for network (key=value) / ネットワークへの追加の引数",
1411
+ )
1412
+ parser.add_argument(
1413
+ "--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する"
1414
+ )
1415
+ parser.add_argument(
1416
+ "--network_train_text_encoder_only",
1417
+ action="store_true",
1418
+ help="only training Text Encoder part / Text Encoder関連部分のみ学習する",
1419
+ )
1420
+ parser.add_argument(
1421
+ "--training_comment",
1422
+ type=str,
1423
+ default=None,
1424
+ help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列",
1425
+ )
1426
+ parser.add_argument(
1427
+ "--dim_from_weights",
1428
+ action="store_true",
1429
+ help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する",
1430
+ )
1431
+ parser.add_argument(
1432
+ "--scale_weight_norms",
1433
+ type=float,
1434
+ default=None,
1435
+ help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point) / 重みの値をスケーリングして勾配爆発を防ぐ(1が初期値としては適当)",
1436
+ )
1437
+ parser.add_argument(
1438
+ "--base_weights",
1439
+ type=str,
1440
+ default=None,
1441
+ nargs="*",
1442
+ help="network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みファイル",
1443
+ )
1444
+ parser.add_argument(
1445
+ "--base_weights_multiplier",
1446
+ type=float,
1447
+ default=None,
1448
+ nargs="*",
1449
+ help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率",
1450
+ )
1451
+ parser.add_argument(
1452
+ "--no_half_vae",
1453
+ action="store_true",
1454
+ help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
1455
+ )
1456
+ parser.add_argument(
1457
+ "--skip_until_initial_step",
1458
+ action="store_true",
1459
+ help="skip training until initial_step is reached / initial_stepに到達するまで学習をスキップする",
1460
+ )
1461
+ parser.add_argument(
1462
+ "--initial_epoch",
1463
+ type=int,
1464
+ default=None,
1465
+ help="initial epoch number, 1 means first epoch (same as not specifying). NOTE: initial_epoch/step doesn't affect to lr scheduler. Which means lr scheduler will start from 0 without `--resume`."
1466
+ + " / 初期エポック数、1で最初のエポック(未指定時と同じ)。注意:initial_epoch/stepはlr schedulerに影響しないため、`--resume`しない場合はlr schedulerは0から始まる",
1467
+ )
1468
+ parser.add_argument(
1469
+ "--initial_step",
1470
+ type=int,
1471
+ default=None,
1472
+ help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch."
1473
+ + " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする",
1474
+ )
1475
+ parser.add_argument(
1476
+ "--lora_ups_num",
1477
+ type=int,
1478
+ required=True, # 参数必须填写
1479
+ help="初始化lora的下游矩阵个数"
1480
+ )
1481
+ return parser
1482
+
1483
+
1484
+ if __name__ == "__main__":
1485
+ parser = setup_parser()
1486
+
1487
+ args = parser.parse_args()
1488
+ train_util.verify_command_line_training_args(args)
1489
+ args = train_util.read_config_from_file(args, parser)
1490
+
1491
+ trainer = NetworkTrainer()
1492
+ trainer.train(args)