Commit
·
94a2309
0
Parent(s):
Fresh start
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +39 -0
- .gitignore +13 -0
- README.md +178 -0
- TRAIN.md +44 -0
- budget.jpg +3 -0
- cherrypick-vavae.ipynb +3 -0
- dataset_fromfolder.py +386 -0
- model_index.json +3 -0
- pipeline_sdxs.py +295 -0
- promo.png +3 -0
- requirements.txt +11 -0
- result_grid.jpg +3 -0
- samples/unet_192x384_0.jpg +3 -0
- samples/unet_256x384_0.jpg +3 -0
- samples/unet_320x384_0.jpg +3 -0
- samples/unet_384x192_0.jpg +3 -0
- samples/unet_384x256_0.jpg +3 -0
- samples/unet_384x320_0.jpg +3 -0
- samples/unet_384x384_0.jpg +3 -0
- scheduler/scheduler_config.json +3 -0
- src/captions_moondream2.ipynb +3 -0
- src/captions_moondream2_wd3.ipynb +3 -0
- src/captions_qwen2-vl-7b.py +261 -0
- src/captions_wd.ipynb +3 -0
- src/cherrypick.ipynb +3 -0
- src/cuda.ipynb +3 -0
- src/dataset_clean.ipynb +3 -0
- src/dataset_combine.py +68 -0
- src/dataset_fromzip.ipynb +3 -0
- src/dataset_imagenet.ipynb +3 -0
- src/dataset_laion_coco.ipynb +3 -0
- src/dataset_mjnj.ipynb +3 -0
- src/dataset_mnist-te.ipynb +3 -0
- src/dataset_mnist.ipynb +3 -0
- src/dataset_sample.ipynb +3 -0
- src/inference.ipynb +3 -0
- src/sdxs_create-vavae.ipynb +3 -0
- src/sdxs_create.ipynb +3 -0
- src/sdxs_create_simple.ipynb +3 -0
- src/sdxs_create_unet.ipynb +3 -0
- src/sdxs_sdxxs_transfer.ipynb +3 -0
- test.ipynb +3 -0
- text_encoder/config.json +3 -0
- text_encoder/model.fp16.safetensors +3 -0
- text_projector/config.json +3 -0
- text_projector/model.safetensors +3 -0
- tokenizer/special_tokens_map.json +3 -0
- tokenizer/tokenizer.json +3 -0
- tokenizer/tokenizer_config.json +3 -0
- train-Copy1.py +789 -0
.gitattributes
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
*.ipynb filter=lfs diff=lfs merge=lfs -text
|
39 |
+
*.json filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Jupyter Notebook
|
2 |
+
__pycache__/
|
3 |
+
*.pyc
|
4 |
+
.ipynb_checkpoints/
|
5 |
+
*.ipynb_checkpoints/*
|
6 |
+
.ipynb_checkpoints/*
|
7 |
+
src/samples
|
8 |
+
# cache
|
9 |
+
cache
|
10 |
+
datasets
|
11 |
+
test
|
12 |
+
wandb
|
13 |
+
nohup.out
|
README.md
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
pipeline_tag: text-to-image
|
4 |
+
---
|
5 |
+
|
6 |
+
# Simple Diffusion XS
|
7 |
+
|
8 |
+
*XS Size, Excess Quality*
|
9 |
+
|
10 |
+
|
11 |
+
Train status, in progress: [wandb](https://wandb.ai/recoilme/unet)
|
12 |
+
|
13 |
+

|
14 |
+
|
15 |
+
## Example
|
16 |
+
|
17 |
+
```python
|
18 |
+
import torch
|
19 |
+
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
|
20 |
+
from transformers import AutoModel, AutoTokenizer
|
21 |
+
from PIL import Image
|
22 |
+
from tqdm.auto import tqdm
|
23 |
+
import os
|
24 |
+
|
25 |
+
def encode_prompt(prompt, negative_prompt, device, dtype):
|
26 |
+
if negative_prompt is None:
|
27 |
+
negative_prompt = ""
|
28 |
+
|
29 |
+
with torch.no_grad():
|
30 |
+
positive_inputs = tokenizer(
|
31 |
+
prompt,
|
32 |
+
return_tensors="pt",
|
33 |
+
padding="max_length",
|
34 |
+
max_length=512,
|
35 |
+
truncation=True,
|
36 |
+
).to(device)
|
37 |
+
positive_embeddings = text_model.encode_texts(
|
38 |
+
positive_inputs.input_ids, positive_inputs.attention_mask
|
39 |
+
)
|
40 |
+
if positive_embeddings.ndim == 2:
|
41 |
+
positive_embeddings = positive_embeddings.unsqueeze(1)
|
42 |
+
positive_embeddings = positive_embeddings.to(device, dtype=dtype)
|
43 |
+
|
44 |
+
negative_inputs = tokenizer(
|
45 |
+
negative_prompt,
|
46 |
+
return_tensors="pt",
|
47 |
+
padding="max_length",
|
48 |
+
max_length=512,
|
49 |
+
truncation=True,
|
50 |
+
).to(device)
|
51 |
+
negative_embeddings = text_model.encode_texts(negative_inputs.input_ids, negative_inputs.attention_mask)
|
52 |
+
if negative_embeddings.ndim == 2:
|
53 |
+
negative_embeddings = negative_embeddings.unsqueeze(1)
|
54 |
+
negative_embeddings = negative_embeddings.to(device, dtype=dtype)
|
55 |
+
return torch.cat([negative_embeddings, positive_embeddings], dim=0)
|
56 |
+
|
57 |
+
def generate_latents(embeddings, height=576, width=576, num_inference_steps=50, guidance_scale=5.5):
|
58 |
+
with torch.no_grad():
|
59 |
+
device, dtype = embeddings.device, embeddings.dtype
|
60 |
+
half = embeddings.shape[0] // 2
|
61 |
+
latent_shape = (half, 16, height // 8, width // 8)
|
62 |
+
latents = torch.randn(latent_shape, device=device, dtype=dtype)
|
63 |
+
embeddings = embeddings.repeat_interleave(half, dim=0)
|
64 |
+
|
65 |
+
scheduler.set_timesteps(num_inference_steps)
|
66 |
+
|
67 |
+
for t in tqdm(scheduler.timesteps, desc="Генерация"):
|
68 |
+
latent_model_input = torch.cat([latents] * 2)
|
69 |
+
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
70 |
+
noise_pred = unet(latent_model_input, t, embeddings).sample
|
71 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
72 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
73 |
+
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
74 |
+
return latents
|
75 |
+
|
76 |
+
|
77 |
+
def decode_latents(latents, vae, output_type="pil"):
|
78 |
+
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
|
79 |
+
with torch.no_grad():
|
80 |
+
images = vae.decode(latents).sample
|
81 |
+
images = (images / 2 + 0.5).clamp(0, 1)
|
82 |
+
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
|
83 |
+
if output_type == "pil":
|
84 |
+
images = (images * 255).round().astype("uint8")
|
85 |
+
images = [Image.fromarray(image) for image in images]
|
86 |
+
return images
|
87 |
+
|
88 |
+
# Example usage:
|
89 |
+
if __name__ == "__main__":
|
90 |
+
device = "cuda"
|
91 |
+
dtype = torch.float16
|
92 |
+
|
93 |
+
prompt = "кот"
|
94 |
+
negative_prompt = "bad quality"
|
95 |
+
tokenizer = AutoTokenizer.from_pretrained("visheratin/mexma-siglip")
|
96 |
+
text_model = AutoModel.from_pretrained(
|
97 |
+
"visheratin/mexma-siglip", torch_dtype=dtype, trust_remote_code=True
|
98 |
+
).to(device, dtype=dtype).eval()
|
99 |
+
|
100 |
+
embeddings = encode_prompt(prompt, negative_prompt, device, dtype)
|
101 |
+
|
102 |
+
pipeid = "AiArtLab/sdxs"
|
103 |
+
variant = "fp16"
|
104 |
+
|
105 |
+
unet = UNet2DConditionModel.from_pretrained(pipeid, subfolder="unet", variant=variant).to(device, dtype=dtype).eval()
|
106 |
+
vae = AutoencoderKL.from_pretrained(pipeid, subfolder="vae", variant=variant).to(device, dtype=dtype).eval()
|
107 |
+
scheduler = DDPMScheduler.from_pretrained(pipeid, subfolder="scheduler")
|
108 |
+
|
109 |
+
|
110 |
+
height, width = 576, 576
|
111 |
+
num_inference_steps = 40
|
112 |
+
output_folder, project_name = "samples", "sdxs"
|
113 |
+
latents = generate_latents(
|
114 |
+
embeddings=embeddings,
|
115 |
+
height=height,
|
116 |
+
width=width,
|
117 |
+
num_inference_steps = num_inference_steps
|
118 |
+
)
|
119 |
+
|
120 |
+
images = decode_latents(latents, vae)
|
121 |
+
|
122 |
+
os.makedirs(output_folder, exist_ok=True)
|
123 |
+
for idx, image in enumerate(images):
|
124 |
+
image.save(f"{output_folder}/{project_name}_{idx}.jpg")
|
125 |
+
|
126 |
+
print("Images generated and saved to:", output_folder)
|
127 |
+
```
|
128 |
+
|
129 |
+
## Introduction
|
130 |
+
*Fast, Lightweight & Multilingual Diffusion for Everyone*
|
131 |
+
|
132 |
+
We are **AiArtLab**, a small team of enthusiasts with a limited budget. Our goal is to create a compact and fast model that can be trained on consumer graphics cards (full training cycle, not LoRA). We chose U-Net for its ability to efficiently handle small datasets and train quickly even on a 16GB GPU (e.g., RTX 4080). Our budget was limited to a few thousand dollars, significantly less than competitors like SDXL (tens of millions), so we decided to create a small but efficient model, similar to SD1.5 but for 2025 year.
|
133 |
+
|
134 |
+
## Encoder Architecture (Text and Images)
|
135 |
+
We experimented with various encoders and concluded that large models like LLaMA or T5 XXL are unnecessary for high-quality generation. However, we needed an encoder that understands the context of the query, focusing on "prompt understanding" versus "prompt following." We chose the multilingual encoder Mexma-SigLIP, which supports 80 languages and processes sentences rather than individual tokens. Mexma accepts up to 512 tokens, creating a large matrix that slows down training. Therefore, we used a pooling layer to simplify 512x1152 matrix with plain 1x1152 vector. Specifically, we passed it through a linear model/text projector to achieve compatibility with SigLIP embeddings. This allowed us to synchronize text embeddings with images, potentially leading to a unified multimodal model. This functionality enables mixing image embeddings with textual descriptions in queries. Moreover, the model can be trained without text descriptions, using only images. This should simplify training on videos, where annotation is challenging, and achieve more consistent and seamless video generation by inputting embeddings of previous frames with decay. In the future, we aim to expand the model to 3D/video generation.
|
136 |
+
|
137 |
+
## U-Net Architecture
|
138 |
+
We chose a smooth channel pyramid: [384, 576, 768, 960] with two layers per block and [4, 6, 8, 10] transformers with 1152/48=24 attention heads. This architecture provides the highest training speed with a model size of around 2 billion parameters (and fitting perfectly in my RTX 4080). We believe that due to its greater 'depth,' the quality will be on par with SDXL despite the smaller 'size.' The model can be expanded to 4 billion parameters by adding an 1152 layer, achieving perfect symmetry with the embedding size, which we value for its elegance, and probably 'Flux/MJ level' quality.
|
139 |
+
|
140 |
+
## VAE Architecture
|
141 |
+
We chose an unconventional 8x 16-channel AuraDiffusion VAE, which preserves details, text, and anatomy without the 'haze' characteristic of SD3/Flux. We used a fast version with FFN convolution, observing minor texture damage on fine patterns, which may lower its rating on benchmarks. Upscalers like ESRGAN can address these artifacts. Overall, we believe this VAE is highly underrated."
|
142 |
+
|
143 |
+
## Training Process
|
144 |
+
### Optimizer
|
145 |
+
We tested several optimizers (AdamW, Laion, Optimi-AdamW, Adafactor, and AdamW-8bit) and chose AdamW-8bit. Optimi-AdamW demonstrated the smoothest gradient decay curve, although AdamW-8bit behaves more chaotically. However, its smaller size allows for larger batch sizes, maximizing training speed on low-cost GPUs (we used 4xA6000 and 5xL40s for training).
|
146 |
+
|
147 |
+
### Learning Rate
|
148 |
+
We found that manipulating the decay/warm-up curve has an effect but is not significant. The optimal learning rate is often overestimated. Our experiments showed that Adam allows for a wide learning rate range. We started at 1e-4, gradually decreasing to 1e-6 during training. In other words, choosing the correct model architecture is far more critical than tweaking hyperparameters.
|
149 |
+
|
150 |
+
### Dataset
|
151 |
+
We trained the model on approximately 1 million images: 60 epochs on ImageNet at 256 resolution (wasted time because of low-quality annotations) and 8 epochs on CaptionEmporium/midjourney-niji-1m-llavanext, plus realistic photos and anime/art at 576 resolution. We used human prompts, Caption Emporium provided prompts, WD-Tagger from SmilingWolf, and Moondream2 for annotation, varying prompt length and composition to ensure the model understands different prompting styles. The dataset is extremely small, leading the model to miss many entities and struggle with unseen concepts like 'a goose on a bicycle.' The dataset also included many waifu-style images, as we were interested in how well the model learns human anatomy rather than drawing 'The Astronaut on horseback' skills. While most descriptions were in English, our tests indicate the model is multilingual.
|
152 |
+
|
153 |
+
## Limitations
|
154 |
+
- Limited concept coverage due to the extremely small dataset.
|
155 |
+
- The Image2Image functionality needs further training (we reduced the SigLIP portion to 5% to focus on text-to-image training).
|
156 |
+
|
157 |
+
## Acknowledgments
|
158 |
+
- **[Stan](https://t.me/Stangle)** — Key investor. Primary financial support - thank you for believing in us when others called it madness.
|
159 |
+
- **Captainsaturnus** — Material support.
|
160 |
+
- **Lovescape** & **Whargarbl** — Moral support.
|
161 |
+
- **[CaptionEmporium](https://huggingface.co/CaptionEmporium)** — Datasets.
|
162 |
+
|
163 |
+
> "We believe the future lies in efficient, compact models. We are grateful for the donations and hope for your continued support."
|
164 |
+
|
165 |
+
## Training budget
|
166 |
+
|
167 |
+

|
168 |
+
|
169 |
+
## Donations
|
170 |
+
|
171 |
+
Please contact with us if you may provide some GPU's or money on training
|
172 |
+
|
173 |
+
DOGE: DEw2DR8C7BnF8GgcrfTzUjSnGkuMeJhg83
|
174 |
+
BTC: 3JHv9Hb8kEW8zMAccdgCdZGfrHeMhH1rpN
|
175 |
+
|
176 |
+
## Contacts
|
177 |
+
|
178 |
+
[recoilme](https://t.me/recoilme)
|
TRAIN.md
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
---
|
4 |
+
|
5 |
+
Краткая инструкция по установке
|
6 |
+
Обновите систему и установите git-lfs:
|
7 |
+
|
8 |
+
```
|
9 |
+
apt update
|
10 |
+
apt install git-lfs
|
11 |
+
git config --global credential.helper store
|
12 |
+
```
|
13 |
+
Обновите pip и установите требуемые пакеты:
|
14 |
+
|
15 |
+
```
|
16 |
+
python -m pip install --upgrade pip
|
17 |
+
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124 -U
|
18 |
+
pip install flash-attn --no-build-isolation # optional
|
19 |
+
```
|
20 |
+
Клонируйте репозиторий:
|
21 |
+
|
22 |
+
```
|
23 |
+
git clone https://huggingface.co/AiArtLab/sdxs
|
24 |
+
cd sdxs/
|
25 |
+
pip install -r requirements.txt
|
26 |
+
```
|
27 |
+
Подготовьте датасет:
|
28 |
+
|
29 |
+
```
|
30 |
+
mkdir datasets
|
31 |
+
cd datasets
|
32 |
+
huggingface-cli download AiArtLab/384 --local-dir 384 --repo-type dataset
|
33 |
+
```
|
34 |
+
Выполните вход в сервисы:
|
35 |
+
|
36 |
+
```
|
37 |
+
huggingface-cli login
|
38 |
+
wandb login
|
39 |
+
```
|
40 |
+
Запустите обучение!
|
41 |
+
|
42 |
+
```
|
43 |
+
nohup accelerate launch train.py &
|
44 |
+
```
|
budget.jpg
ADDED
![]() |
Git LFS Details
|
cherrypick-vavae.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d76d1eb4ba99e75f234a07e86ef0503ebe6ddec695c61308b6d20a5a6eca9f0f
|
3 |
+
size 16287
|
dataset_fromfolder.py
ADDED
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pip install flash-attn --no-build-isolation
|
2 |
+
from datasets import Dataset, load_from_disk, concatenate_datasets
|
3 |
+
from diffusers import AutoencoderKL
|
4 |
+
from torchvision.transforms import Resize, ToTensor, Normalize, Compose, InterpolationMode, Lambda
|
5 |
+
from transformers import AutoModel, AutoImageProcessor, AutoTokenizer
|
6 |
+
import torch
|
7 |
+
import os
|
8 |
+
import gc
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image
|
11 |
+
from tqdm import tqdm
|
12 |
+
import random
|
13 |
+
import json
|
14 |
+
import shutil
|
15 |
+
import time
|
16 |
+
from datetime import timedelta
|
17 |
+
|
18 |
+
# ---------------- 1️⃣ Настройки ----------------
|
19 |
+
dtype = torch.float16
|
20 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
21 |
+
batch_size = 10
|
22 |
+
min_size = 192
|
23 |
+
max_size = 384
|
24 |
+
step = 64
|
25 |
+
img_share = 0.05
|
26 |
+
empty_share = 0.05
|
27 |
+
limit = 0
|
28 |
+
textemb_full = False
|
29 |
+
# Основная процедура обработки
|
30 |
+
folder_path = "/workspace/d3"
|
31 |
+
save_path = "/workspace/sdxs/datasets/ds3_384"
|
32 |
+
os.makedirs(save_path, exist_ok=True)
|
33 |
+
|
34 |
+
# Функция для очистки CUDA памяти
|
35 |
+
def clear_cuda_memory():
|
36 |
+
if torch.cuda.is_available():
|
37 |
+
used_gb = torch.cuda.max_memory_allocated() / 1024**3
|
38 |
+
print(f"used_gb: {used_gb:.2f} GB")
|
39 |
+
torch.cuda.empty_cache()
|
40 |
+
gc.collect()
|
41 |
+
|
42 |
+
# ---------------- 2️⃣ Загрузка моделей ----------------
|
43 |
+
def load_models():
|
44 |
+
print("Загрузка моделей...")
|
45 |
+
vae = AutoencoderKL.from_pretrained("/workspace/sdxs/vae", variant="fp16",torch_dtype=dtype).to(device).eval()
|
46 |
+
model = AutoModel.from_pretrained("visheratin/mexma-siglip", torch_dtype=dtype, trust_remote_code=True, optimized=True).to(device).eval()
|
47 |
+
processor = AutoImageProcessor.from_pretrained("visheratin/mexma-siglip", use_fast=True)
|
48 |
+
tokenizer = AutoTokenizer.from_pretrained("visheratin/mexma-siglip")
|
49 |
+
return vae, model, processor, tokenizer
|
50 |
+
|
51 |
+
vae, model, processor, tokenizer = load_models()
|
52 |
+
|
53 |
+
|
54 |
+
# ---------------- 3️⃣ Трансформации ----------------
|
55 |
+
def get_image_transform(min_size=256, max_size=512, step=64):
|
56 |
+
def transform(img, dry_run=False):
|
57 |
+
# Сохраняем исходные размеры изображения
|
58 |
+
original_width, original_height = img.size
|
59 |
+
|
60 |
+
# 0. Ресайз: масштабируем изображение, чтобы максимальная сторона была равна max_size
|
61 |
+
if original_width >= original_height:
|
62 |
+
new_width = max_size
|
63 |
+
new_height = int(max_size * original_height / original_width)
|
64 |
+
else:
|
65 |
+
new_height = max_size
|
66 |
+
new_width = int(max_size * original_width / original_height)
|
67 |
+
|
68 |
+
if new_height < min_size or new_width < min_size:
|
69 |
+
# 1. Ресайз: масштабируем изображение, чтобы минимальная сторона была равна min_size
|
70 |
+
if original_width <= original_height:
|
71 |
+
new_width = min_size
|
72 |
+
new_height = int(min_size * original_height / original_width)
|
73 |
+
else:
|
74 |
+
new_height = min_size
|
75 |
+
new_width = int(min_size * original_width / original_height)
|
76 |
+
|
77 |
+
# 2. Проверка: если одна из сторон превышает max_size, готовимся к обрезке
|
78 |
+
crop_width = min(max_size, (new_width // step) * step)
|
79 |
+
crop_height = min(max_size, (new_height // step) * step)
|
80 |
+
|
81 |
+
# Убеждаемся, что размеры обрезки не меньше min_size
|
82 |
+
crop_width = max(min_size, crop_width)
|
83 |
+
crop_height = max(min_size, crop_height)
|
84 |
+
|
85 |
+
# Если запрошен только предварительный расчёт размеров
|
86 |
+
if dry_run:
|
87 |
+
return crop_width, crop_height
|
88 |
+
|
89 |
+
# Конвертация в RGB и ресайз
|
90 |
+
img_resized = img.convert("RGB").resize((new_width, new_height), Image.LANCZOS)
|
91 |
+
|
92 |
+
# Определение координат обрезки (обрезаем с учетом вотермарок - треть сверху)
|
93 |
+
top = (new_height - crop_height) // 3
|
94 |
+
left = 0
|
95 |
+
|
96 |
+
# Обрезка изображения
|
97 |
+
img_cropped = img_resized.crop((left, top, left + crop_width, top + crop_height))
|
98 |
+
|
99 |
+
# Сохраняем итоговые размеры после всех преобразований
|
100 |
+
final_width, final_height = img_cropped.size
|
101 |
+
|
102 |
+
# тензор
|
103 |
+
img_tensor = ToTensor()(img_cropped)
|
104 |
+
img_tensor = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])(img_tensor)
|
105 |
+
return img_tensor, img_cropped, final_width, final_height
|
106 |
+
|
107 |
+
return transform
|
108 |
+
|
109 |
+
# ---------------- 4️⃣ Функции обработки ----------------
|
110 |
+
def encode_images_batch(images, processor, model):
|
111 |
+
pixel_values = torch.stack([processor(images=img, return_tensors="pt")["pixel_values"].squeeze(0) for img in images]).to(device, dtype)
|
112 |
+
|
113 |
+
with torch.inference_mode():
|
114 |
+
image_embeddings = model.vision_model(pixel_values).pooler_output
|
115 |
+
|
116 |
+
return image_embeddings.unsqueeze(1).cpu().numpy()
|
117 |
+
|
118 |
+
def encode_texts_batch(texts, tokenizer, model):
|
119 |
+
with torch.inference_mode():
|
120 |
+
text_tokenized = tokenizer(texts, return_tensors="pt", padding="max_length",
|
121 |
+
max_length=512,
|
122 |
+
truncation=True).to(device)
|
123 |
+
text_embeddings = model.encode_texts(text_tokenized.input_ids, text_tokenized.attention_mask)
|
124 |
+
return text_embeddings.unsqueeze(1).cpu().numpy()
|
125 |
+
|
126 |
+
def encode_texts_batch_full(texts, tokenizer, model):
|
127 |
+
with torch.inference_mode():
|
128 |
+
text_tokenized = tokenizer(texts, return_tensors="pt", padding="max_length",max_length=512,truncation=True).to(device)
|
129 |
+
features = model.text_model(
|
130 |
+
input_ids=text_tokenized.input_ids, attention_mask=text_tokenized.attention_mask
|
131 |
+
).last_hidden_state
|
132 |
+
features_proj = model.text_projector(features)
|
133 |
+
return features_proj.cpu().numpy()
|
134 |
+
|
135 |
+
def clean_label(label):
|
136 |
+
label = label.replace("Image 1", "").replace("Image 2", "").replace("Image 3", "").replace("Image 4", "")
|
137 |
+
return label
|
138 |
+
|
139 |
+
def process_labels_for_guidance(original_labels, prob_to_make_empty=0.01):
|
140 |
+
"""
|
141 |
+
Обрабатывает список меток для classifier-free guidance.
|
142 |
+
|
143 |
+
С вероятностью prob_to_make_empty:
|
144 |
+
- Метка в первом списке заменяется на пустую строку.
|
145 |
+
- К метке во втором списке добавляется префикс "zero:".
|
146 |
+
|
147 |
+
В противном случае метки в обоих списках остаются оригинальными.
|
148 |
+
|
149 |
+
"""
|
150 |
+
labels_for_model = []
|
151 |
+
labels_for_logging = []
|
152 |
+
|
153 |
+
for label in original_labels:
|
154 |
+
if random.random() < prob_to_make_empty:
|
155 |
+
labels_for_model.append("") # Заменяем на пустую строку для модели
|
156 |
+
labels_for_logging.append(f"zero: {label}") # Добавляем префикс для логгирования
|
157 |
+
else:
|
158 |
+
labels_for_model.append(label) # Оставляем оригинальную метку для модели
|
159 |
+
labels_for_logging.append(label) # Оставляем оригинальную метку для логгирования
|
160 |
+
|
161 |
+
return labels_for_model, labels_for_logging
|
162 |
+
|
163 |
+
def encode_to_latents(images, texts):
|
164 |
+
transform = get_image_transform(min_size, max_size, step)
|
165 |
+
|
166 |
+
try:
|
167 |
+
# Обработка изображений (все одинакового размера)
|
168 |
+
transformed_tensors = []
|
169 |
+
pil_images = []
|
170 |
+
widths, heights = [], []
|
171 |
+
|
172 |
+
# Применяем трансформацию ко всем изображениям
|
173 |
+
for img in images:
|
174 |
+
try:
|
175 |
+
t_img, pil_img, w, h = transform(img)
|
176 |
+
transformed_tensors.append(t_img)
|
177 |
+
pil_images.append(pil_img)
|
178 |
+
widths.append(w)
|
179 |
+
heights.append(h)
|
180 |
+
except Exception as e:
|
181 |
+
print(f"Ошибка трансформации: {e}")
|
182 |
+
continue
|
183 |
+
|
184 |
+
if not transformed_tensors:
|
185 |
+
return None
|
186 |
+
|
187 |
+
# Создаём батч
|
188 |
+
batch_tensor = torch.stack(transformed_tensors).to(device, dtype)
|
189 |
+
|
190 |
+
# Кодируем батч
|
191 |
+
with torch.no_grad():
|
192 |
+
posteriors = vae.encode(batch_tensor).latent_dist.mode()
|
193 |
+
latents = (posteriors - vae.config.shift_factor) * vae.config.scaling_factor
|
194 |
+
|
195 |
+
latents_np = latents.cpu().numpy()
|
196 |
+
|
197 |
+
# Проверка однородности форм
|
198 |
+
base_shape = latents_np.shape[1:] # Форма без батча
|
199 |
+
valid_indices = []
|
200 |
+
valid_latents = []
|
201 |
+
|
202 |
+
for idx, latent in enumerate(latents_np):
|
203 |
+
if latent.shape != base_shape:
|
204 |
+
print(f"❌ Несоответствие формы в индексе {idx}: {latent.shape} vs {base_shape}")
|
205 |
+
continue
|
206 |
+
valid_indices.append(idx)
|
207 |
+
valid_latents.append(latent)
|
208 |
+
|
209 |
+
# Фильтруем данные
|
210 |
+
valid_pil = [pil_images[i] for i in valid_indices]
|
211 |
+
valid_widths = [widths[i] for i in valid_indices]
|
212 |
+
valid_heights = [heights[i] for i in valid_indices]
|
213 |
+
|
214 |
+
# Обрабатываем тексты
|
215 |
+
text_labels = [clean_label(texts[i]) for i in valid_indices]
|
216 |
+
if random.random() < img_share:
|
217 |
+
embeddings = encode_images_batch(valid_pil, processor, model)
|
218 |
+
text_labels = [f"img: {text_labels[i]}" for i in valid_indices]
|
219 |
+
else:
|
220 |
+
model_prompts, text_labels = process_labels_for_guidance(text_labels, empty_share)
|
221 |
+
if textemb_full:
|
222 |
+
embeddings = encode_texts_batch_full(model_prompts, tokenizer, model)
|
223 |
+
else:
|
224 |
+
embeddings = encode_texts_batch(model_prompts, tokenizer, model)
|
225 |
+
|
226 |
+
return {
|
227 |
+
"vae": np.array(valid_latents),
|
228 |
+
"embeddings": embeddings,
|
229 |
+
"text": text_labels,
|
230 |
+
"width": valid_widths,
|
231 |
+
"height": valid_heights
|
232 |
+
}
|
233 |
+
|
234 |
+
except Exception as e:
|
235 |
+
print(f"Критическая ошибка в encode_to_latents: {e}")
|
236 |
+
raise
|
237 |
+
|
238 |
+
# ---------------- 5️⃣ Обработка папки с изображениями и текстами ----------------
|
239 |
+
def process_folder(folder_path, limit=None):
|
240 |
+
"""
|
241 |
+
Рекурсивно обходит указанную директорию и все вложенные директории,
|
242 |
+
собирая пути к изображениям и соответствующим текстовым файлам.
|
243 |
+
"""
|
244 |
+
image_paths = []
|
245 |
+
text_paths = []
|
246 |
+
width = []
|
247 |
+
height = []
|
248 |
+
transform = get_image_transform(min_size, max_size, step)
|
249 |
+
|
250 |
+
# Используем os.walk для рекурсивного обхода директорий
|
251 |
+
for root, dirs, files in os.walk(folder_path):
|
252 |
+
for filename in files:
|
253 |
+
# Проверяем, является ли файл изображением
|
254 |
+
if filename.lower().endswith((".jpg", ".jpeg", ".png")):
|
255 |
+
image_path = os.path.join(root, filename)
|
256 |
+
try:
|
257 |
+
img = Image.open(image_path)
|
258 |
+
except Exception as e:
|
259 |
+
print(f"Ошибка при открытии {image_path}: {e}")
|
260 |
+
os.remove(image_path)
|
261 |
+
text_path = os.path.splitext(image_path)[0] + ".txt"
|
262 |
+
if os.path.exists(text_path):
|
263 |
+
os.remove(text_path)
|
264 |
+
continue
|
265 |
+
# Применяем трансформацию только для получения размеров
|
266 |
+
w, h = transform(img, dry_run=True)
|
267 |
+
# Формируем путь к текстовому файлу
|
268 |
+
text_path = os.path.splitext(image_path)[0] + ".txt"
|
269 |
+
|
270 |
+
# Добавляем пути, если текстовый файл существует
|
271 |
+
if os.path.exists(text_path) and min(w, h)>0:
|
272 |
+
image_paths.append(image_path)
|
273 |
+
text_paths.append(text_path)
|
274 |
+
width.append(w) # Добавляем в список
|
275 |
+
height.append(h) # Добавляем в список
|
276 |
+
|
277 |
+
# Проверяем ограничение на количество
|
278 |
+
if limit and limit>0 and len(image_paths) >= limit:
|
279 |
+
print(f"Достигнут лимит в {limit} изображений")
|
280 |
+
return image_paths, text_paths, width, height
|
281 |
+
|
282 |
+
print(f"Найдено {len(image_paths)} изображений с текстовыми описаниями")
|
283 |
+
return image_paths, text_paths, width, height
|
284 |
+
|
285 |
+
def process_in_chunks(image_paths, text_paths, width, height, chunk_size=50000, batch_size=1):
|
286 |
+
total_files = len(image_paths)
|
287 |
+
start_time = time.time()
|
288 |
+
chunks = range(0, total_files, chunk_size)
|
289 |
+
|
290 |
+
for chunk_idx, start in enumerate(chunks, 1):
|
291 |
+
end = min(start + chunk_size, total_files)
|
292 |
+
chunk_image_paths = image_paths[start:end]
|
293 |
+
chunk_text_paths = text_paths[start:end]
|
294 |
+
chunk_widths = width[start:end] if isinstance(width, list) else [width] * len(chunk_image_paths)
|
295 |
+
chunk_heights = height[start:end] if isinstance(height, list) else [height] * len(chunk_image_paths)
|
296 |
+
|
297 |
+
# Чтение текстов
|
298 |
+
chunk_texts = []
|
299 |
+
for text_path in chunk_text_paths:
|
300 |
+
try:
|
301 |
+
with open(text_path, 'r', encoding='utf-8') as f:
|
302 |
+
text = f.read().strip()
|
303 |
+
chunk_texts.append(text)
|
304 |
+
except Exception as e:
|
305 |
+
print(f"Ошибка чтения {text_path}: {e}")
|
306 |
+
chunk_texts.append("")
|
307 |
+
|
308 |
+
# Группируем изображения по размерам
|
309 |
+
size_groups = {}
|
310 |
+
for i in range(len(chunk_image_paths)):
|
311 |
+
size_key = (chunk_widths[i], chunk_heights[i])
|
312 |
+
if size_key not in size_groups:
|
313 |
+
size_groups[size_key] = {"image_paths": [], "texts": []}
|
314 |
+
size_groups[size_key]["image_paths"].append(chunk_image_paths[i])
|
315 |
+
size_groups[size_key]["texts"].append(chunk_texts[i])
|
316 |
+
|
317 |
+
# Обрабатываем каждую группу размеров отдельно
|
318 |
+
for size_key, group_data in size_groups.items():
|
319 |
+
print(f"Обработка группы с размером {size_key[0]}x{size_key[1]} - {len(group_data['image_paths'])} изображений")
|
320 |
+
|
321 |
+
group_dataset = Dataset.from_dict({
|
322 |
+
"image_path": group_data["image_paths"],
|
323 |
+
"text": group_data["texts"]
|
324 |
+
})
|
325 |
+
|
326 |
+
# Теперь можно использовать указанный batch_size, т.к. все изображения одного размера
|
327 |
+
processed_group = group_dataset.map(
|
328 |
+
lambda examples: encode_to_latents(
|
329 |
+
[Image.open(path) for path in examples["image_path"]],
|
330 |
+
examples["text"]
|
331 |
+
),
|
332 |
+
batched=True,
|
333 |
+
batch_size=batch_size,
|
334 |
+
remove_columns=["image_path"],
|
335 |
+
desc=f"Обработка группы размера {size_key[0]}x{size_key[1]}"
|
336 |
+
)
|
337 |
+
|
338 |
+
# Сохраняем результаты группы
|
339 |
+
group_save_path = f"{save_path}_temp/chunk_{chunk_idx}_size_{size_key[0]}x{size_key[1]}"
|
340 |
+
processed_group.save_to_disk(group_save_path)
|
341 |
+
clear_cuda_memory()
|
342 |
+
elapsed = time.time() - start_time
|
343 |
+
processed = (chunk_idx - 1) * chunk_size + sum([len(sg["image_paths"]) for sg in list(size_groups.values())[:list(size_groups.values()).index(group_data) + 1]])
|
344 |
+
if processed > 0:
|
345 |
+
remaining = (elapsed / processed) * (total_files - processed)
|
346 |
+
elapsed_str = str(timedelta(seconds=int(elapsed)))
|
347 |
+
remaining_str = str(timedelta(seconds=int(remaining)))
|
348 |
+
print(f"ETA: Прошло {elapsed_str}, Осталось {remaining_str}, Прогресс {processed}/{total_files} ({processed/total_files:.1%})")
|
349 |
+
|
350 |
+
# ---------------- 7️⃣ Объединение чанков ----------------
|
351 |
+
def combine_chunks(temp_path, final_path):
|
352 |
+
"""Объединение обработанных чанков в финальный датасет"""
|
353 |
+
chunks = sorted([
|
354 |
+
os.path.join(temp_path, d)
|
355 |
+
for d in os.listdir(temp_path)
|
356 |
+
if d.startswith("chunk_")
|
357 |
+
])
|
358 |
+
|
359 |
+
datasets = [load_from_disk(chunk) for chunk in chunks]
|
360 |
+
combined = concatenate_datasets(datasets)
|
361 |
+
combined.save_to_disk(final_path)
|
362 |
+
|
363 |
+
print(f"✅ Датасет успешно сохранен в: {final_path}")
|
364 |
+
|
365 |
+
|
366 |
+
|
367 |
+
# Создаем временную папку для чанков
|
368 |
+
temp_path = f"{save_path}_temp"
|
369 |
+
os.makedirs(temp_path, exist_ok=True)
|
370 |
+
|
371 |
+
# Получаем список файлов
|
372 |
+
image_paths, text_paths, width, height = process_folder(folder_path,limit)
|
373 |
+
print(f"Всего найдено {len(image_paths)} изображений")
|
374 |
+
|
375 |
+
# Обработка с чанкованием
|
376 |
+
process_in_chunks(image_paths, text_paths, width, height, chunk_size=100000, batch_size=batch_size)
|
377 |
+
|
378 |
+
# Объединение чанков в финальный датасет
|
379 |
+
combine_chunks(temp_path, save_path)
|
380 |
+
|
381 |
+
# Удаление временной папки
|
382 |
+
try:
|
383 |
+
shutil.rmtree(temp_path)
|
384 |
+
print(f"✅ Временная папка {temp_path} успешно удалена")
|
385 |
+
except Exception as e:
|
386 |
+
print(f"⚠️ Ошибка при удалении временной папки: {e}")
|
model_index.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:85e884f29f7a6282a634d90350933aa68021326035a0980072667757c3bc9112
|
3 |
+
size 476
|
pipeline_sdxs.py
ADDED
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers import DiffusionPipeline
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import os
|
5 |
+
from diffusers.utils import BaseOutput
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from typing import List, Union, Optional
|
8 |
+
from PIL import Image
|
9 |
+
import numpy as np
|
10 |
+
import json
|
11 |
+
from safetensors.torch import load_file
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
@dataclass
|
15 |
+
class SdxsPipelineOutput(BaseOutput):
|
16 |
+
images: Union[List[Image.Image], np.ndarray]
|
17 |
+
|
18 |
+
class SdxsPipeline(DiffusionPipeline):
|
19 |
+
def __init__(self, vae, text_encoder, tokenizer, unet, scheduler, text_projector=None):
|
20 |
+
super().__init__()
|
21 |
+
|
22 |
+
# Register components
|
23 |
+
self.register_modules(
|
24 |
+
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer,
|
25 |
+
unet=unet, scheduler=scheduler
|
26 |
+
)
|
27 |
+
|
28 |
+
# Get the model path, which is either provided directly or from internal dict
|
29 |
+
model_path = None
|
30 |
+
if hasattr(self, '_internal_dict') and self._internal_dict.get('_name_or_path'):
|
31 |
+
model_path = self._internal_dict.get('_name_or_path')
|
32 |
+
|
33 |
+
# Get device and dtype from existing components
|
34 |
+
device = "cuda"
|
35 |
+
dtype = torch.float16
|
36 |
+
|
37 |
+
# Always load text_projector, regardless of whether one was provided
|
38 |
+
projector_path = None
|
39 |
+
|
40 |
+
# Try to find projector path
|
41 |
+
if model_path and os.path.exists(f"{model_path}/text_projector"):
|
42 |
+
projector_path = f"{model_path}/text_projector"
|
43 |
+
elif os.path.exists("./text_projector"):
|
44 |
+
projector_path = "./text_projector"
|
45 |
+
|
46 |
+
if projector_path:
|
47 |
+
# Create and load projector
|
48 |
+
try:
|
49 |
+
with open(f"{projector_path}/config.json", "r") as f:
|
50 |
+
projector_config = json.load(f)
|
51 |
+
|
52 |
+
# Create Linear layer with bias=False
|
53 |
+
self.text_projector = nn.Linear(
|
54 |
+
in_features=projector_config["in_features"],
|
55 |
+
out_features=projector_config["out_features"],
|
56 |
+
bias=False
|
57 |
+
)
|
58 |
+
|
59 |
+
# Load the state dict using safetensors
|
60 |
+
self.text_projector.load_state_dict(load_file(f"{projector_path}/model.safetensors"))
|
61 |
+
self.text_projector.to(device=device, dtype=dtype)
|
62 |
+
print(f"Successfully loaded text_projector from {projector_path}",device, dtype)
|
63 |
+
except Exception as e:
|
64 |
+
print(f"Error loading text_projector: {e}")
|
65 |
+
|
66 |
+
self.vae_scale_factor = 8
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
def encode_prompt(self, prompt=None, negative_prompt=None, device=None, dtype=None):
|
71 |
+
"""Кодирование текстовых промптов в эмбеддинги.
|
72 |
+
|
73 |
+
Возвращает:
|
74 |
+
- text_embeddings: Тензор эмбеддингов [batch_size, 1, dim] или [2*batch_size, 1, dim] с guidance
|
75 |
+
"""
|
76 |
+
if prompt is None and negative_prompt is None:
|
77 |
+
raise ValueError("Требуется хотя бы один из параметров: prompt или negative_prompt")
|
78 |
+
|
79 |
+
# Устанавливаем device и dtype
|
80 |
+
device = device or self.device
|
81 |
+
dtype = dtype or next(self.unet.parameters()).dtype
|
82 |
+
|
83 |
+
with torch.no_grad():
|
84 |
+
# Обрабатываем позитивный промпт
|
85 |
+
if prompt is not None:
|
86 |
+
if isinstance(prompt, str):
|
87 |
+
prompt = [prompt]
|
88 |
+
|
89 |
+
text_inputs = self.tokenizer(
|
90 |
+
prompt, return_tensors="pt", padding="max_length",
|
91 |
+
max_length=512, truncation=True
|
92 |
+
).to(device)
|
93 |
+
|
94 |
+
# Получаем эмбеддинги
|
95 |
+
outputs = self.text_encoder(text_inputs.input_ids, text_inputs.attention_mask)
|
96 |
+
last_hidden_state = outputs.last_hidden_state.to(device, dtype=dtype)
|
97 |
+
pos_embeddings = self.text_projector(last_hidden_state[:, 0])
|
98 |
+
|
99 |
+
# Добавляем размерность для batch processing
|
100 |
+
if pos_embeddings.ndim == 2:
|
101 |
+
pos_embeddings = pos_embeddings.unsqueeze(1)
|
102 |
+
else:
|
103 |
+
# Создаем пустые эмбеддинги, если нет позитивного промпта
|
104 |
+
# (полезно для некоторых сценариев с unconditional generation)
|
105 |
+
batch_size = len(negative_prompt) if isinstance(negative_prompt, list) else 1
|
106 |
+
pos_embeddings = torch.zeros(
|
107 |
+
batch_size, 1, self.unet.config.cross_attention_dim,
|
108 |
+
device=device, dtype=dtype
|
109 |
+
)
|
110 |
+
|
111 |
+
# Обрабатываем негативный промпт
|
112 |
+
if negative_prompt is not None:
|
113 |
+
if isinstance(negative_prompt, str):
|
114 |
+
negative_prompt = [negative_prompt]
|
115 |
+
|
116 |
+
# Убеждаемся, что размеры негативного и позитивного промптов совпадают
|
117 |
+
if prompt is not None and len(negative_prompt) != len(prompt):
|
118 |
+
neg_batch_size = len(prompt)
|
119 |
+
if len(negative_prompt) == 1:
|
120 |
+
negative_prompt = negative_prompt * neg_batch_size
|
121 |
+
else:
|
122 |
+
negative_prompt = negative_prompt[:neg_batch_size]
|
123 |
+
|
124 |
+
neg_inputs = self.tokenizer(
|
125 |
+
negative_prompt, return_tensors="pt", padding="max_length",
|
126 |
+
max_length=512, truncation=True
|
127 |
+
).to(device)
|
128 |
+
|
129 |
+
neg_outputs = self.text_encoder(neg_inputs.input_ids, neg_inputs.attention_mask)
|
130 |
+
neg_last_hidden_state = neg_outputs.last_hidden_state.to(device, dtype=dtype)
|
131 |
+
neg_embeddings = self.text_projector(neg_last_hidden_state[:, 0])
|
132 |
+
|
133 |
+
if neg_embeddings.ndim == 2:
|
134 |
+
neg_embeddings = neg_embeddings.unsqueeze(1)
|
135 |
+
|
136 |
+
# Объединяем для classifier-free guidance
|
137 |
+
text_embeddings = torch.cat([neg_embeddings, pos_embeddings], dim=0)
|
138 |
+
else:
|
139 |
+
# Если нет негативного промпта, используем нулевые эмбеддинги
|
140 |
+
batch_size = pos_embeddings.shape[0]
|
141 |
+
neg_embeddings = torch.zeros_like(pos_embeddings)
|
142 |
+
text_embeddings = torch.cat([neg_embeddings, pos_embeddings], dim=0)
|
143 |
+
|
144 |
+
return text_embeddings.to(device=device, dtype=dtype)
|
145 |
+
|
146 |
+
@torch.no_grad()
|
147 |
+
def generate_latents(
|
148 |
+
self,
|
149 |
+
text_embeddings,
|
150 |
+
height: int = 576,
|
151 |
+
width: int = 576,
|
152 |
+
num_inference_steps: int = 40,
|
153 |
+
guidance_scale: float = 5.0,
|
154 |
+
latent_channels: int = 16,
|
155 |
+
batch_size: int = 1,
|
156 |
+
generator = None,
|
157 |
+
):
|
158 |
+
"""Генерация латентов с использованием эмбеддингов промптов."""
|
159 |
+
device = self.device
|
160 |
+
dtype = next(self.unet.parameters()).dtype
|
161 |
+
|
162 |
+
# Проверка размера эмбеддингов
|
163 |
+
do_classifier_free_guidance = guidance_scale > 0
|
164 |
+
embedding_dim = text_embeddings.shape[0] // 2 if do_classifier_free_guidance else text_embeddings.shape[0]
|
165 |
+
|
166 |
+
if batch_size > embedding_dim:
|
167 |
+
# Повторяем эмбеддинги до нужного размера батча
|
168 |
+
if do_classifier_free_guidance:
|
169 |
+
neg_embeds, pos_embeds = text_embeddings.chunk(2)
|
170 |
+
neg_embeds = neg_embeds.repeat(batch_size // embedding_dim, 1, 1)
|
171 |
+
pos_embeds = pos_embeds.repeat(batch_size // embedding_dim, 1, 1)
|
172 |
+
text_embeddings = torch.cat([neg_embeds, pos_embeds], dim=0)
|
173 |
+
else:
|
174 |
+
text_embeddings = text_embeddings.repeat(batch_size // embedding_dim, 1, 1)
|
175 |
+
|
176 |
+
# Установка timesteps
|
177 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
178 |
+
|
179 |
+
# Инициализация латентов с заданным seed
|
180 |
+
latent_shape = (
|
181 |
+
batch_size,
|
182 |
+
latent_channels,
|
183 |
+
height // self.vae_scale_factor,
|
184 |
+
width // self.vae_scale_factor
|
185 |
+
)
|
186 |
+
latents = torch.randn(
|
187 |
+
latent_shape,
|
188 |
+
device=device,
|
189 |
+
dtype=dtype,
|
190 |
+
generator=generator
|
191 |
+
)
|
192 |
+
|
193 |
+
# Процесс диффузии
|
194 |
+
for t in tqdm(self.scheduler.timesteps, desc="Генерация"):
|
195 |
+
# Подготовка входных данных
|
196 |
+
if do_classifier_free_guidance:
|
197 |
+
latent_input = torch.cat([latents] * 2)
|
198 |
+
else:
|
199 |
+
latent_input = latents
|
200 |
+
|
201 |
+
latent_input = self.scheduler.scale_model_input(latent_input, t)
|
202 |
+
|
203 |
+
# Предсказание шума
|
204 |
+
noise_pred = self.unet(latent_input, t, text_embeddings).sample
|
205 |
+
|
206 |
+
# Применение guidance
|
207 |
+
if do_classifier_free_guidance:
|
208 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
209 |
+
noise_pred = noise_pred_uncond + guidance_scale * (
|
210 |
+
noise_pred_text - noise_pred_uncond
|
211 |
+
)
|
212 |
+
|
213 |
+
# Обновление латентов
|
214 |
+
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
|
215 |
+
|
216 |
+
return latents
|
217 |
+
|
218 |
+
def decode_latents(self, latents, output_type="pil"):
|
219 |
+
"""Декодирование латентов в изображения."""
|
220 |
+
# Нормализация латентов
|
221 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
222 |
+
|
223 |
+
# Декодирование
|
224 |
+
with torch.no_grad():
|
225 |
+
images = self.vae.decode(latents).sample
|
226 |
+
|
227 |
+
# Нормализация изображений
|
228 |
+
images = (images / 2 + 0.5).clamp(0, 1)
|
229 |
+
|
230 |
+
# Конвертация в нужный формат
|
231 |
+
if output_type == "pil":
|
232 |
+
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
|
233 |
+
images = (images * 255).round().astype("uint8")
|
234 |
+
return [Image.fromarray(image) for image in images]
|
235 |
+
else:
|
236 |
+
return images.cpu().permute(0, 2, 3, 1).float().numpy()
|
237 |
+
|
238 |
+
@torch.no_grad()
|
239 |
+
def __call__(
|
240 |
+
self,
|
241 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
242 |
+
height: int = 576,
|
243 |
+
width: int = 576,
|
244 |
+
num_inference_steps: int = 40,
|
245 |
+
guidance_scale: float = 5.0,
|
246 |
+
latent_channels: int = 16,
|
247 |
+
output_type: str = "pil",
|
248 |
+
return_dict: bool = True,
|
249 |
+
batch_size: int = 1,
|
250 |
+
seed: Optional[int] = None,
|
251 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
252 |
+
text_embeddings: Optional[torch.FloatTensor] = None,
|
253 |
+
):
|
254 |
+
"""Генерация изображения из текстовых промптов или эмбеддингов."""
|
255 |
+
device = self.device
|
256 |
+
|
257 |
+
# Устанавливаем генератор с seed для воспроизводимости
|
258 |
+
generator = None
|
259 |
+
if seed is not None:
|
260 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
261 |
+
|
262 |
+
# Получаем эмбеддинги, если они не предоставлены
|
263 |
+
if text_embeddings is None:
|
264 |
+
if prompt is None and negative_prompt is None:
|
265 |
+
raise ValueError("Необходимо указать prompt, negative_prompt или text_embeddings")
|
266 |
+
|
267 |
+
# Вычисляем эмбеддинги
|
268 |
+
text_embeddings = self.encode_prompt(
|
269 |
+
prompt=prompt,
|
270 |
+
negative_prompt=negative_prompt,
|
271 |
+
device=device
|
272 |
+
)
|
273 |
+
else:
|
274 |
+
# Убеждаемся, что эмбеддинги на правильном устройстве
|
275 |
+
text_embeddings = text_embeddings.to(device)
|
276 |
+
|
277 |
+
# Генерируем латенты
|
278 |
+
latents = self.generate_latents(
|
279 |
+
text_embeddings=text_embeddings,
|
280 |
+
height=height,
|
281 |
+
width=width,
|
282 |
+
num_inference_steps=num_inference_steps,
|
283 |
+
guidance_scale=guidance_scale,
|
284 |
+
latent_channels=latent_channels,
|
285 |
+
batch_size=batch_size,
|
286 |
+
generator=generator
|
287 |
+
)
|
288 |
+
|
289 |
+
# Декодируем латенты в изображения
|
290 |
+
images = self.decode_latents(latents, output_type=output_type)
|
291 |
+
|
292 |
+
if not return_dict:
|
293 |
+
return images
|
294 |
+
|
295 |
+
return SdxsPipelineOutput(images=images)
|
promo.png
ADDED
![]() |
Git LFS Details
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# torch>=2.6.0
|
2 |
+
# torchvision>=0.21.0
|
3 |
+
# torchaudio>=2.6.0
|
4 |
+
diffusers>=0.32.2
|
5 |
+
accelerate>=1.5.2
|
6 |
+
datasets>=3.5.0
|
7 |
+
matplotlib>=3.10.1
|
8 |
+
wandb>=0.19.8
|
9 |
+
huggingface_hub>=0.29.3
|
10 |
+
bitsandbytes>=0.45.4
|
11 |
+
transformers
|
result_grid.jpg
ADDED
![]() |
Git LFS Details
|
samples/unet_192x384_0.jpg
ADDED
![]() |
Git LFS Details
|
samples/unet_256x384_0.jpg
ADDED
![]() |
Git LFS Details
|
samples/unet_320x384_0.jpg
ADDED
![]() |
Git LFS Details
|
samples/unet_384x192_0.jpg
ADDED
![]() |
Git LFS Details
|
samples/unet_384x256_0.jpg
ADDED
![]() |
Git LFS Details
|
samples/unet_384x320_0.jpg
ADDED
![]() |
Git LFS Details
|
samples/unet_384x384_0.jpg
ADDED
![]() |
Git LFS Details
|
scheduler/scheduler_config.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2e919ad3cde5f0bdf9529c68ee7c3306b1ceef40245778d29050d58ebc074158
|
3 |
+
size 507
|
src/captions_moondream2.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cd389dacb701c76713fa256b68a05972b89cf80c8d2fafef341abccbba826765
|
3 |
+
size 4999
|
src/captions_moondream2_wd3.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c1519f45fc46f644f46631acc4250c5b558a9d697e47548a00bb3eeedbb14e75
|
3 |
+
size 9956
|
src/captions_qwen2-vl-7b.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
|
5 |
+
import numpy as np
|
6 |
+
from pathlib import Path
|
7 |
+
from tqdm import tqdm
|
8 |
+
import argparse
|
9 |
+
import gc
|
10 |
+
|
11 |
+
# Configuration options
|
12 |
+
PRINT_CAPTIONS = False # Print captions to the console during inference
|
13 |
+
PRINT_CAPTIONING_STATUS = False # Print captioning file status to the console
|
14 |
+
OVERWRITE = True # Allow overwriting existing caption files
|
15 |
+
PREPEND_STRING = "" # Prefix string to prepend to the generated caption
|
16 |
+
APPEND_STRING = "" # Suffix string to append to the generated caption
|
17 |
+
STRIP_LINEBREAKS = True # Remove line breaks from generated captions before saving
|
18 |
+
DEFAULT_SAVE_FORMAT = ".txt" # Default format for saving captions
|
19 |
+
|
20 |
+
# Image resizing options
|
21 |
+
MAX_WIDTH = 512 # Set to 0 or less to ignore
|
22 |
+
MAX_HEIGHT = 512 # Set to 0 or less to ignore
|
23 |
+
|
24 |
+
# Generation parameters
|
25 |
+
REPETITION_PENALTY = 1.3 # Penalty for repeating phrases, float ~1.5
|
26 |
+
TEMPERATURE = 0.7 # Sampling temperature to control randomness
|
27 |
+
TOP_K = 50 # Top-k sampling to limit number of potential next tokens
|
28 |
+
|
29 |
+
# Default values for input folder, output folder, prompt, and save format
|
30 |
+
DEFAULT_INPUT_FOLDER = Path(__file__).parent / "input"
|
31 |
+
DEFAULT_OUTPUT_FOLDER = DEFAULT_INPUT_FOLDER
|
32 |
+
DEFAULT_PROMPT = "In two medium sentence, caption the key aspects of this image."
|
33 |
+
|
34 |
+
#os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
35 |
+
|
36 |
+
# Function to parse command-line arguments
|
37 |
+
def parse_arguments():
|
38 |
+
parser = argparse.ArgumentParser(description="Process images and generate captions using Qwen model.")
|
39 |
+
parser.add_argument("--input_folder", type=str, default=DEFAULT_INPUT_FOLDER, help="Path to the input folder containing images.")
|
40 |
+
parser.add_argument("--output_folder", type=str, default=DEFAULT_OUTPUT_FOLDER, help="Path to the output folder for saving captions.")
|
41 |
+
parser.add_argument("--prompt", type=str, default=DEFAULT_PROMPT, help="Prompt for generating the caption.")
|
42 |
+
parser.add_argument("--save_format", type=str, default=DEFAULT_SAVE_FORMAT, help="Format for saving captions (e.g., .txt, .md, .json).")
|
43 |
+
parser.add_argument("--max_width", type=int, default=MAX_WIDTH, help="Maximum width for resizing images (default: no resizing).")
|
44 |
+
parser.add_argument("--max_height", type=int, default=MAX_HEIGHT, help="Maximum height for resizing images (default: no resizing).")
|
45 |
+
parser.add_argument("--repetition_penalty", type=float, default=REPETITION_PENALTY, help="Penalty for repetition during caption generation (default: 1.10).")
|
46 |
+
parser.add_argument("--temperature", type=float, default=TEMPERATURE, help="Sampling temperature for generation (default: 0.7).")
|
47 |
+
parser.add_argument("--top_k", type=int, default=TOP_K, help="Top-k sampling during generation (default: 50).")
|
48 |
+
return parser.parse_args()
|
49 |
+
|
50 |
+
# Function to ignore images that don't have output files yet
|
51 |
+
def filter_images_without_output(input_folder, save_format):
|
52 |
+
images_to_caption = []
|
53 |
+
skipped_images = 0
|
54 |
+
total_images = 0
|
55 |
+
|
56 |
+
for root, _, files in os.walk(input_folder):
|
57 |
+
for file in files:
|
58 |
+
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
|
59 |
+
total_images += 1
|
60 |
+
image_path = os.path.join(root, file)
|
61 |
+
output_path = os.path.splitext(image_path)[0] + save_format
|
62 |
+
if not OVERWRITE and os.path.exists(output_path):
|
63 |
+
skipped_images += 1
|
64 |
+
else:
|
65 |
+
images_to_caption.append(image_path)
|
66 |
+
|
67 |
+
return images_to_caption, total_images, skipped_images
|
68 |
+
|
69 |
+
# Function to save caption to a file
|
70 |
+
def save_caption_to_file(image_path, caption, save_format):
|
71 |
+
txt_file_path = os.path.splitext(image_path)[0] + save_format # Same name, but with chosen save format
|
72 |
+
caption = PREPEND_STRING + caption + APPEND_STRING # Apply prepend/append strings
|
73 |
+
|
74 |
+
with open(txt_file_path, "w") as txt_file:
|
75 |
+
txt_file.write(caption)
|
76 |
+
|
77 |
+
if PRINT_CAPTIONING_STATUS:
|
78 |
+
print(f"Caption for {os.path.abspath(image_path)} saved in {save_format} format.")
|
79 |
+
|
80 |
+
|
81 |
+
# Function to process all images recursively in a folder
|
82 |
+
def process_images_in_folder(images_to_caption, prompt, save_format, max_width=MAX_WIDTH, max_height=MAX_HEIGHT, repetition_penalty=REPETITION_PENALTY, temperature=TEMPERATURE, top_k=TOP_K):
|
83 |
+
for image_path in tqdm(images_to_caption, desc="Processing Images"):
|
84 |
+
try:
|
85 |
+
# Загружаем и обрабатываем изображение
|
86 |
+
image = Image.open(image_path).convert("RGB")
|
87 |
+
image = resize_image_proportionally(image, max_width, max_height)
|
88 |
+
|
89 |
+
# Читаем теги
|
90 |
+
tags_filename = str(Path(image_path).with_suffix('.tag'))
|
91 |
+
with open(tags_filename, "r") as file:
|
92 |
+
tag_caption = file.read()
|
93 |
+
#prompt = f'In one medium sentence, caption the key aspects of this image. Be direct for NSFW. You may use following tags as context: {tag_caption}'
|
94 |
+
prompt = f'In two medium sentence, caption the key aspects of this image. Be direct for NSFW, Don\'t be shy in your expressions like pussy, anal. Add approximate age for humans like: a child,young girl or 40 years old woman. You may use following tags as context if they are relevant: {tag_caption}'
|
95 |
+
|
96 |
+
#print("prompt",prompt)
|
97 |
+
|
98 |
+
# Генерируем подпись
|
99 |
+
caption = qwen_caption(image, prompt, repetition_penalty, temperature, top_k)
|
100 |
+
save_caption_to_file(image_path, caption, save_format)
|
101 |
+
|
102 |
+
if PRINT_CAPTIONS:
|
103 |
+
print(f"Caption for {os.path.abspath(image_path)}: {caption}")
|
104 |
+
|
105 |
+
# Освобождаем память
|
106 |
+
del image, tag_caption, caption
|
107 |
+
torch.cuda.empty_cache()
|
108 |
+
gc.collect()
|
109 |
+
|
110 |
+
except Exception as e:
|
111 |
+
print(f"Error processing {os.path.abspath(image_path)}: {str(e)}")
|
112 |
+
torch.cuda.empty_cache()
|
113 |
+
gc.collect()
|
114 |
+
|
115 |
+
# Resize the image proportionally based on max width and/or max height.
|
116 |
+
def resize_image_proportionally(image, max_width=None, max_height=None):
|
117 |
+
"""
|
118 |
+
If both max_width and max_height are provided, the image is resized to fit within both dimensions,
|
119 |
+
keeping the aspect ratio intact. If only one dimension is provided, the image is resized based on that dimension.
|
120 |
+
"""
|
121 |
+
if (max_width is None or max_width <= 0) and (max_height is None or max_height <= 0):
|
122 |
+
return image # No resizing if both dimensions are not provided or set to 0 or less
|
123 |
+
|
124 |
+
original_width, original_height = image.size
|
125 |
+
aspect_ratio = original_width / original_height
|
126 |
+
|
127 |
+
# Determine the new dimensions
|
128 |
+
if max_width and not max_height:
|
129 |
+
# Resize based on width
|
130 |
+
new_width = max_width
|
131 |
+
new_height = int(new_width / aspect_ratio)
|
132 |
+
elif max_height and not max_width:
|
133 |
+
# Resize based on height
|
134 |
+
new_height = max_height
|
135 |
+
new_width = int(new_height * aspect_ratio)
|
136 |
+
else:
|
137 |
+
# Resize based on both width and height, keeping the aspect ratio
|
138 |
+
new_width = max_width
|
139 |
+
new_height = max_height
|
140 |
+
|
141 |
+
# Adjust the dimensions proportionally to the aspect ratio
|
142 |
+
if new_width / aspect_ratio > new_height:
|
143 |
+
new_width = int(new_height * aspect_ratio)
|
144 |
+
else:
|
145 |
+
new_height = int(new_width / aspect_ratio)
|
146 |
+
|
147 |
+
# Resize the image using LANCZOS (equivalent to ANTIALIAS in older versions)
|
148 |
+
resized_image = image.resize((new_width, new_height))
|
149 |
+
return resized_image
|
150 |
+
|
151 |
+
# Generate a caption for the provided image using the Ertugrul/Qwen2-VL-7B-Captioner-Relaxed model
|
152 |
+
def qwen_caption(image, prompt, repetition_penalty=REPETITION_PENALTY, temperature=TEMPERATURE, top_k=TOP_K):
|
153 |
+
if not isinstance(image, Image.Image):
|
154 |
+
image = Image.fromarray(np.uint8(image))
|
155 |
+
|
156 |
+
# Prepare the conversation content, which includes the image and the text prompt
|
157 |
+
conversation = [
|
158 |
+
{
|
159 |
+
"role": "user",
|
160 |
+
"content": [
|
161 |
+
{
|
162 |
+
"type": "image",
|
163 |
+
},
|
164 |
+
{"type": "text", "text": prompt},
|
165 |
+
],
|
166 |
+
}
|
167 |
+
]
|
168 |
+
|
169 |
+
# Apply the chat template to format the message for processing
|
170 |
+
text_prompt = qwen_processor.apply_chat_template(
|
171 |
+
conversation, add_generation_prompt=True
|
172 |
+
)
|
173 |
+
|
174 |
+
# Prepare the inputs for the model, padding as necessary and converting to tensors
|
175 |
+
inputs = qwen_processor(
|
176 |
+
text=[text_prompt],
|
177 |
+
images=[image],
|
178 |
+
padding=True,
|
179 |
+
return_tensors="pt",
|
180 |
+
)
|
181 |
+
inputs = inputs.to("cuda")
|
182 |
+
|
183 |
+
with torch.no_grad():
|
184 |
+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
185 |
+
output_ids = qwen_model.generate(
|
186 |
+
**inputs,
|
187 |
+
max_new_tokens=384,
|
188 |
+
do_sample=True,
|
189 |
+
temperature=temperature,
|
190 |
+
use_cache=True,
|
191 |
+
top_k=top_k,
|
192 |
+
repetition_penalty=repetition_penalty,
|
193 |
+
)
|
194 |
+
|
195 |
+
# Trim the generated IDs to remove the input part from the output
|
196 |
+
generated_ids_trimmed = [
|
197 |
+
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, output_ids)
|
198 |
+
]
|
199 |
+
|
200 |
+
# Decode the trimmed output into text, skipping special tokens
|
201 |
+
output_text = qwen_processor.batch_decode(
|
202 |
+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
203 |
+
)
|
204 |
+
|
205 |
+
# Strip line breaks if the option is enabled
|
206 |
+
if STRIP_LINEBREAKS:
|
207 |
+
output_text[0] = output_text[0].replace('\n', ' ')
|
208 |
+
|
209 |
+
# Освобождаем память
|
210 |
+
del inputs, output_ids, generated_ids_trimmed
|
211 |
+
torch.cuda.empty_cache()
|
212 |
+
gc.collect()
|
213 |
+
|
214 |
+
return output_text[0]
|
215 |
+
|
216 |
+
# Run the script
|
217 |
+
if __name__ == "__main__":
|
218 |
+
args = parse_arguments()
|
219 |
+
input_folder = args.input_folder
|
220 |
+
output_folder = args.output_folder
|
221 |
+
prompt = args.prompt
|
222 |
+
save_format = args.save_format
|
223 |
+
max_width = args.max_width
|
224 |
+
max_height = args.max_height
|
225 |
+
repetition_penalty = args.repetition_penalty
|
226 |
+
temperature = args.temperature
|
227 |
+
top_k = args.top_k
|
228 |
+
|
229 |
+
# Define model_id
|
230 |
+
model_id = "Ertugrul/Qwen2-VL-7B-Captioner-Relaxed"
|
231 |
+
|
232 |
+
# Filter images before loading the model
|
233 |
+
images_to_caption, total_images, skipped_images = filter_images_without_output(input_folder, save_format)
|
234 |
+
|
235 |
+
# Print summary of found, skipped, and to-be-processed images
|
236 |
+
print(f"\nFound {total_images} image{'s' if total_images != 1 else ''}.")
|
237 |
+
if not OVERWRITE:
|
238 |
+
print(f"{skipped_images} image{'s' if skipped_images != 1 else ''} already have captions with format {save_format}, skipping.")
|
239 |
+
print(f"\nCaptioning {len(images_to_caption)} image{'s' if len(images_to_caption) != 1 else ''}.\n\n")
|
240 |
+
|
241 |
+
# Only load the model if there are images to caption
|
242 |
+
if len(images_to_caption) == 0:
|
243 |
+
print("No images to process. Exiting.\n\n")
|
244 |
+
else:
|
245 |
+
# Initialize the Ertugrul/Qwen2-VL-7B-Captioner-Relaxed model
|
246 |
+
qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
|
247 |
+
model_id, torch_dtype=torch.bfloat16, device_map="auto"
|
248 |
+
)
|
249 |
+
qwen_processor = AutoProcessor.from_pretrained(model_id)
|
250 |
+
|
251 |
+
# Process the images with optional resizing and caption generation
|
252 |
+
process_images_in_folder(
|
253 |
+
images_to_caption,
|
254 |
+
prompt,
|
255 |
+
save_format,
|
256 |
+
max_width=max_width,
|
257 |
+
max_height=max_height,
|
258 |
+
repetition_penalty=repetition_penalty,
|
259 |
+
temperature=temperature,
|
260 |
+
top_k=top_k
|
261 |
+
)
|
src/captions_wd.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b46993285995bdc52e69d82e900f239ebafd9dc924be046e80372639c8796ed8
|
3 |
+
size 29850
|
src/cherrypick.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c94efa8bf16993a1d15ef5c455fadf92084dd9349fdd8a9b5a9ca66fe869f565
|
3 |
+
size 48464
|
src/cuda.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cba78a89ed8649cb384e15b3b241df0a1aa35b36ca89a5453e59fc4b875ec0f1
|
3 |
+
size 1503
|
src/dataset_clean.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b1f0686dd3fe000d5bc00ff5d676f173793699630b0017f53530f3dcb1ec474e
|
3 |
+
size 5085
|
src/dataset_combine.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
from datasets import load_from_disk, concatenate_datasets
|
4 |
+
|
5 |
+
def combine_datasets(main_dataset_path, datasets_to_add):
|
6 |
+
"""
|
7 |
+
Объединяет указанные датасеты с основным датасетом.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
main_dataset_path (str): Путь к основному датасету, в который нужно добавить данные
|
11 |
+
datasets_to_add (list): Список путей к датасетам, которые нужно добавить
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
Dataset: Объединенный датасет
|
15 |
+
"""
|
16 |
+
# Загружаем основной датасет
|
17 |
+
try:
|
18 |
+
main_dataset = load_from_disk(main_dataset_path)
|
19 |
+
print(f"Загружен основной датасет: {main_dataset_path} ({len(main_dataset)} записей)")
|
20 |
+
except Exception as e:
|
21 |
+
print(f"Ошибка загрузки основного датасета: {e}")
|
22 |
+
return None
|
23 |
+
|
24 |
+
# Список всех датасетов для объединения
|
25 |
+
all_datasets = [main_dataset]
|
26 |
+
|
27 |
+
# Загружаем и добавляем все дополнительные датасеты
|
28 |
+
for path in datasets_to_add:
|
29 |
+
try:
|
30 |
+
ds = load_from_disk(path)
|
31 |
+
all_datasets.append(ds)
|
32 |
+
print(f"Добавлен датасет: {path} ({len(ds)} записей)")
|
33 |
+
except Exception as e:
|
34 |
+
print(f"Ошибка загрузки датасета {path}: {e}")
|
35 |
+
|
36 |
+
# Объединяем все датасеты
|
37 |
+
print(f"Объединение {len(all_datasets)} датасетов...")
|
38 |
+
combined = concatenate_datasets(all_datasets)
|
39 |
+
|
40 |
+
# Создаем временную директорию на основе имени основного датасета
|
41 |
+
temp_dir = f"{main_dataset_path}_temp"
|
42 |
+
|
43 |
+
# Удаляем временную директорию, если она уже существует
|
44 |
+
if os.path.exists(temp_dir):
|
45 |
+
shutil.rmtree(temp_dir)
|
46 |
+
|
47 |
+
try:
|
48 |
+
# Сохраняем в временную директорию
|
49 |
+
print(f"Сохранение во временную директорию {temp_dir}...")
|
50 |
+
combined.save_to_disk(temp_dir)
|
51 |
+
|
52 |
+
# Удаляем старую директорию и перемещаем новую на ее место
|
53 |
+
print(f"Обновление основного датасета...")
|
54 |
+
#if os.path.exists(main_dataset_path):
|
55 |
+
# shutil.rmtree(main_dataset_path)
|
56 |
+
#shutil.copytree(temp_dir, main_dataset_path)
|
57 |
+
|
58 |
+
# Удаляем временную директорию после успешного копирования
|
59 |
+
#shutil.rmtree(temp_dir)
|
60 |
+
|
61 |
+
print(f"✅ Объединенный датасет ({len(combined)} записей) успешно сохранен в: {main_dataset_path}")
|
62 |
+
except Exception as e:
|
63 |
+
print(f"Ошибка при сохранении датасета: {e}")
|
64 |
+
print(f"Временные данные сохранены в: {temp_dir}")
|
65 |
+
|
66 |
+
return combined
|
67 |
+
|
68 |
+
combine_datasets("/workspace/sdxs/datasets/384_temp", ["/workspace/sdxs/datasets/ds3_384"])
|
src/dataset_fromzip.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8d12f225412aeed4f5ff6cd2dc23db6bcdfd944ff00fe6f0d9c2e8fe0ec426ee
|
3 |
+
size 6167
|
src/dataset_imagenet.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:92ea6bc9e4033a778b9e36defff6adf481baae0d8b0a3fa537313df6fb5b4472
|
3 |
+
size 318505
|
src/dataset_laion_coco.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:31395e7f40ef370971b523fb9d9ab56b404ca8cc1e8e932cc602beaf72140411
|
3 |
+
size 25403
|
src/dataset_mjnj.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8cf317c438de242a8cc0c7d710c00ceec53e887108b081235a1fb05dae0074b0
|
3 |
+
size 23158
|
src/dataset_mnist-te.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ac1571369244de9ff15d4b1785e962e06521630fa1be32f0471175e42ef00630
|
3 |
+
size 34388
|
src/dataset_mnist.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c644e111748cb374d2fb9fec28ef99a5ed616898100e689cd02c6ba80b3431a7
|
3 |
+
size 33829
|
src/dataset_sample.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ac0384c01b5ed29625df6ab7c2da36bbf9b7b9beb4ba83746eb6c00fbd6046e1
|
3 |
+
size 1986940
|
src/inference.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9fdead0e35dd039c20314c1f7f8579c92b2a891a310965ffdec6002fd8a78c00
|
3 |
+
size 2147113
|
src/sdxs_create-vavae.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:11c2151ba855c0c0fda1e58c295f56612843c5b42aecd779cdb3a03b3802b991
|
3 |
+
size 9794
|
src/sdxs_create.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9fd812a7bc5233c0c2e3932fe11c5ba132e5d0389a505726fa54a95c26b42edf
|
3 |
+
size 7417
|
src/sdxs_create_simple.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:93e0cdf53493f39cd1b8b76f41055aee1e8377e128446d03ccf524e0bb0dcd00
|
3 |
+
size 51335
|
src/sdxs_create_unet.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1e2dec0ba7a9a8d4aaaeaad2201dd660ec266cb887d7f8eb127ffbe8c7d80c4f
|
3 |
+
size 35930
|
src/sdxs_sdxxs_transfer.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:57db74be42dbf73bc551cf86b4302dd2717555280e26325e599ca89f51b4916e
|
3 |
+
size 168192
|
test.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6bca9de56c7bdda4a032e2c84d15fd5dfc8108aea08fc3186f8203f428b966f8
|
3 |
+
size 5148457
|
text_encoder/config.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:87131a858ee394af6afae023f733cdebc36eda2ccbed27c36bc887cfae427392
|
3 |
+
size 721
|
text_encoder/model.fp16.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:107fe15da52fe6d13d877512fa36861d1100534d1b9b88015ad9fd017db095a7
|
3 |
+
size 1119825680
|
text_projector/config.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ae2f211593cd2cc736bf8617bcb0a5e6abd4db0265170de82ae03b7a6664feda
|
3 |
+
size 83
|
text_projector/model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6e7060a387b4a6419f9d1d852759cb5b94541a1845e996f6062a07462d8b7b6a
|
3 |
+
size 2359384
|
tokenizer/special_tokens_map.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8c785abebea9ae3257b61681b4e6fd8365ceafde980c21970d001e834cf10835
|
3 |
+
size 964
|
tokenizer/tokenizer.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3ffb37461c391f096759f4a9bbbc329da0f36952f88bab061fcf84940c022e98
|
3 |
+
size 17082999
|
tokenizer/tokenizer_config.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ccf223ba3d5b3cc7fa6c3bf451f3bb40557a5c92b0aa33f63d17802ff1a96fd9
|
3 |
+
size 1178
|
train-Copy1.py
ADDED
@@ -0,0 +1,789 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
from torch.utils.data import DataLoader, Sampler
|
7 |
+
from torch.utils.data.distributed import DistributedSampler
|
8 |
+
from collections import defaultdict
|
9 |
+
from torch.optim.lr_scheduler import LambdaLR
|
10 |
+
from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler
|
11 |
+
from accelerate import Accelerator
|
12 |
+
from datasets import load_from_disk
|
13 |
+
from tqdm import tqdm
|
14 |
+
from PIL import Image,ImageOps
|
15 |
+
import wandb
|
16 |
+
import random
|
17 |
+
import gc
|
18 |
+
from accelerate.state import DistributedType
|
19 |
+
from torch.distributed import broadcast_object_list
|
20 |
+
from torch.utils.checkpoint import checkpoint
|
21 |
+
from diffusers.models.attention_processor import AttnProcessor2_0
|
22 |
+
from datetime import datetime
|
23 |
+
import bitsandbytes as bnb
|
24 |
+
|
25 |
+
# --------------------------- Параметры ---------------------------
|
26 |
+
ds_path = "datasets/384"
|
27 |
+
batch_size = 50
|
28 |
+
base_learning_rate = 3e-5
|
29 |
+
min_learning_rate = 3e-6
|
30 |
+
num_epochs = 10
|
31 |
+
num_warmup_steps = 1000
|
32 |
+
project = "unet"
|
33 |
+
use_wandb = True
|
34 |
+
save_model = True
|
35 |
+
sample_interval_share = 5 # samples/save per epoch
|
36 |
+
fbp = False # fused backward pass
|
37 |
+
adam8bit = True
|
38 |
+
percentile_clipping = 97 # Lion
|
39 |
+
torch_compile = False
|
40 |
+
unet_gradient = True
|
41 |
+
clip_sample = False #Scheduler
|
42 |
+
fixed_seed = False
|
43 |
+
shuffle = True
|
44 |
+
dtype = torch.float32
|
45 |
+
steps_offset = 1 # Scheduler
|
46 |
+
limit = 0
|
47 |
+
checkpoints_folder = ""
|
48 |
+
mixed_precision = "no"
|
49 |
+
accelerator = Accelerator(mixed_precision=mixed_precision)
|
50 |
+
device = accelerator.device
|
51 |
+
|
52 |
+
# Параметры для диффузии
|
53 |
+
n_diffusion_steps = 50
|
54 |
+
samples_to_generate = 12
|
55 |
+
guidance_scale = 5
|
56 |
+
|
57 |
+
# Папки для сохранения результатов
|
58 |
+
generated_folder = "samples"
|
59 |
+
os.makedirs(generated_folder, exist_ok=True)
|
60 |
+
|
61 |
+
# Настройка seed для воспроизводимости
|
62 |
+
current_date = datetime.now()
|
63 |
+
seed = int(current_date.strftime("%Y%m%d"))
|
64 |
+
if fixed_seed:
|
65 |
+
torch.manual_seed(seed)
|
66 |
+
np.random.seed(seed)
|
67 |
+
random.seed(seed)
|
68 |
+
if torch.cuda.is_available():
|
69 |
+
torch.cuda.manual_seed_all(seed)
|
70 |
+
|
71 |
+
#torch.backends.cuda.matmul.allow_tf32 = True
|
72 |
+
#torch.backends.cudnn.allow_tf32 = True
|
73 |
+
# --------------------------- Параметры LoRA ---------------------------
|
74 |
+
# pip install peft
|
75 |
+
lora_name = "" #"nusha" # Имя для сохранения/загрузки LoRA адаптеров
|
76 |
+
lora_rank = 32 # Ранг LoRA (чем меньше, тем компактнее модель)
|
77 |
+
lora_alpha = 64 # Альфа параметр LoRA, определяющий масштаб
|
78 |
+
|
79 |
+
print("init")
|
80 |
+
|
81 |
+
# --------------------------- Инициализация WandB ---------------------------
|
82 |
+
if use_wandb and accelerator.is_main_process:
|
83 |
+
wandb.init(project=project+lora_name, config={
|
84 |
+
"batch_size": batch_size,
|
85 |
+
"base_learning_rate": base_learning_rate,
|
86 |
+
"num_epochs": num_epochs,
|
87 |
+
"fbp": fbp,
|
88 |
+
"adam8bit": adam8bit,
|
89 |
+
})
|
90 |
+
|
91 |
+
# Включение Flash Attention 2/SDPA
|
92 |
+
torch.backends.cuda.enable_flash_sdp(True)
|
93 |
+
# --------------------------- Инициализация Accelerator --------------------
|
94 |
+
gen = torch.Generator(device=device)
|
95 |
+
gen.manual_seed(seed)
|
96 |
+
|
97 |
+
# --------------------------- Загрузка моделей ---------------------------
|
98 |
+
# VAE загружается на CPU для экономии GPU-памяти
|
99 |
+
vae = AutoencoderKL.from_pretrained("vae", variant="fp16").to("cpu").eval()
|
100 |
+
|
101 |
+
# DDPMScheduler с V_Prediction и Zero-SNR
|
102 |
+
scheduler = DDPMScheduler(
|
103 |
+
num_train_timesteps=1000, # Полный график шагов для обучения
|
104 |
+
prediction_type="v_prediction", # V-Prediction
|
105 |
+
rescale_betas_zero_snr=True, # Включение Zero-SNR
|
106 |
+
clip_sample = clip_sample,
|
107 |
+
steps_offset = steps_offset
|
108 |
+
)
|
109 |
+
|
110 |
+
class DistributedResolutionBatchSampler(Sampler):
|
111 |
+
def __init__(self, dataset, batch_size, num_replicas, rank, shuffle=True, drop_last=True):
|
112 |
+
self.dataset = dataset
|
113 |
+
self.batch_size = max(1, batch_size // num_replicas)
|
114 |
+
self.num_replicas = num_replicas
|
115 |
+
self.rank = rank
|
116 |
+
self.shuffle = shuffle
|
117 |
+
self.drop_last = drop_last
|
118 |
+
self.epoch = 0
|
119 |
+
|
120 |
+
# Используем numpy для ускорения
|
121 |
+
try:
|
122 |
+
widths = np.array(dataset["width"])
|
123 |
+
heights = np.array(dataset["height"])
|
124 |
+
except KeyError:
|
125 |
+
widths = np.zeros(len(dataset))
|
126 |
+
heights = np.zeros(len(dataset))
|
127 |
+
|
128 |
+
# Создаем уникальные ключи для размеров
|
129 |
+
self.size_keys = np.unique(np.stack([widths, heights], axis=1), axis=0)
|
130 |
+
|
131 |
+
# Группируем индексы по размерам используя numpy
|
132 |
+
self.size_groups = {}
|
133 |
+
for w, h in self.size_keys:
|
134 |
+
mask = (widths == w) & (heights == h)
|
135 |
+
self.size_groups[(w, h)] = np.where(mask)[0]
|
136 |
+
|
137 |
+
# Предварительно вычисляем количество пол��ых батчей для каждой группы
|
138 |
+
self.group_num_batches = {}
|
139 |
+
total_batches = 0
|
140 |
+
for size, indices in self.size_groups.items():
|
141 |
+
num_full_batches = len(indices) // (self.batch_size * self.num_replicas)
|
142 |
+
self.group_num_batches[size] = num_full_batches
|
143 |
+
total_batches += num_full_batches
|
144 |
+
|
145 |
+
# Округляем до числа, делящегося на num_replicas
|
146 |
+
self.num_batches = (total_batches // self.num_replicas) * self.num_replicas
|
147 |
+
|
148 |
+
def __iter__(self):
|
149 |
+
# print(f"Rank {self.rank}: Starting iteration")
|
150 |
+
# Очищаем CUDA кэш перед формированием новых батчей
|
151 |
+
if torch.cuda.is_available():
|
152 |
+
torch.cuda.empty_cache()
|
153 |
+
all_batches = []
|
154 |
+
rng = np.random.RandomState(self.epoch)
|
155 |
+
|
156 |
+
for size, indices in self.size_groups.items():
|
157 |
+
# print(f"Rank {self.rank}: Processing size {size}, {len(indices)} samples")
|
158 |
+
indices = indices.copy()
|
159 |
+
if self.shuffle:
|
160 |
+
rng.shuffle(indices)
|
161 |
+
|
162 |
+
num_full_batches = self.group_num_batches[size]
|
163 |
+
if num_full_batches == 0:
|
164 |
+
continue
|
165 |
+
|
166 |
+
# Берем только индексы для полных батчей
|
167 |
+
valid_indices = indices[:num_full_batches * self.batch_size * self.num_replicas]
|
168 |
+
|
169 |
+
# Reshape для быстрого разделения на батчи
|
170 |
+
batches = valid_indices.reshape(-1, self.batch_size * self.num_replicas)
|
171 |
+
|
172 |
+
# Выбираем часть для текущего GPU
|
173 |
+
start_idx = self.rank * self.batch_size
|
174 |
+
end_idx = start_idx + self.batch_size
|
175 |
+
gpu_batches = batches[:, start_idx:end_idx]
|
176 |
+
|
177 |
+
all_batches.extend(gpu_batches)
|
178 |
+
|
179 |
+
if self.shuffle:
|
180 |
+
rng.shuffle(all_batches)
|
181 |
+
|
182 |
+
# Синхронизируем все процессы после формирования батчей
|
183 |
+
accelerator.wait_for_everyone()
|
184 |
+
# print(f"Rank {self.rank}: Created {len(all_batches)} batches")
|
185 |
+
return iter(all_batches)
|
186 |
+
|
187 |
+
def __len__(self):
|
188 |
+
return self.num_batches
|
189 |
+
|
190 |
+
def set_epoch(self, epoch):
|
191 |
+
self.epoch = epoch
|
192 |
+
|
193 |
+
# Функция для выборки фиксированных семплов по размерам
|
194 |
+
def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
|
195 |
+
"""Выбирает фиксированные семплы для каждого уникального разрешения"""
|
196 |
+
# Группируем по размерам
|
197 |
+
size_groups = defaultdict(list)
|
198 |
+
try:
|
199 |
+
widths = dataset["width"]
|
200 |
+
heights = dataset["height"]
|
201 |
+
except KeyError:
|
202 |
+
widths = [0] * len(dataset)
|
203 |
+
heights = [0] * len(dataset)
|
204 |
+
for i, (w, h) in enumerate(zip(widths, heights)):
|
205 |
+
size = (w, h)
|
206 |
+
size_groups[size].append(i)
|
207 |
+
|
208 |
+
# Выбираем фиксированные примеры из каждой группы
|
209 |
+
fixed_samples = {}
|
210 |
+
for size, indices in size_groups.items():
|
211 |
+
# Определяем сколько семплов брать из этой группы
|
212 |
+
n_samples = min(samples_per_group, len(indices))
|
213 |
+
if len(size_groups)==1:
|
214 |
+
n_samples = samples_to_generate
|
215 |
+
if n_samples == 0:
|
216 |
+
continue
|
217 |
+
|
218 |
+
# Выбираем случайные индексы
|
219 |
+
sample_indices = random.sample(indices, n_samples)
|
220 |
+
samples_data = [dataset[idx] for idx in sample_indices]
|
221 |
+
|
222 |
+
# Собираем данные
|
223 |
+
latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device,dtype=dtype)
|
224 |
+
embeddings = torch.tensor(np.array([item["embeddings"] for item in samples_data])).to(device,dtype=dtype)
|
225 |
+
texts = [item["text"] for item in samples_data]
|
226 |
+
|
227 |
+
# Сохраняем для этого размера
|
228 |
+
fixed_samples[size] = (latents, embeddings, texts)
|
229 |
+
|
230 |
+
print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
|
231 |
+
return fixed_samples
|
232 |
+
|
233 |
+
if limit > 0:
|
234 |
+
dataset = load_from_disk(ds_path).select(range(limit))
|
235 |
+
else:
|
236 |
+
dataset = load_from_disk(ds_path)
|
237 |
+
|
238 |
+
def collate_fn_simple(batch):
|
239 |
+
# Преобразуем список в тензоры и перемещаем на девайс
|
240 |
+
latents = torch.tensor(np.array([item["vae"] for item in batch])).to(device,dtype=dtype)
|
241 |
+
embeddings = torch.tensor(np.array([item["embeddings"] for item in batch])).to(device,dtype=dtype)
|
242 |
+
return latents, embeddings
|
243 |
+
|
244 |
+
def collate_fn(batch):
|
245 |
+
if not batch:
|
246 |
+
return [], []
|
247 |
+
|
248 |
+
# Берем эталонную форму
|
249 |
+
ref_vae_shape = np.array(batch[0]["vae"]).shape
|
250 |
+
ref_embed_shape = np.array(batch[0]["embeddings"]).shape
|
251 |
+
|
252 |
+
# Фильтруем
|
253 |
+
valid_latents = []
|
254 |
+
valid_embeddings = []
|
255 |
+
for item in batch:
|
256 |
+
if (np.array(item["vae"]).shape == ref_vae_shape and
|
257 |
+
np.array(item["embeddings"]).shape == ref_embed_shape):
|
258 |
+
valid_latents.append(item["vae"])
|
259 |
+
valid_embeddings.append(item["embeddings"])
|
260 |
+
|
261 |
+
# Создаем тензоры
|
262 |
+
latents = torch.tensor(np.array(valid_latents)).to(device,dtype=dtype)
|
263 |
+
embeddings = torch.tensor(np.array(valid_embeddings)).to(device,dtype=dtype)
|
264 |
+
|
265 |
+
return latents, embeddings
|
266 |
+
|
267 |
+
# Используем наш ResolutionBatchSampler
|
268 |
+
#batch_sampler = ResolutionBatchSampler(dataset, batch_size=batch_size, shuffle=True)
|
269 |
+
#dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn)
|
270 |
+
|
271 |
+
# Создаем ResolutionBatchSampler на основе индексов от DistributedSampler
|
272 |
+
batch_sampler = DistributedResolutionBatchSampler(
|
273 |
+
dataset=dataset,
|
274 |
+
batch_size=batch_size,
|
275 |
+
num_replicas=accelerator.num_processes,
|
276 |
+
rank=accelerator.process_index,
|
277 |
+
shuffle=shuffle
|
278 |
+
)
|
279 |
+
|
280 |
+
# Создаем DataLoader
|
281 |
+
dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple)
|
282 |
+
|
283 |
+
print("Total samples",len(dataloader))
|
284 |
+
dataloader = accelerator.prepare(dataloader)
|
285 |
+
|
286 |
+
# Инициализация переменных для возобновления обучения
|
287 |
+
start_epoch = 0
|
288 |
+
global_step = 0
|
289 |
+
|
290 |
+
# Расчёт общего количества шагов
|
291 |
+
total_training_steps = (len(dataloader) * num_epochs)
|
292 |
+
# Get the world size
|
293 |
+
world_size = accelerator.state.num_processes
|
294 |
+
#print(f"World Size: {world_size}")
|
295 |
+
|
296 |
+
# Опция загрузки модели из последнего чекпоинта (если существует)
|
297 |
+
latest_checkpoint = os.path.join(checkpoints_folder, project)
|
298 |
+
if os.path.isdir(latest_checkpoint):
|
299 |
+
print("Загружаем UNet из чекпоинта:", latest_checkpoint)
|
300 |
+
if dtype == torch.float32:
|
301 |
+
unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device,dtype=dtype)
|
302 |
+
else:
|
303 |
+
unet = UNet2DConditionModel.from_pretrained(latest_checkpoint, variant="fp16").to(device=device,dtype=dtype)
|
304 |
+
if unet_gradient:
|
305 |
+
unet.enable_gradient_checkpointing()
|
306 |
+
unet.set_use_memory_efficient_attention_xformers(False) # отключаем xformers
|
307 |
+
try:
|
308 |
+
unet.set_attn_processor(AttnProcessor2_0()) # Используем стандартный AttnProcessor
|
309 |
+
except Exception as e:
|
310 |
+
print(f"Ошибка при включении SDPA: {e}")
|
311 |
+
print("Попытка использовать enable_xformers_memory_efficient_attention.")
|
312 |
+
unet.set_use_memory_efficient_attention_xformers(True)
|
313 |
+
|
314 |
+
if hasattr(torch.backends.cuda, "flash_sdp_enabled"):
|
315 |
+
print(f"torch.backends.cuda.flash_sdp_enabled(): {torch.backends.cuda.flash_sdp_enabled()}")
|
316 |
+
if hasattr(torch.backends.cuda, "mem_efficient_sdp_enabled"):
|
317 |
+
print(f"torch.backends.cuda.mem_efficient_sdp_enabled(): {torch.backends.cuda.mem_efficient_sdp_enabled()}")
|
318 |
+
if hasattr(torch.nn.functional, "get_flash_attention_available"):
|
319 |
+
print(f"torch.nn.functional.get_flash_attention_available(): {torch.nn.functional.get_flash_attention_available()}")
|
320 |
+
if torch_compile:
|
321 |
+
print("compiling")
|
322 |
+
torch.set_float32_matmul_precision('high')
|
323 |
+
unet = torch.compile(unet)#, mode="reduce-overhead", fullgraph=True)
|
324 |
+
print("compiling - ok")
|
325 |
+
|
326 |
+
if lora_name:
|
327 |
+
print(f"--- Настройка LoRA через PEFT (Rank={lora_rank}, Alpha={lora_alpha}) ---")
|
328 |
+
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
329 |
+
from peft.tuners.lora import LoraModel
|
330 |
+
import os
|
331 |
+
# 1. Замораживаем все параметры UNet
|
332 |
+
unet.requires_grad_(False)
|
333 |
+
print("Параметры базового UNet заморожены.")
|
334 |
+
|
335 |
+
# 2. Создаем конфигурацию LoRA
|
336 |
+
lora_config = LoraConfig(
|
337 |
+
r=lora_rank,
|
338 |
+
lora_alpha=lora_alpha,
|
339 |
+
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
340 |
+
)
|
341 |
+
unet.add_adapter(lora_config)
|
342 |
+
|
343 |
+
# 3. Оборачиваем UNet в PEFT-модель
|
344 |
+
from peft import get_peft_model
|
345 |
+
|
346 |
+
peft_unet = get_peft_model(unet, lora_config)
|
347 |
+
|
348 |
+
# 4. Получаем параметры для оптимизации
|
349 |
+
params_to_optimize = list(p for p in peft_unet.parameters() if p.requires_grad)
|
350 |
+
|
351 |
+
|
352 |
+
# 5. Выводим информацию о количестве параметров
|
353 |
+
if accelerator.is_main_process:
|
354 |
+
lora_params_count = sum(p.numel() for p in params_to_optimize)
|
355 |
+
total_params_count = sum(p.numel() for p in unet.parameters())
|
356 |
+
print(f"Количество обучаемых параметров (LoRA): {lora_params_count:,}")
|
357 |
+
print(f"Общее количество параметров UNet: {total_params_count:,}")
|
358 |
+
|
359 |
+
# 6. Путь для сохранения
|
360 |
+
lora_save_path = os.path.join("lora", lora_name)
|
361 |
+
os.makedirs(lora_save_path, exist_ok=True)
|
362 |
+
|
363 |
+
# 7. Функция для сохранения
|
364 |
+
def save_lora_checkpoint(model):
|
365 |
+
if accelerator.is_main_process:
|
366 |
+
print(f"Сохраняем LoRA адаптеры в {lora_save_path}")
|
367 |
+
from peft.utils.save_and_load import get_peft_model_state_dict
|
368 |
+
# Получаем state_dict только LoRA
|
369 |
+
lora_state_dict = get_peft_model_state_dict(model)
|
370 |
+
|
371 |
+
# Сохраняем веса
|
372 |
+
torch.save(lora_state_dict, os.path.join(lora_save_path, "adapter_model.bin"))
|
373 |
+
|
374 |
+
# Сохраняем конфиг
|
375 |
+
model.peft_config["default"].save_pretrained(lora_save_path)
|
376 |
+
# SDXL must be compatible
|
377 |
+
from diffusers import StableDiffusionXLPipeline
|
378 |
+
StableDiffusionXLPipeline.save_lora_weights(lora_save_path, lora_state_dict)
|
379 |
+
|
380 |
+
# --------------------------- Оптимизатор ---------------------------
|
381 |
+
# Определяем параметры для оптимизации
|
382 |
+
#unet = torch.compile(unet)
|
383 |
+
if lora_name:
|
384 |
+
# Если используется LoRA, оптимизируем только параметры LoRA
|
385 |
+
trainable_params = [p for p in unet.parameters() if p.requires_grad]
|
386 |
+
else:
|
387 |
+
# Иначе оптимизируем все параметры
|
388 |
+
if fbp:
|
389 |
+
trainable_params = list(unet.parameters())
|
390 |
+
|
391 |
+
if fbp:
|
392 |
+
# [1] Создаем словарь оптимизаторов (fused backward)
|
393 |
+
if adam8bit:
|
394 |
+
optimizer_dict = {
|
395 |
+
p: bnb.optim.AdamW8bit(
|
396 |
+
[p], # Каждый параметр получает свой оптимизатор
|
397 |
+
lr=base_learning_rate,
|
398 |
+
eps=1e-8
|
399 |
+
) for p in trainable_params
|
400 |
+
}
|
401 |
+
else:
|
402 |
+
optimizer_dict = {
|
403 |
+
p: bnb.optim.Lion8bit(
|
404 |
+
[p], # Каждый параметр получает свой оптимизатор
|
405 |
+
lr=base_learning_rate,
|
406 |
+
betas=(0.9, 0.97),
|
407 |
+
weight_decay=0.01,
|
408 |
+
percentile_clipping=percentile_clipping,
|
409 |
+
) for p in trainable_params
|
410 |
+
}
|
411 |
+
|
412 |
+
# [2] Определяем hook для применения оптимизатора сразу после накопления градиента
|
413 |
+
def optimizer_hook(param):
|
414 |
+
optimizer_dict[param].step()
|
415 |
+
optimizer_dict[param].zero_grad(set_to_none=True)
|
416 |
+
|
417 |
+
# [3] Регистрируем hook для trainable параметров модели
|
418 |
+
for param in trainable_params:
|
419 |
+
param.register_post_accumulate_grad_hook(optimizer_hook)
|
420 |
+
|
421 |
+
# Подготовка через Accelerator
|
422 |
+
unet, optimizer = accelerator.prepare(unet, optimizer_dict)
|
423 |
+
else:
|
424 |
+
if adam8bit:
|
425 |
+
optimizer = bnb.optim.AdamW8bit(
|
426 |
+
params=unet.parameters(),
|
427 |
+
lr=base_learning_rate,
|
428 |
+
betas=(0.9, 0.999),
|
429 |
+
eps=1e-8,
|
430 |
+
weight_decay=0.01
|
431 |
+
)
|
432 |
+
#from torch.optim import AdamW
|
433 |
+
#optimizer = AdamW(
|
434 |
+
# params=unet.parameters(),
|
435 |
+
# lr=base_learning_rate,
|
436 |
+
# betas=(0.9, 0.999),
|
437 |
+
# eps=1e-8,
|
438 |
+
# weight_decay=0.01
|
439 |
+
#)
|
440 |
+
else:
|
441 |
+
optimizer = bnb.optim.Lion8bit(
|
442 |
+
params=unet.parameters(),
|
443 |
+
lr=base_learning_rate,
|
444 |
+
betas=(0.9, 0.97),
|
445 |
+
weight_decay=0.01,
|
446 |
+
percentile_clipping=percentile_clipping,
|
447 |
+
)
|
448 |
+
from transformers import get_constant_schedule_with_warmup
|
449 |
+
|
450 |
+
# warmup
|
451 |
+
num_warmup_steps = num_warmup_steps * world_size
|
452 |
+
|
453 |
+
#lr_scheduler = get_constant_schedule_with_warmup(
|
454 |
+
# optimizer=optimizer,
|
455 |
+
# num_warmup_steps=num_warmup_steps
|
456 |
+
#)
|
457 |
+
from torch.optim.lr_scheduler import LambdaLR
|
458 |
+
def lr_schedule(step, max_steps, base_lr, min_lr, use_decay=True):
|
459 |
+
# Если не используем затухание, возвращаем базовый LR
|
460 |
+
if not use_decay:
|
461 |
+
return base_lr
|
462 |
+
|
463 |
+
# Иначе используем линейный прогрев и косинусное затухание
|
464 |
+
x = step / max_steps
|
465 |
+
percent = 0.05
|
466 |
+
if x < percent:
|
467 |
+
# Линейный прогрев до percent% шагов
|
468 |
+
return min_lr + (base_lr - min_lr) * (x / percent)
|
469 |
+
else:
|
470 |
+
# Косинусное затухание
|
471 |
+
decay_ratio = (x - percent) / (1 - percent)
|
472 |
+
return min_lr + 0.5 * (base_lr - min_lr) * (1 + math.cos(math.pi * decay_ratio))
|
473 |
+
|
474 |
+
|
475 |
+
def custom_lr_lambda(step):
|
476 |
+
return lr_schedule(step, total_training_steps*world_size,
|
477 |
+
base_learning_rate, min_learning_rate,
|
478 |
+
(num_warmup_steps>0)) / base_learning_rate
|
479 |
+
|
480 |
+
lr_scheduler = LambdaLR(optimizer, lr_lambda=custom_lr_lambda)
|
481 |
+
unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
|
482 |
+
|
483 |
+
# --------------------------- Фиксированные семплы для генерации ---------------------------
|
484 |
+
# Примеры фиксированных семплов по размерам
|
485 |
+
fixed_samples = get_fixed_samples_by_resolution(dataset)
|
486 |
+
|
487 |
+
@torch.compiler.disable()
|
488 |
+
@torch.no_grad()
|
489 |
+
def generate_and_save_samples(fixed_samples_cpu, step):
|
490 |
+
"""
|
491 |
+
Генерирует семплы для каждого из разрешений и сохраняет их.
|
492 |
+
|
493 |
+
Args:
|
494 |
+
fixed_samples_cpu: Словарь, где ключи - размеры (width, height),
|
495 |
+
а значения - кортежи (latents, embeddings, text) на CPU.
|
496 |
+
step: Текущий шаг обучения
|
497 |
+
"""
|
498 |
+
original_model = None # Инициализируем, чтобы finally не ругался
|
499 |
+
try:
|
500 |
+
|
501 |
+
original_model = accelerator.unwrap_model(unet)
|
502 |
+
original_model = original_model.to(dtype = dtype)
|
503 |
+
original_model.eval()
|
504 |
+
|
505 |
+
vae.to(device=device, dtype=dtype)
|
506 |
+
vae.eval()
|
507 |
+
|
508 |
+
scheduler.set_timesteps(n_diffusion_steps)
|
509 |
+
|
510 |
+
all_generated_images = []
|
511 |
+
all_captions = []
|
512 |
+
|
513 |
+
for size, (sample_latents, sample_text_embeddings, sample_text) in fixed_samples_cpu.items():
|
514 |
+
width, height = size
|
515 |
+
|
516 |
+
sample_latents = sample_latents.to(dtype=dtype)
|
517 |
+
sample_text_embeddings = sample_text_embeddings.to(dtype=dtype)
|
518 |
+
|
519 |
+
# Инициализируем латенты случайным шумом
|
520 |
+
# sample_latents уже в dtype, так что noise будет создан в dtype
|
521 |
+
noise = torch.randn(
|
522 |
+
sample_latents.shape, # Используем форму от sample_latents, которые теперь на GPU и fp16
|
523 |
+
generator=gen,
|
524 |
+
device=device,
|
525 |
+
dtype=sample_latents.dtype
|
526 |
+
)
|
527 |
+
current_latents = noise.clone()
|
528 |
+
|
529 |
+
# Подготовка текстовых эмбеддингов для guidance
|
530 |
+
if guidance_scale > 0:
|
531 |
+
# empty_embeddings должны быть того же типа и на том же устройстве
|
532 |
+
empty_embeddings = torch.zeros_like(sample_text_embeddings, dtype=sample_text_embeddings.dtype, device=device)
|
533 |
+
text_embeddings_batch = torch.cat([empty_embeddings, sample_text_embeddings], dim=0)
|
534 |
+
else:
|
535 |
+
text_embeddings_batch = sample_text_embeddings
|
536 |
+
|
537 |
+
for t in scheduler.timesteps:
|
538 |
+
t_batch = t.repeat(current_latents.shape[0]).to(device) # Убедимся, что t на устройстве
|
539 |
+
|
540 |
+
if guidance_scale > 0:
|
541 |
+
latent_model_input = torch.cat([current_latents] * 2)
|
542 |
+
else:
|
543 |
+
latent_model_input = current_latents
|
544 |
+
|
545 |
+
latent_model_input_scaled = scheduler.scale_model_input(latent_model_input, t_batch)
|
546 |
+
|
547 |
+
# Предсказание шума (UNet)
|
548 |
+
noise_pred = original_model(latent_model_input_scaled, t_batch, text_embeddings_batch).sample
|
549 |
+
|
550 |
+
if guidance_scale > 0:
|
551 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
552 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
553 |
+
|
554 |
+
current_latents = scheduler.step(noise_pred, t, current_latents).prev_sample
|
555 |
+
|
556 |
+
#print(f"current_latents Min: {current_latents.min()} Max: {current_latents.max()}")
|
557 |
+
# Декодирование через VAE
|
558 |
+
latent_for_vae = (current_latents.detach() / vae.config.scaling_factor) + vae.config.shift_factor
|
559 |
+
decoded = vae.decode(latent_for_vae).sample
|
560 |
+
|
561 |
+
# Преобразуем тензоры в PIL-изображения
|
562 |
+
# Для математики с изображением (нормализация) лучше перейти в fp32
|
563 |
+
decoded_fp32 = decoded.to(torch.float32)
|
564 |
+
for img_idx, img_tensor in enumerate(decoded_fp32):
|
565 |
+
img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1, 2, 0)
|
566 |
+
# If NaNs or infs are present, print them
|
567 |
+
if np.isnan(img).any():
|
568 |
+
print("NaNs found, saving stoped! Step:", step)
|
569 |
+
save_model = False
|
570 |
+
pil_img = Image.fromarray((img * 255).astype("uint8"))
|
571 |
+
|
572 |
+
max_w_overall = max(s[0] for s in fixed_samples_cpu.keys())
|
573 |
+
max_h_overall = max(s[1] for s in fixed_samples_cpu.keys())
|
574 |
+
max_w_overall = max(255, max_w_overall)
|
575 |
+
max_h_overall = max(255, max_h_overall)
|
576 |
+
|
577 |
+
padded_img = ImageOps.pad(pil_img, (max_w_overall, max_h_overall), color='white')
|
578 |
+
all_generated_images.append(padded_img)
|
579 |
+
|
580 |
+
caption_text = sample_text[img_idx][:200] if img_idx < len(sample_text) else ""
|
581 |
+
all_captions.append(caption_text)
|
582 |
+
|
583 |
+
sample_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg"
|
584 |
+
pil_img.save(sample_path, "JPEG", quality=96)
|
585 |
+
|
586 |
+
if use_wandb and accelerator.is_main_process:
|
587 |
+
wandb_images = [
|
588 |
+
wandb.Image(img, caption=f"{all_captions[i]}")
|
589 |
+
for i, img in enumerate(all_generated_images)
|
590 |
+
]
|
591 |
+
wandb.log({"generated_images": wandb_images, "global_step": step})
|
592 |
+
|
593 |
+
finally:
|
594 |
+
vae.to("cpu") # Перемещаем VAE обратно на CPU
|
595 |
+
original_model = original_model.to(dtype = dtype)
|
596 |
+
if original_model is not None:
|
597 |
+
del original_model
|
598 |
+
# Очистка переменных, которые являются тензорами и были созданы в функции
|
599 |
+
for var in list(locals().keys()):
|
600 |
+
if isinstance(locals()[var], torch.Tensor):
|
601 |
+
del locals()[var]
|
602 |
+
|
603 |
+
torch.cuda.empty_cache()
|
604 |
+
gc.collect()
|
605 |
+
|
606 |
+
# --------------------------- Генерация сэмплов перед обучением ---------------------------
|
607 |
+
if accelerator.is_main_process:
|
608 |
+
if save_model:
|
609 |
+
print("Генерация сэмплов до старта обучения...")
|
610 |
+
generate_and_save_samples(fixed_samples,0)
|
611 |
+
|
612 |
+
# Модифицируем функцию сохранения модели для поддержки LoRA
|
613 |
+
def save_checkpoint(unet,variant=""):
|
614 |
+
if accelerator.is_main_process:
|
615 |
+
if lora_name:
|
616 |
+
# Сохраняем только LoRA адаптеры
|
617 |
+
save_lora_checkpoint(unet)
|
618 |
+
else:
|
619 |
+
# Сохраняем полную модель
|
620 |
+
if variant!="":
|
621 |
+
accelerator.unwrap_model(unet.to(dtype=torch.float16)).save_pretrained(os.path.join(checkpoints_folder, f"{project}"),variant=variant)
|
622 |
+
else:
|
623 |
+
accelerator.unwrap_model(unet).save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
|
624 |
+
unet = unet.to(dtype=dtype)
|
625 |
+
|
626 |
+
# --------------------------- Тренировочный цикл ---------------------------
|
627 |
+
# Для логирования среднего лосса каждые % эпохи
|
628 |
+
if accelerator.is_main_process:
|
629 |
+
print(f"Total steps per GPU: {total_training_steps}")
|
630 |
+
|
631 |
+
epoch_loss_points = []
|
632 |
+
progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
|
633 |
+
|
634 |
+
# Определяем интервал для сэмплирования и логирования в пределах эпохи (10% эпохи)
|
635 |
+
steps_per_epoch = len(dataloader)
|
636 |
+
sample_interval = max(1, steps_per_epoch // sample_interval_share)
|
637 |
+
min_loss = 1.
|
638 |
+
|
639 |
+
# Начинаем с указанной эпохи (полезно при возобновлении)
|
640 |
+
for epoch in range(start_epoch, start_epoch + num_epochs):
|
641 |
+
batch_losses = []
|
642 |
+
batch_grads = []
|
643 |
+
#unet = unet.to(dtype = dtype)
|
644 |
+
batch_sampler.set_epoch(epoch)
|
645 |
+
accelerator.wait_for_everyone()
|
646 |
+
unet.train()
|
647 |
+
print("epoch:",epoch)
|
648 |
+
for step, (latents, embeddings) in enumerate(dataloader):
|
649 |
+
with accelerator.accumulate(unet):
|
650 |
+
if save_model == False and step == 5 :
|
651 |
+
used_gb = torch.cuda.max_memory_allocated() / 1024**3
|
652 |
+
print(f"Шаг {step}: {used_gb:.2f} GB")
|
653 |
+
|
654 |
+
#latents = latents.to(dtype = dtype)
|
655 |
+
#embeddings = embeddings.to(dtype = dtype)
|
656 |
+
#print(f"Latents dtype: {latents.dtype}")
|
657 |
+
#print(f"Embeddings dtype: {embeddings.dtype}")
|
658 |
+
#print(f"Noise dtype: {noise.dtype}")
|
659 |
+
|
660 |
+
# Forward pass
|
661 |
+
noise = torch.randn_like(latents, dtype=latents.dtype)
|
662 |
+
|
663 |
+
timesteps = torch.randint(steps_offset, scheduler.config.num_train_timesteps,
|
664 |
+
(latents.shape[0],), device=device).long()
|
665 |
+
|
666 |
+
# Добавляем шум к латентам
|
667 |
+
noisy_latents = scheduler.add_noise(latents, noise, timesteps)
|
668 |
+
|
669 |
+
# Используем целевое значение
|
670 |
+
model_pred = unet(noisy_latents, timesteps, embeddings).sample
|
671 |
+
target_pred = scheduler.get_velocity(latents, noise, timesteps)
|
672 |
+
|
673 |
+
# Считаем лосс
|
674 |
+
# Проверяем model_pred на nan/inf
|
675 |
+
#if torch.isnan(model_pred.float()).any() or torch.isinf(model_pred.float()).any():
|
676 |
+
# print(f"Rank {accelerator.process_index}: Found nan/inf in model_pred",model_pred.float())
|
677 |
+
# # Обработка nan/inf значений
|
678 |
+
# model_pred = torch.nan_to_num(model_pred.float(), nan=0.0, posinf=1.0, neginf=-1.0)
|
679 |
+
loss = torch.nn.functional.mse_loss(model_pred, target_pred)
|
680 |
+
|
681 |
+
# Проверяем на nan/inf перед backward
|
682 |
+
if torch.isnan(loss) or torch.isinf(loss):
|
683 |
+
print(f"Rank {accelerator.process_index}: Found nan/inf in loss: {loss}")
|
684 |
+
loss = torch.zeros_like(loss)
|
685 |
+
|
686 |
+
# Делаем backward через Accelerator
|
687 |
+
accelerator.backward(loss)
|
688 |
+
|
689 |
+
if (global_step % 100 == 0) or (global_step % sample_interval == 0):
|
690 |
+
accelerator.wait_for_everyone()
|
691 |
+
|
692 |
+
grad = 0.0
|
693 |
+
if not fbp:
|
694 |
+
if accelerator.sync_gradients:
|
695 |
+
grad = accelerator.clip_grad_norm_(unet.parameters(), 1.)
|
696 |
+
optimizer.step()
|
697 |
+
lr_scheduler.step()
|
698 |
+
optimizer.zero_grad(set_to_none=True)
|
699 |
+
|
700 |
+
# Увеличиваем счетчик глобальных шагов
|
701 |
+
global_step += 1
|
702 |
+
|
703 |
+
# Обновляем прогресс-бар
|
704 |
+
progress_bar.update(1)
|
705 |
+
|
706 |
+
# Логируем метрики
|
707 |
+
if accelerator.is_main_process:
|
708 |
+
if fbp:
|
709 |
+
current_lr = base_learning_rate
|
710 |
+
else:
|
711 |
+
current_lr = lr_scheduler.get_last_lr()[0]
|
712 |
+
batch_losses.append(loss.detach().item())
|
713 |
+
batch_grads.append(grad)
|
714 |
+
|
715 |
+
# Логируем в Wandb
|
716 |
+
if use_wandb:
|
717 |
+
wandb.log({
|
718 |
+
"loss": loss.detach().item(),
|
719 |
+
"learning_rate": current_lr,
|
720 |
+
"epoch": epoch,
|
721 |
+
"grad": grad,
|
722 |
+
"global_step": global_step
|
723 |
+
})
|
724 |
+
|
725 |
+
# Генерируем сэмплы с заданным интервалом
|
726 |
+
if global_step % sample_interval == 0:
|
727 |
+
generate_and_save_samples(fixed_samples,global_step)
|
728 |
+
|
729 |
+
# Выводим текущий лосс
|
730 |
+
avg_loss = np.mean(batch_losses[-sample_interval:])
|
731 |
+
avg_grad = torch.mean(torch.stack(batch_grads[-sample_interval:])).cpu().item()
|
732 |
+
print(f"Эпоха {epoch}, шаг {global_step}, средний лосс: {avg_loss:.6f}")
|
733 |
+
|
734 |
+
if save_model:
|
735 |
+
if avg_loss < min_loss:
|
736 |
+
min_loss = avg_loss
|
737 |
+
save_checkpoint(unet,"fp16")
|
738 |
+
save_checkpoint(unet)
|
739 |
+
if use_wandb:
|
740 |
+
wandb.log({"intermediate_loss": avg_loss})
|
741 |
+
wandb.log({"intermediate_grad": avg_grad})
|
742 |
+
|
743 |
+
|
744 |
+
# По окончании эпохи
|
745 |
+
#accelerator.wait_for_everyone()
|
746 |
+
if accelerator.is_main_process:
|
747 |
+
avg_epoch_loss = np.mean(batch_losses)
|
748 |
+
print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
|
749 |
+
if use_wandb:
|
750 |
+
wandb.log({"epoch_loss": avg_epoch_loss, "epoch": epoch+1})
|
751 |
+
|
752 |
+
# Завершение обучения - сохраняем финальную модель
|
753 |
+
if accelerator.is_main_process:
|
754 |
+
print("Обучение завершено! Сохраняем финальную модель...")
|
755 |
+
# Сохраняем основную модель
|
756 |
+
if save_model:
|
757 |
+
save_checkpoint(unet)
|
758 |
+
print("Готово!")
|
759 |
+
|
760 |
+
# randomize ode timesteps
|
761 |
+
# input_timestep = torch.round(
|
762 |
+
# F.sigmoid(torch.randn((n,), device=latents.device)), decimals=3
|
763 |
+
# )
|
764 |
+
|
765 |
+
#def create_distribution(num_points, device=None):
|
766 |
+
# # Диапазон вероятностей на оси x
|
767 |
+
# x = torch.linspace(0, 1, num_points, device=device)
|
768 |
+
|
769 |
+
# Пользовательская функция плотности вероятности
|
770 |
+
# probabilities = -7.7 * ((x - 0.5) ** 2) + 2
|
771 |
+
|
772 |
+
# Нормализация, чтобы сумма равнялась 1
|
773 |
+
# probabilities /= probabilities.sum()
|
774 |
+
|
775 |
+
# return x, probabilities
|
776 |
+
|
777 |
+
#def sample_from_distribution(x, probabilities, n, device=None):
|
778 |
+
# Выбор индексов на основе распределения вероятностей
|
779 |
+
# indices = torch.multinomial(probabilities, n, replacement=True)
|
780 |
+
# return x[indices]
|
781 |
+
|
782 |
+
# Пример использования
|
783 |
+
#num_points = 1000 # Количество точек в диапазоне
|
784 |
+
#n = latents.shape[0] # Количество временных шагов для выборки
|
785 |
+
#x, probabilities = create_distribution(num_points, device=latents.device)
|
786 |
+
#timesteps = sample_from_distribution(x, probabilities, n, device=latents.device)
|
787 |
+
|
788 |
+
# Преобразование в формат, подходящий для вашего кода
|
789 |
+
#timesteps = (timesteps * (scheduler.config.num_train_timesteps - 1)).long()
|