512res
Browse files- dataset_fromfolder.py +2 -2
- samples/unet_192x384_0.jpg +2 -2
- samples/unet_256x384_0.jpg +2 -2
- samples/unet_320x384_0.jpg +2 -2
- samples/unet_320x576_0.jpg +2 -2
- samples/unet_384x192_0.jpg +2 -2
- samples/unet_384x256_0.jpg +2 -2
- samples/unet_384x320_0.jpg +2 -2
- samples/unet_384x384_0.jpg +2 -2
- samples/unet_384x576_0.jpg +2 -2
- samples/unet_448x576_0.jpg +2 -2
- samples/unet_512x576_0.jpg +2 -2
- samples/unet_576x320_0.jpg +2 -2
- samples/unet_576x384_0.jpg +2 -2
- samples/unet_576x448_0.jpg +2 -2
- samples/unet_576x512_0.jpg +2 -2
- samples/unet_576x576_0.jpg +2 -2
- src/dataset_combine.py +1 -1
- train.py +124 -27
- unet/diffusion_pytorch_model.safetensors +1 -1
dataset_fromfolder.py
CHANGED
@@ -27,8 +27,8 @@ empty_share = 0.05
|
|
27 |
limit = 0
|
28 |
textemb_full = False
|
29 |
# Основная процедура обработки
|
30 |
-
folder_path = "/workspace/
|
31 |
-
save_path = "/workspace/sdxs/datasets/
|
32 |
os.makedirs(save_path, exist_ok=True)
|
33 |
|
34 |
# Функция для очистки CUDA памяти
|
|
|
27 |
limit = 0
|
28 |
textemb_full = False
|
29 |
# Основная процедура обработки
|
30 |
+
folder_path = "/workspace/eshu"
|
31 |
+
save_path = "/workspace/sdxs/datasets/eshu_576"
|
32 |
os.makedirs(save_path, exist_ok=True)
|
33 |
|
34 |
# Функция для очистки CUDA памяти
|
samples/unet_192x384_0.jpg
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
samples/unet_256x384_0.jpg
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
samples/unet_320x384_0.jpg
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
samples/unet_320x576_0.jpg
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
samples/unet_384x192_0.jpg
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
samples/unet_384x256_0.jpg
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
samples/unet_384x320_0.jpg
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
samples/unet_384x384_0.jpg
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
samples/unet_384x576_0.jpg
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
samples/unet_448x576_0.jpg
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
samples/unet_512x576_0.jpg
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
samples/unet_576x320_0.jpg
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
samples/unet_576x384_0.jpg
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
samples/unet_576x448_0.jpg
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
samples/unet_576x512_0.jpg
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
samples/unet_576x576_0.jpg
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
src/dataset_combine.py
CHANGED
@@ -65,4 +65,4 @@ def combine_datasets(main_dataset_path, datasets_to_add):
|
|
65 |
|
66 |
return combined
|
67 |
|
68 |
-
combine_datasets("/workspace/sdxs/datasets/
|
|
|
65 |
|
66 |
return combined
|
67 |
|
68 |
+
combine_datasets("/workspace/sdxs/datasets/576", ["/workspace/sdxs/datasets/eshu_576"])
|
train.py
CHANGED
@@ -25,12 +25,12 @@ import bitsandbytes as bnb
|
|
25 |
import torch.nn.functional as F
|
26 |
|
27 |
# --------------------------- Параметры ---------------------------
|
28 |
-
ds_path = "datasets/
|
29 |
project = "unet"
|
30 |
-
batch_size =
|
31 |
-
base_learning_rate =
|
32 |
-
min_learning_rate =
|
33 |
-
num_epochs =
|
34 |
# samples/save per epoch
|
35 |
sample_interval_share = 10
|
36 |
use_wandb = True
|
@@ -43,11 +43,14 @@ unet_gradient = True
|
|
43 |
clip_sample = False #Scheduler
|
44 |
fixed_seed = False
|
45 |
shuffle = True
|
|
|
46 |
torch.backends.cuda.matmul.allow_tf32 = True
|
47 |
torch.backends.cudnn.allow_tf32 = True
|
48 |
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
49 |
dtype = torch.float32
|
50 |
save_barrier = 1.03
|
|
|
|
|
51 |
percentile_clipping = 97 # Lion
|
52 |
steps_offset = 1 # Scheduler
|
53 |
limit = 0
|
@@ -91,6 +94,93 @@ class AccelerateDispersiveLoss:
|
|
91 |
self.activations = []
|
92 |
self.hooks = []
|
93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
def register_hooks(self, model, target_layer="down_blocks.0"):
|
95 |
# Получаем "чистую" модель без DDP wrapper'а
|
96 |
unwrapped_model = self.accelerator.unwrap_model(model)
|
@@ -455,12 +545,13 @@ if os.path.isdir(latest_checkpoint):
|
|
455 |
if hasattr(torch.nn.functional, "get_flash_attention_available"):
|
456 |
print(f"torch.nn.functional.get_flash_attention_available(): {torch.nn.functional.get_flash_attention_available()}")
|
457 |
|
458 |
-
# Регистрируем
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
|
|
464 |
|
465 |
if torch_compile:
|
466 |
print("compiling")
|
@@ -590,8 +681,9 @@ else:
|
|
590 |
lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
|
591 |
unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
|
592 |
|
593 |
-
|
594 |
-
|
|
|
595 |
|
596 |
# --------------------------- Фиксированные семплы для генерации ---------------------------
|
597 |
# Примеры фиксированных семплов по размерам
|
@@ -611,9 +703,7 @@ def generate_and_save_samples(fixed_samples_cpu, step):
|
|
611 |
original_model = None # Инициализируем, чтобы finally не ругался
|
612 |
try:
|
613 |
|
614 |
-
original_model = accelerator.unwrap_model(unet)
|
615 |
-
original_model = original_model.to(dtype = dtype)
|
616 |
-
original_model.eval()
|
617 |
|
618 |
vae.to(device=device, dtype=dtype)
|
619 |
vae.eval()
|
@@ -705,9 +795,6 @@ def generate_and_save_samples(fixed_samples_cpu, step):
|
|
705 |
|
706 |
finally:
|
707 |
vae.to("cpu") # Перемещаем VAE обратно на CPU
|
708 |
-
original_model = original_model.to(dtype = dtype)
|
709 |
-
if original_model is not None:
|
710 |
-
del original_model
|
711 |
# Очистка переменных, которые являются тензорами и были созданы в функции
|
712 |
for var in list(locals().keys()):
|
713 |
if isinstance(locals()[var], torch.Tensor):
|
@@ -721,6 +808,7 @@ if accelerator.is_main_process:
|
|
721 |
if save_model:
|
722 |
print("Генерация сэмплов до старта обучения...")
|
723 |
generate_and_save_samples(fixed_samples,0)
|
|
|
724 |
|
725 |
# Модифицируем функцию сохранения модели для поддержки LoRA
|
726 |
def save_checkpoint(unet,variant=""):
|
@@ -775,7 +863,8 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
|
|
775 |
noisy_latents = scheduler.add_noise(latents, noise, timesteps)
|
776 |
|
777 |
# Очищаем активации перед forward pass
|
778 |
-
|
|
|
779 |
|
780 |
# Используем целевое значение
|
781 |
model_pred = unet(noisy_latents, timesteps, embeddings).sample
|
@@ -787,8 +876,12 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
|
|
787 |
# Dispersive Loss
|
788 |
#Идентичные векторы: Loss = -0.0000
|
789 |
#Ортогональные векторы: Loss = -3.9995
|
790 |
-
|
791 |
-
|
|
|
|
|
|
|
|
|
792 |
|
793 |
# Итоговый loss
|
794 |
# dispersive_loss должен падать и тотал падать - поэтому плюс
|
@@ -800,17 +893,20 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
|
|
800 |
save_model = False
|
801 |
break
|
802 |
|
803 |
-
|
|
|
|
|
|
|
804 |
accelerator.backward(total_loss)
|
805 |
|
806 |
if (global_step % 100 == 0) or (global_step % sample_interval == 0):
|
807 |
accelerator.wait_for_everyone()
|
808 |
-
|
809 |
grad = 0.0
|
810 |
if not fbp:
|
811 |
if accelerator.sync_gradients:
|
812 |
-
with torch.cuda.amp.autocast(enabled=False):
|
813 |
-
|
814 |
optimizer.step()
|
815 |
lr_scheduler.step()
|
816 |
optimizer.zero_grad(set_to_none=True)
|
@@ -873,7 +969,8 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
|
|
873 |
wandb.log({"epoch_loss": avg_epoch_loss, "epoch": epoch+1})
|
874 |
|
875 |
# Завершение обучения - сохраняем финальную модель
|
876 |
-
|
|
|
877 |
if accelerator.is_main_process:
|
878 |
print("Обучение завершено! Сохраняем финальную модель...")
|
879 |
# Сохраняем основную модель
|
|
|
25 |
import torch.nn.functional as F
|
26 |
|
27 |
# --------------------------- Параметры ---------------------------
|
28 |
+
ds_path = "datasets/576"
|
29 |
project = "unet"
|
30 |
+
batch_size = 50
|
31 |
+
base_learning_rate = 5e-5
|
32 |
+
min_learning_rate = 1e-5
|
33 |
+
num_epochs = 20
|
34 |
# samples/save per epoch
|
35 |
sample_interval_share = 10
|
36 |
use_wandb = True
|
|
|
43 |
clip_sample = False #Scheduler
|
44 |
fixed_seed = False
|
45 |
shuffle = True
|
46 |
+
dispersive_loss = True
|
47 |
torch.backends.cuda.matmul.allow_tf32 = True
|
48 |
torch.backends.cudnn.allow_tf32 = True
|
49 |
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
50 |
dtype = torch.float32
|
51 |
save_barrier = 1.03
|
52 |
+
dispersive_temperature=0.5
|
53 |
+
dispersive_weight=0.25
|
54 |
percentile_clipping = 97 # Lion
|
55 |
steps_offset = 1 # Scheduler
|
56 |
limit = 0
|
|
|
94 |
self.activations = []
|
95 |
self.hooks = []
|
96 |
|
97 |
+
def register_hooks(self, model, target_layer="down_blocks.0"):
|
98 |
+
unwrapped_model = self.accelerator.unwrap_model(model)
|
99 |
+
print("=== Поиск слоев в unwrapped модели ===")
|
100 |
+
for name, module in unwrapped_model.named_modules():
|
101 |
+
if target_layer in name:
|
102 |
+
hook = module.register_forward_hook(self.hook_fn)
|
103 |
+
self.hooks.append(hook)
|
104 |
+
print(f"✅ Хук зарегистрирован на: {name}")
|
105 |
+
break
|
106 |
+
|
107 |
+
def hook_fn(self, module, input, output):
|
108 |
+
|
109 |
+
if isinstance(output, tuple):
|
110 |
+
activation = output[0]
|
111 |
+
else:
|
112 |
+
activation = output
|
113 |
+
|
114 |
+
if len(activation.shape) > 2:
|
115 |
+
activation = activation.view(activation.shape[0], -1)
|
116 |
+
|
117 |
+
self.activations.append(activation.detach())
|
118 |
+
|
119 |
+
def compute_dispersive_loss(self):
|
120 |
+
if not self.activations:
|
121 |
+
return torch.tensor(0.0, requires_grad=True)
|
122 |
+
|
123 |
+
local_activations = self.activations[-1].float()
|
124 |
+
|
125 |
+
batch_size = local_activations.shape[0]
|
126 |
+
if batch_size < 2:
|
127 |
+
return torch.tensor(0.0, requires_grad=True)
|
128 |
+
|
129 |
+
# Нормализация и вычисление loss
|
130 |
+
sf = local_activations / torch.norm(local_activations, dim=1, keepdim=True)
|
131 |
+
distance = torch.nn.functional.pdist(sf.float(), p=2) ** 2
|
132 |
+
exp_neg_dist = torch.exp(-distance / self.temperature) + 1e-5
|
133 |
+
dispersive_loss = torch.log(torch.mean(exp_neg_dist))
|
134 |
+
|
135 |
+
# ВАЖНО: он отриц и должен падать
|
136 |
+
return dispersive_loss
|
137 |
+
|
138 |
+
def compute_dispersive_loss2(self):
|
139 |
+
# Если нет активаций, возвращаем 0
|
140 |
+
if not self.activations:
|
141 |
+
return torch.tensor(0.0, device=self.accelerator.device, requires_grad=True)
|
142 |
+
|
143 |
+
# Работаем только с локальными активациями главного процесса
|
144 |
+
activations = self.activations[-1].float()
|
145 |
+
|
146 |
+
batch_size = activations.shape[0]
|
147 |
+
if batch_size < 2:
|
148 |
+
return torch.tensor(0.0, device=self.accelerator.device, requires_grad=True)
|
149 |
+
|
150 |
+
# Нормализация
|
151 |
+
norm = torch.norm(activations, dim=1, keepdim=True).clamp(min=1e-12)
|
152 |
+
sf = activations / norm
|
153 |
+
|
154 |
+
# Вычисляем расстояния
|
155 |
+
distance = torch.nn.functional.pdist(sf, p=2)
|
156 |
+
distance = distance.clamp(min=1e-12)
|
157 |
+
distance_squared = distance ** 2
|
158 |
+
|
159 |
+
# Вычисляем loss с клиппингом для стабильности
|
160 |
+
exp_neg_dist = torch.exp((-distance_squared / self.temperature).clamp(min=-20, max=20))
|
161 |
+
exp_neg_dist = exp_neg_dist + 1e-12
|
162 |
+
|
163 |
+
mean_exp = torch.mean(exp_neg_dist)
|
164 |
+
dispersive_loss = torch.log(mean_exp.clamp(min=1e-12))
|
165 |
+
|
166 |
+
return dispersive_loss
|
167 |
+
|
168 |
+
def clear_activations(self):
|
169 |
+
self.activations.clear()
|
170 |
+
|
171 |
+
def remove_hooks(self):
|
172 |
+
for hook in self.hooks:
|
173 |
+
hook.remove()
|
174 |
+
self.hooks.clear()
|
175 |
+
|
176 |
+
class AccelerateDispersiveLoss2:
|
177 |
+
def __init__(self, accelerator, temperature=0.5, weight=0.5):
|
178 |
+
self.accelerator = accelerator
|
179 |
+
self.temperature = temperature
|
180 |
+
self.weight = weight
|
181 |
+
self.activations = []
|
182 |
+
self.hooks = []
|
183 |
+
|
184 |
def register_hooks(self, model, target_layer="down_blocks.0"):
|
185 |
# Получаем "чистую" модель без DDP wrapper'а
|
186 |
unwrapped_model = self.accelerator.unwrap_model(model)
|
|
|
545 |
if hasattr(torch.nn.functional, "get_flash_attention_available"):
|
546 |
print(f"torch.nn.functional.get_flash_attention_available(): {torch.nn.functional.get_flash_attention_available()}")
|
547 |
|
548 |
+
# Регистрируем хук на модел
|
549 |
+
if dispersive_loss:
|
550 |
+
dispersive_hook = AccelerateDispersiveLoss(
|
551 |
+
accelerator=accelerator,
|
552 |
+
temperature=dispersive_temperature,
|
553 |
+
weight=dispersive_weight
|
554 |
+
)
|
555 |
|
556 |
if torch_compile:
|
557 |
print("compiling")
|
|
|
681 |
lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
|
682 |
unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
|
683 |
|
684 |
+
# Регистрация хуков ПОСЛЕ prepare
|
685 |
+
if dispersive_loss:
|
686 |
+
dispersive_hook.register_hooks(unet, "down_blocks.2")
|
687 |
|
688 |
# --------------------------- Фиксированные семплы для генерации ---------------------------
|
689 |
# Примеры фиксированных семплов по размерам
|
|
|
703 |
original_model = None # Инициализируем, чтобы finally не ругался
|
704 |
try:
|
705 |
|
706 |
+
original_model = accelerator.unwrap_model(unet).eval()
|
|
|
|
|
707 |
|
708 |
vae.to(device=device, dtype=dtype)
|
709 |
vae.eval()
|
|
|
795 |
|
796 |
finally:
|
797 |
vae.to("cpu") # Перемещаем VAE обратно на CPU
|
|
|
|
|
|
|
798 |
# Очистка переменных, которые являются тензорами и были созданы в функции
|
799 |
for var in list(locals().keys()):
|
800 |
if isinstance(locals()[var], torch.Tensor):
|
|
|
808 |
if save_model:
|
809 |
print("Генерация сэмплов до старта обучения...")
|
810 |
generate_and_save_samples(fixed_samples,0)
|
811 |
+
accelerator.wait_for_everyone()
|
812 |
|
813 |
# Модифицируем функцию сохранения модели для поддержки LoRA
|
814 |
def save_checkpoint(unet,variant=""):
|
|
|
863 |
noisy_latents = scheduler.add_noise(latents, noise, timesteps)
|
864 |
|
865 |
# Очищаем активации перед forward pass
|
866 |
+
if dispersive_loss:
|
867 |
+
dispersive_hook.clear_activations()
|
868 |
|
869 |
# Используем целевое значение
|
870 |
model_pred = unet(noisy_latents, timesteps, embeddings).sample
|
|
|
876 |
# Dispersive Loss
|
877 |
#Идентичные векторы: Loss = -0.0000
|
878 |
#Ортогональные векторы: Loss = -3.9995
|
879 |
+
if dispersive_loss:
|
880 |
+
with torch.cuda.amp.autocast(enabled=False):
|
881 |
+
dispersive_loss = dispersive_hook.weight * dispersive_hook.compute_dispersive_loss()
|
882 |
+
if torch.isnan(dispersive_loss) or torch.isinf(dispersive_loss):
|
883 |
+
print(f"Rank {accelerator.process_index}: Found nan/inf in dispersive_loss: {total_loss}")
|
884 |
+
break
|
885 |
|
886 |
# Итоговый loss
|
887 |
# dispersive_loss должен падать и тотал падать - поэтому плюс
|
|
|
893 |
save_model = False
|
894 |
break
|
895 |
|
896 |
+
if (global_step % 100 == 0) or (global_step % sample_interval == 0):
|
897 |
+
accelerator.wait_for_everyone()
|
898 |
+
|
899 |
+
# Делаем backward через Accelerator
|
900 |
accelerator.backward(total_loss)
|
901 |
|
902 |
if (global_step % 100 == 0) or (global_step % sample_interval == 0):
|
903 |
accelerator.wait_for_everyone()
|
904 |
+
|
905 |
grad = 0.0
|
906 |
if not fbp:
|
907 |
if accelerator.sync_gradients:
|
908 |
+
#with torch.cuda.amp.autocast(enabled=False):
|
909 |
+
grad = accelerator.clip_grad_norm_(unet.parameters(), 1.)
|
910 |
optimizer.step()
|
911 |
lr_scheduler.step()
|
912 |
optimizer.zero_grad(set_to_none=True)
|
|
|
969 |
wandb.log({"epoch_loss": avg_epoch_loss, "epoch": epoch+1})
|
970 |
|
971 |
# Завершение обучения - сохраняем финальную модель
|
972 |
+
if dispersive_loss:
|
973 |
+
dispersive_hook.remove_hooks()
|
974 |
if accelerator.is_main_process:
|
975 |
print("Обучение завершено! Сохраняем финальную модель...")
|
976 |
# Сохраняем основную модель
|
unet/diffusion_pytorch_model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 7014306128
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:37bee8e3947ce359ec56fedc8c30322465ad3a69e62d4a3964b9c433d975c34c
|
3 |
size 7014306128
|