Spaces:
Running
on
Zero
Running
on
Zero
Upload 17 files
Browse files- .gitattributes +11 -0
- LICENSE +21 -0
- README.md +239 -14
- flux_inference_recraft.py +442 -0
- flux_minimal_inference.py +576 -0
- flux_minimal_inference_asylora.py +583 -0
- flux_train_network.py +588 -0
- flux_train_network_asylora.py +591 -0
- flux_train_recraft.py +713 -0
- gradio_app.py +233 -0
- id_rsa +50 -0
- requirements.txt +47 -6
- setup.py +3 -0
- split_asylora.py +37 -0
- train_network.py +1479 -0
- train_network_asylora.py +1492 -0
.gitattributes
CHANGED
@@ -33,3 +33,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
id_rsa
|
37 |
+
__pycache__
|
38 |
+
*.egg-info
|
39 |
+
.vscode
|
40 |
+
wandb
|
41 |
+
Merge
|
42 |
+
asy_results
|
43 |
+
recraft_results
|
44 |
+
drop
|
45 |
+
SplitAsy
|
46 |
+
example*
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2025 Show Lab
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,14 +1,239 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MakeAnything
|
2 |
+
|
3 |
+
> **MakeAnything: Harnessing Diffusion Transformers for Multi-Domain Procedural Sequence Generation**
|
4 |
+
> <br>
|
5 |
+
> [Yiren Song](https://scholar.google.com.hk/citations?user=L2YS0jgAAAAJ),
|
6 |
+
> [Cheng Liu](https://scholar.google.com.hk/citations?hl=zh-CN&user=TvdVuAYAAAAJ),
|
7 |
+
> and
|
8 |
+
> [Mike Zheng Shou](https://sites.google.com/view/showlab)
|
9 |
+
> <br>
|
10 |
+
> [Show Lab](https://sites.google.com/view/showlab), National University of Singapore
|
11 |
+
> <br>
|
12 |
+
|
13 |
+
<a href="https://arxiv.org/abs/2502.01572"><img src="https://img.shields.io/badge/ariXv-2411.15098-A42C25.svg" alt="arXiv"></a>
|
14 |
+
<a href="https://huggingface.co/showlab/makeanything"><img src="https://img.shields.io/badge/🤗_HuggingFace-Model-ffbd45.svg" alt="HuggingFace"></a>
|
15 |
+
<a href="https://huggingface.co/datasets/showlab/makeanything/"><img src="https://img.shields.io/badge/🤗_HuggingFace-Dataset-ffbd45.svg" alt="HuggingFace"></a>
|
16 |
+
|
17 |
+
<br>
|
18 |
+
|
19 |
+
<img src='./images/teaser.png' width='100%' />
|
20 |
+
|
21 |
+
|
22 |
+
## Configuration
|
23 |
+
### 1. **Environment setup**
|
24 |
+
```bash
|
25 |
+
git clone https://github.com/showlab/MakeAnything.git
|
26 |
+
cd MakeAnything
|
27 |
+
|
28 |
+
conda create -n makeanything python=3.11.10
|
29 |
+
conda activate makeanything
|
30 |
+
```
|
31 |
+
### 2. **Requirements installation**
|
32 |
+
```bash
|
33 |
+
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124
|
34 |
+
pip install --upgrade -r requirements.txt
|
35 |
+
|
36 |
+
accelerate config
|
37 |
+
```
|
38 |
+
|
39 |
+
## Asymmetric LoRA
|
40 |
+
### 1. Weights
|
41 |
+
You can download the trained checkpoints of Asymmetric LoRA & LoRA for inference. Below are the details of available models:
|
42 |
+
|
43 |
+
| **Model** | **Description** | **Resolution** |
|
44 |
+
|:-:|:-:|:-:|
|
45 |
+
| [asylora_9f_general](https://huggingface.co/showlab/makeanything/blob/main/asymmetric_lora/asymmetric_lora_9f_general.safetensors) | The Asymmetric LoRA has been fine-tuned on all 9-frames datasets. *Index of lora_up*: `1:LEGO` `2:Cook` `3:Painting` `4:Icon` `5:Landscape illustration` `6:Portrait` `7:Transformer` `8:Sand art` `9:Illustration` `10:Sketch` | 1056,1056 |
|
46 |
+
| [asylora_4f_general](https://huggingface.co/showlab/makeanything/blob/main/asymmetric_lora/asymmetric_lora_4f_general.safetensors) | The Asymmetric LoRA has been fine-tuned on all 4-frames datasets. *Index of lora_up: (1~10 same as 9f)* `11:Clay toys` `12:Clay sculpture` `13:Zbrush Modeling` `14:Wood sculpture` `15:Ink painting` `16:Pencil sketch` `17:Fabric toys` `18:Oil painting` `19:Jade Carving` `20:Line draw` `21:Emoji` | 1024,1024 |
|
47 |
+
|
48 |
+
### 2. Training
|
49 |
+
<span id="dataset_setting"></span>
|
50 |
+
#### 2.1 Settings for dataset
|
51 |
+
The training process relies on paired dataset consisting of text captions and images. Each dataset folder contains both `.caption` and `.png` files, where the filenames of the caption files correspond directly to the image filenames. Here is an example of the organized dataset.
|
52 |
+
|
53 |
+
```
|
54 |
+
dataset/
|
55 |
+
├── portrait_001.png
|
56 |
+
├── portrait_001.caption
|
57 |
+
├── portrait_002.png
|
58 |
+
├── portrait_002.caption
|
59 |
+
├── lego_001.png
|
60 |
+
├── lego_001.caption
|
61 |
+
```
|
62 |
+
|
63 |
+
The `.caption` files contain a **single line** of text that serves as a prompt for generating the corresponding image. The prompt **must specify the index of the lora_up** used for that particular training sample in the Asymmetric LoRA. The format for this is `--lora_up <index>`, where `<index>` is the current B matrices index in the Asymmetric LoRA, refers to the certain domain used in the training, and index should **start from 1**, not 0.
|
64 |
+
|
65 |
+
For example, a .caption file for a portrait painting sequence might look as follows:
|
66 |
+
|
67 |
+
```caption
|
68 |
+
3*3 of 9 sub-images, step-by-step portrait painting process, 1 girl --lora_up 6
|
69 |
+
```
|
70 |
+
|
71 |
+
Then, you should organize your **dataset configuration file** written in `TOML`. Here is an example:
|
72 |
+
|
73 |
+
```toml
|
74 |
+
[general]
|
75 |
+
enable_bucket = false
|
76 |
+
|
77 |
+
[[datasets]]
|
78 |
+
resolution = 1056
|
79 |
+
batch_size = 1
|
80 |
+
|
81 |
+
[[datasets.subsets]]
|
82 |
+
image_dir = '/path/to/dataset/'
|
83 |
+
caption_extension = '.caption'
|
84 |
+
num_repeats = 1
|
85 |
+
```
|
86 |
+
|
87 |
+
It is recommended to set batch size to 1 and set resolution to 1024 (4-frames) or 1056 (9-frames).
|
88 |
+
|
89 |
+
#### 2.2 Start training
|
90 |
+
We have provided a template file for training Asymmetric LoRA in `scripts/asylora_train.sh`. Simply replace corresponding paths with yours to start the training. Note that `lora_ups_num` in the script is the total number of B matrices used in Asymmetric LoRA that you specified during training.
|
91 |
+
|
92 |
+
```bash
|
93 |
+
chmod +x scripts/asylora_train.sh
|
94 |
+
scripts/asylora_train.sh
|
95 |
+
```
|
96 |
+
|
97 |
+
Additionally, if you are directly **using our dataset for training**, the `.caption` files in our released dataset do not specify the `--lora_up <index>` field. You will need to organize and update the `.caption` files to include the appropriate `--lora_up <index>` values before starting the training.
|
98 |
+
|
99 |
+
### 3. Inference
|
100 |
+
We have also provided a template file for inference Asymmetric LoRA in `scripts/asylora_inference.sh`. Once the training is done, replace file paths, fill in your prompt and run inference. Note that `lora_up_cur` in the script is the current number of B matrices index to be used for inference.
|
101 |
+
|
102 |
+
```bash
|
103 |
+
chmod +x scripts/asylora_inference.sh
|
104 |
+
scripts/asylora_train.sh
|
105 |
+
```
|
106 |
+
|
107 |
+
|
108 |
+
## Recraft Model
|
109 |
+
### 1. Weights
|
110 |
+
You can download the trained checkpoints of Recraft Model for inference. Below are the details of available models:
|
111 |
+
| **Model** | **Description** | **Resolution** |
|
112 |
+
|:-:|:-:|:-:|
|
113 |
+
| [recraft_9f_lego ](https://huggingface.co/showlab/makeanything/blob/main/recraft/recraft_9f_lego.safetensors) | The Recraft Model has been trained on `LEGO` dataset. Support `9-frames` generation. | 1056,1056 |
|
114 |
+
| [recraft_9f_portrait ](https://huggingface.co/showlab/makeanything/blob/main/recraft/recraft_9f_portrait.safetensors) | The Recraft Model has been trained on `Portrait` dataset. Support `9-frames` generation. | 1056,1056 |
|
115 |
+
| [recraft_9f_sketch ](https://huggingface.co/showlab/makeanything/blob/main/recraft/recraft_9f_sketch.safetensors) | The Recraft Model has been trained on `Sketch` dataset. Support `9-frames` generation. | 1056,1056 |
|
116 |
+
| [recraft_4f_wood_sculpture ](https://huggingface.co/showlab/makeanything/blob/main/recraft/recraft_4f_wood_sculpture.safetensors) | The Recraft Model has been trained on `Wood sculpture` dataset. Support `4-frames` generation. | 1024,1024 |
|
117 |
+
|
118 |
+
### 2. Training
|
119 |
+
#### 2.1 Obtain standard LoRA
|
120 |
+
During the second phase of training the image-to-sequence generation with the Recraft model, we need to apply a **standard LoRA architecture** to be merged to flux.1 before performing the Recraft training. Therefore, the first step is to decompose the Asymmetric LoRA into the original LoRA format.
|
121 |
+
|
122 |
+
To achieve this, **train a standard LoRA directly** (optional method below) or we have provided a script template in `scripts/asylora_split.sh` for **splitting the Asymmetric LoRA**. The script allows you to extract the required B matrices from the Asymmetric LoRA model. Specifically, the `LORA_UP` in the script specifies the index of the B matrices you wish to extract for use as the original LoRA.
|
123 |
+
|
124 |
+
```bash
|
125 |
+
chmod +x scripts/asylora_split.sh
|
126 |
+
scripts/asylora_split.sh
|
127 |
+
```
|
128 |
+
|
129 |
+
#### (Optional) Train standard LoRA
|
130 |
+
You can also **directly train a standard LoRA** for Recraft process, eliminating the need to decompose the Asymmetric LoRA. In our project, we have included the standard LoRA training code from [kohya-ss/sd-scripts](https://github.com/sd-scripts) in the files `flux_train_network.py` for training and `flux_minimal_inference.py` for inference. You can refer to the related documentation for guidance on how to train.
|
131 |
+
|
132 |
+
Alternatively, using other training platforms like [kijai/ComfyUI-FluxTrainer](https://github.com/ComfyUI-FluxTrainer) is also a viable option. These platforms provide tools to facilitate the training and inference of LoRA models for the Recraft process.
|
133 |
+
|
134 |
+
#### 2.2 Merge LoRA to flux.1
|
135 |
+
Now you have obtained a standard LoRA, use our `scripts/lora_merge.sh` template script to merge the LoRA to flux.1 checkpoints for further recraft training. Note that the merged model may take up **around 50GB** of your memory space.
|
136 |
+
|
137 |
+
```bash
|
138 |
+
chmod +x scripts/lora_merge.sh
|
139 |
+
scripts/lora_merge.sh
|
140 |
+
```
|
141 |
+
#### 2.3 Settings for training
|
142 |
+
|
143 |
+
The dataset structure for Recraft training follows the same organization format as the dataset for Asymmetric LoRA, specifically described in [Asymmetric LoRA 2.1 Settings for dataset](#dataset_setting). A `TOML` configuration file is also required to organize and configure the dataset. Below is a template for the dataset configuration file:
|
144 |
+
|
145 |
+
```toml
|
146 |
+
[general]
|
147 |
+
flip_aug = false
|
148 |
+
color_aug = false
|
149 |
+
keep_tokens_separator = "|||"
|
150 |
+
shuffle_caption = false
|
151 |
+
caption_tag_dropout_rate = 0
|
152 |
+
caption_extension = ".caption"
|
153 |
+
|
154 |
+
[[datasets]]
|
155 |
+
batch_size = 1
|
156 |
+
enable_bucket = true
|
157 |
+
resolution = [1024, 1024]
|
158 |
+
|
159 |
+
[[datasets.subsets]]
|
160 |
+
image_dir = "/path/to/dataset/"
|
161 |
+
num_repeats = 1
|
162 |
+
```
|
163 |
+
|
164 |
+
Note that for training with 4-frame step sequences, the resolution must be set to `1024`. For training with 9-frame steps, the resolution should be `1056`.
|
165 |
+
|
166 |
+
For the sampling phase of the Recraft training process, we need to organize two text files: `sample_images.txt` and `sample_prompts.txt`. These files will store the sampled condition images and their corresponding prompts, respectively. Below are the templates for both files:
|
167 |
+
|
168 |
+
**sample_images.txt**
|
169 |
+
```txt
|
170 |
+
/path/to/image_1.png
|
171 |
+
/path/to/image_2.png
|
172 |
+
```
|
173 |
+
|
174 |
+
**sample_prompts.txt**
|
175 |
+
```txt
|
176 |
+
image_1_prompt_content
|
177 |
+
image_2_prompt_content
|
178 |
+
```
|
179 |
+
#### 2.4 Recraft training
|
180 |
+
We have provided a template file for training Recraft Model in `scripts/recraft_train.sh`. Simply replace corresponding paths with yours to start the training. Note that `frame_num` in the script must be `4` (for 1024 resolution) or `9` (for 1056 resolution).
|
181 |
+
|
182 |
+
```bash
|
183 |
+
chmod +x scripts/asylora_train.sh
|
184 |
+
scripts/asylora_train.sh
|
185 |
+
```
|
186 |
+
|
187 |
+
### 3. Inference
|
188 |
+
We have also provided a template file for inference Recraft Model in `scripts/recraft_inference.sh`. Once the training is done, replace file paths, fill in your prompt and run inference.
|
189 |
+
|
190 |
+
```bash
|
191 |
+
chmod +x scripts/asylora_inference.sh
|
192 |
+
scripts/asylora_train.sh
|
193 |
+
```
|
194 |
+
|
195 |
+
## Datasets
|
196 |
+
|
197 |
+
We have uploaded our datasets on [Hugging Face](https://huggingface.co/datasets/showlab/makeanything/). The datasets includes both 4-frame and 9-frame sequence images, covering a total of 21 domains of procedural sequences. For MakeAnything training, each domain consists of **50 sequences**, with resolutions of either **1024 (4-frame)** or **1056 (9-frame)**. Additionally, we provide an extensive collection of SVG datasets and Sketch datasets for further research and experimentation.
|
198 |
+
|
199 |
+
Note that the arrangement of **9-frame sequences follows an S-shape pattern**, whereas **4-frame sequences follow a ɔ-shape pattern**.
|
200 |
+
|
201 |
+
<details>
|
202 |
+
<summary>Click to preview the datasets</summary>
|
203 |
+
<br>
|
204 |
+
|
205 |
+
| Domain | Preview | Quantity | Domain | Preview | Quantity |
|
206 |
+
|:--------:|:---------:|:----------:|:--------:|:---------:|:----------:|
|
207 |
+
| LEGO | ![LEGO Preview](./images/datasets/lego.png) | 50 | Cook | ![Cook Preview](./images/datasets/cook.png) | 50 |
|
208 |
+
| Painting | ![Painting Preview](./images/datasets/painting.png) | 50 | Icon | ![Icon Preview](./images/datasets/icon.png) | 50+1.4k |
|
209 |
+
| Landscape Illustration | ![Landscape Illustration Preview](./images/datasets/landscape.png) | 50 | Portrait | ![Portrait Preview](./images/datasets/portrait.png) | 50+2k |
|
210 |
+
| Transformer | ![Transformer Preview](./images/datasets/transformer.png) | 50 | Sand Art | ![Sand Art Preview](./images/datasets/sandart.png) | 50 |
|
211 |
+
| Illustration | ![Illustration Preview](./images/datasets/illustration.png) | 50 | Sketch | ![Sketch Preview](./images/datasets/sketch.png) | 50+9k |
|
212 |
+
| Clay Toys | ![Clay Toys Preview](./images/datasets/claytoys.png) | 50 | Clay Sculpture | ![Clay Sculpture Preview](./images/datasets/claysculpture.png) | 50 |
|
213 |
+
| ZBrush Modeling | ![ZBrush Modeling Preview](./images/datasets/zbrush.png) | 50 | Wood Sculpture | ![Wood Sculpture Preview](./images/datasets/woodsculpture.png) | 50 |
|
214 |
+
| Ink Painting | ![Ink Painting Preview](./images/datasets/inkpainting.png) | 50 | Pencil Sketch | ![Pencil Sketch Preview](./images/datasets/pencilsketch.png) | 50 |
|
215 |
+
| Fabric Toys | ![Fabric Toys Preview](./images/datasets/fabrictoys.png) | 50 | Oil Painting | ![Oil Painting Preview](./images/datasets/oilpainting.png) | 50 |
|
216 |
+
| Jade Carving | ![Jade Carving Preview](./images/datasets/jadecarving.png) | 50 | Line Draw | ![Line Draw Preview](./images/datasets/linedraw.png) | 50 |
|
217 |
+
| Emoji | ![Emoji Preview](./images/datasets/emoji.png) | 50+12k | | | |
|
218 |
+
|
219 |
+
</details>
|
220 |
+
|
221 |
+
## Results
|
222 |
+
### Text-to-Sequence Generation (LoRA & Asymmetric LoRA)
|
223 |
+
<img src='./images/t2i.png' width='100%' />
|
224 |
+
|
225 |
+
### Image-to-Sequence Generation (Recraft Model)
|
226 |
+
<img src='./images/i2i.png' width='100%' />
|
227 |
+
|
228 |
+
### Generalization on Unseen Domains
|
229 |
+
<img src='./images/oneshot.png' width='100%' />
|
230 |
+
|
231 |
+
## Citation
|
232 |
+
```
|
233 |
+
@inproceedings{Song2025MakeAnythingHD,
|
234 |
+
title={MakeAnything: Harnessing Diffusion Transformers for Multi-Domain Procedural Sequence Generation},
|
235 |
+
author={Yiren Song and Cheng Liu and Mike Zheng Shou},
|
236 |
+
year={2025},
|
237 |
+
url={https://api.semanticscholar.org/CorpusID:276107845}
|
238 |
+
}
|
239 |
+
```
|
flux_inference_recraft.py
ADDED
@@ -0,0 +1,442 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import copy
|
3 |
+
import math
|
4 |
+
import random
|
5 |
+
from typing import Any
|
6 |
+
import pdb
|
7 |
+
import os
|
8 |
+
|
9 |
+
import time
|
10 |
+
from PIL import Image, ImageOps
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from accelerate import Accelerator
|
14 |
+
from library.device_utils import clean_memory_on_device
|
15 |
+
from safetensors.torch import load_file
|
16 |
+
from networks import lora_flux
|
17 |
+
|
18 |
+
from library import flux_models, flux_train_utils_recraft as flux_train_utils, flux_utils, sd3_train_utils, \
|
19 |
+
strategy_base, strategy_flux, train_util
|
20 |
+
from torchvision import transforms
|
21 |
+
import train_network
|
22 |
+
from library.utils import setup_logging
|
23 |
+
from diffusers.utils import load_image
|
24 |
+
import numpy as np
|
25 |
+
|
26 |
+
setup_logging()
|
27 |
+
import logging
|
28 |
+
|
29 |
+
logger = logging.getLogger(__name__)
|
30 |
+
|
31 |
+
|
32 |
+
def load_target_model(
|
33 |
+
fp8_base: bool,
|
34 |
+
pretrained_model_name_or_path: str,
|
35 |
+
disable_mmap_load_safetensors: bool,
|
36 |
+
clip_l_path: str,
|
37 |
+
fp8_base_unet: bool,
|
38 |
+
t5xxl_path: str,
|
39 |
+
ae_path: str,
|
40 |
+
weight_dtype: torch.dtype,
|
41 |
+
accelerator: Accelerator
|
42 |
+
):
|
43 |
+
# Determine the loading data type
|
44 |
+
loading_dtype = None if fp8_base else weight_dtype
|
45 |
+
|
46 |
+
# Load the main model to the accelerator's device
|
47 |
+
_, model = flux_utils.load_flow_model(
|
48 |
+
pretrained_model_name_or_path,
|
49 |
+
# loading_dtype,
|
50 |
+
torch.float8_e4m3fn,
|
51 |
+
# accelerator.device, # Changed from "cpu" to accelerator.device
|
52 |
+
"cpu",
|
53 |
+
disable_mmap=disable_mmap_load_safetensors
|
54 |
+
)
|
55 |
+
|
56 |
+
if fp8_base:
|
57 |
+
# Check dtype of the model
|
58 |
+
if model.dtype in {torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz}:
|
59 |
+
raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
|
60 |
+
elif model.dtype == torch.float8_e4m3fn:
|
61 |
+
logger.info("Loaded fp8 FLUX model")
|
62 |
+
|
63 |
+
# Load the CLIP model to the accelerator's device
|
64 |
+
clip_l = flux_utils.load_clip_l(
|
65 |
+
clip_l_path,
|
66 |
+
weight_dtype,
|
67 |
+
# accelerator.device, # Changed from "cpu" to accelerator.device
|
68 |
+
"cpu",
|
69 |
+
disable_mmap=disable_mmap_load_safetensors
|
70 |
+
)
|
71 |
+
clip_l.eval()
|
72 |
+
|
73 |
+
# Determine the loading data type for T5XXL
|
74 |
+
if fp8_base and not fp8_base_unet:
|
75 |
+
loading_dtype_t5xxl = None # as is
|
76 |
+
else:
|
77 |
+
loading_dtype_t5xxl = weight_dtype
|
78 |
+
|
79 |
+
# Load the T5XXL model to the accelerator's device
|
80 |
+
t5xxl = flux_utils.load_t5xxl(
|
81 |
+
t5xxl_path,
|
82 |
+
loading_dtype_t5xxl,
|
83 |
+
# accelerator.device, # Changed from "cpu" to accelerator.device
|
84 |
+
"cpu",
|
85 |
+
disable_mmap=disable_mmap_load_safetensors
|
86 |
+
)
|
87 |
+
t5xxl.eval()
|
88 |
+
|
89 |
+
if fp8_base and not fp8_base_unet:
|
90 |
+
# Check dtype of the T5XXL model
|
91 |
+
if t5xxl.dtype in {torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz}:
|
92 |
+
raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}")
|
93 |
+
elif t5xxl.dtype == torch.float8_e4m3fn:
|
94 |
+
logger.info("Loaded fp8 T5XXL model")
|
95 |
+
|
96 |
+
# Load the AE model to the accelerator's device
|
97 |
+
ae = flux_utils.load_ae(
|
98 |
+
ae_path,
|
99 |
+
weight_dtype,
|
100 |
+
# accelerator.device, # Changed from "cpu" to accelerator.device
|
101 |
+
"cpu",
|
102 |
+
disable_mmap=disable_mmap_load_safetensors
|
103 |
+
)
|
104 |
+
|
105 |
+
# # Wrap models with Accelerator for potential distributed setups
|
106 |
+
# model, clip_l, t5xxl, ae = accelerator.prepare(model, clip_l, t5xxl, ae)
|
107 |
+
|
108 |
+
return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model
|
109 |
+
|
110 |
+
|
111 |
+
import torchvision.transforms as transforms
|
112 |
+
|
113 |
+
|
114 |
+
class ResizeWithPadding:
|
115 |
+
def __init__(self, size, fill=255):
|
116 |
+
self.size = size
|
117 |
+
self.fill = fill
|
118 |
+
|
119 |
+
def __call__(self, img):
|
120 |
+
if isinstance(img, np.ndarray):
|
121 |
+
img = Image.fromarray(img)
|
122 |
+
elif not isinstance(img, Image.Image):
|
123 |
+
raise TypeError("Input must be a PIL Image or a NumPy array")
|
124 |
+
|
125 |
+
width, height = img.size
|
126 |
+
|
127 |
+
if width == height:
|
128 |
+
img = img.resize((self.size, self.size), Image.LANCZOS)
|
129 |
+
else:
|
130 |
+
max_dim = max(width, height)
|
131 |
+
|
132 |
+
new_img = Image.new("RGB", (max_dim, max_dim), (self.fill, self.fill, self.fill))
|
133 |
+
new_img.paste(img, ((max_dim - width) // 2, (max_dim - height) // 2))
|
134 |
+
|
135 |
+
img = new_img.resize((self.size, self.size), Image.LANCZOS)
|
136 |
+
|
137 |
+
return img
|
138 |
+
|
139 |
+
|
140 |
+
def sample(args, accelerator, vae, text_encoder, flux, output_dir, sample_images, sample_prompts):
|
141 |
+
def encode_images_to_latents(vae, images):
|
142 |
+
# Get image dimensions
|
143 |
+
b, c, h, w = images.shape
|
144 |
+
num_split = 2 if args.frame_num == 4 else 3
|
145 |
+
# Split the image into three parts
|
146 |
+
img_parts = [images[:, :, :, i * w // num_split:(i + 1) * w // num_split] for i in range(num_split)]
|
147 |
+
# Encode each part
|
148 |
+
latents = [vae.encode(img) for img in img_parts]
|
149 |
+
# Concatenate latents in the latent space to reconstruct the full image
|
150 |
+
latents = torch.cat(latents, dim=-1)
|
151 |
+
return latents
|
152 |
+
|
153 |
+
def encode_images_to_latents2(vae, images):
|
154 |
+
latents = vae.encode(images)
|
155 |
+
return latents
|
156 |
+
|
157 |
+
# Directly use precomputed conditions
|
158 |
+
conditions = {}
|
159 |
+
with torch.no_grad():
|
160 |
+
for image_path, prompt_dict in zip(sample_images, sample_prompts):
|
161 |
+
prompt = prompt_dict.get("prompt", "")
|
162 |
+
if prompt not in conditions:
|
163 |
+
logger.info(f"Cache conditions for image: {image_path} with prompt: {prompt}")
|
164 |
+
resize_transform = ResizeWithPadding(size=512, fill=255) if args.frame_num == 4 else ResizeWithPadding(size=352, fill=255)
|
165 |
+
img_transforms = transforms.Compose([
|
166 |
+
resize_transform,
|
167 |
+
transforms.ToTensor(),
|
168 |
+
transforms.Normalize([0.5], [0.5]),
|
169 |
+
])
|
170 |
+
# Load and preprocess image
|
171 |
+
image = img_transforms(np.array(load_image(image_path), dtype=np.uint8)).unsqueeze(0).to(
|
172 |
+
# accelerator.device, # Move image to CUDA
|
173 |
+
vae.device,
|
174 |
+
dtype=vae.dtype
|
175 |
+
)
|
176 |
+
latents = encode_images_to_latents2(vae, image)
|
177 |
+
|
178 |
+
# Log the shape of latents
|
179 |
+
logger.debug(f"Encoded latents shape for prompt '{prompt}': {latents.shape}")
|
180 |
+
# Store conditions on CUDA
|
181 |
+
# conditions[prompt] = latents[:,:,latents.shape[2]//2:latents.shape[2], :latents.shape[3]//2].to("cpu")
|
182 |
+
conditions[prompt] = latents.to("cpu")
|
183 |
+
|
184 |
+
sample_conditions = conditions
|
185 |
+
|
186 |
+
if sample_conditions is not None:
|
187 |
+
conditions = {k: v for k, v in sample_conditions.items()} # Already on CUDA
|
188 |
+
|
189 |
+
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
|
190 |
+
text_encoder[0].to(accelerator.device)
|
191 |
+
text_encoder[1].to(accelerator.device)
|
192 |
+
|
193 |
+
tokenize_strategy = strategy_flux.FluxTokenizeStrategy(512)
|
194 |
+
text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(True)
|
195 |
+
|
196 |
+
with accelerator.autocast(), torch.no_grad():
|
197 |
+
for prompt_dict in sample_prompts:
|
198 |
+
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
|
199 |
+
if p not in sample_prompts_te_outputs:
|
200 |
+
logger.info(f"Cache Text Encoder outputs for prompt: {p}")
|
201 |
+
tokens_and_masks = tokenize_strategy.tokenize(p)
|
202 |
+
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
|
203 |
+
tokenize_strategy, text_encoder, tokens_and_masks, True
|
204 |
+
)
|
205 |
+
|
206 |
+
logger.info(f"Generating image")
|
207 |
+
save_dir = output_dir
|
208 |
+
os.makedirs(save_dir, exist_ok=True)
|
209 |
+
|
210 |
+
with torch.no_grad(), accelerator.autocast():
|
211 |
+
for prompt_dict in sample_prompts:
|
212 |
+
sample_image_inference(
|
213 |
+
args,
|
214 |
+
accelerator,
|
215 |
+
flux,
|
216 |
+
text_encoder,
|
217 |
+
vae,
|
218 |
+
save_dir,
|
219 |
+
prompt_dict,
|
220 |
+
sample_prompts_te_outputs,
|
221 |
+
None,
|
222 |
+
conditions
|
223 |
+
)
|
224 |
+
|
225 |
+
clean_memory_on_device(accelerator.device)
|
226 |
+
|
227 |
+
|
228 |
+
def sample_image_inference(
|
229 |
+
args,
|
230 |
+
accelerator: Accelerator,
|
231 |
+
flux: flux_models.Flux,
|
232 |
+
text_encoder,
|
233 |
+
ae: flux_models.AutoEncoder,
|
234 |
+
save_dir,
|
235 |
+
prompt_dict,
|
236 |
+
sample_prompts_te_outputs,
|
237 |
+
prompt_replacement,
|
238 |
+
sample_images_ae_outputs
|
239 |
+
):
|
240 |
+
# Extract parameters from prompt_dict
|
241 |
+
sample_steps = prompt_dict.get("sample_steps", 20)
|
242 |
+
width = prompt_dict.get("width", 1024) if args.frame_num == 4 else prompt_dict.get("width", 1056)
|
243 |
+
height = prompt_dict.get("height", 1024) if args.frame_num == 4 else prompt_dict.get("height", 1056)
|
244 |
+
scale = prompt_dict.get("scale", 1.0)
|
245 |
+
seed = prompt_dict.get("seed")
|
246 |
+
prompt: str = prompt_dict.get("prompt", "")
|
247 |
+
|
248 |
+
if prompt_replacement is not None:
|
249 |
+
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
250 |
+
|
251 |
+
if seed is not None:
|
252 |
+
torch.manual_seed(seed)
|
253 |
+
torch.cuda.manual_seed(seed)
|
254 |
+
else:
|
255 |
+
# True random sample image generation
|
256 |
+
torch.seed()
|
257 |
+
torch.cuda.seed()
|
258 |
+
|
259 |
+
# Ensure height and width are divisible by 16
|
260 |
+
height = max(64, height - height % 16)
|
261 |
+
width = max(64, width - width % 16)
|
262 |
+
logger.info(f"prompt: {prompt}")
|
263 |
+
logger.info(f"height: {height}")
|
264 |
+
logger.info(f"width: {width}")
|
265 |
+
logger.info(f"sample_steps: {sample_steps}")
|
266 |
+
logger.info(f"scale: {scale}")
|
267 |
+
if seed is not None:
|
268 |
+
logger.info(f"seed: {seed}")
|
269 |
+
|
270 |
+
# Encode prompts
|
271 |
+
# Assuming that TokenizeStrategy and TextEncodingStrategy are compatible with Accelerator
|
272 |
+
text_encoder_conds = []
|
273 |
+
if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs:
|
274 |
+
text_encoder_conds = sample_prompts_te_outputs[prompt]
|
275 |
+
logger.info(f"Using cached text encoder outputs for prompt: {prompt}")
|
276 |
+
|
277 |
+
if sample_images_ae_outputs and prompt in sample_images_ae_outputs:
|
278 |
+
ae_outputs = sample_images_ae_outputs[prompt]
|
279 |
+
else:
|
280 |
+
ae_outputs = None
|
281 |
+
|
282 |
+
# ae_outputs = torch.load('ae_outputs.pth', map_location='cuda:0')
|
283 |
+
|
284 |
+
# text_encoder_conds = torch.load('text_encoder_conds.pth', map_location='cuda:0')
|
285 |
+
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
286 |
+
|
287 |
+
# 打印调试信息
|
288 |
+
logger.debug(
|
289 |
+
f"l_pooled shape: {l_pooled.shape}, t5_out shape: {t5_out.shape}, txt_ids shape: {txt_ids.shape}, t5_attn_mask shape: {t5_attn_mask.shape}")
|
290 |
+
|
291 |
+
# 采样图像
|
292 |
+
weight_dtype = ae.dtype # TODO: give dtype as argument
|
293 |
+
packed_latent_height = height // 16
|
294 |
+
packed_latent_width = width // 16
|
295 |
+
|
296 |
+
# 打印调试信息
|
297 |
+
logger.debug(f"packed_latent_height: {packed_latent_height}, packed_latent_width: {packed_latent_width}")
|
298 |
+
|
299 |
+
# 准备噪声张量在 CUDA 上
|
300 |
+
noise = torch.randn(
|
301 |
+
1,
|
302 |
+
packed_latent_height * packed_latent_width,
|
303 |
+
16 * 2 * 2,
|
304 |
+
device=accelerator.device,
|
305 |
+
dtype=weight_dtype,
|
306 |
+
generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None,
|
307 |
+
)
|
308 |
+
|
309 |
+
timesteps = flux_train_utils.get_schedule(sample_steps, noise.shape[1], shift=True) # FLUX.1 dev -> shift=True
|
310 |
+
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(
|
311 |
+
accelerator.device, dtype=weight_dtype
|
312 |
+
)
|
313 |
+
t5_attn_mask = t5_attn_mask.to(accelerator.device)
|
314 |
+
|
315 |
+
clip_l, t5xxl = text_encoder
|
316 |
+
# ae.to("cpu")
|
317 |
+
clip_l.to("cpu")
|
318 |
+
t5xxl.to("cpu")
|
319 |
+
|
320 |
+
clean_memory_on_device(accelerator.device)
|
321 |
+
flux.to("cuda")
|
322 |
+
|
323 |
+
for param in flux.parameters():
|
324 |
+
param.requires_grad = False
|
325 |
+
|
326 |
+
# 执行去噪
|
327 |
+
with accelerator.autocast(), torch.no_grad():
|
328 |
+
x = flux_train_utils.denoise(args, flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps,
|
329 |
+
guidance=scale, t5_attn_mask=t5_attn_mask, ae_outputs=ae_outputs)
|
330 |
+
|
331 |
+
# 打印x的形状
|
332 |
+
logger.debug(f"x shape after denoise: {x.shape}")
|
333 |
+
|
334 |
+
x = x.float()
|
335 |
+
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
|
336 |
+
|
337 |
+
# 将潜在向量转换为图像
|
338 |
+
# clean_memory_on_device(accelerator.device)
|
339 |
+
ae.to(accelerator.device)
|
340 |
+
with accelerator.autocast(), torch.no_grad():
|
341 |
+
x = ae.decode(x)
|
342 |
+
ae.to("cpu")
|
343 |
+
clean_memory_on_device(accelerator.device)
|
344 |
+
|
345 |
+
x = x.clamp(-1, 1)
|
346 |
+
x = x.permute(0, 2, 3, 1)
|
347 |
+
image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
|
348 |
+
|
349 |
+
# 生成唯一的文件名
|
350 |
+
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
351 |
+
seed_suffix = "" if seed is None else f"_{seed}"
|
352 |
+
i: int = prompt_dict.get("enum", 0) # Ensure 'enum' exists
|
353 |
+
img_filename = f"{ts_str}{seed_suffix}_{i}.png" # Added 'i' to filename for uniqueness
|
354 |
+
image.save(os.path.join(save_dir, img_filename))
|
355 |
+
|
356 |
+
|
357 |
+
def setup_argparse():
|
358 |
+
parser = argparse.ArgumentParser(description="FLUX-Controlnet-Inpainting Inference Script")
|
359 |
+
|
360 |
+
# Paths
|
361 |
+
parser.add_argument('--base_flux_checkpoint', type=str, required=True,
|
362 |
+
help='Path to BASE_FLUX_CHECKPOINT')
|
363 |
+
parser.add_argument('--lora_weights_path', type=str, required=True,
|
364 |
+
help='Path to LORA_WEIGHTS_PATH')
|
365 |
+
parser.add_argument('--clip_l_path', type=str, required=True,
|
366 |
+
help='Path to CLIP_L_PATH')
|
367 |
+
parser.add_argument('--t5xxl_path', type=str, required=True,
|
368 |
+
help='Path to T5XXL_PATH')
|
369 |
+
parser.add_argument('--ae_path', type=str, required=True,
|
370 |
+
help='Path to AE_PATH')
|
371 |
+
parser.add_argument('--sample_images_file', type=str, required=True,
|
372 |
+
help='Path to SAMPLE_IMAGES_FILE')
|
373 |
+
parser.add_argument('--sample_prompts_file', type=str, required=True,
|
374 |
+
help='Path to SAMPLE_PROMPTS_FILE')
|
375 |
+
parser.add_argument('--output_dir', type=str, required=True,
|
376 |
+
help='Directory to save OUTPUT_DIR')
|
377 |
+
parser.add_argument('--frame_num', type=int, choices=[4, 9], required=True,
|
378 |
+
help="The number of steps in the generated step diagram (choose 4 or 9)")
|
379 |
+
|
380 |
+
return parser.parse_args()
|
381 |
+
|
382 |
+
|
383 |
+
def main(args):
|
384 |
+
accelerator = Accelerator(mixed_precision='bf16', device_placement=True)
|
385 |
+
|
386 |
+
BASE_FLUX_CHECKPOINT = args.base_flux_checkpoint
|
387 |
+
LORA_WEIGHTS_PATH = args.lora_weights_path
|
388 |
+
CLIP_L_PATH = args.clip_l_path
|
389 |
+
T5XXL_PATH = args.t5xxl_path
|
390 |
+
AE_PATH = args.ae_path
|
391 |
+
|
392 |
+
SAMPLE_IMAGES_FILE = args.sample_images_file
|
393 |
+
SAMPLE_PROMPTS_FILE = args.sample_prompts_file
|
394 |
+
OUTPUT_DIR = args.output_dir
|
395 |
+
|
396 |
+
with open(SAMPLE_IMAGES_FILE, "r", encoding="utf-8") as f:
|
397 |
+
image_lines = f.readlines()
|
398 |
+
sample_images = [line.strip() for line in image_lines if line.strip() and not line.strip().startswith("#")]
|
399 |
+
|
400 |
+
sample_prompts = train_util.load_prompts(SAMPLE_PROMPTS_FILE)
|
401 |
+
|
402 |
+
# Load models onto CUDA via Accelerator
|
403 |
+
_, [clip_l, t5xxl], ae, model = load_target_model(
|
404 |
+
fp8_base=True,
|
405 |
+
pretrained_model_name_or_path=BASE_FLUX_CHECKPOINT,
|
406 |
+
disable_mmap_load_safetensors=False,
|
407 |
+
clip_l_path=CLIP_L_PATH,
|
408 |
+
fp8_base_unet=False,
|
409 |
+
t5xxl_path=T5XXL_PATH,
|
410 |
+
ae_path=AE_PATH,
|
411 |
+
weight_dtype=torch.bfloat16,
|
412 |
+
accelerator=accelerator
|
413 |
+
)
|
414 |
+
|
415 |
+
model.eval()
|
416 |
+
clip_l.eval()
|
417 |
+
t5xxl.eval()
|
418 |
+
ae.eval()
|
419 |
+
|
420 |
+
# LoRA
|
421 |
+
multiplier = 1.0
|
422 |
+
weights_sd = load_file(LORA_WEIGHTS_PATH)
|
423 |
+
lora_model, _ = lora_flux.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd,
|
424 |
+
True)
|
425 |
+
|
426 |
+
lora_model.apply_to([clip_l, t5xxl], model)
|
427 |
+
info = lora_model.load_state_dict(weights_sd, strict=True)
|
428 |
+
logger.info(f"Loaded LoRA weights from {LORA_WEIGHTS_PATH}: {info}")
|
429 |
+
lora_model.eval()
|
430 |
+
lora_model.to("cuda")
|
431 |
+
|
432 |
+
# Set text encoders
|
433 |
+
text_encoder = [clip_l, t5xxl]
|
434 |
+
|
435 |
+
sample(args, accelerator, vae=ae, text_encoder=text_encoder, flux=model, output_dir=OUTPUT_DIR,
|
436 |
+
sample_images=sample_images, sample_prompts=sample_prompts)
|
437 |
+
|
438 |
+
|
439 |
+
if __name__ == "__main__":
|
440 |
+
args = setup_argparse()
|
441 |
+
|
442 |
+
main(args)
|
flux_minimal_inference.py
ADDED
@@ -0,0 +1,576 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Minimum Inference Code for FLUX
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import datetime
|
5 |
+
import math
|
6 |
+
import os
|
7 |
+
import random
|
8 |
+
from typing import Callable, List, Optional
|
9 |
+
import einops
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from tqdm import tqdm
|
14 |
+
from PIL import Image
|
15 |
+
import accelerate
|
16 |
+
from transformers import CLIPTextModel
|
17 |
+
from safetensors.torch import load_file
|
18 |
+
|
19 |
+
from library import device_utils
|
20 |
+
from library.device_utils import init_ipex, get_preferred_device
|
21 |
+
from networks import oft_flux
|
22 |
+
|
23 |
+
init_ipex()
|
24 |
+
|
25 |
+
|
26 |
+
from library.utils import setup_logging, str_to_dtype
|
27 |
+
|
28 |
+
setup_logging()
|
29 |
+
import logging
|
30 |
+
|
31 |
+
logger = logging.getLogger(__name__)
|
32 |
+
|
33 |
+
import networks.lora_flux as lora_flux
|
34 |
+
from library import flux_models, flux_utils, sd3_utils, strategy_flux
|
35 |
+
|
36 |
+
|
37 |
+
def time_shift(mu: float, sigma: float, t: torch.Tensor):
|
38 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
39 |
+
|
40 |
+
|
41 |
+
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
|
42 |
+
m = (y2 - y1) / (x2 - x1)
|
43 |
+
b = y1 - m * x1
|
44 |
+
return lambda x: m * x + b
|
45 |
+
|
46 |
+
|
47 |
+
def get_schedule(
|
48 |
+
num_steps: int,
|
49 |
+
image_seq_len: int,
|
50 |
+
base_shift: float = 0.5,
|
51 |
+
max_shift: float = 1.15,
|
52 |
+
shift: bool = True,
|
53 |
+
) -> list[float]:
|
54 |
+
# extra step for zero
|
55 |
+
timesteps = torch.linspace(1, 0, num_steps + 1)
|
56 |
+
|
57 |
+
# shifting the schedule to favor high timesteps for higher signal images
|
58 |
+
if shift:
|
59 |
+
# eastimate mu based on linear estimation between two points
|
60 |
+
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
61 |
+
timesteps = time_shift(mu, 1.0, timesteps)
|
62 |
+
|
63 |
+
return timesteps.tolist()
|
64 |
+
|
65 |
+
|
66 |
+
def denoise(
|
67 |
+
model: flux_models.Flux,
|
68 |
+
img: torch.Tensor,
|
69 |
+
img_ids: torch.Tensor,
|
70 |
+
txt: torch.Tensor,
|
71 |
+
txt_ids: torch.Tensor,
|
72 |
+
vec: torch.Tensor,
|
73 |
+
timesteps: list[float],
|
74 |
+
guidance: float = 4.0,
|
75 |
+
t5_attn_mask: Optional[torch.Tensor] = None,
|
76 |
+
neg_txt: Optional[torch.Tensor] = None,
|
77 |
+
neg_vec: Optional[torch.Tensor] = None,
|
78 |
+
neg_t5_attn_mask: Optional[torch.Tensor] = None,
|
79 |
+
cfg_scale: Optional[float] = None,
|
80 |
+
):
|
81 |
+
# this is ignored for schnell
|
82 |
+
logger.info(f"guidance: {guidance}, cfg_scale: {cfg_scale}")
|
83 |
+
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
84 |
+
|
85 |
+
# prepare classifier free guidance
|
86 |
+
if neg_txt is not None and neg_vec is not None:
|
87 |
+
b_img_ids = torch.cat([img_ids, img_ids], dim=0)
|
88 |
+
b_txt_ids = torch.cat([txt_ids, txt_ids], dim=0)
|
89 |
+
b_txt = torch.cat([neg_txt, txt], dim=0)
|
90 |
+
b_vec = torch.cat([neg_vec, vec], dim=0)
|
91 |
+
if t5_attn_mask is not None and neg_t5_attn_mask is not None:
|
92 |
+
b_t5_attn_mask = torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0)
|
93 |
+
else:
|
94 |
+
b_t5_attn_mask = None
|
95 |
+
else:
|
96 |
+
b_img_ids = img_ids
|
97 |
+
b_txt_ids = txt_ids
|
98 |
+
b_txt = txt
|
99 |
+
b_vec = vec
|
100 |
+
b_t5_attn_mask = t5_attn_mask
|
101 |
+
|
102 |
+
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
103 |
+
t_vec = torch.full((b_img_ids.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
104 |
+
|
105 |
+
# classifier free guidance
|
106 |
+
if neg_txt is not None and neg_vec is not None:
|
107 |
+
b_img = torch.cat([img, img], dim=0)
|
108 |
+
else:
|
109 |
+
b_img = img
|
110 |
+
|
111 |
+
pred = model(
|
112 |
+
img=b_img,
|
113 |
+
img_ids=b_img_ids,
|
114 |
+
txt=b_txt,
|
115 |
+
txt_ids=b_txt_ids,
|
116 |
+
y=b_vec,
|
117 |
+
timesteps=t_vec,
|
118 |
+
guidance=guidance_vec,
|
119 |
+
txt_attention_mask=b_t5_attn_mask,
|
120 |
+
)
|
121 |
+
|
122 |
+
# classifier free guidance
|
123 |
+
if neg_txt is not None and neg_vec is not None:
|
124 |
+
pred_uncond, pred = torch.chunk(pred, 2, dim=0)
|
125 |
+
pred = pred_uncond + cfg_scale * (pred - pred_uncond)
|
126 |
+
|
127 |
+
img = img + (t_prev - t_curr) * pred
|
128 |
+
|
129 |
+
return img
|
130 |
+
|
131 |
+
|
132 |
+
def do_sample(
|
133 |
+
accelerator: Optional[accelerate.Accelerator],
|
134 |
+
model: flux_models.Flux,
|
135 |
+
img: torch.Tensor,
|
136 |
+
img_ids: torch.Tensor,
|
137 |
+
l_pooled: torch.Tensor,
|
138 |
+
t5_out: torch.Tensor,
|
139 |
+
txt_ids: torch.Tensor,
|
140 |
+
num_steps: int,
|
141 |
+
guidance: float,
|
142 |
+
t5_attn_mask: Optional[torch.Tensor],
|
143 |
+
is_schnell: bool,
|
144 |
+
device: torch.device,
|
145 |
+
flux_dtype: torch.dtype,
|
146 |
+
neg_l_pooled: Optional[torch.Tensor] = None,
|
147 |
+
neg_t5_out: Optional[torch.Tensor] = None,
|
148 |
+
neg_t5_attn_mask: Optional[torch.Tensor] = None,
|
149 |
+
cfg_scale: Optional[float] = None,
|
150 |
+
):
|
151 |
+
logger.info(f"num_steps: {num_steps}")
|
152 |
+
timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell)
|
153 |
+
|
154 |
+
# denoise initial noise
|
155 |
+
if accelerator:
|
156 |
+
with accelerator.autocast(), torch.no_grad():
|
157 |
+
x = denoise(
|
158 |
+
model,
|
159 |
+
img,
|
160 |
+
img_ids,
|
161 |
+
t5_out,
|
162 |
+
txt_ids,
|
163 |
+
l_pooled,
|
164 |
+
timesteps,
|
165 |
+
guidance,
|
166 |
+
t5_attn_mask,
|
167 |
+
neg_t5_out,
|
168 |
+
neg_l_pooled,
|
169 |
+
neg_t5_attn_mask,
|
170 |
+
cfg_scale,
|
171 |
+
)
|
172 |
+
else:
|
173 |
+
with torch.autocast(device_type=device.type, dtype=flux_dtype), torch.no_grad():
|
174 |
+
x = denoise(
|
175 |
+
model,
|
176 |
+
img,
|
177 |
+
img_ids,
|
178 |
+
t5_out,
|
179 |
+
txt_ids,
|
180 |
+
l_pooled,
|
181 |
+
timesteps,
|
182 |
+
guidance,
|
183 |
+
t5_attn_mask,
|
184 |
+
neg_t5_out,
|
185 |
+
neg_l_pooled,
|
186 |
+
neg_t5_attn_mask,
|
187 |
+
cfg_scale,
|
188 |
+
)
|
189 |
+
|
190 |
+
return x
|
191 |
+
|
192 |
+
|
193 |
+
def generate_image(
|
194 |
+
model,
|
195 |
+
clip_l: CLIPTextModel,
|
196 |
+
t5xxl,
|
197 |
+
ae,
|
198 |
+
prompt: str,
|
199 |
+
seed: Optional[int],
|
200 |
+
image_width: int,
|
201 |
+
image_height: int,
|
202 |
+
steps: Optional[int],
|
203 |
+
guidance: float,
|
204 |
+
negative_prompt: Optional[str],
|
205 |
+
cfg_scale: float,
|
206 |
+
):
|
207 |
+
seed = seed if seed is not None else random.randint(0, 2**32 - 1)
|
208 |
+
logger.info(f"Seed: {seed}")
|
209 |
+
|
210 |
+
# make first noise with packed shape
|
211 |
+
# original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2
|
212 |
+
packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16)
|
213 |
+
noise_dtype = torch.float32 if is_fp8(dtype) else dtype
|
214 |
+
noise = torch.randn(
|
215 |
+
1,
|
216 |
+
packed_latent_height * packed_latent_width,
|
217 |
+
16 * 2 * 2,
|
218 |
+
device=device,
|
219 |
+
dtype=noise_dtype,
|
220 |
+
generator=torch.Generator(device=device).manual_seed(seed),
|
221 |
+
)
|
222 |
+
|
223 |
+
# prepare img and img ids
|
224 |
+
|
225 |
+
# this is needed only for img2img
|
226 |
+
# img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
227 |
+
# if img.shape[0] == 1 and bs > 1:
|
228 |
+
# img = repeat(img, "1 ... -> bs ...", bs=bs)
|
229 |
+
|
230 |
+
# txt2img only needs img_ids
|
231 |
+
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width)
|
232 |
+
|
233 |
+
# prepare fp8 models
|
234 |
+
if is_fp8(clip_l_dtype) and (not hasattr(clip_l, "fp8_prepared") or not clip_l.fp8_prepared):
|
235 |
+
logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}")
|
236 |
+
clip_l.to(clip_l_dtype) # fp8
|
237 |
+
clip_l.text_model.embeddings.to(dtype=torch.bfloat16)
|
238 |
+
clip_l.fp8_prepared = True
|
239 |
+
|
240 |
+
if is_fp8(t5xxl_dtype) and (not hasattr(t5xxl, "fp8_prepared") or not t5xxl.fp8_prepared):
|
241 |
+
logger.info(f"prepare T5xxl for fp8: set to {t5xxl_dtype}")
|
242 |
+
|
243 |
+
def prepare_fp8(text_encoder, target_dtype):
|
244 |
+
def forward_hook(module):
|
245 |
+
def forward(hidden_states):
|
246 |
+
hidden_gelu = module.act(module.wi_0(hidden_states))
|
247 |
+
hidden_linear = module.wi_1(hidden_states)
|
248 |
+
hidden_states = hidden_gelu * hidden_linear
|
249 |
+
hidden_states = module.dropout(hidden_states)
|
250 |
+
|
251 |
+
hidden_states = module.wo(hidden_states)
|
252 |
+
return hidden_states
|
253 |
+
|
254 |
+
return forward
|
255 |
+
|
256 |
+
for module in text_encoder.modules():
|
257 |
+
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
|
258 |
+
# print("set", module.__class__.__name__, "to", target_dtype)
|
259 |
+
module.to(target_dtype)
|
260 |
+
if module.__class__.__name__ in ["T5DenseGatedActDense"]:
|
261 |
+
# print("set", module.__class__.__name__, "hooks")
|
262 |
+
module.forward = forward_hook(module)
|
263 |
+
|
264 |
+
t5xxl.to(t5xxl_dtype)
|
265 |
+
prepare_fp8(t5xxl.encoder, torch.bfloat16)
|
266 |
+
t5xxl.fp8_prepared = True
|
267 |
+
|
268 |
+
# prepare embeddings
|
269 |
+
logger.info("Encoding prompts...")
|
270 |
+
clip_l = clip_l.to(device)
|
271 |
+
t5xxl = t5xxl.to(device)
|
272 |
+
|
273 |
+
def encode(prpt: str):
|
274 |
+
tokens_and_masks = tokenize_strategy.tokenize(prpt)
|
275 |
+
with torch.no_grad():
|
276 |
+
if is_fp8(clip_l_dtype):
|
277 |
+
with accelerator.autocast():
|
278 |
+
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
|
279 |
+
else:
|
280 |
+
with torch.autocast(device_type=device.type, dtype=clip_l_dtype):
|
281 |
+
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
|
282 |
+
|
283 |
+
if is_fp8(t5xxl_dtype):
|
284 |
+
with accelerator.autocast():
|
285 |
+
_, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
|
286 |
+
tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
|
287 |
+
)
|
288 |
+
else:
|
289 |
+
with torch.autocast(device_type=device.type, dtype=t5xxl_dtype):
|
290 |
+
_, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
|
291 |
+
tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
|
292 |
+
)
|
293 |
+
return l_pooled, t5_out, txt_ids, t5_attn_mask
|
294 |
+
|
295 |
+
l_pooled, t5_out, txt_ids, t5_attn_mask = encode(prompt)
|
296 |
+
if negative_prompt:
|
297 |
+
neg_l_pooled, neg_t5_out, _, neg_t5_attn_mask = encode(negative_prompt)
|
298 |
+
else:
|
299 |
+
neg_l_pooled, neg_t5_out, neg_t5_attn_mask = None, None, None
|
300 |
+
|
301 |
+
# NaN check
|
302 |
+
if torch.isnan(l_pooled).any():
|
303 |
+
raise ValueError("NaN in l_pooled")
|
304 |
+
if torch.isnan(t5_out).any():
|
305 |
+
raise ValueError("NaN in t5_out")
|
306 |
+
|
307 |
+
if args.offload:
|
308 |
+
clip_l = clip_l.cpu()
|
309 |
+
t5xxl = t5xxl.cpu()
|
310 |
+
# del clip_l, t5xxl
|
311 |
+
device_utils.clean_memory()
|
312 |
+
|
313 |
+
# generate image
|
314 |
+
logger.info("Generating image...")
|
315 |
+
model = model.to(device)
|
316 |
+
if steps is None:
|
317 |
+
steps = 4 if is_schnell else 50
|
318 |
+
|
319 |
+
img_ids = img_ids.to(device)
|
320 |
+
t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None
|
321 |
+
|
322 |
+
x = do_sample(
|
323 |
+
accelerator,
|
324 |
+
model,
|
325 |
+
noise,
|
326 |
+
img_ids,
|
327 |
+
l_pooled,
|
328 |
+
t5_out,
|
329 |
+
txt_ids,
|
330 |
+
steps,
|
331 |
+
guidance,
|
332 |
+
t5_attn_mask,
|
333 |
+
is_schnell,
|
334 |
+
device,
|
335 |
+
flux_dtype,
|
336 |
+
neg_l_pooled,
|
337 |
+
neg_t5_out,
|
338 |
+
neg_t5_attn_mask,
|
339 |
+
cfg_scale,
|
340 |
+
)
|
341 |
+
if args.offload:
|
342 |
+
model = model.cpu()
|
343 |
+
# del model
|
344 |
+
device_utils.clean_memory()
|
345 |
+
|
346 |
+
# unpack
|
347 |
+
x = x.float()
|
348 |
+
x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2)
|
349 |
+
|
350 |
+
# decode
|
351 |
+
logger.info("Decoding image...")
|
352 |
+
ae = ae.to(device)
|
353 |
+
with torch.no_grad():
|
354 |
+
if is_fp8(ae_dtype):
|
355 |
+
with accelerator.autocast():
|
356 |
+
x = ae.decode(x)
|
357 |
+
else:
|
358 |
+
with torch.autocast(device_type=device.type, dtype=ae_dtype):
|
359 |
+
x = ae.decode(x)
|
360 |
+
if args.offload:
|
361 |
+
ae = ae.cpu()
|
362 |
+
|
363 |
+
x = x.clamp(-1, 1)
|
364 |
+
x = x.permute(0, 2, 3, 1)
|
365 |
+
img = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
|
366 |
+
|
367 |
+
# save image
|
368 |
+
output_dir = args.output_dir
|
369 |
+
os.makedirs(output_dir, exist_ok=True)
|
370 |
+
output_path = os.path.join(output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png")
|
371 |
+
img.save(output_path)
|
372 |
+
|
373 |
+
logger.info(f"Saved image to {output_path}")
|
374 |
+
|
375 |
+
|
376 |
+
if __name__ == "__main__":
|
377 |
+
target_height = 768 # 1024
|
378 |
+
target_width = 1360 # 1024
|
379 |
+
|
380 |
+
# steps = 50 # 28 # 50
|
381 |
+
# guidance_scale = 5
|
382 |
+
# seed = 1 # None # 1
|
383 |
+
|
384 |
+
device = get_preferred_device()
|
385 |
+
|
386 |
+
parser = argparse.ArgumentParser()
|
387 |
+
parser.add_argument("--ckpt_path", type=str, required=True)
|
388 |
+
parser.add_argument("--clip_l", type=str, required=False)
|
389 |
+
parser.add_argument("--t5xxl", type=str, required=False)
|
390 |
+
parser.add_argument("--ae", type=str, required=False)
|
391 |
+
parser.add_argument("--apply_t5_attn_mask", action="store_true")
|
392 |
+
parser.add_argument("--prompt", type=str, default="A photo of a cat")
|
393 |
+
parser.add_argument("--output_dir", type=str, default=".")
|
394 |
+
parser.add_argument("--dtype", type=str, default="bfloat16", help="base dtype")
|
395 |
+
parser.add_argument("--clip_l_dtype", type=str, default=None, help="dtype for clip_l")
|
396 |
+
parser.add_argument("--ae_dtype", type=str, default=None, help="dtype for ae")
|
397 |
+
parser.add_argument("--t5xxl_dtype", type=str, default=None, help="dtype for t5xxl")
|
398 |
+
parser.add_argument("--flux_dtype", type=str, default=None, help="dtype for flux")
|
399 |
+
parser.add_argument("--seed", type=int, default=None)
|
400 |
+
parser.add_argument("--steps", type=int, default=None, help="Number of steps. Default is 4 for schnell, 50 for dev")
|
401 |
+
parser.add_argument("--guidance", type=float, default=3.5)
|
402 |
+
parser.add_argument("--negative_prompt", type=str, default=None)
|
403 |
+
parser.add_argument("--cfg_scale", type=float, default=1.0)
|
404 |
+
parser.add_argument("--offload", action="store_true", help="Offload to CPU")
|
405 |
+
parser.add_argument(
|
406 |
+
"--lora_weights",
|
407 |
+
type=str,
|
408 |
+
nargs="*",
|
409 |
+
default=[],
|
410 |
+
help="LoRA weights, only supports networks.lora_flux and lora_oft, each argument is a `path;multiplier` (semi-colon separated)",
|
411 |
+
)
|
412 |
+
parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model")
|
413 |
+
parser.add_argument("--width", type=int, default=target_width)
|
414 |
+
parser.add_argument("--height", type=int, default=target_height)
|
415 |
+
parser.add_argument("--interactive", action="store_true")
|
416 |
+
args = parser.parse_args()
|
417 |
+
|
418 |
+
seed = args.seed
|
419 |
+
steps = args.steps
|
420 |
+
guidance_scale = args.guidance
|
421 |
+
|
422 |
+
def is_fp8(dt):
|
423 |
+
return dt in [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz]
|
424 |
+
|
425 |
+
dtype = str_to_dtype(args.dtype)
|
426 |
+
clip_l_dtype = str_to_dtype(args.clip_l_dtype, dtype)
|
427 |
+
t5xxl_dtype = str_to_dtype(args.t5xxl_dtype, dtype)
|
428 |
+
ae_dtype = str_to_dtype(args.ae_dtype, dtype)
|
429 |
+
flux_dtype = str_to_dtype(args.flux_dtype, dtype)
|
430 |
+
|
431 |
+
logger.info(f"Dtypes for clip_l, t5xxl, ae, flux: {clip_l_dtype}, {t5xxl_dtype}, {ae_dtype}, {flux_dtype}")
|
432 |
+
|
433 |
+
loading_device = "cpu" if args.offload else device
|
434 |
+
|
435 |
+
use_fp8 = [is_fp8(d) for d in [dtype, clip_l_dtype, t5xxl_dtype, ae_dtype, flux_dtype]]
|
436 |
+
if any(use_fp8):
|
437 |
+
accelerator = accelerate.Accelerator(mixed_precision="bf16")
|
438 |
+
else:
|
439 |
+
accelerator = None
|
440 |
+
|
441 |
+
# load clip_l
|
442 |
+
logger.info(f"Loading clip_l from {args.clip_l}...")
|
443 |
+
clip_l = flux_utils.load_clip_l(args.clip_l, clip_l_dtype, loading_device)
|
444 |
+
clip_l.eval()
|
445 |
+
|
446 |
+
logger.info(f"Loading t5xxl from {args.t5xxl}...")
|
447 |
+
t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device)
|
448 |
+
t5xxl.eval()
|
449 |
+
|
450 |
+
# if is_fp8(clip_l_dtype):
|
451 |
+
# clip_l = accelerator.prepare(clip_l)
|
452 |
+
# if is_fp8(t5xxl_dtype):
|
453 |
+
# t5xxl = accelerator.prepare(t5xxl)
|
454 |
+
|
455 |
+
# DiT
|
456 |
+
is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device)
|
457 |
+
model.eval()
|
458 |
+
logger.info(f"Casting model to {flux_dtype}")
|
459 |
+
model.to(flux_dtype) # make sure model is dtype
|
460 |
+
# if is_fp8(flux_dtype):
|
461 |
+
# model = accelerator.prepare(model)
|
462 |
+
# if args.offload:
|
463 |
+
# model = model.to("cpu")
|
464 |
+
|
465 |
+
t5xxl_max_length = 256 if is_schnell else 512
|
466 |
+
tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length)
|
467 |
+
encoding_strategy = strategy_flux.FluxTextEncodingStrategy()
|
468 |
+
|
469 |
+
# AE
|
470 |
+
ae = flux_utils.load_ae(args.ae, ae_dtype, loading_device)
|
471 |
+
ae.eval()
|
472 |
+
# if is_fp8(ae_dtype):
|
473 |
+
# ae = accelerator.prepare(ae)
|
474 |
+
|
475 |
+
# LoRA
|
476 |
+
lora_models: List[lora_flux.LoRANetwork] = []
|
477 |
+
for weights_file in args.lora_weights:
|
478 |
+
if ";" in weights_file:
|
479 |
+
weights_file, multiplier = weights_file.split(";")
|
480 |
+
multiplier = float(multiplier)
|
481 |
+
else:
|
482 |
+
multiplier = 1.0
|
483 |
+
|
484 |
+
weights_sd = load_file(weights_file)
|
485 |
+
is_lora = is_oft = False
|
486 |
+
for key in weights_sd.keys():
|
487 |
+
if key.startswith("lora"):
|
488 |
+
is_lora = True
|
489 |
+
if key.startswith("oft"):
|
490 |
+
is_oft = True
|
491 |
+
if is_lora or is_oft:
|
492 |
+
break
|
493 |
+
|
494 |
+
module = lora_flux if is_lora else oft_flux
|
495 |
+
lora_model, _ = module.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True)
|
496 |
+
|
497 |
+
if args.merge_lora_weights:
|
498 |
+
lora_model.merge_to([clip_l, t5xxl], model, weights_sd)
|
499 |
+
else:
|
500 |
+
lora_model.apply_to([clip_l, t5xxl], model)
|
501 |
+
info = lora_model.load_state_dict(weights_sd, strict=True)
|
502 |
+
logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
|
503 |
+
lora_model.eval()
|
504 |
+
lora_model.to(device)
|
505 |
+
|
506 |
+
lora_models.append(lora_model)
|
507 |
+
|
508 |
+
if not args.interactive:
|
509 |
+
generate_image(
|
510 |
+
model,
|
511 |
+
clip_l,
|
512 |
+
t5xxl,
|
513 |
+
ae,
|
514 |
+
args.prompt,
|
515 |
+
args.seed,
|
516 |
+
args.width,
|
517 |
+
args.height,
|
518 |
+
args.steps,
|
519 |
+
args.guidance,
|
520 |
+
args.negative_prompt,
|
521 |
+
args.cfg_scale,
|
522 |
+
)
|
523 |
+
else:
|
524 |
+
# loop for interactive
|
525 |
+
width = target_width
|
526 |
+
height = target_height
|
527 |
+
steps = None
|
528 |
+
guidance = args.guidance
|
529 |
+
cfg_scale = args.cfg_scale
|
530 |
+
|
531 |
+
while True:
|
532 |
+
print(
|
533 |
+
"Enter prompt (empty to exit). Options: --w <width> --h <height> --s <steps> --d <seed> --g <guidance> --m <multipliers for LoRA>"
|
534 |
+
" --n <negative prompt>, `-` for empty negative prompt --c <cfg_scale>"
|
535 |
+
)
|
536 |
+
prompt = input()
|
537 |
+
if prompt == "":
|
538 |
+
break
|
539 |
+
|
540 |
+
# parse options
|
541 |
+
options = prompt.split("--")
|
542 |
+
prompt = options[0].strip()
|
543 |
+
seed = None
|
544 |
+
negative_prompt = None
|
545 |
+
for opt in options[1:]:
|
546 |
+
try:
|
547 |
+
opt = opt.strip()
|
548 |
+
if opt.startswith("w"):
|
549 |
+
width = int(opt[1:].strip())
|
550 |
+
elif opt.startswith("h"):
|
551 |
+
height = int(opt[1:].strip())
|
552 |
+
elif opt.startswith("s"):
|
553 |
+
steps = int(opt[1:].strip())
|
554 |
+
elif opt.startswith("d"):
|
555 |
+
seed = int(opt[1:].strip())
|
556 |
+
elif opt.startswith("g"):
|
557 |
+
guidance = float(opt[1:].strip())
|
558 |
+
elif opt.startswith("m"):
|
559 |
+
mutipliers = opt[1:].strip().split(",")
|
560 |
+
if len(mutipliers) != len(lora_models):
|
561 |
+
logger.error(f"Invalid number of multipliers, expected {len(lora_models)}")
|
562 |
+
continue
|
563 |
+
for i, lora_model in enumerate(lora_models):
|
564 |
+
lora_model.set_multiplier(float(mutipliers[i]))
|
565 |
+
elif opt.startswith("n"):
|
566 |
+
negative_prompt = opt[1:].strip()
|
567 |
+
if negative_prompt == "-":
|
568 |
+
negative_prompt = ""
|
569 |
+
elif opt.startswith("c"):
|
570 |
+
cfg_scale = float(opt[1:].strip())
|
571 |
+
except ValueError as e:
|
572 |
+
logger.error(f"Invalid option: {opt}, {e}")
|
573 |
+
|
574 |
+
generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance, negative_prompt, cfg_scale)
|
575 |
+
|
576 |
+
logger.info("Done!")
|
flux_minimal_inference_asylora.py
ADDED
@@ -0,0 +1,583 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# Minimum Inference Code for FLUX
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import datetime
|
6 |
+
import math
|
7 |
+
import os
|
8 |
+
import random
|
9 |
+
from typing import Callable, List, Optional
|
10 |
+
import einops
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from tqdm import tqdm
|
15 |
+
from PIL import Image
|
16 |
+
import accelerate
|
17 |
+
from transformers import CLIPTextModel
|
18 |
+
from safetensors.torch import load_file
|
19 |
+
|
20 |
+
from library import device_utils
|
21 |
+
from library.device_utils import init_ipex, get_preferred_device
|
22 |
+
from networks import oft_flux
|
23 |
+
|
24 |
+
init_ipex()
|
25 |
+
|
26 |
+
|
27 |
+
from library.utils import setup_logging, str_to_dtype
|
28 |
+
|
29 |
+
setup_logging()
|
30 |
+
import logging
|
31 |
+
|
32 |
+
logger = logging.getLogger(__name__)
|
33 |
+
|
34 |
+
import networks.asylora_flux as lora_flux
|
35 |
+
from library import flux_models, flux_utils, sd3_utils, strategy_flux
|
36 |
+
|
37 |
+
|
38 |
+
def time_shift(mu: float, sigma: float, t: torch.Tensor):
|
39 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
40 |
+
|
41 |
+
|
42 |
+
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
|
43 |
+
m = (y2 - y1) / (x2 - x1)
|
44 |
+
b = y1 - m * x1
|
45 |
+
return lambda x: m * x + b
|
46 |
+
|
47 |
+
|
48 |
+
def get_schedule(
|
49 |
+
num_steps: int,
|
50 |
+
image_seq_len: int,
|
51 |
+
base_shift: float = 0.5,
|
52 |
+
max_shift: float = 1.15,
|
53 |
+
shift: bool = True,
|
54 |
+
) -> list[float]:
|
55 |
+
# extra step for zero
|
56 |
+
timesteps = torch.linspace(1, 0, num_steps + 1)
|
57 |
+
|
58 |
+
# shifting the schedule to favor high timesteps for higher signal images
|
59 |
+
if shift:
|
60 |
+
# eastimate mu based on linear estimation between two points
|
61 |
+
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
62 |
+
timesteps = time_shift(mu, 1.0, timesteps)
|
63 |
+
|
64 |
+
return timesteps.tolist()
|
65 |
+
|
66 |
+
|
67 |
+
def denoise(
|
68 |
+
model: flux_models.Flux,
|
69 |
+
img: torch.Tensor,
|
70 |
+
img_ids: torch.Tensor,
|
71 |
+
txt: torch.Tensor,
|
72 |
+
txt_ids: torch.Tensor,
|
73 |
+
vec: torch.Tensor,
|
74 |
+
timesteps: list[float],
|
75 |
+
guidance: float = 4.0,
|
76 |
+
t5_attn_mask: Optional[torch.Tensor] = None,
|
77 |
+
neg_txt: Optional[torch.Tensor] = None,
|
78 |
+
neg_vec: Optional[torch.Tensor] = None,
|
79 |
+
neg_t5_attn_mask: Optional[torch.Tensor] = None,
|
80 |
+
cfg_scale: Optional[float] = None,
|
81 |
+
):
|
82 |
+
# this is ignored for schnell
|
83 |
+
logger.info(f"guidance: {guidance}, cfg_scale: {cfg_scale}")
|
84 |
+
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
85 |
+
|
86 |
+
# prepare classifier free guidance
|
87 |
+
if neg_txt is not None and neg_vec is not None:
|
88 |
+
b_img_ids = torch.cat([img_ids, img_ids], dim=0)
|
89 |
+
b_txt_ids = torch.cat([txt_ids, txt_ids], dim=0)
|
90 |
+
b_txt = torch.cat([neg_txt, txt], dim=0)
|
91 |
+
b_vec = torch.cat([neg_vec, vec], dim=0)
|
92 |
+
if t5_attn_mask is not None and neg_t5_attn_mask is not None:
|
93 |
+
b_t5_attn_mask = torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0)
|
94 |
+
else:
|
95 |
+
b_t5_attn_mask = None
|
96 |
+
else:
|
97 |
+
b_img_ids = img_ids
|
98 |
+
b_txt_ids = txt_ids
|
99 |
+
b_txt = txt
|
100 |
+
b_vec = vec
|
101 |
+
b_t5_attn_mask = t5_attn_mask
|
102 |
+
|
103 |
+
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
104 |
+
t_vec = torch.full((b_img_ids.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
105 |
+
|
106 |
+
# classifier free guidance
|
107 |
+
if neg_txt is not None and neg_vec is not None:
|
108 |
+
b_img = torch.cat([img, img], dim=0)
|
109 |
+
else:
|
110 |
+
b_img = img
|
111 |
+
|
112 |
+
pred = model(
|
113 |
+
img=b_img,
|
114 |
+
img_ids=b_img_ids,
|
115 |
+
txt=b_txt,
|
116 |
+
txt_ids=b_txt_ids,
|
117 |
+
y=b_vec,
|
118 |
+
timesteps=t_vec,
|
119 |
+
guidance=guidance_vec,
|
120 |
+
txt_attention_mask=b_t5_attn_mask,
|
121 |
+
)
|
122 |
+
|
123 |
+
# classifier free guidance
|
124 |
+
if neg_txt is not None and neg_vec is not None:
|
125 |
+
pred_uncond, pred = torch.chunk(pred, 2, dim=0)
|
126 |
+
pred = pred_uncond + cfg_scale * (pred - pred_uncond)
|
127 |
+
|
128 |
+
img = img + (t_prev - t_curr) * pred
|
129 |
+
|
130 |
+
return img
|
131 |
+
|
132 |
+
|
133 |
+
def do_sample(
|
134 |
+
accelerator: Optional[accelerate.Accelerator],
|
135 |
+
model: flux_models.Flux,
|
136 |
+
img: torch.Tensor,
|
137 |
+
img_ids: torch.Tensor,
|
138 |
+
l_pooled: torch.Tensor,
|
139 |
+
t5_out: torch.Tensor,
|
140 |
+
txt_ids: torch.Tensor,
|
141 |
+
num_steps: int,
|
142 |
+
guidance: float,
|
143 |
+
t5_attn_mask: Optional[torch.Tensor],
|
144 |
+
is_schnell: bool,
|
145 |
+
device: torch.device,
|
146 |
+
flux_dtype: torch.dtype,
|
147 |
+
neg_l_pooled: Optional[torch.Tensor] = None,
|
148 |
+
neg_t5_out: Optional[torch.Tensor] = None,
|
149 |
+
neg_t5_attn_mask: Optional[torch.Tensor] = None,
|
150 |
+
cfg_scale: Optional[float] = None,
|
151 |
+
):
|
152 |
+
logger.info(f"num_steps: {num_steps}")
|
153 |
+
timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell)
|
154 |
+
|
155 |
+
# denoise initial noise
|
156 |
+
if accelerator:
|
157 |
+
with accelerator.autocast(), torch.no_grad():
|
158 |
+
x = denoise(
|
159 |
+
model,
|
160 |
+
img,
|
161 |
+
img_ids,
|
162 |
+
t5_out,
|
163 |
+
txt_ids,
|
164 |
+
l_pooled,
|
165 |
+
timesteps,
|
166 |
+
guidance,
|
167 |
+
t5_attn_mask,
|
168 |
+
neg_t5_out,
|
169 |
+
neg_l_pooled,
|
170 |
+
neg_t5_attn_mask,
|
171 |
+
cfg_scale,
|
172 |
+
)
|
173 |
+
else:
|
174 |
+
with torch.autocast(device_type=device.type, dtype=flux_dtype), torch.no_grad():
|
175 |
+
x = denoise(
|
176 |
+
model,
|
177 |
+
img,
|
178 |
+
img_ids,
|
179 |
+
t5_out,
|
180 |
+
txt_ids,
|
181 |
+
l_pooled,
|
182 |
+
timesteps,
|
183 |
+
guidance,
|
184 |
+
t5_attn_mask,
|
185 |
+
neg_t5_out,
|
186 |
+
neg_l_pooled,
|
187 |
+
neg_t5_attn_mask,
|
188 |
+
cfg_scale,
|
189 |
+
)
|
190 |
+
|
191 |
+
return x
|
192 |
+
|
193 |
+
|
194 |
+
def generate_image(
|
195 |
+
model,
|
196 |
+
clip_l: CLIPTextModel,
|
197 |
+
t5xxl,
|
198 |
+
ae,
|
199 |
+
prompt: str,
|
200 |
+
seed: Optional[int],
|
201 |
+
image_width: int,
|
202 |
+
image_height: int,
|
203 |
+
steps: Optional[int],
|
204 |
+
guidance: float,
|
205 |
+
negative_prompt: Optional[str],
|
206 |
+
cfg_scale: float,
|
207 |
+
):
|
208 |
+
seed = seed if seed is not None else random.randint(0, 2**32 - 1)
|
209 |
+
logger.info(f"Seed: {seed}")
|
210 |
+
|
211 |
+
# make first noise with packed shape
|
212 |
+
# original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2
|
213 |
+
packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16)
|
214 |
+
noise_dtype = torch.float32 if is_fp8(dtype) else dtype
|
215 |
+
noise = torch.randn(
|
216 |
+
1,
|
217 |
+
packed_latent_height * packed_latent_width,
|
218 |
+
16 * 2 * 2,
|
219 |
+
device=device,
|
220 |
+
dtype=noise_dtype,
|
221 |
+
generator=torch.Generator(device=device).manual_seed(seed),
|
222 |
+
)
|
223 |
+
|
224 |
+
# prepare img and img ids
|
225 |
+
|
226 |
+
# this is needed only for img2img
|
227 |
+
# img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
228 |
+
# if img.shape[0] == 1 and bs > 1:
|
229 |
+
# img = repeat(img, "1 ... -> bs ...", bs=bs)
|
230 |
+
|
231 |
+
# txt2img only needs img_ids
|
232 |
+
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width)
|
233 |
+
|
234 |
+
# prepare fp8 models
|
235 |
+
if is_fp8(clip_l_dtype) and (not hasattr(clip_l, "fp8_prepared") or not clip_l.fp8_prepared):
|
236 |
+
logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}")
|
237 |
+
clip_l.to(clip_l_dtype) # fp8
|
238 |
+
clip_l.text_model.embeddings.to(dtype=torch.bfloat16)
|
239 |
+
clip_l.fp8_prepared = True
|
240 |
+
|
241 |
+
if is_fp8(t5xxl_dtype) and (not hasattr(t5xxl, "fp8_prepared") or not t5xxl.fp8_prepared):
|
242 |
+
logger.info(f"prepare T5xxl for fp8: set to {t5xxl_dtype}")
|
243 |
+
|
244 |
+
def prepare_fp8(text_encoder, target_dtype):
|
245 |
+
def forward_hook(module):
|
246 |
+
def forward(hidden_states):
|
247 |
+
hidden_gelu = module.act(module.wi_0(hidden_states))
|
248 |
+
hidden_linear = module.wi_1(hidden_states)
|
249 |
+
hidden_states = hidden_gelu * hidden_linear
|
250 |
+
hidden_states = module.dropout(hidden_states)
|
251 |
+
|
252 |
+
hidden_states = module.wo(hidden_states)
|
253 |
+
return hidden_states
|
254 |
+
|
255 |
+
return forward
|
256 |
+
|
257 |
+
for module in text_encoder.modules():
|
258 |
+
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
|
259 |
+
# print("set", module.__class__.__name__, "to", target_dtype)
|
260 |
+
module.to(target_dtype)
|
261 |
+
if module.__class__.__name__ in ["T5DenseGatedActDense"]:
|
262 |
+
# print("set", module.__class__.__name__, "hooks")
|
263 |
+
module.forward = forward_hook(module)
|
264 |
+
|
265 |
+
t5xxl.to(t5xxl_dtype)
|
266 |
+
prepare_fp8(t5xxl.encoder, torch.bfloat16)
|
267 |
+
t5xxl.fp8_prepared = True
|
268 |
+
|
269 |
+
# prepare embeddings
|
270 |
+
logger.info("Encoding prompts...")
|
271 |
+
clip_l = clip_l.to(device)
|
272 |
+
t5xxl = t5xxl.to(device)
|
273 |
+
|
274 |
+
def encode(prpt: str):
|
275 |
+
tokens_and_masks = tokenize_strategy.tokenize(prpt)
|
276 |
+
with torch.no_grad():
|
277 |
+
if is_fp8(clip_l_dtype):
|
278 |
+
with accelerator.autocast():
|
279 |
+
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
|
280 |
+
else:
|
281 |
+
with torch.autocast(device_type=device.type, dtype=clip_l_dtype):
|
282 |
+
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
|
283 |
+
|
284 |
+
if is_fp8(t5xxl_dtype):
|
285 |
+
with accelerator.autocast():
|
286 |
+
_, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
|
287 |
+
tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
|
288 |
+
)
|
289 |
+
else:
|
290 |
+
with torch.autocast(device_type=device.type, dtype=t5xxl_dtype):
|
291 |
+
_, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
|
292 |
+
tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
|
293 |
+
)
|
294 |
+
return l_pooled, t5_out, txt_ids, t5_attn_mask
|
295 |
+
|
296 |
+
l_pooled, t5_out, txt_ids, t5_attn_mask = encode(prompt)
|
297 |
+
if negative_prompt:
|
298 |
+
neg_l_pooled, neg_t5_out, _, neg_t5_attn_mask = encode(negative_prompt)
|
299 |
+
else:
|
300 |
+
neg_l_pooled, neg_t5_out, neg_t5_attn_mask = None, None, None
|
301 |
+
|
302 |
+
# NaN check
|
303 |
+
if torch.isnan(l_pooled).any():
|
304 |
+
raise ValueError("NaN in l_pooled")
|
305 |
+
if torch.isnan(t5_out).any():
|
306 |
+
raise ValueError("NaN in t5_out")
|
307 |
+
|
308 |
+
if args.offload:
|
309 |
+
clip_l = clip_l.cpu()
|
310 |
+
t5xxl = t5xxl.cpu()
|
311 |
+
# del clip_l, t5xxl
|
312 |
+
device_utils.clean_memory()
|
313 |
+
|
314 |
+
# generate image
|
315 |
+
logger.info("Generating image...")
|
316 |
+
model = model.to(device)
|
317 |
+
if steps is None:
|
318 |
+
steps = 4 if is_schnell else 50
|
319 |
+
|
320 |
+
img_ids = img_ids.to(device)
|
321 |
+
t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None
|
322 |
+
|
323 |
+
x = do_sample(
|
324 |
+
accelerator,
|
325 |
+
model,
|
326 |
+
noise,
|
327 |
+
img_ids,
|
328 |
+
l_pooled,
|
329 |
+
t5_out,
|
330 |
+
txt_ids,
|
331 |
+
steps,
|
332 |
+
guidance,
|
333 |
+
t5_attn_mask,
|
334 |
+
is_schnell,
|
335 |
+
device,
|
336 |
+
flux_dtype,
|
337 |
+
neg_l_pooled,
|
338 |
+
neg_t5_out,
|
339 |
+
neg_t5_attn_mask,
|
340 |
+
cfg_scale,
|
341 |
+
)
|
342 |
+
if args.offload:
|
343 |
+
model = model.cpu()
|
344 |
+
# del model
|
345 |
+
device_utils.clean_memory()
|
346 |
+
|
347 |
+
# unpack
|
348 |
+
x = x.float()
|
349 |
+
x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2)
|
350 |
+
|
351 |
+
# decode
|
352 |
+
logger.info("Decoding image...")
|
353 |
+
ae = ae.to(device)
|
354 |
+
with torch.no_grad():
|
355 |
+
if is_fp8(ae_dtype):
|
356 |
+
with accelerator.autocast():
|
357 |
+
x = ae.decode(x)
|
358 |
+
else:
|
359 |
+
with torch.autocast(device_type=device.type, dtype=ae_dtype):
|
360 |
+
x = ae.decode(x)
|
361 |
+
if args.offload:
|
362 |
+
ae = ae.cpu()
|
363 |
+
|
364 |
+
x = x.clamp(-1, 1)
|
365 |
+
x = x.permute(0, 2, 3, 1)
|
366 |
+
img = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
|
367 |
+
|
368 |
+
# save image
|
369 |
+
output_dir = args.output_dir
|
370 |
+
os.makedirs(output_dir, exist_ok=True)
|
371 |
+
output_path = os.path.join(output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png")
|
372 |
+
img.save(output_path)
|
373 |
+
|
374 |
+
logger.info(f"Saved image to {output_path}")
|
375 |
+
|
376 |
+
|
377 |
+
if __name__ == "__main__":
|
378 |
+
target_height = 768 # 1024
|
379 |
+
target_width = 1360 # 1024
|
380 |
+
|
381 |
+
# steps = 50 # 28 # 50
|
382 |
+
# guidance_scale = 5
|
383 |
+
# seed = 1 # None # 1
|
384 |
+
|
385 |
+
device = get_preferred_device()
|
386 |
+
|
387 |
+
parser = argparse.ArgumentParser()
|
388 |
+
parser.add_argument("--lora_ups_num", type=int, required=True)
|
389 |
+
parser.add_argument("--lora_up_cur", type=int, required=True)
|
390 |
+
parser.add_argument("--ckpt_path", type=str, required=True)
|
391 |
+
parser.add_argument("--clip_l", type=str, required=False)
|
392 |
+
parser.add_argument("--t5xxl", type=str, required=False)
|
393 |
+
parser.add_argument("--ae", type=str, required=False)
|
394 |
+
parser.add_argument("--apply_t5_attn_mask", action="store_true")
|
395 |
+
parser.add_argument("--prompt", type=str, default="A photo of a cat")
|
396 |
+
parser.add_argument("--output_dir", type=str, default=".")
|
397 |
+
parser.add_argument("--dtype", type=str, default="bfloat16", help="base dtype")
|
398 |
+
parser.add_argument("--clip_l_dtype", type=str, default=None, help="dtype for clip_l")
|
399 |
+
parser.add_argument("--ae_dtype", type=str, default=None, help="dtype for ae")
|
400 |
+
parser.add_argument("--t5xxl_dtype", type=str, default=None, help="dtype for t5xxl")
|
401 |
+
parser.add_argument("--flux_dtype", type=str, default=None, help="dtype for flux")
|
402 |
+
parser.add_argument("--seed", type=int, default=None)
|
403 |
+
parser.add_argument("--steps", type=int, default=None, help="Number of steps. Default is 4 for schnell, 50 for dev")
|
404 |
+
parser.add_argument("--guidance", type=float, default=3.5)
|
405 |
+
parser.add_argument("--negative_prompt", type=str, default=None)
|
406 |
+
parser.add_argument("--cfg_scale", type=float, default=1.0)
|
407 |
+
parser.add_argument("--offload", action="store_true", help="Offload to CPU")
|
408 |
+
parser.add_argument(
|
409 |
+
"--lora_weights",
|
410 |
+
type=str,
|
411 |
+
nargs="*",
|
412 |
+
default=[],
|
413 |
+
help="LoRA weights, only supports networks.lora_flux and lora_oft, each argument is a `path;multiplier` (semi-colon separated)",
|
414 |
+
)
|
415 |
+
parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model")
|
416 |
+
parser.add_argument("--width", type=int, default=target_width)
|
417 |
+
parser.add_argument("--height", type=int, default=target_height)
|
418 |
+
parser.add_argument("--interactive", action="store_true")
|
419 |
+
args = parser.parse_args()
|
420 |
+
|
421 |
+
seed = args.seed
|
422 |
+
steps = args.steps
|
423 |
+
guidance_scale = args.guidance
|
424 |
+
lora_ups_num = args.lora_ups_num
|
425 |
+
lora_up_cur = args.lora_up_cur
|
426 |
+
|
427 |
+
def is_fp8(dt):
|
428 |
+
return dt in [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz]
|
429 |
+
|
430 |
+
dtype = str_to_dtype(args.dtype)
|
431 |
+
clip_l_dtype = str_to_dtype(args.clip_l_dtype, dtype)
|
432 |
+
t5xxl_dtype = str_to_dtype(args.t5xxl_dtype, dtype)
|
433 |
+
ae_dtype = str_to_dtype(args.ae_dtype, dtype)
|
434 |
+
flux_dtype = str_to_dtype(args.flux_dtype, dtype)
|
435 |
+
|
436 |
+
logger.info(f"Dtypes for clip_l, t5xxl, ae, flux: {clip_l_dtype}, {t5xxl_dtype}, {ae_dtype}, {flux_dtype}")
|
437 |
+
|
438 |
+
loading_device = "cpu" if args.offload else device
|
439 |
+
|
440 |
+
use_fp8 = [is_fp8(d) for d in [dtype, clip_l_dtype, t5xxl_dtype, ae_dtype, flux_dtype]]
|
441 |
+
if any(use_fp8):
|
442 |
+
accelerator = accelerate.Accelerator(mixed_precision="bf16")
|
443 |
+
else:
|
444 |
+
accelerator = None
|
445 |
+
|
446 |
+
# load clip_l
|
447 |
+
logger.info(f"Loading clip_l from {args.clip_l}...")
|
448 |
+
clip_l = flux_utils.load_clip_l(args.clip_l, clip_l_dtype, loading_device)
|
449 |
+
clip_l.eval()
|
450 |
+
|
451 |
+
logger.info(f"Loading t5xxl from {args.t5xxl}...")
|
452 |
+
t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device)
|
453 |
+
t5xxl.eval()
|
454 |
+
|
455 |
+
# if is_fp8(clip_l_dtype):
|
456 |
+
# clip_l = accelerator.prepare(clip_l)
|
457 |
+
# if is_fp8(t5xxl_dtype):
|
458 |
+
# t5xxl = accelerator.prepare(t5xxl)
|
459 |
+
|
460 |
+
# DiT
|
461 |
+
is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device)
|
462 |
+
model.eval()
|
463 |
+
logger.info(f"Casting model to {flux_dtype}")
|
464 |
+
model.to(flux_dtype) # make sure model is dtype
|
465 |
+
# if is_fp8(flux_dtype):
|
466 |
+
# model = accelerator.prepare(model)
|
467 |
+
# if args.offload:
|
468 |
+
# model = model.to("cpu")
|
469 |
+
|
470 |
+
t5xxl_max_length = 256 if is_schnell else 512
|
471 |
+
tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length)
|
472 |
+
encoding_strategy = strategy_flux.FluxTextEncodingStrategy()
|
473 |
+
|
474 |
+
# AE
|
475 |
+
ae = flux_utils.load_ae(args.ae, ae_dtype, loading_device)
|
476 |
+
ae.eval()
|
477 |
+
# if is_fp8(ae_dtype):
|
478 |
+
# ae = accelerator.prepare(ae)
|
479 |
+
|
480 |
+
# LoRA
|
481 |
+
lora_models: List[lora_flux.LoRANetwork] = []
|
482 |
+
for weights_file in args.lora_weights:
|
483 |
+
if ";" in weights_file:
|
484 |
+
weights_file, multiplier = weights_file.split(";")
|
485 |
+
multiplier = float(multiplier)
|
486 |
+
else:
|
487 |
+
multiplier = 1.0
|
488 |
+
|
489 |
+
weights_sd = load_file(weights_file)
|
490 |
+
is_lora = is_oft = False
|
491 |
+
for key in weights_sd.keys():
|
492 |
+
if key.startswith("lora"):
|
493 |
+
is_lora = True
|
494 |
+
if key.startswith("oft"):
|
495 |
+
is_oft = True
|
496 |
+
if is_lora or is_oft:
|
497 |
+
break
|
498 |
+
|
499 |
+
module = lora_flux if is_lora else oft_flux
|
500 |
+
lora_model, _ = module.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True, lora_ups_num)
|
501 |
+
for sub_lora in lora_model.unet_loras:
|
502 |
+
sub_lora.set_lora_up_cur(lora_up_cur-1)
|
503 |
+
|
504 |
+
if args.merge_lora_weights:
|
505 |
+
lora_model.merge_to([clip_l, t5xxl], model, weights_sd)
|
506 |
+
else:
|
507 |
+
lora_model.apply_to([clip_l, t5xxl], model)
|
508 |
+
info = lora_model.load_state_dict(weights_sd, strict=True)
|
509 |
+
logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
|
510 |
+
lora_model.eval()
|
511 |
+
lora_model.to(device)
|
512 |
+
|
513 |
+
lora_models.append(lora_model)
|
514 |
+
|
515 |
+
if not args.interactive:
|
516 |
+
generate_image(
|
517 |
+
model,
|
518 |
+
clip_l,
|
519 |
+
t5xxl,
|
520 |
+
ae,
|
521 |
+
args.prompt,
|
522 |
+
args.seed,
|
523 |
+
args.width,
|
524 |
+
args.height,
|
525 |
+
args.steps,
|
526 |
+
args.guidance,
|
527 |
+
args.negative_prompt,
|
528 |
+
args.cfg_scale,
|
529 |
+
)
|
530 |
+
else:
|
531 |
+
# loop for interactive
|
532 |
+
width = target_width
|
533 |
+
height = target_height
|
534 |
+
steps = None
|
535 |
+
guidance = args.guidance
|
536 |
+
cfg_scale = args.cfg_scale
|
537 |
+
|
538 |
+
while True:
|
539 |
+
print(
|
540 |
+
"Enter prompt (empty to exit). Options: --w <width> --h <height> --s <steps> --d <seed> --g <guidance> --m <multipliers for LoRA>"
|
541 |
+
" --n <negative prompt>, `-` for empty negative prompt --c <cfg_scale>"
|
542 |
+
)
|
543 |
+
prompt = input()
|
544 |
+
if prompt == "":
|
545 |
+
break
|
546 |
+
|
547 |
+
# parse options
|
548 |
+
options = prompt.split("--")
|
549 |
+
prompt = options[0].strip()
|
550 |
+
seed = None
|
551 |
+
negative_prompt = None
|
552 |
+
for opt in options[1:]:
|
553 |
+
try:
|
554 |
+
opt = opt.strip()
|
555 |
+
if opt.startswith("w"):
|
556 |
+
width = int(opt[1:].strip())
|
557 |
+
elif opt.startswith("h"):
|
558 |
+
height = int(opt[1:].strip())
|
559 |
+
elif opt.startswith("s"):
|
560 |
+
steps = int(opt[1:].strip())
|
561 |
+
elif opt.startswith("d"):
|
562 |
+
seed = int(opt[1:].strip())
|
563 |
+
elif opt.startswith("g"):
|
564 |
+
guidance = float(opt[1:].strip())
|
565 |
+
elif opt.startswith("m"):
|
566 |
+
mutipliers = opt[1:].strip().split(",")
|
567 |
+
if len(mutipliers) != len(lora_models):
|
568 |
+
logger.error(f"Invalid number of multipliers, expected {len(lora_models)}")
|
569 |
+
continue
|
570 |
+
for i, lora_model in enumerate(lora_models):
|
571 |
+
lora_model.set_multiplier(float(mutipliers[i]))
|
572 |
+
elif opt.startswith("n"):
|
573 |
+
negative_prompt = opt[1:].strip()
|
574 |
+
if negative_prompt == "-":
|
575 |
+
negative_prompt = ""
|
576 |
+
elif opt.startswith("c"):
|
577 |
+
cfg_scale = float(opt[1:].strip())
|
578 |
+
except ValueError as e:
|
579 |
+
logger.error(f"Invalid option: {opt}, {e}")
|
580 |
+
|
581 |
+
generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance, negative_prompt, cfg_scale)
|
582 |
+
|
583 |
+
logger.info("Done!")
|
flux_train_network.py
ADDED
@@ -0,0 +1,588 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import copy
|
3 |
+
import math
|
4 |
+
import random
|
5 |
+
from typing import Any, Optional, Union
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from accelerate import Accelerator
|
9 |
+
|
10 |
+
from library.device_utils import clean_memory_on_device, init_ipex
|
11 |
+
|
12 |
+
init_ipex()
|
13 |
+
|
14 |
+
import train_network
|
15 |
+
from library import (
|
16 |
+
flux_models,
|
17 |
+
flux_train_utils,
|
18 |
+
flux_utils,
|
19 |
+
sd3_train_utils,
|
20 |
+
strategy_base,
|
21 |
+
strategy_flux,
|
22 |
+
train_util,
|
23 |
+
)
|
24 |
+
from library.utils import setup_logging
|
25 |
+
|
26 |
+
setup_logging()
|
27 |
+
import logging
|
28 |
+
|
29 |
+
logger = logging.getLogger(__name__)
|
30 |
+
|
31 |
+
|
32 |
+
class FluxNetworkTrainer(train_network.NetworkTrainer):
|
33 |
+
def __init__(self):
|
34 |
+
super().__init__()
|
35 |
+
self.sample_prompts_te_outputs = None
|
36 |
+
self.is_schnell: Optional[bool] = None
|
37 |
+
self.is_swapping_blocks: bool = False
|
38 |
+
|
39 |
+
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
|
40 |
+
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
|
41 |
+
# sdxl_train_util.verify_sdxl_training_args(args)
|
42 |
+
|
43 |
+
if args.fp8_base_unet:
|
44 |
+
args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1
|
45 |
+
|
46 |
+
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
47 |
+
logger.warning(
|
48 |
+
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります"
|
49 |
+
)
|
50 |
+
args.cache_text_encoder_outputs = True
|
51 |
+
|
52 |
+
if args.cache_text_encoder_outputs:
|
53 |
+
assert (
|
54 |
+
train_dataset_group.is_text_encoder_output_cacheable()
|
55 |
+
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
56 |
+
|
57 |
+
# prepare CLIP-L/T5XXL training flags
|
58 |
+
self.train_clip_l = not args.network_train_unet_only
|
59 |
+
self.train_t5xxl = False # default is False even if args.network_train_unet_only is False
|
60 |
+
|
61 |
+
if args.max_token_length is not None:
|
62 |
+
logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")
|
63 |
+
|
64 |
+
assert (
|
65 |
+
args.blocks_to_swap is None or args.blocks_to_swap == 0
|
66 |
+
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
|
67 |
+
|
68 |
+
# deprecated split_mode option
|
69 |
+
if args.split_mode:
|
70 |
+
if args.blocks_to_swap is not None:
|
71 |
+
logger.warning(
|
72 |
+
"split_mode is deprecated. Because `--blocks_to_swap` is set, `--split_mode` is ignored."
|
73 |
+
" / split_modeは非推奨です。`--blocks_to_swap`が設定されているため、`--split_mode`は無視されます。"
|
74 |
+
)
|
75 |
+
else:
|
76 |
+
logger.warning(
|
77 |
+
"split_mode is deprecated. Please use `--blocks_to_swap` instead. `--blocks_to_swap 18` is automatically set."
|
78 |
+
" / split_modeは非推奨です。代わりに`--blocks_to_swap`を使用してください。`--blocks_to_swap 18`が自動的に設定されました。"
|
79 |
+
)
|
80 |
+
args.blocks_to_swap = 18 # 18 is safe for most cases
|
81 |
+
|
82 |
+
train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
|
83 |
+
if val_dataset_group is not None:
|
84 |
+
val_dataset_group.verify_bucket_reso_steps(32) # TODO check this
|
85 |
+
|
86 |
+
def load_target_model(self, args, weight_dtype, accelerator):
|
87 |
+
# currently offload to cpu for some models
|
88 |
+
|
89 |
+
# if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
|
90 |
+
loading_dtype = None if args.fp8_base else weight_dtype
|
91 |
+
|
92 |
+
# if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
|
93 |
+
self.is_schnell, model = flux_utils.load_flow_model(
|
94 |
+
args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors
|
95 |
+
)
|
96 |
+
if args.fp8_base:
|
97 |
+
# check dtype of model
|
98 |
+
if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz:
|
99 |
+
raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
|
100 |
+
elif model.dtype == torch.float8_e4m3fn:
|
101 |
+
logger.info("Loaded fp8 FLUX model")
|
102 |
+
else:
|
103 |
+
logger.info(
|
104 |
+
"Cast FLUX model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint."
|
105 |
+
" / FLUXモデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。"
|
106 |
+
)
|
107 |
+
model.to(torch.float8_e4m3fn)
|
108 |
+
|
109 |
+
# if args.split_mode:
|
110 |
+
# model = self.prepare_split_model(model, weight_dtype, accelerator)
|
111 |
+
|
112 |
+
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
|
113 |
+
if self.is_swapping_blocks:
|
114 |
+
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
|
115 |
+
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
|
116 |
+
model.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
117 |
+
|
118 |
+
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
119 |
+
clip_l.eval()
|
120 |
+
|
121 |
+
# if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8)
|
122 |
+
if args.fp8_base and not args.fp8_base_unet:
|
123 |
+
loading_dtype = None # as is
|
124 |
+
else:
|
125 |
+
loading_dtype = weight_dtype
|
126 |
+
|
127 |
+
# loading t5xxl to cpu takes a long time, so we should load to gpu in future
|
128 |
+
t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
129 |
+
t5xxl.eval()
|
130 |
+
if args.fp8_base and not args.fp8_base_unet:
|
131 |
+
# check dtype of model
|
132 |
+
if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz:
|
133 |
+
raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}")
|
134 |
+
elif t5xxl.dtype == torch.float8_e4m3fn:
|
135 |
+
logger.info("Loaded fp8 T5XXL model")
|
136 |
+
|
137 |
+
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
138 |
+
|
139 |
+
return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model
|
140 |
+
|
141 |
+
def get_tokenize_strategy(self, args):
|
142 |
+
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
|
143 |
+
|
144 |
+
if args.t5xxl_max_token_length is None:
|
145 |
+
if is_schnell:
|
146 |
+
t5xxl_max_token_length = 256
|
147 |
+
else:
|
148 |
+
t5xxl_max_token_length = 512
|
149 |
+
else:
|
150 |
+
t5xxl_max_token_length = args.t5xxl_max_token_length
|
151 |
+
|
152 |
+
logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}")
|
153 |
+
return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir)
|
154 |
+
|
155 |
+
def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy):
|
156 |
+
return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl]
|
157 |
+
|
158 |
+
def get_latents_caching_strategy(self, args):
|
159 |
+
latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False)
|
160 |
+
return latents_caching_strategy
|
161 |
+
|
162 |
+
def get_text_encoding_strategy(self, args):
|
163 |
+
return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask)
|
164 |
+
|
165 |
+
def post_process_network(self, args, accelerator, network, text_encoders, unet):
|
166 |
+
# check t5xxl is trained or not
|
167 |
+
self.train_t5xxl = network.train_t5xxl
|
168 |
+
|
169 |
+
if self.train_t5xxl and args.cache_text_encoder_outputs:
|
170 |
+
raise ValueError(
|
171 |
+
"T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません"
|
172 |
+
)
|
173 |
+
|
174 |
+
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
175 |
+
if args.cache_text_encoder_outputs:
|
176 |
+
if self.train_clip_l and not self.train_t5xxl:
|
177 |
+
return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached
|
178 |
+
else:
|
179 |
+
return None # no text encoders are needed for encoding because both are cached
|
180 |
+
else:
|
181 |
+
return text_encoders # both CLIP-L and T5XXL are needed for encoding
|
182 |
+
|
183 |
+
def get_text_encoders_train_flags(self, args, text_encoders):
|
184 |
+
return [self.train_clip_l, self.train_t5xxl]
|
185 |
+
|
186 |
+
def get_text_encoder_outputs_caching_strategy(self, args):
|
187 |
+
if args.cache_text_encoder_outputs:
|
188 |
+
# if the text encoders is trained, we need tokenization, so is_partial is True
|
189 |
+
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
|
190 |
+
args.cache_text_encoder_outputs_to_disk,
|
191 |
+
args.text_encoder_batch_size,
|
192 |
+
args.skip_cache_check,
|
193 |
+
is_partial=self.train_clip_l or self.train_t5xxl,
|
194 |
+
apply_t5_attn_mask=args.apply_t5_attn_mask,
|
195 |
+
)
|
196 |
+
else:
|
197 |
+
return None
|
198 |
+
|
199 |
+
def cache_text_encoder_outputs_if_needed(
|
200 |
+
self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
|
201 |
+
):
|
202 |
+
if args.cache_text_encoder_outputs:
|
203 |
+
if not args.lowram:
|
204 |
+
# メモリ消費を減らす
|
205 |
+
logger.info("move vae and unet to cpu to save memory")
|
206 |
+
org_vae_device = vae.device
|
207 |
+
org_unet_device = unet.device
|
208 |
+
vae.to("cpu")
|
209 |
+
unet.to("cpu")
|
210 |
+
clean_memory_on_device(accelerator.device)
|
211 |
+
|
212 |
+
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
213 |
+
logger.info("move text encoders to gpu")
|
214 |
+
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
|
215 |
+
text_encoders[1].to(accelerator.device)
|
216 |
+
|
217 |
+
if text_encoders[1].dtype == torch.float8_e4m3fn:
|
218 |
+
# if we load fp8 weights, the model is already fp8, so we use it as is
|
219 |
+
self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
|
220 |
+
else:
|
221 |
+
# otherwise, we need to convert it to target dtype
|
222 |
+
text_encoders[1].to(weight_dtype)
|
223 |
+
|
224 |
+
with accelerator.autocast():
|
225 |
+
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
|
226 |
+
|
227 |
+
# cache sample prompts
|
228 |
+
if args.sample_prompts is not None:
|
229 |
+
logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")
|
230 |
+
|
231 |
+
tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
|
232 |
+
text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
|
233 |
+
|
234 |
+
prompts = train_util.load_prompts(args.sample_prompts)
|
235 |
+
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
|
236 |
+
with accelerator.autocast(), torch.no_grad():
|
237 |
+
for prompt_dict in prompts:
|
238 |
+
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
|
239 |
+
if p not in sample_prompts_te_outputs:
|
240 |
+
logger.info(f"cache Text Encoder outputs for prompt: {p}")
|
241 |
+
tokens_and_masks = tokenize_strategy.tokenize(p)
|
242 |
+
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
|
243 |
+
tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask
|
244 |
+
)
|
245 |
+
self.sample_prompts_te_outputs = sample_prompts_te_outputs
|
246 |
+
|
247 |
+
accelerator.wait_for_everyone()
|
248 |
+
|
249 |
+
# move back to cpu
|
250 |
+
if not self.is_train_text_encoder(args):
|
251 |
+
logger.info("move CLIP-L back to cpu")
|
252 |
+
text_encoders[0].to("cpu")
|
253 |
+
logger.info("move t5XXL back to cpu")
|
254 |
+
text_encoders[1].to("cpu")
|
255 |
+
clean_memory_on_device(accelerator.device)
|
256 |
+
|
257 |
+
if not args.lowram:
|
258 |
+
logger.info("move vae and unet back to original device")
|
259 |
+
vae.to(org_vae_device)
|
260 |
+
unet.to(org_unet_device)
|
261 |
+
else:
|
262 |
+
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
|
263 |
+
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
264 |
+
text_encoders[1].to(accelerator.device)
|
265 |
+
|
266 |
+
# def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
267 |
+
# noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
268 |
+
|
269 |
+
# # get size embeddings
|
270 |
+
# orig_size = batch["original_sizes_hw"]
|
271 |
+
# crop_size = batch["crop_top_lefts"]
|
272 |
+
# target_size = batch["target_sizes_hw"]
|
273 |
+
# embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
|
274 |
+
|
275 |
+
# # concat embeddings
|
276 |
+
# encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
|
277 |
+
# vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
|
278 |
+
# text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
|
279 |
+
|
280 |
+
# noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
281 |
+
# return noise_pred
|
282 |
+
|
283 |
+
def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux):
|
284 |
+
text_encoders = text_encoder # for compatibility
|
285 |
+
text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)
|
286 |
+
|
287 |
+
flux_train_utils.sample_images(
|
288 |
+
accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs
|
289 |
+
)
|
290 |
+
# return
|
291 |
+
|
292 |
+
"""
|
293 |
+
class FluxUpperLowerWrapper(torch.nn.Module):
|
294 |
+
def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.FluxLower, device: torch.device):
|
295 |
+
super().__init__()
|
296 |
+
self.flux_upper = flux_upper
|
297 |
+
self.flux_lower = flux_lower
|
298 |
+
self.target_device = device
|
299 |
+
|
300 |
+
def prepare_block_swap_before_forward(self):
|
301 |
+
pass
|
302 |
+
|
303 |
+
def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None):
|
304 |
+
self.flux_lower.to("cpu")
|
305 |
+
clean_memory_on_device(self.target_device)
|
306 |
+
self.flux_upper.to(self.target_device)
|
307 |
+
img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance, txt_attention_mask)
|
308 |
+
self.flux_upper.to("cpu")
|
309 |
+
clean_memory_on_device(self.target_device)
|
310 |
+
self.flux_lower.to(self.target_device)
|
311 |
+
return self.flux_lower(img, txt, vec, pe, txt_attention_mask)
|
312 |
+
|
313 |
+
wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device)
|
314 |
+
clean_memory_on_device(accelerator.device)
|
315 |
+
flux_train_utils.sample_images(
|
316 |
+
accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs
|
317 |
+
)
|
318 |
+
clean_memory_on_device(accelerator.device)
|
319 |
+
"""
|
320 |
+
|
321 |
+
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
|
322 |
+
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
|
323 |
+
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
324 |
+
return noise_scheduler
|
325 |
+
|
326 |
+
def encode_images_to_latents(self, args, accelerator, vae, images):
|
327 |
+
return vae.encode(images)
|
328 |
+
|
329 |
+
def shift_scale_latents(self, args, latents):
|
330 |
+
return latents
|
331 |
+
|
332 |
+
def get_noise_pred_and_target(
|
333 |
+
self,
|
334 |
+
args,
|
335 |
+
accelerator,
|
336 |
+
noise_scheduler,
|
337 |
+
latents,
|
338 |
+
batch,
|
339 |
+
text_encoder_conds,
|
340 |
+
unet: flux_models.Flux,
|
341 |
+
network,
|
342 |
+
weight_dtype,
|
343 |
+
train_unet,
|
344 |
+
is_train=True
|
345 |
+
):
|
346 |
+
# Sample noise that we'll add to the latents
|
347 |
+
noise = torch.randn_like(latents)
|
348 |
+
bsz = latents.shape[0]
|
349 |
+
|
350 |
+
# get noisy model input and timesteps
|
351 |
+
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
|
352 |
+
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
|
353 |
+
)
|
354 |
+
|
355 |
+
# pack latents and get img_ids
|
356 |
+
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
|
357 |
+
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
|
358 |
+
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
|
359 |
+
|
360 |
+
# get guidance
|
361 |
+
# ensure guidance_scale in args is float
|
362 |
+
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
|
363 |
+
|
364 |
+
# ensure the hidden state will require grad
|
365 |
+
if args.gradient_checkpointing:
|
366 |
+
noisy_model_input.requires_grad_(True)
|
367 |
+
for t in text_encoder_conds:
|
368 |
+
if t is not None and t.dtype.is_floating_point:
|
369 |
+
t.requires_grad_(True)
|
370 |
+
img_ids.requires_grad_(True)
|
371 |
+
guidance_vec.requires_grad_(True)
|
372 |
+
|
373 |
+
# Predict the noise residual
|
374 |
+
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
375 |
+
if not args.apply_t5_attn_mask:
|
376 |
+
t5_attn_mask = None
|
377 |
+
|
378 |
+
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
|
379 |
+
# if not args.split_mode:
|
380 |
+
# normal forward
|
381 |
+
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
382 |
+
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
383 |
+
model_pred = unet(
|
384 |
+
img=img,
|
385 |
+
img_ids=img_ids,
|
386 |
+
txt=t5_out,
|
387 |
+
txt_ids=txt_ids,
|
388 |
+
y=l_pooled,
|
389 |
+
timesteps=timesteps / 1000,
|
390 |
+
guidance=guidance_vec,
|
391 |
+
txt_attention_mask=t5_attn_mask,
|
392 |
+
)
|
393 |
+
"""
|
394 |
+
else:
|
395 |
+
# split forward to reduce memory usage
|
396 |
+
assert network.train_blocks == "single", "train_blocks must be single for split mode"
|
397 |
+
with accelerator.autocast():
|
398 |
+
# move flux lower to cpu, and then move flux upper to gpu
|
399 |
+
unet.to("cpu")
|
400 |
+
clean_memory_on_device(accelerator.device)
|
401 |
+
self.flux_upper.to(accelerator.device)
|
402 |
+
|
403 |
+
# upper model does not require grad
|
404 |
+
with torch.no_grad():
|
405 |
+
intermediate_img, intermediate_txt, vec, pe = self.flux_upper(
|
406 |
+
img=packed_noisy_model_input,
|
407 |
+
img_ids=img_ids,
|
408 |
+
txt=t5_out,
|
409 |
+
txt_ids=txt_ids,
|
410 |
+
y=l_pooled,
|
411 |
+
timesteps=timesteps / 1000,
|
412 |
+
guidance=guidance_vec,
|
413 |
+
txt_attention_mask=t5_attn_mask,
|
414 |
+
)
|
415 |
+
|
416 |
+
# move flux upper back to cpu, and then move flux lower to gpu
|
417 |
+
self.flux_upper.to("cpu")
|
418 |
+
clean_memory_on_device(accelerator.device)
|
419 |
+
unet.to(accelerator.device)
|
420 |
+
|
421 |
+
# lower model requires grad
|
422 |
+
intermediate_img.requires_grad_(True)
|
423 |
+
intermediate_txt.requires_grad_(True)
|
424 |
+
vec.requires_grad_(True)
|
425 |
+
pe.requires_grad_(True)
|
426 |
+
|
427 |
+
with torch.set_grad_enabled(is_train and train_unet):
|
428 |
+
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
|
429 |
+
"""
|
430 |
+
|
431 |
+
return model_pred
|
432 |
+
|
433 |
+
model_pred = call_dit(
|
434 |
+
img=packed_noisy_model_input,
|
435 |
+
img_ids=img_ids,
|
436 |
+
t5_out=t5_out,
|
437 |
+
txt_ids=txt_ids,
|
438 |
+
l_pooled=l_pooled,
|
439 |
+
timesteps=timesteps,
|
440 |
+
guidance_vec=guidance_vec,
|
441 |
+
t5_attn_mask=t5_attn_mask,
|
442 |
+
)
|
443 |
+
|
444 |
+
# unpack latents
|
445 |
+
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
|
446 |
+
|
447 |
+
# apply model prediction type
|
448 |
+
model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
449 |
+
|
450 |
+
# flow matching loss: this is different from SD3
|
451 |
+
target = noise - latents
|
452 |
+
|
453 |
+
# differential output preservation
|
454 |
+
if "custom_attributes" in batch:
|
455 |
+
diff_output_pr_indices = []
|
456 |
+
for i, custom_attributes in enumerate(batch["custom_attributes"]):
|
457 |
+
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
|
458 |
+
diff_output_pr_indices.append(i)
|
459 |
+
|
460 |
+
if len(diff_output_pr_indices) > 0:
|
461 |
+
network.set_multiplier(0.0)
|
462 |
+
unet.prepare_block_swap_before_forward()
|
463 |
+
with torch.no_grad():
|
464 |
+
model_pred_prior = call_dit(
|
465 |
+
img=packed_noisy_model_input[diff_output_pr_indices],
|
466 |
+
img_ids=img_ids[diff_output_pr_indices],
|
467 |
+
t5_out=t5_out[diff_output_pr_indices],
|
468 |
+
txt_ids=txt_ids[diff_output_pr_indices],
|
469 |
+
l_pooled=l_pooled[diff_output_pr_indices],
|
470 |
+
timesteps=timesteps[diff_output_pr_indices],
|
471 |
+
guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None,
|
472 |
+
t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None,
|
473 |
+
)
|
474 |
+
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
|
475 |
+
|
476 |
+
model_pred_prior = flux_utils.unpack_latents(model_pred_prior, packed_latent_height, packed_latent_width)
|
477 |
+
model_pred_prior, _ = flux_train_utils.apply_model_prediction_type(
|
478 |
+
args,
|
479 |
+
model_pred_prior,
|
480 |
+
noisy_model_input[diff_output_pr_indices],
|
481 |
+
sigmas[diff_output_pr_indices] if sigmas is not None else None,
|
482 |
+
)
|
483 |
+
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
|
484 |
+
|
485 |
+
return model_pred, target, timesteps, weighting
|
486 |
+
|
487 |
+
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
488 |
+
return loss
|
489 |
+
|
490 |
+
def get_sai_model_spec(self, args):
|
491 |
+
return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev")
|
492 |
+
|
493 |
+
def update_metadata(self, metadata, args):
|
494 |
+
metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask
|
495 |
+
metadata["ss_weighting_scheme"] = args.weighting_scheme
|
496 |
+
metadata["ss_logit_mean"] = args.logit_mean
|
497 |
+
metadata["ss_logit_std"] = args.logit_std
|
498 |
+
metadata["ss_mode_scale"] = args.mode_scale
|
499 |
+
metadata["ss_guidance_scale"] = args.guidance_scale
|
500 |
+
metadata["ss_timestep_sampling"] = args.timestep_sampling
|
501 |
+
metadata["ss_sigmoid_scale"] = args.sigmoid_scale
|
502 |
+
metadata["ss_model_prediction_type"] = args.model_prediction_type
|
503 |
+
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
|
504 |
+
|
505 |
+
def is_text_encoder_not_needed_for_training(self, args):
|
506 |
+
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
|
507 |
+
|
508 |
+
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
|
509 |
+
if index == 0: # CLIP-L
|
510 |
+
return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder)
|
511 |
+
else: # T5XXL
|
512 |
+
text_encoder.encoder.embed_tokens.requires_grad_(True)
|
513 |
+
|
514 |
+
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
|
515 |
+
if index == 0: # CLIP-L
|
516 |
+
logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}")
|
517 |
+
text_encoder.to(te_weight_dtype) # fp8
|
518 |
+
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
|
519 |
+
else: # T5XXL
|
520 |
+
|
521 |
+
def prepare_fp8(text_encoder, target_dtype):
|
522 |
+
def forward_hook(module):
|
523 |
+
def forward(hidden_states):
|
524 |
+
hidden_gelu = module.act(module.wi_0(hidden_states))
|
525 |
+
hidden_linear = module.wi_1(hidden_states)
|
526 |
+
hidden_states = hidden_gelu * hidden_linear
|
527 |
+
hidden_states = module.dropout(hidden_states)
|
528 |
+
|
529 |
+
hidden_states = module.wo(hidden_states)
|
530 |
+
return hidden_states
|
531 |
+
|
532 |
+
return forward
|
533 |
+
|
534 |
+
for module in text_encoder.modules():
|
535 |
+
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
|
536 |
+
# print("set", module.__class__.__name__, "to", target_dtype)
|
537 |
+
module.to(target_dtype)
|
538 |
+
if module.__class__.__name__ in ["T5DenseGatedActDense"]:
|
539 |
+
# print("set", module.__class__.__name__, "hooks")
|
540 |
+
module.forward = forward_hook(module)
|
541 |
+
|
542 |
+
if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype:
|
543 |
+
logger.info(f"T5XXL already prepared for fp8")
|
544 |
+
else:
|
545 |
+
logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks")
|
546 |
+
text_encoder.to(te_weight_dtype) # fp8
|
547 |
+
prepare_fp8(text_encoder, weight_dtype)
|
548 |
+
|
549 |
+
def prepare_unet_with_accelerator(
|
550 |
+
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
|
551 |
+
) -> torch.nn.Module:
|
552 |
+
if not self.is_swapping_blocks:
|
553 |
+
return super().prepare_unet_with_accelerator(args, accelerator, unet)
|
554 |
+
|
555 |
+
# if we doesn't swap blocks, we can move the model to device
|
556 |
+
flux: flux_models.Flux = unet
|
557 |
+
flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks])
|
558 |
+
accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
|
559 |
+
accelerator.unwrap_model(flux).prepare_block_swap_before_forward()
|
560 |
+
|
561 |
+
return flux
|
562 |
+
|
563 |
+
|
564 |
+
def setup_parser() -> argparse.ArgumentParser:
|
565 |
+
parser = train_network.setup_parser()
|
566 |
+
train_util.add_dit_training_arguments(parser)
|
567 |
+
flux_train_utils.add_flux_train_arguments(parser)
|
568 |
+
|
569 |
+
parser.add_argument(
|
570 |
+
"--split_mode",
|
571 |
+
action="store_true",
|
572 |
+
# help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required"
|
573 |
+
# + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要",
|
574 |
+
help="[Deprecated] This option is deprecated. Please use `--blocks_to_swap` instead."
|
575 |
+
" / このオプションは非推奨です。代わりに`--blocks_to_swap`を使用してください。",
|
576 |
+
)
|
577 |
+
return parser
|
578 |
+
|
579 |
+
|
580 |
+
if __name__ == "__main__":
|
581 |
+
parser = setup_parser()
|
582 |
+
|
583 |
+
args = parser.parse_args()
|
584 |
+
train_util.verify_command_line_training_args(args)
|
585 |
+
args = train_util.read_config_from_file(args, parser)
|
586 |
+
|
587 |
+
trainer = FluxNetworkTrainer()
|
588 |
+
trainer.train(args)
|
flux_train_network_asylora.py
ADDED
@@ -0,0 +1,591 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import copy
|
6 |
+
import math
|
7 |
+
import random
|
8 |
+
from typing import Any, Optional
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from accelerate import Accelerator
|
12 |
+
from library.device_utils import init_ipex, clean_memory_on_device
|
13 |
+
|
14 |
+
init_ipex()
|
15 |
+
|
16 |
+
from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util
|
17 |
+
import train_network_asylora
|
18 |
+
from library.utils import setup_logging
|
19 |
+
|
20 |
+
setup_logging()
|
21 |
+
import logging
|
22 |
+
import re
|
23 |
+
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
|
26 |
+
|
27 |
+
class FluxNetworkTrainer(train_network_asylora.NetworkTrainer):
|
28 |
+
def __init__(self):
|
29 |
+
super().__init__()
|
30 |
+
self.sample_prompts_te_outputs = None
|
31 |
+
self.is_schnell: Optional[bool] = None
|
32 |
+
self.is_swapping_blocks: bool = False
|
33 |
+
|
34 |
+
def assert_extra_args(self, args, train_dataset_group):
|
35 |
+
super().assert_extra_args(args, train_dataset_group)
|
36 |
+
# sdxl_train_util.verify_sdxl_training_args(args)
|
37 |
+
|
38 |
+
if args.fp8_base_unet:
|
39 |
+
args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1
|
40 |
+
|
41 |
+
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
42 |
+
logger.warning(
|
43 |
+
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled"
|
44 |
+
)
|
45 |
+
args.cache_text_encoder_outputs = True
|
46 |
+
|
47 |
+
if args.cache_text_encoder_outputs:
|
48 |
+
assert (
|
49 |
+
train_dataset_group.is_text_encoder_output_cacheable()
|
50 |
+
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
51 |
+
|
52 |
+
# prepare CLIP-L/T5XXL training flags
|
53 |
+
self.train_clip_l = not args.network_train_unet_only
|
54 |
+
self.train_t5xxl = False # default is False even if args.network_train_unet_only is False
|
55 |
+
|
56 |
+
if args.max_token_length is not None:
|
57 |
+
logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")
|
58 |
+
|
59 |
+
assert (
|
60 |
+
args.blocks_to_swap is None or args.blocks_to_swap == 0
|
61 |
+
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
|
62 |
+
|
63 |
+
# deprecated split_mode option
|
64 |
+
if args.split_mode:
|
65 |
+
if args.blocks_to_swap is not None:
|
66 |
+
logger.warning(
|
67 |
+
"split_mode is deprecated. Because `--blocks_to_swap` is set, `--split_mode` is ignored."
|
68 |
+
" / split_modeは非推奨です。`--blocks_to_swap`が設定されているため、`--split_mode`は無視されます。"
|
69 |
+
)
|
70 |
+
else:
|
71 |
+
logger.warning(
|
72 |
+
"split_mode is deprecated. Please use `--blocks_to_swap` instead. `--blocks_to_swap 18` is automatically set."
|
73 |
+
" / split_modeは非推奨です。代わりに`--blocks_to_swap`を使用してください。`--blocks_to_swap 18`が自動的に設定されました。"
|
74 |
+
)
|
75 |
+
args.blocks_to_swap = 18 # 18 is safe for most cases
|
76 |
+
|
77 |
+
train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
|
78 |
+
|
79 |
+
def load_target_model(self, args, weight_dtype, accelerator):
|
80 |
+
# currently offload to cpu for some models
|
81 |
+
|
82 |
+
# if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
|
83 |
+
loading_dtype = None if args.fp8_base else weight_dtype
|
84 |
+
|
85 |
+
# if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
|
86 |
+
self.is_schnell, model = flux_utils.load_flow_model(
|
87 |
+
args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors
|
88 |
+
)
|
89 |
+
if args.fp8_base:
|
90 |
+
# check dtype of model
|
91 |
+
if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz:
|
92 |
+
raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
|
93 |
+
elif model.dtype == torch.float8_e4m3fn:
|
94 |
+
logger.info("Loaded fp8 FLUX model")
|
95 |
+
else:
|
96 |
+
logger.info(
|
97 |
+
"Cast FLUX model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint."
|
98 |
+
" / FLUXモデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。"
|
99 |
+
)
|
100 |
+
model.to(torch.float8_e4m3fn)
|
101 |
+
|
102 |
+
# if args.split_mode:
|
103 |
+
# model = self.prepare_split_model(model, weight_dtype, accelerator)
|
104 |
+
|
105 |
+
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
|
106 |
+
if self.is_swapping_blocks:
|
107 |
+
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
|
108 |
+
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
|
109 |
+
model.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
110 |
+
|
111 |
+
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
112 |
+
clip_l.eval()
|
113 |
+
|
114 |
+
# if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8)
|
115 |
+
if args.fp8_base and not args.fp8_base_unet:
|
116 |
+
loading_dtype = None # as is
|
117 |
+
else:
|
118 |
+
loading_dtype = weight_dtype
|
119 |
+
|
120 |
+
# loading t5xxl to cpu takes a long time, so we should load to gpu in future
|
121 |
+
t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
122 |
+
t5xxl.eval()
|
123 |
+
if args.fp8_base and not args.fp8_base_unet:
|
124 |
+
# check dtype of model
|
125 |
+
if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz:
|
126 |
+
raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}")
|
127 |
+
elif t5xxl.dtype == torch.float8_e4m3fn:
|
128 |
+
logger.info("Loaded fp8 T5XXL model")
|
129 |
+
|
130 |
+
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
131 |
+
|
132 |
+
return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model
|
133 |
+
|
134 |
+
def get_tokenize_strategy(self, args):
|
135 |
+
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
|
136 |
+
|
137 |
+
if args.t5xxl_max_token_length is None:
|
138 |
+
if is_schnell:
|
139 |
+
t5xxl_max_token_length = 256
|
140 |
+
else:
|
141 |
+
t5xxl_max_token_length = 512
|
142 |
+
else:
|
143 |
+
t5xxl_max_token_length = args.t5xxl_max_token_length
|
144 |
+
|
145 |
+
logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}")
|
146 |
+
return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir)
|
147 |
+
|
148 |
+
def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy):
|
149 |
+
return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl]
|
150 |
+
|
151 |
+
def get_latents_caching_strategy(self, args):
|
152 |
+
latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False)
|
153 |
+
return latents_caching_strategy
|
154 |
+
|
155 |
+
def get_text_encoding_strategy(self, args):
|
156 |
+
return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask)
|
157 |
+
|
158 |
+
def post_process_network(self, args, accelerator, network, text_encoders, unet):
|
159 |
+
# check t5xxl is trained or not
|
160 |
+
self.train_t5xxl = network.train_t5xxl
|
161 |
+
|
162 |
+
if self.train_t5xxl and args.cache_text_encoder_outputs:
|
163 |
+
raise ValueError(
|
164 |
+
"T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません"
|
165 |
+
)
|
166 |
+
|
167 |
+
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
168 |
+
if args.cache_text_encoder_outputs:
|
169 |
+
if self.train_clip_l and not self.train_t5xxl:
|
170 |
+
return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached
|
171 |
+
else:
|
172 |
+
return None # no text encoders are needed for encoding because both are cached
|
173 |
+
else:
|
174 |
+
return text_encoders # both CLIP-L and T5XXL are needed for encoding
|
175 |
+
|
176 |
+
def get_text_encoders_train_flags(self, args, text_encoders):
|
177 |
+
return [self.train_clip_l, self.train_t5xxl]
|
178 |
+
|
179 |
+
def get_text_encoder_outputs_caching_strategy(self, args):
|
180 |
+
if args.cache_text_encoder_outputs:
|
181 |
+
# if the text encoders is trained, we need tokenization, so is_partial is True
|
182 |
+
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
|
183 |
+
args.cache_text_encoder_outputs_to_disk,
|
184 |
+
args.text_encoder_batch_size,
|
185 |
+
args.skip_cache_check,
|
186 |
+
is_partial=self.train_clip_l or self.train_t5xxl,
|
187 |
+
apply_t5_attn_mask=args.apply_t5_attn_mask,
|
188 |
+
)
|
189 |
+
else:
|
190 |
+
return None
|
191 |
+
|
192 |
+
def cache_text_encoder_outputs_if_needed(
|
193 |
+
self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
|
194 |
+
):
|
195 |
+
if args.cache_text_encoder_outputs:
|
196 |
+
if not args.lowram:
|
197 |
+
# メモリ消費を減らす
|
198 |
+
logger.info("move vae and unet to cpu to save memory")
|
199 |
+
org_vae_device = vae.device
|
200 |
+
org_unet_device = unet.device
|
201 |
+
vae.to("cpu")
|
202 |
+
unet.to("cpu")
|
203 |
+
clean_memory_on_device(accelerator.device)
|
204 |
+
|
205 |
+
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
206 |
+
logger.info("move text encoders to gpu")
|
207 |
+
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
|
208 |
+
text_encoders[1].to(accelerator.device)
|
209 |
+
|
210 |
+
if text_encoders[1].dtype == torch.float8_e4m3fn:
|
211 |
+
# if we load fp8 weights, the model is already fp8, so we use it as is
|
212 |
+
self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
|
213 |
+
else:
|
214 |
+
# otherwise, we need to convert it to target dtype
|
215 |
+
text_encoders[1].to(weight_dtype)
|
216 |
+
|
217 |
+
with accelerator.autocast():
|
218 |
+
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
|
219 |
+
|
220 |
+
# cache sample prompts
|
221 |
+
if args.sample_prompts is not None:
|
222 |
+
logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")
|
223 |
+
|
224 |
+
tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
|
225 |
+
text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
|
226 |
+
|
227 |
+
prompts = train_util.load_prompts(args.sample_prompts)
|
228 |
+
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
|
229 |
+
with accelerator.autocast(), torch.no_grad():
|
230 |
+
for prompt_dict in prompts:
|
231 |
+
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
|
232 |
+
if p not in sample_prompts_te_outputs:
|
233 |
+
logger.info(f"cache Text Encoder outputs for prompt: {p}")
|
234 |
+
tokens_and_masks = tokenize_strategy.tokenize(p)
|
235 |
+
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
|
236 |
+
tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask
|
237 |
+
)
|
238 |
+
self.sample_prompts_te_outputs = sample_prompts_te_outputs
|
239 |
+
|
240 |
+
accelerator.wait_for_everyone()
|
241 |
+
|
242 |
+
# move back to cpu
|
243 |
+
if not self.is_train_text_encoder(args):
|
244 |
+
logger.info("move CLIP-L back to cpu")
|
245 |
+
text_encoders[0].to("cpu")
|
246 |
+
logger.info("move t5XXL back to cpu")
|
247 |
+
text_encoders[1].to("cpu")
|
248 |
+
clean_memory_on_device(accelerator.device)
|
249 |
+
|
250 |
+
if not args.lowram:
|
251 |
+
logger.info("move vae and unet back to original device")
|
252 |
+
vae.to(org_vae_device)
|
253 |
+
unet.to(org_unet_device)
|
254 |
+
else:
|
255 |
+
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
|
256 |
+
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
257 |
+
text_encoders[1].to(accelerator.device)
|
258 |
+
|
259 |
+
# def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
260 |
+
# noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
261 |
+
|
262 |
+
# # get size embeddings
|
263 |
+
# orig_size = batch["original_sizes_hw"]
|
264 |
+
# crop_size = batch["crop_top_lefts"]
|
265 |
+
# target_size = batch["target_sizes_hw"]
|
266 |
+
# embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
|
267 |
+
|
268 |
+
# # concat embeddings
|
269 |
+
# encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
|
270 |
+
# vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
|
271 |
+
# text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
|
272 |
+
|
273 |
+
# noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
274 |
+
# return noise_pred
|
275 |
+
|
276 |
+
def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux):
|
277 |
+
text_encoders = text_encoder # for compatibility
|
278 |
+
text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)
|
279 |
+
|
280 |
+
flux_train_utils.sample_images(
|
281 |
+
accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs
|
282 |
+
)
|
283 |
+
# return
|
284 |
+
|
285 |
+
"""
|
286 |
+
class FluxUpperLowerWrapper(torch.nn.Module):
|
287 |
+
def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.FluxLower, device: torch.device):
|
288 |
+
super().__init__()
|
289 |
+
self.flux_upper = flux_upper
|
290 |
+
self.flux_lower = flux_lower
|
291 |
+
self.target_device = device
|
292 |
+
|
293 |
+
def prepare_block_swap_before_forward(self):
|
294 |
+
pass
|
295 |
+
|
296 |
+
def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None):
|
297 |
+
self.flux_lower.to("cpu")
|
298 |
+
clean_memory_on_device(self.target_device)
|
299 |
+
self.flux_upper.to(self.target_device)
|
300 |
+
img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance, txt_attention_mask)
|
301 |
+
self.flux_upper.to("cpu")
|
302 |
+
clean_memory_on_device(self.target_device)
|
303 |
+
self.flux_lower.to(self.target_device)
|
304 |
+
return self.flux_lower(img, txt, vec, pe, txt_attention_mask)
|
305 |
+
|
306 |
+
wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device)
|
307 |
+
clean_memory_on_device(accelerator.device)
|
308 |
+
flux_train_utils.sample_images(
|
309 |
+
accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs
|
310 |
+
)
|
311 |
+
clean_memory_on_device(accelerator.device)
|
312 |
+
"""
|
313 |
+
|
314 |
+
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
|
315 |
+
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
|
316 |
+
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
317 |
+
return noise_scheduler
|
318 |
+
|
319 |
+
def encode_images_to_latents(self, args, accelerator, vae, images):
|
320 |
+
return vae.encode(images)
|
321 |
+
|
322 |
+
def shift_scale_latents(self, args, latents):
|
323 |
+
return latents
|
324 |
+
|
325 |
+
def get_noise_pred_and_target(
|
326 |
+
self,
|
327 |
+
args,
|
328 |
+
accelerator,
|
329 |
+
noise_scheduler,
|
330 |
+
latents,
|
331 |
+
batch,
|
332 |
+
text_encoder_conds,
|
333 |
+
unet: flux_models.Flux,
|
334 |
+
network,
|
335 |
+
weight_dtype,
|
336 |
+
train_unet,
|
337 |
+
):
|
338 |
+
# Sample noise that we'll add to the latents
|
339 |
+
noise = torch.randn_like(latents)
|
340 |
+
bsz = latents.shape[0]
|
341 |
+
|
342 |
+
# get noisy model input and timesteps
|
343 |
+
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
|
344 |
+
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
|
345 |
+
)
|
346 |
+
|
347 |
+
# pack latents and get img_ids
|
348 |
+
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
|
349 |
+
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
|
350 |
+
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
|
351 |
+
|
352 |
+
# get guidance
|
353 |
+
# ensure guidance_scale in args is float
|
354 |
+
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
|
355 |
+
|
356 |
+
# ensure the hidden state will require grad
|
357 |
+
if args.gradient_checkpointing:
|
358 |
+
noisy_model_input.requires_grad_(True)
|
359 |
+
for t in text_encoder_conds:
|
360 |
+
if t is not None and t.dtype.is_floating_point:
|
361 |
+
t.requires_grad_(True)
|
362 |
+
img_ids.requires_grad_(True)
|
363 |
+
guidance_vec.requires_grad_(True)
|
364 |
+
|
365 |
+
# Predict the noise residual
|
366 |
+
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
367 |
+
if not args.apply_t5_attn_mask:
|
368 |
+
t5_attn_mask = None
|
369 |
+
|
370 |
+
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
|
371 |
+
# if not args.split_mode:
|
372 |
+
# normal forward
|
373 |
+
with accelerator.autocast():
|
374 |
+
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
375 |
+
model_pred = unet(
|
376 |
+
img=img,
|
377 |
+
img_ids=img_ids,
|
378 |
+
txt=t5_out,
|
379 |
+
txt_ids=txt_ids,
|
380 |
+
y=l_pooled,
|
381 |
+
timesteps=timesteps / 1000,
|
382 |
+
guidance=guidance_vec,
|
383 |
+
txt_attention_mask=t5_attn_mask
|
384 |
+
)
|
385 |
+
"""
|
386 |
+
else:
|
387 |
+
# split forward to reduce memory usage
|
388 |
+
assert network.train_blocks == "single", "train_blocks must be single for split mode"
|
389 |
+
with accelerator.autocast():
|
390 |
+
# move flux lower to cpu, and then move flux upper to gpu
|
391 |
+
unet.to("cpu")
|
392 |
+
clean_memory_on_device(accelerator.device)
|
393 |
+
self.flux_upper.to(accelerator.device)
|
394 |
+
|
395 |
+
# upper model does not require grad
|
396 |
+
with torch.no_grad():
|
397 |
+
intermediate_img, intermediate_txt, vec, pe = self.flux_upper(
|
398 |
+
img=packed_noisy_model_input,
|
399 |
+
img_ids=img_ids,
|
400 |
+
txt=t5_out,
|
401 |
+
txt_ids=txt_ids,
|
402 |
+
y=l_pooled,
|
403 |
+
timesteps=timesteps / 1000,
|
404 |
+
guidance=guidance_vec,
|
405 |
+
txt_attention_mask=t5_attn_mask,
|
406 |
+
)
|
407 |
+
|
408 |
+
# move flux upper back to cpu, and then move flux lower to gpu
|
409 |
+
self.flux_upper.to("cpu")
|
410 |
+
clean_memory_on_device(accelerator.device)
|
411 |
+
unet.to(accelerator.device)
|
412 |
+
|
413 |
+
# lower model requires grad
|
414 |
+
intermediate_img.requires_grad_(True)
|
415 |
+
intermediate_txt.requires_grad_(True)
|
416 |
+
vec.requires_grad_(True)
|
417 |
+
pe.requires_grad_(True)
|
418 |
+
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
|
419 |
+
"""
|
420 |
+
|
421 |
+
return model_pred
|
422 |
+
|
423 |
+
# 获取数据集分类编号 文本
|
424 |
+
# lora_category = batch["captions"][0].split(",")[0][3:]
|
425 |
+
# assert lora_category.isdigit(), f"lora_category 不是整数,值为: {lora_category}, {batch['captions'][0]}"
|
426 |
+
# lora_category = int(lora_category)
|
427 |
+
|
428 |
+
prompt_cur = batch["captions"][0]
|
429 |
+
match = re.search(r'--lora_up_cur (\d+)', prompt_cur)
|
430 |
+
assert match, "Pattern '--lora_up_cur' not found"
|
431 |
+
lora_category = int(match.group(1))
|
432 |
+
|
433 |
+
for lora in network.unet_loras:
|
434 |
+
lora.set_lora_up_cur(lora_category-1)
|
435 |
+
|
436 |
+
model_pred = call_dit(
|
437 |
+
img=packed_noisy_model_input,
|
438 |
+
img_ids=img_ids,
|
439 |
+
t5_out=t5_out,
|
440 |
+
txt_ids=txt_ids,
|
441 |
+
l_pooled=l_pooled,
|
442 |
+
timesteps=timesteps,
|
443 |
+
guidance_vec=guidance_vec,
|
444 |
+
t5_attn_mask=t5_attn_mask
|
445 |
+
)
|
446 |
+
|
447 |
+
# unpack latents
|
448 |
+
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
|
449 |
+
|
450 |
+
# apply model prediction type
|
451 |
+
model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
452 |
+
|
453 |
+
# flow matching loss: this is different from SD3
|
454 |
+
target = noise - latents
|
455 |
+
|
456 |
+
# differential output preservation
|
457 |
+
if "custom_attributes" in batch:
|
458 |
+
diff_output_pr_indices = []
|
459 |
+
for i, custom_attributes in enumerate(batch["custom_attributes"]):
|
460 |
+
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
|
461 |
+
diff_output_pr_indices.append(i)
|
462 |
+
|
463 |
+
if len(diff_output_pr_indices) > 0:
|
464 |
+
network.set_multiplier(0.0)
|
465 |
+
unet.prepare_block_swap_before_forward()
|
466 |
+
with torch.no_grad():
|
467 |
+
model_pred_prior = call_dit(
|
468 |
+
img=packed_noisy_model_input[diff_output_pr_indices],
|
469 |
+
img_ids=img_ids[diff_output_pr_indices],
|
470 |
+
t5_out=t5_out[diff_output_pr_indices],
|
471 |
+
txt_ids=txt_ids[diff_output_pr_indices],
|
472 |
+
l_pooled=l_pooled[diff_output_pr_indices],
|
473 |
+
timesteps=timesteps[diff_output_pr_indices],
|
474 |
+
guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None,
|
475 |
+
t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None,
|
476 |
+
)
|
477 |
+
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
|
478 |
+
|
479 |
+
model_pred_prior = flux_utils.unpack_latents(model_pred_prior, packed_latent_height, packed_latent_width)
|
480 |
+
model_pred_prior, _ = flux_train_utils.apply_model_prediction_type(
|
481 |
+
args,
|
482 |
+
model_pred_prior,
|
483 |
+
noisy_model_input[diff_output_pr_indices],
|
484 |
+
sigmas[diff_output_pr_indices] if sigmas is not None else None,
|
485 |
+
)
|
486 |
+
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
|
487 |
+
|
488 |
+
return model_pred, target, timesteps, None, weighting
|
489 |
+
|
490 |
+
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
491 |
+
return loss
|
492 |
+
|
493 |
+
def get_sai_model_spec(self, args):
|
494 |
+
return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev")
|
495 |
+
|
496 |
+
def update_metadata(self, metadata, args):
|
497 |
+
metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask
|
498 |
+
metadata["ss_weighting_scheme"] = args.weighting_scheme
|
499 |
+
metadata["ss_logit_mean"] = args.logit_mean
|
500 |
+
metadata["ss_logit_std"] = args.logit_std
|
501 |
+
metadata["ss_mode_scale"] = args.mode_scale
|
502 |
+
metadata["ss_guidance_scale"] = args.guidance_scale
|
503 |
+
metadata["ss_timestep_sampling"] = args.timestep_sampling
|
504 |
+
metadata["ss_sigmoid_scale"] = args.sigmoid_scale
|
505 |
+
metadata["ss_model_prediction_type"] = args.model_prediction_type
|
506 |
+
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
|
507 |
+
|
508 |
+
def is_text_encoder_not_needed_for_training(self, args):
|
509 |
+
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
|
510 |
+
|
511 |
+
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
|
512 |
+
if index == 0: # CLIP-L
|
513 |
+
return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder)
|
514 |
+
else: # T5XXL
|
515 |
+
text_encoder.encoder.embed_tokens.requires_grad_(True)
|
516 |
+
|
517 |
+
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
|
518 |
+
if index == 0: # CLIP-L
|
519 |
+
logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}")
|
520 |
+
text_encoder.to(te_weight_dtype) # fp8
|
521 |
+
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
|
522 |
+
else: # T5XXL
|
523 |
+
|
524 |
+
def prepare_fp8(text_encoder, target_dtype):
|
525 |
+
def forward_hook(module):
|
526 |
+
def forward(hidden_states):
|
527 |
+
hidden_gelu = module.act(module.wi_0(hidden_states))
|
528 |
+
hidden_linear = module.wi_1(hidden_states)
|
529 |
+
hidden_states = hidden_gelu * hidden_linear
|
530 |
+
hidden_states = module.dropout(hidden_states)
|
531 |
+
|
532 |
+
hidden_states = module.wo(hidden_states)
|
533 |
+
return hidden_states
|
534 |
+
|
535 |
+
return forward
|
536 |
+
|
537 |
+
for module in text_encoder.modules():
|
538 |
+
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
|
539 |
+
# print("set", module.__class__.__name__, "to", target_dtype)
|
540 |
+
module.to(target_dtype)
|
541 |
+
if module.__class__.__name__ in ["T5DenseGatedActDense"]:
|
542 |
+
# print("set", module.__class__.__name__, "hooks")
|
543 |
+
module.forward = forward_hook(module)
|
544 |
+
|
545 |
+
if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype:
|
546 |
+
logger.info(f"T5XXL already prepared for fp8")
|
547 |
+
else:
|
548 |
+
logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks")
|
549 |
+
text_encoder.to(te_weight_dtype) # fp8
|
550 |
+
prepare_fp8(text_encoder, weight_dtype)
|
551 |
+
|
552 |
+
def prepare_unet_with_accelerator(
|
553 |
+
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
|
554 |
+
) -> torch.nn.Module:
|
555 |
+
if not self.is_swapping_blocks:
|
556 |
+
return super().prepare_unet_with_accelerator(args, accelerator, unet)
|
557 |
+
|
558 |
+
# if we doesn't swap blocks, we can move the model to device
|
559 |
+
flux: flux_models.Flux = unet
|
560 |
+
flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks])
|
561 |
+
accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
|
562 |
+
accelerator.unwrap_model(flux).prepare_block_swap_before_forward()
|
563 |
+
|
564 |
+
return flux
|
565 |
+
|
566 |
+
|
567 |
+
def setup_parser() -> argparse.ArgumentParser:
|
568 |
+
parser = train_network_asylora.setup_parser()
|
569 |
+
train_util.add_dit_training_arguments(parser)
|
570 |
+
flux_train_utils.add_flux_train_arguments(parser)
|
571 |
+
|
572 |
+
parser.add_argument(
|
573 |
+
"--split_mode",
|
574 |
+
action="store_true",
|
575 |
+
# help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required"
|
576 |
+
# + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要",
|
577 |
+
help="[Deprecated] This option is deprecated. Please use `--blocks_to_swap` instead."
|
578 |
+
" / このオプションは非推奨です。代わりに`--blocks_to_swap`を使用してください。",
|
579 |
+
)
|
580 |
+
return parser
|
581 |
+
|
582 |
+
|
583 |
+
if __name__ == "__main__":
|
584 |
+
parser = setup_parser()
|
585 |
+
|
586 |
+
args = parser.parse_args()
|
587 |
+
train_util.verify_command_line_training_args(args)
|
588 |
+
args = train_util.read_config_from_file(args, parser)
|
589 |
+
|
590 |
+
trainer = FluxNetworkTrainer()
|
591 |
+
trainer.train(args)
|
flux_train_recraft.py
ADDED
@@ -0,0 +1,713 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import copy
|
3 |
+
import math
|
4 |
+
import random
|
5 |
+
from typing import Any
|
6 |
+
import pdb
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from accelerate import Accelerator
|
10 |
+
from library.device_utils import init_ipex, clean_memory_on_device
|
11 |
+
|
12 |
+
init_ipex()
|
13 |
+
|
14 |
+
from library import flux_models, flux_train_utils_recraft as flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util
|
15 |
+
from torchvision import transforms
|
16 |
+
import train_network
|
17 |
+
from library.utils import setup_logging
|
18 |
+
from diffusers.utils import load_image
|
19 |
+
import numpy as np
|
20 |
+
from PIL import Image, ImageOps
|
21 |
+
|
22 |
+
setup_logging()
|
23 |
+
import logging
|
24 |
+
|
25 |
+
logger = logging.getLogger(__name__)
|
26 |
+
|
27 |
+
# NUM_SPLIT = 2
|
28 |
+
|
29 |
+
class ResizeWithPadding:
|
30 |
+
def __init__(self, size, fill=255):
|
31 |
+
self.size = size
|
32 |
+
self.fill = fill
|
33 |
+
|
34 |
+
def __call__(self, img):
|
35 |
+
if isinstance(img, np.ndarray):
|
36 |
+
img = Image.fromarray(img)
|
37 |
+
elif not isinstance(img, Image.Image):
|
38 |
+
raise TypeError("Input must be a PIL Image or a NumPy array")
|
39 |
+
|
40 |
+
width, height = img.size
|
41 |
+
|
42 |
+
if width == height:
|
43 |
+
img = img.resize((self.size, self.size), Image.LANCZOS)
|
44 |
+
else:
|
45 |
+
max_dim = max(width, height)
|
46 |
+
|
47 |
+
new_img = Image.new("RGB", (max_dim, max_dim), (self.fill, self.fill, self.fill))
|
48 |
+
new_img.paste(img, ((max_dim - width) // 2, (max_dim - height) // 2))
|
49 |
+
|
50 |
+
img = new_img.resize((self.size, self.size), Image.LANCZOS)
|
51 |
+
|
52 |
+
return img
|
53 |
+
|
54 |
+
class FluxNetworkTrainer(train_network.NetworkTrainer):
|
55 |
+
def __init__(self):
|
56 |
+
super().__init__()
|
57 |
+
self.sample_prompts_te_outputs = None
|
58 |
+
self.sample_conditions = None
|
59 |
+
self.is_schnell: Optional[bool] = None
|
60 |
+
|
61 |
+
def assert_extra_args(self, args, train_dataset_group):
|
62 |
+
super().assert_extra_args(args, train_dataset_group)
|
63 |
+
# sdxl_train_util.verify_sdxl_training_args(args)
|
64 |
+
|
65 |
+
if args.fp8_base_unet:
|
66 |
+
args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1
|
67 |
+
|
68 |
+
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
69 |
+
logger.warning(
|
70 |
+
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります"
|
71 |
+
)
|
72 |
+
args.cache_text_encoder_outputs = True
|
73 |
+
|
74 |
+
if args.cache_text_encoder_outputs:
|
75 |
+
assert (
|
76 |
+
train_dataset_group.is_text_encoder_output_cacheable()
|
77 |
+
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
78 |
+
|
79 |
+
# prepare CLIP-L/T5XXL training flags
|
80 |
+
self.train_clip_l = not args.network_train_unet_only
|
81 |
+
self.train_t5xxl = False # default is False even if args.network_train_unet_only is False
|
82 |
+
|
83 |
+
if args.max_token_length is not None:
|
84 |
+
logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")
|
85 |
+
|
86 |
+
assert not args.split_mode or not args.cpu_offload_checkpointing, (
|
87 |
+
"split_mode and cpu_offload_checkpointing cannot be used together"
|
88 |
+
" / split_modeとcpu_offload_checkpointingは同時に使用できません"
|
89 |
+
)
|
90 |
+
|
91 |
+
train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
|
92 |
+
|
93 |
+
def load_target_model(self, args, weight_dtype, accelerator):
|
94 |
+
# currently offload to cpu for some models
|
95 |
+
|
96 |
+
# if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
|
97 |
+
loading_dtype = None if args.fp8_base else weight_dtype
|
98 |
+
|
99 |
+
# if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
|
100 |
+
self.is_schnell, model = flux_utils.load_flow_model(
|
101 |
+
args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors
|
102 |
+
)
|
103 |
+
if args.fp8_base:
|
104 |
+
# check dtype of model
|
105 |
+
if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz:
|
106 |
+
raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
|
107 |
+
elif model.dtype == torch.float8_e4m3fn:
|
108 |
+
logger.info("Loaded fp8 FLUX model")
|
109 |
+
|
110 |
+
if args.split_mode:
|
111 |
+
model = self.prepare_split_model(model, weight_dtype, accelerator)
|
112 |
+
|
113 |
+
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
114 |
+
clip_l.eval()
|
115 |
+
|
116 |
+
# if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8)
|
117 |
+
if args.fp8_base and not args.fp8_base_unet:
|
118 |
+
loading_dtype = None # as is
|
119 |
+
else:
|
120 |
+
loading_dtype = weight_dtype
|
121 |
+
|
122 |
+
# loading t5xxl to cpu takes a long time, so we should load to gpu in future
|
123 |
+
t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
124 |
+
t5xxl.eval()
|
125 |
+
if args.fp8_base and not args.fp8_base_unet:
|
126 |
+
# check dtype of model
|
127 |
+
if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz:
|
128 |
+
raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}")
|
129 |
+
elif t5xxl.dtype == torch.float8_e4m3fn:
|
130 |
+
logger.info("Loaded fp8 T5XXL model")
|
131 |
+
|
132 |
+
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
133 |
+
|
134 |
+
return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model
|
135 |
+
|
136 |
+
def prepare_split_model(self, model, weight_dtype, accelerator):
|
137 |
+
from accelerate import init_empty_weights
|
138 |
+
|
139 |
+
logger.info("prepare split model")
|
140 |
+
with init_empty_weights():
|
141 |
+
flux_upper = flux_models.FluxUpper(model.params)
|
142 |
+
flux_lower = flux_models.FluxLower(model.params)
|
143 |
+
sd = model.state_dict()
|
144 |
+
|
145 |
+
# lower (trainable)
|
146 |
+
logger.info("load state dict for lower")
|
147 |
+
flux_lower.load_state_dict(sd, strict=False, assign=True)
|
148 |
+
flux_lower.to(dtype=weight_dtype)
|
149 |
+
|
150 |
+
# upper (frozen)
|
151 |
+
logger.info("load state dict for upper")
|
152 |
+
flux_upper.load_state_dict(sd, strict=False, assign=True)
|
153 |
+
|
154 |
+
logger.info("prepare upper model")
|
155 |
+
target_dtype = torch.float8_e4m3fn if args.fp8_base else weight_dtype
|
156 |
+
flux_upper.to(accelerator.device, dtype=target_dtype)
|
157 |
+
flux_upper.eval()
|
158 |
+
|
159 |
+
if args.fp8_base:
|
160 |
+
# this is required to run on fp8
|
161 |
+
flux_upper = accelerator.prepare(flux_upper)
|
162 |
+
|
163 |
+
flux_upper.to("cpu")
|
164 |
+
|
165 |
+
self.flux_upper = flux_upper
|
166 |
+
del model # we don't need model anymore
|
167 |
+
clean_memory_on_device(accelerator.device)
|
168 |
+
|
169 |
+
logger.info("split model prepared")
|
170 |
+
|
171 |
+
return flux_lower
|
172 |
+
|
173 |
+
def get_tokenize_strategy(self, args):
|
174 |
+
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
|
175 |
+
|
176 |
+
if args.t5xxl_max_token_length is None:
|
177 |
+
if is_schnell:
|
178 |
+
t5xxl_max_token_length = 256
|
179 |
+
else:
|
180 |
+
t5xxl_max_token_length = 512
|
181 |
+
else:
|
182 |
+
t5xxl_max_token_length = args.t5xxl_max_token_length
|
183 |
+
|
184 |
+
logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}")
|
185 |
+
return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir)
|
186 |
+
|
187 |
+
def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy):
|
188 |
+
return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl]
|
189 |
+
|
190 |
+
def get_latents_caching_strategy(self, args):
|
191 |
+
latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False)
|
192 |
+
return latents_caching_strategy
|
193 |
+
|
194 |
+
def get_text_encoding_strategy(self, args):
|
195 |
+
return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask)
|
196 |
+
|
197 |
+
def post_process_network(self, args, accelerator, network, text_encoders, unet):
|
198 |
+
# check t5xxl is trained or not
|
199 |
+
self.train_t5xxl = network.train_t5xxl
|
200 |
+
|
201 |
+
if self.train_t5xxl and args.cache_text_encoder_outputs:
|
202 |
+
raise ValueError(
|
203 |
+
"T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません"
|
204 |
+
)
|
205 |
+
|
206 |
+
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
207 |
+
if args.cache_text_encoder_outputs:
|
208 |
+
if self.train_clip_l and not self.train_t5xxl:
|
209 |
+
return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached
|
210 |
+
else:
|
211 |
+
return None # no text encoders are needed for encoding because both are cached
|
212 |
+
else:
|
213 |
+
return text_encoders # both CLIP-L and T5XXL are needed for encoding
|
214 |
+
|
215 |
+
def get_text_encoders_train_flags(self, args, text_encoders):
|
216 |
+
return [self.train_clip_l, self.train_t5xxl]
|
217 |
+
|
218 |
+
def get_text_encoder_outputs_caching_strategy(self, args):
|
219 |
+
if args.cache_text_encoder_outputs:
|
220 |
+
# if the text encoders is trained, we need tokenization, so is_partial is True
|
221 |
+
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
|
222 |
+
args.cache_text_encoder_outputs_to_disk,
|
223 |
+
args.text_encoder_batch_size,
|
224 |
+
args.skip_cache_check,
|
225 |
+
is_partial=self.train_clip_l or self.train_t5xxl,
|
226 |
+
apply_t5_attn_mask=args.apply_t5_attn_mask,
|
227 |
+
)
|
228 |
+
else:
|
229 |
+
return None
|
230 |
+
|
231 |
+
def cache_text_encoder_outputs_if_needed(
|
232 |
+
self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
|
233 |
+
):
|
234 |
+
if args.cache_text_encoder_outputs:
|
235 |
+
if not args.lowram:
|
236 |
+
# メモリ消費を減らす
|
237 |
+
logger.info("move vae and unet to cpu to save memory")
|
238 |
+
org_vae_device = vae.device
|
239 |
+
org_unet_device = unet.device
|
240 |
+
vae.to("cpu")
|
241 |
+
unet.to("cpu")
|
242 |
+
clean_memory_on_device(accelerator.device)
|
243 |
+
|
244 |
+
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
245 |
+
logger.info("move text encoders to gpu")
|
246 |
+
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
|
247 |
+
text_encoders[1].to(accelerator.device)
|
248 |
+
|
249 |
+
if text_encoders[1].dtype == torch.float8_e4m3fn:
|
250 |
+
# if we load fp8 weights, the model is already fp8, so we use it as is
|
251 |
+
self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
|
252 |
+
else:
|
253 |
+
# otherwise, we need to convert it to target dtype
|
254 |
+
text_encoders[1].to(weight_dtype)
|
255 |
+
|
256 |
+
with accelerator.autocast():
|
257 |
+
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
|
258 |
+
|
259 |
+
# cache sample prompts
|
260 |
+
if args.sample_prompts is not None:
|
261 |
+
logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")
|
262 |
+
|
263 |
+
tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
|
264 |
+
text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
|
265 |
+
|
266 |
+
prompts = train_util.load_prompts(args.sample_prompts)
|
267 |
+
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
|
268 |
+
with accelerator.autocast(), torch.no_grad():
|
269 |
+
for prompt_dict in prompts:
|
270 |
+
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
|
271 |
+
if p not in sample_prompts_te_outputs:
|
272 |
+
logger.info(f"cache Text Encoder outputs for prompt: {p}")
|
273 |
+
tokens_and_masks = tokenize_strategy.tokenize(p)
|
274 |
+
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
|
275 |
+
tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask
|
276 |
+
)
|
277 |
+
self.sample_prompts_te_outputs = sample_prompts_te_outputs
|
278 |
+
|
279 |
+
# 添加conditions缓存逻辑
|
280 |
+
if args.sample_images is not None:
|
281 |
+
logger.info(f"cache conditions for sample images: {args.sample_images}")
|
282 |
+
|
283 |
+
# lc03lc
|
284 |
+
resize_transform = ResizeWithPadding(size=512, fill=255) if args.frame_num == 4 else ResizeWithPadding(size=352, fill=255)
|
285 |
+
img_transforms = transforms.Compose([
|
286 |
+
resize_transform,
|
287 |
+
transforms.ToTensor(),
|
288 |
+
transforms.Normalize([0.5], [0.5]),
|
289 |
+
])
|
290 |
+
|
291 |
+
if args.sample_images.endswith(".txt"):
|
292 |
+
with open(args.sample_images, "r", encoding="utf-8") as f:
|
293 |
+
lines = f.readlines()
|
294 |
+
sample_images = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
|
295 |
+
else:
|
296 |
+
raise NotImplementedError(f"sample_images file format not supported: {args.sample_images}")
|
297 |
+
|
298 |
+
prompts = train_util.load_prompts(args.sample_prompts)
|
299 |
+
conditions = {} # key: prompt, value: latents
|
300 |
+
|
301 |
+
with torch.no_grad():
|
302 |
+
for image, prompt_dict in zip(sample_images, prompts):
|
303 |
+
prompt = prompt_dict.get("prompt", "")
|
304 |
+
if prompt not in conditions:
|
305 |
+
logger.info(f"cache conditions for image: {image} with prompt: {prompt}")
|
306 |
+
image = img_transforms(np.array(load_image(image), dtype=np.uint8)).unsqueeze(0).to(vae.device, dtype=vae.dtype)
|
307 |
+
latents = self.encode_images_to_latents2(args, accelerator, vae, image)
|
308 |
+
# lc03lc
|
309 |
+
conditions[prompt] = latents
|
310 |
+
# if args.frame_num == 4:
|
311 |
+
# conditions[prompt] = latents[:,:,2*latents.shape[2]//3:latents.shape[2], 2*latents.shape[3]//3:latents.shape[3]].to("cpu")
|
312 |
+
# else:
|
313 |
+
# conditions[prompt] = latents[:,:,latents.shape[2]//2:latents.shape[2], :latents.shape[3]//2].to("cpu")
|
314 |
+
|
315 |
+
self.sample_conditions = conditions
|
316 |
+
|
317 |
+
accelerator.wait_for_everyone()
|
318 |
+
|
319 |
+
# move back to cpu
|
320 |
+
if not self.is_train_text_encoder(args):
|
321 |
+
logger.info("move CLIP-L back to cpu")
|
322 |
+
text_encoders[0].to("cpu")
|
323 |
+
logger.info("move t5XXL back to cpu")
|
324 |
+
text_encoders[1].to("cpu")
|
325 |
+
clean_memory_on_device(accelerator.device)
|
326 |
+
|
327 |
+
if not args.lowram:
|
328 |
+
logger.info("move vae and unet back to original device")
|
329 |
+
vae.to(org_vae_device)
|
330 |
+
unet.to(org_unet_device)
|
331 |
+
else:
|
332 |
+
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
|
333 |
+
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
334 |
+
text_encoders[1].to(accelerator.device)
|
335 |
+
|
336 |
+
# def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
337 |
+
# noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
338 |
+
|
339 |
+
# # get size embeddings
|
340 |
+
# orig_size = batch["original_sizes_hw"]
|
341 |
+
# crop_size = batch["crop_top_lefts"]
|
342 |
+
# target_size = batch["target_sizes_hw"]
|
343 |
+
# embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
|
344 |
+
|
345 |
+
# # concat embeddings
|
346 |
+
# encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
|
347 |
+
# vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
|
348 |
+
# text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
|
349 |
+
|
350 |
+
# noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
351 |
+
# return noise_pred
|
352 |
+
|
353 |
+
def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux):
|
354 |
+
text_encoders = text_encoder # for compatibility
|
355 |
+
text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)
|
356 |
+
# 直接使用预先计算的conditions
|
357 |
+
conditions = None
|
358 |
+
if self.sample_conditions is not None:
|
359 |
+
conditions = {k: v.to(accelerator.device) for k, v in self.sample_conditions.items()}
|
360 |
+
|
361 |
+
if not args.split_mode:
|
362 |
+
flux_train_utils.sample_images(
|
363 |
+
accelerator, args, epoch, global_step, flux, ae, text_encoder, self.sample_prompts_te_outputs, None, conditions
|
364 |
+
)
|
365 |
+
return
|
366 |
+
|
367 |
+
class FluxUpperLowerWrapper(torch.nn.Module):
|
368 |
+
def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.FluxLower, device: torch.device):
|
369 |
+
super().__init__()
|
370 |
+
self.flux_upper = flux_upper
|
371 |
+
self.flux_lower = flux_lower
|
372 |
+
self.target_device = device
|
373 |
+
|
374 |
+
def prepare_block_swap_before_forward(self):
|
375 |
+
pass
|
376 |
+
|
377 |
+
def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None):
|
378 |
+
self.flux_lower.to("cpu")
|
379 |
+
clean_memory_on_device(self.target_device)
|
380 |
+
self.flux_upper.to(self.target_device)
|
381 |
+
img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance, txt_attention_mask)
|
382 |
+
self.flux_upper.to("cpu")
|
383 |
+
clean_memory_on_device(self.target_device)
|
384 |
+
self.flux_lower.to(self.target_device)
|
385 |
+
return self.flux_lower(img, txt, vec, pe, txt_attention_mask)
|
386 |
+
|
387 |
+
wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device)
|
388 |
+
clean_memory_on_device(accelerator.device)
|
389 |
+
flux_train_utils.sample_images(
|
390 |
+
accelerator, args, epoch, global_step, wrapper, ae, text_encoder, self.sample_prompts_te_outputs, conditions
|
391 |
+
)
|
392 |
+
clean_memory_on_device(accelerator.device)
|
393 |
+
|
394 |
+
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
|
395 |
+
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
|
396 |
+
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
397 |
+
return noise_scheduler
|
398 |
+
|
399 |
+
def encode_images_to_latents(self, args, accelerator, vae, images):
|
400 |
+
# 获取图像尺寸
|
401 |
+
b, c, h, w = images.shape
|
402 |
+
|
403 |
+
# num_split = NUM_SPLIT
|
404 |
+
num_split = 2 if args.frame_num == 4 else 3
|
405 |
+
# 将图像分成三个部分
|
406 |
+
img_parts = [images[:,:,:,i*w//num_split:(i+1)*w//num_split] for i in range(num_split)]
|
407 |
+
# 分别编码
|
408 |
+
latents = [vae.encode(img) for img in img_parts]
|
409 |
+
# 在latent空间拼接回完整图像
|
410 |
+
latents = torch.cat(latents, dim=-1)
|
411 |
+
|
412 |
+
return latents
|
413 |
+
|
414 |
+
def encode_images_to_latents2(self, args, accelerator, vae, images):
|
415 |
+
# 获取图像尺寸
|
416 |
+
b, c, h, w = images.shape
|
417 |
+
# num_split = NUM_SPLIT
|
418 |
+
num_split = 2 if args.frame_num == 4 else 3
|
419 |
+
latents = vae.encode(images)
|
420 |
+
return latents
|
421 |
+
|
422 |
+
def encode_images_to_latents3(self, args, accelerator, vae, images):
|
423 |
+
b, c, h, w = images.shape
|
424 |
+
# Number of splits along each dimension
|
425 |
+
num_split = 3
|
426 |
+
# Check if the image can be evenly divided into 3x3 grid
|
427 |
+
assert h % num_split == 0 and w % num_split == 0, "Image dimensions must be divisible by 3."
|
428 |
+
|
429 |
+
# Height and width of each split
|
430 |
+
split_h, split_w = h // num_split, w // num_split
|
431 |
+
|
432 |
+
# Store latents for each split
|
433 |
+
latents = []
|
434 |
+
|
435 |
+
for i in range(num_split):
|
436 |
+
for j in range(num_split):
|
437 |
+
# Extract the (i, j) sub-image
|
438 |
+
img_part = images[:, :, i * split_h:(i + 1) * split_h, j * split_w:(j + 1) * split_w]
|
439 |
+
# Encode the sub-image using VAE
|
440 |
+
latent = vae.encode(img_part)
|
441 |
+
# Append the latent
|
442 |
+
latents.append(latent)
|
443 |
+
|
444 |
+
# Combine latents into a 3x3 grid in the latent space
|
445 |
+
# Latents list -> Tensor [num_split^2, b, latent_dim, h', w']
|
446 |
+
latents = torch.stack(latents, dim=0)
|
447 |
+
|
448 |
+
# Reshape into a 3x3 grid
|
449 |
+
# Shape: [num_split, num_split, b, latent_dim, h', w']
|
450 |
+
latents = latents.view(num_split, num_split, b, *latents.shape[2:])
|
451 |
+
|
452 |
+
# Combine the 3x3 grid along height and width in latent space
|
453 |
+
# Concatenate along width for each row, then concatenate rows along height
|
454 |
+
latents = torch.cat([torch.cat(latents[i], dim=-1) for i in range(num_split)], dim=-2)
|
455 |
+
|
456 |
+
# Final shape: [b, latent_dim, h', w']
|
457 |
+
return latents
|
458 |
+
|
459 |
+
def shift_scale_latents(self, args, latents):
|
460 |
+
return latents
|
461 |
+
|
462 |
+
def get_noise_pred_and_target(
|
463 |
+
self,
|
464 |
+
args,
|
465 |
+
accelerator,
|
466 |
+
noise_scheduler,
|
467 |
+
latents,
|
468 |
+
batch,
|
469 |
+
text_encoder_conds,
|
470 |
+
unet: flux_models.Flux,
|
471 |
+
network,
|
472 |
+
weight_dtype,
|
473 |
+
train_unet,
|
474 |
+
):
|
475 |
+
# Sample noise that we'll add to the latents
|
476 |
+
noise = torch.randn_like(latents)
|
477 |
+
bsz = latents.shape[0]
|
478 |
+
|
479 |
+
# get noisy model input and timesteps
|
480 |
+
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
|
481 |
+
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
|
482 |
+
)
|
483 |
+
|
484 |
+
# pack latents and get img_ids
|
485 |
+
# yiren ? need modify?
|
486 |
+
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
|
487 |
+
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
|
488 |
+
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
|
489 |
+
|
490 |
+
# get guidance
|
491 |
+
# ensure guidance_scale in args is float
|
492 |
+
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
|
493 |
+
|
494 |
+
# ensure the hidden state will require grad
|
495 |
+
if args.gradient_checkpointing:
|
496 |
+
noisy_model_input.requires_grad_(True)
|
497 |
+
for t in text_encoder_conds:
|
498 |
+
if t is not None and t.dtype.is_floating_point:
|
499 |
+
t.requires_grad_(True)
|
500 |
+
img_ids.requires_grad_(True)
|
501 |
+
guidance_vec.requires_grad_(True)
|
502 |
+
|
503 |
+
# Predict the noise residual
|
504 |
+
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
505 |
+
if not args.apply_t5_attn_mask:
|
506 |
+
t5_attn_mask = None
|
507 |
+
|
508 |
+
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
|
509 |
+
if not args.split_mode:
|
510 |
+
# normal forward
|
511 |
+
with accelerator.autocast():
|
512 |
+
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
513 |
+
model_pred = unet(
|
514 |
+
img=img,
|
515 |
+
img_ids=img_ids,
|
516 |
+
txt=t5_out,
|
517 |
+
txt_ids=txt_ids,
|
518 |
+
y=l_pooled,
|
519 |
+
timesteps=timesteps / 1000,
|
520 |
+
guidance=guidance_vec,
|
521 |
+
txt_attention_mask=t5_attn_mask,
|
522 |
+
)
|
523 |
+
else:
|
524 |
+
# split forward to reduce memory usage
|
525 |
+
assert network.train_blocks == "single", "train_blocks must be single for split mode"
|
526 |
+
with accelerator.autocast():
|
527 |
+
# move flux lower to cpu, and then move flux upper to gpu
|
528 |
+
unet.to("cpu")
|
529 |
+
clean_memory_on_device(accelerator.device)
|
530 |
+
self.flux_upper.to(accelerator.device)
|
531 |
+
|
532 |
+
# upper model does not require grad
|
533 |
+
with torch.no_grad():
|
534 |
+
intermediate_img, intermediate_txt, vec, pe = self.flux_upper(
|
535 |
+
img=packed_noisy_model_input,
|
536 |
+
img_ids=img_ids,
|
537 |
+
txt=t5_out,
|
538 |
+
txt_ids=txt_ids,
|
539 |
+
y=l_pooled,
|
540 |
+
timesteps=timesteps / 1000,
|
541 |
+
guidance=guidance_vec,
|
542 |
+
txt_attention_mask=t5_attn_mask,
|
543 |
+
)
|
544 |
+
|
545 |
+
# move flux upper back to cpu, and then move flux lower to gpu
|
546 |
+
self.flux_upper.to("cpu")
|
547 |
+
clean_memory_on_device(accelerator.device)
|
548 |
+
unet.to(accelerator.device)
|
549 |
+
|
550 |
+
# lower model requires grad
|
551 |
+
intermediate_img.requires_grad_(True)
|
552 |
+
intermediate_txt.requires_grad_(True)
|
553 |
+
vec.requires_grad_(True)
|
554 |
+
pe.requires_grad_(True)
|
555 |
+
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
|
556 |
+
|
557 |
+
return model_pred
|
558 |
+
|
559 |
+
model_pred = call_dit(
|
560 |
+
img=packed_noisy_model_input,
|
561 |
+
img_ids=img_ids,
|
562 |
+
t5_out=t5_out,
|
563 |
+
txt_ids=txt_ids,
|
564 |
+
l_pooled=l_pooled,
|
565 |
+
timesteps=timesteps,
|
566 |
+
guidance_vec=guidance_vec,
|
567 |
+
t5_attn_mask=t5_attn_mask,
|
568 |
+
)
|
569 |
+
|
570 |
+
# unpack latents
|
571 |
+
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
|
572 |
+
|
573 |
+
# apply model prediction type
|
574 |
+
model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
575 |
+
|
576 |
+
# flow matching loss: this is different from SD3
|
577 |
+
target = noise - latents
|
578 |
+
|
579 |
+
# differential output preservation
|
580 |
+
if "custom_attributes" in batch:
|
581 |
+
diff_output_pr_indices = []
|
582 |
+
for i, custom_attributes in enumerate(batch["custom_attributes"]):
|
583 |
+
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
|
584 |
+
diff_output_pr_indices.append(i)
|
585 |
+
|
586 |
+
if len(diff_output_pr_indices) > 0:
|
587 |
+
network.set_multiplier(0.0)
|
588 |
+
with torch.no_grad():
|
589 |
+
model_pred_prior = call_dit(
|
590 |
+
img=packed_noisy_model_input[diff_output_pr_indices],
|
591 |
+
img_ids=img_ids[diff_output_pr_indices],
|
592 |
+
t5_out=t5_out[diff_output_pr_indices],
|
593 |
+
txt_ids=txt_ids[diff_output_pr_indices],
|
594 |
+
l_pooled=l_pooled[diff_output_pr_indices],
|
595 |
+
timesteps=timesteps[diff_output_pr_indices],
|
596 |
+
guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None,
|
597 |
+
t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None,
|
598 |
+
)
|
599 |
+
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
|
600 |
+
|
601 |
+
model_pred_prior = flux_utils.unpack_latents(model_pred_prior, packed_latent_height, packed_latent_width)
|
602 |
+
model_pred_prior, _ = flux_train_utils.apply_model_prediction_type(
|
603 |
+
args,
|
604 |
+
model_pred_prior,
|
605 |
+
noisy_model_input[diff_output_pr_indices],
|
606 |
+
sigmas[diff_output_pr_indices] if sigmas is not None else None,
|
607 |
+
)
|
608 |
+
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
|
609 |
+
|
610 |
+
# elimilate the loss in the left top quarter of the image
|
611 |
+
h, w = target.shape[2], target.shape[3]
|
612 |
+
# num_split = NUM_SPLIT
|
613 |
+
num_split = 2 if args.frame_num == 4 else 3
|
614 |
+
# target[:, :, :, :w//num_split] = model_pred[:, :, :, :w//num_split]
|
615 |
+
# target[:, :, :, :w//num_split] = model_pred[:, :, :, :w//num_split]
|
616 |
+
target[:, :, 2*h//num_split:h, 2*w//num_split:w] = model_pred[:, :, 2*h//num_split:h, 2*w//num_split:w]
|
617 |
+
|
618 |
+
|
619 |
+
return model_pred, target, timesteps, None, weighting
|
620 |
+
|
621 |
+
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
622 |
+
return loss
|
623 |
+
|
624 |
+
def get_sai_model_spec(self, args):
|
625 |
+
return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev")
|
626 |
+
|
627 |
+
def update_metadata(self, metadata, args):
|
628 |
+
metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask
|
629 |
+
metadata["ss_weighting_scheme"] = args.weighting_scheme
|
630 |
+
metadata["ss_logit_mean"] = args.logit_mean
|
631 |
+
metadata["ss_logit_std"] = args.logit_std
|
632 |
+
metadata["ss_mode_scale"] = args.mode_scale
|
633 |
+
metadata["ss_guidance_scale"] = args.guidance_scale
|
634 |
+
metadata["ss_timestep_sampling"] = args.timestep_sampling
|
635 |
+
metadata["ss_sigmoid_scale"] = args.sigmoid_scale
|
636 |
+
metadata["ss_model_prediction_type"] = args.model_prediction_type
|
637 |
+
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
|
638 |
+
|
639 |
+
def is_text_encoder_not_needed_for_training(self, args):
|
640 |
+
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
|
641 |
+
|
642 |
+
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
|
643 |
+
if index == 0: # CLIP-L
|
644 |
+
return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder)
|
645 |
+
else: # T5XXL
|
646 |
+
text_encoder.encoder.embed_tokens.requires_grad_(True)
|
647 |
+
|
648 |
+
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
|
649 |
+
if index == 0: # CLIP-L
|
650 |
+
logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}")
|
651 |
+
text_encoder.to(te_weight_dtype) # fp8
|
652 |
+
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
|
653 |
+
else: # T5XXL
|
654 |
+
|
655 |
+
def prepare_fp8(text_encoder, target_dtype):
|
656 |
+
def forward_hook(module):
|
657 |
+
def forward(hidden_states):
|
658 |
+
hidden_gelu = module.act(module.wi_0(hidden_states))
|
659 |
+
hidden_linear = module.wi_1(hidden_states)
|
660 |
+
hidden_states = hidden_gelu * hidden_linear
|
661 |
+
hidden_states = module.dropout(hidden_states)
|
662 |
+
|
663 |
+
hidden_states = module.wo(hidden_states)
|
664 |
+
return hidden_states
|
665 |
+
|
666 |
+
return forward
|
667 |
+
|
668 |
+
for module in text_encoder.modules():
|
669 |
+
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
|
670 |
+
# print("set", module.__class__.__name__, "to", target_dtype)
|
671 |
+
module.to(target_dtype)
|
672 |
+
if module.__class__.__name__ in ["T5DenseGatedActDense"]:
|
673 |
+
# print("set", module.__class__.__name__, "hooks")
|
674 |
+
module.forward = forward_hook(module)
|
675 |
+
|
676 |
+
if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype:
|
677 |
+
logger.info(f"T5XXL already prepared for fp8")
|
678 |
+
else:
|
679 |
+
logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks")
|
680 |
+
text_encoder.to(te_weight_dtype) # fp8
|
681 |
+
prepare_fp8(text_encoder, weight_dtype)
|
682 |
+
|
683 |
+
|
684 |
+
def setup_parser() -> argparse.ArgumentParser:
|
685 |
+
parser = train_network.setup_parser()
|
686 |
+
flux_train_utils.add_flux_train_arguments(parser)
|
687 |
+
|
688 |
+
parser.add_argument(
|
689 |
+
"--split_mode",
|
690 |
+
action="store_true",
|
691 |
+
help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required"
|
692 |
+
+ "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要",
|
693 |
+
)
|
694 |
+
|
695 |
+
parser.add_argument(
|
696 |
+
'--frame_num',
|
697 |
+
type=int,
|
698 |
+
choices=[4, 9],
|
699 |
+
required=True,
|
700 |
+
help="The number of steps in the generated step diagram (choose 4 or 9)"
|
701 |
+
)
|
702 |
+
return parser
|
703 |
+
|
704 |
+
|
705 |
+
if __name__ == "__main__":
|
706 |
+
parser = setup_parser()
|
707 |
+
|
708 |
+
args = parser.parse_args()
|
709 |
+
train_util.verify_command_line_training_args(args)
|
710 |
+
args = train_util.read_config_from_file(args, parser)
|
711 |
+
|
712 |
+
trainer = FluxNetworkTrainer()
|
713 |
+
trainer.train(args)
|
gradio_app.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import gradio as gr
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import random
|
6 |
+
from PIL import Image
|
7 |
+
from accelerate import Accelerator
|
8 |
+
import os
|
9 |
+
import time
|
10 |
+
from torchvision import transforms
|
11 |
+
from safetensors.torch import load_file
|
12 |
+
from networks import lora_flux
|
13 |
+
from library import flux_utils, flux_train_utils_recraft as flux_train_utils, strategy_flux
|
14 |
+
import logging
|
15 |
+
|
16 |
+
# Set up logger
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
logging.basicConfig(level=logging.DEBUG)
|
19 |
+
|
20 |
+
# Ensure necessary devices are available
|
21 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
22 |
+
accelerator = Accelerator(mixed_precision='bf16', device_placement=True)
|
23 |
+
|
24 |
+
# Model paths (replace these with your actual model paths)
|
25 |
+
BASE_FLUX_CHECKPOINT="/tiamat-NAS/songyiren/FYP/liucheng/sd-scripts/MergeModel/6_Portrait/6_Portrait.safetensors"
|
26 |
+
LORA_WEIGHTS_PATH="/tiamat-NAS/songyiren/FYP/liucheng/sd-scripts/RecraftModel/6_Portrait/6_Portrait-step00025000.safetensors"
|
27 |
+
CLIP_L_PATH="/tiamat-NAS/hailong/storage_backup/models/stabilityai/stable-diffusion-3-medium/text_encoders/clip_l.safetensors"
|
28 |
+
T5XXL_PATH="/tiamat-NAS/hailong/storage_backup/models/stabilityai/stable-diffusion-3-medium/text_encoders/t5xxl_fp16.safetensors"
|
29 |
+
AE_PATH="/tiamat-vePFS/share_data/storage/huggingface/models/black-forest-labs/FLUX.1-dev/ae.safetensors"
|
30 |
+
|
31 |
+
# Load model function
|
32 |
+
def load_target_model():
|
33 |
+
logger.info("Loading models...")
|
34 |
+
try:
|
35 |
+
_, model = flux_utils.load_flow_model(
|
36 |
+
BASE_FLUX_CHECKPOINT, torch.float8_e4m3fn, "cpu", disable_mmap=False
|
37 |
+
)
|
38 |
+
clip_l = flux_utils.load_clip_l(CLIP_L_PATH, torch.bfloat16, "cpu", disable_mmap=False)
|
39 |
+
clip_l.eval()
|
40 |
+
t5xxl = flux_utils.load_t5xxl(T5XXL_PATH, torch.bfloat16, "cpu", disable_mmap=False)
|
41 |
+
t5xxl.eval()
|
42 |
+
ae = flux_utils.load_ae(AE_PATH, torch.bfloat16, "cpu", disable_mmap=False)
|
43 |
+
logger.info("Models loaded successfully.")
|
44 |
+
return model, [clip_l, t5xxl], ae
|
45 |
+
except Exception as e:
|
46 |
+
logger.error(f"Error loading models: {e}")
|
47 |
+
raise
|
48 |
+
|
49 |
+
# Image pre-processing (resize and padding)
|
50 |
+
class ResizeWithPadding:
|
51 |
+
def __init__(self, size, fill=255):
|
52 |
+
self.size = size
|
53 |
+
self.fill = fill
|
54 |
+
|
55 |
+
def __call__(self, img):
|
56 |
+
if isinstance(img, np.ndarray):
|
57 |
+
img = Image.fromarray(img)
|
58 |
+
elif not isinstance(img, Image.Image):
|
59 |
+
raise TypeError("Input must be a PIL Image or a NumPy array")
|
60 |
+
|
61 |
+
width, height = img.size
|
62 |
+
|
63 |
+
if width == height:
|
64 |
+
img = img.resize((self.size, self.size), Image.LANCZOS)
|
65 |
+
else:
|
66 |
+
max_dim = max(width, height)
|
67 |
+
new_img = Image.new("RGB", (max_dim, max_dim), (self.fill, self.fill, self.fill))
|
68 |
+
new_img.paste(img, ((max_dim - width) // 2, (max_dim - height) // 2))
|
69 |
+
img = new_img.resize((self.size, self.size), Image.LANCZOS)
|
70 |
+
return img
|
71 |
+
|
72 |
+
# The function to generate image from a prompt and conditional image
|
73 |
+
def infer(prompt, sample_image, frame_num, seed=0, randomize_seed=False):
|
74 |
+
logger.info(f"Started generating image with prompt: {prompt}")
|
75 |
+
|
76 |
+
# Load models
|
77 |
+
model, [clip_l, t5xxl], ae = load_target_model()
|
78 |
+
|
79 |
+
model.eval()
|
80 |
+
clip_l.eval()
|
81 |
+
t5xxl.eval()
|
82 |
+
ae.eval()
|
83 |
+
|
84 |
+
# LoRA
|
85 |
+
multiplier = 1.0
|
86 |
+
weights_sd = load_file(LORA_WEIGHTS_PATH)
|
87 |
+
lora_model, _ = lora_flux.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd,
|
88 |
+
True)
|
89 |
+
|
90 |
+
lora_model.apply_to([clip_l, t5xxl], model)
|
91 |
+
info = lora_model.load_state_dict(weights_sd, strict=True)
|
92 |
+
logger.info(f"Loaded LoRA weights from {LORA_WEIGHTS_PATH}: {info}")
|
93 |
+
lora_model.eval()
|
94 |
+
lora_model.to("cuda")
|
95 |
+
|
96 |
+
# Process the seed
|
97 |
+
if randomize_seed:
|
98 |
+
seed = random.randint(0, np.iinfo(np.int32).max)
|
99 |
+
logger.debug(f"Using seed: {seed}")
|
100 |
+
|
101 |
+
# Preprocess the conditional image
|
102 |
+
resize_transform = ResizeWithPadding(size=512) if frame_num == 4 else ResizeWithPadding(size=352)
|
103 |
+
img_transforms = transforms.Compose([
|
104 |
+
resize_transform,
|
105 |
+
transforms.ToTensor(),
|
106 |
+
transforms.Normalize([0.5], [0.5]),
|
107 |
+
])
|
108 |
+
image = img_transforms(np.array(sample_image, dtype=np.uint8)).unsqueeze(0).to(
|
109 |
+
device=device,
|
110 |
+
dtype=torch.bfloat16
|
111 |
+
)
|
112 |
+
logger.debug("Conditional image preprocessed.")
|
113 |
+
|
114 |
+
# Encode the image to latents
|
115 |
+
ae.to("cuda")
|
116 |
+
latents = ae.encode(image)
|
117 |
+
logger.debug("Image encoded to latents.")
|
118 |
+
|
119 |
+
conditions = {}
|
120 |
+
conditions[prompt] = latents.to("cpu")
|
121 |
+
|
122 |
+
ae.to("cpu")
|
123 |
+
clip_l.to("cuda")
|
124 |
+
t5xxl.to("cuda")
|
125 |
+
|
126 |
+
# Encode the prompt
|
127 |
+
tokenize_strategy = strategy_flux.FluxTokenizeStrategy(512)
|
128 |
+
text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(True)
|
129 |
+
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
130 |
+
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, True)
|
131 |
+
|
132 |
+
logger.debug("Prompt encoded.")
|
133 |
+
|
134 |
+
# Prepare the noise and other parameters
|
135 |
+
width = 1024 if frame_num == 4 else 1056
|
136 |
+
height = 1024 if frame_num == 4 else 1056
|
137 |
+
|
138 |
+
height = max(64, height - height % 16)
|
139 |
+
width = max(64, width - width % 16)
|
140 |
+
|
141 |
+
packed_latent_height = height // 16
|
142 |
+
packed_latent_width = width // 16
|
143 |
+
|
144 |
+
noise = torch.randn(1, packed_latent_height * packed_latent_width, 16 * 2 * 2, device=device, dtype=torch.float16)
|
145 |
+
logger.debug("Noise prepared.")
|
146 |
+
|
147 |
+
# Generate the image
|
148 |
+
timesteps = flux_train_utils.get_schedule(20, noise.shape[1], shift=True) # Sample steps = 20
|
149 |
+
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(device)
|
150 |
+
|
151 |
+
t5_attn_mask = t5_attn_mask.to(device)
|
152 |
+
ae_outputs = conditions[prompt]
|
153 |
+
|
154 |
+
logger.debug("Image generation parameters set.")
|
155 |
+
|
156 |
+
args = lambda: None
|
157 |
+
args.frame_num = frame_num
|
158 |
+
|
159 |
+
clip_l.to("cpu")
|
160 |
+
t5xxl.to("cpu")
|
161 |
+
|
162 |
+
torch.cuda.empty_cache()
|
163 |
+
model.to("cuda")
|
164 |
+
|
165 |
+
# import pdb
|
166 |
+
# pdb.set_trace()
|
167 |
+
|
168 |
+
# Run the denoising process
|
169 |
+
with accelerator.autocast(), torch.no_grad():
|
170 |
+
x = flux_train_utils.denoise(
|
171 |
+
args, model, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=1.0, t5_attn_mask=t5_attn_mask, ae_outputs=ae_outputs
|
172 |
+
)
|
173 |
+
logger.debug("Denoising process completed.")
|
174 |
+
|
175 |
+
# Decode the final image
|
176 |
+
x = x.float()
|
177 |
+
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
|
178 |
+
model.to("cpu")
|
179 |
+
ae.to("cuda")
|
180 |
+
with accelerator.autocast(), torch.no_grad():
|
181 |
+
x = ae.decode(x)
|
182 |
+
logger.debug("Latents decoded into image.")
|
183 |
+
ae.to("cpu")
|
184 |
+
|
185 |
+
# Convert the tensor to an image
|
186 |
+
x = x.clamp(-1, 1)
|
187 |
+
x = x.permute(0, 2, 3, 1)
|
188 |
+
generated_image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
|
189 |
+
|
190 |
+
logger.info("Image generation completed.")
|
191 |
+
return generated_image
|
192 |
+
|
193 |
+
# Gradio interface
|
194 |
+
with gr.Blocks() as demo:
|
195 |
+
gr.Markdown("## FLUX Image Generation")
|
196 |
+
|
197 |
+
with gr.Row():
|
198 |
+
# Input for the prompt
|
199 |
+
prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here", lines=1)
|
200 |
+
|
201 |
+
# File upload for image
|
202 |
+
sample_image = gr.Image(label="Upload a Conditional Image", type="pil")
|
203 |
+
|
204 |
+
# Frame number selection
|
205 |
+
frame_num = gr.Radio([4, 9], label="Select Frame Number", value=4)
|
206 |
+
|
207 |
+
# Seed and randomize seed options
|
208 |
+
seed = gr.Slider(0, np.iinfo(np.int32).max, step=1, label="Seed", value=0)
|
209 |
+
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
|
210 |
+
|
211 |
+
# Run Button
|
212 |
+
run_button = gr.Button("Generate Image")
|
213 |
+
|
214 |
+
# Output result
|
215 |
+
result_image = gr.Image(label="Generated Image")
|
216 |
+
|
217 |
+
run_button.click(
|
218 |
+
fn=infer,
|
219 |
+
inputs=[prompt, sample_image, frame_num, seed, randomize_seed],
|
220 |
+
outputs=[result_image]
|
221 |
+
)
|
222 |
+
|
223 |
+
# Launch the Gradio app
|
224 |
+
demo.launch(server_port=8289, server_name="0.0.0.0", share=True)
|
225 |
+
|
226 |
+
|
227 |
+
# prompt = "1girl"
|
228 |
+
# sample_image = Image.open("/tiamat-NAS/songyiren/FYP/liucheng/sd-scripts/MergeModel/test/1.png") # 使用一个测试图像
|
229 |
+
# frame_num = 9
|
230 |
+
# seed = 42
|
231 |
+
# randomize_seed = False
|
232 |
+
# result = infer(prompt, sample_image, frame_num, seed, randomize_seed)
|
233 |
+
# result.save('asy_results/generated_image.png')
|
id_rsa
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-----BEGIN OPENSSH PRIVATE KEY-----
|
2 |
+
b3BlbnNzaC1rZXktdjEAAAAACmFlczI1Ni1jdHIAAAAGYmNyeXB0AAAAGAAAABAcjqc3DU
|
3 |
+
L8b9ja3ALHmkowAAAAGAAAAAEAAAIXAAAAB3NzaC1yc2EAAAADAQABAAACAQDE/44P8+xW
|
4 |
+
HQVsebSAQhmu0nHf82prYlt2OMWH/xCCzvs8b9Z8QppPRQd0mexgyvk5jxttXXT22nQdSP
|
5 |
+
ILQsvaOSFJJGBR/RWUI0dNFdR3CdJqWWH331YYEsYe+SWwpUhVSyg0Ys9KhjkWYSVpwv+V
|
6 |
+
VKf65QVydeH4OkCe/zuJClIWOl7U/XEwksLjSj/FzIinWADamdnREBbGHK62l5Y4gHTtGI
|
7 |
+
DoHqvqbGnOWbDic3YSEyOwXdA4bkjqabuQKVAvsr/p9I6+OqAv2yHAHAG6nSclHz7a8pWi
|
8 |
+
FhTeQXVDMOB9W34mfyqN8/lcFlNLgYPab757p4fjk6Ox6TjCSNI6JfJeX/yQeqBwi4pR8h
|
9 |
+
YbOvui+NnHLy2bgWBxPinDb752qDpXQuy65aDGYWWUhGLd8LWaTQkx3LD/+iGR5bOwV18X
|
10 |
+
HuexVVh4pSGVidJ3m1fqpb7AGGcjL0inTvdnz4r3oGfhFREvuhjQo/O1mzgxElcRK2P+xG
|
11 |
+
nJ2eZzC8iwx4UjCotkTdQ3S/2uVI7CFsOnV3oqQ403xcp5OzanH/e/0sTukkM4UWUmUXIG
|
12 |
+
QZxicctZJCndVxffNHFutCL2uIwr2evpwwBjE5GasL4VcxaQlcP1KMwPjYyzZUR62qWRDg
|
13 |
+
sQNlTx+QQWRl4VWEU/jzvehXMza1vl9bbCo+Y9aSVmAQAAB1DCpgt/NwB5QcPJ3XXuqDJ+
|
14 |
+
NCS9iuXhqKGy0L/S0jwuxYNpHjoJ2fptVJo5fgIjGJI3TcMqOHlRNVIZG8e2Cw9YbVivat
|
15 |
+
4WIZq57ZH645uE1XcngdbjxEZfCYozUATaD4piXWYAcsmszlG/r540lBSpGzxB6YESJJgk
|
16 |
+
KNh9+gXgMjZCk3NCTplvb3qdvoYYOvagGE68CqtZn10OavMgTwymG1fzrOQ9Biuq27VOhz
|
17 |
+
zS5ZYqdynWu/F7Vl3QxdS0Mbap7ratRdQlKgDx6WslsXVQwTKoDjyI0tykOVQxEjlwVhTu
|
18 |
+
loQRTJd43Po2BByMclbCM84rlZEViK/jaBc+evsjV0H6CqhJ0Sh28rfSuM/Rb1tYTzktEh
|
19 |
+
x7vd5fcTHRFFupxdAcusOaI4QfRNTQPyX+ShwS3q1CqbgcxWZutysPMpq4f2w7sa2L0Tu1
|
20 |
+
qUmGCxEkfKSUHO2sk0lIXlNG9p6cxR+fv5D3AcSK1LceyZuYLk45gGHpMlBPDs8R2grX1J
|
21 |
+
mTkO7D+EpHcCd0mbEZ37YBLOtERSBNohgVgOEc3SCTRQu5LrCj+or/47b7T2dU/XL/BE/6
|
22 |
+
oD1bYk+cSgT16xYOmqKlE7f+StbOVASkQ32Pfnv5b+JsbQRiEZXe9YKoHRRw4KRDiHNjMG
|
23 |
+
Zb+N8b4hpf2tOn6aWUVH471ciDlJBcbdyLQaupFM5QRwGobSdoAaQrmFIPzpIZ1rOVE8Jp
|
24 |
+
TeZ6CmpxiRNWeFXGojOp7x4GJc7oghdA9loJ5LRVD0Mw2K5VrD32+PMjJ1OZNJrXTZ+FW9
|
25 |
+
ujmjocpbL456Cuqr1yEFWzfy5liC9CItiEEp2sDRXjVnDwDOdO1DJe+1E+0ciTqkizxFnw
|
26 |
+
M/vTgF8BSVC+Vfg559A/0JRPsxyv4FdEpuOWOART2vnam+XgxjgJ4AoBDC44YNhlzkHNZ4
|
27 |
+
I08neqz45hs7B8pbMlE6DxUcJ2dKzzwIcImX0G4QGUdIdGyMC8y10hqI29XBMTa18LDy2i
|
28 |
+
n+tH3Q5E3hoRZVPpIeU+B9bLIznclwaOsDyFF+6E2ET4/Wwl9uiQA9XY9HCL2H1pUMr0Mi
|
29 |
+
QD04/qF6P2GCQqeI0w4NHtZRlQPiVdExtpHpgIPyyfCFM6G5p+3GH9ZL4/IMcj9R1ziiW8
|
30 |
+
6cZUR1kFHoueAMv63Q4hOzWiaGxdbrYbZbmmp0geO6tXDbeBbBi3YmBHwu0MQheZDbjMxW
|
31 |
+
EWYtawjK4wcOnoFkoMh+hU5AheuZJsHKUKEVxq2XauFjpc6dJQaCAVZqoVwMaalS12T8Y1
|
32 |
+
m6+Qa/V5vOZvsGdRB80iJOzA5aUbHb203jG7AmIuOhCaW8DQdY1ai69s6wXV7XXS/A0nh9
|
33 |
+
vjZiFXxb/nCNNJhk1LrzrcvqXl3LRx/5NTRMt+Bp12bsA9fP0U8tlMPHtGq+ctQosS1nWh
|
34 |
+
qrJ2lw1P2fsxWjRVu/4t+n1Hv19tz2stTwcBvUrzYaOvC5SsD05xypRLOfca/QsTlqcqSl
|
35 |
+
P4rnP43y5ACsBSXAsnQ2dIfumZFtE+yUy41CXotLbj2MwrPmXFDwh7mNcWYwOz49YnlL5m
|
36 |
+
j06liu12UDZd/HVxehHj7HSkbS1AXlvbwKK7kUpmY/ZPhi7ElqtdWHOBg+ixgTdLZVGQdV
|
37 |
+
lSAArBMkiuc3CsFNPRJqNtcCLrRm2BYkX3ZCaR9jQ19a2jgBPex4FtOxSgZCcH101RXnKv
|
38 |
+
bqnn1hMbJnOUeaMU2w3SqUs2s3sR5RONnn97tfnKGwZRzzg7UGtlxs0yy/MqIl6FTe+HrD
|
39 |
+
tz/6mK7afM8fp+AoKqrE0S6atjcuIiu71BJ7T7NlOi/AobNZAmX1FeE8qSxg+IV8Indkn/
|
40 |
+
wkVpQ+WnbxPNmsBxJbSLRUNjOQahyzi7+e/sB1eBI1zlEjLW/rnk3SMX2xD3umQsYK6Zce
|
41 |
+
MsamiTJh+qrV9NF/+6y80iYtcsKkTU+MESETFJuWAcD0lfP+jjJPGN0XiCumaACMKPtYFz
|
42 |
+
qMyiPzOnxirPikNTZQ7FAM/YnVauwBjLG/NCmTsHKtCsJYSlNFn+OFCcyAeZjKJtwev6Aq
|
43 |
+
nf+l8YK0ieOdjhPTGzyBYbo6r/XndiVN1bKxS9xT88nzn3ZPKggT2vd5bUm133F2Z9qeRz
|
44 |
+
o9XyEW/CuoxjfBLbxIqCJh7Ow5oBDA4YYCxBjyP8Q2TlrjVpntTYh+fkAzeo9H8uNQ4Jop
|
45 |
+
Mmc/56IObSiLVzTcyRaNHcnx8hDi0QjgOPKmIcVzqdStgmyZX0nBxiboF9nPIUICkUUgXN
|
46 |
+
SqidyUZRjTGWcmuVESrpfS5eDYD7pgpwpogsC9Cr7wyXbZSWaCnYvmZz3B+UAoo0BNH4az
|
47 |
+
dHl8ZC7dhmPFUJj+gCxkOrYA1lkYczR23g2EpwVAELQnjHPzShtvLYniWIKtqBPk7tLrvK
|
48 |
+
4PRb3IvyEGFIRNjvyGewsB4FJce470EK1aGJDRDf7r0UCxM5KU5dMUZQs529EWVkO1gt6l
|
49 |
+
y1uJAmV7ycLKV43ExoZhP5wwE=
|
50 |
+
-----END OPENSSH PRIVATE KEY-----
|
requirements.txt
CHANGED
@@ -1,6 +1,47 @@
|
|
1 |
-
accelerate
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==0.33.0
|
2 |
+
transformers==4.44.0
|
3 |
+
diffusers[torch]==0.25.0
|
4 |
+
ftfy==6.1.1
|
5 |
+
# albumentations==1.3.0
|
6 |
+
opencv-python==4.8.1.78
|
7 |
+
einops==0.7.0
|
8 |
+
pytorch-lightning==1.9.0
|
9 |
+
bitsandbytes==0.44.0
|
10 |
+
prodigyopt==1.0
|
11 |
+
lion-pytorch==0.0.6
|
12 |
+
came_pytorch==0.1.3
|
13 |
+
schedulefree==1.4
|
14 |
+
tensorboard
|
15 |
+
safetensors==0.4.4
|
16 |
+
# gradio==3.16.2
|
17 |
+
altair==4.2.2
|
18 |
+
easygui==0.98.3
|
19 |
+
toml==0.10.2
|
20 |
+
voluptuous==0.13.1
|
21 |
+
huggingface-hub==0.24.5
|
22 |
+
# for Image utils
|
23 |
+
imagesize==1.4.1
|
24 |
+
numpy<=2.0
|
25 |
+
# for BLIP captioning
|
26 |
+
# requests==2.28.2
|
27 |
+
# timm==0.6.12
|
28 |
+
# fairscale==0.4.13
|
29 |
+
# for WD14 captioning (tensorflow)
|
30 |
+
# tensorflow==2.10.1
|
31 |
+
# for WD14 captioning (onnx)
|
32 |
+
# onnx==1.15.0
|
33 |
+
# onnxruntime-gpu==1.17.1
|
34 |
+
# onnxruntime==1.17.1
|
35 |
+
# for cuda 12.1(default 11.8)
|
36 |
+
# onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/
|
37 |
+
|
38 |
+
# this is for onnx:
|
39 |
+
# protobuf==3.20.3
|
40 |
+
# open clip for SDXL
|
41 |
+
# open-clip-torch==2.20.0
|
42 |
+
# For logging
|
43 |
+
rich==13.7.0
|
44 |
+
# for T5XXL tokenizer (SD3/FLUX)
|
45 |
+
sentencepiece==0.2.0
|
46 |
+
# for kohya_ss library
|
47 |
+
-e .
|
setup.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup, find_packages
|
2 |
+
|
3 |
+
setup(name = "library", packages = find_packages())
|
split_asylora.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from safetensors import safe_open
|
4 |
+
from safetensors.torch import save_file
|
5 |
+
|
6 |
+
def main():
|
7 |
+
parser = argparse.ArgumentParser()
|
8 |
+
parser.add_argument('--asylora_path', type=str, required=True, help="Path to the input asylora file.")
|
9 |
+
parser.add_argument('--output_path', type=str, required=True, help="Path to save the modified safetensors file.")
|
10 |
+
parser.add_argument('--lora_up', type=int, required=True, help="The target lora_up value.")
|
11 |
+
|
12 |
+
args = parser.parse_args()
|
13 |
+
|
14 |
+
output_dir = os.path.dirname(args.output_path)
|
15 |
+
if not os.path.exists(output_dir):
|
16 |
+
os.makedirs(output_dir)
|
17 |
+
|
18 |
+
with safe_open(args.asylora_path, framework="pt") as f:
|
19 |
+
tensor_dict = {key: f.get_tensor(key) for key in f.keys()}
|
20 |
+
|
21 |
+
modified_dict = {}
|
22 |
+
|
23 |
+
for key, tensor in tensor_dict.items():
|
24 |
+
if 'lora_ups' in key:
|
25 |
+
lora_up_index = int(key.split('.')[2])
|
26 |
+
if lora_up_index != args.lora_up - 1:
|
27 |
+
continue
|
28 |
+
else:
|
29 |
+
new_key = key.replace(f'lora_ups.{lora_up_index}.', 'lora_up.')
|
30 |
+
modified_dict[new_key] = tensor
|
31 |
+
else:
|
32 |
+
modified_dict[key] = tensor
|
33 |
+
|
34 |
+
save_file(modified_dict, args.output_path)
|
35 |
+
|
36 |
+
if __name__ == "__main__":
|
37 |
+
main()
|
train_network.py
ADDED
@@ -0,0 +1,1479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import argparse
|
3 |
+
import math
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
import random
|
7 |
+
import time
|
8 |
+
import json
|
9 |
+
from multiprocessing import Value
|
10 |
+
from typing import Any, List
|
11 |
+
import toml
|
12 |
+
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
import torch
|
16 |
+
from library.device_utils import init_ipex, clean_memory_on_device
|
17 |
+
|
18 |
+
init_ipex()
|
19 |
+
|
20 |
+
from accelerate.utils import set_seed
|
21 |
+
from diffusers import DDPMScheduler
|
22 |
+
from library import deepspeed_utils, model_util, strategy_base, strategy_sd
|
23 |
+
|
24 |
+
import library.train_util as train_util
|
25 |
+
from library.train_util import DreamBoothDataset
|
26 |
+
import library.config_util as config_util
|
27 |
+
from library.config_util import (
|
28 |
+
ConfigSanitizer,
|
29 |
+
BlueprintGenerator,
|
30 |
+
)
|
31 |
+
import library.huggingface_util as huggingface_util
|
32 |
+
import library.custom_train_functions as custom_train_functions
|
33 |
+
from library.custom_train_functions import (
|
34 |
+
apply_snr_weight,
|
35 |
+
get_weighted_text_embeddings,
|
36 |
+
prepare_scheduler_for_custom_training,
|
37 |
+
scale_v_prediction_loss_like_noise_prediction,
|
38 |
+
add_v_prediction_like_loss,
|
39 |
+
apply_debiased_estimation,
|
40 |
+
apply_masked_loss,
|
41 |
+
)
|
42 |
+
from library.utils import setup_logging, add_logging_arguments
|
43 |
+
|
44 |
+
setup_logging()
|
45 |
+
import logging
|
46 |
+
|
47 |
+
logger = logging.getLogger(__name__)
|
48 |
+
|
49 |
+
|
50 |
+
class NetworkTrainer:
|
51 |
+
def __init__(self):
|
52 |
+
self.vae_scale_factor = 0.18215
|
53 |
+
self.is_sdxl = False
|
54 |
+
|
55 |
+
# TODO 他のスクリプトと共通化する
|
56 |
+
def generate_step_logs(
|
57 |
+
self,
|
58 |
+
args: argparse.Namespace,
|
59 |
+
current_loss,
|
60 |
+
avr_loss,
|
61 |
+
lr_scheduler,
|
62 |
+
lr_descriptions,
|
63 |
+
keys_scaled=None,
|
64 |
+
mean_norm=None,
|
65 |
+
maximum_norm=None,
|
66 |
+
):
|
67 |
+
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
68 |
+
|
69 |
+
if keys_scaled is not None:
|
70 |
+
logs["max_norm/keys_scaled"] = keys_scaled
|
71 |
+
logs["max_norm/average_key_norm"] = mean_norm
|
72 |
+
logs["max_norm/max_key_norm"] = maximum_norm
|
73 |
+
|
74 |
+
lrs = lr_scheduler.get_last_lr()
|
75 |
+
for i, lr in enumerate(lrs):
|
76 |
+
if lr_descriptions is not None:
|
77 |
+
lr_desc = lr_descriptions[i]
|
78 |
+
else:
|
79 |
+
idx = i - (0 if args.network_train_unet_only else -1)
|
80 |
+
if idx == -1:
|
81 |
+
lr_desc = "textencoder"
|
82 |
+
else:
|
83 |
+
if len(lrs) > 2:
|
84 |
+
lr_desc = f"group{idx}"
|
85 |
+
else:
|
86 |
+
lr_desc = "unet"
|
87 |
+
|
88 |
+
logs[f"lr/{lr_desc}"] = lr
|
89 |
+
|
90 |
+
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
|
91 |
+
# tracking d*lr value
|
92 |
+
logs[f"lr/d*lr/{lr_desc}"] = (
|
93 |
+
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
|
94 |
+
)
|
95 |
+
|
96 |
+
return logs
|
97 |
+
|
98 |
+
def assert_extra_args(self, args, train_dataset_group):
|
99 |
+
train_dataset_group.verify_bucket_reso_steps(64)
|
100 |
+
|
101 |
+
def load_target_model(self, args, weight_dtype, accelerator):
|
102 |
+
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
103 |
+
|
104 |
+
# モデルに xformers とか memory efficient attention を組み込む
|
105 |
+
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
106 |
+
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
|
107 |
+
vae.set_use_memory_efficient_attention_xformers(args.xformers)
|
108 |
+
|
109 |
+
return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet
|
110 |
+
|
111 |
+
def get_tokenize_strategy(self, args):
|
112 |
+
return strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
|
113 |
+
|
114 |
+
def get_tokenizers(self, tokenize_strategy: strategy_sd.SdTokenizeStrategy) -> List[Any]:
|
115 |
+
return [tokenize_strategy.tokenizer]
|
116 |
+
|
117 |
+
def get_latents_caching_strategy(self, args):
|
118 |
+
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
|
119 |
+
True, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
|
120 |
+
)
|
121 |
+
return latents_caching_strategy
|
122 |
+
|
123 |
+
def get_text_encoding_strategy(self, args):
|
124 |
+
return strategy_sd.SdTextEncodingStrategy(args.clip_skip)
|
125 |
+
|
126 |
+
def get_text_encoder_outputs_caching_strategy(self, args):
|
127 |
+
return None
|
128 |
+
|
129 |
+
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
130 |
+
"""
|
131 |
+
Returns a list of models that will be used for text encoding. SDXL uses wrapped and unwrapped models.
|
132 |
+
FLUX.1 and SD3 may cache some outputs of the text encoder, so return the models that will be used for encoding (not cached).
|
133 |
+
"""
|
134 |
+
return text_encoders
|
135 |
+
|
136 |
+
# returns a list of bool values indicating whether each text encoder should be trained
|
137 |
+
def get_text_encoders_train_flags(self, args, text_encoders):
|
138 |
+
return [True] * len(text_encoders) if self.is_train_text_encoder(args) else [False] * len(text_encoders)
|
139 |
+
|
140 |
+
def is_train_text_encoder(self, args):
|
141 |
+
return not args.network_train_unet_only
|
142 |
+
|
143 |
+
def cache_text_encoder_outputs_if_needed(self, args, accelerator, unet, vae, text_encoders, dataset, weight_dtype):
|
144 |
+
for t_enc in text_encoders:
|
145 |
+
t_enc.to(accelerator.device, dtype=weight_dtype)
|
146 |
+
|
147 |
+
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype, **kwargs):
|
148 |
+
noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample
|
149 |
+
return noise_pred
|
150 |
+
|
151 |
+
def all_reduce_network(self, accelerator, network):
|
152 |
+
for param in network.parameters():
|
153 |
+
if param.grad is not None:
|
154 |
+
param.grad = accelerator.reduce(param.grad, reduction="mean")
|
155 |
+
|
156 |
+
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoder, unet):
|
157 |
+
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizers[0], text_encoder, unet)
|
158 |
+
|
159 |
+
# region SD/SDXL
|
160 |
+
|
161 |
+
def post_process_network(self, args, accelerator, network, text_encoders, unet):
|
162 |
+
pass
|
163 |
+
|
164 |
+
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
|
165 |
+
noise_scheduler = DDPMScheduler(
|
166 |
+
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
167 |
+
)
|
168 |
+
prepare_scheduler_for_custom_training(noise_scheduler, device)
|
169 |
+
if args.zero_terminal_snr:
|
170 |
+
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
|
171 |
+
return noise_scheduler
|
172 |
+
|
173 |
+
def encode_images_to_latents(self, args, accelerator, vae, images):
|
174 |
+
return vae.encode(images).latent_dist.sample()
|
175 |
+
|
176 |
+
def shift_scale_latents(self, args, latents):
|
177 |
+
return latents * self.vae_scale_factor
|
178 |
+
|
179 |
+
def get_noise_pred_and_target(
|
180 |
+
self,
|
181 |
+
args,
|
182 |
+
accelerator,
|
183 |
+
noise_scheduler,
|
184 |
+
latents,
|
185 |
+
batch,
|
186 |
+
text_encoder_conds,
|
187 |
+
unet,
|
188 |
+
network,
|
189 |
+
weight_dtype,
|
190 |
+
train_unet,
|
191 |
+
):
|
192 |
+
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
193 |
+
# with noise offset and/or multires noise if specified
|
194 |
+
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
195 |
+
|
196 |
+
# ensure the hidden state will require grad
|
197 |
+
if args.gradient_checkpointing:
|
198 |
+
for x in noisy_latents:
|
199 |
+
x.requires_grad_(True)
|
200 |
+
for t in text_encoder_conds:
|
201 |
+
t.requires_grad_(True)
|
202 |
+
|
203 |
+
# Predict the noise residual
|
204 |
+
with accelerator.autocast():
|
205 |
+
noise_pred = self.call_unet(
|
206 |
+
args,
|
207 |
+
accelerator,
|
208 |
+
unet,
|
209 |
+
noisy_latents.requires_grad_(train_unet),
|
210 |
+
timesteps,
|
211 |
+
text_encoder_conds,
|
212 |
+
batch,
|
213 |
+
weight_dtype,
|
214 |
+
)
|
215 |
+
|
216 |
+
if args.v_parameterization:
|
217 |
+
# v-parameterization training
|
218 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
219 |
+
else:
|
220 |
+
target = noise
|
221 |
+
|
222 |
+
# differential output preservation
|
223 |
+
if "custom_attributes" in batch:
|
224 |
+
diff_output_pr_indices = []
|
225 |
+
for i, custom_attributes in enumerate(batch["custom_attributes"]):
|
226 |
+
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
|
227 |
+
diff_output_pr_indices.append(i)
|
228 |
+
|
229 |
+
if len(diff_output_pr_indices) > 0:
|
230 |
+
network.set_multiplier(0.0)
|
231 |
+
with torch.no_grad(), accelerator.autocast():
|
232 |
+
noise_pred_prior = self.call_unet(
|
233 |
+
args,
|
234 |
+
accelerator,
|
235 |
+
unet,
|
236 |
+
noisy_latents,
|
237 |
+
timesteps,
|
238 |
+
text_encoder_conds,
|
239 |
+
batch,
|
240 |
+
weight_dtype,
|
241 |
+
indices=diff_output_pr_indices,
|
242 |
+
)
|
243 |
+
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
|
244 |
+
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)
|
245 |
+
|
246 |
+
return noise_pred, target, timesteps, huber_c, None
|
247 |
+
|
248 |
+
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
249 |
+
if args.min_snr_gamma:
|
250 |
+
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
251 |
+
if args.scale_v_pred_loss_like_noise_pred:
|
252 |
+
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
253 |
+
if args.v_pred_like_loss:
|
254 |
+
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
255 |
+
if args.debiased_estimation_loss:
|
256 |
+
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
|
257 |
+
return loss
|
258 |
+
|
259 |
+
def get_sai_model_spec(self, args):
|
260 |
+
return train_util.get_sai_model_spec(None, args, self.is_sdxl, True, False)
|
261 |
+
|
262 |
+
def update_metadata(self, metadata, args):
|
263 |
+
pass
|
264 |
+
|
265 |
+
def is_text_encoder_not_needed_for_training(self, args):
|
266 |
+
return False # use for sample images
|
267 |
+
|
268 |
+
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
|
269 |
+
# set top parameter requires_grad = True for gradient checkpointing works
|
270 |
+
text_encoder.text_model.embeddings.requires_grad_(True)
|
271 |
+
|
272 |
+
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
|
273 |
+
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
|
274 |
+
|
275 |
+
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
276 |
+
pass
|
277 |
+
|
278 |
+
# endregion
|
279 |
+
|
280 |
+
def train(self, args):
|
281 |
+
session_id = random.randint(0, 2**32)
|
282 |
+
training_started_at = time.time()
|
283 |
+
train_util.verify_training_args(args)
|
284 |
+
train_util.prepare_dataset_args(args, True)
|
285 |
+
deepspeed_utils.prepare_deepspeed_args(args)
|
286 |
+
setup_logging(args, reset=True)
|
287 |
+
|
288 |
+
cache_latents = args.cache_latents
|
289 |
+
use_dreambooth_method = args.in_json is None
|
290 |
+
use_user_config = args.dataset_config is not None
|
291 |
+
|
292 |
+
if args.seed is None:
|
293 |
+
args.seed = random.randint(0, 2**32)
|
294 |
+
set_seed(args.seed)
|
295 |
+
|
296 |
+
tokenize_strategy = self.get_tokenize_strategy(args)
|
297 |
+
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
|
298 |
+
tokenizers = self.get_tokenizers(tokenize_strategy) # will be removed after sample_image is refactored
|
299 |
+
|
300 |
+
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
|
301 |
+
latents_caching_strategy = self.get_latents_caching_strategy(args)
|
302 |
+
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
|
303 |
+
|
304 |
+
# データセットを準備する
|
305 |
+
if args.dataset_class is None:
|
306 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
|
307 |
+
if use_user_config:
|
308 |
+
logger.info(f"Loading dataset config from {args.dataset_config}")
|
309 |
+
user_config = config_util.load_user_config(args.dataset_config)
|
310 |
+
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
311 |
+
if any(getattr(args, attr) is not None for attr in ignored):
|
312 |
+
logger.warning(
|
313 |
+
"ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
314 |
+
", ".join(ignored)
|
315 |
+
)
|
316 |
+
)
|
317 |
+
else:
|
318 |
+
if use_dreambooth_method:
|
319 |
+
logger.info("Using DreamBooth method.")
|
320 |
+
user_config = {
|
321 |
+
"datasets": [
|
322 |
+
{
|
323 |
+
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
|
324 |
+
args.train_data_dir, args.reg_data_dir
|
325 |
+
)
|
326 |
+
}
|
327 |
+
]
|
328 |
+
}
|
329 |
+
else:
|
330 |
+
logger.info("Training with captions.")
|
331 |
+
user_config = {
|
332 |
+
"datasets": [
|
333 |
+
{
|
334 |
+
"subsets": [
|
335 |
+
{
|
336 |
+
"image_dir": args.train_data_dir,
|
337 |
+
"metadata_file": args.in_json,
|
338 |
+
}
|
339 |
+
]
|
340 |
+
}
|
341 |
+
]
|
342 |
+
}
|
343 |
+
|
344 |
+
blueprint = blueprint_generator.generate(user_config, args)
|
345 |
+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
346 |
+
else:
|
347 |
+
# use arbitrary dataset class
|
348 |
+
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
349 |
+
|
350 |
+
current_epoch = Value("i", 0)
|
351 |
+
current_step = Value("i", 0)
|
352 |
+
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
353 |
+
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
354 |
+
|
355 |
+
if args.debug_dataset:
|
356 |
+
train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly
|
357 |
+
train_util.debug_dataset(train_dataset_group)
|
358 |
+
return
|
359 |
+
if len(train_dataset_group) == 0:
|
360 |
+
logger.error(
|
361 |
+
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではな��、画像があるフォルダの親フォルダを指定する必要があります)"
|
362 |
+
)
|
363 |
+
return
|
364 |
+
|
365 |
+
if cache_latents:
|
366 |
+
assert (
|
367 |
+
train_dataset_group.is_latent_cacheable()
|
368 |
+
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
369 |
+
|
370 |
+
self.assert_extra_args(args, train_dataset_group) # may change some args
|
371 |
+
|
372 |
+
# acceleratorを準備する
|
373 |
+
logger.info("preparing accelerator")
|
374 |
+
accelerator = train_util.prepare_accelerator(args)
|
375 |
+
is_main_process = accelerator.is_main_process
|
376 |
+
|
377 |
+
# mixed precisionに対応した型を用意しておき適宜castする
|
378 |
+
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
379 |
+
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
|
380 |
+
|
381 |
+
# モデルを読み込む
|
382 |
+
model_version, text_encoder, vae, unet = self.load_target_model(args, weight_dtype, accelerator)
|
383 |
+
|
384 |
+
# text_encoder is List[CLIPTextModel] or CLIPTextModel
|
385 |
+
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
|
386 |
+
|
387 |
+
# 差分追加学習のためにモデルを読み込む
|
388 |
+
sys.path.append(os.path.dirname(__file__))
|
389 |
+
accelerator.print("import network module:", args.network_module)
|
390 |
+
network_module = importlib.import_module(args.network_module)
|
391 |
+
|
392 |
+
if args.base_weights is not None:
|
393 |
+
# base_weights が指定されている場合は、指定された重みを読み込みマージする
|
394 |
+
for i, weight_path in enumerate(args.base_weights):
|
395 |
+
if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i:
|
396 |
+
multiplier = 1.0
|
397 |
+
else:
|
398 |
+
multiplier = args.base_weights_multiplier[i]
|
399 |
+
|
400 |
+
accelerator.print(f"merging module: {weight_path} with multiplier {multiplier}")
|
401 |
+
|
402 |
+
module, weights_sd = network_module.create_network_from_weights(
|
403 |
+
multiplier, weight_path, vae, text_encoder, unet, for_inference=True
|
404 |
+
)
|
405 |
+
module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu")
|
406 |
+
|
407 |
+
accelerator.print(f"all weights merged: {', '.join(args.base_weights)}")
|
408 |
+
|
409 |
+
# 学習を準備する
|
410 |
+
if cache_latents:
|
411 |
+
vae.to(accelerator.device, dtype=vae_dtype)
|
412 |
+
vae.requires_grad_(False)
|
413 |
+
vae.eval()
|
414 |
+
|
415 |
+
train_dataset_group.new_cache_latents(vae, accelerator)
|
416 |
+
|
417 |
+
vae.to("cpu")
|
418 |
+
clean_memory_on_device(accelerator.device)
|
419 |
+
|
420 |
+
accelerator.wait_for_everyone()
|
421 |
+
|
422 |
+
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
|
423 |
+
# cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu
|
424 |
+
text_encoding_strategy = self.get_text_encoding_strategy(args)
|
425 |
+
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
|
426 |
+
|
427 |
+
text_encoder_outputs_caching_strategy = self.get_text_encoder_outputs_caching_strategy(args)
|
428 |
+
if text_encoder_outputs_caching_strategy is not None:
|
429 |
+
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_outputs_caching_strategy)
|
430 |
+
self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, train_dataset_group, weight_dtype)
|
431 |
+
|
432 |
+
# prepare network
|
433 |
+
net_kwargs = {}
|
434 |
+
if args.network_args is not None:
|
435 |
+
for net_arg in args.network_args:
|
436 |
+
key, value = net_arg.split("=")
|
437 |
+
net_kwargs[key] = value
|
438 |
+
|
439 |
+
# if a new network is added in future, add if ~ then blocks for each network (;'∀')
|
440 |
+
if args.dim_from_weights:
|
441 |
+
network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs)
|
442 |
+
else:
|
443 |
+
if "dropout" not in net_kwargs:
|
444 |
+
# workaround for LyCORIS (;^ω^)
|
445 |
+
net_kwargs["dropout"] = args.network_dropout
|
446 |
+
|
447 |
+
network = network_module.create_network(
|
448 |
+
1.0,
|
449 |
+
args.network_dim,
|
450 |
+
args.network_alpha,
|
451 |
+
vae,
|
452 |
+
text_encoder,
|
453 |
+
unet,
|
454 |
+
neuron_dropout=args.network_dropout,
|
455 |
+
**net_kwargs,
|
456 |
+
)
|
457 |
+
if network is None:
|
458 |
+
return
|
459 |
+
network_has_multiplier = hasattr(network, "set_multiplier")
|
460 |
+
|
461 |
+
if hasattr(network, "prepare_network"):
|
462 |
+
network.prepare_network(args)
|
463 |
+
if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"):
|
464 |
+
logger.warning(
|
465 |
+
"warning: scale_weight_norms is specified but the network does not support it / scale_weight_normsが指定されていますが、ネットワークが対応���ていません"
|
466 |
+
)
|
467 |
+
args.scale_weight_norms = False
|
468 |
+
|
469 |
+
self.post_process_network(args, accelerator, network, text_encoders, unet)
|
470 |
+
|
471 |
+
# apply network to unet and text_encoder
|
472 |
+
train_unet = not args.network_train_text_encoder_only
|
473 |
+
train_text_encoder = self.is_train_text_encoder(args)
|
474 |
+
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
|
475 |
+
|
476 |
+
if args.network_weights is not None:
|
477 |
+
# FIXME consider alpha of weights: this assumes that the alpha is not changed
|
478 |
+
info = network.load_weights(args.network_weights)
|
479 |
+
accelerator.print(f"load network weights from {args.network_weights}: {info}")
|
480 |
+
|
481 |
+
if args.gradient_checkpointing:
|
482 |
+
if args.cpu_offload_checkpointing:
|
483 |
+
unet.enable_gradient_checkpointing(cpu_offload=True)
|
484 |
+
else:
|
485 |
+
unet.enable_gradient_checkpointing()
|
486 |
+
|
487 |
+
for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)):
|
488 |
+
if flag:
|
489 |
+
if t_enc.supports_gradient_checkpointing:
|
490 |
+
t_enc.gradient_checkpointing_enable()
|
491 |
+
del t_enc
|
492 |
+
network.enable_gradient_checkpointing() # may have no effect
|
493 |
+
|
494 |
+
# 学習に必要なクラスを準備する
|
495 |
+
accelerator.print("prepare optimizer, data loader etc.")
|
496 |
+
|
497 |
+
# make backward compatibility for text_encoder_lr
|
498 |
+
support_multiple_lrs = hasattr(network, "prepare_optimizer_params_with_multiple_te_lrs")
|
499 |
+
if support_multiple_lrs:
|
500 |
+
text_encoder_lr = args.text_encoder_lr
|
501 |
+
else:
|
502 |
+
# toml backward compatibility
|
503 |
+
if args.text_encoder_lr is None or isinstance(args.text_encoder_lr, float) or isinstance(args.text_encoder_lr, int):
|
504 |
+
text_encoder_lr = args.text_encoder_lr
|
505 |
+
else:
|
506 |
+
text_encoder_lr = None if len(args.text_encoder_lr) == 0 else args.text_encoder_lr[0]
|
507 |
+
try:
|
508 |
+
if support_multiple_lrs:
|
509 |
+
results = network.prepare_optimizer_params_with_multiple_te_lrs(text_encoder_lr, args.unet_lr, args.learning_rate)
|
510 |
+
else:
|
511 |
+
results = network.prepare_optimizer_params(text_encoder_lr, args.unet_lr, args.learning_rate)
|
512 |
+
if type(results) is tuple:
|
513 |
+
trainable_params = results[0]
|
514 |
+
lr_descriptions = results[1]
|
515 |
+
else:
|
516 |
+
trainable_params = results
|
517 |
+
lr_descriptions = None
|
518 |
+
except TypeError as e:
|
519 |
+
trainable_params = network.prepare_optimizer_params(text_encoder_lr, args.unet_lr)
|
520 |
+
lr_descriptions = None
|
521 |
+
|
522 |
+
# if len(trainable_params) == 0:
|
523 |
+
# accelerator.print("no trainable parameters found / 学習可能なパラメータが見つかりませんでした")
|
524 |
+
# for params in trainable_params:
|
525 |
+
# for k, v in params.items():
|
526 |
+
# if type(v) == float:
|
527 |
+
# pass
|
528 |
+
# else:
|
529 |
+
# v = len(v)
|
530 |
+
# accelerator.print(f"trainable_params: {k} = {v}")
|
531 |
+
|
532 |
+
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
|
533 |
+
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
|
534 |
+
|
535 |
+
# prepare dataloader
|
536 |
+
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
|
537 |
+
# some strategies can be None
|
538 |
+
train_dataset_group.set_current_strategies()
|
539 |
+
|
540 |
+
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
541 |
+
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
542 |
+
|
543 |
+
train_dataloader = torch.utils.data.DataLoader(
|
544 |
+
train_dataset_group,
|
545 |
+
batch_size=1,
|
546 |
+
shuffle=True,
|
547 |
+
collate_fn=collator,
|
548 |
+
num_workers=n_workers,
|
549 |
+
persistent_workers=args.persistent_data_loader_workers,
|
550 |
+
)
|
551 |
+
|
552 |
+
# 学習ステップ数を計算する
|
553 |
+
if args.max_train_epochs is not None:
|
554 |
+
args.max_train_steps = args.max_train_epochs * math.ceil(
|
555 |
+
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
556 |
+
)
|
557 |
+
accelerator.print(
|
558 |
+
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
559 |
+
)
|
560 |
+
|
561 |
+
# データセット側にも学習ステップを送信
|
562 |
+
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
563 |
+
|
564 |
+
# lr schedulerを用意する
|
565 |
+
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
566 |
+
|
567 |
+
# 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
|
568 |
+
if args.full_fp16:
|
569 |
+
assert (
|
570 |
+
args.mixed_precision == "fp16"
|
571 |
+
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
572 |
+
accelerator.print("enable full fp16 training.")
|
573 |
+
network.to(weight_dtype)
|
574 |
+
elif args.full_bf16:
|
575 |
+
assert (
|
576 |
+
args.mixed_precision == "bf16"
|
577 |
+
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
|
578 |
+
accelerator.print("enable full bf16 training.")
|
579 |
+
network.to(weight_dtype)
|
580 |
+
|
581 |
+
unet_weight_dtype = te_weight_dtype = weight_dtype
|
582 |
+
# Experimental Feature: Put base model into fp8 to save vram
|
583 |
+
if args.fp8_base or args.fp8_base_unet:
|
584 |
+
assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。"
|
585 |
+
assert (
|
586 |
+
args.mixed_precision != "no"
|
587 |
+
), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。"
|
588 |
+
accelerator.print("enable fp8 training for U-Net.")
|
589 |
+
unet_weight_dtype = torch.float8_e4m3fn
|
590 |
+
|
591 |
+
if not args.fp8_base_unet:
|
592 |
+
accelerator.print("enable fp8 training for Text Encoder.")
|
593 |
+
te_weight_dtype = weight_dtype if args.fp8_base_unet else torch.float8_e4m3fn
|
594 |
+
|
595 |
+
# unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM
|
596 |
+
# unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory
|
597 |
+
|
598 |
+
logger.info(f"set U-Net weight dtype to {unet_weight_dtype}, device to {accelerator.device}")
|
599 |
+
unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above
|
600 |
+
|
601 |
+
unet.requires_grad_(False)
|
602 |
+
unet.to(dtype=unet_weight_dtype)
|
603 |
+
for i, t_enc in enumerate(text_encoders):
|
604 |
+
t_enc.requires_grad_(False)
|
605 |
+
|
606 |
+
# in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16
|
607 |
+
if t_enc.device.type != "cpu":
|
608 |
+
t_enc.to(dtype=te_weight_dtype)
|
609 |
+
|
610 |
+
# nn.Embedding not support FP8
|
611 |
+
if te_weight_dtype != weight_dtype:
|
612 |
+
self.prepare_text_encoder_fp8(i, t_enc, te_weight_dtype, weight_dtype)
|
613 |
+
|
614 |
+
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
|
615 |
+
if args.deepspeed:
|
616 |
+
flags = self.get_text_encoders_train_flags(args, text_encoders)
|
617 |
+
ds_model = deepspeed_utils.prepare_deepspeed_model(
|
618 |
+
args,
|
619 |
+
unet=unet if train_unet else None,
|
620 |
+
text_encoder1=text_encoders[0] if flags[0] else None,
|
621 |
+
text_encoder2=(text_encoders[1] if flags[1] else None) if len(text_encoders) > 1 else None,
|
622 |
+
network=network,
|
623 |
+
)
|
624 |
+
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
625 |
+
ds_model, optimizer, train_dataloader, lr_scheduler
|
626 |
+
)
|
627 |
+
training_model = ds_model
|
628 |
+
else:
|
629 |
+
if train_unet:
|
630 |
+
unet = accelerator.prepare(unet)
|
631 |
+
else:
|
632 |
+
unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator
|
633 |
+
if train_text_encoder:
|
634 |
+
text_encoders = [
|
635 |
+
(accelerator.prepare(t_enc) if flag else t_enc)
|
636 |
+
for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders))
|
637 |
+
]
|
638 |
+
if len(text_encoders) > 1:
|
639 |
+
text_encoder = text_encoders
|
640 |
+
else:
|
641 |
+
text_encoder = text_encoders[0]
|
642 |
+
else:
|
643 |
+
pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set
|
644 |
+
|
645 |
+
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
646 |
+
network, optimizer, train_dataloader, lr_scheduler
|
647 |
+
)
|
648 |
+
training_model = network
|
649 |
+
|
650 |
+
if args.gradient_checkpointing:
|
651 |
+
# according to TI example in Diffusers, train is required
|
652 |
+
unet.train()
|
653 |
+
for i, (t_enc, frag) in enumerate(zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders))):
|
654 |
+
t_enc.train()
|
655 |
+
|
656 |
+
# set top parameter requires_grad = True for gradient checkpointing works
|
657 |
+
if frag:
|
658 |
+
self.prepare_text_encoder_grad_ckpt_workaround(i, t_enc)
|
659 |
+
|
660 |
+
else:
|
661 |
+
unet.eval()
|
662 |
+
for t_enc in text_encoders:
|
663 |
+
t_enc.eval()
|
664 |
+
|
665 |
+
del t_enc
|
666 |
+
|
667 |
+
accelerator.unwrap_model(network).prepare_grad_etc(text_encoder, unet)
|
668 |
+
|
669 |
+
if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する
|
670 |
+
vae.requires_grad_(False)
|
671 |
+
vae.eval()
|
672 |
+
vae.to(accelerator.device, dtype=vae_dtype)
|
673 |
+
|
674 |
+
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
675 |
+
if args.full_fp16:
|
676 |
+
train_util.patch_accelerator_for_fp16_training(accelerator)
|
677 |
+
|
678 |
+
# before resuming make hook for saving/loading to save/load the network weights only
|
679 |
+
def save_model_hook(models, weights, output_dir):
|
680 |
+
# pop weights of other models than network to save only network weights
|
681 |
+
# only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606
|
682 |
+
if accelerator.is_main_process or args.deepspeed:
|
683 |
+
remove_indices = []
|
684 |
+
for i, model in enumerate(models):
|
685 |
+
if not isinstance(model, type(accelerator.unwrap_model(network))):
|
686 |
+
remove_indices.append(i)
|
687 |
+
for i in reversed(remove_indices):
|
688 |
+
if len(weights) > i:
|
689 |
+
weights.pop(i)
|
690 |
+
# print(f"save model hook: {len(weights)} weights will be saved")
|
691 |
+
|
692 |
+
# save current ecpoch and step
|
693 |
+
train_state_file = os.path.join(output_dir, "train_state.json")
|
694 |
+
# +1 is needed because the state is saved before current_step is set from global_step
|
695 |
+
logger.info(f"save train state to {train_state_file} at epoch {current_epoch.value} step {current_step.value+1}")
|
696 |
+
with open(train_state_file, "w", encoding="utf-8") as f:
|
697 |
+
json.dump({"current_epoch": current_epoch.value, "current_step": current_step.value + 1}, f)
|
698 |
+
|
699 |
+
steps_from_state = None
|
700 |
+
|
701 |
+
def load_model_hook(models, input_dir):
|
702 |
+
# remove models except network
|
703 |
+
remove_indices = []
|
704 |
+
for i, model in enumerate(models):
|
705 |
+
if not isinstance(model, type(accelerator.unwrap_model(network))):
|
706 |
+
remove_indices.append(i)
|
707 |
+
for i in reversed(remove_indices):
|
708 |
+
models.pop(i)
|
709 |
+
# print(f"load model hook: {len(models)} models will be loaded")
|
710 |
+
|
711 |
+
# load current epoch and step to
|
712 |
+
nonlocal steps_from_state
|
713 |
+
train_state_file = os.path.join(input_dir, "train_state.json")
|
714 |
+
if os.path.exists(train_state_file):
|
715 |
+
with open(train_state_file, "r", encoding="utf-8") as f:
|
716 |
+
data = json.load(f)
|
717 |
+
steps_from_state = data["current_step"]
|
718 |
+
logger.info(f"load train state from {train_state_file}: {data}")
|
719 |
+
|
720 |
+
accelerator.register_save_state_pre_hook(save_model_hook)
|
721 |
+
accelerator.register_load_state_pre_hook(load_model_hook)
|
722 |
+
|
723 |
+
# resumeする
|
724 |
+
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
725 |
+
|
726 |
+
# epoch数を計算する
|
727 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
728 |
+
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
729 |
+
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
730 |
+
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
731 |
+
|
732 |
+
# 学習する
|
733 |
+
# TODO: find a way to handle total batch size when there are multiple datasets
|
734 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
735 |
+
|
736 |
+
accelerator.print("running training / 学習開始")
|
737 |
+
accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
738 |
+
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
739 |
+
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
740 |
+
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
741 |
+
accelerator.print(
|
742 |
+
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
|
743 |
+
)
|
744 |
+
# accelerator.print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
745 |
+
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
746 |
+
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
747 |
+
|
748 |
+
# TODO refactor metadata creation and move to util
|
749 |
+
metadata = {
|
750 |
+
"ss_session_id": session_id, # random integer indicating which group of epochs the model came from
|
751 |
+
"ss_training_started_at": training_started_at, # unix timestamp
|
752 |
+
"ss_output_name": args.output_name,
|
753 |
+
"ss_learning_rate": args.learning_rate,
|
754 |
+
"ss_text_encoder_lr": text_encoder_lr,
|
755 |
+
"ss_unet_lr": args.unet_lr,
|
756 |
+
"ss_num_train_images": train_dataset_group.num_train_images,
|
757 |
+
"ss_num_reg_images": train_dataset_group.num_reg_images,
|
758 |
+
"ss_num_batches_per_epoch": len(train_dataloader),
|
759 |
+
"ss_num_epochs": num_train_epochs,
|
760 |
+
"ss_gradient_checkpointing": args.gradient_checkpointing,
|
761 |
+
"ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
|
762 |
+
"ss_max_train_steps": args.max_train_steps,
|
763 |
+
"ss_lr_warmup_steps": args.lr_warmup_steps,
|
764 |
+
"ss_lr_scheduler": args.lr_scheduler,
|
765 |
+
"ss_network_module": args.network_module,
|
766 |
+
"ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim
|
767 |
+
"ss_network_alpha": args.network_alpha, # some networks may not have alpha
|
768 |
+
"ss_network_dropout": args.network_dropout, # some networks may not have dropout
|
769 |
+
"ss_mixed_precision": args.mixed_precision,
|
770 |
+
"ss_full_fp16": bool(args.full_fp16),
|
771 |
+
"ss_v2": bool(args.v2),
|
772 |
+
"ss_base_model_version": model_version,
|
773 |
+
"ss_clip_skip": args.clip_skip,
|
774 |
+
"ss_max_token_length": args.max_token_length,
|
775 |
+
"ss_cache_latents": bool(args.cache_latents),
|
776 |
+
"ss_seed": args.seed,
|
777 |
+
"ss_lowram": args.lowram,
|
778 |
+
"ss_noise_offset": args.noise_offset,
|
779 |
+
"ss_multires_noise_iterations": args.multires_noise_iterations,
|
780 |
+
"ss_multires_noise_discount": args.multires_noise_discount,
|
781 |
+
"ss_adaptive_noise_scale": args.adaptive_noise_scale,
|
782 |
+
"ss_zero_terminal_snr": args.zero_terminal_snr,
|
783 |
+
"ss_training_comment": args.training_comment, # will not be updated after training
|
784 |
+
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
|
785 |
+
"ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""),
|
786 |
+
"ss_max_grad_norm": args.max_grad_norm,
|
787 |
+
"ss_caption_dropout_rate": args.caption_dropout_rate,
|
788 |
+
"ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs,
|
789 |
+
"ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate,
|
790 |
+
"ss_face_crop_aug_range": args.face_crop_aug_range,
|
791 |
+
"ss_prior_loss_weight": args.prior_loss_weight,
|
792 |
+
"ss_min_snr_gamma": args.min_snr_gamma,
|
793 |
+
"ss_scale_weight_norms": args.scale_weight_norms,
|
794 |
+
"ss_ip_noise_gamma": args.ip_noise_gamma,
|
795 |
+
"ss_debiased_estimation": bool(args.debiased_estimation_loss),
|
796 |
+
"ss_noise_offset_random_strength": args.noise_offset_random_strength,
|
797 |
+
"ss_ip_noise_gamma_random_strength": args.ip_noise_gamma_random_strength,
|
798 |
+
"ss_loss_type": args.loss_type,
|
799 |
+
"ss_huber_schedule": args.huber_schedule,
|
800 |
+
"ss_huber_c": args.huber_c,
|
801 |
+
"ss_fp8_base": bool(args.fp8_base),
|
802 |
+
"ss_fp8_base_unet": bool(args.fp8_base_unet),
|
803 |
+
}
|
804 |
+
|
805 |
+
self.update_metadata(metadata, args) # architecture specific metadata
|
806 |
+
|
807 |
+
if use_user_config:
|
808 |
+
# save metadata of multiple datasets
|
809 |
+
# NOTE: pack "ss_datasets" value as json one time
|
810 |
+
# or should also pack nested collections as json?
|
811 |
+
datasets_metadata = []
|
812 |
+
tag_frequency = {} # merge tag frequency for metadata editor
|
813 |
+
dataset_dirs_info = {} # merge subset dirs for metadata editor
|
814 |
+
|
815 |
+
for dataset in train_dataset_group.datasets:
|
816 |
+
is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset)
|
817 |
+
dataset_metadata = {
|
818 |
+
"is_dreambooth": is_dreambooth_dataset,
|
819 |
+
"batch_size_per_device": dataset.batch_size,
|
820 |
+
"num_train_images": dataset.num_train_images, # includes repeating
|
821 |
+
"num_reg_images": dataset.num_reg_images,
|
822 |
+
"resolution": (dataset.width, dataset.height),
|
823 |
+
"enable_bucket": bool(dataset.enable_bucket),
|
824 |
+
"min_bucket_reso": dataset.min_bucket_reso,
|
825 |
+
"max_bucket_reso": dataset.max_bucket_reso,
|
826 |
+
"tag_frequency": dataset.tag_frequency,
|
827 |
+
"bucket_info": dataset.bucket_info,
|
828 |
+
}
|
829 |
+
|
830 |
+
subsets_metadata = []
|
831 |
+
for subset in dataset.subsets:
|
832 |
+
subset_metadata = {
|
833 |
+
"img_count": subset.img_count,
|
834 |
+
"num_repeats": subset.num_repeats,
|
835 |
+
"color_aug": bool(subset.color_aug),
|
836 |
+
"flip_aug": bool(subset.flip_aug),
|
837 |
+
"random_crop": bool(subset.random_crop),
|
838 |
+
"shuffle_caption": bool(subset.shuffle_caption),
|
839 |
+
"keep_tokens": subset.keep_tokens,
|
840 |
+
"keep_tokens_separator": subset.keep_tokens_separator,
|
841 |
+
"secondary_separator": subset.secondary_separator,
|
842 |
+
"enable_wildcard": bool(subset.enable_wildcard),
|
843 |
+
"caption_prefix": subset.caption_prefix,
|
844 |
+
"caption_suffix": subset.caption_suffix,
|
845 |
+
}
|
846 |
+
|
847 |
+
image_dir_or_metadata_file = None
|
848 |
+
if subset.image_dir:
|
849 |
+
image_dir = os.path.basename(subset.image_dir)
|
850 |
+
subset_metadata["image_dir"] = image_dir
|
851 |
+
image_dir_or_metadata_file = image_dir
|
852 |
+
|
853 |
+
if is_dreambooth_dataset:
|
854 |
+
subset_metadata["class_tokens"] = subset.class_tokens
|
855 |
+
subset_metadata["is_reg"] = subset.is_reg
|
856 |
+
if subset.is_reg:
|
857 |
+
image_dir_or_metadata_file = None # not merging reg dataset
|
858 |
+
else:
|
859 |
+
metadata_file = os.path.basename(subset.metadata_file)
|
860 |
+
subset_metadata["metadata_file"] = metadata_file
|
861 |
+
image_dir_or_metadata_file = metadata_file # may overwrite
|
862 |
+
|
863 |
+
subsets_metadata.append(subset_metadata)
|
864 |
+
|
865 |
+
# merge dataset dir: not reg subset only
|
866 |
+
# TODO update additional-network extension to show detailed dataset config from metadata
|
867 |
+
if image_dir_or_metadata_file is not None:
|
868 |
+
# datasets may have a certain dir multiple times
|
869 |
+
v = image_dir_or_metadata_file
|
870 |
+
i = 2
|
871 |
+
while v in dataset_dirs_info:
|
872 |
+
v = image_dir_or_metadata_file + f" ({i})"
|
873 |
+
i += 1
|
874 |
+
image_dir_or_metadata_file = v
|
875 |
+
|
876 |
+
dataset_dirs_info[image_dir_or_metadata_file] = {
|
877 |
+
"n_repeats": subset.num_repeats,
|
878 |
+
"img_count": subset.img_count,
|
879 |
+
}
|
880 |
+
|
881 |
+
dataset_metadata["subsets"] = subsets_metadata
|
882 |
+
datasets_metadata.append(dataset_metadata)
|
883 |
+
|
884 |
+
# merge tag frequency:
|
885 |
+
for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items():
|
886 |
+
# あるディレクトリが複数のdatasetで使用されている場合、一度だけ数える
|
887 |
+
# もともと繰り返し回数を指定しているので、キャプション内でのタグの出現回数と、それが学習で何度使われるかは一致しない
|
888 |
+
# なので、ここで複数datasetの回数を合算してもあまり意味はない
|
889 |
+
if ds_dir_name in tag_frequency:
|
890 |
+
continue
|
891 |
+
tag_frequency[ds_dir_name] = ds_freq_for_dir
|
892 |
+
|
893 |
+
metadata["ss_datasets"] = json.dumps(datasets_metadata)
|
894 |
+
metadata["ss_tag_frequency"] = json.dumps(tag_frequency)
|
895 |
+
metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info)
|
896 |
+
else:
|
897 |
+
# conserving backward compatibility when using train_dataset_dir and reg_dataset_dir
|
898 |
+
assert (
|
899 |
+
len(train_dataset_group.datasets) == 1
|
900 |
+
), f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。"
|
901 |
+
|
902 |
+
dataset = train_dataset_group.datasets[0]
|
903 |
+
|
904 |
+
dataset_dirs_info = {}
|
905 |
+
reg_dataset_dirs_info = {}
|
906 |
+
if use_dreambooth_method:
|
907 |
+
for subset in dataset.subsets:
|
908 |
+
info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info
|
909 |
+
info[os.path.basename(subset.image_dir)] = {"n_repeats": subset.num_repeats, "img_count": subset.img_count}
|
910 |
+
else:
|
911 |
+
for subset in dataset.subsets:
|
912 |
+
dataset_dirs_info[os.path.basename(subset.metadata_file)] = {
|
913 |
+
"n_repeats": subset.num_repeats,
|
914 |
+
"img_count": subset.img_count,
|
915 |
+
}
|
916 |
+
|
917 |
+
metadata.update(
|
918 |
+
{
|
919 |
+
"ss_batch_size_per_device": args.train_batch_size,
|
920 |
+
"ss_total_batch_size": total_batch_size,
|
921 |
+
"ss_resolution": args.resolution,
|
922 |
+
"ss_color_aug": bool(args.color_aug),
|
923 |
+
"ss_flip_aug": bool(args.flip_aug),
|
924 |
+
"ss_random_crop": bool(args.random_crop),
|
925 |
+
"ss_shuffle_caption": bool(args.shuffle_caption),
|
926 |
+
"ss_enable_bucket": bool(dataset.enable_bucket),
|
927 |
+
"ss_bucket_no_upscale": bool(dataset.bucket_no_upscale),
|
928 |
+
"ss_min_bucket_reso": dataset.min_bucket_reso,
|
929 |
+
"ss_max_bucket_reso": dataset.max_bucket_reso,
|
930 |
+
"ss_keep_tokens": args.keep_tokens,
|
931 |
+
"ss_dataset_dirs": json.dumps(dataset_dirs_info),
|
932 |
+
"ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info),
|
933 |
+
"ss_tag_frequency": json.dumps(dataset.tag_frequency),
|
934 |
+
"ss_bucket_info": json.dumps(dataset.bucket_info),
|
935 |
+
}
|
936 |
+
)
|
937 |
+
|
938 |
+
# add extra args
|
939 |
+
if args.network_args:
|
940 |
+
metadata["ss_network_args"] = json.dumps(net_kwargs)
|
941 |
+
|
942 |
+
# model name and hash
|
943 |
+
if args.pretrained_model_name_or_path is not None:
|
944 |
+
sd_model_name = args.pretrained_model_name_or_path
|
945 |
+
if os.path.exists(sd_model_name):
|
946 |
+
metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name)
|
947 |
+
metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(sd_model_name)
|
948 |
+
sd_model_name = os.path.basename(sd_model_name)
|
949 |
+
metadata["ss_sd_model_name"] = sd_model_name
|
950 |
+
|
951 |
+
if args.vae is not None:
|
952 |
+
vae_name = args.vae
|
953 |
+
if os.path.exists(vae_name):
|
954 |
+
metadata["ss_vae_hash"] = train_util.model_hash(vae_name)
|
955 |
+
metadata["ss_new_vae_hash"] = train_util.calculate_sha256(vae_name)
|
956 |
+
vae_name = os.path.basename(vae_name)
|
957 |
+
metadata["ss_vae_name"] = vae_name
|
958 |
+
|
959 |
+
metadata = {k: str(v) for k, v in metadata.items()}
|
960 |
+
|
961 |
+
# make minimum metadata for filtering
|
962 |
+
minimum_metadata = {}
|
963 |
+
for key in train_util.SS_METADATA_MINIMUM_KEYS:
|
964 |
+
if key in metadata:
|
965 |
+
minimum_metadata[key] = metadata[key]
|
966 |
+
|
967 |
+
# calculate steps to skip when resuming or starting from a specific step
|
968 |
+
initial_step = 0
|
969 |
+
if args.initial_epoch is not None or args.initial_step is not None:
|
970 |
+
# if initial_epoch or initial_step is specified, steps_from_state is ignored even when resuming
|
971 |
+
if steps_from_state is not None:
|
972 |
+
logger.warning(
|
973 |
+
"steps from the state is ignored because initial_step is specified / initial_stepが指定されているため、stateからのステップ数は無視されます"
|
974 |
+
)
|
975 |
+
if args.initial_step is not None:
|
976 |
+
initial_step = args.initial_step
|
977 |
+
else:
|
978 |
+
# num steps per epoch is calculated by num_processes and gradient_accumulation_steps
|
979 |
+
initial_step = (args.initial_epoch - 1) * math.ceil(
|
980 |
+
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
981 |
+
)
|
982 |
+
else:
|
983 |
+
# if initial_epoch and initial_step are not specified, steps_from_state is used when resuming
|
984 |
+
if steps_from_state is not None:
|
985 |
+
initial_step = steps_from_state
|
986 |
+
steps_from_state = None
|
987 |
+
|
988 |
+
if initial_step > 0:
|
989 |
+
assert (
|
990 |
+
args.max_train_steps > initial_step
|
991 |
+
), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: {args.max_train_steps} vs {initial_step}"
|
992 |
+
|
993 |
+
progress_bar = tqdm(
|
994 |
+
range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps"
|
995 |
+
)
|
996 |
+
|
997 |
+
epoch_to_start = 0
|
998 |
+
if initial_step > 0:
|
999 |
+
if args.skip_until_initial_step:
|
1000 |
+
# if skip_until_initial_step is specified, load data and discard it to ensure the same data is used
|
1001 |
+
if not args.resume:
|
1002 |
+
logger.info(
|
1003 |
+
f"initial_step is specified but not resuming. lr scheduler will be started from the beginning / initial_stepが指定されていますがresumeしていないため、lr schedulerは最初から始まります"
|
1004 |
+
)
|
1005 |
+
logger.info(f"skipping {initial_step} steps / {initial_step}ステップをスキップします")
|
1006 |
+
initial_step *= args.gradient_accumulation_steps
|
1007 |
+
|
1008 |
+
# set epoch to start to make initial_step less than len(train_dataloader)
|
1009 |
+
epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
1010 |
+
else:
|
1011 |
+
# if not, only epoch no is skipped for informative purpose
|
1012 |
+
epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
1013 |
+
initial_step = 0 # do not skip
|
1014 |
+
|
1015 |
+
global_step = 0
|
1016 |
+
|
1017 |
+
noise_scheduler = self.get_noise_scheduler(args, accelerator.device)
|
1018 |
+
|
1019 |
+
if accelerator.is_main_process:
|
1020 |
+
init_kwargs = {}
|
1021 |
+
if args.wandb_run_name:
|
1022 |
+
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
1023 |
+
if args.log_tracker_config is not None:
|
1024 |
+
init_kwargs = toml.load(args.log_tracker_config)
|
1025 |
+
accelerator.init_trackers(
|
1026 |
+
"network_train" if args.log_tracker_name is None else args.log_tracker_name,
|
1027 |
+
config=train_util.get_sanitized_config_or_none(args),
|
1028 |
+
init_kwargs=init_kwargs,
|
1029 |
+
)
|
1030 |
+
|
1031 |
+
loss_recorder = train_util.LossRecorder()
|
1032 |
+
del train_dataset_group
|
1033 |
+
|
1034 |
+
# callback for step start
|
1035 |
+
if hasattr(accelerator.unwrap_model(network), "on_step_start"):
|
1036 |
+
on_step_start_for_network = accelerator.unwrap_model(network).on_step_start
|
1037 |
+
else:
|
1038 |
+
on_step_start_for_network = lambda *args, **kwargs: None
|
1039 |
+
|
1040 |
+
# function for saving/removing
|
1041 |
+
def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False):
|
1042 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
1043 |
+
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
1044 |
+
|
1045 |
+
accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
|
1046 |
+
metadata["ss_training_finished_at"] = str(time.time())
|
1047 |
+
metadata["ss_steps"] = str(steps)
|
1048 |
+
metadata["ss_epoch"] = str(epoch_no)
|
1049 |
+
|
1050 |
+
metadata_to_save = minimum_metadata if args.no_metadata else metadata
|
1051 |
+
sai_metadata = self.get_sai_model_spec(args)
|
1052 |
+
metadata_to_save.update(sai_metadata)
|
1053 |
+
|
1054 |
+
unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save)
|
1055 |
+
if args.huggingface_repo_id is not None:
|
1056 |
+
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
|
1057 |
+
|
1058 |
+
def remove_model(old_ckpt_name):
|
1059 |
+
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
1060 |
+
if os.path.exists(old_ckpt_file):
|
1061 |
+
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
|
1062 |
+
os.remove(old_ckpt_file)
|
1063 |
+
|
1064 |
+
# if text_encoder is not needed for training, delete it to save memory.
|
1065 |
+
# TODO this can be automated after SDXL sample prompt cache is implemented
|
1066 |
+
if self.is_text_encoder_not_needed_for_training(args):
|
1067 |
+
logger.info("text_encoder is not needed for training. deleting to save memory.")
|
1068 |
+
for t_enc in text_encoders:
|
1069 |
+
del t_enc
|
1070 |
+
text_encoders = []
|
1071 |
+
text_encoder = None
|
1072 |
+
|
1073 |
+
# For --sample_at_first
|
1074 |
+
optimizer_eval_fn()
|
1075 |
+
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
|
1076 |
+
optimizer_train_fn()
|
1077 |
+
if len(accelerator.trackers) > 0:
|
1078 |
+
# log empty object to commit the sample images to wandb
|
1079 |
+
accelerator.log({}, step=0)
|
1080 |
+
|
1081 |
+
# training loop
|
1082 |
+
if initial_step > 0: # only if skip_until_initial_step is specified
|
1083 |
+
for skip_epoch in range(epoch_to_start): # skip epochs
|
1084 |
+
logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}")
|
1085 |
+
initial_step -= len(train_dataloader)
|
1086 |
+
global_step = initial_step
|
1087 |
+
|
1088 |
+
# log device and dtype for each model
|
1089 |
+
logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}")
|
1090 |
+
for i, t_enc in enumerate(text_encoders):
|
1091 |
+
params_itr = t_enc.parameters()
|
1092 |
+
params_itr.__next__() # skip the first parameter
|
1093 |
+
params_itr.__next__() # skip the second parameter. because CLIP first two parameters are embeddings
|
1094 |
+
param_3rd = params_itr.__next__()
|
1095 |
+
logger.info(f"text_encoder [{i}] dtype: {param_3rd.dtype}, device: {t_enc.device}")
|
1096 |
+
|
1097 |
+
clean_memory_on_device(accelerator.device)
|
1098 |
+
|
1099 |
+
for epoch in range(epoch_to_start, num_train_epochs):
|
1100 |
+
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
1101 |
+
current_epoch.value = epoch + 1
|
1102 |
+
|
1103 |
+
metadata["ss_epoch"] = str(epoch + 1)
|
1104 |
+
|
1105 |
+
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)
|
1106 |
+
|
1107 |
+
skipped_dataloader = None
|
1108 |
+
if initial_step > 0:
|
1109 |
+
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step - 1)
|
1110 |
+
initial_step = 1
|
1111 |
+
|
1112 |
+
for step, batch in enumerate(skipped_dataloader or train_dataloader):
|
1113 |
+
current_step.value = global_step
|
1114 |
+
if initial_step > 0:
|
1115 |
+
initial_step -= 1
|
1116 |
+
continue
|
1117 |
+
|
1118 |
+
with accelerator.accumulate(training_model):
|
1119 |
+
on_step_start_for_network(text_encoder, unet)
|
1120 |
+
|
1121 |
+
# temporary, for batch processing
|
1122 |
+
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
1123 |
+
|
1124 |
+
if "latents" in batch and batch["latents"] is not None:
|
1125 |
+
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
1126 |
+
else:
|
1127 |
+
with torch.no_grad():
|
1128 |
+
# latentに変換
|
1129 |
+
latents = self.encode_images_to_latents(args, accelerator, vae, batch["images"].to(vae_dtype))
|
1130 |
+
latents = latents.to(dtype=weight_dtype)
|
1131 |
+
|
1132 |
+
# NaNが含まれていれば警告を表示し0に置き換える
|
1133 |
+
if torch.any(torch.isnan(latents)):
|
1134 |
+
accelerator.print("NaN found in latents, replacing with zeros")
|
1135 |
+
latents = torch.nan_to_num(latents, 0, out=latents)
|
1136 |
+
|
1137 |
+
latents = self.shift_scale_latents(args, latents)
|
1138 |
+
|
1139 |
+
# get multiplier for each sample
|
1140 |
+
if network_has_multiplier:
|
1141 |
+
multipliers = batch["network_multipliers"]
|
1142 |
+
# if all multipliers are same, use single multiplier
|
1143 |
+
if torch.all(multipliers == multipliers[0]):
|
1144 |
+
multipliers = multipliers[0].item()
|
1145 |
+
else:
|
1146 |
+
raise NotImplementedError("multipliers for each sample is not supported yet")
|
1147 |
+
# print(f"set multiplier: {multipliers}")
|
1148 |
+
accelerator.unwrap_model(network).set_multiplier(multipliers)
|
1149 |
+
|
1150 |
+
text_encoder_conds = []
|
1151 |
+
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
1152 |
+
if text_encoder_outputs_list is not None:
|
1153 |
+
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
|
1154 |
+
|
1155 |
+
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
|
1156 |
+
# TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached'
|
1157 |
+
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
|
1158 |
+
# Get the text embedding for conditioning
|
1159 |
+
if args.weighted_captions:
|
1160 |
+
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
|
1161 |
+
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights(
|
1162 |
+
tokenize_strategy,
|
1163 |
+
self.get_models_for_text_encoding(args, accelerator, text_encoders),
|
1164 |
+
input_ids_list,
|
1165 |
+
weights_list,
|
1166 |
+
)
|
1167 |
+
else:
|
1168 |
+
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
|
1169 |
+
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens(
|
1170 |
+
tokenize_strategy,
|
1171 |
+
self.get_models_for_text_encoding(args, accelerator, text_encoders),
|
1172 |
+
input_ids,
|
1173 |
+
)
|
1174 |
+
if args.full_fp16:
|
1175 |
+
encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds]
|
1176 |
+
|
1177 |
+
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
|
1178 |
+
if len(text_encoder_conds) == 0:
|
1179 |
+
text_encoder_conds = encoded_text_encoder_conds
|
1180 |
+
else:
|
1181 |
+
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
|
1182 |
+
for i in range(len(encoded_text_encoder_conds)):
|
1183 |
+
if encoded_text_encoder_conds[i] is not None:
|
1184 |
+
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
1185 |
+
|
1186 |
+
# sample noise, call unet, get target
|
1187 |
+
noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target(
|
1188 |
+
args,
|
1189 |
+
accelerator,
|
1190 |
+
noise_scheduler,
|
1191 |
+
latents,
|
1192 |
+
batch,
|
1193 |
+
text_encoder_conds,
|
1194 |
+
unet,
|
1195 |
+
network,
|
1196 |
+
weight_dtype,
|
1197 |
+
train_unet,
|
1198 |
+
)
|
1199 |
+
|
1200 |
+
loss = train_util.conditional_loss(
|
1201 |
+
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
1202 |
+
)
|
1203 |
+
if weighting is not None:
|
1204 |
+
loss = loss * weighting
|
1205 |
+
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
1206 |
+
loss = apply_masked_loss(loss, batch)
|
1207 |
+
loss = loss.mean([1, 2, 3])
|
1208 |
+
|
1209 |
+
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
1210 |
+
loss = loss * loss_weights
|
1211 |
+
|
1212 |
+
# min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc.
|
1213 |
+
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
|
1214 |
+
|
1215 |
+
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
1216 |
+
|
1217 |
+
accelerator.backward(loss)
|
1218 |
+
if accelerator.sync_gradients:
|
1219 |
+
self.all_reduce_network(accelerator, network) # sync DDP grad manually
|
1220 |
+
if args.max_grad_norm != 0.0:
|
1221 |
+
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
|
1222 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
1223 |
+
|
1224 |
+
optimizer.step()
|
1225 |
+
lr_scheduler.step()
|
1226 |
+
optimizer.zero_grad(set_to_none=True)
|
1227 |
+
|
1228 |
+
if args.scale_weight_norms:
|
1229 |
+
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
|
1230 |
+
args.scale_weight_norms, accelerator.device
|
1231 |
+
)
|
1232 |
+
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
|
1233 |
+
else:
|
1234 |
+
keys_scaled, mean_norm, maximum_norm = None, None, None
|
1235 |
+
|
1236 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
1237 |
+
if accelerator.sync_gradients:
|
1238 |
+
progress_bar.update(1)
|
1239 |
+
global_step += 1
|
1240 |
+
|
1241 |
+
optimizer_eval_fn()
|
1242 |
+
self.sample_images(
|
1243 |
+
accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet
|
1244 |
+
)
|
1245 |
+
|
1246 |
+
# 指定ステップごとにモデルを保存
|
1247 |
+
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
1248 |
+
accelerator.wait_for_everyone()
|
1249 |
+
if accelerator.is_main_process:
|
1250 |
+
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
|
1251 |
+
save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch)
|
1252 |
+
|
1253 |
+
if args.save_state:
|
1254 |
+
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
|
1255 |
+
|
1256 |
+
remove_step_no = train_util.get_remove_step_no(args, global_step)
|
1257 |
+
if remove_step_no is not None:
|
1258 |
+
remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
|
1259 |
+
remove_model(remove_ckpt_name)
|
1260 |
+
optimizer_train_fn()
|
1261 |
+
|
1262 |
+
current_loss = loss.detach().item()
|
1263 |
+
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
1264 |
+
avr_loss: float = loss_recorder.moving_average
|
1265 |
+
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
1266 |
+
progress_bar.set_postfix(**logs)
|
1267 |
+
|
1268 |
+
if args.scale_weight_norms:
|
1269 |
+
progress_bar.set_postfix(**{**max_mean_logs, **logs})
|
1270 |
+
|
1271 |
+
if len(accelerator.trackers) > 0:
|
1272 |
+
logs = self.generate_step_logs(
|
1273 |
+
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, keys_scaled, mean_norm, maximum_norm
|
1274 |
+
)
|
1275 |
+
accelerator.log(logs, step=global_step)
|
1276 |
+
|
1277 |
+
if global_step >= args.max_train_steps:
|
1278 |
+
break
|
1279 |
+
|
1280 |
+
if len(accelerator.trackers) > 0:
|
1281 |
+
logs = {"loss/epoch": loss_recorder.moving_average}
|
1282 |
+
accelerator.log(logs, step=epoch + 1)
|
1283 |
+
|
1284 |
+
accelerator.wait_for_everyone()
|
1285 |
+
|
1286 |
+
# 指定エポックごとにモデルを保存
|
1287 |
+
optimizer_eval_fn()
|
1288 |
+
if args.save_every_n_epochs is not None:
|
1289 |
+
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
|
1290 |
+
if is_main_process and saving:
|
1291 |
+
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
|
1292 |
+
save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1)
|
1293 |
+
|
1294 |
+
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
|
1295 |
+
if remove_epoch_no is not None:
|
1296 |
+
remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no)
|
1297 |
+
remove_model(remove_ckpt_name)
|
1298 |
+
|
1299 |
+
if args.save_state:
|
1300 |
+
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
|
1301 |
+
|
1302 |
+
self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
|
1303 |
+
optimizer_train_fn()
|
1304 |
+
|
1305 |
+
# end of epoch
|
1306 |
+
|
1307 |
+
# metadata["ss_epoch"] = str(num_train_epochs)
|
1308 |
+
metadata["ss_training_finished_at"] = str(time.time())
|
1309 |
+
|
1310 |
+
if is_main_process:
|
1311 |
+
network = accelerator.unwrap_model(network)
|
1312 |
+
|
1313 |
+
accelerator.end_training()
|
1314 |
+
optimizer_eval_fn()
|
1315 |
+
|
1316 |
+
if is_main_process and (args.save_state or args.save_state_on_train_end):
|
1317 |
+
train_util.save_state_on_train_end(args, accelerator)
|
1318 |
+
|
1319 |
+
if is_main_process:
|
1320 |
+
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
1321 |
+
save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True)
|
1322 |
+
|
1323 |
+
logger.info("model saved.")
|
1324 |
+
|
1325 |
+
|
1326 |
+
def setup_parser() -> argparse.ArgumentParser:
|
1327 |
+
parser = argparse.ArgumentParser()
|
1328 |
+
|
1329 |
+
add_logging_arguments(parser)
|
1330 |
+
train_util.add_sd_models_arguments(parser)
|
1331 |
+
train_util.add_dataset_arguments(parser, True, True, True)
|
1332 |
+
train_util.add_training_arguments(parser, True)
|
1333 |
+
train_util.add_masked_loss_arguments(parser)
|
1334 |
+
deepspeed_utils.add_deepspeed_arguments(parser)
|
1335 |
+
train_util.add_optimizer_arguments(parser)
|
1336 |
+
config_util.add_config_arguments(parser)
|
1337 |
+
custom_train_functions.add_custom_train_arguments(parser)
|
1338 |
+
|
1339 |
+
parser.add_argument(
|
1340 |
+
"--cpu_offload_checkpointing",
|
1341 |
+
action="store_true",
|
1342 |
+
help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing for U-Net or DiT, if supported"
|
1343 |
+
" / 勾配チェックポイント時にテンソルをCPUにオフロードする(U-NetまたはDiTのみ、サポートされている場合)",
|
1344 |
+
)
|
1345 |
+
parser.add_argument(
|
1346 |
+
"--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない"
|
1347 |
+
)
|
1348 |
+
parser.add_argument(
|
1349 |
+
"--save_model_as",
|
1350 |
+
type=str,
|
1351 |
+
default="safetensors",
|
1352 |
+
choices=[None, "ckpt", "pt", "safetensors"],
|
1353 |
+
help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)",
|
1354 |
+
)
|
1355 |
+
|
1356 |
+
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
1357 |
+
parser.add_argument(
|
1358 |
+
"--text_encoder_lr",
|
1359 |
+
type=float,
|
1360 |
+
default=None,
|
1361 |
+
nargs="*",
|
1362 |
+
help="learning rate for Text Encoder, can be multiple / Text Encoderの学習率、複数指定可能",
|
1363 |
+
)
|
1364 |
+
parser.add_argument(
|
1365 |
+
"--fp8_base_unet",
|
1366 |
+
action="store_true",
|
1367 |
+
help="use fp8 for U-Net (or DiT), Text Encoder is fp16 or bf16"
|
1368 |
+
" / U-Net(またはDiT)にfp8を使用する。Text Encoderはfp16またはbf16",
|
1369 |
+
)
|
1370 |
+
|
1371 |
+
parser.add_argument(
|
1372 |
+
"--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み"
|
1373 |
+
)
|
1374 |
+
parser.add_argument(
|
1375 |
+
"--network_module", type=str, default=None, help="network module to train / 学習対象のネットワークのモジュール"
|
1376 |
+
)
|
1377 |
+
parser.add_argument(
|
1378 |
+
"--network_dim",
|
1379 |
+
type=int,
|
1380 |
+
default=None,
|
1381 |
+
help="network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)",
|
1382 |
+
)
|
1383 |
+
parser.add_argument(
|
1384 |
+
"--network_alpha",
|
1385 |
+
type=float,
|
1386 |
+
default=1,
|
1387 |
+
help="alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1(旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定)",
|
1388 |
+
)
|
1389 |
+
parser.add_argument(
|
1390 |
+
"--network_dropout",
|
1391 |
+
type=float,
|
1392 |
+
default=None,
|
1393 |
+
help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする(0またはNoneはdropoutなし、1は全ニューロンをdropout)",
|
1394 |
+
)
|
1395 |
+
parser.add_argument(
|
1396 |
+
"--network_args",
|
1397 |
+
type=str,
|
1398 |
+
default=None,
|
1399 |
+
nargs="*",
|
1400 |
+
help="additional arguments for network (key=value) / ネットワークへの追加の引数",
|
1401 |
+
)
|
1402 |
+
parser.add_argument(
|
1403 |
+
"--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する"
|
1404 |
+
)
|
1405 |
+
parser.add_argument(
|
1406 |
+
"--network_train_text_encoder_only",
|
1407 |
+
action="store_true",
|
1408 |
+
help="only training Text Encoder part / Text Encoder関連部分のみ学習する",
|
1409 |
+
)
|
1410 |
+
parser.add_argument(
|
1411 |
+
"--training_comment",
|
1412 |
+
type=str,
|
1413 |
+
default=None,
|
1414 |
+
help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列",
|
1415 |
+
)
|
1416 |
+
parser.add_argument(
|
1417 |
+
"--dim_from_weights",
|
1418 |
+
action="store_true",
|
1419 |
+
help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する",
|
1420 |
+
)
|
1421 |
+
parser.add_argument(
|
1422 |
+
"--scale_weight_norms",
|
1423 |
+
type=float,
|
1424 |
+
default=None,
|
1425 |
+
help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point) / 重みの値をスケ��リングして勾配爆発を防ぐ(1が初期値としては適当)",
|
1426 |
+
)
|
1427 |
+
parser.add_argument(
|
1428 |
+
"--base_weights",
|
1429 |
+
type=str,
|
1430 |
+
default=None,
|
1431 |
+
nargs="*",
|
1432 |
+
help="network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みファイル",
|
1433 |
+
)
|
1434 |
+
parser.add_argument(
|
1435 |
+
"--base_weights_multiplier",
|
1436 |
+
type=float,
|
1437 |
+
default=None,
|
1438 |
+
nargs="*",
|
1439 |
+
help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率",
|
1440 |
+
)
|
1441 |
+
parser.add_argument(
|
1442 |
+
"--no_half_vae",
|
1443 |
+
action="store_true",
|
1444 |
+
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
1445 |
+
)
|
1446 |
+
parser.add_argument(
|
1447 |
+
"--skip_until_initial_step",
|
1448 |
+
action="store_true",
|
1449 |
+
help="skip training until initial_step is reached / initial_stepに到達するまで学習をスキップする",
|
1450 |
+
)
|
1451 |
+
parser.add_argument(
|
1452 |
+
"--initial_epoch",
|
1453 |
+
type=int,
|
1454 |
+
default=None,
|
1455 |
+
help="initial epoch number, 1 means first epoch (same as not specifying). NOTE: initial_epoch/step doesn't affect to lr scheduler. Which means lr scheduler will start from 0 without `--resume`."
|
1456 |
+
+ " / 初期エポック数、1で最初のエポック(未指定時と同じ)。注意:initial_epoch/stepはlr schedulerに影響しないため、`--resume`しない場合はlr schedulerは0から始まる",
|
1457 |
+
)
|
1458 |
+
parser.add_argument(
|
1459 |
+
"--initial_step",
|
1460 |
+
type=int,
|
1461 |
+
default=None,
|
1462 |
+
help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch."
|
1463 |
+
+ " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする",
|
1464 |
+
)
|
1465 |
+
# parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio")
|
1466 |
+
# parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio")
|
1467 |
+
# parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")
|
1468 |
+
return parser
|
1469 |
+
|
1470 |
+
|
1471 |
+
if __name__ == "__main__":
|
1472 |
+
parser = setup_parser()
|
1473 |
+
|
1474 |
+
args = parser.parse_args()
|
1475 |
+
train_util.verify_command_line_training_args(args)
|
1476 |
+
args = train_util.read_config_from_file(args, parser)
|
1477 |
+
|
1478 |
+
trainer = NetworkTrainer()
|
1479 |
+
trainer.train(args)
|
train_network_asylora.py
ADDED
@@ -0,0 +1,1492 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import argparse
|
3 |
+
import math
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
import random
|
7 |
+
import time
|
8 |
+
import json
|
9 |
+
from multiprocessing import Value
|
10 |
+
from typing import Any, List
|
11 |
+
import toml
|
12 |
+
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
import torch
|
16 |
+
from library.device_utils import init_ipex, clean_memory_on_device
|
17 |
+
|
18 |
+
init_ipex()
|
19 |
+
|
20 |
+
from accelerate.utils import set_seed
|
21 |
+
from accelerate import Accelerator
|
22 |
+
from diffusers import DDPMScheduler
|
23 |
+
from library import deepspeed_utils, model_util, strategy_base, strategy_sd
|
24 |
+
|
25 |
+
import library.train_util as train_util
|
26 |
+
from library.train_util import DreamBoothDataset
|
27 |
+
import library.config_util as config_util
|
28 |
+
from library.config_util import (
|
29 |
+
ConfigSanitizer,
|
30 |
+
BlueprintGenerator,
|
31 |
+
)
|
32 |
+
import library.huggingface_util as huggingface_util
|
33 |
+
import library.custom_train_functions as custom_train_functions
|
34 |
+
from library.custom_train_functions import (
|
35 |
+
apply_snr_weight,
|
36 |
+
get_weighted_text_embeddings,
|
37 |
+
prepare_scheduler_for_custom_training,
|
38 |
+
scale_v_prediction_loss_like_noise_prediction,
|
39 |
+
add_v_prediction_like_loss,
|
40 |
+
apply_debiased_estimation,
|
41 |
+
apply_masked_loss,
|
42 |
+
)
|
43 |
+
from library.utils import setup_logging, add_logging_arguments
|
44 |
+
|
45 |
+
setup_logging()
|
46 |
+
import logging
|
47 |
+
|
48 |
+
logger = logging.getLogger(__name__)
|
49 |
+
|
50 |
+
|
51 |
+
class NetworkTrainer:
|
52 |
+
def __init__(self):
|
53 |
+
self.vae_scale_factor = 0.18215
|
54 |
+
self.is_sdxl = False
|
55 |
+
|
56 |
+
# TODO 他のスクリプトと共通化する
|
57 |
+
def generate_step_logs(
|
58 |
+
self,
|
59 |
+
args: argparse.Namespace,
|
60 |
+
current_loss,
|
61 |
+
avr_loss,
|
62 |
+
lr_scheduler,
|
63 |
+
lr_descriptions,
|
64 |
+
keys_scaled=None,
|
65 |
+
mean_norm=None,
|
66 |
+
maximum_norm=None,
|
67 |
+
):
|
68 |
+
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
69 |
+
|
70 |
+
if keys_scaled is not None:
|
71 |
+
logs["max_norm/keys_scaled"] = keys_scaled
|
72 |
+
logs["max_norm/average_key_norm"] = mean_norm
|
73 |
+
logs["max_norm/max_key_norm"] = maximum_norm
|
74 |
+
|
75 |
+
lrs = lr_scheduler.get_last_lr()
|
76 |
+
for i, lr in enumerate(lrs):
|
77 |
+
if lr_descriptions is not None:
|
78 |
+
lr_desc = lr_descriptions[i]
|
79 |
+
else:
|
80 |
+
idx = i - (0 if args.network_train_unet_only else -1)
|
81 |
+
if idx == -1:
|
82 |
+
lr_desc = "textencoder"
|
83 |
+
else:
|
84 |
+
if len(lrs) > 2:
|
85 |
+
lr_desc = f"group{idx}"
|
86 |
+
else:
|
87 |
+
lr_desc = "unet"
|
88 |
+
|
89 |
+
logs[f"lr/{lr_desc}"] = lr
|
90 |
+
|
91 |
+
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
|
92 |
+
# tracking d*lr value
|
93 |
+
logs[f"lr/d*lr/{lr_desc}"] = (
|
94 |
+
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
|
95 |
+
)
|
96 |
+
|
97 |
+
return logs
|
98 |
+
|
99 |
+
def assert_extra_args(self, args, train_dataset_group):
|
100 |
+
train_dataset_group.verify_bucket_reso_steps(64)
|
101 |
+
|
102 |
+
def load_target_model(self, args, weight_dtype, accelerator):
|
103 |
+
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
104 |
+
|
105 |
+
# モデルに xformers とか memory efficient attention を組み込む
|
106 |
+
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
107 |
+
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
|
108 |
+
vae.set_use_memory_efficient_attention_xformers(args.xformers)
|
109 |
+
|
110 |
+
return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet
|
111 |
+
|
112 |
+
def get_tokenize_strategy(self, args):
|
113 |
+
return strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
|
114 |
+
|
115 |
+
def get_tokenizers(self, tokenize_strategy: strategy_sd.SdTokenizeStrategy) -> List[Any]:
|
116 |
+
return [tokenize_strategy.tokenizer]
|
117 |
+
|
118 |
+
def get_latents_caching_strategy(self, args):
|
119 |
+
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
|
120 |
+
True, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
|
121 |
+
)
|
122 |
+
return latents_caching_strategy
|
123 |
+
|
124 |
+
def get_text_encoding_strategy(self, args):
|
125 |
+
return strategy_sd.SdTextEncodingStrategy(args.clip_skip)
|
126 |
+
|
127 |
+
def get_text_encoder_outputs_caching_strategy(self, args):
|
128 |
+
return None
|
129 |
+
|
130 |
+
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
131 |
+
"""
|
132 |
+
Returns a list of models that will be used for text encoding. SDXL uses wrapped and unwrapped models.
|
133 |
+
FLUX.1 and SD3 may cache some outputs of the text encoder, so return the models that will be used for encoding (not cached).
|
134 |
+
"""
|
135 |
+
return text_encoders
|
136 |
+
|
137 |
+
# returns a list of bool values indicating whether each text encoder should be trained
|
138 |
+
def get_text_encoders_train_flags(self, args, text_encoders):
|
139 |
+
return [True] * len(text_encoders) if self.is_train_text_encoder(args) else [False] * len(text_encoders)
|
140 |
+
|
141 |
+
def is_train_text_encoder(self, args):
|
142 |
+
return not args.network_train_unet_only
|
143 |
+
|
144 |
+
def cache_text_encoder_outputs_if_needed(self, args, accelerator, unet, vae, text_encoders, dataset, weight_dtype):
|
145 |
+
for t_enc in text_encoders:
|
146 |
+
t_enc.to(accelerator.device, dtype=weight_dtype)
|
147 |
+
|
148 |
+
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype, **kwargs):
|
149 |
+
noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample
|
150 |
+
return noise_pred
|
151 |
+
|
152 |
+
def all_reduce_network(self, accelerator, network):
|
153 |
+
for param in network.parameters():
|
154 |
+
if param.grad is not None:
|
155 |
+
param.grad = accelerator.reduce(param.grad, reduction="mean")
|
156 |
+
|
157 |
+
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoder, unet):
|
158 |
+
train_util.sample_iDDmages(accelerator, args, epoch, global_step, device, vae, tokenizers[0], text_encoder, unet)
|
159 |
+
|
160 |
+
# region SD/SDXL
|
161 |
+
|
162 |
+
def post_process_network(self, args, accelerator, network, text_encoders, unet):
|
163 |
+
pass
|
164 |
+
|
165 |
+
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
|
166 |
+
noise_scheduler = DDPMScheduler(
|
167 |
+
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
168 |
+
)
|
169 |
+
prepare_scheduler_for_custom_training(noise_scheduler, device)
|
170 |
+
if args.zero_terminal_snr:
|
171 |
+
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
|
172 |
+
return noise_scheduler
|
173 |
+
|
174 |
+
def encode_images_to_latents(self, args, accelerator, vae, images):
|
175 |
+
return vae.encode(images).latent_dist.sample()
|
176 |
+
|
177 |
+
def shift_scale_latents(self, args, latents):
|
178 |
+
return latents * self.vae_scale_factor
|
179 |
+
|
180 |
+
def get_noise_pred_and_target(
|
181 |
+
self,
|
182 |
+
args,
|
183 |
+
accelerator,
|
184 |
+
noise_scheduler,
|
185 |
+
latents,
|
186 |
+
batch,
|
187 |
+
text_encoder_conds,
|
188 |
+
unet,
|
189 |
+
network,
|
190 |
+
weight_dtype,
|
191 |
+
train_unet,
|
192 |
+
):
|
193 |
+
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
194 |
+
# with noise offset and/or multires noise if specified
|
195 |
+
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
196 |
+
|
197 |
+
# ensure the hidden state will require grad
|
198 |
+
if args.gradient_checkpointing:
|
199 |
+
for x in noisy_latents:
|
200 |
+
x.requires_grad_(True)
|
201 |
+
for t in text_encoder_conds:
|
202 |
+
t.requires_grad_(True)
|
203 |
+
|
204 |
+
# Predict the noise residual
|
205 |
+
with accelerator.autocast():
|
206 |
+
noise_pred = self.call_unet(
|
207 |
+
args,
|
208 |
+
accelerator,
|
209 |
+
unet,
|
210 |
+
noisy_latents.requires_grad_(train_unet),
|
211 |
+
timesteps,
|
212 |
+
text_encoder_conds,
|
213 |
+
batch,
|
214 |
+
weight_dtype,
|
215 |
+
)
|
216 |
+
|
217 |
+
if args.v_parameterization:
|
218 |
+
# v-parameterization training
|
219 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
220 |
+
else:
|
221 |
+
target = noise
|
222 |
+
|
223 |
+
# differential output preservation
|
224 |
+
if "custom_attributes" in batch:
|
225 |
+
diff_output_pr_indices = []
|
226 |
+
for i, custom_attributes in enumerate(batch["custom_attributes"]):
|
227 |
+
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
|
228 |
+
diff_output_pr_indices.append(i)
|
229 |
+
|
230 |
+
if len(diff_output_pr_indices) > 0:
|
231 |
+
network.set_multiplier(0.0)
|
232 |
+
with torch.no_grad(), accelerator.autocast():
|
233 |
+
noise_pred_prior = self.call_unet(
|
234 |
+
args,
|
235 |
+
accelerator,
|
236 |
+
unet,
|
237 |
+
noisy_latents,
|
238 |
+
timesteps,
|
239 |
+
text_encoder_conds,
|
240 |
+
batch,
|
241 |
+
weight_dtype,
|
242 |
+
indices=diff_output_pr_indices,
|
243 |
+
)
|
244 |
+
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
|
245 |
+
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)
|
246 |
+
|
247 |
+
return noise_pred, target, timesteps, huber_c, None
|
248 |
+
|
249 |
+
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
250 |
+
if args.min_snr_gamma:
|
251 |
+
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
252 |
+
if args.scale_v_pred_loss_like_noise_pred:
|
253 |
+
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
254 |
+
if args.v_pred_like_loss:
|
255 |
+
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
256 |
+
if args.debiased_estimation_loss:
|
257 |
+
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
|
258 |
+
return loss
|
259 |
+
|
260 |
+
def get_sai_model_spec(self, args):
|
261 |
+
return train_util.get_sai_model_spec(None, args, self.is_sdxl, True, False)
|
262 |
+
|
263 |
+
def update_metadata(self, metadata, args):
|
264 |
+
pass
|
265 |
+
|
266 |
+
def is_text_encoder_not_needed_for_training(self, args):
|
267 |
+
return False # use for sample images
|
268 |
+
|
269 |
+
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
|
270 |
+
# set top parameter requires_grad = True for gradient checkpointing works
|
271 |
+
text_encoder.text_model.embeddings.requires_grad_(True)
|
272 |
+
|
273 |
+
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
|
274 |
+
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
|
275 |
+
|
276 |
+
def prepare_unet_with_accelerator(
|
277 |
+
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
|
278 |
+
) -> torch.nn.Module:
|
279 |
+
return accelerator.prepare(unet)
|
280 |
+
|
281 |
+
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
282 |
+
pass
|
283 |
+
|
284 |
+
# endregion
|
285 |
+
|
286 |
+
def train(self, args):
|
287 |
+
session_id = random.randint(0, 2**32)
|
288 |
+
training_started_at = time.time()
|
289 |
+
train_util.verify_training_args(args)
|
290 |
+
train_util.prepare_dataset_args(args, True)
|
291 |
+
deepspeed_utils.prepare_deepspeed_args(args)
|
292 |
+
setup_logging(args, reset=True)
|
293 |
+
|
294 |
+
cache_latents = args.cache_latents
|
295 |
+
use_dreambooth_method = args.in_json is None
|
296 |
+
use_user_config = args.dataset_config is not None
|
297 |
+
|
298 |
+
if args.seed is None:
|
299 |
+
args.seed = random.randint(0, 2**32)
|
300 |
+
set_seed(args.seed)
|
301 |
+
|
302 |
+
tokenize_strategy = self.get_tokenize_strategy(args)
|
303 |
+
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
|
304 |
+
tokenizers = self.get_tokenizers(tokenize_strategy) # will be removed after sample_image is refactored
|
305 |
+
|
306 |
+
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
|
307 |
+
latents_caching_strategy = self.get_latents_caching_strategy(args)
|
308 |
+
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
|
309 |
+
|
310 |
+
# 准备数据集
|
311 |
+
if args.dataset_class is None:
|
312 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
|
313 |
+
if use_user_config:
|
314 |
+
logger.info(f"Loading dataset config from {args.dataset_config}")
|
315 |
+
user_config = config_util.load_user_config(args.dataset_config)
|
316 |
+
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
317 |
+
# if any(getattr(args, attr) is not None for attr in ignored):
|
318 |
+
# logger.warning(
|
319 |
+
# "ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
320 |
+
# ", ".join(ignored)
|
321 |
+
# )
|
322 |
+
# )
|
323 |
+
# else:
|
324 |
+
# if use_dreambooth_method:
|
325 |
+
# logger.info("Using DreamBooth method.")
|
326 |
+
# user_config = {
|
327 |
+
# "datasets": [
|
328 |
+
# {
|
329 |
+
# "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
|
330 |
+
# args.train_data_dir, args.reg_data_dir
|
331 |
+
# )
|
332 |
+
# }
|
333 |
+
# ]
|
334 |
+
# }
|
335 |
+
# else:
|
336 |
+
# logger.info("Training with captions.")
|
337 |
+
# user_config = {
|
338 |
+
# "datasets": [
|
339 |
+
# {
|
340 |
+
# "subsets": [
|
341 |
+
# {
|
342 |
+
# "image_dir": args.train_data_dir,
|
343 |
+
# "metadata_file": args.in_json,
|
344 |
+
# }
|
345 |
+
# ]
|
346 |
+
# }
|
347 |
+
# ]
|
348 |
+
# }
|
349 |
+
|
350 |
+
blueprint = blueprint_generator.generate(user_config, args) # user_config: LoraConfig.toml
|
351 |
+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
352 |
+
else:
|
353 |
+
# use arbitrary dataset class
|
354 |
+
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
355 |
+
|
356 |
+
current_epoch = Value("i", 0)
|
357 |
+
current_step = Value("i", 0)
|
358 |
+
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
359 |
+
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
360 |
+
|
361 |
+
if args.debug_dataset:
|
362 |
+
train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly
|
363 |
+
train_util.debug_dataset(train_dataset_group)
|
364 |
+
return
|
365 |
+
if len(train_dataset_group) == 0:
|
366 |
+
logger.error(
|
367 |
+
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)"
|
368 |
+
)
|
369 |
+
return
|
370 |
+
|
371 |
+
if cache_latents:
|
372 |
+
assert (
|
373 |
+
train_dataset_group.is_latent_cacheable()
|
374 |
+
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
375 |
+
|
376 |
+
self.assert_extra_args(args, train_dataset_group) # may change some args
|
377 |
+
|
378 |
+
# acceleratorを準備する
|
379 |
+
logger.info("preparing accelerator")
|
380 |
+
accelerator = train_util.prepare_accelerator(args)
|
381 |
+
is_main_process = accelerator.is_main_process
|
382 |
+
|
383 |
+
# mixed precisionに対応した型を用意しておき適宜castする
|
384 |
+
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
385 |
+
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
|
386 |
+
|
387 |
+
# 加载模型 model_version: flux text_encoder: t5, clip
|
388 |
+
model_version, text_encoder, vae, unet = self.load_target_model(args, weight_dtype, accelerator)
|
389 |
+
|
390 |
+
# text_encoder is List[CLIPTextModel] or CLIPTextModel
|
391 |
+
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
|
392 |
+
|
393 |
+
# 加载模型以进行额外的差异学习
|
394 |
+
sys.path.append(os.path.dirname(__file__))
|
395 |
+
accelerator.print("import network module:", args.network_module)
|
396 |
+
network_module = importlib.import_module(args.network_module)
|
397 |
+
|
398 |
+
# if args.base_weights is not None:
|
399 |
+
# # base_weights が指定されている場合は、指定された重みを読み込みマージする
|
400 |
+
# for i, weight_path in enumerate(args.base_weights):
|
401 |
+
# if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i:
|
402 |
+
# multiplier = 1.0
|
403 |
+
# else:
|
404 |
+
# multiplier = args.base_weights_multiplier[i]
|
405 |
+
#
|
406 |
+
# accelerator.print(f"merging module: {weight_path} with multiplier {multiplier}")
|
407 |
+
#
|
408 |
+
# module, weights_sd = network_module.create_network_from_weights(
|
409 |
+
# multiplier, weight_path, vae, text_encoder, unet, for_inference=True
|
410 |
+
# )
|
411 |
+
# module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu")
|
412 |
+
#
|
413 |
+
# accelerator.print(f"all weights merged: {', '.join(args.base_weights)}")
|
414 |
+
|
415 |
+
# 准备学习
|
416 |
+
if cache_latents:
|
417 |
+
vae.to(accelerator.device, dtype=vae_dtype)
|
418 |
+
vae.requires_grad_(False)
|
419 |
+
vae.eval()
|
420 |
+
|
421 |
+
train_dataset_group.new_cache_latents(vae, accelerator)
|
422 |
+
|
423 |
+
vae.to("cpu")
|
424 |
+
clean_memory_on_device(accelerator.device)
|
425 |
+
|
426 |
+
accelerator.wait_for_everyone()
|
427 |
+
|
428 |
+
# 如有必要,缓存文本编码器输出:将Text Encoder移至cpu或gpu
|
429 |
+
# cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu
|
430 |
+
text_encoding_strategy = self.get_text_encoding_strategy(args)
|
431 |
+
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
|
432 |
+
|
433 |
+
text_encoder_outputs_caching_strategy = self.get_text_encoder_outputs_caching_strategy(args)
|
434 |
+
if text_encoder_outputs_caching_strategy is not None:
|
435 |
+
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_outputs_caching_strategy)
|
436 |
+
self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, train_dataset_group, weight_dtype)
|
437 |
+
|
438 |
+
# prepare network
|
439 |
+
net_kwargs = {}
|
440 |
+
if args.network_args is not None:
|
441 |
+
for net_arg in args.network_args:
|
442 |
+
key, value = net_arg.split("=")
|
443 |
+
net_kwargs[key] = value
|
444 |
+
|
445 |
+
# if a new network is added in future, add if ~ then blocks for each network (;'?')
|
446 |
+
if args.dim_from_weights:
|
447 |
+
network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs)
|
448 |
+
else:
|
449 |
+
if "dropout" not in net_kwargs:
|
450 |
+
# workaround for LyCORIS (;^ω^)
|
451 |
+
net_kwargs["dropout"] = args.network_dropout
|
452 |
+
|
453 |
+
network = network_module.create_network(
|
454 |
+
1.0,
|
455 |
+
args.network_dim,
|
456 |
+
args.network_alpha,
|
457 |
+
vae,
|
458 |
+
text_encoder,
|
459 |
+
unet,
|
460 |
+
lora_ups_num=args.lora_ups_num,
|
461 |
+
neuron_dropout=args.network_dropout,
|
462 |
+
**net_kwargs,
|
463 |
+
)
|
464 |
+
if network is None:
|
465 |
+
return
|
466 |
+
network_has_multiplier = hasattr(network, "set_multiplier") # network_has_multiplier: True
|
467 |
+
|
468 |
+
if hasattr(network, "prepare_network"):
|
469 |
+
network.prepare_network(args)
|
470 |
+
if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"):
|
471 |
+
logger.warning(
|
472 |
+
"warning: scale_weight_norms is specified but the network does not support it / scale_weight_normsが指定されていますが、ネットワークが対応していません"
|
473 |
+
)
|
474 |
+
args.scale_weight_norms = False
|
475 |
+
|
476 |
+
self.post_process_network(args, accelerator, network, text_encoders, unet)
|
477 |
+
|
478 |
+
# apply network to unet and text_encoder
|
479 |
+
train_unet = not args.network_train_text_encoder_only # train_unet: True
|
480 |
+
train_text_encoder = self.is_train_text_encoder(args) # train_text_encoder: False
|
481 |
+
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
|
482 |
+
|
483 |
+
if args.network_weights is not None:
|
484 |
+
# FIXME consider alpha of weights: this assumes that the alpha is not changed
|
485 |
+
info = network.load_weights(args.network_weights)
|
486 |
+
accelerator.print(f"load network weights from {args.network_weights}: {info}")
|
487 |
+
|
488 |
+
if args.gradient_checkpointing:
|
489 |
+
if args.cpu_offload_checkpointing:
|
490 |
+
unet.enable_gradient_checkpointing(cpu_offload=True)
|
491 |
+
else:
|
492 |
+
unet.enable_gradient_checkpointing()
|
493 |
+
|
494 |
+
for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)):
|
495 |
+
if flag:
|
496 |
+
if t_enc.supports_gradient_checkpointing:
|
497 |
+
t_enc.gradient_checkpointing_enable()
|
498 |
+
del t_enc
|
499 |
+
network.enable_gradient_checkpointing() # may have no effect
|
500 |
+
|
501 |
+
# 准备优化器等超参数
|
502 |
+
accelerator.print("prepare optimizer, data loader etc.")
|
503 |
+
|
504 |
+
# make backward compatibility for text_encoder_lr
|
505 |
+
support_multiple_lrs = hasattr(network, "prepare_optimizer_params_with_multiple_te_lrs")
|
506 |
+
if support_multiple_lrs:
|
507 |
+
text_encoder_lr = args.text_encoder_lr
|
508 |
+
else:
|
509 |
+
# toml backward compatibility
|
510 |
+
if args.text_encoder_lr is None or isinstance(args.text_encoder_lr, float) or isinstance(args.text_encoder_lr, int):
|
511 |
+
text_encoder_lr = args.text_encoder_lr
|
512 |
+
else:
|
513 |
+
text_encoder_lr = None if len(args.text_encoder_lr) == 0 else args.text_encoder_lr[0]
|
514 |
+
try:
|
515 |
+
if support_multiple_lrs:
|
516 |
+
results = network.prepare_optimizer_params_with_multiple_te_lrs(text_encoder_lr, args.unet_lr, args.learning_rate)
|
517 |
+
else:
|
518 |
+
results = network.prepare_optimizer_params(text_encoder_lr, args.unet_lr, args.learning_rate)
|
519 |
+
if type(results) is tuple:
|
520 |
+
trainable_params = results[0]
|
521 |
+
lr_descriptions = results[1]
|
522 |
+
else:
|
523 |
+
trainable_params = results
|
524 |
+
lr_descriptions = None
|
525 |
+
except TypeError as e:
|
526 |
+
trainable_params = network.prepare_optimizer_params(text_encoder_lr, args.unet_lr)
|
527 |
+
lr_descriptions = None
|
528 |
+
|
529 |
+
# if len(trainable_params) == 0:
|
530 |
+
# accelerator.print("no trainable parameters found / 学習可能なパラメータが見つかりませんでした")
|
531 |
+
# for params in trainable_params:
|
532 |
+
# for k, v in params.items():
|
533 |
+
# if type(v) == float:
|
534 |
+
# pass
|
535 |
+
# else:
|
536 |
+
# v = len(v)
|
537 |
+
# accelerator.print(f"trainable_params: {k} = {v}")
|
538 |
+
|
539 |
+
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
|
540 |
+
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
|
541 |
+
|
542 |
+
# prepare dataloader
|
543 |
+
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
|
544 |
+
# some strategies can be None
|
545 |
+
train_dataset_group.set_current_strategies()
|
546 |
+
|
547 |
+
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
548 |
+
|
549 |
+
train_dataloader = torch.utils.data.DataLoader(
|
550 |
+
train_dataset_group,
|
551 |
+
batch_size=1,
|
552 |
+
shuffle=True,
|
553 |
+
collate_fn=collator,
|
554 |
+
num_workers=n_workers,
|
555 |
+
persistent_workers=args.persistent_data_loader_workers,
|
556 |
+
)
|
557 |
+
|
558 |
+
# 计算学习步数
|
559 |
+
if args.max_train_epochs is not None:
|
560 |
+
args.max_train_steps = args.max_train_epochs * math.ceil(
|
561 |
+
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
562 |
+
)
|
563 |
+
accelerator.print(
|
564 |
+
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
565 |
+
)
|
566 |
+
|
567 |
+
# 设置最大训练步数
|
568 |
+
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
569 |
+
|
570 |
+
# 设置调度器
|
571 |
+
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
572 |
+
|
573 |
+
# 实验功能:进行fp16/bf16学习包括梯度,将整个模型设置为fp16/bf16
|
574 |
+
if args.full_fp16: # 不走
|
575 |
+
assert (
|
576 |
+
args.mixed_precision == "fp16"
|
577 |
+
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
578 |
+
accelerator.print("enable full fp16 training.")
|
579 |
+
network.to(weight_dtype)
|
580 |
+
elif args.full_bf16: # 不走
|
581 |
+
assert (
|
582 |
+
args.mixed_precision == "bf16"
|
583 |
+
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
|
584 |
+
accelerator.print("enable full bf16 training.")
|
585 |
+
network.to(weight_dtype)
|
586 |
+
|
587 |
+
unet_weight_dtype = te_weight_dtype = weight_dtype
|
588 |
+
# Experimental Feature: Put base model into fp8 to save vram
|
589 |
+
if args.fp8_base or args.fp8_base_unet:
|
590 |
+
assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。"
|
591 |
+
assert (
|
592 |
+
args.mixed_precision != "no"
|
593 |
+
), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。"
|
594 |
+
accelerator.print("enable fp8 training for U-Net.")
|
595 |
+
unet_weight_dtype = torch.float8_e4m3fn # torch.float8_e4m3fn
|
596 |
+
|
597 |
+
if not args.fp8_base_unet:
|
598 |
+
accelerator.print("enable fp8 training for Text Encoder.")
|
599 |
+
te_weight_dtype = weight_dtype if args.fp8_base_unet else torch.float8_e4m3fn
|
600 |
+
|
601 |
+
# unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM
|
602 |
+
# unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory
|
603 |
+
|
604 |
+
# logger.info(f"set U-Net weight dtype to {unet_weight_dtype}, device to {accelerator.device}")
|
605 |
+
# unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above
|
606 |
+
logger.info(f"set U-Net weight dtype to {unet_weight_dtype}")
|
607 |
+
unet.to(dtype=unet_weight_dtype) # do not move to device because unet is not prepared by accelerator
|
608 |
+
|
609 |
+
unet.requires_grad_(False)
|
610 |
+
unet.to(dtype=unet_weight_dtype)
|
611 |
+
for i, t_enc in enumerate(text_encoders):
|
612 |
+
t_enc.requires_grad_(False)
|
613 |
+
|
614 |
+
# in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16
|
615 |
+
if t_enc.device.type != "cpu":
|
616 |
+
t_enc.to(dtype=te_weight_dtype)
|
617 |
+
|
618 |
+
# nn.Embedding not support FP8
|
619 |
+
if te_weight_dtype != weight_dtype:
|
620 |
+
self.prepare_text_encoder_fp8(i, t_enc, te_weight_dtype, weight_dtype)
|
621 |
+
|
622 |
+
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
|
623 |
+
if args.deepspeed: # 不走
|
624 |
+
flags = self.get_text_encoders_train_flags(args, text_encoders)
|
625 |
+
ds_model = deepspeed_utils.prepare_deepspeed_model(
|
626 |
+
args,
|
627 |
+
unet=unet if train_unet else None,
|
628 |
+
text_encoder1=text_encoders[0] if flags[0] else None,
|
629 |
+
text_encoder2=(text_encoders[1] if flags[1] else None) if len(text_encoders) > 1 else None,
|
630 |
+
network=network,
|
631 |
+
)
|
632 |
+
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
633 |
+
ds_model, optimizer, train_dataloader, lr_scheduler
|
634 |
+
)
|
635 |
+
training_model = ds_model
|
636 |
+
else:
|
637 |
+
if train_unet:
|
638 |
+
# default implementation is: unet = accelerator.prepare(unet)
|
639 |
+
unet = self.prepare_unet_with_accelerator(args, accelerator, unet) # accelerator does some magic here
|
640 |
+
else:
|
641 |
+
unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator
|
642 |
+
if train_text_encoder:
|
643 |
+
text_encoders = [
|
644 |
+
(accelerator.prepare(t_enc) if flag else t_enc)
|
645 |
+
for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders))
|
646 |
+
]
|
647 |
+
if len(text_encoders) > 1:
|
648 |
+
text_encoder = text_encoders
|
649 |
+
else:
|
650 |
+
text_encoder = text_encoders[0]
|
651 |
+
else:
|
652 |
+
pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set
|
653 |
+
|
654 |
+
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
655 |
+
network, optimizer, train_dataloader, lr_scheduler
|
656 |
+
)
|
657 |
+
training_model = network
|
658 |
+
|
659 |
+
if args.gradient_checkpointing:
|
660 |
+
# according to TI example in Diffusers, train is required
|
661 |
+
unet.train()
|
662 |
+
for i, (t_enc, frag) in enumerate(zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders))):
|
663 |
+
t_enc.train()
|
664 |
+
|
665 |
+
# set top parameter requires_grad = True for gradient checkpointing works
|
666 |
+
if frag:
|
667 |
+
self.prepare_text_encoder_grad_ckpt_workaround(i, t_enc)
|
668 |
+
|
669 |
+
else:
|
670 |
+
unet.eval()
|
671 |
+
for t_enc in text_encoders:
|
672 |
+
t_enc.eval()
|
673 |
+
|
674 |
+
del t_enc
|
675 |
+
|
676 |
+
accelerator.unwrap_model(network).prepare_grad_etc(text_encoder, unet)
|
677 |
+
|
678 |
+
if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する
|
679 |
+
vae.requires_grad_(False)
|
680 |
+
vae.eval()
|
681 |
+
vae.to(accelerator.device, dtype=vae_dtype)
|
682 |
+
|
683 |
+
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
684 |
+
if args.full_fp16:
|
685 |
+
train_util.patch_accelerator_for_fp16_training(accelerator)
|
686 |
+
|
687 |
+
# before resuming make hook for saving/loading to save/load the network weights only
|
688 |
+
def save_model_hook(models, weights, output_dir):
|
689 |
+
# pop weights of other models than network to save only network weights
|
690 |
+
# only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606
|
691 |
+
if accelerator.is_main_process or args.deepspeed:
|
692 |
+
remove_indices = []
|
693 |
+
for i, model in enumerate(models):
|
694 |
+
if not isinstance(model, type(accelerator.unwrap_model(network))):
|
695 |
+
remove_indices.append(i)
|
696 |
+
for i in reversed(remove_indices):
|
697 |
+
if len(weights) > i:
|
698 |
+
weights.pop(i)
|
699 |
+
# print(f"save model hook: {len(weights)} weights will be saved")
|
700 |
+
|
701 |
+
# save current ecpoch and step
|
702 |
+
train_state_file = os.path.join(output_dir, "train_state.json")
|
703 |
+
# +1 is needed because the state is saved before current_step is set from global_step
|
704 |
+
logger.info(f"save train state to {train_state_file} at epoch {current_epoch.value} step {current_step.value+1}")
|
705 |
+
with open(train_state_file, "w", encoding="utf-8") as f:
|
706 |
+
json.dump({"current_epoch": current_epoch.value, "current_step": current_step.value + 1}, f)
|
707 |
+
|
708 |
+
steps_from_state = None
|
709 |
+
|
710 |
+
def load_model_hook(models, input_dir):
|
711 |
+
# remove models except network
|
712 |
+
remove_indices = []
|
713 |
+
for i, model in enumerate(models):
|
714 |
+
if not isinstance(model, type(accelerator.unwrap_model(network))):
|
715 |
+
remove_indices.append(i)
|
716 |
+
for i in reversed(remove_indices):
|
717 |
+
models.pop(i)
|
718 |
+
# print(f"load model hook: {len(models)} models will be loaded")
|
719 |
+
|
720 |
+
# load current epoch and step to
|
721 |
+
nonlocal steps_from_state
|
722 |
+
train_state_file = os.path.join(input_dir, "train_state.json")
|
723 |
+
if os.path.exists(train_state_file):
|
724 |
+
with open(train_state_file, "r", encoding="utf-8") as f:
|
725 |
+
data = json.load(f)
|
726 |
+
steps_from_state = data["current_step"]
|
727 |
+
logger.info(f"load train state from {train_state_file}: {data}")
|
728 |
+
|
729 |
+
accelerator.register_save_state_pre_hook(save_model_hook)
|
730 |
+
accelerator.register_load_state_pre_hook(load_model_hook)
|
731 |
+
|
732 |
+
# resumeする
|
733 |
+
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
734 |
+
|
735 |
+
# epoch数を計算する
|
736 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
737 |
+
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
738 |
+
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
739 |
+
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
740 |
+
|
741 |
+
# 学習する
|
742 |
+
# TODO: find a way to handle total batch size when there are multiple datasets
|
743 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
744 |
+
|
745 |
+
accelerator.print("开始训练 / Training started")
|
746 |
+
accelerator.print(
|
747 |
+
f" 训练图片数 * 重复次数 / Number of training images * repeats: {train_dataset_group.num_train_images}")
|
748 |
+
accelerator.print(f" 正则化图片数 / Number of regularization images: {train_dataset_group.num_reg_images}")
|
749 |
+
accelerator.print(f" 每个 epoch 的批次数 / Number of batches per epoch: {len(train_dataloader)}")
|
750 |
+
accelerator.print(f" 训练的 epoch 数 / Number of epochs: {num_train_epochs}")
|
751 |
+
accelerator.print(
|
752 |
+
f" 每个设备的批次大小 / Batch size per device: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
|
753 |
+
)
|
754 |
+
# accelerator.print(f" 总批次大小(包括并行和分布式训练及梯度累积)/ Total batch size (with parallel & distributed & accumulation): {total_batch_size}")
|
755 |
+
accelerator.print(f" 梯度累积步数 / Gradient accumulation steps: {args.gradient_accumulation_steps}")
|
756 |
+
accelerator.print(f" 总优化步骤数 / Total optimization steps: {args.max_train_steps}")
|
757 |
+
|
758 |
+
# TODO refactor metadata creation and move to util
|
759 |
+
metadata = {
|
760 |
+
"ss_session_id": session_id, # random integer indicating which group of epochs the model came from
|
761 |
+
"ss_training_started_at": training_started_at, # unix timestamp
|
762 |
+
"ss_output_name": args.output_name,
|
763 |
+
"ss_learning_rate": args.learning_rate,
|
764 |
+
"ss_text_encoder_lr": text_encoder_lr,
|
765 |
+
"ss_unet_lr": args.unet_lr,
|
766 |
+
"ss_num_train_images": train_dataset_group.num_train_images,
|
767 |
+
"ss_num_reg_images": train_dataset_group.num_reg_images,
|
768 |
+
"ss_num_batches_per_epoch": len(train_dataloader),
|
769 |
+
"ss_num_epochs": num_train_epochs,
|
770 |
+
"ss_gradient_checkpointing": args.gradient_checkpointing,
|
771 |
+
"ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
|
772 |
+
"ss_max_train_steps": args.max_train_steps,
|
773 |
+
"ss_lr_warmup_steps": args.lr_warmup_steps,
|
774 |
+
"ss_lr_scheduler": args.lr_scheduler,
|
775 |
+
"ss_network_module": args.network_module,
|
776 |
+
"ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim
|
777 |
+
"ss_network_alpha": args.network_alpha, # some networks may not have alpha
|
778 |
+
"ss_network_dropout": args.network_dropout, # some networks may not have dropout
|
779 |
+
"ss_mixed_precision": args.mixed_precision,
|
780 |
+
"ss_full_fp16": bool(args.full_fp16),
|
781 |
+
"ss_v2": bool(args.v2),
|
782 |
+
"ss_base_model_version": model_version,
|
783 |
+
"ss_clip_skip": args.clip_skip,
|
784 |
+
"ss_max_token_length": args.max_token_length,
|
785 |
+
"ss_cache_latents": bool(args.cache_latents),
|
786 |
+
"ss_seed": args.seed,
|
787 |
+
"ss_lowram": args.lowram,
|
788 |
+
"ss_noise_offset": args.noise_offset,
|
789 |
+
"ss_multires_noise_iterations": args.multires_noise_iterations,
|
790 |
+
"ss_multires_noise_discount": args.multires_noise_discount,
|
791 |
+
"ss_adaptive_noise_scale": args.adaptive_noise_scale,
|
792 |
+
"ss_zero_terminal_snr": args.zero_terminal_snr,
|
793 |
+
"ss_training_comment": args.training_comment, # will not be updated after training
|
794 |
+
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
|
795 |
+
"ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""),
|
796 |
+
"ss_max_grad_norm": args.max_grad_norm,
|
797 |
+
"ss_caption_dropout_rate": args.caption_dropout_rate,
|
798 |
+
"ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs,
|
799 |
+
"ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate,
|
800 |
+
"ss_face_crop_aug_range": args.face_crop_aug_range,
|
801 |
+
"ss_prior_loss_weight": args.prior_loss_weight,
|
802 |
+
"ss_min_snr_gamma": args.min_snr_gamma,
|
803 |
+
"ss_scale_weight_norms": args.scale_weight_norms,
|
804 |
+
"ss_ip_noise_gamma": args.ip_noise_gamma,
|
805 |
+
"ss_debiased_estimation": bool(args.debiased_estimation_loss),
|
806 |
+
"ss_noise_offset_random_strength": args.noise_offset_random_strength,
|
807 |
+
"ss_ip_noise_gamma_random_strength": args.ip_noise_gamma_random_strength,
|
808 |
+
"ss_loss_type": args.loss_type,
|
809 |
+
"ss_huber_schedule": args.huber_schedule,
|
810 |
+
"ss_huber_c": args.huber_c,
|
811 |
+
"ss_fp8_base": bool(args.fp8_base),
|
812 |
+
"ss_fp8_base_unet": bool(args.fp8_base_unet),
|
813 |
+
}
|
814 |
+
|
815 |
+
self.update_metadata(metadata, args) # architecture specific metadata
|
816 |
+
|
817 |
+
if use_user_config:
|
818 |
+
# save metadata of multiple datasets
|
819 |
+
# NOTE: pack "ss_datasets" value as json one time
|
820 |
+
# or should also pack nested collections as json?
|
821 |
+
datasets_metadata = []
|
822 |
+
tag_frequency = {} # merge tag frequency for metadata editor
|
823 |
+
dataset_dirs_info = {} # merge subset dirs for metadata editor
|
824 |
+
|
825 |
+
for dataset in train_dataset_group.datasets:
|
826 |
+
is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset)
|
827 |
+
dataset_metadata = {
|
828 |
+
"is_dreambooth": is_dreambooth_dataset,
|
829 |
+
"batch_size_per_device": dataset.batch_size,
|
830 |
+
"num_train_images": dataset.num_train_images, # includes repeating
|
831 |
+
"num_reg_images": dataset.num_reg_images,
|
832 |
+
"resolution": (dataset.width, dataset.height),
|
833 |
+
"enable_bucket": bool(dataset.enable_bucket),
|
834 |
+
"min_bucket_reso": dataset.min_bucket_reso,
|
835 |
+
"max_bucket_reso": dataset.max_bucket_reso,
|
836 |
+
"tag_frequency": dataset.tag_frequency,
|
837 |
+
"bucket_info": dataset.bucket_info,
|
838 |
+
}
|
839 |
+
|
840 |
+
subsets_metadata = []
|
841 |
+
for subset in dataset.subsets:
|
842 |
+
subset_metadata = {
|
843 |
+
"img_count": subset.img_count,
|
844 |
+
"num_repeats": subset.num_repeats,
|
845 |
+
"color_aug": bool(subset.color_aug),
|
846 |
+
"flip_aug": bool(subset.flip_aug),
|
847 |
+
"random_crop": bool(subset.random_crop),
|
848 |
+
"shuffle_caption": bool(subset.shuffle_caption),
|
849 |
+
"keep_tokens": subset.keep_tokens,
|
850 |
+
"keep_tokens_separator": subset.keep_tokens_separator,
|
851 |
+
"secondary_separator": subset.secondary_separator,
|
852 |
+
"enable_wildcard": bool(subset.enable_wildcard),
|
853 |
+
"caption_prefix": subset.caption_prefix,
|
854 |
+
"caption_suffix": subset.caption_suffix,
|
855 |
+
}
|
856 |
+
|
857 |
+
image_dir_or_metadata_file = None
|
858 |
+
if subset.image_dir:
|
859 |
+
image_dir = os.path.basename(subset.image_dir)
|
860 |
+
subset_metadata["image_dir"] = image_dir
|
861 |
+
image_dir_or_metadata_file = image_dir
|
862 |
+
|
863 |
+
if is_dreambooth_dataset:
|
864 |
+
subset_metadata["class_tokens"] = subset.class_tokens
|
865 |
+
subset_metadata["is_reg"] = subset.is_reg
|
866 |
+
if subset.is_reg:
|
867 |
+
image_dir_or_metadata_file = None # not merging reg dataset
|
868 |
+
else:
|
869 |
+
metadata_file = os.path.basename(subset.metadata_file)
|
870 |
+
subset_metadata["metadata_file"] = metadata_file
|
871 |
+
image_dir_or_metadata_file = metadata_file # may overwrite
|
872 |
+
|
873 |
+
subsets_metadata.append(subset_metadata)
|
874 |
+
|
875 |
+
# merge dataset dir: not reg subset only
|
876 |
+
# TODO update additional-network extension to show detailed dataset config from metadata
|
877 |
+
if image_dir_or_metadata_file is not None:
|
878 |
+
# datasets may have a certain dir multiple times
|
879 |
+
v = image_dir_or_metadata_file
|
880 |
+
i = 2
|
881 |
+
while v in dataset_dirs_info:
|
882 |
+
v = image_dir_or_metadata_file + f" ({i})"
|
883 |
+
i += 1
|
884 |
+
image_dir_or_metadata_file = v
|
885 |
+
|
886 |
+
dataset_dirs_info[image_dir_or_metadata_file] = {
|
887 |
+
"n_repeats": subset.num_repeats,
|
888 |
+
"img_count": subset.img_count,
|
889 |
+
}
|
890 |
+
|
891 |
+
dataset_metadata["subsets"] = subsets_metadata
|
892 |
+
datasets_metadata.append(dataset_metadata)
|
893 |
+
|
894 |
+
# merge tag frequency:
|
895 |
+
for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items():
|
896 |
+
# 如果一个目录被多个dataset使用,则只计数一次
|
897 |
+
# 因为我们最初指定了重复次数,所以标签在标题中出现的次数和它在学习中使用的次数并不匹配。
|
898 |
+
# 所以在这里把多个dataset的次数加在一起也没什么意义
|
899 |
+
if ds_dir_name in tag_frequency:
|
900 |
+
continue
|
901 |
+
tag_frequency[ds_dir_name] = ds_freq_for_dir
|
902 |
+
|
903 |
+
metadata["ss_datasets"] = json.dumps(datasets_metadata)
|
904 |
+
metadata["ss_tag_frequency"] = json.dumps(tag_frequency)
|
905 |
+
metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info)
|
906 |
+
else:
|
907 |
+
# conserving backward compatibility when using train_dataset_dir and reg_dataset_dir
|
908 |
+
assert (
|
909 |
+
len(train_dataset_group.datasets) == 1
|
910 |
+
), f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。"
|
911 |
+
|
912 |
+
dataset = train_dataset_group.datasets[0]
|
913 |
+
|
914 |
+
dataset_dirs_info = {}
|
915 |
+
reg_dataset_dirs_info = {}
|
916 |
+
if use_dreambooth_method:
|
917 |
+
for subset in dataset.subsets:
|
918 |
+
info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info
|
919 |
+
info[os.path.basename(subset.image_dir)] = {"n_repeats": subset.num_repeats, "img_count": subset.img_count}
|
920 |
+
else:
|
921 |
+
for subset in dataset.subsets:
|
922 |
+
dataset_dirs_info[os.path.basename(subset.metadata_file)] = {
|
923 |
+
"n_repeats": subset.num_repeats,
|
924 |
+
"img_count": subset.img_count,
|
925 |
+
}
|
926 |
+
|
927 |
+
metadata.update(
|
928 |
+
{
|
929 |
+
"ss_batch_size_per_device": args.train_batch_size,
|
930 |
+
"ss_total_batch_size": total_batch_size,
|
931 |
+
"ss_resolution": args.resolution,
|
932 |
+
"ss_color_aug": bool(args.color_aug),
|
933 |
+
"ss_flip_aug": bool(args.flip_aug),
|
934 |
+
"ss_random_crop": bool(args.random_crop),
|
935 |
+
"ss_shuffle_caption": bool(args.shuffle_caption),
|
936 |
+
"ss_enable_bucket": bool(dataset.enable_bucket),
|
937 |
+
"ss_bucket_no_upscale": bool(dataset.bucket_no_upscale),
|
938 |
+
"ss_min_bucket_reso": dataset.min_bucket_reso,
|
939 |
+
"ss_max_bucket_reso": dataset.max_bucket_reso,
|
940 |
+
"ss_keep_tokens": args.keep_tokens,
|
941 |
+
"ss_dataset_dirs": json.dumps(dataset_dirs_info),
|
942 |
+
"ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info),
|
943 |
+
"ss_tag_frequency": json.dumps(dataset.tag_frequency),
|
944 |
+
"ss_bucket_info": json.dumps(dataset.bucket_info),
|
945 |
+
}
|
946 |
+
)
|
947 |
+
|
948 |
+
# add extra args
|
949 |
+
if args.network_args:
|
950 |
+
metadata["ss_network_args"] = json.dumps(net_kwargs)
|
951 |
+
|
952 |
+
# model name and hash
|
953 |
+
if args.pretrained_model_name_or_path is not None:
|
954 |
+
sd_model_name = args.pretrained_model_name_or_path
|
955 |
+
if os.path.exists(sd_model_name):
|
956 |
+
metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name)
|
957 |
+
metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(sd_model_name)
|
958 |
+
sd_model_name = os.path.basename(sd_model_name)
|
959 |
+
metadata["ss_sd_model_name"] = sd_model_name
|
960 |
+
|
961 |
+
if args.vae is not None:
|
962 |
+
vae_name = args.vae
|
963 |
+
if os.path.exists(vae_name):
|
964 |
+
metadata["ss_vae_hash"] = train_util.model_hash(vae_name)
|
965 |
+
metadata["ss_new_vae_hash"] = train_util.calculate_sha256(vae_name)
|
966 |
+
vae_name = os.path.basename(vae_name)
|
967 |
+
metadata["ss_vae_name"] = vae_name
|
968 |
+
|
969 |
+
metadata = {k: str(v) for k, v in metadata.items()}
|
970 |
+
|
971 |
+
# make minimum metadata for filtering
|
972 |
+
minimum_metadata = {}
|
973 |
+
for key in train_util.SS_METADATA_MINIMUM_KEYS:
|
974 |
+
if key in metadata:
|
975 |
+
minimum_metadata[key] = metadata[key]
|
976 |
+
|
977 |
+
# calculate steps to skip when resuming or starting from a specific step
|
978 |
+
initial_step = 0
|
979 |
+
if args.initial_epoch is not None or args.initial_step is not None:
|
980 |
+
# if initial_epoch or initial_step is specified, steps_from_state is ignored even when resuming
|
981 |
+
if steps_from_state is not None:
|
982 |
+
logger.warning(
|
983 |
+
"steps from the state is ignored because initial_step is specified / initial_stepが指定されているため、stateからのステップ数は無視されます"
|
984 |
+
)
|
985 |
+
if args.initial_step is not None:
|
986 |
+
initial_step = args.initial_step
|
987 |
+
else:
|
988 |
+
# num steps per epoch is calculated by num_processes and gradient_accumulation_steps
|
989 |
+
initial_step = (args.initial_epoch - 1) * math.ceil(
|
990 |
+
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
991 |
+
)
|
992 |
+
else:
|
993 |
+
# if initial_epoch and initial_step are not specified, steps_from_state is used when resuming
|
994 |
+
if steps_from_state is not None:
|
995 |
+
initial_step = steps_from_state
|
996 |
+
steps_from_state = None
|
997 |
+
|
998 |
+
if initial_step > 0:
|
999 |
+
assert (
|
1000 |
+
args.max_train_steps > initial_step
|
1001 |
+
), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: {args.max_train_steps} vs {initial_step}"
|
1002 |
+
|
1003 |
+
progress_bar = tqdm(
|
1004 |
+
range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps"
|
1005 |
+
)
|
1006 |
+
|
1007 |
+
epoch_to_start = 0
|
1008 |
+
if initial_step > 0:
|
1009 |
+
if args.skip_until_initial_step:
|
1010 |
+
# if skip_until_initial_step is specified, load data and discard it to ensure the same data is used
|
1011 |
+
if not args.resume:
|
1012 |
+
logger.info(
|
1013 |
+
f"initial_step is specified but not resuming. lr scheduler will be started from the beginning / initial_stepが指定されていますがresumeしていないため、lr schedulerは最初から始まります"
|
1014 |
+
)
|
1015 |
+
logger.info(f"skipping {initial_step} steps / {initial_step}ステップをスキップします")
|
1016 |
+
initial_step *= args.gradient_accumulation_steps
|
1017 |
+
|
1018 |
+
# set epoch to start to make initial_step less than len(train_dataloader)
|
1019 |
+
epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
1020 |
+
else:
|
1021 |
+
# if not, only epoch no is skipped for informative purpose
|
1022 |
+
epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
1023 |
+
initial_step = 0 # do not skip
|
1024 |
+
|
1025 |
+
global_step = 0
|
1026 |
+
|
1027 |
+
noise_scheduler = self.get_noise_scheduler(args, accelerator.device)
|
1028 |
+
|
1029 |
+
if accelerator.is_main_process:
|
1030 |
+
init_kwargs = {}
|
1031 |
+
if args.wandb_run_name:
|
1032 |
+
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
1033 |
+
if args.log_tracker_config is not None:
|
1034 |
+
init_kwargs = toml.load(args.log_tracker_config)
|
1035 |
+
accelerator.init_trackers(
|
1036 |
+
"network_train" if args.log_tracker_name is None else args.log_tracker_name,
|
1037 |
+
config=train_util.get_sanitized_config_or_none(args),
|
1038 |
+
init_kwargs=init_kwargs,
|
1039 |
+
)
|
1040 |
+
|
1041 |
+
loss_recorder = train_util.LossRecorder()
|
1042 |
+
del train_dataset_group
|
1043 |
+
|
1044 |
+
# callback for step start
|
1045 |
+
if hasattr(accelerator.unwrap_model(network), "on_step_start"):
|
1046 |
+
on_step_start_for_network = accelerator.unwrap_model(network).on_step_start
|
1047 |
+
else:
|
1048 |
+
on_step_start_for_network = lambda *args, **kwargs: None
|
1049 |
+
|
1050 |
+
# function for saving/removing
|
1051 |
+
def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False):
|
1052 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
1053 |
+
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
1054 |
+
|
1055 |
+
accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
|
1056 |
+
metadata["ss_training_finished_at"] = str(time.time())
|
1057 |
+
metadata["ss_steps"] = str(steps)
|
1058 |
+
metadata["ss_epoch"] = str(epoch_no)
|
1059 |
+
|
1060 |
+
metadata_to_save = minimum_metadata if args.no_metadata else metadata
|
1061 |
+
sai_metadata = self.get_sai_model_spec(args)
|
1062 |
+
metadata_to_save.update(sai_metadata)
|
1063 |
+
|
1064 |
+
unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save)
|
1065 |
+
if args.huggingface_repo_id is not None:
|
1066 |
+
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
|
1067 |
+
|
1068 |
+
def remove_model(old_ckpt_name):
|
1069 |
+
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
1070 |
+
if os.path.exists(old_ckpt_file):
|
1071 |
+
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
|
1072 |
+
os.remove(old_ckpt_file)
|
1073 |
+
|
1074 |
+
# if text_encoder is not needed for training, delete it to save memory.
|
1075 |
+
# TODO this can be automated after SDXL sample prompt cache is implemented
|
1076 |
+
if self.is_text_encoder_not_needed_for_training(args):
|
1077 |
+
logger.info("text_encoder is not needed for training. deleting to save memory.")
|
1078 |
+
for t_enc in text_encoders:
|
1079 |
+
del t_enc
|
1080 |
+
text_encoders = []
|
1081 |
+
text_encoder = None
|
1082 |
+
|
1083 |
+
# For --sample_at_first
|
1084 |
+
optimizer_eval_fn()
|
1085 |
+
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
|
1086 |
+
optimizer_train_fn()
|
1087 |
+
if len(accelerator.trackers) > 0:
|
1088 |
+
# log empty object to commit the sample images to wandb
|
1089 |
+
accelerator.log({}, step=0)
|
1090 |
+
|
1091 |
+
# training loop
|
1092 |
+
if initial_step > 0: # only if skip_until_initial_step is specified
|
1093 |
+
for skip_epoch in range(epoch_to_start): # skip epochs
|
1094 |
+
logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}")
|
1095 |
+
initial_step -= len(train_dataloader)
|
1096 |
+
global_step = initial_step
|
1097 |
+
|
1098 |
+
# log device and dtype for each model
|
1099 |
+
logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}")
|
1100 |
+
for i, t_enc in enumerate(text_encoders):
|
1101 |
+
params_itr = t_enc.parameters()
|
1102 |
+
params_itr.__next__() # skip the first parameter
|
1103 |
+
params_itr.__next__() # skip the second parameter. because CLIP first two parameters are embeddings
|
1104 |
+
param_3rd = params_itr.__next__()
|
1105 |
+
logger.info(f"text_encoder [{i}] dtype: {param_3rd.dtype}, device: {t_enc.device}")
|
1106 |
+
|
1107 |
+
clean_memory_on_device(accelerator.device)
|
1108 |
+
|
1109 |
+
for epoch in range(epoch_to_start, num_train_epochs):
|
1110 |
+
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
1111 |
+
current_epoch.value = epoch + 1
|
1112 |
+
|
1113 |
+
metadata["ss_epoch"] = str(epoch + 1)
|
1114 |
+
|
1115 |
+
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)
|
1116 |
+
|
1117 |
+
skipped_dataloader = None
|
1118 |
+
if initial_step > 0:
|
1119 |
+
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step - 1)
|
1120 |
+
initial_step = 1
|
1121 |
+
|
1122 |
+
for step, batch in enumerate(skipped_dataloader or train_dataloader):
|
1123 |
+
current_step.value = global_step
|
1124 |
+
if initial_step > 0:
|
1125 |
+
initial_step -= 1
|
1126 |
+
continue
|
1127 |
+
|
1128 |
+
with accelerator.accumulate(training_model):
|
1129 |
+
on_step_start_for_network(text_encoder, unet)
|
1130 |
+
|
1131 |
+
# temporary, for batch processing
|
1132 |
+
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
1133 |
+
|
1134 |
+
if "latents" in batch and batch["latents"] is not None:
|
1135 |
+
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
1136 |
+
else:
|
1137 |
+
with torch.no_grad():
|
1138 |
+
# latentに変換
|
1139 |
+
latents = self.encode_images_to_latents(args, accelerator, vae, batch["images"].to(vae_dtype))
|
1140 |
+
latents = latents.to(dtype=weight_dtype)
|
1141 |
+
|
1142 |
+
# NaNが含まれていれば警告を表示し0に置き換える
|
1143 |
+
if torch.any(torch.isnan(latents)):
|
1144 |
+
accelerator.print("NaN found in latents, replacing with zeros")
|
1145 |
+
latents = torch.nan_to_num(latents, 0, out=latents)
|
1146 |
+
|
1147 |
+
latents = self.shift_scale_latents(args, latents)
|
1148 |
+
|
1149 |
+
# get multiplier for each sample
|
1150 |
+
if network_has_multiplier:
|
1151 |
+
multipliers = batch["network_multipliers"]
|
1152 |
+
# if all multipliers are same, use single multiplier
|
1153 |
+
if torch.all(multipliers == multipliers[0]):
|
1154 |
+
multipliers = multipliers[0].item()
|
1155 |
+
else:
|
1156 |
+
raise NotImplementedError("multipliers for each sample is not supported yet")
|
1157 |
+
# print(f"set multiplier: {multipliers}")
|
1158 |
+
accelerator.unwrap_model(network).set_multiplier(multipliers)
|
1159 |
+
|
1160 |
+
text_encoder_conds = []
|
1161 |
+
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
1162 |
+
if text_encoder_outputs_list is not None:
|
1163 |
+
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
|
1164 |
+
|
1165 |
+
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
|
1166 |
+
# TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached'
|
1167 |
+
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
|
1168 |
+
# Get the text embedding for conditioning
|
1169 |
+
if args.weighted_captions:
|
1170 |
+
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
|
1171 |
+
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights(
|
1172 |
+
tokenize_strategy,
|
1173 |
+
self.get_models_for_text_encoding(args, accelerator, text_encoders),
|
1174 |
+
input_ids_list,
|
1175 |
+
weights_list,
|
1176 |
+
)
|
1177 |
+
else:
|
1178 |
+
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
|
1179 |
+
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens(
|
1180 |
+
tokenize_strategy,
|
1181 |
+
self.get_models_for_text_encoding(args, accelerator, text_encoders),
|
1182 |
+
input_ids,
|
1183 |
+
)
|
1184 |
+
if args.full_fp16:
|
1185 |
+
encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds]
|
1186 |
+
|
1187 |
+
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
|
1188 |
+
if len(text_encoder_conds) == 0:
|
1189 |
+
text_encoder_conds = encoded_text_encoder_conds
|
1190 |
+
else:
|
1191 |
+
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
|
1192 |
+
for i in range(len(encoded_text_encoder_conds)):
|
1193 |
+
if encoded_text_encoder_conds[i] is not None:
|
1194 |
+
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
1195 |
+
|
1196 |
+
# sample noise, call unet, get target
|
1197 |
+
noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target(
|
1198 |
+
args,
|
1199 |
+
accelerator,
|
1200 |
+
noise_scheduler,
|
1201 |
+
latents,
|
1202 |
+
batch, # 这里面有文本信息
|
1203 |
+
text_encoder_conds,
|
1204 |
+
unet,
|
1205 |
+
network,
|
1206 |
+
weight_dtype,
|
1207 |
+
train_unet,
|
1208 |
+
)
|
1209 |
+
|
1210 |
+
loss = train_util.conditional_loss(
|
1211 |
+
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
1212 |
+
)
|
1213 |
+
if weighting is not None:
|
1214 |
+
loss = loss * weighting
|
1215 |
+
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
1216 |
+
loss = apply_masked_loss(loss, batch)
|
1217 |
+
loss = loss.mean([1, 2, 3])
|
1218 |
+
|
1219 |
+
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
1220 |
+
loss = loss * loss_weights
|
1221 |
+
|
1222 |
+
# min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc.
|
1223 |
+
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
|
1224 |
+
|
1225 |
+
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
1226 |
+
|
1227 |
+
accelerator.backward(loss)
|
1228 |
+
if accelerator.sync_gradients:
|
1229 |
+
self.all_reduce_network(accelerator, network) # sync DDP grad manually
|
1230 |
+
if args.max_grad_norm != 0.0:
|
1231 |
+
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
|
1232 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
1233 |
+
|
1234 |
+
optimizer.step()
|
1235 |
+
lr_scheduler.step()
|
1236 |
+
optimizer.zero_grad(set_to_none=True)
|
1237 |
+
|
1238 |
+
if args.scale_weight_norms:
|
1239 |
+
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
|
1240 |
+
args.scale_weight_norms, accelerator.device
|
1241 |
+
)
|
1242 |
+
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
|
1243 |
+
else:
|
1244 |
+
keys_scaled, mean_norm, maximum_norm = None, None, None
|
1245 |
+
|
1246 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
1247 |
+
if accelerator.sync_gradients:
|
1248 |
+
progress_bar.update(1)
|
1249 |
+
global_step += 1
|
1250 |
+
|
1251 |
+
optimizer_eval_fn()
|
1252 |
+
self.sample_images(
|
1253 |
+
accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet
|
1254 |
+
)
|
1255 |
+
|
1256 |
+
# 指定ステップごとにモデルを保存
|
1257 |
+
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
1258 |
+
accelerator.wait_for_everyone()
|
1259 |
+
if accelerator.is_main_process:
|
1260 |
+
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
|
1261 |
+
save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch)
|
1262 |
+
|
1263 |
+
if args.save_state:
|
1264 |
+
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
|
1265 |
+
|
1266 |
+
remove_step_no = train_util.get_remove_step_no(args, global_step)
|
1267 |
+
if remove_step_no is not None:
|
1268 |
+
remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
|
1269 |
+
remove_model(remove_ckpt_name)
|
1270 |
+
optimizer_train_fn()
|
1271 |
+
|
1272 |
+
current_loss = loss.detach().item()
|
1273 |
+
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
1274 |
+
avr_loss: float = loss_recorder.moving_average
|
1275 |
+
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
1276 |
+
progress_bar.set_postfix(**logs)
|
1277 |
+
|
1278 |
+
if args.scale_weight_norms:
|
1279 |
+
progress_bar.set_postfix(**{**max_mean_logs, **logs})
|
1280 |
+
|
1281 |
+
if len(accelerator.trackers) > 0:
|
1282 |
+
logs = self.generate_step_logs(
|
1283 |
+
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, keys_scaled, mean_norm, maximum_norm
|
1284 |
+
)
|
1285 |
+
accelerator.log(logs, step=global_step)
|
1286 |
+
|
1287 |
+
if global_step >= args.max_train_steps:
|
1288 |
+
break
|
1289 |
+
|
1290 |
+
if len(accelerator.trackers) > 0:
|
1291 |
+
logs = {"loss/epoch": loss_recorder.moving_average}
|
1292 |
+
accelerator.log(logs, step=epoch + 1)
|
1293 |
+
|
1294 |
+
accelerator.wait_for_everyone()
|
1295 |
+
|
1296 |
+
# 指定エポックごとにモデルを保存
|
1297 |
+
optimizer_eval_fn()
|
1298 |
+
if args.save_every_n_epochs is not None:
|
1299 |
+
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
|
1300 |
+
if is_main_process and saving:
|
1301 |
+
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
|
1302 |
+
save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1)
|
1303 |
+
|
1304 |
+
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
|
1305 |
+
if remove_epoch_no is not None:
|
1306 |
+
remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no)
|
1307 |
+
remove_model(remove_ckpt_name)
|
1308 |
+
|
1309 |
+
if args.save_state:
|
1310 |
+
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
|
1311 |
+
|
1312 |
+
self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
|
1313 |
+
optimizer_train_fn()
|
1314 |
+
|
1315 |
+
# end of epoch
|
1316 |
+
|
1317 |
+
# metadata["ss_epoch"] = str(num_train_epochs)
|
1318 |
+
metadata["ss_training_finished_at"] = str(time.time())
|
1319 |
+
|
1320 |
+
if is_main_process:
|
1321 |
+
network = accelerator.unwrap_model(network)
|
1322 |
+
|
1323 |
+
accelerator.end_training()
|
1324 |
+
optimizer_eval_fn()
|
1325 |
+
|
1326 |
+
if is_main_process and (args.save_state or args.save_state_on_train_end):
|
1327 |
+
train_util.save_state_on_train_end(args, accelerator)
|
1328 |
+
|
1329 |
+
if is_main_process:
|
1330 |
+
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
1331 |
+
save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True)
|
1332 |
+
|
1333 |
+
logger.info("model saved.")
|
1334 |
+
|
1335 |
+
|
1336 |
+
def setup_parser() -> argparse.ArgumentParser:
|
1337 |
+
parser = argparse.ArgumentParser()
|
1338 |
+
|
1339 |
+
add_logging_arguments(parser)
|
1340 |
+
train_util.add_sd_models_arguments(parser)
|
1341 |
+
train_util.add_dataset_arguments(parser, True, True, True)
|
1342 |
+
train_util.add_training_arguments(parser, True)
|
1343 |
+
train_util.add_masked_loss_arguments(parser)
|
1344 |
+
deepspeed_utils.add_deepspeed_arguments(parser)
|
1345 |
+
train_util.add_optimizer_arguments(parser)
|
1346 |
+
config_util.add_config_arguments(parser)
|
1347 |
+
custom_train_functions.add_custom_train_arguments(parser)
|
1348 |
+
|
1349 |
+
parser.add_argument(
|
1350 |
+
"--cpu_offload_checkpointing",
|
1351 |
+
action="store_true",
|
1352 |
+
help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing for U-Net or DiT, if supported"
|
1353 |
+
" / 勾配チェックポイント時にテンソルをCPUにオフロードする(U-NetまたはDiTのみ、サポートされている場合)",
|
1354 |
+
)
|
1355 |
+
parser.add_argument(
|
1356 |
+
"--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない"
|
1357 |
+
)
|
1358 |
+
parser.add_argument(
|
1359 |
+
"--save_model_as",
|
1360 |
+
type=str,
|
1361 |
+
default="safetensors",
|
1362 |
+
choices=[None, "ckpt", "pt", "safetensors"],
|
1363 |
+
help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)",
|
1364 |
+
)
|
1365 |
+
|
1366 |
+
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
1367 |
+
parser.add_argument(
|
1368 |
+
"--text_encoder_lr",
|
1369 |
+
type=float,
|
1370 |
+
default=None,
|
1371 |
+
nargs="*",
|
1372 |
+
help="learning rate for Text Encoder, can be multiple / Text Encoderの学習率、複数指定可能",
|
1373 |
+
)
|
1374 |
+
parser.add_argument(
|
1375 |
+
"--fp8_base_unet",
|
1376 |
+
action="store_true",
|
1377 |
+
help="use fp8 for U-Net (or DiT), Text Encoder is fp16 or bf16"
|
1378 |
+
" / U-Net(またはDiT)にfp8を使用する。Text Encoderはfp16またはbf16",
|
1379 |
+
)
|
1380 |
+
|
1381 |
+
parser.add_argument(
|
1382 |
+
"--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み"
|
1383 |
+
)
|
1384 |
+
parser.add_argument(
|
1385 |
+
"--network_module", type=str, default=None, help="network module to train / 学習対象のネットワークのモジュール"
|
1386 |
+
)
|
1387 |
+
parser.add_argument(
|
1388 |
+
"--network_dim",
|
1389 |
+
type=int,
|
1390 |
+
default=None,
|
1391 |
+
help="network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)",
|
1392 |
+
)
|
1393 |
+
parser.add_argument(
|
1394 |
+
"--network_alpha",
|
1395 |
+
type=float,
|
1396 |
+
default=1,
|
1397 |
+
help="alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1(旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定)",
|
1398 |
+
)
|
1399 |
+
parser.add_argument(
|
1400 |
+
"--network_dropout",
|
1401 |
+
type=float,
|
1402 |
+
default=None,
|
1403 |
+
help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする(0またはNoneはdropoutなし、1は全ニューロンをdropout)",
|
1404 |
+
)
|
1405 |
+
parser.add_argument(
|
1406 |
+
"--network_args",
|
1407 |
+
type=str,
|
1408 |
+
default=None,
|
1409 |
+
nargs="*",
|
1410 |
+
help="additional arguments for network (key=value) / ネットワークへの追加の引数",
|
1411 |
+
)
|
1412 |
+
parser.add_argument(
|
1413 |
+
"--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する"
|
1414 |
+
)
|
1415 |
+
parser.add_argument(
|
1416 |
+
"--network_train_text_encoder_only",
|
1417 |
+
action="store_true",
|
1418 |
+
help="only training Text Encoder part / Text Encoder関連部分のみ学習する",
|
1419 |
+
)
|
1420 |
+
parser.add_argument(
|
1421 |
+
"--training_comment",
|
1422 |
+
type=str,
|
1423 |
+
default=None,
|
1424 |
+
help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列",
|
1425 |
+
)
|
1426 |
+
parser.add_argument(
|
1427 |
+
"--dim_from_weights",
|
1428 |
+
action="store_true",
|
1429 |
+
help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する",
|
1430 |
+
)
|
1431 |
+
parser.add_argument(
|
1432 |
+
"--scale_weight_norms",
|
1433 |
+
type=float,
|
1434 |
+
default=None,
|
1435 |
+
help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point) / 重みの値をスケーリングして勾配爆発を防ぐ(1が初期値としては適当)",
|
1436 |
+
)
|
1437 |
+
parser.add_argument(
|
1438 |
+
"--base_weights",
|
1439 |
+
type=str,
|
1440 |
+
default=None,
|
1441 |
+
nargs="*",
|
1442 |
+
help="network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みファイル",
|
1443 |
+
)
|
1444 |
+
parser.add_argument(
|
1445 |
+
"--base_weights_multiplier",
|
1446 |
+
type=float,
|
1447 |
+
default=None,
|
1448 |
+
nargs="*",
|
1449 |
+
help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率",
|
1450 |
+
)
|
1451 |
+
parser.add_argument(
|
1452 |
+
"--no_half_vae",
|
1453 |
+
action="store_true",
|
1454 |
+
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
1455 |
+
)
|
1456 |
+
parser.add_argument(
|
1457 |
+
"--skip_until_initial_step",
|
1458 |
+
action="store_true",
|
1459 |
+
help="skip training until initial_step is reached / initial_stepに到達するまで学習をスキップする",
|
1460 |
+
)
|
1461 |
+
parser.add_argument(
|
1462 |
+
"--initial_epoch",
|
1463 |
+
type=int,
|
1464 |
+
default=None,
|
1465 |
+
help="initial epoch number, 1 means first epoch (same as not specifying). NOTE: initial_epoch/step doesn't affect to lr scheduler. Which means lr scheduler will start from 0 without `--resume`."
|
1466 |
+
+ " / 初期エポック数、1で最初のエポック(未指定時と同じ)。注意:initial_epoch/stepはlr schedulerに影響しないため、`--resume`しない場合はlr schedulerは0から始まる",
|
1467 |
+
)
|
1468 |
+
parser.add_argument(
|
1469 |
+
"--initial_step",
|
1470 |
+
type=int,
|
1471 |
+
default=None,
|
1472 |
+
help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch."
|
1473 |
+
+ " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする",
|
1474 |
+
)
|
1475 |
+
parser.add_argument(
|
1476 |
+
"--lora_ups_num",
|
1477 |
+
type=int,
|
1478 |
+
required=True, # 参数必须填写
|
1479 |
+
help="初始化lora的下游矩阵个数"
|
1480 |
+
)
|
1481 |
+
return parser
|
1482 |
+
|
1483 |
+
|
1484 |
+
if __name__ == "__main__":
|
1485 |
+
parser = setup_parser()
|
1486 |
+
|
1487 |
+
args = parser.parse_args()
|
1488 |
+
train_util.verify_command_line_training_args(args)
|
1489 |
+
args = train_util.read_config_from_file(args, parser)
|
1490 |
+
|
1491 |
+
trainer = NetworkTrainer()
|
1492 |
+
trainer.train(args)
|