recoilme commited on
Commit
5b9c339
·
1 Parent(s): 4aa2866
samples/unet_320x576_0.jpg CHANGED

Git LFS Details

  • SHA256: 2cd448dd5de462d52f70980f87bdea2083c250621894ffc868e4735b20adc694
  • Pointer size: 130 Bytes
  • Size of remote file: 59.2 kB

Git LFS Details

  • SHA256: 13dc91c061b6b36b5fa419c372bd33bd4e02bf05f1f3a1092ac4d207d97889c0
  • Pointer size: 130 Bytes
  • Size of remote file: 55.7 kB
samples/unet_384x576_0.jpg CHANGED

Git LFS Details

  • SHA256: 4ea2aa8fb41888b9d47019f72d6e4c3df27b1e24ac5835093a12649d2edc4064
  • Pointer size: 130 Bytes
  • Size of remote file: 52.2 kB

Git LFS Details

  • SHA256: 5b81ef0ed8250322f7df1d9e9290af27f4a127f80e4ca267b5b0d6011bac6df8
  • Pointer size: 130 Bytes
  • Size of remote file: 79.9 kB
samples/unet_448x576_0.jpg CHANGED

Git LFS Details

  • SHA256: cf7beba7d5450c4640f2848f4455055615a183207a77df6b78cf6922af78de37
  • Pointer size: 131 Bytes
  • Size of remote file: 151 kB

Git LFS Details

  • SHA256: a6822518d2b7c6bd3423ef05e57b197ab5fb41a311ebad69380bf86bea0834fe
  • Pointer size: 131 Bytes
  • Size of remote file: 154 kB
samples/unet_512x576_0.jpg CHANGED

Git LFS Details

  • SHA256: 1ebe6d6dc58bb7cde2e1826dd346c367d7a0d8d46e0ebb999209df2831259879
  • Pointer size: 130 Bytes
  • Size of remote file: 82.5 kB

Git LFS Details

  • SHA256: 8d7e797b2113fdbe0108e6c77f17ad0045249738098b823e2bd96963ff58f7bd
  • Pointer size: 130 Bytes
  • Size of remote file: 72.5 kB
samples/unet_576x320_0.jpg CHANGED

Git LFS Details

  • SHA256: 244f1830c98c5ed0d373c20a6f925219468ac90111b037c7a7292095c70b2278
  • Pointer size: 130 Bytes
  • Size of remote file: 80.2 kB

Git LFS Details

  • SHA256: f1c4a5c7dbd07f2d9fa8c7fa992d369b9fd465202a5809c9a5f6ba2cbc8ca583
  • Pointer size: 130 Bytes
  • Size of remote file: 90.2 kB
samples/unet_576x384_0.jpg CHANGED

Git LFS Details

  • SHA256: 7919d5f28322e3de17f26dc1f809596540cf523470f525d4ba9afd94b39effba
  • Pointer size: 130 Bytes
  • Size of remote file: 66.7 kB

Git LFS Details

  • SHA256: 2be73a6efba7447bc6ff203dcd63f23461f473d9ff6f19a79b2fa7dae2e6f1d6
  • Pointer size: 130 Bytes
  • Size of remote file: 72.1 kB
samples/unet_576x448_0.jpg CHANGED

Git LFS Details

  • SHA256: 64fe9c9a9956fa78a477983587d742e832eea0a24a881113ff6d2ef11ba42565
  • Pointer size: 131 Bytes
  • Size of remote file: 120 kB

Git LFS Details

  • SHA256: 12f796c06170182ff2770a1feae4400f5a01f73725047e78ae7d809581a99e42
  • Pointer size: 131 Bytes
  • Size of remote file: 116 kB
samples/unet_576x512_0.jpg CHANGED

Git LFS Details

  • SHA256: aa1d853283088b5df8b392bbba27a940cf13ae4df5c1c5d6ef27453b2fc0d9ea
  • Pointer size: 131 Bytes
  • Size of remote file: 106 kB

Git LFS Details

  • SHA256: f54fd3334405b6189f6578263902e3d3cb5b58b688e775f9dffebfbd528b2ab0
  • Pointer size: 131 Bytes
  • Size of remote file: 105 kB
samples/unet_576x576_0.jpg CHANGED

Git LFS Details

  • SHA256: 21d34e928dfa69c4972f82ccde7381636625b64ff3f25a5e6a84d6ff7f2ae372
  • Pointer size: 131 Bytes
  • Size of remote file: 185 kB

Git LFS Details

  • SHA256: fb81309b1c64c3d4b25828193d673298aa72d90192bf0d1d442c7d569dc276b8
  • Pointer size: 131 Bytes
  • Size of remote file: 162 kB
train-Copy1.py CHANGED
@@ -5,6 +5,7 @@ import numpy as np
5
  import matplotlib.pyplot as plt
6
  from torch.utils.data import DataLoader, Sampler
7
  from torch.utils.data.distributed import DistributedSampler
 
8
  from collections import defaultdict
9
  from torch.optim.lr_scheduler import LambdaLR
10
  from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler
@@ -21,31 +22,40 @@ from torch.utils.checkpoint import checkpoint
21
  from diffusers.models.attention_processor import AttnProcessor2_0
22
  from datetime import datetime
23
  import bitsandbytes as bnb
 
24
 
25
  # --------------------------- Параметры ---------------------------
26
- ds_path = "datasets/384"
27
- batch_size = 50
28
- base_learning_rate = 3e-5
29
- min_learning_rate = 3e-6
30
- num_epochs = 10
31
- num_warmup_steps = 1000
32
  project = "unet"
 
 
 
 
 
 
33
  use_wandb = True
34
  save_model = True
35
- sample_interval_share = 5 # samples/save per epoch
36
  fbp = False # fused backward pass
37
- adam8bit = True
38
- percentile_clipping = 97 # Lion
39
  torch_compile = False
40
  unet_gradient = True
41
  clip_sample = False #Scheduler
42
  fixed_seed = False
43
  shuffle = True
 
 
 
 
44
  dtype = torch.float32
 
 
 
 
45
  steps_offset = 1 # Scheduler
46
  limit = 0
47
  checkpoints_folder = ""
48
- mixed_precision = "no"
49
  accelerator = Accelerator(mixed_precision=mixed_precision)
50
  device = accelerator.device
51
 
@@ -68,8 +78,6 @@ if fixed_seed:
68
  if torch.cuda.is_available():
69
  torch.cuda.manual_seed_all(seed)
70
 
71
- #torch.backends.cuda.matmul.allow_tf32 = True
72
- #torch.backends.cudnn.allow_tf32 = True
73
  # --------------------------- Параметры LoRA ---------------------------
74
  # pip install peft
75
  lora_name = "" #"nusha" # Имя для сохранения/загрузки LoRA адаптеров
@@ -78,6 +86,228 @@ lora_alpha = 64 # Альфа параметр LoRA, определяющий м
78
 
79
  print("init")
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  # --------------------------- Инициализация WandB ---------------------------
82
  if use_wandb and accelerator.is_main_process:
83
  wandb.init(project=project+lora_name, config={
@@ -85,7 +315,7 @@ if use_wandb and accelerator.is_main_process:
85
  "base_learning_rate": base_learning_rate,
86
  "num_epochs": num_epochs,
87
  "fbp": fbp,
88
- "adam8bit": adam8bit,
89
  })
90
 
91
  # Включение Flash Attention 2/SDPA
@@ -107,6 +337,7 @@ scheduler = DDPMScheduler(
107
  steps_offset = steps_offset
108
  )
109
 
 
110
  class DistributedResolutionBatchSampler(Sampler):
111
  def __init__(self, dataset, batch_size, num_replicas, rank, shuffle=True, drop_last=True):
112
  self.dataset = dataset
@@ -263,10 +494,6 @@ def collate_fn(batch):
263
  embeddings = torch.tensor(np.array(valid_embeddings)).to(device,dtype=dtype)
264
 
265
  return latents, embeddings
266
-
267
- # Используем наш ResolutionBatchSampler
268
- #batch_sampler = ResolutionBatchSampler(dataset, batch_size=batch_size, shuffle=True)
269
- #dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn)
270
 
271
  # Создаем ResolutionBatchSampler на основе индексов от DistributedSampler
272
  batch_sampler = DistributedResolutionBatchSampler(
@@ -297,10 +524,10 @@ world_size = accelerator.state.num_processes
297
  latest_checkpoint = os.path.join(checkpoints_folder, project)
298
  if os.path.isdir(latest_checkpoint):
299
  print("Загружаем UNet из чекпоинта:", latest_checkpoint)
300
- if dtype == torch.float32:
301
- unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device,dtype=dtype)
302
- else:
303
- unet = UNet2DConditionModel.from_pretrained(latest_checkpoint, variant="fp16").to(device=device,dtype=dtype)
304
  if unet_gradient:
305
  unet.enable_gradient_checkpointing()
306
  unet.set_use_memory_efficient_attention_xformers(False) # отключаем xformers
@@ -317,6 +544,15 @@ if os.path.isdir(latest_checkpoint):
317
  print(f"torch.backends.cuda.mem_efficient_sdp_enabled(): {torch.backends.cuda.mem_efficient_sdp_enabled()}")
318
  if hasattr(torch.nn.functional, "get_flash_attention_available"):
319
  print(f"torch.nn.functional.get_flash_attention_available(): {torch.nn.functional.get_flash_attention_available()}")
 
 
 
 
 
 
 
 
 
320
  if torch_compile:
321
  print("compiling")
322
  torch.set_float32_matmul_precision('high')
@@ -388,98 +624,66 @@ else:
388
  if fbp:
389
  trainable_params = list(unet.parameters())
390
 
391
- if fbp:
392
- # [1] Создаем словарь оптимизаторов (fused backward)
393
- if adam8bit:
394
- optimizer_dict = {
395
- p: bnb.optim.AdamW8bit(
396
- [p], # Каждый параметр получает свой оптимизатор
397
- lr=base_learning_rate,
398
- eps=1e-8
399
- ) for p in trainable_params
400
- }
 
 
 
 
 
 
 
 
 
 
 
 
401
  else:
402
- optimizer_dict = {
403
- p: bnb.optim.Lion8bit(
404
- [p], # Каждый параметр получает свой оптимизатор
405
- lr=base_learning_rate,
406
- betas=(0.9, 0.97),
407
- weight_decay=0.01,
408
- percentile_clipping=percentile_clipping,
409
- ) for p in trainable_params
410
- }
411
 
412
- # [2] Определяем hook для применения оптимизатора сразу после накопления градиента
413
  def optimizer_hook(param):
414
  optimizer_dict[param].step()
415
  optimizer_dict[param].zero_grad(set_to_none=True)
416
 
417
- # [3] Регистрируем hook для trainable параметров модели
418
  for param in trainable_params:
419
  param.register_post_accumulate_grad_hook(optimizer_hook)
420
 
421
- # Подготовка через Accelerator
422
  unet, optimizer = accelerator.prepare(unet, optimizer_dict)
423
  else:
424
- if adam8bit:
425
- optimizer = bnb.optim.AdamW8bit(
426
- params=unet.parameters(),
427
- lr=base_learning_rate,
428
- betas=(0.9, 0.999),
429
- eps=1e-8,
430
- weight_decay=0.01
431
- )
432
- #from torch.optim import AdamW
433
- #optimizer = AdamW(
434
- # params=unet.parameters(),
435
- # lr=base_learning_rate,
436
- # betas=(0.9, 0.999),
437
- # eps=1e-8,
438
- # weight_decay=0.01
439
- #)
440
- else:
441
- optimizer = bnb.optim.Lion8bit(
442
- params=unet.parameters(),
443
- lr=base_learning_rate,
444
- betas=(0.9, 0.97),
445
- weight_decay=0.01,
446
- percentile_clipping=percentile_clipping,
447
- )
448
- from transformers import get_constant_schedule_with_warmup
449
-
450
- # warmup
451
- num_warmup_steps = num_warmup_steps * world_size
452
-
453
- #lr_scheduler = get_constant_schedule_with_warmup(
454
- # optimizer=optimizer,
455
- # num_warmup_steps=num_warmup_steps
456
- #)
457
- from torch.optim.lr_scheduler import LambdaLR
458
- def lr_schedule(step, max_steps, base_lr, min_lr, use_decay=True):
459
- # Если не используем затухание, возвращаем базовый LR
460
- if not use_decay:
461
- return base_lr
462
-
463
- # Иначе используем линейный прогрев и косинусное затухание
464
- x = step / max_steps
465
- percent = 0.05
466
- if x < percent:
467
- # Линейный прогрев до percent% шагов
468
- return min_lr + (base_lr - min_lr) * (x / percent)
469
- else:
470
- # Косинусное затухание
471
- decay_ratio = (x - percent) / (1 - percent)
472
- return min_lr + 0.5 * (base_lr - min_lr) * (1 + math.cos(math.pi * decay_ratio))
473
 
 
 
 
 
 
 
 
 
 
 
 
 
474
 
475
- def custom_lr_lambda(step):
476
- return lr_schedule(step, total_training_steps*world_size,
477
- base_learning_rate, min_learning_rate,
478
- (num_warmup_steps>0)) / base_learning_rate
479
-
480
- lr_scheduler = LambdaLR(optimizer, lr_lambda=custom_lr_lambda)
481
  unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
482
 
 
 
 
 
483
  # --------------------------- Фиксированные семплы для генерации ---------------------------
484
  # Примеры фиксированных семплов по размерам
485
  fixed_samples = get_fixed_samples_by_resolution(dataset)
@@ -498,9 +702,7 @@ def generate_and_save_samples(fixed_samples_cpu, step):
498
  original_model = None # Инициализируем, чтобы finally не ругался
499
  try:
500
 
501
- original_model = accelerator.unwrap_model(unet)
502
- original_model = original_model.to(dtype = dtype)
503
- original_model.eval()
504
 
505
  vae.to(device=device, dtype=dtype)
506
  vae.eval()
@@ -592,9 +794,6 @@ def generate_and_save_samples(fixed_samples_cpu, step):
592
 
593
  finally:
594
  vae.to("cpu") # Перемещаем VAE обратно на CPU
595
- original_model = original_model.to(dtype = dtype)
596
- if original_model is not None:
597
- del original_model
598
  # Очистка переменных, которые являются тензорами и были созданы в функции
599
  for var in list(locals().keys()):
600
  if isinstance(locals()[var], torch.Tensor):
@@ -608,6 +807,7 @@ if accelerator.is_main_process:
608
  if save_model:
609
  print("Генерация сэмплов до старта обучения...")
610
  generate_and_save_samples(fixed_samples,0)
 
611
 
612
  # Модифицируем функцию сохранения модели для поддержки LoRA
613
  def save_checkpoint(unet,variant=""):
@@ -639,6 +839,7 @@ min_loss = 1.
639
  # Начинаем с указанной эпохи (полезно при возобновлении)
640
  for epoch in range(start_epoch, start_epoch + num_epochs):
641
  batch_losses = []
 
642
  batch_grads = []
643
  #unet = unet.to(dtype = dtype)
644
  batch_sampler.set_epoch(epoch)
@@ -650,12 +851,6 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
650
  if save_model == False and step == 5 :
651
  used_gb = torch.cuda.max_memory_allocated() / 1024**3
652
  print(f"Шаг {step}: {used_gb:.2f} GB")
653
-
654
- #latents = latents.to(dtype = dtype)
655
- #embeddings = embeddings.to(dtype = dtype)
656
- #print(f"Latents dtype: {latents.dtype}")
657
- #print(f"Embeddings dtype: {embeddings.dtype}")
658
- #print(f"Noise dtype: {noise.dtype}")
659
 
660
  # Forward pass
661
  noise = torch.randn_like(latents, dtype=latents.dtype)
@@ -665,34 +860,51 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
665
 
666
  # Добавляем шум к латентам
667
  noisy_latents = scheduler.add_noise(latents, noise, timesteps)
 
 
 
 
668
 
669
  # Используем целевое значение
670
  model_pred = unet(noisy_latents, timesteps, embeddings).sample
671
  target_pred = scheduler.get_velocity(latents, noise, timesteps)
672
 
673
  # Считаем лосс
674
- # Проверяем model_pred на nan/inf
675
- #if torch.isnan(model_pred.float()).any() or torch.isinf(model_pred.float()).any():
676
- # print(f"Rank {accelerator.process_index}: Found nan/inf in model_pred",model_pred.float())
677
- # # Обработка nan/inf значений
678
- # model_pred = torch.nan_to_num(model_pred.float(), nan=0.0, posinf=1.0, neginf=-1.0)
679
- loss = torch.nn.functional.mse_loss(model_pred, target_pred)
 
 
 
 
 
 
 
 
680
 
681
  # Проверяем на nan/inf перед backward
682
  if torch.isnan(loss) or torch.isinf(loss):
683
  print(f"Rank {accelerator.process_index}: Found nan/inf in loss: {loss}")
684
- loss = torch.zeros_like(loss)
 
685
 
686
- # Делаем backward через Accelerator
687
- accelerator.backward(loss)
 
 
 
688
 
689
  if (global_step % 100 == 0) or (global_step % sample_interval == 0):
690
  accelerator.wait_for_everyone()
691
-
692
  grad = 0.0
693
  if not fbp:
694
- if accelerator.sync_gradients:
695
- grad = accelerator.clip_grad_norm_(unet.parameters(), 1.)
 
696
  optimizer.step()
697
  lr_scheduler.step()
698
  optimizer.zero_grad(set_to_none=True)
@@ -710,16 +922,19 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
710
  else:
711
  current_lr = lr_scheduler.get_last_lr()[0]
712
  batch_losses.append(loss.detach().item())
 
713
  batch_grads.append(grad)
714
 
715
  # Логируем в Wandb
716
  if use_wandb:
717
  wandb.log({
718
- "loss": loss.detach().item(),
719
  "learning_rate": current_lr,
720
  "epoch": epoch,
721
  "grad": grad,
722
- "global_step": global_step
 
 
723
  })
724
 
725
  # Генерируем сэмплы с заданным интервалом
@@ -728,17 +943,19 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
728
 
729
  # Выводим текущий лосс
730
  avg_loss = np.mean(batch_losses[-sample_interval:])
 
731
  avg_grad = torch.mean(torch.stack(batch_grads[-sample_interval:])).cpu().item()
732
  print(f"Эпоха {epoch}, шаг {global_step}, средний лосс: {avg_loss:.6f}")
733
 
734
  if save_model:
735
- if avg_loss < min_loss:
 
736
  min_loss = avg_loss
737
- save_checkpoint(unet,"fp16")
738
- save_checkpoint(unet)
739
  if use_wandb:
740
- wandb.log({"intermediate_loss": avg_loss})
741
- wandb.log({"intermediate_grad": avg_grad})
 
742
 
743
 
744
  # По окончании эпохи
@@ -750,11 +967,13 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
750
  wandb.log({"epoch_loss": avg_epoch_loss, "epoch": epoch+1})
751
 
752
  # Завершение обучения - сохраняем финальную модель
 
 
753
  if accelerator.is_main_process:
754
  print("Обучение завершено! Сохраняем финальную модель...")
755
  # Сохраняем основную модель
756
  if save_model:
757
- save_checkpoint(unet)
758
  print("Готово!")
759
 
760
  # randomize ode timesteps
 
5
  import matplotlib.pyplot as plt
6
  from torch.utils.data import DataLoader, Sampler
7
  from torch.utils.data.distributed import DistributedSampler
8
+ from torch.optim.lr_scheduler import LambdaLR
9
  from collections import defaultdict
10
  from torch.optim.lr_scheduler import LambdaLR
11
  from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler
 
22
  from diffusers.models.attention_processor import AttnProcessor2_0
23
  from datetime import datetime
24
  import bitsandbytes as bnb
25
+ import torch.nn.functional as F
26
 
27
  # --------------------------- Параметры ---------------------------
28
+ ds_path = "datasets/576"
 
 
 
 
 
29
  project = "unet"
30
+ batch_size = 50
31
+ base_learning_rate = 9e-6
32
+ min_learning_rate = 8e-6
33
+ num_epochs = 5
34
+ # samples/save per epoch
35
+ sample_interval_share = 5
36
  use_wandb = True
37
  save_model = True
38
+ use_decay = True
39
  fbp = False # fused backward pass
40
+ optimizer_type = "adam8bit"
 
41
  torch_compile = False
42
  unet_gradient = True
43
  clip_sample = False #Scheduler
44
  fixed_seed = False
45
  shuffle = True
46
+ dispersive_loss = True
47
+ torch.backends.cuda.matmul.allow_tf32 = True
48
+ torch.backends.cudnn.allow_tf32 = True
49
+ torch.backends.cuda.enable_mem_efficient_sdp(False)
50
  dtype = torch.float32
51
+ save_barrier = 1.03
52
+ dispersive_temperature=0.5
53
+ dispersive_weight=0.05
54
+ percentile_clipping = 90 # 8bit optim
55
  steps_offset = 1 # Scheduler
56
  limit = 0
57
  checkpoints_folder = ""
58
+ mixed_precision = "fp16"
59
  accelerator = Accelerator(mixed_precision=mixed_precision)
60
  device = accelerator.device
61
 
 
78
  if torch.cuda.is_available():
79
  torch.cuda.manual_seed_all(seed)
80
 
 
 
81
  # --------------------------- Параметры LoRA ---------------------------
82
  # pip install peft
83
  lora_name = "" #"nusha" # Имя для сохранения/загрузки LoRA адаптеров
 
86
 
87
  print("init")
88
 
89
+ class AccelerateDispersiveLoss:
90
+ def __init__(self, accelerator, temperature=0.5, weight=0.5):
91
+ self.accelerator = accelerator
92
+ self.temperature = temperature
93
+ self.weight = weight
94
+ self.activations = []
95
+ self.hooks = []
96
+
97
+ def register_hooks(self, model, target_layer="down_blocks.0"):
98
+ unwrapped_model = self.accelerator.unwrap_model(model)
99
+ print("=== Поиск слоев в unwrapped модели ===")
100
+ for name, module in unwrapped_model.named_modules():
101
+ if target_layer in name:
102
+ hook = module.register_forward_hook(self.hook_fn)
103
+ self.hooks.append(hook)
104
+ print(f"✅ Хук зарегистрирован на: {name}")
105
+ break
106
+
107
+ def hook_fn(self, module, input, output):
108
+
109
+ if isinstance(output, tuple):
110
+ activation = output[0]
111
+ else:
112
+ activation = output
113
+
114
+ if len(activation.shape) > 2:
115
+ activation = activation.view(activation.shape[0], -1)
116
+
117
+ self.activations.append(activation.detach())
118
+
119
+ def compute_dispersive_loss(self):
120
+ if not self.activations:
121
+ return torch.tensor(0.0, requires_grad=True)
122
+
123
+ local_activations = self.activations[-1].float()
124
+
125
+ batch_size = local_activations.shape[0]
126
+ if batch_size < 2:
127
+ return torch.tensor(0.0, requires_grad=True)
128
+
129
+ # Нормализация и вычисление loss
130
+ sf = local_activations / torch.norm(local_activations, dim=1, keepdim=True)
131
+ distance = torch.nn.functional.pdist(sf.float(), p=2) ** 2
132
+ exp_neg_dist = torch.exp(-distance / self.temperature) + 1e-5
133
+ dispersive_loss = torch.log(torch.mean(exp_neg_dist))
134
+
135
+ # ВАЖНО: он отриц и должен падать
136
+ return dispersive_loss
137
+
138
+ def compute_dispersive_loss2(self):
139
+ # Если нет активаций, возвращаем 0
140
+ if not self.activations:
141
+ return torch.tensor(0.0, device=self.accelerator.device, requires_grad=True)
142
+
143
+ # Работаем только с локальными активациями главного процесса
144
+ activations = self.activations[-1].float()
145
+
146
+ batch_size = activations.shape[0]
147
+ if batch_size < 2:
148
+ return torch.tensor(0.0, device=self.accelerator.device, requires_grad=True)
149
+
150
+ # Нормализация
151
+ norm = torch.norm(activations, dim=1, keepdim=True).clamp(min=1e-12)
152
+ sf = activations / norm
153
+
154
+ # Вычисляем расстояния
155
+ distance = torch.nn.functional.pdist(sf, p=2)
156
+ distance = distance.clamp(min=1e-12)
157
+ distance_squared = distance ** 2
158
+
159
+ # Вычисляем loss с клиппингом для стабильности
160
+ exp_neg_dist = torch.exp((-distance_squared / self.temperature).clamp(min=-20, max=20))
161
+ exp_neg_dist = exp_neg_dist + 1e-12
162
+
163
+ mean_exp = torch.mean(exp_neg_dist)
164
+ dispersive_loss = torch.log(mean_exp.clamp(min=1e-12))
165
+
166
+ return dispersive_loss
167
+
168
+ def clear_activations(self):
169
+ self.activations.clear()
170
+
171
+ def remove_hooks(self):
172
+ for hook in self.hooks:
173
+ hook.remove()
174
+ self.hooks.clear()
175
+
176
+ class AccelerateDispersiveLoss2:
177
+ def __init__(self, accelerator, temperature=0.5, weight=0.5):
178
+ self.accelerator = accelerator
179
+ self.temperature = temperature
180
+ self.weight = weight
181
+ self.activations = []
182
+ self.hooks = []
183
+
184
+ def register_hooks(self, model, target_layer="down_blocks.0"):
185
+ # Получаем "чистую" модель без DDP wrapper'а
186
+ unwrapped_model = self.accelerator.unwrap_model(model)
187
+
188
+ print("=== Поиск слоев в unwrapped модели ===")
189
+ for name, module in unwrapped_model.named_modules():
190
+ if target_layer in name:
191
+ hook = module.register_forward_hook(self.hook_fn)
192
+ self.hooks.append(hook)
193
+ print(f"✅ Хук зарегистрирован на: {name}")
194
+ break
195
+
196
+ def hook_fn(self, module, input, output):
197
+ if isinstance(output, tuple):
198
+ activation = output[0]
199
+ else:
200
+ activation = output
201
+
202
+ if len(activation.shape) > 2:
203
+ activation = activation.view(activation.shape[0], -1)
204
+
205
+ self.activations.append(activation.detach())
206
+
207
+ def compute_dispersive_loss_fix(self):
208
+ if not self.activations:
209
+ return torch.tensor(0.0, requires_grad=True)
210
+
211
+ local_activations = self.activations[-1]
212
+
213
+ # Собираем активации со всех GPU
214
+ if self.accelerator.num_processes > 1:
215
+ gathered_activations = self.accelerator.gather(local_activations)
216
+ else:
217
+ gathered_activations = local_activations
218
+
219
+ batch_size = gathered_activations.shape[0]
220
+ if batch_size < 2:
221
+ return torch.tensor(0.0, requires_grad=True)
222
+
223
+ # Переводим в float32 для стабильности
224
+ gathered_activations = gathered_activations.float()
225
+
226
+ # Нормализация с eps для стабильности
227
+ norm = torch.norm(gathered_activations, dim=1, keepdim=True).clamp(min=1e-12)
228
+ sf = gathered_activations / norm
229
+
230
+ # Вычисляем расстояния
231
+ distance = torch.nn.functional.pdist(sf, p=2)
232
+ distance = distance.clamp(min=1e-12) # избегаем слишком маленьких значений
233
+ distance_squared = distance ** 2
234
+
235
+ # Экспонента с клиппингом
236
+ exp_neg_dist = torch.exp((-distance_squared / self.temperature).clamp(min=-20, max=20))
237
+ exp_neg_dist = exp_neg_dist + 1e-12 # избегаем нулей
238
+
239
+ # Среднее и лог
240
+ mean_exp = torch.mean(exp_neg_dist)
241
+ dispersive_loss = torch.log(mean_exp.clamp(min=1e-12))
242
+
243
+ return dispersive_loss
244
+
245
+ def compute_dispersive_loss(self):
246
+ if not self.activations:
247
+ return torch.tensor(0.0, requires_grad=True)
248
+
249
+ local_activations = self.activations[-1].float()
250
+
251
+ # Собираем активации со всех GPU
252
+ if self.accelerator.num_processes > 1:
253
+ gathered_activations = self.accelerator.gather(local_activations)
254
+ else:
255
+ gathered_activations = local_activations
256
+
257
+ batch_size = gathered_activations.shape[0]
258
+ if batch_size < 2:
259
+ return torch.tensor(0.0, requires_grad=True)
260
+
261
+ # Нормализация и вычисление loss
262
+ sf = gathered_activations / torch.norm(gathered_activations, dim=1, keepdim=True)
263
+ sf = sf.float()
264
+ distance = torch.nn.functional.pdist(sf, p=2) ** 2
265
+ exp_neg_dist = torch.exp(-distance / self.temperature) + 1e-5
266
+ dispersive_loss = torch.log(torch.mean(exp_neg_dist))
267
+
268
+ # ВАЖНО: он отриц и должен падать
269
+ return dispersive_loss
270
+
271
+
272
+ def compute_dispersive_loss_single(self):
273
+ if not self.activations:
274
+ return torch.tensor(0.0, requires_grad=True)
275
+
276
+ local_activations = self.activations[-1] # Активации с текущего GPU
277
+
278
+ # Собираем активации со всех GPU
279
+ if self.accelerator.num_processes > 1:
280
+ # Используем accelerate для сбора
281
+ gathered_activations = self.accelerator.gather(local_activations)
282
+ else:
283
+ gathered_activations = local_activations
284
+
285
+ # На главном процессе вычисляем loss
286
+ if self.accelerator.is_main_process:
287
+ batch_size = gathered_activations.shape[0]
288
+ if batch_size < 2:
289
+ return torch.tensor(0.0, requires_grad=True)
290
+
291
+ # Нормализация и вычисление loss
292
+ sf = gathered_activations / torch.norm(gathered_activations, dim=1, keepdim=True)
293
+ distance = torch.nn.functional.pdist(sf, p=2) ** 2
294
+ exp_neg_dist = torch.exp(-distance / self.temperature) + 1e-5
295
+ dispersive_loss = torch.log(torch.mean(exp_neg_dist))
296
+
297
+ return dispersive_loss
298
+ else:
299
+ # На не-главных процессах возвращаем 0
300
+ return torch.tensor(0.0, requires_grad=True)
301
+
302
+ def clear_activations(self):
303
+ self.activations.clear()
304
+
305
+ def remove_hooks(self):
306
+ for hook in self.hooks:
307
+ hook.remove()
308
+ self.hooks.clear()
309
+
310
+
311
  # --------------------------- Инициализация WandB ---------------------------
312
  if use_wandb and accelerator.is_main_process:
313
  wandb.init(project=project+lora_name, config={
 
315
  "base_learning_rate": base_learning_rate,
316
  "num_epochs": num_epochs,
317
  "fbp": fbp,
318
+ "optimizer_type": optimizer_type,
319
  })
320
 
321
  # Включение Flash Attention 2/SDPA
 
337
  steps_offset = steps_offset
338
  )
339
 
340
+
341
  class DistributedResolutionBatchSampler(Sampler):
342
  def __init__(self, dataset, batch_size, num_replicas, rank, shuffle=True, drop_last=True):
343
  self.dataset = dataset
 
494
  embeddings = torch.tensor(np.array(valid_embeddings)).to(device,dtype=dtype)
495
 
496
  return latents, embeddings
 
 
 
 
497
 
498
  # Создаем ResolutionBatchSampler на основе индексов от DistributedSampler
499
  batch_sampler = DistributedResolutionBatchSampler(
 
524
  latest_checkpoint = os.path.join(checkpoints_folder, project)
525
  if os.path.isdir(latest_checkpoint):
526
  print("Загружаем UNet из чекпоинта:", latest_checkpoint)
527
+ #if dtype == torch.float32:
528
+ unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device,dtype=dtype)
529
+ #else:
530
+ #unet = UNet2DConditionModel.from_pretrained(latest_checkpoint, variant="fp16").to(device=device,dtype=dtype)
531
  if unet_gradient:
532
  unet.enable_gradient_checkpointing()
533
  unet.set_use_memory_efficient_attention_xformers(False) # отключаем xformers
 
544
  print(f"torch.backends.cuda.mem_efficient_sdp_enabled(): {torch.backends.cuda.mem_efficient_sdp_enabled()}")
545
  if hasattr(torch.nn.functional, "get_flash_attention_available"):
546
  print(f"torch.nn.functional.get_flash_attention_available(): {torch.nn.functional.get_flash_attention_available()}")
547
+
548
+ # Регистрируем хук на модел
549
+ if dispersive_loss:
550
+ dispersive_hook = AccelerateDispersiveLoss(
551
+ accelerator=accelerator,
552
+ temperature=dispersive_temperature,
553
+ weight=dispersive_weight
554
+ )
555
+
556
  if torch_compile:
557
  print("compiling")
558
  torch.set_float32_matmul_precision('high')
 
624
  if fbp:
625
  trainable_params = list(unet.parameters())
626
 
627
+ def create_optimizer(name, params):
628
+ if name == "adam8bit":
629
+ return bnb.optim.AdamW8bit(
630
+ params, lr=base_learning_rate, betas=(0.9, 0.97), eps=1e-5, weight_decay=0.001,
631
+ percentile_clipping=percentile_clipping
632
+ )
633
+ elif name == "adam":
634
+ return torch.optim.AdamW(
635
+ params, lr=base_learning_rate, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01
636
+ )
637
+ elif name == "lion8bit":
638
+ return bnb.optim.Lion8bit(
639
+ params, lr=base_learning_rate, betas=(0.9, 0.97), weight_decay=0.01,
640
+ percentile_clipping=percentile_clipping
641
+ )
642
+ elif name == "adafactor":
643
+ from transformers import Adafactor
644
+ return Adafactor(
645
+ params, lr=base_learning_rate, scale_parameter=True, relative_step=False,
646
+ warmup_init=False, eps=(1e-30, 1e-3), clip_threshold=1.0,
647
+ beta1=0.9, weight_decay=0.01
648
+ )
649
  else:
650
+ raise ValueError(f"Unknown optimizer: {name}")
651
+
652
+ if fbp:
653
+ # Создаем отдельный оптимизатор для каждого параметра
654
+ optimizer_dict = {p: create_optimizer(optimizer_type, [p]) for p in trainable_params}
 
 
 
 
655
 
 
656
  def optimizer_hook(param):
657
  optimizer_dict[param].step()
658
  optimizer_dict[param].zero_grad(set_to_none=True)
659
 
 
660
  for param in trainable_params:
661
  param.register_post_accumulate_grad_hook(optimizer_hook)
662
 
 
663
  unet, optimizer = accelerator.prepare(unet, optimizer_dict)
664
  else:
665
+ optimizer = create_optimizer(optimizer_type, unet.parameters())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
666
 
667
+ def lr_schedule(step):
668
+ x = step / (total_training_steps * world_size)
669
+ warmup = 0.05
670
+
671
+ if not use_decay:
672
+ return base_learning_rate
673
+ if x < warmup:
674
+ return min_learning_rate + (base_learning_rate - min_learning_rate) * (x / warmup)
675
+
676
+ decay_ratio = (x - warmup) / (1 - warmup)
677
+ return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \
678
+ (1 + math.cos(math.pi * decay_ratio))
679
 
680
+ lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
 
 
 
 
 
681
  unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
682
 
683
+ # Регистрация хуков ПОСЛЕ prepare
684
+ if dispersive_loss:
685
+ dispersive_hook.register_hooks(unet, "down_blocks.2")
686
+
687
  # --------------------------- Фиксированные семплы для генерации ---------------------------
688
  # Примеры фиксированных семплов по размерам
689
  fixed_samples = get_fixed_samples_by_resolution(dataset)
 
702
  original_model = None # Инициализируем, чтобы finally не ругался
703
  try:
704
 
705
+ original_model = accelerator.unwrap_model(unet).eval()
 
 
706
 
707
  vae.to(device=device, dtype=dtype)
708
  vae.eval()
 
794
 
795
  finally:
796
  vae.to("cpu") # Перемещаем VAE обратно на CPU
 
 
 
797
  # Очистка переменных, которые являются тензорами и были созданы в функции
798
  for var in list(locals().keys()):
799
  if isinstance(locals()[var], torch.Tensor):
 
807
  if save_model:
808
  print("Генерация сэмплов до старта обучения...")
809
  generate_and_save_samples(fixed_samples,0)
810
+ accelerator.wait_for_everyone()
811
 
812
  # Модифицируем функцию сохранения модели для поддержки LoRA
813
  def save_checkpoint(unet,variant=""):
 
839
  # Начинаем с указанной эпохи (полезно при возобновлении)
840
  for epoch in range(start_epoch, start_epoch + num_epochs):
841
  batch_losses = []
842
+ batch_tlosses = []
843
  batch_grads = []
844
  #unet = unet.to(dtype = dtype)
845
  batch_sampler.set_epoch(epoch)
 
851
  if save_model == False and step == 5 :
852
  used_gb = torch.cuda.max_memory_allocated() / 1024**3
853
  print(f"Шаг {step}: {used_gb:.2f} GB")
 
 
 
 
 
 
854
 
855
  # Forward pass
856
  noise = torch.randn_like(latents, dtype=latents.dtype)
 
860
 
861
  # Добавляем шум к латентам
862
  noisy_latents = scheduler.add_noise(latents, noise, timesteps)
863
+
864
+ # Очищаем активации перед forward pass
865
+ if dispersive_loss:
866
+ dispersive_hook.clear_activations()
867
 
868
  # Используем целевое значение
869
  model_pred = unet(noisy_latents, timesteps, embeddings).sample
870
  target_pred = scheduler.get_velocity(latents, noise, timesteps)
871
 
872
  # Считаем лосс
873
+ loss = torch.nn.functional.mse_loss(model_pred.float(), target_pred.float())
874
+
875
+ # Dispersive Loss
876
+ #Идентичные векторы: Loss = -0.0000
877
+ #Ортогональные векторы: Loss = -3.9995
878
+ if dispersive_loss:
879
+ with torch.amp.autocast('cuda', enabled=False):
880
+ dispersive_loss = dispersive_hook.weight * dispersive_hook.compute_dispersive_loss()
881
+ if torch.isnan(dispersive_loss) or torch.isinf(dispersive_loss):
882
+ print(f"Rank {accelerator.process_index}: Found nan/inf in dispersive_loss: {total_loss}")
883
+
884
+ # Итоговый loss
885
+ # dispersive_loss должен падать и тотал падать - поэтому плюс
886
+ total_loss = loss + dispersive_loss
887
 
888
  # Проверяем на nan/inf перед backward
889
  if torch.isnan(loss) or torch.isinf(loss):
890
  print(f"Rank {accelerator.process_index}: Found nan/inf in loss: {loss}")
891
+ save_model = False
892
+ break
893
 
894
+ if (global_step % 100 == 0) or (global_step % sample_interval == 0):
895
+ accelerator.wait_for_everyone()
896
+
897
+ # Делаем backward через Accelerator
898
+ accelerator.backward(total_loss)
899
 
900
  if (global_step % 100 == 0) or (global_step % sample_interval == 0):
901
  accelerator.wait_for_everyone()
902
+
903
  grad = 0.0
904
  if not fbp:
905
+ if accelerator.sync_gradients:
906
+ with torch.amp.autocast('cuda', enabled=False):
907
+ grad = accelerator.clip_grad_norm_(unet.parameters(), 0.25)
908
  optimizer.step()
909
  lr_scheduler.step()
910
  optimizer.zero_grad(set_to_none=True)
 
922
  else:
923
  current_lr = lr_scheduler.get_last_lr()[0]
924
  batch_losses.append(loss.detach().item())
925
+ batch_tlosses.append(total_loss.detach().item())
926
  batch_grads.append(grad)
927
 
928
  # Логируем в Wandb
929
  if use_wandb:
930
  wandb.log({
931
+ "mse_loss": loss.detach().item(),
932
  "learning_rate": current_lr,
933
  "epoch": epoch,
934
  "grad": grad,
935
+ "global_step": global_step,
936
+ "dispersive_loss": dispersive_loss,
937
+ "total_loss": total_loss
938
  })
939
 
940
  # Генерируем сэмплы с заданным интервалом
 
943
 
944
  # Выводим текущий лосс
945
  avg_loss = np.mean(batch_losses[-sample_interval:])
946
+ avg_tloss = np.mean(batch_tlosses[-sample_interval:])
947
  avg_grad = torch.mean(torch.stack(batch_grads[-sample_interval:])).cpu().item()
948
  print(f"Эпоха {epoch}, шаг {global_step}, средний лосс: {avg_loss:.6f}")
949
 
950
  if save_model:
951
+ print("saving:",avg_loss < min_loss*save_barrier)
952
+ if avg_loss < min_loss*save_barrier:
953
  min_loss = avg_loss
954
+ save_checkpoint(unet)
 
955
  if use_wandb:
956
+ wandb.log({"interm_loss": avg_loss})
957
+ wandb.log({"interm_totalloss": avg_tloss})
958
+ wandb.log({"interm_grad": avg_grad})
959
 
960
 
961
  # По окончании эпохи
 
967
  wandb.log({"epoch_loss": avg_epoch_loss, "epoch": epoch+1})
968
 
969
  # Завершение обучения - сохраняем финальную модель
970
+ if dispersive_loss:
971
+ dispersive_hook.remove_hooks()
972
  if accelerator.is_main_process:
973
  print("Обучение завершено! Сохраняем финальную модель...")
974
  # Сохраняем основную модель
975
  if save_model:
976
+ save_checkpoint(unet,"fp16")
977
  print("Готово!")
978
 
979
  # randomize ode timesteps
train.py CHANGED
@@ -27,12 +27,12 @@ import torch.nn.functional as F
27
  # --------------------------- Параметры ---------------------------
28
  ds_path = "datasets/576"
29
  project = "unet"
30
- batch_size = 50
31
- base_learning_rate = 9e-6
32
- min_learning_rate = 8e-6
33
- num_epochs = 5
34
  # samples/save per epoch
35
- sample_interval_share = 5
36
  use_wandb = True
37
  save_model = True
38
  use_decay = True
@@ -56,7 +56,11 @@ steps_offset = 1 # Scheduler
56
  limit = 0
57
  checkpoints_folder = ""
58
  mixed_precision = "fp16"
59
- accelerator = Accelerator(mixed_precision=mixed_precision)
 
 
 
 
60
  device = accelerator.device
61
 
62
  # Параметры для диффузии
@@ -905,15 +909,16 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
905
  if accelerator.sync_gradients:
906
  with torch.amp.autocast('cuda', enabled=False):
907
  grad = accelerator.clip_grad_norm_(unet.parameters(), 0.25)
908
- optimizer.step()
909
- lr_scheduler.step()
910
- optimizer.zero_grad(set_to_none=True)
911
 
912
  # Увеличиваем счетчик глобальных шагов
913
- global_step += 1
 
914
 
915
- # Обновляем прогресс-бар
916
- progress_bar.update(1)
917
 
918
  # Логируем метрики
919
  if accelerator.is_main_process:
 
27
  # --------------------------- Параметры ---------------------------
28
  ds_path = "datasets/576"
29
  project = "unet"
30
+ batch_size = 40
31
+ base_learning_rate = 9.5e-6
32
+ min_learning_rate = 9e-6
33
+ num_epochs = 2
34
  # samples/save per epoch
35
+ sample_interval_share = 10
36
  use_wandb = True
37
  save_model = True
38
  use_decay = True
 
56
  limit = 0
57
  checkpoints_folder = ""
58
  mixed_precision = "fp16"
59
+ gradient_accumulation_steps = 2
60
+ accelerator = Accelerator(
61
+ mixed_precision=mixed_precision,
62
+ gradient_accumulation_steps=gradient_accumulation_steps
63
+ )
64
  device = accelerator.device
65
 
66
  # Параметры для диффузии
 
909
  if accelerator.sync_gradients:
910
  with torch.amp.autocast('cuda', enabled=False):
911
  grad = accelerator.clip_grad_norm_(unet.parameters(), 0.25)
912
+ optimizer.step()
913
+ lr_scheduler.step()
914
+ optimizer.zero_grad(set_to_none=True)
915
 
916
  # Увеличиваем счетчик глобальных шагов
917
+ if accelerator.sync_gradients:
918
+ global_step += 1
919
 
920
+ # Обновляем прогресс-бар
921
+ progress_bar.update(1)
922
 
923
  # Логируем метрики
924
  if accelerator.is_main_process:
train_dispersive.py DELETED
@@ -1,898 +0,0 @@
1
- import os
2
- import math
3
- import torch
4
- import numpy as np
5
- import matplotlib.pyplot as plt
6
- from torch.utils.data import DataLoader, Sampler
7
- from torch.utils.data.distributed import DistributedSampler
8
- from torch.optim.lr_scheduler import LambdaLR
9
- from collections import defaultdict
10
- from torch.optim.lr_scheduler import LambdaLR
11
- from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler
12
- from accelerate import Accelerator
13
- from datasets import load_from_disk
14
- from tqdm import tqdm
15
- from PIL import Image,ImageOps
16
- import wandb
17
- import random
18
- import gc
19
- from accelerate.state import DistributedType
20
- from torch.distributed import broadcast_object_list
21
- from torch.utils.checkpoint import checkpoint
22
- from diffusers.models.attention_processor import AttnProcessor2_0
23
- from datetime import datetime
24
- import bitsandbytes as bnb
25
- import torch.nn.functional as F
26
-
27
- # --------------------------- Параметры ---------------------------
28
- ds_path = "datasets/384"
29
- project = "unet"
30
- batch_size = 30
31
- base_learning_rate = 3e-5
32
- min_learning_rate = 1e-6
33
- num_epochs = 15
34
- # samples/save per epoch
35
- sample_interval_share = 10
36
- use_wandb = True
37
- save_model = True
38
- use_decay = True
39
- fbp = False # fused backward pass
40
- adam8bit = True
41
- torch_compile = False
42
- unet_gradient = True
43
- clip_sample = False #Scheduler
44
- fixed_seed = False
45
- shuffle = True
46
- torch.backends.cuda.matmul.allow_tf32 = True
47
- torch.backends.cudnn.allow_tf32 = True
48
- torch.backends.cuda.enable_mem_efficient_sdp(False)
49
- dtype = torch.float32
50
- save_barrier = 1.03
51
- percentile_clipping = 97 # Lion
52
- steps_offset = 1 # Scheduler
53
- limit = 0
54
- checkpoints_folder = ""
55
- mixed_precision = "no"
56
- accelerator = Accelerator(mixed_precision=mixed_precision)
57
- device = accelerator.device
58
-
59
- # Параметры для диффузии
60
- n_diffusion_steps = 50
61
- samples_to_generate = 12
62
- guidance_scale = 5
63
-
64
- # Папки для сохранения результатов
65
- generated_folder = "samples"
66
- os.makedirs(generated_folder, exist_ok=True)
67
-
68
- # Настройка seed для воспроизводимости
69
- current_date = datetime.now()
70
- seed = int(current_date.strftime("%Y%m%d"))
71
- if fixed_seed:
72
- torch.manual_seed(seed)
73
- np.random.seed(seed)
74
- random.seed(seed)
75
- if torch.cuda.is_available():
76
- torch.cuda.manual_seed_all(seed)
77
-
78
- # --------------------------- Параметры LoRA ---------------------------
79
- # pip install peft
80
- lora_name = "" #"nusha" # Имя для сохранения/загрузки LoRA адаптеров
81
- lora_rank = 32 # Ранг LoRA (чем меньше, тем компактнее модель)
82
- lora_alpha = 64 # Альфа параметр LoRA, определяющий масштаб
83
-
84
- print("init")
85
-
86
- class AccelerateDispersiveLoss:
87
- def __init__(self, accelerator, temperature=0.5, weight=0.5):
88
- self.accelerator = accelerator
89
- self.temperature = temperature
90
- self.weight = weight
91
- self.activations = []
92
- self.hooks = []
93
-
94
- def register_hooks(self, model, target_layer="down_blocks.0"):
95
- # Получаем "чистую" модель без DDP wrapper'а
96
- unwrapped_model = self.accelerator.unwrap_model(model)
97
-
98
- print("=== Поиск слоев в unwrapped модели ===")
99
- for name, module in unwrapped_model.named_modules():
100
- if target_layer in name:
101
- hook = module.register_forward_hook(self.hook_fn)
102
- self.hooks.append(hook)
103
- print(f"✅ Хук зарегистрирован на: {name}")
104
- break
105
-
106
- def hook_fn(self, module, input, output):
107
- if isinstance(output, tuple):
108
- activation = output[0]
109
- else:
110
- activation = output
111
-
112
- if len(activation.shape) > 2:
113
- activation = activation.view(activation.shape[0], -1)
114
-
115
- self.activations.append(activation.detach())
116
-
117
- def compute_dispersive_loss(self):
118
- if not self.activations:
119
- return torch.tensor(0.0, requires_grad=True)
120
-
121
- local_activations = self.activations[-1]
122
-
123
- # Собираем активации со всех GPU
124
- if self.accelerator.num_processes > 1:
125
- gathered_activations = self.accelerator.gather(local_activations)
126
- else:
127
- gathered_activations = local_activations
128
-
129
- batch_size = gathered_activations.shape[0]
130
- if batch_size < 2:
131
- return torch.tensor(0.0, requires_grad=True)
132
-
133
- # Нормализация и вычисление loss
134
- sf = gathered_activations / torch.norm(gathered_activations, dim=1, keepdim=True)
135
- distance = torch.nn.functional.pdist(sf, p=2) ** 2
136
- exp_neg_dist = torch.exp(-distance / self.temperature) + 1e-5
137
- dispersive_loss = torch.log(torch.mean(exp_neg_dist))
138
-
139
- # ВАЖНО: он отриц и должен падать
140
- return dispersive_loss
141
-
142
-
143
- def compute_dispersive_loss_single(self):
144
- if not self.activations:
145
- return torch.tensor(0.0, requires_grad=True)
146
-
147
- local_activations = self.activations[-1] # Активации с текущего GPU
148
-
149
- # Собираем активации со всех GPU
150
- if self.accelerator.num_processes > 1:
151
- # Используем accelerate для сбора
152
- gathered_activations = self.accelerator.gather(local_activations)
153
- else:
154
- gathered_activations = local_activations
155
-
156
- # На главном процессе вычисляем loss
157
- if self.accelerator.is_main_process:
158
- batch_size = gathered_activations.shape[0]
159
- if batch_size < 2:
160
- return torch.tensor(0.0, requires_grad=True)
161
-
162
- # Нормализация и вычисление loss
163
- sf = gathered_activations / torch.norm(gathered_activations, dim=1, keepdim=True)
164
- distance = torch.nn.functional.pdist(sf, p=2) ** 2
165
- exp_neg_dist = torch.exp(-distance / self.temperature) + 1e-5
166
- dispersive_loss = torch.log(torch.mean(exp_neg_dist))
167
-
168
- return dispersive_loss
169
- else:
170
- # На не-главных процессах возвращаем 0
171
- return torch.tensor(0.0, requires_grad=True)
172
-
173
- def clear_activations(self):
174
- self.activations.clear()
175
-
176
- def remove_hooks(self):
177
- for hook in self.hooks:
178
- hook.remove()
179
- self.hooks.clear()
180
-
181
-
182
- # --------------------------- Инициализация WandB ---------------------------
183
- if use_wandb and accelerator.is_main_process:
184
- wandb.init(project=project+lora_name, config={
185
- "batch_size": batch_size,
186
- "base_learning_rate": base_learning_rate,
187
- "num_epochs": num_epochs,
188
- "fbp": fbp,
189
- "adam8bit": adam8bit,
190
- })
191
-
192
- # Включение Flash Attention 2/SDPA
193
- torch.backends.cuda.enable_flash_sdp(True)
194
- # --------------------------- Инициализация Accelerator --------------------
195
- gen = torch.Generator(device=device)
196
- gen.manual_seed(seed)
197
-
198
- # --------------------------- Загрузка моделей ---------------------------
199
- # VAE загружается на CPU для экономии GPU-памяти
200
- vae = AutoencoderKL.from_pretrained("vae", variant="fp16").to("cpu").eval()
201
-
202
- # DDPMScheduler с V_Prediction и Zero-SNR
203
- scheduler = DDPMScheduler(
204
- num_train_timesteps=1000, # Полный график шагов для обучения
205
- prediction_type="v_prediction", # V-Prediction
206
- rescale_betas_zero_snr=True, # Включение Zero-SNR
207
- clip_sample = clip_sample,
208
- steps_offset = steps_offset
209
- )
210
-
211
-
212
- class DistributedResolutionBatchSampler(Sampler):
213
- def __init__(self, dataset, batch_size, num_replicas, rank, shuffle=True, drop_last=True):
214
- self.dataset = dataset
215
- self.batch_size = max(1, batch_size // num_replicas)
216
- self.num_replicas = num_replicas
217
- self.rank = rank
218
- self.shuffle = shuffle
219
- self.drop_last = drop_last
220
- self.epoch = 0
221
-
222
- # Используем numpy для ускорения
223
- try:
224
- widths = np.array(dataset["width"])
225
- heights = np.array(dataset["height"])
226
- except KeyError:
227
- widths = np.zeros(len(dataset))
228
- heights = np.zeros(len(dataset))
229
-
230
- # Создаем уникальные ключи для размеров
231
- self.size_keys = np.unique(np.stack([widths, heights], axis=1), axis=0)
232
-
233
- # Группируем индексы по размерам используя numpy
234
- self.size_groups = {}
235
- for w, h in self.size_keys:
236
- mask = (widths == w) & (heights == h)
237
- self.size_groups[(w, h)] = np.where(mask)[0]
238
-
239
- # Предварительно вычисляем количество полных батчей для каждой группы
240
- self.group_num_batches = {}
241
- total_batches = 0
242
- for size, indices in self.size_groups.items():
243
- num_full_batches = len(indices) // (self.batch_size * self.num_replicas)
244
- self.group_num_batches[size] = num_full_batches
245
- total_batches += num_full_batches
246
-
247
- # Округляем до числа, делящегося на num_replicas
248
- self.num_batches = (total_batches // self.num_replicas) * self.num_replicas
249
-
250
- def __iter__(self):
251
- # print(f"Rank {self.rank}: Starting iteration")
252
- # Очищаем CUDA кэш перед формированием новых батчей
253
- if torch.cuda.is_available():
254
- torch.cuda.empty_cache()
255
- all_batches = []
256
- rng = np.random.RandomState(self.epoch)
257
-
258
- for size, indices in self.size_groups.items():
259
- # print(f"Rank {self.rank}: Processing size {size}, {len(indices)} samples")
260
- indices = indices.copy()
261
- if self.shuffle:
262
- rng.shuffle(indices)
263
-
264
- num_full_batches = self.group_num_batches[size]
265
- if num_full_batches == 0:
266
- continue
267
-
268
- # Берем только индексы для полных батчей
269
- valid_indices = indices[:num_full_batches * self.batch_size * self.num_replicas]
270
-
271
- # Reshape для быстрого разделения на батчи
272
- batches = valid_indices.reshape(-1, self.batch_size * self.num_replicas)
273
-
274
- # Выбираем часть для текущего GPU
275
- start_idx = self.rank * self.batch_size
276
- end_idx = start_idx + self.batch_size
277
- gpu_batches = batches[:, start_idx:end_idx]
278
-
279
- all_batches.extend(gpu_batches)
280
-
281
- if self.shuffle:
282
- rng.shuffle(all_batches)
283
-
284
- # Синхронизируем все процессы после формирования батчей
285
- accelerator.wait_for_everyone()
286
- # print(f"Rank {self.rank}: Created {len(all_batches)} batches")
287
- return iter(all_batches)
288
-
289
- def __len__(self):
290
- return self.num_batches
291
-
292
- def set_epoch(self, epoch):
293
- self.epoch = epoch
294
-
295
- # Функция для выборки фиксированных семплов по размерам
296
- def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
297
- """Выбирает фиксированные семплы для каждого уникального разрешения"""
298
- # Группируем по размерам
299
- size_groups = defaultdict(list)
300
- try:
301
- widths = dataset["width"]
302
- heights = dataset["height"]
303
- except KeyError:
304
- widths = [0] * len(dataset)
305
- heights = [0] * len(dataset)
306
- for i, (w, h) in enumerate(zip(widths, heights)):
307
- size = (w, h)
308
- size_groups[size].append(i)
309
-
310
- # Выбираем фиксированные примеры из каждой группы
311
- fixed_samples = {}
312
- for size, indices in size_groups.items():
313
- # Определяем сколько семплов брать из этой группы
314
- n_samples = min(samples_per_group, len(indices))
315
- if len(size_groups)==1:
316
- n_samples = samples_to_generate
317
- if n_samples == 0:
318
- continue
319
-
320
- # Выбираем случайные индексы
321
- sample_indices = random.sample(indices, n_samples)
322
- samples_data = [dataset[idx] for idx in sample_indices]
323
-
324
- # Собираем данные
325
- latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device,dtype=dtype)
326
- embeddings = torch.tensor(np.array([item["embeddings"] for item in samples_data])).to(device,dtype=dtype)
327
- texts = [item["text"] for item in samples_data]
328
-
329
- # Сохраняем для этого размера
330
- fixed_samples[size] = (latents, embeddings, texts)
331
-
332
- print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
333
- return fixed_samples
334
-
335
- if limit > 0:
336
- dataset = load_from_disk(ds_path).select(range(limit))
337
- else:
338
- dataset = load_from_disk(ds_path)
339
-
340
- def collate_fn_simple(batch):
341
- # Преобразуем список в тензоры и перемещаем на девайс
342
- latents = torch.tensor(np.array([item["vae"] for item in batch])).to(device,dtype=dtype)
343
- embeddings = torch.tensor(np.array([item["embeddings"] for item in batch])).to(device,dtype=dtype)
344
- return latents, embeddings
345
-
346
- def collate_fn(batch):
347
- if not batch:
348
- return [], []
349
-
350
- # Берем эталонную форму
351
- ref_vae_shape = np.array(batch[0]["vae"]).shape
352
- ref_embed_shape = np.array(batch[0]["embeddings"]).shape
353
-
354
- # Фильтруем
355
- valid_latents = []
356
- valid_embeddings = []
357
- for item in batch:
358
- if (np.array(item["vae"]).shape == ref_vae_shape and
359
- np.array(item["embeddings"]).shape == ref_embed_shape):
360
- valid_latents.append(item["vae"])
361
- valid_embeddings.append(item["embeddings"])
362
-
363
- # Создаем тензоры
364
- latents = torch.tensor(np.array(valid_latents)).to(device,dtype=dtype)
365
- embeddings = torch.tensor(np.array(valid_embeddings)).to(device,dtype=dtype)
366
-
367
- return latents, embeddings
368
-
369
- # Создаем ResolutionBatchSampler на основе индексов от DistributedSampler
370
- batch_sampler = DistributedResolutionBatchSampler(
371
- dataset=dataset,
372
- batch_size=batch_size,
373
- num_replicas=accelerator.num_processes,
374
- rank=accelerator.process_index,
375
- shuffle=shuffle
376
- )
377
-
378
- # Создаем DataLoader
379
- dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple)
380
-
381
- print("Total samples",len(dataloader))
382
- dataloader = accelerator.prepare(dataloader)
383
-
384
- # Инициализация переменных для возобновления обучения
385
- start_epoch = 0
386
- global_step = 0
387
-
388
- # Расчёт общего количества шагов
389
- total_training_steps = (len(dataloader) * num_epochs)
390
- # Get the world size
391
- world_size = accelerator.state.num_processes
392
- #print(f"World Size: {world_size}")
393
-
394
- # Опция загрузки модели из последнего чекпоинта (если существует)
395
- latest_checkpoint = os.path.join(checkpoints_folder, project)
396
- if os.path.isdir(latest_checkpoint):
397
- print("Загружаем UNet из чекпоинта:", latest_checkpoint)
398
- #if dtype == torch.float32:
399
- # unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device,dtype=dtype)
400
- #else:
401
- unet = UNet2DConditionModel.from_pretrained(latest_checkpoint, variant="fp16").to(device=device,dtype=dtype)
402
- if unet_gradient:
403
- unet.enable_gradient_checkpointing()
404
- unet.set_use_memory_efficient_attention_xformers(False) # отключаем xformers
405
- try:
406
- unet.set_attn_processor(AttnProcessor2_0()) # Используем стандартный AttnProcessor
407
- except Exception as e:
408
- print(f"Ошибка при включении SDPA: {e}")
409
- print("Попытка использовать enable_xformers_memory_efficient_attention.")
410
- unet.set_use_memory_efficient_attention_xformers(True)
411
-
412
- if hasattr(torch.backends.cuda, "flash_sdp_enabled"):
413
- print(f"torch.backends.cuda.flash_sdp_enabled(): {torch.backends.cuda.flash_sdp_enabled()}")
414
- if hasattr(torch.backends.cuda, "mem_efficient_sdp_enabled"):
415
- print(f"torch.backends.cuda.mem_efficient_sdp_enabled(): {torch.backends.cuda.mem_efficient_sdp_enabled()}")
416
- if hasattr(torch.nn.functional, "get_flash_attention_available"):
417
- print(f"torch.nn.functional.get_flash_attention_available(): {torch.nn.functional.get_flash_attention_available()}")
418
-
419
- # Регистрируем хуки на модел
420
- dispersive_hook = AccelerateDispersiveLoss(
421
- accelerator=accelerator,
422
- temperature=2,
423
- weight=0.25
424
- )
425
-
426
- if torch_compile:
427
- print("compiling")
428
- torch.set_float32_matmul_precision('high')
429
- unet = torch.compile(unet)#, mode="reduce-overhead", fullgraph=True)
430
- print("compiling - ok")
431
-
432
- if lora_name:
433
- print(f"--- Настройка LoRA через PEFT (Rank={lora_rank}, Alpha={lora_alpha}) ---")
434
- from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
435
- from peft.tuners.lora import LoraModel
436
- import os
437
- # 1. Замораживаем все параметры UNet
438
- unet.requires_grad_(False)
439
- print("Параметры базового UNet заморожены.")
440
-
441
- # 2. Создаем конфигурацию LoRA
442
- lora_config = LoraConfig(
443
- r=lora_rank,
444
- lora_alpha=lora_alpha,
445
- target_modules=["to_q", "to_k", "to_v", "to_out.0"],
446
- )
447
- unet.add_adapter(lora_config)
448
-
449
- # 3. Оборачиваем UNet в PEFT-модель
450
- from peft import get_peft_model
451
-
452
- peft_unet = get_peft_model(unet, lora_config)
453
-
454
- # 4. Получаем параметры для оптимизации
455
- params_to_optimize = list(p for p in peft_unet.parameters() if p.requires_grad)
456
-
457
-
458
- # 5. Выводим информацию о количестве параметров
459
- if accelerator.is_main_process:
460
- lora_params_count = sum(p.numel() for p in params_to_optimize)
461
- total_params_count = sum(p.numel() for p in unet.parameters())
462
- print(f"Количество обучаемых параметров (LoRA): {lora_params_count:,}")
463
- print(f"Общее количество параметров UNet: {total_params_count:,}")
464
-
465
- # 6. Путь для сохранения
466
- lora_save_path = os.path.join("lora", lora_name)
467
- os.makedirs(lora_save_path, exist_ok=True)
468
-
469
- # 7. Функция для сохранения
470
- def save_lora_checkpoint(model):
471
- if accelerator.is_main_process:
472
- print(f"Сохраняем LoRA адаптеры в {lora_save_path}")
473
- from peft.utils.save_and_load import get_peft_model_state_dict
474
- # Получаем state_dict только LoRA
475
- lora_state_dict = get_peft_model_state_dict(model)
476
-
477
- # Сохраняем веса
478
- torch.save(lora_state_dict, os.path.join(lora_save_path, "adapter_model.bin"))
479
-
480
- # Сохраняем конфиг
481
- model.peft_config["default"].save_pretrained(lora_save_path)
482
- # SDXL must be compatible
483
- from diffusers import StableDiffusionXLPipeline
484
- StableDiffusionXLPipeline.save_lora_weights(lora_save_path, lora_state_dict)
485
-
486
- # --------------------------- Оптимизатор ---------------------------
487
- # Определяем параметры для оптимизации
488
- #unet = torch.compile(unet)
489
- if lora_name:
490
- # Если используется LoRA, оптимизируем только параметры LoRA
491
- trainable_params = [p for p in unet.parameters() if p.requires_grad]
492
- else:
493
- # Иначе оптимизируем все параметры
494
- if fbp:
495
- trainable_params = list(unet.parameters())
496
-
497
- if fbp:
498
- # [1] Создаем словарь оптимизаторов (fused backward)
499
- if adam8bit:
500
- optimizer_dict = {
501
- p: bnb.optim.AdamW8bit(
502
- [p], # Каждый параметр получает свой оптимизатор
503
- lr=base_learning_rate,
504
- eps=1e-8
505
- ) for p in trainable_params
506
- }
507
- else:
508
- optimizer_dict = {
509
- p: bnb.optim.Lion8bit(
510
- [p], # Каждый параметр получает свой оптимизатор
511
- lr=base_learning_rate,
512
- betas=(0.9, 0.97),
513
- weight_decay=0.01,
514
- percentile_clipping=percentile_clipping,
515
- ) for p in trainable_params
516
- }
517
-
518
- # [2] Определяем hook для применения оптимизатора сразу после накопления градиента
519
- def optimizer_hook(param):
520
- optimizer_dict[param].step()
521
- optimizer_dict[param].zero_grad(set_to_none=True)
522
-
523
- # [3] Регистрируем hook для trainable параметров модели
524
- for param in trainable_params:
525
- param.register_post_accumulate_grad_hook(optimizer_hook)
526
-
527
- # Подготовка через Accelerator
528
- unet, optimizer = accelerator.prepare(unet, optimizer_dict)
529
- else:
530
- if adam8bit:
531
- optimizer = bnb.optim.AdamW8bit(
532
- params=unet.parameters(),
533
- lr=base_learning_rate,
534
- betas=(0.9, 0.999),
535
- eps=1e-8,
536
- weight_decay=0.01
537
- )
538
- #from torch.optim import AdamW
539
- #optimizer = AdamW(
540
- # params=unet.parameters(),
541
- # lr=base_learning_rate,
542
- # betas=(0.9, 0.999),
543
- # eps=1e-8,
544
- # weight_decay=0.01
545
- #)
546
- else:
547
- optimizer = bnb.optim.Lion8bit(
548
- params=unet.parameters(),
549
- lr=base_learning_rate,
550
- betas=(0.9, 0.97),
551
- weight_decay=0.01,
552
- percentile_clipping=percentile_clipping,
553
- )
554
-
555
- def lr_schedule(step, max_steps, base_lr, min_lr, use_decay=True):
556
- # Если не используем затухание, возвращаем базовый LR
557
- if not use_decay:
558
- return base_lr
559
-
560
- # Иначе используем линейный прогрев и косинусное затухание
561
- x = step / max_steps
562
- percent = 0.05
563
- if x < percent:
564
- # Линейный прогрев до percent% шагов
565
- return min_lr + (base_lr - min_lr) * (x / percent)
566
- else:
567
- # Косинусное затухание
568
- decay_ratio = (x - percent) / (1 - percent)
569
- return min_lr + 0.5 * (base_lr - min_lr) * (1 + math.cos(math.pi * decay_ratio))
570
-
571
-
572
- def custom_lr_lambda(step):
573
- return lr_schedule(step, total_training_steps*world_size,
574
- base_learning_rate, min_learning_rate,
575
- use_decay) / base_learning_rate
576
-
577
- lr_scheduler = LambdaLR(optimizer, lr_lambda=custom_lr_lambda)
578
- unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
579
-
580
- # Регистрация хуков ПОСЛЕ prepare
581
- dispersive_hook.register_hooks(unet, "down_blocks.2")
582
-
583
- # --------------------------- Фиксированные семплы для генерации ---------------------------
584
- # Примеры фиксированных семплов по размерам
585
- fixed_samples = get_fixed_samples_by_resolution(dataset)
586
-
587
- @torch.compiler.disable()
588
- @torch.no_grad()
589
- def generate_and_save_samples(fixed_samples_cpu, step):
590
- """
591
- Генерирует семплы для каждого из разрешений и сохраняет их.
592
-
593
- Args:
594
- fixed_samples_cpu: Словарь, где ключи - размеры (width, height),
595
- а значения - кортежи (latents, embeddings, text) на CPU.
596
- step: Текущий шаг обучения
597
- """
598
- original_model = None # Инициализируем, чтобы finally не ругался
599
- try:
600
-
601
- original_model = accelerator.unwrap_model(unet)
602
- original_model = original_model.to(dtype = dtype)
603
- original_model.eval()
604
-
605
- vae.to(device=device, dtype=dtype)
606
- vae.eval()
607
-
608
- scheduler.set_timesteps(n_diffusion_steps)
609
-
610
- all_generated_images = []
611
- all_captions = []
612
-
613
- for size, (sample_latents, sample_text_embeddings, sample_text) in fixed_samples_cpu.items():
614
- width, height = size
615
-
616
- sample_latents = sample_latents.to(dtype=dtype)
617
- sample_text_embeddings = sample_text_embeddings.to(dtype=dtype)
618
-
619
- # Инициализируем латенты случайным шумом
620
- # sample_latents уже в dtype, так что noise будет создан в dtype
621
- noise = torch.randn(
622
- sample_latents.shape, # Используем форму от sample_latents, которые теперь на GPU и fp16
623
- generator=gen,
624
- device=device,
625
- dtype=sample_latents.dtype
626
- )
627
- current_latents = noise.clone()
628
-
629
- # Подготовка текстовых эмбеддингов для guidance
630
- if guidance_scale > 0:
631
- # empty_embeddings должны быть того же типа и на том же устройстве
632
- empty_embeddings = torch.zeros_like(sample_text_embeddings, dtype=sample_text_embeddings.dtype, device=device)
633
- text_embeddings_batch = torch.cat([empty_embeddings, sample_text_embeddings], dim=0)
634
- else:
635
- text_embeddings_batch = sample_text_embeddings
636
-
637
- for t in scheduler.timesteps:
638
- t_batch = t.repeat(current_latents.shape[0]).to(device) # Убедимся, что t на устройстве
639
-
640
- if guidance_scale > 0:
641
- latent_model_input = torch.cat([current_latents] * 2)
642
- else:
643
- latent_model_input = current_latents
644
-
645
- latent_model_input_scaled = scheduler.scale_model_input(latent_model_input, t_batch)
646
-
647
- # Предсказание шума (UNet)
648
- noise_pred = original_model(latent_model_input_scaled, t_batch, text_embeddings_batch).sample
649
-
650
- if guidance_scale > 0:
651
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
652
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
653
-
654
- current_latents = scheduler.step(noise_pred, t, current_latents).prev_sample
655
-
656
- #print(f"current_latents Min: {current_latents.min()} Max: {current_latents.max()}")
657
- # Декодирование через VAE
658
- latent_for_vae = (current_latents.detach() / vae.config.scaling_factor) + vae.config.shift_factor
659
- decoded = vae.decode(latent_for_vae).sample
660
-
661
- # Преобразуем тензоры в PIL-изображения
662
- # Для математики с изображением (нормализация) лучше перейти в fp32
663
- decoded_fp32 = decoded.to(torch.float32)
664
- for img_idx, img_tensor in enumerate(decoded_fp32):
665
- img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1, 2, 0)
666
- # If NaNs or infs are present, print them
667
- if np.isnan(img).any():
668
- print("NaNs found, saving stoped! Step:", step)
669
- save_model = False
670
- pil_img = Image.fromarray((img * 255).astype("uint8"))
671
-
672
- max_w_overall = max(s[0] for s in fixed_samples_cpu.keys())
673
- max_h_overall = max(s[1] for s in fixed_samples_cpu.keys())
674
- max_w_overall = max(255, max_w_overall)
675
- max_h_overall = max(255, max_h_overall)
676
-
677
- padded_img = ImageOps.pad(pil_img, (max_w_overall, max_h_overall), color='white')
678
- all_generated_images.append(padded_img)
679
-
680
- caption_text = sample_text[img_idx][:200] if img_idx < len(sample_text) else ""
681
- all_captions.append(caption_text)
682
-
683
- sample_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg"
684
- pil_img.save(sample_path, "JPEG", quality=96)
685
-
686
- if use_wandb and accelerator.is_main_process:
687
- wandb_images = [
688
- wandb.Image(img, caption=f"{all_captions[i]}")
689
- for i, img in enumerate(all_generated_images)
690
- ]
691
- wandb.log({"generated_images": wandb_images, "global_step": step})
692
-
693
- finally:
694
- vae.to("cpu") # Перемещаем VAE обратно на CPU
695
- original_model = original_model.to(dtype = dtype)
696
- if original_model is not None:
697
- del original_model
698
- # Очистка переменных, которые являются тензорами и были созданы в функции
699
- for var in list(locals().keys()):
700
- if isinstance(locals()[var], torch.Tensor):
701
- del locals()[var]
702
-
703
- torch.cuda.empty_cache()
704
- gc.collect()
705
-
706
- # --------------------------- Генерация сэмплов перед обучением ---------------------------
707
- if accelerator.is_main_process:
708
- if save_model:
709
- print("Генерация сэмплов до старта обучения...")
710
- generate_and_save_samples(fixed_samples,0)
711
-
712
- # Модифицируем функцию сохранения модели для поддержки LoRA
713
- def save_checkpoint(unet,variant=""):
714
- if accelerator.is_main_process:
715
- if lora_name:
716
- # Сохраняем только LoRA адаптеры
717
- save_lora_checkpoint(unet)
718
- else:
719
- # Сохраняем полную модель
720
- if variant!="":
721
- accelerator.unwrap_model(unet.to(dtype=torch.float16)).save_pretrained(os.path.join(checkpoints_folder, f"{project}"),variant=variant)
722
- else:
723
- accelerator.unwrap_model(unet).save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
724
- unet = unet.to(dtype=dtype)
725
-
726
- # --------------------------- Тренировочный цикл ---------------------------
727
- # Для логирования среднего лосса каждые % эпохи
728
- if accelerator.is_main_process:
729
- print(f"Total steps per GPU: {total_training_steps}")
730
-
731
- epoch_loss_points = []
732
- progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
733
-
734
- # Определяем интервал для сэмплирования и логирования в пределах эпохи (10% эпохи)
735
- steps_per_epoch = len(dataloader)
736
- sample_interval = max(1, steps_per_epoch // sample_interval_share)
737
- min_loss = 1.
738
-
739
- # Начинаем с указанной эпохи (полезно при возобновлении)
740
- for epoch in range(start_epoch, start_epoch + num_epochs):
741
- batch_losses = []
742
- batch_tlosses = []
743
- batch_grads = []
744
- #unet = unet.to(dtype = dtype)
745
- batch_sampler.set_epoch(epoch)
746
- accelerator.wait_for_everyone()
747
- unet.train()
748
- print("epoch:",epoch)
749
- for step, (latents, embeddings) in enumerate(dataloader):
750
- with accelerator.accumulate(unet):
751
- if save_model == False and step == 5 :
752
- used_gb = torch.cuda.max_memory_allocated() / 1024**3
753
- print(f"Шаг {step}: {used_gb:.2f} GB")
754
-
755
- # Forward pass
756
- noise = torch.randn_like(latents, dtype=latents.dtype)
757
-
758
- timesteps = torch.randint(steps_offset, scheduler.config.num_train_timesteps,
759
- (latents.shape[0],), device=device).long()
760
-
761
- # Добавляем шум к латентам
762
- noisy_latents = scheduler.add_noise(latents, noise, timesteps)
763
-
764
- # Очищаем активации перед forward pass
765
- dispersive_hook.clear_activations()
766
-
767
- # Используем целевое значение
768
- model_pred = unet(noisy_latents, timesteps, embeddings).sample
769
- target_pred = scheduler.get_velocity(latents, noise, timesteps)
770
-
771
- # Считаем лосс
772
- loss = torch.nn.functional.mse_loss(model_pred, target_pred)
773
-
774
- # Dispersive Loss
775
- #Идентичные векторы: Loss = -0.0000
776
- #Ортогональные векторы: Loss = -3.9995
777
- dispersive_loss = dispersive_hook.weight * dispersive_hook.compute_dispersive_loss()
778
-
779
- # Итоговый loss
780
- # dispersive_loss должен падать и тотал падать - поэтому плюс
781
- total_loss = loss + dispersive_loss
782
-
783
- # Проверяем на nan/inf перед backward
784
- if torch.isnan(loss) or torch.isinf(loss):
785
- print(f"Rank {accelerator.process_index}: Found nan/inf in loss: {loss}")
786
- save_model = False
787
- break
788
-
789
- # Делаем backward через Accelerator
790
- accelerator.backward(total_loss)
791
-
792
- if (global_step % 100 == 0) or (global_step % sample_interval == 0):
793
- accelerator.wait_for_everyone()
794
-
795
- grad = 0.0
796
- if not fbp:
797
- if accelerator.sync_gradients:
798
- grad = accelerator.clip_grad_norm_(unet.parameters(), 1.)
799
- optimizer.step()
800
- lr_scheduler.step()
801
- optimizer.zero_grad(set_to_none=True)
802
-
803
- # Увеличиваем счетчик глобальных шагов
804
- global_step += 1
805
-
806
- # Обновляем прогресс-бар
807
- progress_bar.update(1)
808
-
809
- # Логируем метрики
810
- if accelerator.is_main_process:
811
- if fbp:
812
- current_lr = base_learning_rate
813
- else:
814
- current_lr = lr_scheduler.get_last_lr()[0]
815
- batch_losses.append(loss.detach().item())
816
- batch_tlosses.append(total_loss.detach().item())
817
- batch_grads.append(grad)
818
-
819
- # Логируем в Wandb
820
- if use_wandb:
821
- wandb.log({
822
- "mse_loss": loss.detach().item(),
823
- "learning_rate": current_lr,
824
- "epoch": epoch,
825
- "grad": grad,
826
- "global_step": global_step,
827
- "dispersive_loss": dispersive_loss,
828
- "total_loss": total_loss
829
- })
830
-
831
- # Генерируем сэмплы с заданным интервалом
832
- if global_step % sample_interval == 0:
833
- generate_and_save_samples(fixed_samples,global_step)
834
-
835
- # Выводим текущий лосс
836
- avg_loss = np.mean(batch_losses[-sample_interval:])
837
- avg_tloss = np.mean(batch_tlosses[-sample_interval:])
838
- avg_grad = torch.mean(torch.stack(batch_grads[-sample_interval:])).cpu().item()
839
- print(f"Эпоха {epoch}, шаг {global_step}, средний лосс: {avg_loss:.6f}")
840
-
841
- if save_model:
842
- print("saving:",avg_loss < min_loss*save_barrier)
843
- if avg_loss < min_loss*save_barrier:
844
- min_loss = avg_loss
845
- save_checkpoint(unet)
846
- if use_wandb:
847
- wandb.log({"interm_loss": avg_loss})
848
- wandb.log({"interm_totalloss": avg_tloss})
849
- wandb.log({"interm_grad": avg_grad})
850
-
851
-
852
- # По окончании эпохи
853
- #accelerator.wait_for_everyone()
854
- if accelerator.is_main_process:
855
- avg_epoch_loss = np.mean(batch_losses)
856
- print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
857
- if use_wandb:
858
- wandb.log({"epoch_loss": avg_epoch_loss, "epoch": epoch+1})
859
-
860
- # Завершение обучения - сохраняем финальную модель
861
- dispersive_hook.remove_hooks()
862
- if accelerator.is_main_process:
863
- print("Обучение завершено! Сохраняем финальную модель...")
864
- # Сохраняем основную модель
865
- if save_model:
866
- save_checkpoint(unet,"fp16")
867
- print("Готово!")
868
-
869
- # randomize ode timesteps
870
- # input_timestep = torch.round(
871
- # F.sigmoid(torch.randn((n,), device=latents.device)), decimals=3
872
- # )
873
-
874
- #def create_distribution(num_points, device=None):
875
- # # Диапазон вероятностей на оси x
876
- # x = torch.linspace(0, 1, num_points, device=device)
877
-
878
- # Пользовательская функция плотности вероятности
879
- # probabilities = -7.7 * ((x - 0.5) ** 2) + 2
880
-
881
- # Нормализация, чтобы сумма равнялась 1
882
- # probabilities /= probabilities.sum()
883
-
884
- # return x, probabilities
885
-
886
- #def sample_from_distribution(x, probabilities, n, device=None):
887
- # Выбор индексов на основе распределения вероятностей
888
- # indices = torch.multinomial(probabilities, n, replacement=True)
889
- # return x[indices]
890
-
891
- # Пример использования
892
- #num_points = 1000 # Количество точек в диапазоне
893
- #n = latents.shape[0] # Количество временных шагов для выборки
894
- #x, probabilities = create_distribution(num_points, device=latents.device)
895
- #timesteps = sample_from_distribution(x, probabilities, n, device=latents.device)
896
-
897
- # Преобразование в формат, подходящий для вашего кода
898
- #timesteps = (timesteps * (scheduler.config.num_train_timesteps - 1)).long()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train_dist.py DELETED
@@ -1,713 +0,0 @@
1
- import os
2
- import math
3
- import torch
4
- import numpy as np
5
- import matplotlib.pyplot as plt
6
- from torch.utils.data import DataLoader, Sampler
7
- from torch.utils.data.distributed import DistributedSampler
8
- from collections import defaultdict
9
- from torch.optim.lr_scheduler import LambdaLR
10
- from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler
11
- from accelerate import Accelerator
12
- from datasets import load_from_disk
13
- from tqdm import tqdm
14
- from PIL import Image,ImageOps
15
- import wandb
16
- import random
17
- import gc
18
- from accelerate.state import DistributedType
19
- from torch.distributed import broadcast_object_list
20
- from torch.utils.checkpoint import checkpoint
21
- from diffusers.models.attention_processor import AttnProcessor2_0
22
- from datetime import datetime
23
- import bitsandbytes as bnb
24
-
25
- # --------------------------- Параметры ---------------------------
26
- ds_path = "datasets/384"
27
- batch_size = 30
28
- base_learning_rate = 3e-5
29
- num_epochs = 4
30
- num_warmup_steps = 500
31
- project = "unet"
32
- use_wandb = True
33
- save_model = True
34
- sample_interval_share = 10 # samples/save per epoch
35
- fbp = False # fused backward pass
36
- adam8bit = True
37
- percentile_clipping = 97 # Lion
38
- torch_compile = False
39
- unet_gradient = True
40
- clip_sample = False #Scheduler
41
- fixed_seed = False
42
- dtype_unet = torch.float32
43
- dtype_embed = torch.float32
44
- dtype_infer = torch.float16
45
- steps_offset = 1 # Scheduler
46
- limit = 0
47
- checkpoints_folder = ""
48
- mixed_precision = "no"
49
- accelerator = Accelerator(mixed_precision=mixed_precision)
50
- device = accelerator.device
51
-
52
- # Параметры для диффузии
53
- n_diffusion_steps = 50
54
- samples_to_generate = 12
55
- guidance_scale = 5
56
-
57
- # Папки для сохранения результатов
58
- generated_folder = "samples"
59
- os.makedirs(generated_folder, exist_ok=True)
60
-
61
- # Настройка seed для воспроизводимости
62
- current_date = datetime.now()
63
- seed = int(current_date.strftime("%Y%m%d"))
64
- if fixed_seed:
65
- torch.manual_seed(seed)
66
- np.random.seed(seed)
67
- random.seed(seed)
68
- if torch.cuda.is_available():
69
- torch.cuda.manual_seed_all(seed)
70
-
71
- # --------------------------- Параметры LoRA ---------------------------
72
- # pip install peft
73
- lora_name = "" #"nusha" # Имя для сохранения/загрузки LoRA адаптеров
74
- lora_rank = 32 # Ранг LoRA (чем меньше, тем компактнее модель)
75
- lora_alpha = 64 # Альфа параметр LoRA, определяющий масштаб
76
-
77
- print("init")
78
-
79
- # --------------------------- Инициализация WandB ---------------------------
80
- if use_wandb and accelerator.is_main_process:
81
- wandb.init(project=project+lora_name, config={
82
- "batch_size": batch_size,
83
- "base_learning_rate": base_learning_rate,
84
- "num_epochs": num_epochs,
85
- "fbp": fbp,
86
- "adam8bit": adam8bit,
87
- })
88
-
89
- # Включение Flash Attention 2/SDPA
90
- torch.backends.cuda.enable_flash_sdp(True)
91
- # --------------------------- Инициализация Accelerator --------------------
92
- gen = torch.Generator(device=device)
93
- gen.manual_seed(seed)
94
-
95
- # --------------------------- Загрузка моделей ---------------------------
96
- # VAE загружается на CPU для экономии GPU-памяти
97
- vae = AutoencoderKL.from_pretrained("vae", variant="fp16").to("cpu").eval()
98
-
99
- # DDPMScheduler с V_Prediction и Zero-SNR
100
- scheduler = DDPMScheduler(
101
- num_train_timesteps=1000, # Полный график шагов для обучения
102
- prediction_type="v_prediction", # V-Prediction
103
- rescale_betas_zero_snr=True, # Включение Zero-SNR
104
- clip_sample = clip_sample,
105
- steps_offset = steps_offset
106
- )
107
-
108
- # --------------------------- Загрузка датасета ---------------------------
109
- class ResolutionBatchSampler(Sampler):
110
- """Сэмплер, который группирует примеры по одинаковым размерам"""
111
- def __init__(self, dataset, batch_size, shuffle=True, drop_last=False):
112
- self.dataset = dataset
113
- self.batch_size = batch_size
114
- self.shuffle = shuffle
115
- self.drop_last = drop_last
116
-
117
- # Группируем примеры по размерам
118
- self.size_groups = defaultdict(list)
119
-
120
- try:
121
- widths = dataset["width"]
122
- heights = dataset["height"]
123
- except KeyError:
124
- widths = [0] * len(dataset)
125
- heights = [0] * len(dataset)
126
-
127
- for i, (w, h) in enumerate(zip(widths, heights)):
128
- size = (w, h)
129
- self.size_groups[size].append(i)
130
-
131
- # Печатаем статистику по размерам
132
- print(f"Найдено {len(self.size_groups)} уникальных размеров:")
133
- for size, indices in sorted(self.size_groups.items(), key=lambda x: len(x[1]), reverse=True):
134
- width, height = size
135
- print(f" {width}x{height}: {len(indices)} пример��в")
136
-
137
- # Формируем батчи
138
- self.reset()
139
-
140
- def reset(self):
141
- """Сбрасывает и перемешивает индексы"""
142
- self.batches = []
143
-
144
- for size, indices in self.size_groups.items():
145
- if self.shuffle:
146
- indices_copy = indices.copy()
147
- random.shuffle(indices_copy)
148
- else:
149
- indices_copy = indices
150
-
151
- # Разбиваем на батчи
152
- for i in range(0, len(indices_copy), self.batch_size):
153
- batch_indices = indices_copy[i:i + self.batch_size]
154
-
155
- # Пропускаем неполные батчи если drop_last=True
156
- if self.drop_last and len(batch_indices) < self.batch_size:
157
- continue
158
-
159
- self.batches.append(batch_indices)
160
-
161
- # Перемешиваем батчи между собой
162
- if self.shuffle:
163
- random.shuffle(self.batches)
164
-
165
- def __iter__(self):
166
- self.reset() # Сбрасываем и перемешиваем в начале каждой эпохи
167
- return iter(self.batches)
168
-
169
- def __len__(self):
170
- return len(self.batches)
171
-
172
- # Функция для выборки фиксированных семплов по размерам
173
- def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
174
- """Выбирает фиксированные семплы для каждого уникального разрешения"""
175
- # Группируем по размерам
176
- size_groups = defaultdict(list)
177
- try:
178
- widths = dataset["width"]
179
- heights = dataset["height"]
180
- except KeyError:
181
- widths = [0] * len(dataset)
182
- heights = [0] * len(dataset)
183
- for i, (w, h) in enumerate(zip(widths, heights)):
184
- size = (w, h)
185
- size_groups[size].append(i)
186
-
187
- # Выбираем фиксированные примеры из каждой группы
188
- fixed_samples = {}
189
- for size, indices in size_groups.items():
190
- # Определяем сколько семплов брать из этой группы
191
- n_samples = min(samples_per_group, len(indices))
192
- if len(size_groups)==1:
193
- n_samples = samples_to_generate
194
- if n_samples == 0:
195
- continue
196
-
197
- # Выбираем случайные индексы
198
- sample_indices = random.sample(indices, n_samples)
199
- samples_data = [dataset[idx] for idx in sample_indices]
200
-
201
- # Собираем данные
202
- latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device,dtype=dtype_embed)
203
- embeddings = torch.tensor(np.array([item["embeddings"] for item in samples_data])).to(device,dtype=dtype_embed)
204
- texts = [item["text"] for item in samples_data]
205
-
206
- # Сохраняем для этого размера
207
- fixed_samples[size] = (latents, embeddings, texts)
208
-
209
- print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
210
- return fixed_samples
211
-
212
- if limit > 0:
213
- dataset = load_from_disk(ds_path).select(range(limit))
214
- else:
215
- dataset = load_from_disk(ds_path)
216
-
217
- # Создаем DistributedSampler
218
- if accelerator.num_processes > 1:
219
- dist_sampler = DistributedSampler(dataset, shuffle=False)
220
- else:
221
- dist_sampler = None
222
-
223
- def collate_fn_simple(batch):
224
- # Преобразуем список в тензоры и перемещаем на девайс
225
- latents = torch.tensor(np.array([item["vae"] for item in batch])).to(device,dtype=dtype_embed)
226
- embeddings = torch.tensor(np.array([item["embeddings"] for item in batch])).to(device,dtype=dtype_embed)
227
- return latents, embeddings
228
-
229
- def collate_fn(batch):
230
- if not batch:
231
- return [], []
232
-
233
- # Берем эталонную форму
234
- ref_vae_shape = np.array(batch[0]["vae"]).shape
235
- ref_embed_shape = np.array(batch[0]["embeddings"]).shape
236
-
237
- # Фильтруем
238
- valid_latents = []
239
- valid_embeddings = []
240
- for item in batch:
241
- if (np.array(item["vae"]).shape == ref_vae_shape and
242
- np.array(item["embeddings"]).shape == ref_embed_shape):
243
- valid_latents.append(item["vae"])
244
- valid_embeddings.append(item["embeddings"])
245
-
246
- # Создаем тензоры
247
- latents = torch.tensor(np.array(valid_latents)).to(device,dtype=dtype_embed)
248
- embeddings = torch.tensor(np.array(valid_embeddings)).to(device,dtype=dtype_embed)
249
-
250
- return latents, embeddings
251
-
252
- # Используем наш ResolutionBatchSampler
253
- #batch_sampler = ResolutionBatchSampler(dataset, batch_size=batch_size, shuffle=True)
254
- #dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn)
255
-
256
- # Создаем ResolutionBatchSampler на осн��ве индексов от DistributedSampler
257
- if dist_sampler is not None:
258
- batch_sampler = ResolutionBatchSampler(list(dist_sampler), dataset, batch_size=batch_size, shuffle=True)
259
- else:
260
- batch_sampler = ResolutionBatchSampler(list(range(len(dataset))), dataset, batch_size=batch_size, shuffle=True)
261
-
262
- # Создаем DataLoader
263
- dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple)
264
-
265
-
266
- print("Total samples",len(dataloader))
267
- dataloader = accelerator.prepare(dataloader)
268
-
269
- # Инициализация переменных для возобновления обучения
270
- start_epoch = 0
271
- global_step = 0
272
-
273
- # Расчёт общего количества шагов
274
- total_training_steps = (len(dataloader) * num_epochs)
275
- # Get the world size
276
- world_size = accelerator.state.num_processes
277
- #print(f"World Size: {world_size}")
278
-
279
- # Опция загрузки модели из последнего чекпоинта (если существует)
280
- latest_checkpoint = os.path.join(checkpoints_folder, project)
281
- if os.path.isdir(latest_checkpoint):
282
- print("Загружаем UNet из чекпоинта:", latest_checkpoint)
283
- unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device,dtype=dtype_unet)
284
- if unet_gradient:
285
- unet.enable_gradient_checkpointing()
286
- unet.set_use_memory_efficient_attention_xformers(False) # отключаем xformers
287
- try:
288
- unet.set_attn_processor(AttnProcessor2_0()) # Используем стандартный AttnProcessor
289
- except Exception as e:
290
- print(f"Ошибка при включении SDPA: {e}")
291
- print("Попытка использовать enable_xformers_memory_efficient_attention.")
292
- unet.set_use_memory_efficient_attention_xformers(True)
293
-
294
- if hasattr(torch.backends.cuda, "flash_sdp_enabled"):
295
- print(f"torch.backends.cuda.flash_sdp_enabled(): {torch.backends.cuda.flash_sdp_enabled()}")
296
- if hasattr(torch.backends.cuda, "mem_efficient_sdp_enabled"):
297
- print(f"torch.backends.cuda.mem_efficient_sdp_enabled(): {torch.backends.cuda.mem_efficient_sdp_enabled()}")
298
- if hasattr(torch.nn.functional, "get_flash_attention_available"):
299
- print(f"torch.nn.functional.get_flash_attention_available(): {torch.nn.functional.get_flash_attention_available()}")
300
- if torch_compile:
301
- print("compiling")
302
- torch.set_float32_matmul_precision('high')
303
- unet = torch.compile(unet)#, mode="reduce-overhead", fullgraph=True)
304
- print("compiling - ok")
305
-
306
- if lora_name:
307
- print(f"--- Настройка LoRA через PEFT (Rank={lora_rank}, Alpha={lora_alpha}) ---")
308
- from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
309
- from peft.tuners.lora import LoraModel
310
- import os
311
- # 1. Замораживаем все параметры UNet
312
- unet.requires_grad_(False)
313
- print("Параметры базового UNet заморожены.")
314
-
315
- # 2. Создаем конфигурацию LoRA
316
- lora_config = LoraConfig(
317
- r=lora_rank,
318
- lora_alpha=lora_alpha,
319
- target_modules=["to_q", "to_k", "to_v", "to_out.0"],
320
- )
321
- unet.add_adapter(lora_config)
322
-
323
- # 3. Оборачиваем UNet в PEFT-модель
324
- from peft import get_peft_model
325
-
326
- peft_unet = get_peft_model(unet, lora_config)
327
-
328
- # 4. Получаем параметры для оптимизации
329
- params_to_optimize = list(p for p in peft_unet.parameters() if p.requires_grad)
330
-
331
-
332
- # 5. Выводим информацию о количестве параметров
333
- if accelerator.is_main_process:
334
- lora_params_count = sum(p.numel() for p in params_to_optimize)
335
- total_params_count = sum(p.numel() for p in unet.parameters())
336
- print(f"Количество обучаемых параметров (LoRA): {lora_params_count:,}")
337
- print(f"Общее количество параметров UNet: {total_params_count:,}")
338
-
339
- # 6. Путь для сохранения
340
- lora_save_path = os.path.join("lora", lora_name)
341
- os.makedirs(lora_save_path, exist_ok=True)
342
-
343
- # 7. Функция для сохранения
344
- def save_lora_checkpoint(model):
345
- if accelerator.is_main_process:
346
- print(f"Сохраняем LoRA адаптеры в {lora_save_path}")
347
- from peft.utils.save_and_load import get_peft_model_state_dict
348
- # Получаем state_dict только LoRA
349
- lora_state_dict = get_peft_model_state_dict(model)
350
-
351
- # Сохраняем веса
352
- torch.save(lora_state_dict, os.path.join(lora_save_path, "adapter_model.bin"))
353
-
354
- # Сохраняем конфиг
355
- model.peft_config["default"].save_pretrained(lora_save_path)
356
- # SDXL must be compatible
357
- from diffusers import StableDiffusionXLPipeline
358
- StableDiffusionXLPipeline.save_lora_weights(lora_save_path, lora_state_dict)
359
-
360
- # --------------------------- Оптимизатор ---------------------------
361
- # Определяем параметры для оптимизации
362
- #unet = torch.compile(unet)
363
- if lora_name:
364
- # Если используется LoRA, оптимизируем только параметры LoRA
365
- trainable_params = [p for p in unet.parameters() if p.requires_grad]
366
- else:
367
- # Иначе оптимизируем все параметры
368
- if fbp:
369
- trainable_params = list(unet.parameters())
370
-
371
- if fbp:
372
- # [1] Создаем словарь оптимизаторов (fused backward)
373
- if adam8bit:
374
- optimizer_dict = {
375
- p: bnb.optim.AdamW8bit(
376
- [p], # Каждый параметр получает свой оптимизатор
377
- lr=base_learning_rate,
378
- eps=1e-8
379
- ) for p in trainable_params
380
- }
381
- else:
382
- optimizer_dict = {
383
- p: bnb.optim.Lion8bit(
384
- [p], # Каждый параметр получает свой оптимизатор
385
- lr=base_learning_rate,
386
- betas=(0.9, 0.97),
387
- weight_decay=0.01,
388
- percentile_clipping=percentile_clipping,
389
- ) for p in trainable_params
390
- }
391
-
392
- # [2] Определяем hook для применения оптимизатора сразу после накопления градиента
393
- def optimizer_hook(param):
394
- optimizer_dict[param].step()
395
- optimizer_dict[param].zero_grad(set_to_none=True)
396
-
397
- # [3] Регистрируем hook для trainable параметров модели
398
- for param in trainable_params:
399
- param.register_post_accumulate_grad_hook(optimizer_hook)
400
-
401
- # Подготовка через Accelerator
402
- unet, optimizer = accelerator.prepare(unet, optimizer_dict)
403
- else:
404
- if adam8bit:
405
- optimizer = bnb.optim.AdamW8bit(
406
- params=unet.parameters(),
407
- lr=base_learning_rate,
408
- eps=1e-8
409
- )
410
- else:
411
- optimizer = bnb.optim.Lion8bit(
412
- params=unet.parameters(),
413
- lr=base_learning_rate,
414
- betas=(0.9, 0.97),
415
- weight_decay=0.01,
416
- percentile_clipping=percentile_clipping,
417
- )
418
- from transformers import get_constant_schedule_with_warmup
419
-
420
- # warmup
421
- num_warmup_steps = num_warmup_steps * world_size
422
-
423
- lr_scheduler = get_constant_schedule_with_warmup(
424
- optimizer=optimizer,
425
- num_warmup_steps=num_warmup_steps
426
- )
427
-
428
- unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
429
-
430
- # --------------------------- Фиксированные семплы для генерации ---------------------------
431
- # Примеры фиксированных семплов по размерам
432
- fixed_samples = get_fixed_samples_by_resolution(dataset)
433
-
434
- @torch.compiler.disable()
435
- @torch.no_grad()
436
- def generate_and_save_samples(fixed_samples_cpu, step):
437
- """
438
- Генерирует семплы для каждого из разрешений и сохраняет их.
439
-
440
- Args:
441
- fixed_samples_cpu: Словарь, где ключи - размеры (width, height),
442
- а значения - кортежи (latents, embeddings, text) на CPU.
443
- step: Текущий шаг обучения
444
- """
445
- original_model = None # Инициализируем, чтобы finally не ругался
446
- try:
447
-
448
- original_model = accelerator.unwrap_model(unet)
449
- original_model = original_model.to(dtype = dtype_infer)
450
- original_model.eval()
451
-
452
- vae.to(device=device, dtype=dtype_infer)
453
- vae.eval()
454
-
455
- scheduler.set_timesteps(n_diffusion_steps)
456
-
457
- all_generated_images = []
458
- all_captions = []
459
-
460
- for size, (sample_latents, sample_text_embeddings, sample_text) in fixed_samples_cpu.items():
461
- width, height = size
462
-
463
- sample_latents = sample_latents.to(dtype=dtype_infer)
464
- sample_text_embeddings = sample_text_embeddings.to(dtype=dtype_infer)
465
-
466
- # Инициализируем латенты случайным шумом
467
- # sample_latents уже в dtype_infer, так что noise будет создан в dtype_infer
468
- noise = torch.randn(
469
- sample_latents.shape, # Используем форму от sample_latents, которые теперь на GPU и fp16
470
- generator=gen,
471
- device=device,
472
- dtype=sample_latents.dtype
473
- )
474
- current_latents = noise.clone()
475
-
476
- # Подготовка текстовых эмбеддингов для guidance
477
- if guidance_scale > 0:
478
- # empty_embeddings должны быть того же типа и на том же устройстве
479
- empty_embeddings = torch.zeros_like(sample_text_embeddings, dtype=sample_text_embeddings.dtype, device=device)
480
- text_embeddings_batch = torch.cat([empty_embeddings, sample_text_embeddings], dim=0)
481
- else:
482
- text_embeddings_batch = sample_text_embeddings
483
-
484
- for t in scheduler.timesteps:
485
- t_batch = t.repeat(current_latents.shape[0]).to(device) # Убедимся, что t на устройстве
486
-
487
- if guidance_scale > 0:
488
- latent_model_input = torch.cat([current_latents] * 2)
489
- else:
490
- latent_model_input = current_latents
491
-
492
- latent_model_input_scaled = scheduler.scale_model_input(latent_model_input, t_batch)
493
-
494
- # Предсказание шума (UNet)
495
- noise_pred = original_model(latent_model_input_scaled, t_batch, text_embeddings_batch).sample
496
-
497
- if guidance_scale > 0:
498
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
499
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
500
-
501
- current_latents = scheduler.step(noise_pred, t, current_latents).prev_sample
502
-
503
- #print(f"current_latents Min: {current_latents.min()} Max: {current_latents.max()}")
504
- # Декодирование через VAE
505
- latent_for_vae = (current_latents.detach() / vae.config.scaling_factor) + vae.config.shift_factor
506
- decoded = vae.decode(latent_for_vae).sample
507
-
508
- # Преобразуем тензоры в PIL-изображения
509
- # Для математики с изображением (нормализация) лучше перейти в fp32
510
- decoded_fp32 = decoded.to(torch.float32)
511
- for img_idx, img_tensor in enumerate(decoded_fp32):
512
- img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1, 2, 0)
513
- # If NaNs or infs are present, print them
514
- if np.isnan(img).any():
515
- print("NaNs found, saving stoped! Step:", step)
516
- save_model = False
517
- pil_img = Image.fromarray((img * 255).astype("uint8"))
518
-
519
- max_w_overall = max(s[0] for s in fixed_samples_cpu.keys())
520
- max_h_overall = max(s[1] for s in fixed_samples_cpu.keys())
521
- max_w_overall = max(255, max_w_overall)
522
- max_h_overall = max(255, max_h_overall)
523
-
524
- padded_img = ImageOps.pad(pil_img, (max_w_overall, max_h_overall), color='white')
525
- all_generated_images.append(padded_img)
526
-
527
- caption_text = sample_text[img_idx][:200] if img_idx < len(sample_text) else ""
528
- all_captions.append(caption_text)
529
-
530
- sample_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg"
531
- pil_img.save(sample_path, "JPEG", quality=96)
532
-
533
- if use_wandb and accelerator.is_main_process:
534
- wandb_images = [
535
- wandb.Image(img, caption=f"{all_captions[i]}")
536
- for i, img in enumerate(all_generated_images)
537
- ]
538
- wandb.log({"generated_images": wandb_images, "global_step": step})
539
-
540
- finally:
541
- vae.to("cpu") # Перемещаем VAE обратно на CPU
542
- original_model = original_model.to(dtype = dtype_unet)
543
- if original_model is not None:
544
- del original_model
545
- # Очистка переменных, которые являются тензорами и были созданы в функции
546
- for var in list(locals().keys()):
547
- if isinstance(locals()[var], torch.Tensor):
548
- del locals()[var]
549
-
550
- torch.cuda.empty_cache()
551
- gc.collect()
552
-
553
- # --------------------------- Генерация сэмплов перед обучением ---------------------------
554
- if accelerator.is_main_process:
555
- if save_model:
556
- print("Генерация сэмплов до старта обучения...")
557
- generate_and_save_samples(fixed_samples,0)
558
-
559
- # Модифицируем функцию сохранения модели для поддержки LoRA
560
- def save_checkpoint(unet,variant=""):
561
- if accelerator.is_main_process:
562
- if lora_name:
563
- # Сохраняем только LoRA адаптеры
564
- save_lora_checkpoint(unet)
565
- else:
566
- # Сохраняем полную модель
567
- if variant!="":
568
- accelerator.unwrap_model(unet.to(dtype=dtype_infer)).save_pretrained(os.path.join(checkpoints_folder, f"{project}"),variant=variant)
569
- else:
570
- accelerator.unwrap_model(unet.to(dtype=dtype_infer)).save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
571
- unet = unet.to(dtype=dtype_unet)
572
-
573
- # --------------------------- Тренировочный цикл ---------------------------
574
- # Для логирования среднего лосса каждые % эпохи
575
- if accelerator.is_main_process:
576
- print(f"Total steps per GPU: {total_training_steps}")
577
-
578
- epoch_loss_points = []
579
- progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
580
-
581
- # Определяем интервал для сэмплирования и логирования в пределах эпохи (10% эпохи)
582
- steps_per_epoch = len(dataloader)
583
- sample_interval = max(1, steps_per_epoch // sample_interval_share)
584
-
585
- # Начинаем с указанной эпохи (полезно при возобновлении)
586
- for epoch in range(start_epoch, start_epoch + num_epochs):
587
- batch_losses = []
588
- batch_grads = []
589
- unet = unet.to(dtype = dtype_unet)
590
- if dist_sampler is not None:
591
- dist_sampler.set_epoch(epoch) # Важно для правильного shuffling
592
- unet.train()
593
- for step, (latents, embeddings) in enumerate(dataloader):
594
- with accelerator.accumulate(unet):
595
- if save_model == False and step == 5 :
596
- used_gb = torch.cuda.max_memory_allocated() / 1024**3
597
- print(f"Шаг {step}: {used_gb:.2f} GB")
598
-
599
- # Forward pass
600
- noise = torch.randn_like(latents, dtype=latents.dtype)
601
-
602
- timesteps = torch.randint(steps_offset, scheduler.config.num_train_timesteps,
603
- (latents.shape[0],), device=device).long()
604
-
605
- # Добавляем шум к латентам
606
- noisy_latents = scheduler.add_noise(latents, noise, timesteps)
607
-
608
- # Используем целевое значение
609
- model_pred = unet(noisy_latents, timesteps, embeddings).sample
610
- target_pred = scheduler.get_velocity(latents, noise, timesteps)
611
-
612
- # Считаем лосс
613
- loss = torch.nn.functional.mse_loss(model_pred, target_pred)
614
-
615
- # Делаем backward через Accelerator
616
- accelerator.backward(loss)
617
-
618
- grad = 0.0
619
- if not fbp:
620
- if accelerator.sync_gradients:
621
- grad = accelerator.clip_grad_norm_(unet.parameters(), 1.)
622
- accelerator.wait_for_everyone()
623
- optimizer.step()
624
- lr_scheduler.step()
625
- optimizer.zero_grad(set_to_none=True)
626
-
627
- # Увеличиваем счетчик глобальных шагов
628
- global_step += 1
629
-
630
- # Обновляем прогресс-бар
631
- progress_bar.update(1)
632
-
633
- # Логируем метрики
634
- if accelerator.is_main_process:
635
- if fbp:
636
- current_lr = base_learning_rate
637
- else:
638
- current_lr = lr_scheduler.get_last_lr()[0]
639
- batch_losses.append(loss.detach().item())
640
- batch_grads.append(grad)
641
-
642
- # Логируем в Wandb
643
- if use_wandb:
644
- wandb.log({
645
- "loss": loss.detach().item(),
646
- "learning_rate": current_lr,
647
- "epoch": epoch,
648
- "grad": grad,
649
- "global_step": global_step
650
- })
651
-
652
- # Генерируем сэмплы с заданным интервалом
653
- if global_step % sample_interval == 0:
654
- generate_and_save_samples(fixed_samples,global_step)
655
- if save_model:
656
- save_checkpoint(unet)
657
-
658
- # Выводим текущий лосс
659
- avg_loss = np.mean(batch_losses[-sample_interval:])
660
- avg_grad = np.mean(batch_grads[-sample_interval:])
661
- #print(f"Эпоха {epoch}, шаг {global_step}, средний лосс: {avg_loss:.6f}, LR: {current_lr:.8f}")
662
- if use_wandb:
663
- wandb.log({"intermediate_loss": avg_loss})
664
- wandb.log({"intermediate_grad": avg_grad})
665
-
666
-
667
- # По окончании эпохи
668
- if accelerator.is_main_process:
669
- avg_epoch_loss = np.mean(batch_losses)
670
- print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
671
- if use_wandb:
672
- wandb.log({"epoch_loss": avg_epoch_loss, "epoch": epoch+1})
673
-
674
- # Завершение обучения - сохраняем финальную модель
675
- if accelerator.is_main_process:
676
- print("Обучение завершено! Сохраняем финальную модель...")
677
- # Сохраняем осно��ную модель
678
- if save_model:
679
- save_checkpoint(accelerator.unwrap_model(unet).to(dtype = torch.float16))
680
-
681
- save_checkpoint(accelerator.unwrap_model(unet).to(dtype = torch.float16),"fp16")
682
- print("Готово!")
683
-
684
- # randomize ode timesteps
685
- # input_timestep = torch.round(
686
- # F.sigmoid(torch.randn((n,), device=latents.device)), decimals=3
687
- # )
688
-
689
- #def create_distribution(num_points, device=None):
690
- # # Диапазон вероятностей на оси x
691
- # x = torch.linspace(0, 1, num_points, device=device)
692
-
693
- # Пользовательская функция плотности вероятности
694
- # probabilities = -7.7 * ((x - 0.5) ** 2) + 2
695
-
696
- # Нормализация, чтобы сумма равнялась 1
697
- # probabilities /= probabilities.sum()
698
-
699
- # return x, probabilities
700
-
701
- #def sample_from_distribution(x, probabilities, n, device=None):
702
- # Выбор индексов на основе распределения вероятностей
703
- # indices = torch.multinomial(probabilities, n, replacement=True)
704
- # return x[indices]
705
-
706
- # Пример использования
707
- #num_points = 1000 # Количество точек в диапазоне
708
- #n = latents.shape[0] # Количество временных шагов для выборки
709
- #x, probabilities = create_distribution(num_points, device=latents.device)
710
- #timesteps = sample_from_distribution(x, probabilities, n, device=latents.device)
711
-
712
- # Преобразование в формат, подходящий для вашего кода
713
- #timesteps = (timesteps * (scheduler.config.num_train_timesteps - 1)).long()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train_nofbp.py DELETED
@@ -1,695 +0,0 @@
1
- import os
2
- import math
3
- import torch
4
- import numpy as np
5
- import matplotlib.pyplot as plt
6
- from torch.utils.data import DataLoader, Sampler
7
- from collections import defaultdict
8
- from torch.optim.lr_scheduler import LambdaLR
9
- from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler
10
- from accelerate import Accelerator
11
- from datasets import load_from_disk
12
- from tqdm import tqdm
13
- from PIL import Image,ImageOps
14
- import wandb
15
- import random
16
- import gc
17
- from accelerate.state import DistributedType
18
- from torch.distributed import broadcast_object_list
19
- from torch.utils.checkpoint import checkpoint
20
- from diffusers.models.attention_processor import AttnProcessor2_0
21
- from datetime import datetime
22
- import bitsandbytes as bnb
23
-
24
- # --------------------------- Параметры ---------------------------
25
- ds_path = "datasets/384"
26
- batch_size = 25
27
- base_learning_rate = 5e-5
28
- percentile_clipping = 97
29
- num_epochs = 5
30
- num_warmup_steps = 300
31
- project = "unet"
32
- use_wandb = True
33
- save_model = True
34
- adam8bit = True
35
- torch_compile = False
36
- unet_gradient = True
37
- clip_sample = False
38
- fixed_seed = True
39
- fbp = False
40
- sample_interval_share = 10 # samples/save per epoch
41
- dtype_unet = torch.float32
42
- dtype_embed = torch.float32
43
- dtype_infer = torch.float16
44
- steps_offset = 1
45
- limit = 0
46
- checkpoints_folder = ""
47
- mixed_precision = "no"
48
- accelerator = Accelerator(mixed_precision=mixed_precision)
49
- device = accelerator.device
50
-
51
- # Параметры для диффузии
52
- n_diffusion_steps = 50
53
- samples_to_generate = 12
54
- guidance_scale = 5
55
-
56
- # Папки для сохранения результатов
57
- generated_folder = "samples"
58
- os.makedirs(generated_folder, exist_ok=True)
59
-
60
- # Настройка seed для воспроизводимости
61
- current_date = datetime.now()
62
- seed = int(current_date.strftime("%Y%m%d"))
63
- if fixed_seed:
64
- torch.manual_seed(seed)
65
- np.random.seed(seed)
66
- random.seed(seed)
67
- if torch.cuda.is_available():
68
- torch.cuda.manual_seed_all(seed)
69
-
70
- # --------------------------- Параметры LoRA ---------------------------
71
- # pip install peft
72
- lora_name = "" #"nusha" # Имя для сохранения/загрузки LoRA адаптеров
73
- lora_rank = 32 # Ранг LoRA (чем меньше, тем компактнее модель)
74
- lora_alpha = 64 # Альфа параметр LoRA, определяющий масштаб
75
-
76
- print("init")
77
-
78
- # --------------------------- Инициализация WandB ---------------------------
79
- if use_wandb and accelerator.is_main_process:
80
- wandb.init(project=project+lora_name, config={
81
- "batch_size": batch_size,
82
- "base_learning_rate": base_learning_rate,
83
- "num_epochs": num_epochs,
84
- "fbp": fbp,
85
- "adam8bit": adam8bit,
86
- })
87
-
88
-
89
- # Включение Flash Attention 2/SDPA
90
- torch.backends.cuda.enable_flash_sdp(True)
91
- # --------------------------- Инициализация Accelerator --------------------
92
- gen = torch.Generator(device=device)
93
- gen.manual_seed(seed)
94
-
95
- # --------------------------- Загрузка моделей ---------------------------
96
- # VAE загружается на CPU для экономии GPU-памяти
97
- vae = AutoencoderKL.from_pretrained("vae", variant="fp16").to("cpu").eval()
98
- #vae = AutoencoderKL.from_pretrained("vae_flux").to("cpu").eval()
99
-
100
- # DDPMScheduler с V_Prediction и Zero-SNR
101
- scheduler = DDPMScheduler(
102
- num_train_timesteps=1000, # Полный график шагов для обучения
103
- prediction_type="v_prediction", # V-Prediction
104
- rescale_betas_zero_snr=True, # Включение Zero-SNR
105
- clip_sample = clip_sample,
106
- steps_offset = steps_offset
107
- )
108
-
109
- # --------------------------- Загрузка датасета ---------------------------
110
- class ResolutionBatchSampler(Sampler):
111
- """Сэмплер, который группирует примеры по одинаковым размерам"""
112
- def __init__(self, dataset, batch_size, shuffle=True, drop_last=False):
113
- self.dataset = dataset
114
- self.batch_size = batch_size
115
- self.shuffle = shuffle
116
- self.drop_last = drop_last
117
-
118
- # Группируем примеры по размерам
119
- self.size_groups = defaultdict(list)
120
-
121
- try:
122
- widths = dataset["width"]
123
- heights = dataset["height"]
124
- except KeyError:
125
- widths = [0] * len(dataset)
126
- heights = [0] * len(dataset)
127
-
128
- for i, (w, h) in enumerate(zip(widths, heights)):
129
- size = (w, h)
130
- self.size_groups[size].append(i)
131
-
132
- # Печатаем статистику по размерам
133
- print(f"Найдено {len(self.size_groups)} уникальных размеров:")
134
- for size, indices in sorted(self.size_groups.items(), key=lambda x: len(x[1]), reverse=True):
135
- width, height = size
136
- print(f" {width}x{height}: {len(indices)} примеров")
137
-
138
- # Формируем батчи
139
- self.reset()
140
-
141
- def reset(self):
142
- """Сбрасывает и перемешивает индексы"""
143
- self.batches = []
144
-
145
- for size, indices in self.size_groups.items():
146
- if self.shuffle:
147
- indices_copy = indices.copy()
148
- random.shuffle(indices_copy)
149
- else:
150
- indices_copy = indices
151
-
152
- # Разбиваем на батчи
153
- for i in range(0, len(indices_copy), self.batch_size):
154
- batch_indices = indices_copy[i:i + self.batch_size]
155
-
156
- # Пропускаем неполные батчи если drop_last=True
157
- if self.drop_last and len(batch_indices) < self.batch_size:
158
- continue
159
-
160
- self.batches.append(batch_indices)
161
-
162
- # Перемешиваем батчи между собой
163
- if self.shuffle:
164
- random.shuffle(self.batches)
165
-
166
- def __iter__(self):
167
- self.reset() # Сбрасываем и перемешиваем в начале каждой эпохи
168
- return iter(self.batches)
169
-
170
- def __len__(self):
171
- return len(self.batches)
172
-
173
- # Функция для выборки фиксированных семплов по размерам
174
- def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
175
- """Выбирает фиксированные семплы для каждого уникального разрешения"""
176
- # Группируем по размерам
177
- size_groups = defaultdict(list)
178
- try:
179
- widths = dataset["width"]
180
- heights = dataset["height"]
181
- except KeyError:
182
- widths = [0] * len(dataset)
183
- heights = [0] * len(dataset)
184
- for i, (w, h) in enumerate(zip(widths, heights)):
185
- size = (w, h)
186
- size_groups[size].append(i)
187
-
188
- # Выбираем фиксированные примеры из каждой группы
189
- fixed_samples = {}
190
- for size, indices in size_groups.items():
191
- # Определяем сколько семплов брать из этой группы
192
- n_samples = min(samples_per_group, len(indices))
193
- if len(size_groups)==1:
194
- n_samples = samples_to_generate
195
- if n_samples == 0:
196
- continue
197
-
198
- # Выбираем случайные индексы
199
- sample_indices = random.sample(indices, n_samples)
200
- samples_data = [dataset[idx] for idx in sample_indices]
201
-
202
- # Собираем данные
203
- latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device,dtype=dtype_embed)
204
- embeddings = torch.tensor(np.array([item["embeddings"] for item in samples_data])).to(device,dtype=dtype_embed)
205
- texts = [item["text"] for item in samples_data]
206
-
207
- # Сохраняем для этого размера
208
- fixed_samples[size] = (latents, embeddings, texts)
209
-
210
- print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
211
- return fixed_samples
212
-
213
- if limit > 0:
214
- dataset = load_from_disk(ds_path).select(range(limit))
215
- else:
216
- dataset = load_from_disk(ds_path)
217
-
218
-
219
- def collate_fn_simple(batch):
220
- # Преобразуем список в тензоры и перемещаем на девайс
221
- latents = torch.tensor(np.array([item["vae"] for item in batch])).to(device,dtype=dtype_embed)
222
- embeddings = torch.tensor(np.array([item["embeddings"] for item in batch])).to(device,dtype=dtype_embed)
223
- return latents, embeddings
224
-
225
- def collate_fn(batch):
226
- if not batch:
227
- return [], []
228
-
229
- # Берем эталонную форму
230
- ref_vae_shape = np.array(batch[0]["vae"]).shape
231
- ref_embed_shape = np.array(batch[0]["embeddings"]).shape
232
-
233
- # Фильтруем
234
- valid_latents = []
235
- valid_embeddings = []
236
- for item in batch:
237
- if (np.array(item["vae"]).shape == ref_vae_shape and
238
- np.array(item["embeddings"]).shape == ref_embed_shape):
239
- valid_latents.append(item["vae"])
240
- valid_embeddings.append(item["embeddings"])
241
-
242
- # Создаем тензоры
243
- latents = torch.tensor(np.array(valid_latents)).to(device,dtype=dtype_embed)
244
- embeddings = torch.tensor(np.array(valid_embeddings)).to(device,dtype=dtype_embed)
245
-
246
- return latents, embeddings
247
-
248
- # Используем наш ResolutionBatchSampler
249
- batch_sampler = ResolutionBatchSampler(dataset, batch_size=batch_size, shuffle=True)
250
- dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn)
251
-
252
- print("Total samples",len(dataloader))
253
- dataloader = accelerator.prepare(dataloader)
254
-
255
- # Инициализация переменных для возобновления обучения
256
- start_epoch = 0
257
- global_step = 0
258
-
259
- # Расчёт общего количества шагов
260
- total_training_steps = (len(dataloader) * num_epochs)
261
- # Get the world size
262
- world_size = accelerator.state.num_processes
263
- #print(f"World Size: {world_size}")
264
-
265
- # Опция загрузки модели из последнего чекпоинта (если существует)
266
- latest_checkpoint = os.path.join(checkpoints_folder, project)
267
- if os.path.isdir(latest_checkpoint):
268
- print("Загружаем UNet из чекпоинта:", latest_checkpoint)
269
- unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device,dtype=dtype_unet)
270
- if unet_gradient:
271
- unet.enable_gradient_checkpointing()
272
- unet.set_use_memory_efficient_attention_xformers(False) # отключаем xformers
273
- try:
274
- unet.set_attn_processor(AttnProcessor2_0()) # Используем стандартный AttnProcessor
275
- except Exception as e:
276
- print(f"Ошибка при включении SDPA: {e}")
277
- print("Попытка использовать enable_xformers_memory_efficient_attention.")
278
- unet.set_use_memory_efficient_attention_xformers(True)
279
-
280
- if hasattr(torch.backends.cuda, "flash_sdp_enabled"):
281
- print(f"torch.backends.cuda.flash_sdp_enabled(): {torch.backends.cuda.flash_sdp_enabled()}")
282
- if hasattr(torch.backends.cuda, "mem_efficient_sdp_enabled"):
283
- print(f"torch.backends.cuda.mem_efficient_sdp_enabled(): {torch.backends.cuda.mem_efficient_sdp_enabled()}")
284
- if hasattr(torch.nn.functional, "get_flash_attention_available"):
285
- print(f"torch.nn.functional.get_flash_attention_available(): {torch.nn.functional.get_flash_attention_available()}")
286
- if torch_compile:
287
- print("compiling")
288
- torch.set_float32_matmul_precision('high')
289
- unet = torch.compile(unet)#, mode="reduce-overhead", fullgraph=True)
290
- print("compiling - ok")
291
-
292
- if lora_name:
293
- print(f"--- Настройка LoRA через PEFT (Rank={lora_rank}, Alpha={lora_alpha}) ---")
294
- from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
295
- from peft.tuners.lora import LoraModel
296
- import os
297
- # 1. Замораживаем все параметры UNet
298
- unet.requires_grad_(False)
299
- print("Параметры базового UNet заморожены.")
300
-
301
- # 2. Создаем конфигурацию LoRA
302
- lora_config = LoraConfig(
303
- r=lora_rank,
304
- lora_alpha=lora_alpha,
305
- target_modules=["to_q", "to_k", "to_v", "to_out.0"],
306
- )
307
- unet.add_adapter(lora_config)
308
-
309
- # 3. Оборачиваем UNet в PEFT-модель
310
- from peft import get_peft_model
311
-
312
- peft_unet = get_peft_model(unet, lora_config)
313
-
314
- # 4. Получаем параметры для оптимизации
315
- params_to_optimize = list(p for p in peft_unet.parameters() if p.requires_grad)
316
-
317
-
318
- # 5. Выводим информацию о количестве параметров
319
- if accelerator.is_main_process:
320
- lora_params_count = sum(p.numel() for p in params_to_optimize)
321
- total_params_count = sum(p.numel() for p in unet.parameters())
322
- print(f"Количество обучаемых параметров (LoRA): {lora_params_count:,}")
323
- print(f"Общее количество параметров UNet: {total_params_count:,}")
324
-
325
- # 6. Путь для сохранения
326
- lora_save_path = os.path.join("lora", lora_name)
327
- os.makedirs(lora_save_path, exist_ok=True)
328
-
329
- # 7. Функция для сохранения
330
- def save_lora_checkpoint(model):
331
- if accelerator.is_main_process:
332
- print(f"Сохраняем LoRA адаптеры в {lora_save_path}")
333
- from peft.utils.save_and_load import get_peft_model_state_dict
334
- # Получаем state_dict только LoRA
335
- lora_state_dict = get_peft_model_state_dict(model)
336
-
337
- # Сохраняем веса
338
- torch.save(lora_state_dict, os.path.join(lora_save_path, "adapter_model.bin"))
339
-
340
- # Сохраняем конфиг
341
- model.peft_config["default"].save_pretrained(lora_save_path)
342
- # SDXL must be compatible
343
- from diffusers import StableDiffusionXLPipeline
344
- StableDiffusionXLPipeline.save_lora_weights(lora_save_path, lora_state_dict)
345
-
346
- # --------------------------- Оптимизатор ---------------------------
347
- # Определяем параметры для оптимизации
348
- #unet = torch.compile(unet)
349
- if lora_name:
350
- # Если используется LoRA, оптимизируем только параметры LoRA
351
- trainable_params = [p for p in unet.parameters() if p.requires_grad]
352
- else:
353
- # Иначе оптимизируем все параметры
354
- if fbp:
355
- trainable_params = list(unet.parameters())
356
-
357
- if fbp:
358
- # [1] Создаем словарь оптимизаторов (fused backward)
359
- if adam8bit:
360
- optimizer_dict = {
361
- p: bnb.optim.AdamW8bit(
362
- [p], # К��ждый параметр получает свой оптимизатор
363
- lr=base_learning_rate,
364
- eps=1e-8
365
- ) for p in trainable_params
366
- }
367
- else:
368
- optimizer_dict = {
369
- p: bnb.optim.Lion8bit(
370
- [p], # Каждый параметр получает свой оптимизатор
371
- lr=base_learning_rate,
372
- betas=(0.9, 0.97),
373
- weight_decay=0.01,
374
- percentile_clipping=percentile_clipping,
375
- ) for p in trainable_params
376
- }
377
-
378
- # [2] Определяем hook для применения оптимизатора сразу после накопления градиента
379
- def optimizer_hook(param):
380
- optimizer_dict[param].step()
381
- optimizer_dict[param].zero_grad(set_to_none=True)
382
-
383
- # [3] Регистрируем hook для trainable параметров модели
384
- for param in trainable_params:
385
- param.register_post_accumulate_grad_hook(optimizer_hook)
386
-
387
- # Подготовка через Accelerator
388
- unet, optimizer = accelerator.prepare(unet, optimizer_dict)
389
- else:
390
- if adam8bit:
391
- optimizer = bnb.optim.AdamW8bit(
392
- params=unet.parameters(),
393
- lr=base_learning_rate,
394
- eps=1e-8
395
- )
396
- else:
397
- optimizer = bnb.optim.Lion8bit(
398
- params=unet.parameters(),
399
- lr=base_learning_rate,
400
- betas=(0.9, 0.97),
401
- weight_decay=0.01,
402
- percentile_clipping=percentile_clipping,
403
- )
404
- from transformers import get_constant_schedule_with_warmup
405
-
406
- # warmup
407
- num_warmup_steps = num_warmup_steps * world_size
408
-
409
- lr_scheduler = get_constant_schedule_with_warmup(
410
- optimizer=optimizer,
411
- num_warmup_steps=num_warmup_steps
412
- )
413
-
414
- unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
415
-
416
- # --------------------------- Фиксированные семплы для генерации ---------------------------
417
- # Примеры фиксированных семплов по размерам
418
- fixed_samples = get_fixed_samples_by_resolution(dataset)
419
-
420
- @torch.compiler.disable()
421
- @torch.no_grad()
422
- def generate_and_save_samples(fixed_samples_cpu, step):
423
- """
424
- Генерирует семплы для каждого из разрешений и сохраняет их.
425
-
426
- Args:
427
- fixed_samples_cpu: Словарь, где ключи - размеры (width, height),
428
- а значения - кортежи (latents, embeddings, text) на CPU.
429
- step: Текущий шаг обучения
430
- """
431
- original_model = None # Инициализируем, чтобы finally не ругался
432
- try:
433
-
434
- original_model = accelerator.unwrap_model(unet)
435
- original_model = original_model.to(dtype = dtype_infer)
436
- original_model.eval()
437
-
438
- vae.to(device=device, dtype=dtype_infer)
439
- vae.eval()
440
-
441
- scheduler.set_timesteps(n_diffusion_steps)
442
-
443
- all_generated_images = []
444
- all_captions = []
445
-
446
- for size, (sample_latents, sample_text_embeddings, sample_text) in fixed_samples_cpu.items():
447
- width, height = size
448
-
449
- sample_latents = sample_latents.to(dtype=dtype_infer)
450
- sample_text_embeddings = sample_text_embeddings.to(dtype=dtype_infer)
451
-
452
- # Инициализируем латенты случайным шумом
453
- # sample_latents уже в dtype_infer, так что noise будет создан в dtype_infer
454
- noise = torch.randn(
455
- sample_latents.shape, # Используем форму от sample_latents, которые теперь на GPU и fp16
456
- generator=gen,
457
- device=device,
458
- dtype=sample_latents.dtype
459
- )
460
- current_latents = noise.clone()
461
-
462
- # Подготовка текстовых эмбеддингов для guidance
463
- if guidance_scale > 0:
464
- # empty_embeddings должны быть того же типа и на том же устройстве
465
- empty_embeddings = torch.zeros_like(sample_text_embeddings, dtype=sample_text_embeddings.dtype, device=device)
466
- text_embeddings_batch = torch.cat([empty_embeddings, sample_text_embeddings], dim=0)
467
- else:
468
- text_embeddings_batch = sample_text_embeddings
469
-
470
- for t in scheduler.timesteps:
471
- t_batch = t.repeat(current_latents.shape[0]).to(device) # Убедимся, что t на устройстве
472
-
473
- if guidance_scale > 0:
474
- latent_model_input = torch.cat([current_latents] * 2)
475
- else:
476
- latent_model_input = current_latents
477
-
478
- latent_model_input_scaled = scheduler.scale_model_input(latent_model_input, t_batch)
479
-
480
- # Предсказание шума (UNet)
481
- noise_pred = original_model(latent_model_input_scaled, t_batch, text_embeddings_batch).sample
482
-
483
- if guidance_scale > 0:
484
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
485
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
486
-
487
- current_latents = scheduler.step(noise_pred, t, current_latents).prev_sample
488
-
489
- #print(f"current_latents Min: {current_latents.min()} Max: {current_latents.max()}")
490
- # Декодирование через VAE
491
- latent_for_vae = (current_latents.detach() / vae.config.scaling_factor) + vae.config.shift_factor
492
- decoded = vae.decode(latent_for_vae).sample
493
-
494
- # Преобразуем тензоры в PIL-изображения
495
- # Для математики с изображением (нормализация) лучше перейти в fp32
496
- decoded_fp32 = decoded.to(torch.float32)
497
- for img_idx, img_tensor in enumerate(decoded_fp32):
498
- img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1, 2, 0)
499
- # If NaNs or infs are present, print them
500
- if np.isnan(img).any():
501
- print("NaNs found, saving stoped! Step:", step)
502
- save_model = False
503
- pil_img = Image.fromarray((img * 255).astype("uint8"))
504
-
505
- max_w_overall = max(s[0] for s in fixed_samples_cpu.keys())
506
- max_h_overall = max(s[1] for s in fixed_samples_cpu.keys())
507
- max_w_overall = max(255, max_w_overall)
508
- max_h_overall = max(255, max_h_overall)
509
-
510
- padded_img = ImageOps.pad(pil_img, (max_w_overall, max_h_overall), color='white')
511
- all_generated_images.append(padded_img)
512
-
513
- caption_text = sample_text[img_idx][:200] if img_idx < len(sample_text) else ""
514
- all_captions.append(caption_text)
515
-
516
- sample_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg"
517
- pil_img.save(sample_path, "JPEG", quality=96)
518
-
519
- if use_wandb and accelerator.is_main_process:
520
- wandb_images = [
521
- wandb.Image(img, caption=f"{all_captions[i]}")
522
- for i, img in enumerate(all_generated_images)
523
- ]
524
- wandb.log({"generated_images": wandb_images, "global_step": step})
525
-
526
- finally:
527
- vae.to("cpu") # Перемещаем VAE обратно на CPU
528
- original_model = original_model.to(dtype = dtype_unet)
529
- if original_model is not None:
530
- del original_model
531
- # Очистка переменных, которые являются тензорами и были созданы в функции
532
- for var in list(locals().keys()):
533
- if isinstance(locals()[var], torch.Tensor):
534
- del locals()[var]
535
-
536
- torch.cuda.empty_cache()
537
- gc.collect()
538
-
539
- # --------------------------- Генерация сэмплов перед обучением ---------------------------
540
- if accelerator.is_main_process:
541
- if save_model:
542
- print("Генерация сэмплов до старта обучения...")
543
- generate_and_save_samples(fixed_samples,0)
544
-
545
- # Модифицируем функцию сохранения модели для поддержки LoRA
546
- def save_checkpoint(unet,variant=""):
547
- if accelerator.is_main_process:
548
- if lora_name:
549
- # Сохраняем только LoRA адаптеры
550
- save_lora_checkpoint(unet)
551
- else:
552
- # Сохраняем полную модель
553
- if variant!="":
554
- accelerator.unwrap_model(unet.to(dtype=dtype_infer)).save_pretrained(os.path.join(checkpoints_folder, f"{project}"),variant=variant)
555
- else:
556
- accelerator.unwrap_model(unet.to(dtype=dtype_infer)).save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
557
- unet = unet.to(dtype=dtype_unet)
558
-
559
- # --------------------------- Тренировочный цикл ---------------------------
560
- # Для логирования среднего лосса каждые % эпохи
561
- if accelerator.is_main_process:
562
- print(f"Total steps per GPU: {total_training_steps}")
563
-
564
- epoch_loss_points = []
565
- progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
566
-
567
- # Определяем интервал для сэмплирования и логирования в пределах эпохи (10% эпохи)
568
- steps_per_epoch = len(dataloader)
569
- sample_interval = max(1, steps_per_epoch // sample_interval_share)
570
-
571
- # Начинаем �� указанной эпохи (полезно при возобновлении)
572
- for epoch in range(start_epoch, start_epoch + num_epochs):
573
- batch_losses = []
574
- batch_grads = []
575
- unet = unet.to(dtype = dtype_unet)
576
- unet.train()
577
- for step, (latents, embeddings) in enumerate(dataloader):
578
- if save_model == False and step == 5 :
579
- used_gb = torch.cuda.max_memory_allocated() / 1024**3
580
- print(f"Шаг {step}: {used_gb:.2f} GB")
581
-
582
- # Forward pass
583
- noise = torch.randn_like(latents, dtype=latents.dtype)
584
-
585
- timesteps = torch.randint(steps_offset, scheduler.config.num_train_timesteps,
586
- (latents.shape[0],), device=device).long()
587
-
588
- # Добавляем шум к латентам
589
- noisy_latents = scheduler.add_noise(latents, noise, timesteps)
590
-
591
- # Используем целевое значение
592
- model_pred = unet(noisy_latents, timesteps, embeddings).sample
593
- target_pred = scheduler.get_velocity(latents, noise, timesteps)
594
-
595
- # Считаем лосс
596
- loss = torch.nn.functional.mse_loss(model_pred, target_pred)
597
-
598
- # Делаем backward через Accelerator
599
- accelerator.backward(loss)
600
-
601
- grad = 0.0
602
- if not fbp:
603
- #if accelerator.sync_gradients:
604
- grad = accelerator.clip_grad_norm_(unet.parameters(), 1.)
605
- optimizer.step()
606
- lr_scheduler.step()
607
- optimizer.zero_grad(set_to_none=True)
608
-
609
- # Увеличиваем счетчик глобальных шагов
610
- global_step += 1
611
-
612
- # Обновляем прогресс-бар
613
- progress_bar.update(1)
614
-
615
- # Логируем метрики
616
- if accelerator.is_main_process:
617
- if fbp:
618
- current_lr = base_learning_rate
619
- else:
620
- current_lr = lr_scheduler.get_last_lr()[0]
621
- batch_losses.append(loss.detach().item())
622
- batch_grads.append(loss.detach().item())
623
-
624
- # Логируем в Wandb
625
- if use_wandb:
626
- wandb.log({
627
- "loss": loss.detach().item(),
628
- "learning_rate": current_lr,
629
- "epoch": epoch,
630
- "grad": grad,
631
- "global_step": global_step
632
- })
633
-
634
- # Генерируем сэмплы с заданным интервалом
635
- if global_step % sample_interval == 0:
636
- generate_and_save_samples(fixed_samples,global_step)
637
- if save_model:
638
- save_checkpoint(unet)
639
-
640
- # Выводим текущий лосс
641
- avg_loss = np.mean(batch_losses[-sample_interval:])
642
- #print(f"Эпоха {epoch}, шаг {global_step}, средний лосс: {avg_loss:.6f}, LR: {current_lr:.8f}")
643
- if use_wandb:
644
- wandb.log({"intermediate_loss": avg_loss})
645
-
646
-
647
- # По окончании эпохи
648
- if accelerator.is_main_process:
649
- avg_epoch_loss = np.mean(batch_losses)
650
- avg_epoch_grad = np.mean(batch_grads)
651
- print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
652
- if use_wandb:
653
- wandb.log({"epoch_loss": avg_epoch_loss, "epoch": epoch+1})
654
- wandb.log({"epoch_grad": avg_epoch_grad, "epoch": epoch+1})
655
-
656
- # Завершение обучения - сохраняем финальную модель
657
- if accelerator.is_main_process:
658
- print("Обучение завершено! Сохраняем финальную модель...")
659
- # Сохраняем основную модель
660
- if save_model:
661
- save_checkpoint(accelerator.unwrap_model(unet).to(dtype = torch.float16))
662
-
663
- save_checkpoint(accelerator.unwrap_model(unet).to(dtype = torch.float16),"fp16")
664
- print("Готово!")
665
-
666
- # randomize ode timesteps
667
- # input_timestep = torch.round(
668
- # F.sigmoid(torch.randn((n,), device=latents.device)), decimals=3
669
- # )
670
-
671
- #def create_distribution(num_points, device=None):
672
- # # Диапазон вероятностей на оси x
673
- # x = torch.linspace(0, 1, num_points, device=device)
674
-
675
- # Пользовательская функция плотности вероятности
676
- # probabilities = -7.7 * ((x - 0.5) ** 2) + 2
677
-
678
- # Нормализация, чтобы сумма равнялась 1
679
- # probabilities /= probabilities.sum()
680
-
681
- # return x, probabilities
682
-
683
- #def sample_from_distribution(x, probabilities, n, device=None):
684
- # Выбор индексов на основе распределения вероятностей
685
- # indices = torch.multinomial(probabilities, n, replacement=True)
686
- # return x[indices]
687
-
688
- # Пример использования
689
- #num_points = 1000 # Количество точек в диапазоне
690
- #n = latents.shape[0] # Количество временных шагов для выборки
691
- #x, probabilities = create_distribution(num_points, device=latents.device)
692
- #timesteps = sample_from_distribution(x, probabilities, n, device=latents.device)
693
-
694
- # Преобразование в формат, подходящий для вашего кода
695
- #timesteps = (timesteps * (scheduler.config.num_train_timesteps - 1)).long()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unet/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e2987b5660219328cd1c22e5c4072a561d8aa8dabb3b488c55fd06e9d9059229
3
  size 7014306128
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:16c60b36c5f772a0393282bc30c777777c26683f30859e4ae680762628338af7
3
  size 7014306128