|
import os |
|
import math |
|
import torch |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from torch.utils.data import DataLoader, Sampler |
|
from torch.utils.data.distributed import DistributedSampler |
|
from torch.optim.lr_scheduler import LambdaLR |
|
from collections import defaultdict |
|
from torch.optim.lr_scheduler import LambdaLR |
|
from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler |
|
from accelerate import Accelerator |
|
from datasets import load_from_disk |
|
from tqdm import tqdm |
|
from PIL import Image,ImageOps |
|
import wandb |
|
import random |
|
import gc |
|
from accelerate.state import DistributedType |
|
from torch.distributed import broadcast_object_list |
|
from torch.utils.checkpoint import checkpoint |
|
from diffusers.models.attention_processor import AttnProcessor2_0 |
|
from datetime import datetime |
|
import bitsandbytes as bnb |
|
import torch.nn.functional as F |
|
|
|
|
|
ds_path = "datasets/384" |
|
project = "micro" |
|
batch_size = 64 |
|
base_learning_rate = 1e-4 |
|
min_learning_rate = 5e-5 |
|
num_epochs = 50 |
|
|
|
sample_interval_share = 10 |
|
use_wandb = True |
|
save_model = True |
|
use_decay = True |
|
fbp = False |
|
optimizer_type = "adam8bit" |
|
torch_compile = False |
|
unet_gradient = True |
|
clip_sample = False |
|
fixed_seed = False |
|
shuffle = True |
|
dispersive_loss_enabled = True |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.backends.cudnn.allow_tf32 = True |
|
torch.backends.cuda.enable_mem_efficient_sdp(False) |
|
dtype = torch.float32 |
|
save_barrier = 1.03 |
|
warmup_percent = 0.01 |
|
dispersive_temperature=0.5 |
|
dispersive_weight= 0.05 |
|
percentile_clipping = 95 |
|
betta2 = 0.97 |
|
eps = 1e-6 |
|
clip_grad_norm = 1.0 |
|
steps_offset = 0 |
|
limit = 0 |
|
checkpoints_folder = "" |
|
mixed_precision = "no" |
|
gradient_accumulation_steps = 1 |
|
accelerator = Accelerator( |
|
mixed_precision=mixed_precision, |
|
gradient_accumulation_steps=gradient_accumulation_steps |
|
) |
|
device = accelerator.device |
|
|
|
|
|
n_diffusion_steps = 50 |
|
samples_to_generate = 12 |
|
guidance_scale = 5 |
|
|
|
|
|
generated_folder = "samples" |
|
os.makedirs(generated_folder, exist_ok=True) |
|
|
|
|
|
current_date = datetime.now() |
|
seed = int(current_date.strftime("%Y%m%d")) |
|
if fixed_seed: |
|
torch.manual_seed(seed) |
|
np.random.seed(seed) |
|
random.seed(seed) |
|
if torch.cuda.is_available(): |
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
|
|
lora_name = "" |
|
lora_rank = 32 |
|
lora_alpha = 64 |
|
|
|
print("init") |
|
|
|
class AccelerateDispersiveLoss: |
|
def __init__(self, accelerator, temperature=0.5, weight=0.5): |
|
self.accelerator = accelerator |
|
self.temperature = temperature |
|
self.weight = weight |
|
self.activations = [] |
|
self.hooks = [] |
|
|
|
def register_hooks(self, model, target_layer="down_blocks.0"): |
|
unwrapped_model = self.accelerator.unwrap_model(model) |
|
print("=== Поиск слоев в unwrapped модели ===") |
|
for name, module in unwrapped_model.named_modules(): |
|
if target_layer in name: |
|
hook = module.register_forward_hook(self.hook_fn) |
|
self.hooks.append(hook) |
|
print(f"✅ Хук зарегистрирован на: {name}") |
|
break |
|
|
|
def hook_fn(self, module, input, output): |
|
|
|
if isinstance(output, tuple): |
|
activation = output[0] |
|
else: |
|
activation = output |
|
|
|
if len(activation.shape) > 2: |
|
activation = activation.view(activation.shape[0], -1) |
|
|
|
self.activations.append(activation.detach()) |
|
|
|
def compute_dispersive_loss(self): |
|
if not self.activations: |
|
return torch.tensor(0.0, requires_grad=True) |
|
|
|
local_activations = self.activations[-1].float() |
|
|
|
batch_size = local_activations.shape[0] |
|
if batch_size < 2: |
|
return torch.tensor(0.0, requires_grad=True) |
|
|
|
|
|
sf = local_activations / torch.norm(local_activations, dim=1, keepdim=True) |
|
distance = torch.nn.functional.pdist(sf.float(), p=2) ** 2 |
|
exp_neg_dist = torch.exp(-distance / self.temperature) + 1e-5 |
|
dispersive_loss = torch.log(torch.mean(exp_neg_dist)) |
|
|
|
|
|
return dispersive_loss |
|
|
|
def clear_activations(self): |
|
self.activations.clear() |
|
|
|
def remove_hooks(self): |
|
for hook in self.hooks: |
|
hook.remove() |
|
self.hooks.clear() |
|
|
|
|
|
|
|
|
|
if use_wandb and accelerator.is_main_process: |
|
wandb.init(project=project+lora_name, config={ |
|
"batch_size": batch_size, |
|
"base_learning_rate": base_learning_rate, |
|
"num_epochs": num_epochs, |
|
"fbp": fbp, |
|
"optimizer_type": optimizer_type, |
|
}) |
|
|
|
|
|
torch.backends.cuda.enable_flash_sdp(True) |
|
|
|
gen = torch.Generator(device=device) |
|
gen.manual_seed(seed) |
|
|
|
|
|
|
|
vae = AutoencoderKL.from_pretrained("vae", variant="fp16").to("cpu").eval() |
|
|
|
|
|
scheduler = DDPMScheduler( |
|
num_train_timesteps=1000, |
|
prediction_type="v_prediction", |
|
rescale_betas_zero_snr=True, |
|
clip_sample = clip_sample, |
|
steps_offset = steps_offset |
|
) |
|
|
|
|
|
class DistributedResolutionBatchSampler(Sampler): |
|
def __init__(self, dataset, batch_size, num_replicas, rank, shuffle=True, drop_last=True): |
|
self.dataset = dataset |
|
self.batch_size = max(1, batch_size // num_replicas) |
|
self.num_replicas = num_replicas |
|
self.rank = rank |
|
self.shuffle = shuffle |
|
self.drop_last = drop_last |
|
self.epoch = 0 |
|
|
|
|
|
try: |
|
widths = np.array(dataset["width"]) |
|
heights = np.array(dataset["height"]) |
|
except KeyError: |
|
widths = np.zeros(len(dataset)) |
|
heights = np.zeros(len(dataset)) |
|
|
|
|
|
self.size_keys = np.unique(np.stack([widths, heights], axis=1), axis=0) |
|
|
|
|
|
self.size_groups = {} |
|
for w, h in self.size_keys: |
|
mask = (widths == w) & (heights == h) |
|
self.size_groups[(w, h)] = np.where(mask)[0] |
|
|
|
|
|
self.group_num_batches = {} |
|
total_batches = 0 |
|
for size, indices in self.size_groups.items(): |
|
num_full_batches = len(indices) // (self.batch_size * self.num_replicas) |
|
self.group_num_batches[size] = num_full_batches |
|
total_batches += num_full_batches |
|
|
|
|
|
self.num_batches = (total_batches // self.num_replicas) * self.num_replicas |
|
|
|
def __iter__(self): |
|
|
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
all_batches = [] |
|
rng = np.random.RandomState(self.epoch) |
|
|
|
for size, indices in self.size_groups.items(): |
|
|
|
indices = indices.copy() |
|
if self.shuffle: |
|
rng.shuffle(indices) |
|
|
|
num_full_batches = self.group_num_batches[size] |
|
if num_full_batches == 0: |
|
continue |
|
|
|
|
|
valid_indices = indices[:num_full_batches * self.batch_size * self.num_replicas] |
|
|
|
|
|
batches = valid_indices.reshape(-1, self.batch_size * self.num_replicas) |
|
|
|
|
|
start_idx = self.rank * self.batch_size |
|
end_idx = start_idx + self.batch_size |
|
gpu_batches = batches[:, start_idx:end_idx] |
|
|
|
all_batches.extend(gpu_batches) |
|
|
|
if self.shuffle: |
|
rng.shuffle(all_batches) |
|
|
|
|
|
accelerator.wait_for_everyone() |
|
|
|
return iter(all_batches) |
|
|
|
def __len__(self): |
|
return self.num_batches |
|
|
|
def set_epoch(self, epoch): |
|
self.epoch = epoch |
|
|
|
|
|
def get_fixed_samples_by_resolution(dataset, samples_per_group=1): |
|
"""Выбирает фиксированные семплы для каждого уникального разрешения""" |
|
|
|
size_groups = defaultdict(list) |
|
try: |
|
widths = dataset["width"] |
|
heights = dataset["height"] |
|
except KeyError: |
|
widths = [0] * len(dataset) |
|
heights = [0] * len(dataset) |
|
for i, (w, h) in enumerate(zip(widths, heights)): |
|
size = (w, h) |
|
size_groups[size].append(i) |
|
|
|
|
|
fixed_samples = {} |
|
for size, indices in size_groups.items(): |
|
|
|
n_samples = min(samples_per_group, len(indices)) |
|
if len(size_groups)==1: |
|
n_samples = samples_to_generate |
|
if n_samples == 0: |
|
continue |
|
|
|
|
|
sample_indices = random.sample(indices, n_samples) |
|
samples_data = [dataset[idx] for idx in sample_indices] |
|
|
|
|
|
latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device,dtype=dtype) |
|
embeddings = torch.tensor(np.array([item["embeddings"] for item in samples_data])).to(device,dtype=dtype) |
|
texts = [item["text"] for item in samples_data] |
|
|
|
|
|
fixed_samples[size] = (latents, embeddings, texts) |
|
|
|
print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям") |
|
return fixed_samples |
|
|
|
if limit > 0: |
|
dataset = load_from_disk(ds_path).select(range(limit)) |
|
else: |
|
dataset = load_from_disk(ds_path) |
|
|
|
def collate_fn_simple(batch): |
|
|
|
latents = torch.tensor(np.array([item["vae"] for item in batch])).to(device,dtype=dtype) |
|
embeddings = torch.tensor(np.array([item["embeddings"] for item in batch])).to(device,dtype=dtype) |
|
return latents, embeddings |
|
|
|
def collate_fn(batch): |
|
if not batch: |
|
return [], [] |
|
|
|
|
|
ref_vae_shape = np.array(batch[0]["vae"]).shape |
|
ref_embed_shape = np.array(batch[0]["embeddings"]).shape |
|
|
|
|
|
valid_latents = [] |
|
valid_embeddings = [] |
|
for item in batch: |
|
if (np.array(item["vae"]).shape == ref_vae_shape and |
|
np.array(item["embeddings"]).shape == ref_embed_shape): |
|
valid_latents.append(item["vae"]) |
|
valid_embeddings.append(item["embeddings"]) |
|
|
|
|
|
latents = torch.tensor(np.array(valid_latents)).to(device,dtype=dtype) |
|
embeddings = torch.tensor(np.array(valid_embeddings)).to(device,dtype=dtype) |
|
|
|
return latents, embeddings |
|
|
|
|
|
batch_sampler = DistributedResolutionBatchSampler( |
|
dataset=dataset, |
|
batch_size=batch_size, |
|
num_replicas=accelerator.num_processes, |
|
rank=accelerator.process_index, |
|
shuffle=shuffle |
|
) |
|
|
|
|
|
dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple) |
|
|
|
print("Total samples",len(dataloader)) |
|
dataloader = accelerator.prepare(dataloader) |
|
|
|
|
|
start_epoch = 0 |
|
global_step = 0 |
|
|
|
|
|
total_training_steps = (len(dataloader) * num_epochs) |
|
|
|
world_size = accelerator.state.num_processes |
|
|
|
|
|
|
|
latest_checkpoint = os.path.join(checkpoints_folder, project) |
|
if os.path.isdir(latest_checkpoint): |
|
print("Загружаем UNet из чекпоинта:", latest_checkpoint) |
|
|
|
unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device,dtype=dtype) |
|
|
|
|
|
if unet_gradient: |
|
unet.enable_gradient_checkpointing() |
|
unet.set_use_memory_efficient_attention_xformers(False) |
|
try: |
|
unet.set_attn_processor(AttnProcessor2_0()) |
|
except Exception as e: |
|
print(f"Ошибка при включении SDPA: {e}") |
|
print("Попытка использовать enable_xformers_memory_efficient_attention.") |
|
unet.set_use_memory_efficient_attention_xformers(True) |
|
|
|
if hasattr(torch.backends.cuda, "flash_sdp_enabled"): |
|
print(f"torch.backends.cuda.flash_sdp_enabled(): {torch.backends.cuda.flash_sdp_enabled()}") |
|
if hasattr(torch.backends.cuda, "mem_efficient_sdp_enabled"): |
|
print(f"torch.backends.cuda.mem_efficient_sdp_enabled(): {torch.backends.cuda.mem_efficient_sdp_enabled()}") |
|
if hasattr(torch.nn.functional, "get_flash_attention_available"): |
|
print(f"torch.nn.functional.get_flash_attention_available(): {torch.nn.functional.get_flash_attention_available()}") |
|
|
|
|
|
if dispersive_loss_enabled: |
|
dispersive_hook = AccelerateDispersiveLoss( |
|
accelerator=accelerator, |
|
temperature=dispersive_temperature, |
|
weight=dispersive_weight |
|
) |
|
|
|
if torch_compile: |
|
print("compiling") |
|
torch.set_float32_matmul_precision('high') |
|
unet = torch.compile(unet, mode="reduce-overhead", fullgraph=False) |
|
print("compiling - ok") |
|
|
|
if lora_name: |
|
print(f"--- Настройка LoRA через PEFT (Rank={lora_rank}, Alpha={lora_alpha}) ---") |
|
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training |
|
from peft.tuners.lora import LoraModel |
|
import os |
|
|
|
unet.requires_grad_(False) |
|
print("Параметры базового UNet заморожены.") |
|
|
|
|
|
lora_config = LoraConfig( |
|
r=lora_rank, |
|
lora_alpha=lora_alpha, |
|
target_modules=["to_q", "to_k", "to_v", "to_out.0"], |
|
) |
|
unet.add_adapter(lora_config) |
|
|
|
|
|
from peft import get_peft_model |
|
|
|
peft_unet = get_peft_model(unet, lora_config) |
|
|
|
|
|
params_to_optimize = list(p for p in peft_unet.parameters() if p.requires_grad) |
|
|
|
|
|
|
|
if accelerator.is_main_process: |
|
lora_params_count = sum(p.numel() for p in params_to_optimize) |
|
total_params_count = sum(p.numel() for p in unet.parameters()) |
|
print(f"Количество обучаемых параметров (LoRA): {lora_params_count:,}") |
|
print(f"Общее количество параметров UNet: {total_params_count:,}") |
|
|
|
|
|
lora_save_path = os.path.join("lora", lora_name) |
|
os.makedirs(lora_save_path, exist_ok=True) |
|
|
|
|
|
def save_lora_checkpoint(model): |
|
if accelerator.is_main_process: |
|
print(f"Сохраняем LoRA адаптеры в {lora_save_path}") |
|
from peft.utils.save_and_load import get_peft_model_state_dict |
|
|
|
lora_state_dict = get_peft_model_state_dict(model) |
|
|
|
|
|
torch.save(lora_state_dict, os.path.join(lora_save_path, "adapter_model.bin")) |
|
|
|
|
|
model.peft_config["default"].save_pretrained(lora_save_path) |
|
|
|
from diffusers import StableDiffusionXLPipeline |
|
StableDiffusionXLPipeline.save_lora_weights(lora_save_path, lora_state_dict) |
|
|
|
|
|
|
|
|
|
if lora_name: |
|
|
|
trainable_params = [p for p in unet.parameters() if p.requires_grad] |
|
else: |
|
|
|
if fbp: |
|
trainable_params = list(unet.parameters()) |
|
|
|
def create_optimizer(name, params): |
|
if name == "adam8bit": |
|
return bnb.optim.AdamW8bit( |
|
params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.001, |
|
percentile_clipping=percentile_clipping |
|
) |
|
elif name == "adam": |
|
return torch.optim.AdamW( |
|
params, lr=base_learning_rate, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01 |
|
) |
|
elif name == "lion8bit": |
|
return bnb.optim.Lion8bit( |
|
params, lr=base_learning_rate, betas=(0.9, 0.97), weight_decay=0.01, |
|
percentile_clipping=percentile_clipping |
|
) |
|
elif name == "adafactor": |
|
from transformers import Adafactor |
|
return Adafactor( |
|
params, lr=base_learning_rate, scale_parameter=True, relative_step=False, |
|
warmup_init=False, eps=(1e-30, 1e-3), clip_threshold=1.0, |
|
beta1=0.9, weight_decay=0.01 |
|
) |
|
else: |
|
raise ValueError(f"Unknown optimizer: {name}") |
|
|
|
if fbp: |
|
|
|
optimizer_dict = {p: create_optimizer(optimizer_type, [p]) for p in trainable_params} |
|
|
|
def optimizer_hook(param): |
|
optimizer_dict[param].step() |
|
optimizer_dict[param].zero_grad(set_to_none=True) |
|
|
|
for param in trainable_params: |
|
param.register_post_accumulate_grad_hook(optimizer_hook) |
|
|
|
unet, optimizer = accelerator.prepare(unet, optimizer_dict) |
|
else: |
|
optimizer = create_optimizer(optimizer_type, unet.parameters()) |
|
|
|
def lr_schedule(step): |
|
x = step / (total_training_steps * world_size) |
|
warmup = warmup_percent |
|
|
|
if not use_decay: |
|
return base_learning_rate |
|
if x < warmup: |
|
return min_learning_rate + (base_learning_rate - min_learning_rate) * (x / warmup) |
|
|
|
decay_ratio = (x - warmup) / (1 - warmup) |
|
return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \ |
|
(1 + math.cos(math.pi * decay_ratio)) |
|
|
|
lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate) |
|
|
|
num_params = sum(p.numel() for p in unet.parameters()) |
|
print(f"[rank {accelerator.process_index}] total params: {num_params}") |
|
|
|
for name, param in unet.named_parameters(): |
|
if torch.isnan(param).any() or torch.isinf(param).any(): |
|
print(f"[rank {accelerator.process_index}] NaN/Inf in {name}") |
|
|
|
|
|
unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler) |
|
|
|
|
|
if dispersive_loss_enabled: |
|
dispersive_hook.register_hooks(unet, "down_blocks.2") |
|
|
|
|
|
|
|
fixed_samples = get_fixed_samples_by_resolution(dataset) |
|
|
|
@torch.compiler.disable() |
|
@torch.no_grad() |
|
def generate_and_save_samples(fixed_samples_cpu, step): |
|
""" |
|
Генерирует семплы для каждого из разрешений и сохраняет их. |
|
|
|
Args: |
|
fixed_samples_cpu: Словарь, где ключи - размеры (width, height), |
|
а значения - кортежи (latents, embeddings, text) на CPU. |
|
step: Текущий шаг обучения |
|
""" |
|
original_model = None |
|
try: |
|
|
|
original_model = accelerator.unwrap_model(unet).eval() |
|
|
|
vae.to(device=device, dtype=dtype) |
|
vae.eval() |
|
|
|
scheduler.set_timesteps(n_diffusion_steps) |
|
|
|
all_generated_images = [] |
|
all_captions = [] |
|
|
|
for size, (sample_latents, sample_text_embeddings, sample_text) in fixed_samples_cpu.items(): |
|
width, height = size |
|
|
|
sample_latents = sample_latents.to(dtype=dtype) |
|
sample_text_embeddings = sample_text_embeddings.to(dtype=dtype) |
|
|
|
|
|
|
|
noise = torch.randn( |
|
sample_latents.shape, |
|
generator=gen, |
|
device=device, |
|
dtype=sample_latents.dtype |
|
) |
|
current_latents = noise.clone() |
|
|
|
|
|
if guidance_scale > 0: |
|
|
|
empty_embeddings = torch.zeros_like(sample_text_embeddings, dtype=sample_text_embeddings.dtype, device=device) |
|
text_embeddings_batch = torch.cat([empty_embeddings, sample_text_embeddings], dim=0) |
|
else: |
|
text_embeddings_batch = sample_text_embeddings |
|
|
|
for t in scheduler.timesteps: |
|
t_batch = t.repeat(current_latents.shape[0]).to(device) |
|
|
|
if guidance_scale > 0: |
|
latent_model_input = torch.cat([current_latents] * 2) |
|
else: |
|
latent_model_input = current_latents |
|
|
|
latent_model_input_scaled = scheduler.scale_model_input(latent_model_input, t_batch) |
|
|
|
|
|
noise_pred = original_model(latent_model_input_scaled, t_batch, text_embeddings_batch).sample |
|
|
|
if guidance_scale > 0: |
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
|
current_latents = scheduler.step(noise_pred, t, current_latents).prev_sample |
|
|
|
|
|
|
|
latent_for_vae = (current_latents.detach() / vae.config.scaling_factor) + vae.config.shift_factor |
|
decoded = vae.decode(latent_for_vae).sample |
|
|
|
|
|
|
|
decoded_fp32 = decoded.to(torch.float32) |
|
for img_idx, img_tensor in enumerate(decoded_fp32): |
|
img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1, 2, 0) |
|
|
|
if np.isnan(img).any(): |
|
print("NaNs found, saving stoped! Step:", step) |
|
save_model = False |
|
pil_img = Image.fromarray((img * 255).astype("uint8")) |
|
|
|
max_w_overall = max(s[0] for s in fixed_samples_cpu.keys()) |
|
max_h_overall = max(s[1] for s in fixed_samples_cpu.keys()) |
|
max_w_overall = max(255, max_w_overall) |
|
max_h_overall = max(255, max_h_overall) |
|
|
|
padded_img = ImageOps.pad(pil_img, (max_w_overall, max_h_overall), color='white') |
|
all_generated_images.append(padded_img) |
|
|
|
caption_text = sample_text[img_idx][:200] if img_idx < len(sample_text) else "" |
|
all_captions.append(caption_text) |
|
|
|
sample_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg" |
|
pil_img.save(sample_path, "JPEG", quality=96) |
|
|
|
if use_wandb and accelerator.is_main_process: |
|
wandb_images = [ |
|
wandb.Image(img, caption=f"{all_captions[i]}") |
|
for i, img in enumerate(all_generated_images) |
|
] |
|
wandb.log({"generated_images": wandb_images, "global_step": step}) |
|
|
|
finally: |
|
vae.to("cpu") |
|
|
|
for var in list(locals().keys()): |
|
if isinstance(locals()[var], torch.Tensor): |
|
del locals()[var] |
|
|
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
|
|
if accelerator.is_main_process: |
|
if save_model: |
|
print("Генерация сэмплов до старта обучения...") |
|
generate_and_save_samples(fixed_samples,0) |
|
accelerator.wait_for_everyone() |
|
|
|
|
|
def save_checkpoint(unet,variant=""): |
|
if accelerator.is_main_process: |
|
if lora_name: |
|
|
|
save_lora_checkpoint(unet) |
|
else: |
|
|
|
if variant!="": |
|
accelerator.unwrap_model(unet.to(dtype=torch.float16)).save_pretrained(os.path.join(checkpoints_folder, f"{project}"),variant=variant) |
|
else: |
|
accelerator.unwrap_model(unet).save_pretrained(os.path.join(checkpoints_folder, f"{project}")) |
|
unet = unet.to(dtype=dtype) |
|
|
|
|
|
|
|
if accelerator.is_main_process: |
|
print(f"Total steps per GPU: {total_training_steps}") |
|
|
|
epoch_loss_points = [] |
|
progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step") |
|
|
|
|
|
steps_per_epoch = len(dataloader) |
|
sample_interval = max(1, steps_per_epoch // sample_interval_share) |
|
min_loss = 1. |
|
|
|
|
|
for epoch in range(start_epoch, start_epoch + num_epochs): |
|
batch_losses = [] |
|
batch_tlosses = [] |
|
batch_grads = [] |
|
|
|
batch_sampler.set_epoch(epoch) |
|
accelerator.wait_for_everyone() |
|
unet.train() |
|
print("epoch:",epoch) |
|
for step, (latents, embeddings) in enumerate(dataloader): |
|
with accelerator.accumulate(unet): |
|
if save_model == False and step == 5 : |
|
used_gb = torch.cuda.max_memory_allocated() / 1024**3 |
|
print(f"Шаг {step}: {used_gb:.2f} GB") |
|
|
|
|
|
noise = torch.randn_like(latents, dtype=latents.dtype) |
|
|
|
timesteps = torch.randint(steps_offset, scheduler.config.num_train_timesteps, |
|
(latents.shape[0],), device=device).long() |
|
|
|
|
|
noisy_latents = scheduler.add_noise(latents, noise, timesteps) |
|
|
|
|
|
if dispersive_loss_enabled: |
|
dispersive_hook.clear_activations() |
|
|
|
|
|
model_pred = unet(noisy_latents, timesteps, embeddings).sample |
|
target_pred = scheduler.get_velocity(latents, noise, timesteps) |
|
|
|
|
|
loss = torch.nn.functional.mse_loss(model_pred.float(), target_pred.float()) |
|
|
|
|
|
|
|
|
|
if dispersive_loss_enabled: |
|
with torch.amp.autocast('cuda', enabled=False): |
|
dispersive_loss = dispersive_hook.weight * dispersive_hook.compute_dispersive_loss() |
|
if torch.isnan(dispersive_loss) or torch.isinf(dispersive_loss): |
|
print(f"Rank {accelerator.process_index}: Found nan/inf in dispersive_loss: {total_loss}") |
|
|
|
|
|
|
|
if dispersive_loss_enabled: |
|
total_loss = loss + dispersive_loss |
|
else: |
|
total_loss = loss |
|
|
|
|
|
if torch.isnan(loss) or torch.isinf(loss): |
|
print(f"Rank {accelerator.process_index}: Found nan/inf in loss: {loss}") |
|
save_model = False |
|
break |
|
|
|
if torch.isnan(total_loss) or torch.isinf(total_loss): |
|
print(f"Rank {accelerator.process_index}: Found nan/inf in total_loss: {total_loss}") |
|
print(f"Проблемный батч: step={step}, latents.shape={latents.shape}, embeddings.shape={embeddings.shape}") |
|
continue |
|
|
|
if (global_step % 100 == 0) or (global_step % sample_interval == 0): |
|
accelerator.wait_for_everyone() |
|
|
|
|
|
accelerator.backward(total_loss) |
|
|
|
if (global_step % 100 == 0) or (global_step % sample_interval == 0): |
|
accelerator.wait_for_everyone() |
|
|
|
grad = torch.tensor(0.0, device=device) |
|
if not fbp: |
|
if accelerator.sync_gradients: |
|
with torch.amp.autocast('cuda', enabled=False): |
|
grad = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm) |
|
optimizer.step() |
|
lr_scheduler.step() |
|
optimizer.zero_grad(set_to_none=True) |
|
|
|
|
|
global_step += 1 |
|
|
|
|
|
progress_bar.update(1) |
|
|
|
|
|
if accelerator.is_main_process: |
|
if fbp: |
|
current_lr = base_learning_rate |
|
else: |
|
current_lr = lr_scheduler.get_last_lr()[0] |
|
batch_losses.append(loss.detach().item()) |
|
batch_tlosses.append(total_loss.detach().item()) |
|
batch_grads.append(grad) |
|
|
|
|
|
if use_wandb and accelerator.sync_gradients: |
|
wandb.log({ |
|
"mse_loss": loss.detach().item(), |
|
"learning_rate": current_lr, |
|
"epoch": epoch, |
|
"grad": grad, |
|
"global_step": global_step, |
|
**({"dispersive_loss": dispersive_loss} if dispersive_loss_enabled else {}), |
|
**({"total_loss": total_loss} if dispersive_loss_enabled else {}) |
|
}) |
|
|
|
|
|
if global_step % sample_interval == 0: |
|
generate_and_save_samples(fixed_samples,global_step) |
|
|
|
|
|
avg_loss = np.mean(batch_losses[-sample_interval:]) |
|
avg_tloss = np.mean(batch_tlosses[-sample_interval:]) |
|
avg_grad = torch.mean(torch.stack(batch_grads[-sample_interval:])).cpu().item() |
|
print(f"Эпоха {epoch}, шаг {global_step}, средний лосс: {avg_loss:.6f}, grad: {avg_grad:.6f}") |
|
|
|
if save_model: |
|
print("saving:",avg_loss < min_loss*save_barrier) |
|
if avg_loss < min_loss*save_barrier: |
|
min_loss = avg_loss |
|
save_checkpoint(unet) |
|
if use_wandb: |
|
wandb.log({"interm_loss": avg_loss}) |
|
wandb.log({"interm_grad": avg_grad}) |
|
if dispersive_loss_enabled: |
|
wandb.log({"interm_totalloss": avg_tloss}) |
|
|
|
|
|
|
|
if accelerator.is_main_process: |
|
avg_epoch_loss = np.mean(batch_losses) |
|
print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}") |
|
if use_wandb: |
|
wandb.log({"epoch_loss": avg_epoch_loss, "epoch": epoch+1}) |
|
|
|
|
|
if dispersive_loss: |
|
dispersive_hook.remove_hooks() |
|
if accelerator.is_main_process: |
|
print("Обучение завершено! Сохраняем финальную модель...") |
|
|
|
if save_model: |
|
save_checkpoint(unet,"fp16") |
|
accelerator.free_memory() |
|
if torch.distributed.is_initialized(): |
|
torch.distributed.destroy_process_group() |
|
|
|
print("Готово!") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|