recoilme commited on
Commit
a68d8ab
·
1 Parent(s): f067885
dataset_fromfolder.py CHANGED
@@ -27,8 +27,8 @@ empty_share = 0.05
27
  limit = 0
28
  textemb_full = False
29
  # Основная процедура обработки
30
- folder_path = "/workspace/d23"
31
- save_path = "/workspace/sdxs/datasets/ds23_576"
32
  os.makedirs(save_path, exist_ok=True)
33
 
34
  # Функция для очистки CUDA памяти
 
27
  limit = 0
28
  textemb_full = False
29
  # Основная процедура обработки
30
+ folder_path = "/workspace/eshu"
31
+ save_path = "/workspace/sdxs/datasets/eshu_576"
32
  os.makedirs(save_path, exist_ok=True)
33
 
34
  # Функция для очистки CUDA памяти
samples/unet_192x384_0.jpg CHANGED

Git LFS Details

  • SHA256: dfe87c50cb07b34f910992ce566d579ca58c081fabf25364402b92b1c2a4b392
  • Pointer size: 130 Bytes
  • Size of remote file: 48 kB

Git LFS Details

  • SHA256: 7441e68cca6ab353f2333dcb0f54c38e040770a6dc10a62f3a602efdc9bc759c
  • Pointer size: 130 Bytes
  • Size of remote file: 41.9 kB
samples/unet_256x384_0.jpg CHANGED

Git LFS Details

  • SHA256: e50cb93b86fb956c5677d892a146adfdc36600f5941aa1c13543b6dc5ea2b55e
  • Pointer size: 130 Bytes
  • Size of remote file: 47.1 kB

Git LFS Details

  • SHA256: 11b6cd3bdef6f5946593d13d3e611b3b0efc34fb4b2ac54e68cd7157afe810c2
  • Pointer size: 130 Bytes
  • Size of remote file: 51 kB
samples/unet_320x384_0.jpg CHANGED

Git LFS Details

  • SHA256: b607d732417e715403d664b88d99a8ef751147f2cfe717242329ce54a521a2e7
  • Pointer size: 130 Bytes
  • Size of remote file: 42.1 kB

Git LFS Details

  • SHA256: 5a1562843e4b7908b17934c1e524cd0a56a3e182e4d0c348414814505c84af16
  • Pointer size: 130 Bytes
  • Size of remote file: 46.7 kB
samples/unet_320x576_0.jpg CHANGED

Git LFS Details

  • SHA256: 5e0c54615eb21da68780d3f9cd25acf896011b740593fdab126b85c140fef57e
  • Pointer size: 130 Bytes
  • Size of remote file: 96.2 kB

Git LFS Details

  • SHA256: 1cf40f1bab8e5fb72ba60450de7e83b72812a0c750163e43cbd2f1526d6d7cc1
  • Pointer size: 130 Bytes
  • Size of remote file: 38 kB
samples/unet_384x192_0.jpg CHANGED

Git LFS Details

  • SHA256: 0a9631d14f9bc0f9a5b717606ddfa539dc913604fc759e96b73ee61803c2b7ea
  • Pointer size: 130 Bytes
  • Size of remote file: 21.2 kB

Git LFS Details

  • SHA256: 8879fdcba523f7060952be1b81a3ff129111f1ded2563f7b8a6281343311b6e3
  • Pointer size: 130 Bytes
  • Size of remote file: 21.9 kB
samples/unet_384x256_0.jpg CHANGED

Git LFS Details

  • SHA256: efbf07265ebca0391aebafe32c19a17665825df71fa31ab23878b773255eda12
  • Pointer size: 130 Bytes
  • Size of remote file: 46.8 kB

Git LFS Details

  • SHA256: 1567ee2a614927c817201109afaae56f2ec3ad8728f12a8206c7c43d2e182359
  • Pointer size: 130 Bytes
  • Size of remote file: 39.8 kB
samples/unet_384x320_0.jpg CHANGED

Git LFS Details

  • SHA256: 79acc895ddd10c154eb53f8ae3235b4047ed1d1f4340c67ecb066878627bb2fa
  • Pointer size: 130 Bytes
  • Size of remote file: 50.5 kB

Git LFS Details

  • SHA256: 6f6aa97c58bbef96a078e299ca186856aec81465cb3e9064946460ac55c357f3
  • Pointer size: 130 Bytes
  • Size of remote file: 43.3 kB
samples/unet_384x384_0.jpg CHANGED

Git LFS Details

  • SHA256: e4286ad643307345a839df1e55217ebf4bf679b3c37ec806982250df4a3770ed
  • Pointer size: 130 Bytes
  • Size of remote file: 53.9 kB

Git LFS Details

  • SHA256: 16d6058221d63945efb967b4c9b1249b2b67a0e2fc789f94d2c4bfbc0c6c24e5
  • Pointer size: 130 Bytes
  • Size of remote file: 58.8 kB
samples/unet_384x576_0.jpg CHANGED

Git LFS Details

  • SHA256: 8eacfd5ff67ddbdb0c2f5c41747824f728bbd309bc2b4eab1f49e52caaf31e46
  • Pointer size: 130 Bytes
  • Size of remote file: 72.4 kB

Git LFS Details

  • SHA256: f1b4f18bcddf4661fe80bc465e38a79e685bfabf381b0ce058682d566567e454
  • Pointer size: 130 Bytes
  • Size of remote file: 55.4 kB
samples/unet_448x576_0.jpg CHANGED

Git LFS Details

  • SHA256: 1609ca651c2a35a866369b66699816cd39fda2dcb087ff0e9a813f03a4b3733d
  • Pointer size: 131 Bytes
  • Size of remote file: 135 kB

Git LFS Details

  • SHA256: aca4f21558d219935c4c199d9de7243325aa88b1292397e7fddfa20ece5cfb21
  • Pointer size: 131 Bytes
  • Size of remote file: 138 kB
samples/unet_512x576_0.jpg CHANGED

Git LFS Details

  • SHA256: c32f39a75d3f67361ff4cbbbb63ffe1f27dbc0ff07a9c4b472b45805c06eb8ff
  • Pointer size: 130 Bytes
  • Size of remote file: 84.6 kB

Git LFS Details

  • SHA256: 19b2a1eaf9fe3a88f6f8c946261a7344fb5b5d01a307aaf31f4ff679d6a92599
  • Pointer size: 131 Bytes
  • Size of remote file: 100 kB
samples/unet_576x320_0.jpg CHANGED

Git LFS Details

  • SHA256: 25da5efaaa2b7ea76537ec249ea0d53aa12f93daf5b02ea4a56fa2a650d3dbe7
  • Pointer size: 131 Bytes
  • Size of remote file: 109 kB

Git LFS Details

  • SHA256: 7df2e3165e5a5b39a39b438e36c63306fe76c2459d92ffab7d89abd0a9dfe449
  • Pointer size: 130 Bytes
  • Size of remote file: 42.5 kB
samples/unet_576x384_0.jpg CHANGED

Git LFS Details

  • SHA256: 2a8d21984a042592c75bf35fd6baf68adfbbb6ce65664d6a6bc6f537e9ef48c3
  • Pointer size: 130 Bytes
  • Size of remote file: 51.7 kB

Git LFS Details

  • SHA256: 379854c10a80d1804b2984fc21c97dddd7e11ac9bb899724fd660767445456fd
  • Pointer size: 130 Bytes
  • Size of remote file: 66 kB
samples/unet_576x448_0.jpg CHANGED

Git LFS Details

  • SHA256: a1ee124ddc9809df36b2ce8e9079b74c7da391939fc95eac17ae943722a3a452
  • Pointer size: 130 Bytes
  • Size of remote file: 91 kB

Git LFS Details

  • SHA256: fdd4715477657b87ad2169346061a8484f4992e7432cdfe83b506d2019c77cbf
  • Pointer size: 131 Bytes
  • Size of remote file: 124 kB
samples/unet_576x512_0.jpg CHANGED

Git LFS Details

  • SHA256: 0607e3a57a9f1f3d0273f9fd0cd0bf4b8aa269fe64d22fde3bfa1b5fdb6d5bfc
  • Pointer size: 131 Bytes
  • Size of remote file: 139 kB

Git LFS Details

  • SHA256: 1bd576e31d5334429c8e8f9ccf17af5bfb9b8c38056f1798992e0bb12036e588
  • Pointer size: 130 Bytes
  • Size of remote file: 75.4 kB
samples/unet_576x576_0.jpg CHANGED

Git LFS Details

  • SHA256: a7c5453db91bbd9e99a078efb8b9256856427cd56b6b62b0129475a9a92d2c77
  • Pointer size: 131 Bytes
  • Size of remote file: 126 kB

Git LFS Details

  • SHA256: e7467ef68b7224f74164b88643db0ef13dc9d01d4698d44036c64cbe813bb51b
  • Pointer size: 130 Bytes
  • Size of remote file: 75 kB
src/dataset_combine.py CHANGED
@@ -65,4 +65,4 @@ def combine_datasets(main_dataset_path, datasets_to_add):
65
 
66
  return combined
67
 
68
- combine_datasets("/workspace/sdxs/datasets/mjnj_576", ["/workspace/sdxs/datasets/ds23_576"])
 
65
 
66
  return combined
67
 
68
+ combine_datasets("/workspace/sdxs/datasets/576", ["/workspace/sdxs/datasets/eshu_576"])
train.py CHANGED
@@ -25,12 +25,12 @@ import bitsandbytes as bnb
25
  import torch.nn.functional as F
26
 
27
  # --------------------------- Параметры ---------------------------
28
- ds_path = "datasets/384"
29
  project = "unet"
30
- batch_size = 40
31
- base_learning_rate = 1e-5
32
- min_learning_rate = 9e-6
33
- num_epochs = 6
34
  # samples/save per epoch
35
  sample_interval_share = 10
36
  use_wandb = True
@@ -43,11 +43,14 @@ unet_gradient = True
43
  clip_sample = False #Scheduler
44
  fixed_seed = False
45
  shuffle = True
 
46
  torch.backends.cuda.matmul.allow_tf32 = True
47
  torch.backends.cudnn.allow_tf32 = True
48
  torch.backends.cuda.enable_mem_efficient_sdp(True)
49
  dtype = torch.float32
50
  save_barrier = 1.03
 
 
51
  percentile_clipping = 97 # Lion
52
  steps_offset = 1 # Scheduler
53
  limit = 0
@@ -91,6 +94,93 @@ class AccelerateDispersiveLoss:
91
  self.activations = []
92
  self.hooks = []
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  def register_hooks(self, model, target_layer="down_blocks.0"):
95
  # Получаем "чистую" модель без DDP wrapper'а
96
  unwrapped_model = self.accelerator.unwrap_model(model)
@@ -455,12 +545,13 @@ if os.path.isdir(latest_checkpoint):
455
  if hasattr(torch.nn.functional, "get_flash_attention_available"):
456
  print(f"torch.nn.functional.get_flash_attention_available(): {torch.nn.functional.get_flash_attention_available()}")
457
 
458
- # Регистрируем хуки на модел
459
- dispersive_hook = AccelerateDispersiveLoss(
460
- accelerator=accelerator,
461
- temperature=2,
462
- weight=0.25
463
- )
 
464
 
465
  if torch_compile:
466
  print("compiling")
@@ -590,8 +681,9 @@ else:
590
  lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
591
  unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
592
 
593
- # Регистрация хуков ПОСЛЕ prepare
594
- dispersive_hook.register_hooks(unet, "down_blocks.2")
 
595
 
596
  # --------------------------- Фиксированные семплы для генерации ---------------------------
597
  # Примеры фиксированных семплов по размерам
@@ -611,9 +703,7 @@ def generate_and_save_samples(fixed_samples_cpu, step):
611
  original_model = None # Инициализируем, чтобы finally не ругался
612
  try:
613
 
614
- original_model = accelerator.unwrap_model(unet)
615
- original_model = original_model.to(dtype = dtype)
616
- original_model.eval()
617
 
618
  vae.to(device=device, dtype=dtype)
619
  vae.eval()
@@ -705,9 +795,6 @@ def generate_and_save_samples(fixed_samples_cpu, step):
705
 
706
  finally:
707
  vae.to("cpu") # Перемещаем VAE обратно на CPU
708
- original_model = original_model.to(dtype = dtype)
709
- if original_model is not None:
710
- del original_model
711
  # Очистка переменных, которые являются тензорами и были созданы в функции
712
  for var in list(locals().keys()):
713
  if isinstance(locals()[var], torch.Tensor):
@@ -721,6 +808,7 @@ if accelerator.is_main_process:
721
  if save_model:
722
  print("Генерация сэмплов до старта обучения...")
723
  generate_and_save_samples(fixed_samples,0)
 
724
 
725
  # Модифицируем функцию сохранения модели для поддержки LoRA
726
  def save_checkpoint(unet,variant=""):
@@ -775,7 +863,8 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
775
  noisy_latents = scheduler.add_noise(latents, noise, timesteps)
776
 
777
  # Очищаем активации перед forward pass
778
- dispersive_hook.clear_activations()
 
779
 
780
  # Используем целевое значение
781
  model_pred = unet(noisy_latents, timesteps, embeddings).sample
@@ -787,8 +876,12 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
787
  # Dispersive Loss
788
  #Идентичные векторы: Loss = -0.0000
789
  #Ортогональные векторы: Loss = -3.9995
790
- with torch.cuda.amp.autocast(enabled=False):
791
- dispersive_loss = dispersive_hook.weight * dispersive_hook.compute_dispersive_loss()
 
 
 
 
792
 
793
  # Итоговый loss
794
  # dispersive_loss должен падать и тотал падать - поэтому плюс
@@ -800,17 +893,20 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
800
  save_model = False
801
  break
802
 
803
- # Делаем backward через Accelerator
 
 
 
804
  accelerator.backward(total_loss)
805
 
806
  if (global_step % 100 == 0) or (global_step % sample_interval == 0):
807
  accelerator.wait_for_everyone()
808
-
809
  grad = 0.0
810
  if not fbp:
811
  if accelerator.sync_gradients:
812
- with torch.cuda.amp.autocast(enabled=False):
813
- grad = accelerator.clip_grad_norm_(unet.parameters(), 1.)
814
  optimizer.step()
815
  lr_scheduler.step()
816
  optimizer.zero_grad(set_to_none=True)
@@ -873,7 +969,8 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
873
  wandb.log({"epoch_loss": avg_epoch_loss, "epoch": epoch+1})
874
 
875
  # Завершение обучения - сохраняем финальную модель
876
- dispersive_hook.remove_hooks()
 
877
  if accelerator.is_main_process:
878
  print("Обучение завершено! Сохраняем финальную модель...")
879
  # Сохраняем основную модель
 
25
  import torch.nn.functional as F
26
 
27
  # --------------------------- Параметры ---------------------------
28
+ ds_path = "datasets/576"
29
  project = "unet"
30
+ batch_size = 50
31
+ base_learning_rate = 5e-5
32
+ min_learning_rate = 1e-5
33
+ num_epochs = 20
34
  # samples/save per epoch
35
  sample_interval_share = 10
36
  use_wandb = True
 
43
  clip_sample = False #Scheduler
44
  fixed_seed = False
45
  shuffle = True
46
+ dispersive_loss = True
47
  torch.backends.cuda.matmul.allow_tf32 = True
48
  torch.backends.cudnn.allow_tf32 = True
49
  torch.backends.cuda.enable_mem_efficient_sdp(True)
50
  dtype = torch.float32
51
  save_barrier = 1.03
52
+ dispersive_temperature=0.5
53
+ dispersive_weight=0.25
54
  percentile_clipping = 97 # Lion
55
  steps_offset = 1 # Scheduler
56
  limit = 0
 
94
  self.activations = []
95
  self.hooks = []
96
 
97
+ def register_hooks(self, model, target_layer="down_blocks.0"):
98
+ unwrapped_model = self.accelerator.unwrap_model(model)
99
+ print("=== Поиск слоев в unwrapped модели ===")
100
+ for name, module in unwrapped_model.named_modules():
101
+ if target_layer in name:
102
+ hook = module.register_forward_hook(self.hook_fn)
103
+ self.hooks.append(hook)
104
+ print(f"✅ Хук зарегистрирован на: {name}")
105
+ break
106
+
107
+ def hook_fn(self, module, input, output):
108
+
109
+ if isinstance(output, tuple):
110
+ activation = output[0]
111
+ else:
112
+ activation = output
113
+
114
+ if len(activation.shape) > 2:
115
+ activation = activation.view(activation.shape[0], -1)
116
+
117
+ self.activations.append(activation.detach())
118
+
119
+ def compute_dispersive_loss(self):
120
+ if not self.activations:
121
+ return torch.tensor(0.0, requires_grad=True)
122
+
123
+ local_activations = self.activations[-1].float()
124
+
125
+ batch_size = local_activations.shape[0]
126
+ if batch_size < 2:
127
+ return torch.tensor(0.0, requires_grad=True)
128
+
129
+ # Нормализация и вычисление loss
130
+ sf = local_activations / torch.norm(local_activations, dim=1, keepdim=True)
131
+ distance = torch.nn.functional.pdist(sf.float(), p=2) ** 2
132
+ exp_neg_dist = torch.exp(-distance / self.temperature) + 1e-5
133
+ dispersive_loss = torch.log(torch.mean(exp_neg_dist))
134
+
135
+ # ВАЖНО: он отриц и должен падать
136
+ return dispersive_loss
137
+
138
+ def compute_dispersive_loss2(self):
139
+ # Если нет активаций, возвращаем 0
140
+ if not self.activations:
141
+ return torch.tensor(0.0, device=self.accelerator.device, requires_grad=True)
142
+
143
+ # Работаем только с локальными активациями главного процесса
144
+ activations = self.activations[-1].float()
145
+
146
+ batch_size = activations.shape[0]
147
+ if batch_size < 2:
148
+ return torch.tensor(0.0, device=self.accelerator.device, requires_grad=True)
149
+
150
+ # Нормализация
151
+ norm = torch.norm(activations, dim=1, keepdim=True).clamp(min=1e-12)
152
+ sf = activations / norm
153
+
154
+ # Вычисляем расстояния
155
+ distance = torch.nn.functional.pdist(sf, p=2)
156
+ distance = distance.clamp(min=1e-12)
157
+ distance_squared = distance ** 2
158
+
159
+ # Вычисляем loss с клиппингом для стабильности
160
+ exp_neg_dist = torch.exp((-distance_squared / self.temperature).clamp(min=-20, max=20))
161
+ exp_neg_dist = exp_neg_dist + 1e-12
162
+
163
+ mean_exp = torch.mean(exp_neg_dist)
164
+ dispersive_loss = torch.log(mean_exp.clamp(min=1e-12))
165
+
166
+ return dispersive_loss
167
+
168
+ def clear_activations(self):
169
+ self.activations.clear()
170
+
171
+ def remove_hooks(self):
172
+ for hook in self.hooks:
173
+ hook.remove()
174
+ self.hooks.clear()
175
+
176
+ class AccelerateDispersiveLoss2:
177
+ def __init__(self, accelerator, temperature=0.5, weight=0.5):
178
+ self.accelerator = accelerator
179
+ self.temperature = temperature
180
+ self.weight = weight
181
+ self.activations = []
182
+ self.hooks = []
183
+
184
  def register_hooks(self, model, target_layer="down_blocks.0"):
185
  # Получаем "чистую" модель без DDP wrapper'а
186
  unwrapped_model = self.accelerator.unwrap_model(model)
 
545
  if hasattr(torch.nn.functional, "get_flash_attention_available"):
546
  print(f"torch.nn.functional.get_flash_attention_available(): {torch.nn.functional.get_flash_attention_available()}")
547
 
548
+ # Регистрируем хук на модел
549
+ if dispersive_loss:
550
+ dispersive_hook = AccelerateDispersiveLoss(
551
+ accelerator=accelerator,
552
+ temperature=dispersive_temperature,
553
+ weight=dispersive_weight
554
+ )
555
 
556
  if torch_compile:
557
  print("compiling")
 
681
  lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
682
  unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
683
 
684
+ # Регистрация хуков ПОСЛЕ prepare
685
+ if dispersive_loss:
686
+ dispersive_hook.register_hooks(unet, "down_blocks.2")
687
 
688
  # --------------------------- Фиксированные семплы для генерации ---------------------------
689
  # Примеры фиксированных семплов по размерам
 
703
  original_model = None # Инициализируем, чтобы finally не ругался
704
  try:
705
 
706
+ original_model = accelerator.unwrap_model(unet).eval()
 
 
707
 
708
  vae.to(device=device, dtype=dtype)
709
  vae.eval()
 
795
 
796
  finally:
797
  vae.to("cpu") # Перемещаем VAE обратно на CPU
 
 
 
798
  # Очистка переменных, которые являются тензорами и были созданы в функции
799
  for var in list(locals().keys()):
800
  if isinstance(locals()[var], torch.Tensor):
 
808
  if save_model:
809
  print("Генерация сэмплов до старта обучения...")
810
  generate_and_save_samples(fixed_samples,0)
811
+ accelerator.wait_for_everyone()
812
 
813
  # Модифицируем функцию сохранения модели для поддержки LoRA
814
  def save_checkpoint(unet,variant=""):
 
863
  noisy_latents = scheduler.add_noise(latents, noise, timesteps)
864
 
865
  # Очищаем активации перед forward pass
866
+ if dispersive_loss:
867
+ dispersive_hook.clear_activations()
868
 
869
  # Используем целевое значение
870
  model_pred = unet(noisy_latents, timesteps, embeddings).sample
 
876
  # Dispersive Loss
877
  #Идентичные векторы: Loss = -0.0000
878
  #Ортогональные векторы: Loss = -3.9995
879
+ if dispersive_loss:
880
+ with torch.cuda.amp.autocast(enabled=False):
881
+ dispersive_loss = dispersive_hook.weight * dispersive_hook.compute_dispersive_loss()
882
+ if torch.isnan(dispersive_loss) or torch.isinf(dispersive_loss):
883
+ print(f"Rank {accelerator.process_index}: Found nan/inf in dispersive_loss: {total_loss}")
884
+ break
885
 
886
  # Итоговый loss
887
  # dispersive_loss должен падать и тотал падать - поэтому плюс
 
893
  save_model = False
894
  break
895
 
896
+ if (global_step % 100 == 0) or (global_step % sample_interval == 0):
897
+ accelerator.wait_for_everyone()
898
+
899
+ # Делаем backward через Accelerator
900
  accelerator.backward(total_loss)
901
 
902
  if (global_step % 100 == 0) or (global_step % sample_interval == 0):
903
  accelerator.wait_for_everyone()
904
+
905
  grad = 0.0
906
  if not fbp:
907
  if accelerator.sync_gradients:
908
+ #with torch.cuda.amp.autocast(enabled=False):
909
+ grad = accelerator.clip_grad_norm_(unet.parameters(), 1.)
910
  optimizer.step()
911
  lr_scheduler.step()
912
  optimizer.zero_grad(set_to_none=True)
 
969
  wandb.log({"epoch_loss": avg_epoch_loss, "epoch": epoch+1})
970
 
971
  # Завершение обучения - сохраняем финальную модель
972
+ if dispersive_loss:
973
+ dispersive_hook.remove_hooks()
974
  if accelerator.is_main_process:
975
  print("Обучение завершено! Сохраняем финальную модель...")
976
  # Сохраняем основную модель
unet/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7905a5f9fc1ead936613fd66ef66675ab7cde73128c49208ee1c02668dcb8527
3
  size 7014306128
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:37bee8e3947ce359ec56fedc8c30322465ad3a69e62d4a3964b9c433d975c34c
3
  size 7014306128