recoilme commited on
Commit
94a2309
·
0 Parent(s):

Fresh start

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +39 -0
  2. .gitignore +13 -0
  3. README.md +178 -0
  4. TRAIN.md +44 -0
  5. budget.jpg +3 -0
  6. cherrypick-vavae.ipynb +3 -0
  7. dataset_fromfolder.py +386 -0
  8. model_index.json +3 -0
  9. pipeline_sdxs.py +295 -0
  10. promo.png +3 -0
  11. requirements.txt +11 -0
  12. result_grid.jpg +3 -0
  13. samples/unet_192x384_0.jpg +3 -0
  14. samples/unet_256x384_0.jpg +3 -0
  15. samples/unet_320x384_0.jpg +3 -0
  16. samples/unet_384x192_0.jpg +3 -0
  17. samples/unet_384x256_0.jpg +3 -0
  18. samples/unet_384x320_0.jpg +3 -0
  19. samples/unet_384x384_0.jpg +3 -0
  20. scheduler/scheduler_config.json +3 -0
  21. src/captions_moondream2.ipynb +3 -0
  22. src/captions_moondream2_wd3.ipynb +3 -0
  23. src/captions_qwen2-vl-7b.py +261 -0
  24. src/captions_wd.ipynb +3 -0
  25. src/cherrypick.ipynb +3 -0
  26. src/cuda.ipynb +3 -0
  27. src/dataset_clean.ipynb +3 -0
  28. src/dataset_combine.py +68 -0
  29. src/dataset_fromzip.ipynb +3 -0
  30. src/dataset_imagenet.ipynb +3 -0
  31. src/dataset_laion_coco.ipynb +3 -0
  32. src/dataset_mjnj.ipynb +3 -0
  33. src/dataset_mnist-te.ipynb +3 -0
  34. src/dataset_mnist.ipynb +3 -0
  35. src/dataset_sample.ipynb +3 -0
  36. src/inference.ipynb +3 -0
  37. src/sdxs_create-vavae.ipynb +3 -0
  38. src/sdxs_create.ipynb +3 -0
  39. src/sdxs_create_simple.ipynb +3 -0
  40. src/sdxs_create_unet.ipynb +3 -0
  41. src/sdxs_sdxxs_transfer.ipynb +3 -0
  42. test.ipynb +3 -0
  43. text_encoder/config.json +3 -0
  44. text_encoder/model.fp16.safetensors +3 -0
  45. text_projector/config.json +3 -0
  46. text_projector/model.safetensors +3 -0
  47. tokenizer/special_tokens_map.json +3 -0
  48. tokenizer/tokenizer.json +3 -0
  49. tokenizer/tokenizer_config.json +3 -0
  50. train-Copy1.py +789 -0
.gitattributes ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg filter=lfs diff=lfs merge=lfs -text
37
+ *.png filter=lfs diff=lfs merge=lfs -text
38
+ *.ipynb filter=lfs diff=lfs merge=lfs -text
39
+ *.json filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Jupyter Notebook
2
+ __pycache__/
3
+ *.pyc
4
+ .ipynb_checkpoints/
5
+ *.ipynb_checkpoints/*
6
+ .ipynb_checkpoints/*
7
+ src/samples
8
+ # cache
9
+ cache
10
+ datasets
11
+ test
12
+ wandb
13
+ nohup.out
README.md ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ pipeline_tag: text-to-image
4
+ ---
5
+
6
+ # Simple Diffusion XS
7
+
8
+ *XS Size, Excess Quality*
9
+
10
+
11
+ Train status, in progress: [wandb](https://wandb.ai/recoilme/unet)
12
+
13
+ ![result](result_grid.jpg)
14
+
15
+ ## Example
16
+
17
+ ```python
18
+ import torch
19
+ from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
20
+ from transformers import AutoModel, AutoTokenizer
21
+ from PIL import Image
22
+ from tqdm.auto import tqdm
23
+ import os
24
+
25
+ def encode_prompt(prompt, negative_prompt, device, dtype):
26
+ if negative_prompt is None:
27
+ negative_prompt = ""
28
+
29
+ with torch.no_grad():
30
+ positive_inputs = tokenizer(
31
+ prompt,
32
+ return_tensors="pt",
33
+ padding="max_length",
34
+ max_length=512,
35
+ truncation=True,
36
+ ).to(device)
37
+ positive_embeddings = text_model.encode_texts(
38
+ positive_inputs.input_ids, positive_inputs.attention_mask
39
+ )
40
+ if positive_embeddings.ndim == 2:
41
+ positive_embeddings = positive_embeddings.unsqueeze(1)
42
+ positive_embeddings = positive_embeddings.to(device, dtype=dtype)
43
+
44
+ negative_inputs = tokenizer(
45
+ negative_prompt,
46
+ return_tensors="pt",
47
+ padding="max_length",
48
+ max_length=512,
49
+ truncation=True,
50
+ ).to(device)
51
+ negative_embeddings = text_model.encode_texts(negative_inputs.input_ids, negative_inputs.attention_mask)
52
+ if negative_embeddings.ndim == 2:
53
+ negative_embeddings = negative_embeddings.unsqueeze(1)
54
+ negative_embeddings = negative_embeddings.to(device, dtype=dtype)
55
+ return torch.cat([negative_embeddings, positive_embeddings], dim=0)
56
+
57
+ def generate_latents(embeddings, height=576, width=576, num_inference_steps=50, guidance_scale=5.5):
58
+ with torch.no_grad():
59
+ device, dtype = embeddings.device, embeddings.dtype
60
+ half = embeddings.shape[0] // 2
61
+ latent_shape = (half, 16, height // 8, width // 8)
62
+ latents = torch.randn(latent_shape, device=device, dtype=dtype)
63
+ embeddings = embeddings.repeat_interleave(half, dim=0)
64
+
65
+ scheduler.set_timesteps(num_inference_steps)
66
+
67
+ for t in tqdm(scheduler.timesteps, desc="Генерация"):
68
+ latent_model_input = torch.cat([latents] * 2)
69
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
70
+ noise_pred = unet(latent_model_input, t, embeddings).sample
71
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
72
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
73
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
74
+ return latents
75
+
76
+
77
+ def decode_latents(latents, vae, output_type="pil"):
78
+ latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
79
+ with torch.no_grad():
80
+ images = vae.decode(latents).sample
81
+ images = (images / 2 + 0.5).clamp(0, 1)
82
+ images = images.cpu().permute(0, 2, 3, 1).float().numpy()
83
+ if output_type == "pil":
84
+ images = (images * 255).round().astype("uint8")
85
+ images = [Image.fromarray(image) for image in images]
86
+ return images
87
+
88
+ # Example usage:
89
+ if __name__ == "__main__":
90
+ device = "cuda"
91
+ dtype = torch.float16
92
+
93
+ prompt = "кот"
94
+ negative_prompt = "bad quality"
95
+ tokenizer = AutoTokenizer.from_pretrained("visheratin/mexma-siglip")
96
+ text_model = AutoModel.from_pretrained(
97
+ "visheratin/mexma-siglip", torch_dtype=dtype, trust_remote_code=True
98
+ ).to(device, dtype=dtype).eval()
99
+
100
+ embeddings = encode_prompt(prompt, negative_prompt, device, dtype)
101
+
102
+ pipeid = "AiArtLab/sdxs"
103
+ variant = "fp16"
104
+
105
+ unet = UNet2DConditionModel.from_pretrained(pipeid, subfolder="unet", variant=variant).to(device, dtype=dtype).eval()
106
+ vae = AutoencoderKL.from_pretrained(pipeid, subfolder="vae", variant=variant).to(device, dtype=dtype).eval()
107
+ scheduler = DDPMScheduler.from_pretrained(pipeid, subfolder="scheduler")
108
+
109
+
110
+ height, width = 576, 576
111
+ num_inference_steps = 40
112
+ output_folder, project_name = "samples", "sdxs"
113
+ latents = generate_latents(
114
+ embeddings=embeddings,
115
+ height=height,
116
+ width=width,
117
+ num_inference_steps = num_inference_steps
118
+ )
119
+
120
+ images = decode_latents(latents, vae)
121
+
122
+ os.makedirs(output_folder, exist_ok=True)
123
+ for idx, image in enumerate(images):
124
+ image.save(f"{output_folder}/{project_name}_{idx}.jpg")
125
+
126
+ print("Images generated and saved to:", output_folder)
127
+ ```
128
+
129
+ ## Introduction
130
+ *Fast, Lightweight & Multilingual Diffusion for Everyone*
131
+
132
+ We are **AiArtLab**, a small team of enthusiasts with a limited budget. Our goal is to create a compact and fast model that can be trained on consumer graphics cards (full training cycle, not LoRA). We chose U-Net for its ability to efficiently handle small datasets and train quickly even on a 16GB GPU (e.g., RTX 4080). Our budget was limited to a few thousand dollars, significantly less than competitors like SDXL (tens of millions), so we decided to create a small but efficient model, similar to SD1.5 but for 2025 year.
133
+
134
+ ## Encoder Architecture (Text and Images)
135
+ We experimented with various encoders and concluded that large models like LLaMA or T5 XXL are unnecessary for high-quality generation. However, we needed an encoder that understands the context of the query, focusing on "prompt understanding" versus "prompt following." We chose the multilingual encoder Mexma-SigLIP, which supports 80 languages and processes sentences rather than individual tokens. Mexma accepts up to 512 tokens, creating a large matrix that slows down training. Therefore, we used a pooling layer to simplify 512x1152 matrix with plain 1x1152 vector. Specifically, we passed it through a linear model/text projector to achieve compatibility with SigLIP embeddings. This allowed us to synchronize text embeddings with images, potentially leading to a unified multimodal model. This functionality enables mixing image embeddings with textual descriptions in queries. Moreover, the model can be trained without text descriptions, using only images. This should simplify training on videos, where annotation is challenging, and achieve more consistent and seamless video generation by inputting embeddings of previous frames with decay. In the future, we aim to expand the model to 3D/video generation.
136
+
137
+ ## U-Net Architecture
138
+ We chose a smooth channel pyramid: [384, 576, 768, 960] with two layers per block and [4, 6, 8, 10] transformers with 1152/48=24 attention heads. This architecture provides the highest training speed with a model size of around 2 billion parameters (and fitting perfectly in my RTX 4080). We believe that due to its greater 'depth,' the quality will be on par with SDXL despite the smaller 'size.' The model can be expanded to 4 billion parameters by adding an 1152 layer, achieving perfect symmetry with the embedding size, which we value for its elegance, and probably 'Flux/MJ level' quality.
139
+
140
+ ## VAE Architecture
141
+ We chose an unconventional 8x 16-channel AuraDiffusion VAE, which preserves details, text, and anatomy without the 'haze' characteristic of SD3/Flux. We used a fast version with FFN convolution, observing minor texture damage on fine patterns, which may lower its rating on benchmarks. Upscalers like ESRGAN can address these artifacts. Overall, we believe this VAE is highly underrated."
142
+
143
+ ## Training Process
144
+ ### Optimizer
145
+ We tested several optimizers (AdamW, Laion, Optimi-AdamW, Adafactor, and AdamW-8bit) and chose AdamW-8bit. Optimi-AdamW demonstrated the smoothest gradient decay curve, although AdamW-8bit behaves more chaotically. However, its smaller size allows for larger batch sizes, maximizing training speed on low-cost GPUs (we used 4xA6000 and 5xL40s for training).
146
+
147
+ ### Learning Rate
148
+ We found that manipulating the decay/warm-up curve has an effect but is not significant. The optimal learning rate is often overestimated. Our experiments showed that Adam allows for a wide learning rate range. We started at 1e-4, gradually decreasing to 1e-6 during training. In other words, choosing the correct model architecture is far more critical than tweaking hyperparameters.
149
+
150
+ ### Dataset
151
+ We trained the model on approximately 1 million images: 60 epochs on ImageNet at 256 resolution (wasted time because of low-quality annotations) and 8 epochs on CaptionEmporium/midjourney-niji-1m-llavanext, plus realistic photos and anime/art at 576 resolution. We used human prompts, Caption Emporium provided prompts, WD-Tagger from SmilingWolf, and Moondream2 for annotation, varying prompt length and composition to ensure the model understands different prompting styles. The dataset is extremely small, leading the model to miss many entities and struggle with unseen concepts like 'a goose on a bicycle.' The dataset also included many waifu-style images, as we were interested in how well the model learns human anatomy rather than drawing 'The Astronaut on horseback' skills. While most descriptions were in English, our tests indicate the model is multilingual.
152
+
153
+ ## Limitations
154
+ - Limited concept coverage due to the extremely small dataset.
155
+ - The Image2Image functionality needs further training (we reduced the SigLIP portion to 5% to focus on text-to-image training).
156
+
157
+ ## Acknowledgments
158
+ - **[Stan](https://t.me/Stangle)** — Key investor. Primary financial support - thank you for believing in us when others called it madness.
159
+ - **Captainsaturnus** — Material support.
160
+ - **Lovescape** & **Whargarbl** — Moral support.
161
+ - **[CaptionEmporium](https://huggingface.co/CaptionEmporium)** — Datasets.
162
+
163
+ > "We believe the future lies in efficient, compact models. We are grateful for the donations and hope for your continued support."
164
+
165
+ ## Training budget
166
+
167
+ ![budget](budget.jpg)
168
+
169
+ ## Donations
170
+
171
+ Please contact with us if you may provide some GPU's or money on training
172
+
173
+ DOGE: DEw2DR8C7BnF8GgcrfTzUjSnGkuMeJhg83
174
+ BTC: 3JHv9Hb8kEW8zMAccdgCdZGfrHeMhH1rpN
175
+
176
+ ## Contacts
177
+
178
+ [recoilme](https://t.me/recoilme)
TRAIN.md ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+
5
+ Краткая инструкция по установке
6
+ Обновите систему и установите git-lfs:
7
+
8
+ ```
9
+ apt update
10
+ apt install git-lfs
11
+ git config --global credential.helper store
12
+ ```
13
+ Обновите pip и установите требуемые пакеты:
14
+
15
+ ```
16
+ python -m pip install --upgrade pip
17
+ pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124 -U
18
+ pip install flash-attn --no-build-isolation # optional
19
+ ```
20
+ Клонируйте репозиторий:
21
+
22
+ ```
23
+ git clone https://huggingface.co/AiArtLab/sdxs
24
+ cd sdxs/
25
+ pip install -r requirements.txt
26
+ ```
27
+ Подготовьте датасет:
28
+
29
+ ```
30
+ mkdir datasets
31
+ cd datasets
32
+ huggingface-cli download AiArtLab/384 --local-dir 384 --repo-type dataset
33
+ ```
34
+ Выполните вход в сервисы:
35
+
36
+ ```
37
+ huggingface-cli login
38
+ wandb login
39
+ ```
40
+ Запустите обучение!
41
+
42
+ ```
43
+ nohup accelerate launch train.py &
44
+ ```
budget.jpg ADDED

Git LFS Details

  • SHA256: 9e635ae6cd283805338a7dfa6e1bc90089d612d7f30463668034f3b256fa22a5
  • Pointer size: 131 Bytes
  • Size of remote file: 376 kB
cherrypick-vavae.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d76d1eb4ba99e75f234a07e86ef0503ebe6ddec695c61308b6d20a5a6eca9f0f
3
+ size 16287
dataset_fromfolder.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pip install flash-attn --no-build-isolation
2
+ from datasets import Dataset, load_from_disk, concatenate_datasets
3
+ from diffusers import AutoencoderKL
4
+ from torchvision.transforms import Resize, ToTensor, Normalize, Compose, InterpolationMode, Lambda
5
+ from transformers import AutoModel, AutoImageProcessor, AutoTokenizer
6
+ import torch
7
+ import os
8
+ import gc
9
+ import numpy as np
10
+ from PIL import Image
11
+ from tqdm import tqdm
12
+ import random
13
+ import json
14
+ import shutil
15
+ import time
16
+ from datetime import timedelta
17
+
18
+ # ---------------- 1️⃣ Настройки ----------------
19
+ dtype = torch.float16
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ batch_size = 10
22
+ min_size = 192
23
+ max_size = 384
24
+ step = 64
25
+ img_share = 0.05
26
+ empty_share = 0.05
27
+ limit = 0
28
+ textemb_full = False
29
+ # Основная процедура обработки
30
+ folder_path = "/workspace/d3"
31
+ save_path = "/workspace/sdxs/datasets/ds3_384"
32
+ os.makedirs(save_path, exist_ok=True)
33
+
34
+ # Функция для очистки CUDA памяти
35
+ def clear_cuda_memory():
36
+ if torch.cuda.is_available():
37
+ used_gb = torch.cuda.max_memory_allocated() / 1024**3
38
+ print(f"used_gb: {used_gb:.2f} GB")
39
+ torch.cuda.empty_cache()
40
+ gc.collect()
41
+
42
+ # ---------------- 2️⃣ Загрузка моделей ----------------
43
+ def load_models():
44
+ print("Загрузка моделей...")
45
+ vae = AutoencoderKL.from_pretrained("/workspace/sdxs/vae", variant="fp16",torch_dtype=dtype).to(device).eval()
46
+ model = AutoModel.from_pretrained("visheratin/mexma-siglip", torch_dtype=dtype, trust_remote_code=True, optimized=True).to(device).eval()
47
+ processor = AutoImageProcessor.from_pretrained("visheratin/mexma-siglip", use_fast=True)
48
+ tokenizer = AutoTokenizer.from_pretrained("visheratin/mexma-siglip")
49
+ return vae, model, processor, tokenizer
50
+
51
+ vae, model, processor, tokenizer = load_models()
52
+
53
+
54
+ # ---------------- 3️⃣ Трансформации ----------------
55
+ def get_image_transform(min_size=256, max_size=512, step=64):
56
+ def transform(img, dry_run=False):
57
+ # Сохраняем исходные размеры изображения
58
+ original_width, original_height = img.size
59
+
60
+ # 0. Ресайз: масштабируем изображение, чтобы максимальная сторона была равна max_size
61
+ if original_width >= original_height:
62
+ new_width = max_size
63
+ new_height = int(max_size * original_height / original_width)
64
+ else:
65
+ new_height = max_size
66
+ new_width = int(max_size * original_width / original_height)
67
+
68
+ if new_height < min_size or new_width < min_size:
69
+ # 1. Ресайз: масштабируем изображение, чтобы минимальная сторона была равна min_size
70
+ if original_width <= original_height:
71
+ new_width = min_size
72
+ new_height = int(min_size * original_height / original_width)
73
+ else:
74
+ new_height = min_size
75
+ new_width = int(min_size * original_width / original_height)
76
+
77
+ # 2. Проверка: если одна из сторон превышает max_size, готовимся к обрезке
78
+ crop_width = min(max_size, (new_width // step) * step)
79
+ crop_height = min(max_size, (new_height // step) * step)
80
+
81
+ # Убеждаемся, что размеры обрезки не меньше min_size
82
+ crop_width = max(min_size, crop_width)
83
+ crop_height = max(min_size, crop_height)
84
+
85
+ # Если запрошен только предварительный расчёт размеров
86
+ if dry_run:
87
+ return crop_width, crop_height
88
+
89
+ # Конвертация в RGB и ресайз
90
+ img_resized = img.convert("RGB").resize((new_width, new_height), Image.LANCZOS)
91
+
92
+ # Определение координат обрезки (обрезаем с учетом вотермарок - треть сверху)
93
+ top = (new_height - crop_height) // 3
94
+ left = 0
95
+
96
+ # Обрезка изображения
97
+ img_cropped = img_resized.crop((left, top, left + crop_width, top + crop_height))
98
+
99
+ # Сохраняем итоговые размеры после всех преобразований
100
+ final_width, final_height = img_cropped.size
101
+
102
+ # тензор
103
+ img_tensor = ToTensor()(img_cropped)
104
+ img_tensor = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])(img_tensor)
105
+ return img_tensor, img_cropped, final_width, final_height
106
+
107
+ return transform
108
+
109
+ # ---------------- 4️⃣ Функции обработки ----------------
110
+ def encode_images_batch(images, processor, model):
111
+ pixel_values = torch.stack([processor(images=img, return_tensors="pt")["pixel_values"].squeeze(0) for img in images]).to(device, dtype)
112
+
113
+ with torch.inference_mode():
114
+ image_embeddings = model.vision_model(pixel_values).pooler_output
115
+
116
+ return image_embeddings.unsqueeze(1).cpu().numpy()
117
+
118
+ def encode_texts_batch(texts, tokenizer, model):
119
+ with torch.inference_mode():
120
+ text_tokenized = tokenizer(texts, return_tensors="pt", padding="max_length",
121
+ max_length=512,
122
+ truncation=True).to(device)
123
+ text_embeddings = model.encode_texts(text_tokenized.input_ids, text_tokenized.attention_mask)
124
+ return text_embeddings.unsqueeze(1).cpu().numpy()
125
+
126
+ def encode_texts_batch_full(texts, tokenizer, model):
127
+ with torch.inference_mode():
128
+ text_tokenized = tokenizer(texts, return_tensors="pt", padding="max_length",max_length=512,truncation=True).to(device)
129
+ features = model.text_model(
130
+ input_ids=text_tokenized.input_ids, attention_mask=text_tokenized.attention_mask
131
+ ).last_hidden_state
132
+ features_proj = model.text_projector(features)
133
+ return features_proj.cpu().numpy()
134
+
135
+ def clean_label(label):
136
+ label = label.replace("Image 1", "").replace("Image 2", "").replace("Image 3", "").replace("Image 4", "")
137
+ return label
138
+
139
+ def process_labels_for_guidance(original_labels, prob_to_make_empty=0.01):
140
+ """
141
+ Обрабатывает список меток для classifier-free guidance.
142
+
143
+ С вероятностью prob_to_make_empty:
144
+ - Метка в первом списке заменяется на пустую строку.
145
+ - К метке во втором списке добавляется префикс "zero:".
146
+
147
+ В противном случае метки в обоих списках остаются оригинальными.
148
+
149
+ """
150
+ labels_for_model = []
151
+ labels_for_logging = []
152
+
153
+ for label in original_labels:
154
+ if random.random() < prob_to_make_empty:
155
+ labels_for_model.append("") # Заменяем на пустую строку для модели
156
+ labels_for_logging.append(f"zero: {label}") # Добавляем префикс для логгирования
157
+ else:
158
+ labels_for_model.append(label) # Оставляем оригинальную метку для модели
159
+ labels_for_logging.append(label) # Оставляем оригинальную метку для логгирования
160
+
161
+ return labels_for_model, labels_for_logging
162
+
163
+ def encode_to_latents(images, texts):
164
+ transform = get_image_transform(min_size, max_size, step)
165
+
166
+ try:
167
+ # Обработка изображений (все одинакового размера)
168
+ transformed_tensors = []
169
+ pil_images = []
170
+ widths, heights = [], []
171
+
172
+ # Применяем трансформацию ко всем изображениям
173
+ for img in images:
174
+ try:
175
+ t_img, pil_img, w, h = transform(img)
176
+ transformed_tensors.append(t_img)
177
+ pil_images.append(pil_img)
178
+ widths.append(w)
179
+ heights.append(h)
180
+ except Exception as e:
181
+ print(f"Ошибка трансформации: {e}")
182
+ continue
183
+
184
+ if not transformed_tensors:
185
+ return None
186
+
187
+ # Создаём батч
188
+ batch_tensor = torch.stack(transformed_tensors).to(device, dtype)
189
+
190
+ # Кодируем батч
191
+ with torch.no_grad():
192
+ posteriors = vae.encode(batch_tensor).latent_dist.mode()
193
+ latents = (posteriors - vae.config.shift_factor) * vae.config.scaling_factor
194
+
195
+ latents_np = latents.cpu().numpy()
196
+
197
+ # Проверка однородности форм
198
+ base_shape = latents_np.shape[1:] # Форма без батча
199
+ valid_indices = []
200
+ valid_latents = []
201
+
202
+ for idx, latent in enumerate(latents_np):
203
+ if latent.shape != base_shape:
204
+ print(f"❌ Несоответствие формы в индексе {idx}: {latent.shape} vs {base_shape}")
205
+ continue
206
+ valid_indices.append(idx)
207
+ valid_latents.append(latent)
208
+
209
+ # Фильтруем данные
210
+ valid_pil = [pil_images[i] for i in valid_indices]
211
+ valid_widths = [widths[i] for i in valid_indices]
212
+ valid_heights = [heights[i] for i in valid_indices]
213
+
214
+ # Обрабатываем тексты
215
+ text_labels = [clean_label(texts[i]) for i in valid_indices]
216
+ if random.random() < img_share:
217
+ embeddings = encode_images_batch(valid_pil, processor, model)
218
+ text_labels = [f"img: {text_labels[i]}" for i in valid_indices]
219
+ else:
220
+ model_prompts, text_labels = process_labels_for_guidance(text_labels, empty_share)
221
+ if textemb_full:
222
+ embeddings = encode_texts_batch_full(model_prompts, tokenizer, model)
223
+ else:
224
+ embeddings = encode_texts_batch(model_prompts, tokenizer, model)
225
+
226
+ return {
227
+ "vae": np.array(valid_latents),
228
+ "embeddings": embeddings,
229
+ "text": text_labels,
230
+ "width": valid_widths,
231
+ "height": valid_heights
232
+ }
233
+
234
+ except Exception as e:
235
+ print(f"Критическая ошибка в encode_to_latents: {e}")
236
+ raise
237
+
238
+ # ---------------- 5️⃣ Обработка папки с изображениями и текстами ----------------
239
+ def process_folder(folder_path, limit=None):
240
+ """
241
+ Рекурсивно обходит указанную директорию и все вложенные директории,
242
+ собирая пути к изображениям и соответствующим текстовым файлам.
243
+ """
244
+ image_paths = []
245
+ text_paths = []
246
+ width = []
247
+ height = []
248
+ transform = get_image_transform(min_size, max_size, step)
249
+
250
+ # Используем os.walk для рекурсивного обхода директорий
251
+ for root, dirs, files in os.walk(folder_path):
252
+ for filename in files:
253
+ # Проверяем, является ли файл изображением
254
+ if filename.lower().endswith((".jpg", ".jpeg", ".png")):
255
+ image_path = os.path.join(root, filename)
256
+ try:
257
+ img = Image.open(image_path)
258
+ except Exception as e:
259
+ print(f"Ошибка при открытии {image_path}: {e}")
260
+ os.remove(image_path)
261
+ text_path = os.path.splitext(image_path)[0] + ".txt"
262
+ if os.path.exists(text_path):
263
+ os.remove(text_path)
264
+ continue
265
+ # Применяем трансформацию только для получения размеров
266
+ w, h = transform(img, dry_run=True)
267
+ # Формируем путь к текстовому файлу
268
+ text_path = os.path.splitext(image_path)[0] + ".txt"
269
+
270
+ # Добавляем пути, если текстовый файл существует
271
+ if os.path.exists(text_path) and min(w, h)>0:
272
+ image_paths.append(image_path)
273
+ text_paths.append(text_path)
274
+ width.append(w) # Добавляем в список
275
+ height.append(h) # Добавляем в список
276
+
277
+ # Проверяем ограничение на количество
278
+ if limit and limit>0 and len(image_paths) >= limit:
279
+ print(f"Достигнут лимит в {limit} изображений")
280
+ return image_paths, text_paths, width, height
281
+
282
+ print(f"Найдено {len(image_paths)} изображений с текстовыми описаниями")
283
+ return image_paths, text_paths, width, height
284
+
285
+ def process_in_chunks(image_paths, text_paths, width, height, chunk_size=50000, batch_size=1):
286
+ total_files = len(image_paths)
287
+ start_time = time.time()
288
+ chunks = range(0, total_files, chunk_size)
289
+
290
+ for chunk_idx, start in enumerate(chunks, 1):
291
+ end = min(start + chunk_size, total_files)
292
+ chunk_image_paths = image_paths[start:end]
293
+ chunk_text_paths = text_paths[start:end]
294
+ chunk_widths = width[start:end] if isinstance(width, list) else [width] * len(chunk_image_paths)
295
+ chunk_heights = height[start:end] if isinstance(height, list) else [height] * len(chunk_image_paths)
296
+
297
+ # Чтение текстов
298
+ chunk_texts = []
299
+ for text_path in chunk_text_paths:
300
+ try:
301
+ with open(text_path, 'r', encoding='utf-8') as f:
302
+ text = f.read().strip()
303
+ chunk_texts.append(text)
304
+ except Exception as e:
305
+ print(f"Ошибка чтения {text_path}: {e}")
306
+ chunk_texts.append("")
307
+
308
+ # Группируем изображения по размерам
309
+ size_groups = {}
310
+ for i in range(len(chunk_image_paths)):
311
+ size_key = (chunk_widths[i], chunk_heights[i])
312
+ if size_key not in size_groups:
313
+ size_groups[size_key] = {"image_paths": [], "texts": []}
314
+ size_groups[size_key]["image_paths"].append(chunk_image_paths[i])
315
+ size_groups[size_key]["texts"].append(chunk_texts[i])
316
+
317
+ # Обрабатываем каждую группу размеров отдельно
318
+ for size_key, group_data in size_groups.items():
319
+ print(f"Обработка группы с размером {size_key[0]}x{size_key[1]} - {len(group_data['image_paths'])} изображений")
320
+
321
+ group_dataset = Dataset.from_dict({
322
+ "image_path": group_data["image_paths"],
323
+ "text": group_data["texts"]
324
+ })
325
+
326
+ # Теперь можно использовать указанный batch_size, т.к. все изображения одного размера
327
+ processed_group = group_dataset.map(
328
+ lambda examples: encode_to_latents(
329
+ [Image.open(path) for path in examples["image_path"]],
330
+ examples["text"]
331
+ ),
332
+ batched=True,
333
+ batch_size=batch_size,
334
+ remove_columns=["image_path"],
335
+ desc=f"Обработка группы размера {size_key[0]}x{size_key[1]}"
336
+ )
337
+
338
+ # Сохраняем результаты группы
339
+ group_save_path = f"{save_path}_temp/chunk_{chunk_idx}_size_{size_key[0]}x{size_key[1]}"
340
+ processed_group.save_to_disk(group_save_path)
341
+ clear_cuda_memory()
342
+ elapsed = time.time() - start_time
343
+ processed = (chunk_idx - 1) * chunk_size + sum([len(sg["image_paths"]) for sg in list(size_groups.values())[:list(size_groups.values()).index(group_data) + 1]])
344
+ if processed > 0:
345
+ remaining = (elapsed / processed) * (total_files - processed)
346
+ elapsed_str = str(timedelta(seconds=int(elapsed)))
347
+ remaining_str = str(timedelta(seconds=int(remaining)))
348
+ print(f"ETA: Прошло {elapsed_str}, Осталось {remaining_str}, Прогресс {processed}/{total_files} ({processed/total_files:.1%})")
349
+
350
+ # ---------------- 7️⃣ Объединение чанков ----------------
351
+ def combine_chunks(temp_path, final_path):
352
+ """Объединение обработанных чанков в финальный датасет"""
353
+ chunks = sorted([
354
+ os.path.join(temp_path, d)
355
+ for d in os.listdir(temp_path)
356
+ if d.startswith("chunk_")
357
+ ])
358
+
359
+ datasets = [load_from_disk(chunk) for chunk in chunks]
360
+ combined = concatenate_datasets(datasets)
361
+ combined.save_to_disk(final_path)
362
+
363
+ print(f"✅ Датасет успешно сохранен в: {final_path}")
364
+
365
+
366
+
367
+ # Создаем временную папку для чанков
368
+ temp_path = f"{save_path}_temp"
369
+ os.makedirs(temp_path, exist_ok=True)
370
+
371
+ # Получаем список файлов
372
+ image_paths, text_paths, width, height = process_folder(folder_path,limit)
373
+ print(f"Всего найдено {len(image_paths)} изображений")
374
+
375
+ # Обработка с чанкованием
376
+ process_in_chunks(image_paths, text_paths, width, height, chunk_size=100000, batch_size=batch_size)
377
+
378
+ # Объединение чанков в финальный датасет
379
+ combine_chunks(temp_path, save_path)
380
+
381
+ # Удаление временной папки
382
+ try:
383
+ shutil.rmtree(temp_path)
384
+ print(f"✅ Временная папка {temp_path} успешно удалена")
385
+ except Exception as e:
386
+ print(f"⚠️ Ошибка при удалении временной папки: {e}")
model_index.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:85e884f29f7a6282a634d90350933aa68021326035a0980072667757c3bc9112
3
+ size 476
pipeline_sdxs.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import DiffusionPipeline
2
+ import torch
3
+ import torch.nn as nn
4
+ import os
5
+ from diffusers.utils import BaseOutput
6
+ from dataclasses import dataclass
7
+ from typing import List, Union, Optional
8
+ from PIL import Image
9
+ import numpy as np
10
+ import json
11
+ from safetensors.torch import load_file
12
+ from tqdm import tqdm
13
+
14
+ @dataclass
15
+ class SdxsPipelineOutput(BaseOutput):
16
+ images: Union[List[Image.Image], np.ndarray]
17
+
18
+ class SdxsPipeline(DiffusionPipeline):
19
+ def __init__(self, vae, text_encoder, tokenizer, unet, scheduler, text_projector=None):
20
+ super().__init__()
21
+
22
+ # Register components
23
+ self.register_modules(
24
+ vae=vae, text_encoder=text_encoder, tokenizer=tokenizer,
25
+ unet=unet, scheduler=scheduler
26
+ )
27
+
28
+ # Get the model path, which is either provided directly or from internal dict
29
+ model_path = None
30
+ if hasattr(self, '_internal_dict') and self._internal_dict.get('_name_or_path'):
31
+ model_path = self._internal_dict.get('_name_or_path')
32
+
33
+ # Get device and dtype from existing components
34
+ device = "cuda"
35
+ dtype = torch.float16
36
+
37
+ # Always load text_projector, regardless of whether one was provided
38
+ projector_path = None
39
+
40
+ # Try to find projector path
41
+ if model_path and os.path.exists(f"{model_path}/text_projector"):
42
+ projector_path = f"{model_path}/text_projector"
43
+ elif os.path.exists("./text_projector"):
44
+ projector_path = "./text_projector"
45
+
46
+ if projector_path:
47
+ # Create and load projector
48
+ try:
49
+ with open(f"{projector_path}/config.json", "r") as f:
50
+ projector_config = json.load(f)
51
+
52
+ # Create Linear layer with bias=False
53
+ self.text_projector = nn.Linear(
54
+ in_features=projector_config["in_features"],
55
+ out_features=projector_config["out_features"],
56
+ bias=False
57
+ )
58
+
59
+ # Load the state dict using safetensors
60
+ self.text_projector.load_state_dict(load_file(f"{projector_path}/model.safetensors"))
61
+ self.text_projector.to(device=device, dtype=dtype)
62
+ print(f"Successfully loaded text_projector from {projector_path}",device, dtype)
63
+ except Exception as e:
64
+ print(f"Error loading text_projector: {e}")
65
+
66
+ self.vae_scale_factor = 8
67
+
68
+
69
+
70
+ def encode_prompt(self, prompt=None, negative_prompt=None, device=None, dtype=None):
71
+ """Кодирование текстовых промптов в эмбеддинги.
72
+
73
+ Возвращает:
74
+ - text_embeddings: Тензор эмбеддингов [batch_size, 1, dim] или [2*batch_size, 1, dim] с guidance
75
+ """
76
+ if prompt is None and negative_prompt is None:
77
+ raise ValueError("Требуется хотя бы один из параметров: prompt или negative_prompt")
78
+
79
+ # Устанавливаем device и dtype
80
+ device = device or self.device
81
+ dtype = dtype or next(self.unet.parameters()).dtype
82
+
83
+ with torch.no_grad():
84
+ # Обрабатываем позитивный промпт
85
+ if prompt is not None:
86
+ if isinstance(prompt, str):
87
+ prompt = [prompt]
88
+
89
+ text_inputs = self.tokenizer(
90
+ prompt, return_tensors="pt", padding="max_length",
91
+ max_length=512, truncation=True
92
+ ).to(device)
93
+
94
+ # Получаем эмбеддинги
95
+ outputs = self.text_encoder(text_inputs.input_ids, text_inputs.attention_mask)
96
+ last_hidden_state = outputs.last_hidden_state.to(device, dtype=dtype)
97
+ pos_embeddings = self.text_projector(last_hidden_state[:, 0])
98
+
99
+ # Добавляем размерность для batch processing
100
+ if pos_embeddings.ndim == 2:
101
+ pos_embeddings = pos_embeddings.unsqueeze(1)
102
+ else:
103
+ # Создаем пустые эмбеддинги, если нет позитивного промпта
104
+ # (полезно для некоторых сценариев с unconditional generation)
105
+ batch_size = len(negative_prompt) if isinstance(negative_prompt, list) else 1
106
+ pos_embeddings = torch.zeros(
107
+ batch_size, 1, self.unet.config.cross_attention_dim,
108
+ device=device, dtype=dtype
109
+ )
110
+
111
+ # Обрабатываем негативный промпт
112
+ if negative_prompt is not None:
113
+ if isinstance(negative_prompt, str):
114
+ negative_prompt = [negative_prompt]
115
+
116
+ # Убеждаемся, что размеры негативного и позитивного промптов совпадают
117
+ if prompt is not None and len(negative_prompt) != len(prompt):
118
+ neg_batch_size = len(prompt)
119
+ if len(negative_prompt) == 1:
120
+ negative_prompt = negative_prompt * neg_batch_size
121
+ else:
122
+ negative_prompt = negative_prompt[:neg_batch_size]
123
+
124
+ neg_inputs = self.tokenizer(
125
+ negative_prompt, return_tensors="pt", padding="max_length",
126
+ max_length=512, truncation=True
127
+ ).to(device)
128
+
129
+ neg_outputs = self.text_encoder(neg_inputs.input_ids, neg_inputs.attention_mask)
130
+ neg_last_hidden_state = neg_outputs.last_hidden_state.to(device, dtype=dtype)
131
+ neg_embeddings = self.text_projector(neg_last_hidden_state[:, 0])
132
+
133
+ if neg_embeddings.ndim == 2:
134
+ neg_embeddings = neg_embeddings.unsqueeze(1)
135
+
136
+ # Объединяем для classifier-free guidance
137
+ text_embeddings = torch.cat([neg_embeddings, pos_embeddings], dim=0)
138
+ else:
139
+ # Если нет негативного промпта, используем нулевые эмбеддинги
140
+ batch_size = pos_embeddings.shape[0]
141
+ neg_embeddings = torch.zeros_like(pos_embeddings)
142
+ text_embeddings = torch.cat([neg_embeddings, pos_embeddings], dim=0)
143
+
144
+ return text_embeddings.to(device=device, dtype=dtype)
145
+
146
+ @torch.no_grad()
147
+ def generate_latents(
148
+ self,
149
+ text_embeddings,
150
+ height: int = 576,
151
+ width: int = 576,
152
+ num_inference_steps: int = 40,
153
+ guidance_scale: float = 5.0,
154
+ latent_channels: int = 16,
155
+ batch_size: int = 1,
156
+ generator = None,
157
+ ):
158
+ """Генерация латентов с использованием эмбеддингов промптов."""
159
+ device = self.device
160
+ dtype = next(self.unet.parameters()).dtype
161
+
162
+ # Проверка размера эмбеддингов
163
+ do_classifier_free_guidance = guidance_scale > 0
164
+ embedding_dim = text_embeddings.shape[0] // 2 if do_classifier_free_guidance else text_embeddings.shape[0]
165
+
166
+ if batch_size > embedding_dim:
167
+ # Повторяем эмбеддинги до нужного размера батча
168
+ if do_classifier_free_guidance:
169
+ neg_embeds, pos_embeds = text_embeddings.chunk(2)
170
+ neg_embeds = neg_embeds.repeat(batch_size // embedding_dim, 1, 1)
171
+ pos_embeds = pos_embeds.repeat(batch_size // embedding_dim, 1, 1)
172
+ text_embeddings = torch.cat([neg_embeds, pos_embeds], dim=0)
173
+ else:
174
+ text_embeddings = text_embeddings.repeat(batch_size // embedding_dim, 1, 1)
175
+
176
+ # Установка timesteps
177
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
178
+
179
+ # Инициализация латентов с заданным seed
180
+ latent_shape = (
181
+ batch_size,
182
+ latent_channels,
183
+ height // self.vae_scale_factor,
184
+ width // self.vae_scale_factor
185
+ )
186
+ latents = torch.randn(
187
+ latent_shape,
188
+ device=device,
189
+ dtype=dtype,
190
+ generator=generator
191
+ )
192
+
193
+ # Процесс диффузии
194
+ for t in tqdm(self.scheduler.timesteps, desc="Генерация"):
195
+ # Подготовка входных данных
196
+ if do_classifier_free_guidance:
197
+ latent_input = torch.cat([latents] * 2)
198
+ else:
199
+ latent_input = latents
200
+
201
+ latent_input = self.scheduler.scale_model_input(latent_input, t)
202
+
203
+ # Предсказание шума
204
+ noise_pred = self.unet(latent_input, t, text_embeddings).sample
205
+
206
+ # Применение guidance
207
+ if do_classifier_free_guidance:
208
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
209
+ noise_pred = noise_pred_uncond + guidance_scale * (
210
+ noise_pred_text - noise_pred_uncond
211
+ )
212
+
213
+ # Обновление латентов
214
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
215
+
216
+ return latents
217
+
218
+ def decode_latents(self, latents, output_type="pil"):
219
+ """Декодирование латентов в изображения."""
220
+ # Нормализация латентов
221
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
222
+
223
+ # Декодирование
224
+ with torch.no_grad():
225
+ images = self.vae.decode(latents).sample
226
+
227
+ # Нормализация изображений
228
+ images = (images / 2 + 0.5).clamp(0, 1)
229
+
230
+ # Конвертация в нужный формат
231
+ if output_type == "pil":
232
+ images = images.cpu().permute(0, 2, 3, 1).float().numpy()
233
+ images = (images * 255).round().astype("uint8")
234
+ return [Image.fromarray(image) for image in images]
235
+ else:
236
+ return images.cpu().permute(0, 2, 3, 1).float().numpy()
237
+
238
+ @torch.no_grad()
239
+ def __call__(
240
+ self,
241
+ prompt: Optional[Union[str, List[str]]] = None,
242
+ height: int = 576,
243
+ width: int = 576,
244
+ num_inference_steps: int = 40,
245
+ guidance_scale: float = 5.0,
246
+ latent_channels: int = 16,
247
+ output_type: str = "pil",
248
+ return_dict: bool = True,
249
+ batch_size: int = 1,
250
+ seed: Optional[int] = None,
251
+ negative_prompt: Optional[Union[str, List[str]]] = None,
252
+ text_embeddings: Optional[torch.FloatTensor] = None,
253
+ ):
254
+ """Генерация изображения из текстовых промптов или эмбеддингов."""
255
+ device = self.device
256
+
257
+ # Устанавливаем генератор с seed для воспроизводимости
258
+ generator = None
259
+ if seed is not None:
260
+ generator = torch.Generator(device=device).manual_seed(seed)
261
+
262
+ # Получаем эмбеддинги, если они не предоставлены
263
+ if text_embeddings is None:
264
+ if prompt is None and negative_prompt is None:
265
+ raise ValueError("Необходимо указать prompt, negative_prompt или text_embeddings")
266
+
267
+ # Вычисляем эмбеддинги
268
+ text_embeddings = self.encode_prompt(
269
+ prompt=prompt,
270
+ negative_prompt=negative_prompt,
271
+ device=device
272
+ )
273
+ else:
274
+ # Убеждаемся, что эмбеддинги на правильном устройстве
275
+ text_embeddings = text_embeddings.to(device)
276
+
277
+ # Генерируем латенты
278
+ latents = self.generate_latents(
279
+ text_embeddings=text_embeddings,
280
+ height=height,
281
+ width=width,
282
+ num_inference_steps=num_inference_steps,
283
+ guidance_scale=guidance_scale,
284
+ latent_channels=latent_channels,
285
+ batch_size=batch_size,
286
+ generator=generator
287
+ )
288
+
289
+ # Декодируем латенты в изображения
290
+ images = self.decode_latents(latents, output_type=output_type)
291
+
292
+ if not return_dict:
293
+ return images
294
+
295
+ return SdxsPipelineOutput(images=images)
promo.png ADDED

Git LFS Details

  • SHA256: 73b330e4d3677d91a81220b50c230733bf0167e536f4148d280cd79861ecc161
  • Pointer size: 132 Bytes
  • Size of remote file: 4.74 MB
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # torch>=2.6.0
2
+ # torchvision>=0.21.0
3
+ # torchaudio>=2.6.0
4
+ diffusers>=0.32.2
5
+ accelerate>=1.5.2
6
+ datasets>=3.5.0
7
+ matplotlib>=3.10.1
8
+ wandb>=0.19.8
9
+ huggingface_hub>=0.29.3
10
+ bitsandbytes>=0.45.4
11
+ transformers
result_grid.jpg ADDED

Git LFS Details

  • SHA256: 61e4d9c26e8629fc743c1bc1aee9fe6fe7ddb995a8dd77f74f19a77c14011c62
  • Pointer size: 132 Bytes
  • Size of remote file: 6.64 MB
samples/unet_192x384_0.jpg ADDED

Git LFS Details

  • SHA256: 2d23201df4a727a74237b908e00a387e9f47b3147431de5f0da52a2e92676b0c
  • Pointer size: 130 Bytes
  • Size of remote file: 34.4 kB
samples/unet_256x384_0.jpg ADDED

Git LFS Details

  • SHA256: adf02a8971642efe76ff7ab3acdd9d3c4783f5f58a770f567c627d86abd1ea5d
  • Pointer size: 130 Bytes
  • Size of remote file: 46.9 kB
samples/unet_320x384_0.jpg ADDED

Git LFS Details

  • SHA256: 86179dcdab6b10dc43a4907e4966c38089d6ef30fd6253afe7b40afba0ea73a5
  • Pointer size: 130 Bytes
  • Size of remote file: 48.2 kB
samples/unet_384x192_0.jpg ADDED

Git LFS Details

  • SHA256: e6bd2b91ae9abede7ec2dfebf7031a353ece981b82559b65a04b1e263d9dd46b
  • Pointer size: 130 Bytes
  • Size of remote file: 38.4 kB
samples/unet_384x256_0.jpg ADDED

Git LFS Details

  • SHA256: 48c2cedea26fe197993fffadaa36477c077fb827850370178756bc6f15a6cfa8
  • Pointer size: 130 Bytes
  • Size of remote file: 42 kB
samples/unet_384x320_0.jpg ADDED

Git LFS Details

  • SHA256: 511a1a4bac01631b023e46e09b6054e8173fc0237bc04f2554cb6c107da25518
  • Pointer size: 130 Bytes
  • Size of remote file: 61.6 kB
samples/unet_384x384_0.jpg ADDED

Git LFS Details

  • SHA256: cbd11687741b968e66aee70793455765cfbc6a036ae30e59f14c6fd816a8da1b
  • Pointer size: 130 Bytes
  • Size of remote file: 37.2 kB
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2e919ad3cde5f0bdf9529c68ee7c3306b1ceef40245778d29050d58ebc074158
3
+ size 507
src/captions_moondream2.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cd389dacb701c76713fa256b68a05972b89cf80c8d2fafef341abccbba826765
3
+ size 4999
src/captions_moondream2_wd3.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c1519f45fc46f644f46631acc4250c5b558a9d697e47548a00bb3eeedbb14e75
3
+ size 9956
src/captions_qwen2-vl-7b.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from PIL import Image
4
+ from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
5
+ import numpy as np
6
+ from pathlib import Path
7
+ from tqdm import tqdm
8
+ import argparse
9
+ import gc
10
+
11
+ # Configuration options
12
+ PRINT_CAPTIONS = False # Print captions to the console during inference
13
+ PRINT_CAPTIONING_STATUS = False # Print captioning file status to the console
14
+ OVERWRITE = True # Allow overwriting existing caption files
15
+ PREPEND_STRING = "" # Prefix string to prepend to the generated caption
16
+ APPEND_STRING = "" # Suffix string to append to the generated caption
17
+ STRIP_LINEBREAKS = True # Remove line breaks from generated captions before saving
18
+ DEFAULT_SAVE_FORMAT = ".txt" # Default format for saving captions
19
+
20
+ # Image resizing options
21
+ MAX_WIDTH = 512 # Set to 0 or less to ignore
22
+ MAX_HEIGHT = 512 # Set to 0 or less to ignore
23
+
24
+ # Generation parameters
25
+ REPETITION_PENALTY = 1.3 # Penalty for repeating phrases, float ~1.5
26
+ TEMPERATURE = 0.7 # Sampling temperature to control randomness
27
+ TOP_K = 50 # Top-k sampling to limit number of potential next tokens
28
+
29
+ # Default values for input folder, output folder, prompt, and save format
30
+ DEFAULT_INPUT_FOLDER = Path(__file__).parent / "input"
31
+ DEFAULT_OUTPUT_FOLDER = DEFAULT_INPUT_FOLDER
32
+ DEFAULT_PROMPT = "In two medium sentence, caption the key aspects of this image."
33
+
34
+ #os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
35
+
36
+ # Function to parse command-line arguments
37
+ def parse_arguments():
38
+ parser = argparse.ArgumentParser(description="Process images and generate captions using Qwen model.")
39
+ parser.add_argument("--input_folder", type=str, default=DEFAULT_INPUT_FOLDER, help="Path to the input folder containing images.")
40
+ parser.add_argument("--output_folder", type=str, default=DEFAULT_OUTPUT_FOLDER, help="Path to the output folder for saving captions.")
41
+ parser.add_argument("--prompt", type=str, default=DEFAULT_PROMPT, help="Prompt for generating the caption.")
42
+ parser.add_argument("--save_format", type=str, default=DEFAULT_SAVE_FORMAT, help="Format for saving captions (e.g., .txt, .md, .json).")
43
+ parser.add_argument("--max_width", type=int, default=MAX_WIDTH, help="Maximum width for resizing images (default: no resizing).")
44
+ parser.add_argument("--max_height", type=int, default=MAX_HEIGHT, help="Maximum height for resizing images (default: no resizing).")
45
+ parser.add_argument("--repetition_penalty", type=float, default=REPETITION_PENALTY, help="Penalty for repetition during caption generation (default: 1.10).")
46
+ parser.add_argument("--temperature", type=float, default=TEMPERATURE, help="Sampling temperature for generation (default: 0.7).")
47
+ parser.add_argument("--top_k", type=int, default=TOP_K, help="Top-k sampling during generation (default: 50).")
48
+ return parser.parse_args()
49
+
50
+ # Function to ignore images that don't have output files yet
51
+ def filter_images_without_output(input_folder, save_format):
52
+ images_to_caption = []
53
+ skipped_images = 0
54
+ total_images = 0
55
+
56
+ for root, _, files in os.walk(input_folder):
57
+ for file in files:
58
+ if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
59
+ total_images += 1
60
+ image_path = os.path.join(root, file)
61
+ output_path = os.path.splitext(image_path)[0] + save_format
62
+ if not OVERWRITE and os.path.exists(output_path):
63
+ skipped_images += 1
64
+ else:
65
+ images_to_caption.append(image_path)
66
+
67
+ return images_to_caption, total_images, skipped_images
68
+
69
+ # Function to save caption to a file
70
+ def save_caption_to_file(image_path, caption, save_format):
71
+ txt_file_path = os.path.splitext(image_path)[0] + save_format # Same name, but with chosen save format
72
+ caption = PREPEND_STRING + caption + APPEND_STRING # Apply prepend/append strings
73
+
74
+ with open(txt_file_path, "w") as txt_file:
75
+ txt_file.write(caption)
76
+
77
+ if PRINT_CAPTIONING_STATUS:
78
+ print(f"Caption for {os.path.abspath(image_path)} saved in {save_format} format.")
79
+
80
+
81
+ # Function to process all images recursively in a folder
82
+ def process_images_in_folder(images_to_caption, prompt, save_format, max_width=MAX_WIDTH, max_height=MAX_HEIGHT, repetition_penalty=REPETITION_PENALTY, temperature=TEMPERATURE, top_k=TOP_K):
83
+ for image_path in tqdm(images_to_caption, desc="Processing Images"):
84
+ try:
85
+ # Загружаем и обрабатываем изображение
86
+ image = Image.open(image_path).convert("RGB")
87
+ image = resize_image_proportionally(image, max_width, max_height)
88
+
89
+ # Читаем теги
90
+ tags_filename = str(Path(image_path).with_suffix('.tag'))
91
+ with open(tags_filename, "r") as file:
92
+ tag_caption = file.read()
93
+ #prompt = f'In one medium sentence, caption the key aspects of this image. Be direct for NSFW. You may use following tags as context: {tag_caption}'
94
+ prompt = f'In two medium sentence, caption the key aspects of this image. Be direct for NSFW, Don\'t be shy in your expressions like pussy, anal. Add approximate age for humans like: a child,young girl or 40 years old woman. You may use following tags as context if they are relevant: {tag_caption}'
95
+
96
+ #print("prompt",prompt)
97
+
98
+ # Генерируем подпись
99
+ caption = qwen_caption(image, prompt, repetition_penalty, temperature, top_k)
100
+ save_caption_to_file(image_path, caption, save_format)
101
+
102
+ if PRINT_CAPTIONS:
103
+ print(f"Caption for {os.path.abspath(image_path)}: {caption}")
104
+
105
+ # Освобождаем память
106
+ del image, tag_caption, caption
107
+ torch.cuda.empty_cache()
108
+ gc.collect()
109
+
110
+ except Exception as e:
111
+ print(f"Error processing {os.path.abspath(image_path)}: {str(e)}")
112
+ torch.cuda.empty_cache()
113
+ gc.collect()
114
+
115
+ # Resize the image proportionally based on max width and/or max height.
116
+ def resize_image_proportionally(image, max_width=None, max_height=None):
117
+ """
118
+ If both max_width and max_height are provided, the image is resized to fit within both dimensions,
119
+ keeping the aspect ratio intact. If only one dimension is provided, the image is resized based on that dimension.
120
+ """
121
+ if (max_width is None or max_width <= 0) and (max_height is None or max_height <= 0):
122
+ return image # No resizing if both dimensions are not provided or set to 0 or less
123
+
124
+ original_width, original_height = image.size
125
+ aspect_ratio = original_width / original_height
126
+
127
+ # Determine the new dimensions
128
+ if max_width and not max_height:
129
+ # Resize based on width
130
+ new_width = max_width
131
+ new_height = int(new_width / aspect_ratio)
132
+ elif max_height and not max_width:
133
+ # Resize based on height
134
+ new_height = max_height
135
+ new_width = int(new_height * aspect_ratio)
136
+ else:
137
+ # Resize based on both width and height, keeping the aspect ratio
138
+ new_width = max_width
139
+ new_height = max_height
140
+
141
+ # Adjust the dimensions proportionally to the aspect ratio
142
+ if new_width / aspect_ratio > new_height:
143
+ new_width = int(new_height * aspect_ratio)
144
+ else:
145
+ new_height = int(new_width / aspect_ratio)
146
+
147
+ # Resize the image using LANCZOS (equivalent to ANTIALIAS in older versions)
148
+ resized_image = image.resize((new_width, new_height))
149
+ return resized_image
150
+
151
+ # Generate a caption for the provided image using the Ertugrul/Qwen2-VL-7B-Captioner-Relaxed model
152
+ def qwen_caption(image, prompt, repetition_penalty=REPETITION_PENALTY, temperature=TEMPERATURE, top_k=TOP_K):
153
+ if not isinstance(image, Image.Image):
154
+ image = Image.fromarray(np.uint8(image))
155
+
156
+ # Prepare the conversation content, which includes the image and the text prompt
157
+ conversation = [
158
+ {
159
+ "role": "user",
160
+ "content": [
161
+ {
162
+ "type": "image",
163
+ },
164
+ {"type": "text", "text": prompt},
165
+ ],
166
+ }
167
+ ]
168
+
169
+ # Apply the chat template to format the message for processing
170
+ text_prompt = qwen_processor.apply_chat_template(
171
+ conversation, add_generation_prompt=True
172
+ )
173
+
174
+ # Prepare the inputs for the model, padding as necessary and converting to tensors
175
+ inputs = qwen_processor(
176
+ text=[text_prompt],
177
+ images=[image],
178
+ padding=True,
179
+ return_tensors="pt",
180
+ )
181
+ inputs = inputs.to("cuda")
182
+
183
+ with torch.no_grad():
184
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
185
+ output_ids = qwen_model.generate(
186
+ **inputs,
187
+ max_new_tokens=384,
188
+ do_sample=True,
189
+ temperature=temperature,
190
+ use_cache=True,
191
+ top_k=top_k,
192
+ repetition_penalty=repetition_penalty,
193
+ )
194
+
195
+ # Trim the generated IDs to remove the input part from the output
196
+ generated_ids_trimmed = [
197
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, output_ids)
198
+ ]
199
+
200
+ # Decode the trimmed output into text, skipping special tokens
201
+ output_text = qwen_processor.batch_decode(
202
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=True
203
+ )
204
+
205
+ # Strip line breaks if the option is enabled
206
+ if STRIP_LINEBREAKS:
207
+ output_text[0] = output_text[0].replace('\n', ' ')
208
+
209
+ # Освобождаем память
210
+ del inputs, output_ids, generated_ids_trimmed
211
+ torch.cuda.empty_cache()
212
+ gc.collect()
213
+
214
+ return output_text[0]
215
+
216
+ # Run the script
217
+ if __name__ == "__main__":
218
+ args = parse_arguments()
219
+ input_folder = args.input_folder
220
+ output_folder = args.output_folder
221
+ prompt = args.prompt
222
+ save_format = args.save_format
223
+ max_width = args.max_width
224
+ max_height = args.max_height
225
+ repetition_penalty = args.repetition_penalty
226
+ temperature = args.temperature
227
+ top_k = args.top_k
228
+
229
+ # Define model_id
230
+ model_id = "Ertugrul/Qwen2-VL-7B-Captioner-Relaxed"
231
+
232
+ # Filter images before loading the model
233
+ images_to_caption, total_images, skipped_images = filter_images_without_output(input_folder, save_format)
234
+
235
+ # Print summary of found, skipped, and to-be-processed images
236
+ print(f"\nFound {total_images} image{'s' if total_images != 1 else ''}.")
237
+ if not OVERWRITE:
238
+ print(f"{skipped_images} image{'s' if skipped_images != 1 else ''} already have captions with format {save_format}, skipping.")
239
+ print(f"\nCaptioning {len(images_to_caption)} image{'s' if len(images_to_caption) != 1 else ''}.\n\n")
240
+
241
+ # Only load the model if there are images to caption
242
+ if len(images_to_caption) == 0:
243
+ print("No images to process. Exiting.\n\n")
244
+ else:
245
+ # Initialize the Ertugrul/Qwen2-VL-7B-Captioner-Relaxed model
246
+ qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
247
+ model_id, torch_dtype=torch.bfloat16, device_map="auto"
248
+ )
249
+ qwen_processor = AutoProcessor.from_pretrained(model_id)
250
+
251
+ # Process the images with optional resizing and caption generation
252
+ process_images_in_folder(
253
+ images_to_caption,
254
+ prompt,
255
+ save_format,
256
+ max_width=max_width,
257
+ max_height=max_height,
258
+ repetition_penalty=repetition_penalty,
259
+ temperature=temperature,
260
+ top_k=top_k
261
+ )
src/captions_wd.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b46993285995bdc52e69d82e900f239ebafd9dc924be046e80372639c8796ed8
3
+ size 29850
src/cherrypick.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c94efa8bf16993a1d15ef5c455fadf92084dd9349fdd8a9b5a9ca66fe869f565
3
+ size 48464
src/cuda.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cba78a89ed8649cb384e15b3b241df0a1aa35b36ca89a5453e59fc4b875ec0f1
3
+ size 1503
src/dataset_clean.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b1f0686dd3fe000d5bc00ff5d676f173793699630b0017f53530f3dcb1ec474e
3
+ size 5085
src/dataset_combine.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from datasets import load_from_disk, concatenate_datasets
4
+
5
+ def combine_datasets(main_dataset_path, datasets_to_add):
6
+ """
7
+ Объединяет указанные датасеты с основным датасетом.
8
+
9
+ Args:
10
+ main_dataset_path (str): Путь к основному датасету, в который нужно добавить данные
11
+ datasets_to_add (list): Список путей к датасетам, которые нужно добавить
12
+
13
+ Returns:
14
+ Dataset: Объединенный датасет
15
+ """
16
+ # Загружаем основной датасет
17
+ try:
18
+ main_dataset = load_from_disk(main_dataset_path)
19
+ print(f"Загружен основной датасет: {main_dataset_path} ({len(main_dataset)} записей)")
20
+ except Exception as e:
21
+ print(f"Ошибка загрузки основного датасета: {e}")
22
+ return None
23
+
24
+ # Список всех датасетов для объединения
25
+ all_datasets = [main_dataset]
26
+
27
+ # Загружаем и добавляем все дополнительные датасеты
28
+ for path in datasets_to_add:
29
+ try:
30
+ ds = load_from_disk(path)
31
+ all_datasets.append(ds)
32
+ print(f"Добавлен датасет: {path} ({len(ds)} записей)")
33
+ except Exception as e:
34
+ print(f"Ошибка загрузки датасета {path}: {e}")
35
+
36
+ # Объединяем все датасеты
37
+ print(f"Объединение {len(all_datasets)} датасетов...")
38
+ combined = concatenate_datasets(all_datasets)
39
+
40
+ # Создаем временную директорию на основе имени основного датасета
41
+ temp_dir = f"{main_dataset_path}_temp"
42
+
43
+ # Удаляем временную директорию, если она уже существует
44
+ if os.path.exists(temp_dir):
45
+ shutil.rmtree(temp_dir)
46
+
47
+ try:
48
+ # Сохраняем в временную директорию
49
+ print(f"Сохранение во временную директорию {temp_dir}...")
50
+ combined.save_to_disk(temp_dir)
51
+
52
+ # Удаляем старую директорию и перемещаем новую на ее место
53
+ print(f"Обновление основного датасета...")
54
+ #if os.path.exists(main_dataset_path):
55
+ # shutil.rmtree(main_dataset_path)
56
+ #shutil.copytree(temp_dir, main_dataset_path)
57
+
58
+ # Удаляем временную директорию после успешного копирования
59
+ #shutil.rmtree(temp_dir)
60
+
61
+ print(f"✅ Объединенный датасет ({len(combined)} записей) успешно сохранен в: {main_dataset_path}")
62
+ except Exception as e:
63
+ print(f"Ошибка при сохранении датасета: {e}")
64
+ print(f"Временные данные сохранены в: {temp_dir}")
65
+
66
+ return combined
67
+
68
+ combine_datasets("/workspace/sdxs/datasets/384_temp", ["/workspace/sdxs/datasets/ds3_384"])
src/dataset_fromzip.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d12f225412aeed4f5ff6cd2dc23db6bcdfd944ff00fe6f0d9c2e8fe0ec426ee
3
+ size 6167
src/dataset_imagenet.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:92ea6bc9e4033a778b9e36defff6adf481baae0d8b0a3fa537313df6fb5b4472
3
+ size 318505
src/dataset_laion_coco.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:31395e7f40ef370971b523fb9d9ab56b404ca8cc1e8e932cc602beaf72140411
3
+ size 25403
src/dataset_mjnj.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8cf317c438de242a8cc0c7d710c00ceec53e887108b081235a1fb05dae0074b0
3
+ size 23158
src/dataset_mnist-te.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac1571369244de9ff15d4b1785e962e06521630fa1be32f0471175e42ef00630
3
+ size 34388
src/dataset_mnist.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c644e111748cb374d2fb9fec28ef99a5ed616898100e689cd02c6ba80b3431a7
3
+ size 33829
src/dataset_sample.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac0384c01b5ed29625df6ab7c2da36bbf9b7b9beb4ba83746eb6c00fbd6046e1
3
+ size 1986940
src/inference.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9fdead0e35dd039c20314c1f7f8579c92b2a891a310965ffdec6002fd8a78c00
3
+ size 2147113
src/sdxs_create-vavae.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:11c2151ba855c0c0fda1e58c295f56612843c5b42aecd779cdb3a03b3802b991
3
+ size 9794
src/sdxs_create.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9fd812a7bc5233c0c2e3932fe11c5ba132e5d0389a505726fa54a95c26b42edf
3
+ size 7417
src/sdxs_create_simple.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:93e0cdf53493f39cd1b8b76f41055aee1e8377e128446d03ccf524e0bb0dcd00
3
+ size 51335
src/sdxs_create_unet.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e2dec0ba7a9a8d4aaaeaad2201dd660ec266cb887d7f8eb127ffbe8c7d80c4f
3
+ size 35930
src/sdxs_sdxxs_transfer.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:57db74be42dbf73bc551cf86b4302dd2717555280e26325e599ca89f51b4916e
3
+ size 168192
test.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6bca9de56c7bdda4a032e2c84d15fd5dfc8108aea08fc3186f8203f428b966f8
3
+ size 5148457
text_encoder/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:87131a858ee394af6afae023f733cdebc36eda2ccbed27c36bc887cfae427392
3
+ size 721
text_encoder/model.fp16.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:107fe15da52fe6d13d877512fa36861d1100534d1b9b88015ad9fd017db095a7
3
+ size 1119825680
text_projector/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae2f211593cd2cc736bf8617bcb0a5e6abd4db0265170de82ae03b7a6664feda
3
+ size 83
text_projector/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e7060a387b4a6419f9d1d852759cb5b94541a1845e996f6062a07462d8b7b6a
3
+ size 2359384
tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c785abebea9ae3257b61681b4e6fd8365ceafde980c21970d001e834cf10835
3
+ size 964
tokenizer/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ffb37461c391f096759f4a9bbbc329da0f36952f88bab061fcf84940c022e98
3
+ size 17082999
tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ccf223ba3d5b3cc7fa6c3bf451f3bb40557a5c92b0aa33f63d17802ff1a96fd9
3
+ size 1178
train-Copy1.py ADDED
@@ -0,0 +1,789 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import torch
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ from torch.utils.data import DataLoader, Sampler
7
+ from torch.utils.data.distributed import DistributedSampler
8
+ from collections import defaultdict
9
+ from torch.optim.lr_scheduler import LambdaLR
10
+ from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler
11
+ from accelerate import Accelerator
12
+ from datasets import load_from_disk
13
+ from tqdm import tqdm
14
+ from PIL import Image,ImageOps
15
+ import wandb
16
+ import random
17
+ import gc
18
+ from accelerate.state import DistributedType
19
+ from torch.distributed import broadcast_object_list
20
+ from torch.utils.checkpoint import checkpoint
21
+ from diffusers.models.attention_processor import AttnProcessor2_0
22
+ from datetime import datetime
23
+ import bitsandbytes as bnb
24
+
25
+ # --------------------------- Параметры ---------------------------
26
+ ds_path = "datasets/384"
27
+ batch_size = 50
28
+ base_learning_rate = 3e-5
29
+ min_learning_rate = 3e-6
30
+ num_epochs = 10
31
+ num_warmup_steps = 1000
32
+ project = "unet"
33
+ use_wandb = True
34
+ save_model = True
35
+ sample_interval_share = 5 # samples/save per epoch
36
+ fbp = False # fused backward pass
37
+ adam8bit = True
38
+ percentile_clipping = 97 # Lion
39
+ torch_compile = False
40
+ unet_gradient = True
41
+ clip_sample = False #Scheduler
42
+ fixed_seed = False
43
+ shuffle = True
44
+ dtype = torch.float32
45
+ steps_offset = 1 # Scheduler
46
+ limit = 0
47
+ checkpoints_folder = ""
48
+ mixed_precision = "no"
49
+ accelerator = Accelerator(mixed_precision=mixed_precision)
50
+ device = accelerator.device
51
+
52
+ # Параметры для диффузии
53
+ n_diffusion_steps = 50
54
+ samples_to_generate = 12
55
+ guidance_scale = 5
56
+
57
+ # Папки для сохранения результатов
58
+ generated_folder = "samples"
59
+ os.makedirs(generated_folder, exist_ok=True)
60
+
61
+ # Настройка seed для воспроизводимости
62
+ current_date = datetime.now()
63
+ seed = int(current_date.strftime("%Y%m%d"))
64
+ if fixed_seed:
65
+ torch.manual_seed(seed)
66
+ np.random.seed(seed)
67
+ random.seed(seed)
68
+ if torch.cuda.is_available():
69
+ torch.cuda.manual_seed_all(seed)
70
+
71
+ #torch.backends.cuda.matmul.allow_tf32 = True
72
+ #torch.backends.cudnn.allow_tf32 = True
73
+ # --------------------------- Параметры LoRA ---------------------------
74
+ # pip install peft
75
+ lora_name = "" #"nusha" # Имя для сохранения/загрузки LoRA адаптеров
76
+ lora_rank = 32 # Ранг LoRA (чем меньше, тем компактнее модель)
77
+ lora_alpha = 64 # Альфа параметр LoRA, определяющий масштаб
78
+
79
+ print("init")
80
+
81
+ # --------------------------- Инициализация WandB ---------------------------
82
+ if use_wandb and accelerator.is_main_process:
83
+ wandb.init(project=project+lora_name, config={
84
+ "batch_size": batch_size,
85
+ "base_learning_rate": base_learning_rate,
86
+ "num_epochs": num_epochs,
87
+ "fbp": fbp,
88
+ "adam8bit": adam8bit,
89
+ })
90
+
91
+ # Включение Flash Attention 2/SDPA
92
+ torch.backends.cuda.enable_flash_sdp(True)
93
+ # --------------------------- Инициализация Accelerator --------------------
94
+ gen = torch.Generator(device=device)
95
+ gen.manual_seed(seed)
96
+
97
+ # --------------------------- Загрузка моделей ---------------------------
98
+ # VAE загружается на CPU для экономии GPU-памяти
99
+ vae = AutoencoderKL.from_pretrained("vae", variant="fp16").to("cpu").eval()
100
+
101
+ # DDPMScheduler с V_Prediction и Zero-SNR
102
+ scheduler = DDPMScheduler(
103
+ num_train_timesteps=1000, # Полный график шагов для обучения
104
+ prediction_type="v_prediction", # V-Prediction
105
+ rescale_betas_zero_snr=True, # Включение Zero-SNR
106
+ clip_sample = clip_sample,
107
+ steps_offset = steps_offset
108
+ )
109
+
110
+ class DistributedResolutionBatchSampler(Sampler):
111
+ def __init__(self, dataset, batch_size, num_replicas, rank, shuffle=True, drop_last=True):
112
+ self.dataset = dataset
113
+ self.batch_size = max(1, batch_size // num_replicas)
114
+ self.num_replicas = num_replicas
115
+ self.rank = rank
116
+ self.shuffle = shuffle
117
+ self.drop_last = drop_last
118
+ self.epoch = 0
119
+
120
+ # Используем numpy для ускорения
121
+ try:
122
+ widths = np.array(dataset["width"])
123
+ heights = np.array(dataset["height"])
124
+ except KeyError:
125
+ widths = np.zeros(len(dataset))
126
+ heights = np.zeros(len(dataset))
127
+
128
+ # Создаем уникальные ключи для размеров
129
+ self.size_keys = np.unique(np.stack([widths, heights], axis=1), axis=0)
130
+
131
+ # Группируем индексы по размерам используя numpy
132
+ self.size_groups = {}
133
+ for w, h in self.size_keys:
134
+ mask = (widths == w) & (heights == h)
135
+ self.size_groups[(w, h)] = np.where(mask)[0]
136
+
137
+ # Предварительно вычисляем количество пол��ых батчей для каждой группы
138
+ self.group_num_batches = {}
139
+ total_batches = 0
140
+ for size, indices in self.size_groups.items():
141
+ num_full_batches = len(indices) // (self.batch_size * self.num_replicas)
142
+ self.group_num_batches[size] = num_full_batches
143
+ total_batches += num_full_batches
144
+
145
+ # Округляем до числа, делящегося на num_replicas
146
+ self.num_batches = (total_batches // self.num_replicas) * self.num_replicas
147
+
148
+ def __iter__(self):
149
+ # print(f"Rank {self.rank}: Starting iteration")
150
+ # Очищаем CUDA кэш перед формированием новых батчей
151
+ if torch.cuda.is_available():
152
+ torch.cuda.empty_cache()
153
+ all_batches = []
154
+ rng = np.random.RandomState(self.epoch)
155
+
156
+ for size, indices in self.size_groups.items():
157
+ # print(f"Rank {self.rank}: Processing size {size}, {len(indices)} samples")
158
+ indices = indices.copy()
159
+ if self.shuffle:
160
+ rng.shuffle(indices)
161
+
162
+ num_full_batches = self.group_num_batches[size]
163
+ if num_full_batches == 0:
164
+ continue
165
+
166
+ # Берем только индексы для полных батчей
167
+ valid_indices = indices[:num_full_batches * self.batch_size * self.num_replicas]
168
+
169
+ # Reshape для быстрого разделения на батчи
170
+ batches = valid_indices.reshape(-1, self.batch_size * self.num_replicas)
171
+
172
+ # Выбираем часть для текущего GPU
173
+ start_idx = self.rank * self.batch_size
174
+ end_idx = start_idx + self.batch_size
175
+ gpu_batches = batches[:, start_idx:end_idx]
176
+
177
+ all_batches.extend(gpu_batches)
178
+
179
+ if self.shuffle:
180
+ rng.shuffle(all_batches)
181
+
182
+ # Синхронизируем все процессы после формирования батчей
183
+ accelerator.wait_for_everyone()
184
+ # print(f"Rank {self.rank}: Created {len(all_batches)} batches")
185
+ return iter(all_batches)
186
+
187
+ def __len__(self):
188
+ return self.num_batches
189
+
190
+ def set_epoch(self, epoch):
191
+ self.epoch = epoch
192
+
193
+ # Функция для выборки фиксированных семплов по размерам
194
+ def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
195
+ """Выбирает фиксированные семплы для каждого уникального разрешения"""
196
+ # Группируем по размерам
197
+ size_groups = defaultdict(list)
198
+ try:
199
+ widths = dataset["width"]
200
+ heights = dataset["height"]
201
+ except KeyError:
202
+ widths = [0] * len(dataset)
203
+ heights = [0] * len(dataset)
204
+ for i, (w, h) in enumerate(zip(widths, heights)):
205
+ size = (w, h)
206
+ size_groups[size].append(i)
207
+
208
+ # Выбираем фиксированные примеры из каждой группы
209
+ fixed_samples = {}
210
+ for size, indices in size_groups.items():
211
+ # Определяем сколько семплов брать из этой группы
212
+ n_samples = min(samples_per_group, len(indices))
213
+ if len(size_groups)==1:
214
+ n_samples = samples_to_generate
215
+ if n_samples == 0:
216
+ continue
217
+
218
+ # Выбираем случайные индексы
219
+ sample_indices = random.sample(indices, n_samples)
220
+ samples_data = [dataset[idx] for idx in sample_indices]
221
+
222
+ # Собираем данные
223
+ latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device,dtype=dtype)
224
+ embeddings = torch.tensor(np.array([item["embeddings"] for item in samples_data])).to(device,dtype=dtype)
225
+ texts = [item["text"] for item in samples_data]
226
+
227
+ # Сохраняем для этого размера
228
+ fixed_samples[size] = (latents, embeddings, texts)
229
+
230
+ print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
231
+ return fixed_samples
232
+
233
+ if limit > 0:
234
+ dataset = load_from_disk(ds_path).select(range(limit))
235
+ else:
236
+ dataset = load_from_disk(ds_path)
237
+
238
+ def collate_fn_simple(batch):
239
+ # Преобразуем список в тензоры и перемещаем на девайс
240
+ latents = torch.tensor(np.array([item["vae"] for item in batch])).to(device,dtype=dtype)
241
+ embeddings = torch.tensor(np.array([item["embeddings"] for item in batch])).to(device,dtype=dtype)
242
+ return latents, embeddings
243
+
244
+ def collate_fn(batch):
245
+ if not batch:
246
+ return [], []
247
+
248
+ # Берем эталонную форму
249
+ ref_vae_shape = np.array(batch[0]["vae"]).shape
250
+ ref_embed_shape = np.array(batch[0]["embeddings"]).shape
251
+
252
+ # Фильтруем
253
+ valid_latents = []
254
+ valid_embeddings = []
255
+ for item in batch:
256
+ if (np.array(item["vae"]).shape == ref_vae_shape and
257
+ np.array(item["embeddings"]).shape == ref_embed_shape):
258
+ valid_latents.append(item["vae"])
259
+ valid_embeddings.append(item["embeddings"])
260
+
261
+ # Создаем тензоры
262
+ latents = torch.tensor(np.array(valid_latents)).to(device,dtype=dtype)
263
+ embeddings = torch.tensor(np.array(valid_embeddings)).to(device,dtype=dtype)
264
+
265
+ return latents, embeddings
266
+
267
+ # Используем наш ResolutionBatchSampler
268
+ #batch_sampler = ResolutionBatchSampler(dataset, batch_size=batch_size, shuffle=True)
269
+ #dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn)
270
+
271
+ # Создаем ResolutionBatchSampler на основе индексов от DistributedSampler
272
+ batch_sampler = DistributedResolutionBatchSampler(
273
+ dataset=dataset,
274
+ batch_size=batch_size,
275
+ num_replicas=accelerator.num_processes,
276
+ rank=accelerator.process_index,
277
+ shuffle=shuffle
278
+ )
279
+
280
+ # Создаем DataLoader
281
+ dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple)
282
+
283
+ print("Total samples",len(dataloader))
284
+ dataloader = accelerator.prepare(dataloader)
285
+
286
+ # Инициализация переменных для возобновления обучения
287
+ start_epoch = 0
288
+ global_step = 0
289
+
290
+ # Расчёт общего количества шагов
291
+ total_training_steps = (len(dataloader) * num_epochs)
292
+ # Get the world size
293
+ world_size = accelerator.state.num_processes
294
+ #print(f"World Size: {world_size}")
295
+
296
+ # Опция загрузки модели из последнего чекпоинта (если существует)
297
+ latest_checkpoint = os.path.join(checkpoints_folder, project)
298
+ if os.path.isdir(latest_checkpoint):
299
+ print("Загружаем UNet из чекпоинта:", latest_checkpoint)
300
+ if dtype == torch.float32:
301
+ unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device,dtype=dtype)
302
+ else:
303
+ unet = UNet2DConditionModel.from_pretrained(latest_checkpoint, variant="fp16").to(device=device,dtype=dtype)
304
+ if unet_gradient:
305
+ unet.enable_gradient_checkpointing()
306
+ unet.set_use_memory_efficient_attention_xformers(False) # отключаем xformers
307
+ try:
308
+ unet.set_attn_processor(AttnProcessor2_0()) # Используем стандартный AttnProcessor
309
+ except Exception as e:
310
+ print(f"Ошибка при включении SDPA: {e}")
311
+ print("Попытка использовать enable_xformers_memory_efficient_attention.")
312
+ unet.set_use_memory_efficient_attention_xformers(True)
313
+
314
+ if hasattr(torch.backends.cuda, "flash_sdp_enabled"):
315
+ print(f"torch.backends.cuda.flash_sdp_enabled(): {torch.backends.cuda.flash_sdp_enabled()}")
316
+ if hasattr(torch.backends.cuda, "mem_efficient_sdp_enabled"):
317
+ print(f"torch.backends.cuda.mem_efficient_sdp_enabled(): {torch.backends.cuda.mem_efficient_sdp_enabled()}")
318
+ if hasattr(torch.nn.functional, "get_flash_attention_available"):
319
+ print(f"torch.nn.functional.get_flash_attention_available(): {torch.nn.functional.get_flash_attention_available()}")
320
+ if torch_compile:
321
+ print("compiling")
322
+ torch.set_float32_matmul_precision('high')
323
+ unet = torch.compile(unet)#, mode="reduce-overhead", fullgraph=True)
324
+ print("compiling - ok")
325
+
326
+ if lora_name:
327
+ print(f"--- Настройка LoRA через PEFT (Rank={lora_rank}, Alpha={lora_alpha}) ---")
328
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
329
+ from peft.tuners.lora import LoraModel
330
+ import os
331
+ # 1. Замораживаем все параметры UNet
332
+ unet.requires_grad_(False)
333
+ print("Параметры базового UNet заморожены.")
334
+
335
+ # 2. Создаем конфигурацию LoRA
336
+ lora_config = LoraConfig(
337
+ r=lora_rank,
338
+ lora_alpha=lora_alpha,
339
+ target_modules=["to_q", "to_k", "to_v", "to_out.0"],
340
+ )
341
+ unet.add_adapter(lora_config)
342
+
343
+ # 3. Оборачиваем UNet в PEFT-модель
344
+ from peft import get_peft_model
345
+
346
+ peft_unet = get_peft_model(unet, lora_config)
347
+
348
+ # 4. Получаем параметры для оптимизации
349
+ params_to_optimize = list(p for p in peft_unet.parameters() if p.requires_grad)
350
+
351
+
352
+ # 5. Выводим информацию о количестве параметров
353
+ if accelerator.is_main_process:
354
+ lora_params_count = sum(p.numel() for p in params_to_optimize)
355
+ total_params_count = sum(p.numel() for p in unet.parameters())
356
+ print(f"Количество обучаемых параметров (LoRA): {lora_params_count:,}")
357
+ print(f"Общее количество параметров UNet: {total_params_count:,}")
358
+
359
+ # 6. Путь для сохранения
360
+ lora_save_path = os.path.join("lora", lora_name)
361
+ os.makedirs(lora_save_path, exist_ok=True)
362
+
363
+ # 7. Функция для сохранения
364
+ def save_lora_checkpoint(model):
365
+ if accelerator.is_main_process:
366
+ print(f"Сохраняем LoRA адаптеры в {lora_save_path}")
367
+ from peft.utils.save_and_load import get_peft_model_state_dict
368
+ # Получаем state_dict только LoRA
369
+ lora_state_dict = get_peft_model_state_dict(model)
370
+
371
+ # Сохраняем веса
372
+ torch.save(lora_state_dict, os.path.join(lora_save_path, "adapter_model.bin"))
373
+
374
+ # Сохраняем конфиг
375
+ model.peft_config["default"].save_pretrained(lora_save_path)
376
+ # SDXL must be compatible
377
+ from diffusers import StableDiffusionXLPipeline
378
+ StableDiffusionXLPipeline.save_lora_weights(lora_save_path, lora_state_dict)
379
+
380
+ # --------------------------- Оптимизатор ---------------------------
381
+ # Определяем параметры для оптимизации
382
+ #unet = torch.compile(unet)
383
+ if lora_name:
384
+ # Если используется LoRA, оптимизируем только параметры LoRA
385
+ trainable_params = [p for p in unet.parameters() if p.requires_grad]
386
+ else:
387
+ # Иначе оптимизируем все параметры
388
+ if fbp:
389
+ trainable_params = list(unet.parameters())
390
+
391
+ if fbp:
392
+ # [1] Создаем словарь оптимизаторов (fused backward)
393
+ if adam8bit:
394
+ optimizer_dict = {
395
+ p: bnb.optim.AdamW8bit(
396
+ [p], # Каждый параметр получает свой оптимизатор
397
+ lr=base_learning_rate,
398
+ eps=1e-8
399
+ ) for p in trainable_params
400
+ }
401
+ else:
402
+ optimizer_dict = {
403
+ p: bnb.optim.Lion8bit(
404
+ [p], # Каждый параметр получает свой оптимизатор
405
+ lr=base_learning_rate,
406
+ betas=(0.9, 0.97),
407
+ weight_decay=0.01,
408
+ percentile_clipping=percentile_clipping,
409
+ ) for p in trainable_params
410
+ }
411
+
412
+ # [2] Определяем hook для применения оптимизатора сразу после накопления градиента
413
+ def optimizer_hook(param):
414
+ optimizer_dict[param].step()
415
+ optimizer_dict[param].zero_grad(set_to_none=True)
416
+
417
+ # [3] Регистрируем hook для trainable параметров модели
418
+ for param in trainable_params:
419
+ param.register_post_accumulate_grad_hook(optimizer_hook)
420
+
421
+ # Подготовка через Accelerator
422
+ unet, optimizer = accelerator.prepare(unet, optimizer_dict)
423
+ else:
424
+ if adam8bit:
425
+ optimizer = bnb.optim.AdamW8bit(
426
+ params=unet.parameters(),
427
+ lr=base_learning_rate,
428
+ betas=(0.9, 0.999),
429
+ eps=1e-8,
430
+ weight_decay=0.01
431
+ )
432
+ #from torch.optim import AdamW
433
+ #optimizer = AdamW(
434
+ # params=unet.parameters(),
435
+ # lr=base_learning_rate,
436
+ # betas=(0.9, 0.999),
437
+ # eps=1e-8,
438
+ # weight_decay=0.01
439
+ #)
440
+ else:
441
+ optimizer = bnb.optim.Lion8bit(
442
+ params=unet.parameters(),
443
+ lr=base_learning_rate,
444
+ betas=(0.9, 0.97),
445
+ weight_decay=0.01,
446
+ percentile_clipping=percentile_clipping,
447
+ )
448
+ from transformers import get_constant_schedule_with_warmup
449
+
450
+ # warmup
451
+ num_warmup_steps = num_warmup_steps * world_size
452
+
453
+ #lr_scheduler = get_constant_schedule_with_warmup(
454
+ # optimizer=optimizer,
455
+ # num_warmup_steps=num_warmup_steps
456
+ #)
457
+ from torch.optim.lr_scheduler import LambdaLR
458
+ def lr_schedule(step, max_steps, base_lr, min_lr, use_decay=True):
459
+ # Если не используем затухание, возвращаем базовый LR
460
+ if not use_decay:
461
+ return base_lr
462
+
463
+ # Иначе используем линейный прогрев и косинусное затухание
464
+ x = step / max_steps
465
+ percent = 0.05
466
+ if x < percent:
467
+ # Линейный прогрев до percent% шагов
468
+ return min_lr + (base_lr - min_lr) * (x / percent)
469
+ else:
470
+ # Косинусное затухание
471
+ decay_ratio = (x - percent) / (1 - percent)
472
+ return min_lr + 0.5 * (base_lr - min_lr) * (1 + math.cos(math.pi * decay_ratio))
473
+
474
+
475
+ def custom_lr_lambda(step):
476
+ return lr_schedule(step, total_training_steps*world_size,
477
+ base_learning_rate, min_learning_rate,
478
+ (num_warmup_steps>0)) / base_learning_rate
479
+
480
+ lr_scheduler = LambdaLR(optimizer, lr_lambda=custom_lr_lambda)
481
+ unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
482
+
483
+ # --------------------------- Фиксированные семплы для генерации ---------------------------
484
+ # Примеры фиксированных семплов по размерам
485
+ fixed_samples = get_fixed_samples_by_resolution(dataset)
486
+
487
+ @torch.compiler.disable()
488
+ @torch.no_grad()
489
+ def generate_and_save_samples(fixed_samples_cpu, step):
490
+ """
491
+ Генерирует семплы для каждого из разрешений и сохраняет их.
492
+
493
+ Args:
494
+ fixed_samples_cpu: Словарь, где ключи - размеры (width, height),
495
+ а значения - кортежи (latents, embeddings, text) на CPU.
496
+ step: Текущий шаг обучения
497
+ """
498
+ original_model = None # Инициализируем, чтобы finally не ругался
499
+ try:
500
+
501
+ original_model = accelerator.unwrap_model(unet)
502
+ original_model = original_model.to(dtype = dtype)
503
+ original_model.eval()
504
+
505
+ vae.to(device=device, dtype=dtype)
506
+ vae.eval()
507
+
508
+ scheduler.set_timesteps(n_diffusion_steps)
509
+
510
+ all_generated_images = []
511
+ all_captions = []
512
+
513
+ for size, (sample_latents, sample_text_embeddings, sample_text) in fixed_samples_cpu.items():
514
+ width, height = size
515
+
516
+ sample_latents = sample_latents.to(dtype=dtype)
517
+ sample_text_embeddings = sample_text_embeddings.to(dtype=dtype)
518
+
519
+ # Инициализируем латенты случайным шумом
520
+ # sample_latents уже в dtype, так что noise будет создан в dtype
521
+ noise = torch.randn(
522
+ sample_latents.shape, # Используем форму от sample_latents, которые теперь на GPU и fp16
523
+ generator=gen,
524
+ device=device,
525
+ dtype=sample_latents.dtype
526
+ )
527
+ current_latents = noise.clone()
528
+
529
+ # Подготовка текстовых эмбеддингов для guidance
530
+ if guidance_scale > 0:
531
+ # empty_embeddings должны быть того же типа и на том же устройстве
532
+ empty_embeddings = torch.zeros_like(sample_text_embeddings, dtype=sample_text_embeddings.dtype, device=device)
533
+ text_embeddings_batch = torch.cat([empty_embeddings, sample_text_embeddings], dim=0)
534
+ else:
535
+ text_embeddings_batch = sample_text_embeddings
536
+
537
+ for t in scheduler.timesteps:
538
+ t_batch = t.repeat(current_latents.shape[0]).to(device) # Убедимся, что t на устройстве
539
+
540
+ if guidance_scale > 0:
541
+ latent_model_input = torch.cat([current_latents] * 2)
542
+ else:
543
+ latent_model_input = current_latents
544
+
545
+ latent_model_input_scaled = scheduler.scale_model_input(latent_model_input, t_batch)
546
+
547
+ # Предсказание шума (UNet)
548
+ noise_pred = original_model(latent_model_input_scaled, t_batch, text_embeddings_batch).sample
549
+
550
+ if guidance_scale > 0:
551
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
552
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
553
+
554
+ current_latents = scheduler.step(noise_pred, t, current_latents).prev_sample
555
+
556
+ #print(f"current_latents Min: {current_latents.min()} Max: {current_latents.max()}")
557
+ # Декодирование через VAE
558
+ latent_for_vae = (current_latents.detach() / vae.config.scaling_factor) + vae.config.shift_factor
559
+ decoded = vae.decode(latent_for_vae).sample
560
+
561
+ # Преобразуем тензоры в PIL-изображения
562
+ # Для математики с изображением (нормализация) лучше перейти в fp32
563
+ decoded_fp32 = decoded.to(torch.float32)
564
+ for img_idx, img_tensor in enumerate(decoded_fp32):
565
+ img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1, 2, 0)
566
+ # If NaNs or infs are present, print them
567
+ if np.isnan(img).any():
568
+ print("NaNs found, saving stoped! Step:", step)
569
+ save_model = False
570
+ pil_img = Image.fromarray((img * 255).astype("uint8"))
571
+
572
+ max_w_overall = max(s[0] for s in fixed_samples_cpu.keys())
573
+ max_h_overall = max(s[1] for s in fixed_samples_cpu.keys())
574
+ max_w_overall = max(255, max_w_overall)
575
+ max_h_overall = max(255, max_h_overall)
576
+
577
+ padded_img = ImageOps.pad(pil_img, (max_w_overall, max_h_overall), color='white')
578
+ all_generated_images.append(padded_img)
579
+
580
+ caption_text = sample_text[img_idx][:200] if img_idx < len(sample_text) else ""
581
+ all_captions.append(caption_text)
582
+
583
+ sample_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg"
584
+ pil_img.save(sample_path, "JPEG", quality=96)
585
+
586
+ if use_wandb and accelerator.is_main_process:
587
+ wandb_images = [
588
+ wandb.Image(img, caption=f"{all_captions[i]}")
589
+ for i, img in enumerate(all_generated_images)
590
+ ]
591
+ wandb.log({"generated_images": wandb_images, "global_step": step})
592
+
593
+ finally:
594
+ vae.to("cpu") # Перемещаем VAE обратно на CPU
595
+ original_model = original_model.to(dtype = dtype)
596
+ if original_model is not None:
597
+ del original_model
598
+ # Очистка переменных, которые являются тензорами и были созданы в функции
599
+ for var in list(locals().keys()):
600
+ if isinstance(locals()[var], torch.Tensor):
601
+ del locals()[var]
602
+
603
+ torch.cuda.empty_cache()
604
+ gc.collect()
605
+
606
+ # --------------------------- Генерация сэмплов перед обучением ---------------------------
607
+ if accelerator.is_main_process:
608
+ if save_model:
609
+ print("Генерация сэмплов до старта обучения...")
610
+ generate_and_save_samples(fixed_samples,0)
611
+
612
+ # Модифицируем функцию сохранения модели для поддержки LoRA
613
+ def save_checkpoint(unet,variant=""):
614
+ if accelerator.is_main_process:
615
+ if lora_name:
616
+ # Сохраняем только LoRA адаптеры
617
+ save_lora_checkpoint(unet)
618
+ else:
619
+ # Сохраняем полную модель
620
+ if variant!="":
621
+ accelerator.unwrap_model(unet.to(dtype=torch.float16)).save_pretrained(os.path.join(checkpoints_folder, f"{project}"),variant=variant)
622
+ else:
623
+ accelerator.unwrap_model(unet).save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
624
+ unet = unet.to(dtype=dtype)
625
+
626
+ # --------------------------- Тренировочный цикл ---------------------------
627
+ # Для логирования среднего лосса каждые % эпохи
628
+ if accelerator.is_main_process:
629
+ print(f"Total steps per GPU: {total_training_steps}")
630
+
631
+ epoch_loss_points = []
632
+ progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
633
+
634
+ # Определяем интервал для сэмплирования и логирования в пределах эпохи (10% эпохи)
635
+ steps_per_epoch = len(dataloader)
636
+ sample_interval = max(1, steps_per_epoch // sample_interval_share)
637
+ min_loss = 1.
638
+
639
+ # Начинаем с указанной эпохи (полезно при возобновлении)
640
+ for epoch in range(start_epoch, start_epoch + num_epochs):
641
+ batch_losses = []
642
+ batch_grads = []
643
+ #unet = unet.to(dtype = dtype)
644
+ batch_sampler.set_epoch(epoch)
645
+ accelerator.wait_for_everyone()
646
+ unet.train()
647
+ print("epoch:",epoch)
648
+ for step, (latents, embeddings) in enumerate(dataloader):
649
+ with accelerator.accumulate(unet):
650
+ if save_model == False and step == 5 :
651
+ used_gb = torch.cuda.max_memory_allocated() / 1024**3
652
+ print(f"Шаг {step}: {used_gb:.2f} GB")
653
+
654
+ #latents = latents.to(dtype = dtype)
655
+ #embeddings = embeddings.to(dtype = dtype)
656
+ #print(f"Latents dtype: {latents.dtype}")
657
+ #print(f"Embeddings dtype: {embeddings.dtype}")
658
+ #print(f"Noise dtype: {noise.dtype}")
659
+
660
+ # Forward pass
661
+ noise = torch.randn_like(latents, dtype=latents.dtype)
662
+
663
+ timesteps = torch.randint(steps_offset, scheduler.config.num_train_timesteps,
664
+ (latents.shape[0],), device=device).long()
665
+
666
+ # Добавляем шум к латентам
667
+ noisy_latents = scheduler.add_noise(latents, noise, timesteps)
668
+
669
+ # Используем целевое значение
670
+ model_pred = unet(noisy_latents, timesteps, embeddings).sample
671
+ target_pred = scheduler.get_velocity(latents, noise, timesteps)
672
+
673
+ # Считаем лосс
674
+ # Проверяем model_pred на nan/inf
675
+ #if torch.isnan(model_pred.float()).any() or torch.isinf(model_pred.float()).any():
676
+ # print(f"Rank {accelerator.process_index}: Found nan/inf in model_pred",model_pred.float())
677
+ # # Обработка nan/inf значений
678
+ # model_pred = torch.nan_to_num(model_pred.float(), nan=0.0, posinf=1.0, neginf=-1.0)
679
+ loss = torch.nn.functional.mse_loss(model_pred, target_pred)
680
+
681
+ # Проверяем на nan/inf перед backward
682
+ if torch.isnan(loss) or torch.isinf(loss):
683
+ print(f"Rank {accelerator.process_index}: Found nan/inf in loss: {loss}")
684
+ loss = torch.zeros_like(loss)
685
+
686
+ # Делаем backward через Accelerator
687
+ accelerator.backward(loss)
688
+
689
+ if (global_step % 100 == 0) or (global_step % sample_interval == 0):
690
+ accelerator.wait_for_everyone()
691
+
692
+ grad = 0.0
693
+ if not fbp:
694
+ if accelerator.sync_gradients:
695
+ grad = accelerator.clip_grad_norm_(unet.parameters(), 1.)
696
+ optimizer.step()
697
+ lr_scheduler.step()
698
+ optimizer.zero_grad(set_to_none=True)
699
+
700
+ # Увеличиваем счетчик глобальных шагов
701
+ global_step += 1
702
+
703
+ # Обновляем прогресс-бар
704
+ progress_bar.update(1)
705
+
706
+ # Логируем метрики
707
+ if accelerator.is_main_process:
708
+ if fbp:
709
+ current_lr = base_learning_rate
710
+ else:
711
+ current_lr = lr_scheduler.get_last_lr()[0]
712
+ batch_losses.append(loss.detach().item())
713
+ batch_grads.append(grad)
714
+
715
+ # Логируем в Wandb
716
+ if use_wandb:
717
+ wandb.log({
718
+ "loss": loss.detach().item(),
719
+ "learning_rate": current_lr,
720
+ "epoch": epoch,
721
+ "grad": grad,
722
+ "global_step": global_step
723
+ })
724
+
725
+ # Генерируем сэмплы с заданным интервалом
726
+ if global_step % sample_interval == 0:
727
+ generate_and_save_samples(fixed_samples,global_step)
728
+
729
+ # Выводим текущий лосс
730
+ avg_loss = np.mean(batch_losses[-sample_interval:])
731
+ avg_grad = torch.mean(torch.stack(batch_grads[-sample_interval:])).cpu().item()
732
+ print(f"Эпоха {epoch}, шаг {global_step}, средний лосс: {avg_loss:.6f}")
733
+
734
+ if save_model:
735
+ if avg_loss < min_loss:
736
+ min_loss = avg_loss
737
+ save_checkpoint(unet,"fp16")
738
+ save_checkpoint(unet)
739
+ if use_wandb:
740
+ wandb.log({"intermediate_loss": avg_loss})
741
+ wandb.log({"intermediate_grad": avg_grad})
742
+
743
+
744
+ # По окончании эпохи
745
+ #accelerator.wait_for_everyone()
746
+ if accelerator.is_main_process:
747
+ avg_epoch_loss = np.mean(batch_losses)
748
+ print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
749
+ if use_wandb:
750
+ wandb.log({"epoch_loss": avg_epoch_loss, "epoch": epoch+1})
751
+
752
+ # Завершение обучения - сохраняем финальную модель
753
+ if accelerator.is_main_process:
754
+ print("Обучение завершено! Сохраняем финальную модель...")
755
+ # Сохраняем основную модель
756
+ if save_model:
757
+ save_checkpoint(unet)
758
+ print("Готово!")
759
+
760
+ # randomize ode timesteps
761
+ # input_timestep = torch.round(
762
+ # F.sigmoid(torch.randn((n,), device=latents.device)), decimals=3
763
+ # )
764
+
765
+ #def create_distribution(num_points, device=None):
766
+ # # Диапазон вероятностей на оси x
767
+ # x = torch.linspace(0, 1, num_points, device=device)
768
+
769
+ # Пользовательская функция плотности вероятности
770
+ # probabilities = -7.7 * ((x - 0.5) ** 2) + 2
771
+
772
+ # Нормализация, чтобы сумма равнялась 1
773
+ # probabilities /= probabilities.sum()
774
+
775
+ # return x, probabilities
776
+
777
+ #def sample_from_distribution(x, probabilities, n, device=None):
778
+ # Выбор индексов на основе распределения вероятностей
779
+ # indices = torch.multinomial(probabilities, n, replacement=True)
780
+ # return x[indices]
781
+
782
+ # Пример использования
783
+ #num_points = 1000 # Количество точек в диапазоне
784
+ #n = latents.shape[0] # Количество временных шагов для выборки
785
+ #x, probabilities = create_distribution(num_points, device=latents.device)
786
+ #timesteps = sample_from_distribution(x, probabilities, n, device=latents.device)
787
+
788
+ # Преобразование в формат, подходящий для вашего кода
789
+ #timesteps = (timesteps * (scheduler.config.num_train_timesteps - 1)).long()