| | import ast |
| | from copy import deepcopy |
| | import json |
| | import math |
| | import os |
| | import pickle |
| | import random |
| | import shutil |
| | import string |
| | import time |
| | from datetime import datetime |
| | from pathlib import Path |
| | from typing import Optional |
| | import pandas as pd |
| | from constants import UNIDISC_DIR |
| | from data_defs import InterleavedBatch |
| | import einops |
| | import numpy as np |
| | from unidisc.utils.simple_llm import get_llm |
| | from unidisc.utils.viz_utils import augment_image_with_random_object_coco, create_text_image |
| | import torch |
| | import torch.utils.checkpoint |
| | from accelerate.utils import gather, gather_object |
| | from image_utils import Im |
| | from jaxtyping import Bool, Float, Integer |
| | from PIL import Image |
| | from tensordict import TensorDict, tensorclass |
| | from torch import Tensor |
| | from tqdm import tqdm |
| | from collections import defaultdict |
| | import torch.nn.functional as F |
| | import utils |
| | import wandb |
| | from decoupled_utils import (barrier, dprint, get_num_gpus, get_rank, get_world_size, |
| | gprint, is_main_process, print_memory_summary, |
| | rprint, save_memory_profile, show_memory_usage, try_except, sanitize_filename) |
| | from unidisc.tokenizers.chameleon_tokenizers import (decode_ids_batched, |
| | get_chameleon_images) |
| | from unidisc.tokenizers.image_tokenizers import decode_latents, get_image_batch |
| | from unidisc.utils.throughput_monitor import get_available_flops |
| | from model_utils import (_sample_categorical, empty_device_cache, get_chameleon_txt_indices, get_interleaved_block_mask, log, |
| | remap_image_torch, replace_nan_dict, |
| | wrapped_batch_decode) |
| | from torch import nn |
| | from model_utils import get_block_mask, MauveScore, Entropy |
| |
|
| | def get_anole_data(self, model, processor, prompt, image, dtype, device): |
| | inputs = processor(text=prompt, images=[image], padding=True, return_tensors="pt").to(device=device, dtype=dtype) |
| | image_tokens = model.model.get_image_tokens(inputs["pixel_values"]) |
| | special_image_mask = inputs["input_ids"] == model.model.vocabulary_mapping.image_token_id |
| | image_tokens = image_tokens.to(inputs["input_ids"].device, inputs["input_ids"].dtype) |
| | inputs["input_ids"] = inputs["input_ids"].masked_scatter(special_image_mask, image_tokens) |
| | inputs.pop("pixel_values") |
| | return inputs |
| |
|
| | def calculate_chameleon_perplexity(self, model, processor, prompts, images, dtype=torch.bfloat16, return_all=False, standalone=False): |
| | """ |
| | Calculate perplexities for multiple prompts and images using the Chameleon model. |
| | |
| | Args: |
| | model (ChameleonForConditionalGeneration): The Chameleon model. |
| | processor (ChameleonProcessor): The Chameleon processor. |
| | prompts (List[str]): List of prompt strings. |
| | images (List[Image.Image]): List of PIL Image objects. |
| | device (str): The device to use for computation (default: "cuda:0"). |
| | dtype (torch.dtype): The data type to use (default: torch.bfloat16). |
| | |
| | Returns: |
| | List[float]: List of perplexities for each prompt-image pair. |
| | """ |
| | device = self.device |
| | if model is None or processor is None: |
| | model = getattr(self, "chameleon_model", None) |
| | processor = getattr(self, "chameleon_processor", None) |
| | if model is None: |
| | from image_utils import Im |
| | from transformers import (ChameleonForConditionalGeneration, ChameleonProcessor) |
| | self.chameleon_model = ChameleonForConditionalGeneration.from_pretrained("leloy/Anole-7b-v0.1-hf", torch_dtype=torch.bfloat16).to("cuda") |
| | self.chameleon_processor = ChameleonProcessor.from_pretrained("leloy/Anole-7b-v0.1-hf") |
| | |
| | model = self.chameleon_model |
| | processor = self.chameleon_processor |
| | assert len(prompts) == len(images), "Number of prompts and images must match" |
| | |
| | perplexities = [] |
| |
|
| | for prompt, image in zip(prompts, images): |
| | if not standalone: |
| | txt_first_prompt = f"{prompt} <image>" |
| | img_first_prompt = f"<image> {prompt}" |
| | else: |
| | txt_first_prompt = prompt |
| | img_first_prompt = "<image>" |
| | tot_ppl = 0.0 |
| | tot_loss = 0.0 |
| | img_loss = 0.0 |
| | txt_loss = 0.0 |
| | for i, _prompt in enumerate([txt_first_prompt, img_first_prompt]): |
| | inputs = self.get_anole_data(model, processor, _prompt, image, dtype, device) |
| | img_start_tok_id = self.chameleon_processor.tokenizer(self.chameleon_processor.image_start_token)['input_ids'][1] |
| | img_end_tok_id = self.chameleon_processor.tokenizer(self.chameleon_processor.image_end_token)['input_ids'][1] |
| | if i == 0: |
| | |
| | mod_mask = torch.cumsum(inputs['input_ids'] == img_start_tok_id, dim=1).bool() |
| | else: |
| | |
| | mod_mask = torch.cumsum(inputs['input_ids'] == img_end_tok_id, dim=1).bool() |
| | mod_mask = mod_mask.cumsum(dim=1) > 1 |
| | output = model( |
| | input_ids=inputs['input_ids'].to(device), |
| | attention_mask=inputs['attention_mask'].to(device), |
| | labels=inputs['input_ids'].to(device) |
| | ) |
| | loss = output.loss |
| | perplexity = torch.exp(loss).item() |
| | tot_ppl += perplexity |
| | logits = output.logits |
| | logits = logits.transpose(-1, -2) |
| | sample_chunk = inputs["input_ids"] |
| | nlls = F.cross_entropy(logits[..., :-1].to(self.device), sample_chunk[..., 1:].to(self.device), reduction="none") |
| | mod_mask = mod_mask[:, 1:] |
| | |
| | zeros = torch.zeros_like(nlls) |
| | img_nll = torch.where(mod_mask, nlls, zeros).mean().item() |
| | txt_nll = torch.where(~mod_mask, nlls, zeros).mean().item() |
| | tot_loss += loss.item() |
| | if not standalone: |
| | txt_loss += txt_nll |
| | img_loss += img_nll |
| | else: |
| | if i == 0: |
| | txt_loss += loss.item() |
| | else: |
| | img_loss += loss.item() |
| |
|
| | if not standalone: |
| | tot_ppl /= 2 |
| | tot_loss /= 2 |
| | img_loss /= 2 |
| | txt_loss /= 2 |
| |
|
| | if return_all: |
| | perplexities.append((tot_ppl, tot_loss, img_loss, txt_loss)) |
| | else: |
| | perplexities.append(tot_ppl) |
| |
|
| | print(f"Total PPL: {tot_ppl} | Total Loss: {tot_loss} | Img Loss: {img_loss} | Txt Loss: {txt_loss}") |
| | return perplexities |
| |
|
| | def get_every_n_evals(self, n): |
| | return ( |
| | self.config.mode == "eval" |
| | or ((self.num_evals > 0 or getattr(self.config.eval, "log_on_start", False)) and n > 0 and self.num_evals % n == 0) |
| | ) and n != -1 |
| |
|
| | @try_except(write_error_to_file=True) |
| | def on_validation_epoch_start(self): |
| | rprint("on_validation_epoch_start") |
| | |
| | |
| | if self.ema is not None and not self.config.trainer.use_custom_ema: |
| | |
| | rprint(" [WARNING] USING EMA IN on_validation_epoch_start - THIS MIGHT RESET LOADED WEIGHTS ".center(100, "!")) |
| | self.ema.store(self.get_params()) |
| | |
| | self.ema.copy_to(self.get_params()) |
| |
|
| | self.backbone.eval() |
| | self.reset_validation_metrics() |
| |
|
| | if getattr(self.config.trainer, "disable_torchmetrics", False) is False: |
| | assert self.valid_metrics.nll.mean_value == 0 |
| | assert self.valid_metrics.nll.weight == 0 |
| | if self.non_embedding_params < 1e9: |
| | self.print_hashes() |
| | if ( |
| | self.image_model |
| | and getattr(self.config.model, "image_model_fid_eval", False) |
| | and self.get_every_n_evals(getattr(self.config.eval, "log_every_n_fid", 10)) |
| | ): |
| | |
| | self.fid_eval = True |
| | if self.config.eval.fid_mode == "inline": |
| | from vqgan.inception_metrics import MultiInceptionMetrics |
| | self.inception_metrics = MultiInceptionMetrics( |
| | reset_real_features=False, |
| | compute_unconditional_metrics=True, |
| | compute_conditional_metrics=False, |
| | compute_conditional_metrics_per_class=False, |
| | num_classes=1000, |
| | num_inception_chunks=10, |
| | manifold_k=3, |
| | ) |
| | if self.config.mode == "eval": |
| | self.computed_tokens = [] |
| | else: |
| | if getattr(self.config.eval, "force_fid_output_dir", None) is None: |
| | shm_path = Path("/dev/shm") / os.getenv("USER") |
| | fid_save_path = shm_path / Path(self.config.output_dir).parent.stem / Path(self.config.output_dir).stem / f"{self.num_evals}_{self.global_step}" / "fid_gen" |
| | else: |
| | fid_save_path = Path(getattr(self.config.eval, "force_fid_output_dir", None)) / "fid_gen" |
| | fid_save_path.mkdir(parents=True, exist_ok=True) |
| | fid_gt_path = fid_save_path.parent / (fid_save_path.name.replace("gen", "gt")) |
| | fid_gt_path.mkdir(parents=True, exist_ok=True) |
| | self.fid_gen_dir = fid_save_path |
| | self.fid_gt_dir = fid_gt_path |
| | rprint(f"FID eval output dir: {self.fid_gen_dir}, FID GT dir: {self.fid_gt_dir}") |
| |
|
| | rprint(f"Setting FID eval for epoch {self.num_evals}") |
| | else: |
| | self.fid_eval = False |
| | if self.image_model and getattr(self.config.model, "image_model_fid_eval", False): |
| | rprint(f"Not setting FID eval: num_evals: {self.num_evals} % {getattr(self.config.eval, 'log_every_n_fid', 10)}") |
| |
|
| | if self.config.eval.compute_img_to_txt_mauve_clip: |
| | shm_path = Path("/dev/shm") / os.getenv("USER") |
| | img_to_txt_mauve_save_path = shm_path / Path(self.config.output_dir).parent.stem / Path(self.config.output_dir).stem / f"{self.num_evals}_{self.global_step}" / "img_to_txt_mauve_gen" |
| | img_to_txt_mauve_save_path.mkdir(parents=True, exist_ok=True) |
| | img_to_txt_mauve_gt_path = img_to_txt_mauve_save_path.parent / (img_to_txt_mauve_save_path.name.replace("gen", "gt")) |
| | img_to_txt_mauve_gt_path.mkdir(parents=True, exist_ok=True) |
| | self.img_to_txt_mauve_gen_dir = img_to_txt_mauve_save_path |
| | self.img_to_txt_mauve_gt_dir = img_to_txt_mauve_gt_path |
| | rprint(f"Img to txt mauve eval gen dir: {self.img_to_txt_mauve_gen_dir}, gt dir: {self.img_to_txt_mauve_gt_dir}") |
| |
|
| | self.saved_tokens = defaultdict(list) |
| | self.validation_start_time = time.time() |
| |
|
| | if getattr(self.config.trainer, "attach_oom_observer_eval", False): |
| | from torchtnt.utils.oom import attach_oom_observer |
| | attach_oom_observer(output_dir=str(self.config.output_dir), trace_max_entries=1000000) |
| | rprint(f"Attached OOM observer to {self.config.output_dir}") |
| | self.gpu_memory_reserved = torch.cuda.memory_reserved() |
| |
|
| |
|
| | def sample(self, return_input_ids=False, **kwargs): |
| | continuous_mode = self.config.trainer.image_mode == "continuous" |
| | text_only = kwargs.get("text_only", False) |
| | kwargs.pop("text_only", None) |
| | assert not continuous_mode |
| | txt_tokens, img_tokens = self._sample(text_only=text_only, **kwargs) |
| | if img_tokens is not None: |
| | img_pred = decode_latents(self.config, self.get_vae(), img_tokens) |
| | else: |
| | img_pred = None |
| | if txt_tokens is not None: |
| | txt_pred = wrapped_batch_decode(self.tokenizer, txt_tokens, clean_up_tokenization_spaces=True, skip_special_tokens=True) |
| | else: |
| | txt_pred = None |
| | if return_input_ids: |
| | return txt_pred, img_pred, txt_tokens, img_tokens |
| | else: |
| | return txt_pred, img_pred |
| |
|
| |
|
| | @torch.no_grad() |
| | def predict_step(self, batch, batch_idx, dataloader_idx=0): |
| | batch = self.update_batch(batch) |
| | assert (batch["input_ids"][~batch["x0_unmask"]] == self.mask_index).all() |
| | txt_pred, img_pred, txt_tokens, img_tokens = self.sample(x0=batch["input_ids"], x0_unmask=batch["x0_unmask"], return_input_ids=True) |
| | batch.update(dict(txt_pred=txt_pred, img_pred=img_pred, txt_tokens=txt_tokens, img_tokens=(img_tokens + self.text_vocab_size))) |
| | return batch |
| |
|
| | @torch.no_grad() |
| | def zero_shot_eval_step(self, batch, batch_idx): |
| | batch = self.zero_shot_update_batch(batch) |
| | dataset_name = self.config.data.train |
| | |
| | def get_similarity(x0, batch, num_timesteps=None, txt_cond=True, return_unweighed=False, do_unconditional=False): |
| | |
| | |
| | return_unweighed = return_unweighed or getattr(self.config.eval, "return_unweighed_sim", False) |
| | class_log_probs = [] |
| | unweighed_class_log_probs = [] |
| | num_timesteps = num_timesteps or self.config.sampling.steps |
| | effective_batch_size = batch['modality'].shape[0] |
| | empty_device_cache() |
| | times = torch.linspace(0, 1, steps=num_timesteps + 2)[1:-1].to(self.device).to(torch.float32) |
| | |
| | if getattr(self.config.eval, "use_random_timesteps_same_batch", False): |
| | times = torch.rand(num_timesteps, device=x0.device) |
| | times = torch.sort(times)[0] |
| | |
| | if getattr(self.config.eval, "use_random_timesteps_diff_batch", False): |
| | |
| | times = torch.rand(effective_batch_size, num_timesteps, device=x0.device) |
| | times = torch.sort(times)[0] |
| | print(f'Times: {times}') |
| | |
| | do_unconditional = do_unconditional or getattr(self.config.eval, "do_unconditional", False) |
| | |
| | cond_mask = torch.full_like(x0, False, device=x0.device).bool() |
| | if txt_cond: |
| | cond_mask[:, :self.config.model.txt_length] = True |
| | else: |
| | |
| | cond_mask[:, self.config.model.txt_length:] = True |
| | full_mask = torch.full_like(x0, self.mask_index, device=x0.device) |
| | pad_mask = x0 == self.tokenizer.pad_token_id |
| | rprint(f'Getting similarity with {times.shape[0]} timesteps, {effective_batch_size} samples, {do_unconditional} unconditional, {self.parameterization} parameterization, {self.config.eval.cfg} cfg, {num_timesteps} num_timesteps, {txt_cond} txt_cond') |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | for i in range(num_timesteps): |
| | empty_device_cache() |
| | if getattr(self.config.eval, "use_random_timesteps_diff_batch", False): |
| | t = times[:, i] |
| | else: |
| | t = times[i] |
| | t = t.expand(effective_batch_size) |
| | sigma, dsigma = self.noise(t) |
| | |
| | unet_conditioning = None |
| | move_chance = 1 - torch.exp(-sigma[:, None]) |
| |
|
| | xt, ignore_batch_mask_for_metrics, joint_ar_nar_mask, _, __ = self.q_xt(x0, move_chance, return_ignore_batch_mask_for_metrics=True, batch=batch) |
| | if not do_unconditional: |
| | cond = torch.where(cond_mask, x0, xt) |
| | if self.config.eval.cfg is not None: |
| | uncond = torch.where(cond_mask, full_mask, xt) |
| | cond_output = self.forward( |
| | cond, unet_conditioning, return_additional_loss=True, batch=batch, x_img_emb=None, joint_ar_nar_mask=joint_ar_nar_mask, modality=batch['modality'], return_logits=True |
| | ) |
| | uncond_output = self.forward( |
| | uncond, unet_conditioning, return_additional_loss=True, batch=batch, x_img_emb=None, joint_ar_nar_mask=joint_ar_nar_mask, modality=batch['modality'], return_logits=True |
| | ) |
| | cat_output = torch.stack([cond_output, uncond_output]) |
| | logits = cfg(self.config, t, cat_output).squeeze(0) |
| | model_output = self._subs_parameterization(logits, xt=xt, batch=batch, modality=batch['modality']) |
| | else: |
| | |
| | model_output = self.forward( |
| | cond, unet_conditioning, return_additional_loss=True, batch=batch, x_img_emb=None, joint_ar_nar_mask=joint_ar_nar_mask, modality=batch['modality'] |
| | ) |
| | else: |
| | if self.config.eval.cfg is not None: |
| | uncond = torch.where(cond_mask, full_mask, xt) |
| | cond_output = self.forward( |
| | xt, unet_conditioning, return_additional_loss=True, batch=batch, x_img_emb=None, joint_ar_nar_mask=joint_ar_nar_mask, modality=batch['modality'], return_logits=True |
| | ) |
| | uncond_output = self.forward( |
| | uncond, unet_conditioning, return_additional_loss=True, batch=batch, x_img_emb=None, joint_ar_nar_mask=joint_ar_nar_mask, modality=batch['modality'], return_logits=True |
| | ) |
| | cat_output = torch.stack([cond_output, uncond_output]) |
| | logits = cfg(self.config, t, cat_output).squeeze(0) |
| | model_output = self._subs_parameterization(logits, xt=xt, batch=batch, modality=batch['modality']) |
| | else: |
| | |
| | model_output = self.forward( |
| | xt, unet_conditioning, return_additional_loss=True, batch=batch, x_img_emb=None, joint_ar_nar_mask=joint_ar_nar_mask, modality=batch['modality'] |
| | ) |
| |
|
| | |
| | |
| | log_p_theta = torch.gather(input=model_output, dim=-1, index=x0[:, :, None]).squeeze(-1) |
| | |
| | zeros = torch.zeros_like(log_p_theta) |
| | log_p_theta = torch.where(pad_mask, zeros, log_p_theta) |
| | |
| | if not do_unconditional: |
| | log_p_theta = torch.where(cond_mask, zeros, log_p_theta) |
| | |
| | std_weighting = (dsigma / torch.expm1(sigma))[:, None] |
| | unweighed_log_p_theta = -log_p_theta |
| | loss = -log_p_theta * std_weighting |
| | log_probs = loss.sum(dim=-1) / (~pad_mask).sum(dim=-1) |
| | unweighed_log_probs = unweighed_log_p_theta.sum(dim=-1) / (~pad_mask).sum(dim=-1) |
| | |
| | class_log_probs.append(log_probs) |
| | unweighed_class_log_probs.append(unweighed_log_probs) |
| | overall_time_log_probs = torch.stack(class_log_probs) |
| | unweighed_overall_time_log_probs = torch.stack(unweighed_class_log_probs) |
| | if return_unweighed: |
| | return unweighed_overall_time_log_probs.mean(dim=0) |
| | return overall_time_log_probs.mean(dim=0) |
| |
|
| | def get_similarity_ar(x0, batch, txt_cond=True, do_unconditional=False, **kwargs): |
| | |
| | img_first = kwargs.get("img_first", False) |
| | if img_first: |
| | x0 = torch.cat([x0[:, self.config.model.txt_length:], x0[:, :self.config.model.txt_length]], dim=1) |
| | mod = batch['modality'] |
| | mod = torch.cat([mod[:, self.config.model.txt_length:], mod[:, :self.config.model.txt_length]], dim=1) |
| | else: |
| | mod = batch['modality'] |
| | empty_device_cache() |
| | do_unconditional = do_unconditional or getattr(self.config.eval, "do_unconditional", False) |
| |
|
| | if getattr(self.config.eval, "cfg", None): |
| | rprint('NOT SETTING CFG for AR') |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | model_output = self.forward(x=x0, sigma=None, modality=mod) |
| | x0 = x0[:, 1:] |
| | |
| | attention_mask = x0 != self.tokenizer.pad_token_id |
| | log_p_theta = model_output.gather(-1, x0[:, :, None])[:, :, 0] |
| | if img_first: |
| | txt_sl = slice(self.config.model.img_length-1, None) |
| | img_sl = slice(None, self.config.model.img_length-1) |
| | else: |
| | txt_sl = slice(None, self.config.model.txt_length - 1) |
| | img_sl = slice(self.config.model.txt_length - 1, None) |
| | nll = (-log_p_theta * attention_mask).sum(dim=-1) / attention_mask.sum(dim=-1) |
| | txt_nll = (-log_p_theta[:, txt_sl] * attention_mask[:, txt_sl]).sum(dim=-1) / attention_mask[:, txt_sl].sum(dim=-1) |
| | img_nll = (-log_p_theta[:, img_sl] * attention_mask[:, img_sl]).sum(dim=-1) / attention_mask[:, img_sl].sum(dim=-1) |
| | if do_unconditional: |
| | return nll |
| | return img_nll if txt_cond else txt_nll |
| | |
| | def get_similarity_chameleon(zipp, batch, txt_cond=True, do_unconditional=False, prompts=None, images=None, **kwargs): |
| | |
| | empty_device_cache() |
| | img_first = kwargs.get("img_first", False) |
| | img_start_tok_id = self.chameleon_processor.tokenizer(self.chameleon_processor.image_start_token)['input_ids'][1] |
| | img_end_tok_id = self.chameleon_processor.tokenizer(self.chameleon_processor.image_end_token)['input_ids'][1] |
| | do_unconditional = do_unconditional or getattr(self.config.eval, "do_unconditional", False) |
| | if not prompts and not images: |
| | prompt, image = zipp |
| | if img_first: |
| | _prompt = f"<image> {prompt}" |
| | else: |
| | _prompt = f"{prompt} <image>" |
| | inputs = self.get_anole_data(self.chameleon_model, self.chameleon_processor, _prompt, image, dtype=self.dtype, device=self.device) |
| |
|
| | else: |
| | inputs = self.get_anole_data(self.chameleon_model, self.chameleon_processor, prompts, images, dtype=self.dtype, device=self.device) |
| | |
| | |
| | if img_first: |
| | mod_mask = torch.cumsum(inputs['input_ids'] == img_end_tok_id, dim=1).bool() |
| | else: |
| | mod_mask = torch.cumsum(inputs['input_ids'] == img_start_tok_id, dim=1).bool() |
| | |
| | mod_mask = mod_mask.cumsum(dim=1) > 1 |
| | output = self.chameleon_model( |
| | input_ids=inputs['input_ids'].to(self.device), |
| | attention_mask=inputs['attention_mask'].to(self.device), |
| | labels=inputs['input_ids'].to(self.device) |
| | ) |
| | loss = output.loss |
| | logits = output.logits |
| | logits = logits.transpose(-1, -2) |
| | sample_chunk = inputs["input_ids"] |
| | nlls = F.cross_entropy(logits[..., :-1].to(self.device), sample_chunk[..., 1:].to(self.device), reduction="none") |
| | mod_mask = mod_mask[:, 1:] |
| | |
| | zeros = torch.zeros_like(nlls) |
| | img_nll = torch.where(mod_mask, nlls, zeros) |
| | txt_nll = torch.where(~mod_mask, nlls, zeros) |
| | if do_unconditional: |
| | return nlls.mean(dim=-1) |
| | return img_nll.mean(dim=-1) if txt_cond else txt_nll.mean(dim=-1) |
| | |
| | if dataset_name == "nlphuji/flickr30k": |
| | txt_tokens, img_tokens = self._sample( |
| | text_only=False, |
| | x0=batch["input_ids"], |
| | x0_unmask=batch["attention_mask"], |
| | modality=batch["modality"], |
| | ) |
| | img_samples = decode_latents(self.config, self.get_vae(), img_tokens[:, :self.config.model.img_length]) |
| | txt_samples = wrapped_batch_decode(self.tokenizer, txt_tokens[:, self.config.model.img_length:], clean_up_tokenization_spaces=True, skip_special_tokens=True, disable_mask_after_eos=self.config.data.disable_mask_after_eos) |
| | gt_text_samples = wrapped_batch_decode(self.tokenizer, batch['gt_input_ids'][:, :self.config.model.txt_length], skip_special_tokens=True, clean_up_tokenization_spaces=True, disable_mask_after_eos=self.config.data.disable_mask_after_eos) |
| | self.compute_cider(txt_samples, gt_text_samples) |
| | elif dataset_name == "facebook/winoground": |
| | |
| | |
| | |
| | a0_0 = batch["input_ids_0_0"] |
| | a0_1 = batch["input_ids_0_1"] |
| | a1_0 = batch["input_ids_1_0"] |
| | a1_1 = batch["input_ids_1_1"] |
| | |
| | text_correct_count = 0 |
| | image_correct_count = 0 |
| | group_correct_count = 0 |
| | |
| | wino_chameleon = getattr(self.config.eval, "wino_chameleon", False) |
| |
|
| | s0_0, s0_1, s1_0, s1_1 = None, None, None, None |
| | modes = ['image', 'text', 'group'] |
| | |
| | if wino_chameleon: |
| | txt0 = wrapped_batch_decode(tokens=batch['caption_0_input_ids'], tokenizer=self.tokenizer, clean_up_tokenization_spaces=True, skip_special_tokens=True, disable_mask_after_eos=self.config.data.disable_mask_after_eos)[0] |
| | txt1 = wrapped_batch_decode(tokens=batch['caption_1_input_ids'], tokenizer=self.tokenizer, clean_up_tokenization_spaces=True, skip_special_tokens=True, disable_mask_after_eos=self.config.data.disable_mask_after_eos)[0] |
| | img0 = Im(batch['img_0']).pil |
| | img1 = Im(batch['img_1']).pil |
| | prompts = [txt0, txt0, txt1, txt1] |
| | images = [img0, img1, img0, img1] |
| | zipp = list(zip(prompts, images)) |
| |
|
| | |
| | def text_correct(result): |
| | return torch.logical_and(result["s0_i0"] < result["s1_i0"], result["s1_i1"] < result["s0_i1"]) |
| | |
| | def image_correct(result): |
| | return torch.logical_and(result["s0_i0"] < result["s0_i1"], result["s1_i1"] < result["s1_i0"]) |
| | |
| | def group_correct(result): |
| | return torch.logical_and(image_correct(result), text_correct(result)) |
| | results_cond = {} |
| | for mode in modes: |
| | do_unconditional = (mode == 'group') |
| | txt_cond = not (mode == 'text') |
| | img_first = mode == 'text' |
| | if wino_chameleon: |
| | do_unconditional = True |
| | s0_0 = get_similarity_chameleon(zipp[0], batch, txt_cond=False, do_unconditional=do_unconditional, img_first=img_first) |
| | s0_1 = get_similarity_chameleon(zipp[1], batch, txt_cond=False, do_unconditional=do_unconditional, img_first=img_first) |
| | s1_0 = get_similarity_chameleon(zipp[2], batch, txt_cond=False, do_unconditional=do_unconditional, img_first=img_first) |
| | s1_1 = get_similarity_chameleon(zipp[3], batch, txt_cond=False, do_unconditional=do_unconditional, img_first=img_first) |
| | elif self.parameterization == "ar": |
| | s0_0 = get_similarity_ar(a0_0, batch, txt_cond=False, do_unconditional=do_unconditional, img_first=img_first) |
| | s0_1 = get_similarity_ar(a0_1, batch, txt_cond=False, do_unconditional=do_unconditional, img_first=img_first) |
| | s1_0 = get_similarity_ar(a1_0, batch, txt_cond=False, do_unconditional=do_unconditional, img_first=img_first) |
| | s1_1 = get_similarity_ar(a1_1, batch, txt_cond=False, do_unconditional=do_unconditional, img_first=img_first) |
| | else: |
| | s0_0 = get_similarity(a0_0, batch, txt_cond=txt_cond, do_unconditional=do_unconditional) |
| | s0_1 = get_similarity(a0_1, batch, txt_cond=txt_cond, do_unconditional=do_unconditional) |
| | s1_0 = get_similarity(a1_0, batch, txt_cond=txt_cond, do_unconditional=do_unconditional) |
| | s1_1 = get_similarity(a1_1, batch, txt_cond=txt_cond, do_unconditional=do_unconditional) |
| | result = { |
| | "s0_i0": s0_0, |
| | "s0_i1": s0_1, |
| | "s1_i0": s1_0, |
| | "s1_i1": s1_1, |
| | } |
| | if mode == 'text': |
| | results_cond['text'] = text_correct(result) |
| | text_correct_count += text_correct(result).sum().item() |
| | elif mode == 'image': |
| | results_cond['image'] = image_correct(result) |
| | image_correct_count += image_correct(result).sum().item() |
| | elif mode == 'group': |
| | if getattr(self.config.eval, "wino_group_conditional", False): |
| | rprint('[Winoground] Using conditional group accuracy') |
| | group_correct_count = (torch.logical_and(results_cond['text'], results_cond['image'])).sum().item() |
| | else: |
| | rprint('[Winoground] Using unconditional group accuracy') |
| | group_correct_count += group_correct(result).sum().item() |
| | bsz = a0_0.shape[0] |
| | txt_acc = text_correct_count / bsz |
| | img_acc = image_correct_count / bsz |
| | group_acc = group_correct_count / bsz |
| | |
| | self.win_text_accuracy.update(txt_acc) |
| | self.win_image_accuracy.update(img_acc) |
| | self.win_group_accuracy.update(group_acc) |
| | running_avg_txt = self.win_text_accuracy.compute() |
| | running_avg_img = self.win_image_accuracy.compute() |
| | running_avg_group = self.win_group_accuracy.compute() |
| | rprint(f"[{batch_idx}] Winoground Text Accuracy: {txt_acc} ({running_avg_txt}), Image Accuracy: {img_acc} ({running_avg_img}), Group Accuracy: {group_acc} ({running_avg_group})") |
| | else: |
| | |
| | |
| | x0 = batch['input_ids'] |
| | img_first = getattr(self.config.model, "img_first", False) |
| | only_one_correct = getattr(self.config.eval, "only_one_correct", False) |
| | wino_chameleon = getattr(self.config.eval, "wino_chameleon", False) |
| | |
| | x0_txt = x0.clone() |
| | x0_img = x0.clone() |
| | if only_one_correct: |
| | |
| | x0c = x0.clone() |
| | if img_first: |
| | second_half = x0c[1:, self.config.model.img_length:] |
| | else: |
| | second_half = x0c[1:, self.config.model.txt_length:] |
| | |
| | |
| | second_half = torch.cat([second_half[1:], second_half[0].unsqueeze(0)], dim=0) |
| | |
| | if img_first: |
| | x0c[1:, self.config.model.img_length:] = second_half |
| | else: |
| | x0c[1:, self.config.model.txt_length:] = second_half |
| | if wino_chameleon: |
| | if img_first: |
| | img_tokens = x0c[:, :self.config.model.img_length] |
| | txt_tokens = x0c[:, self.config.model.img_length:] |
| | else: |
| | txt_tokens = x0c[:, :self.config.model.txt_length] |
| | img_tokens = x0c[:, self.config.model.txt_length:] |
| | dec_txt = wrapped_batch_decode(self.tokenizer, txt_tokens, clean_up_tokenization_spaces=True, skip_special_tokens=True, disable_mask_after_eos=self.config.data.disable_mask_after_eos) |
| | dec_imgs = decode_latents(self.config, self.get_vae(), img_tokens - self.text_vocab_size) |
| | dec_imgs = [Im(img).pil for img in dec_imgs] |
| | if img_first: |
| | |
| | dec_txt = ['<image> ' + txt for txt in dec_txt] |
| | else: |
| | dec_txt = [txt + ' <image>' for txt in dec_txt] |
| | class_sim = get_similarity_chameleon(None, batch, do_unconditional=True, img_first=img_first, prompts=dec_txt, images=dec_imgs) |
| | if torch.isinf(class_sim).any(): |
| | rprint(f'[Chameleon] Inf found in class_sim, check transformers version') |
| | breakpoint() |
| | elif self.parameterization == "ar": |
| | class_sim = get_similarity_ar(x0c, batch, do_unconditional=True) |
| | else: |
| | class_sim = get_similarity(x0c, batch, do_unconditional=True) |
| | |
| | topk = class_sim.topk(k=1, dim=0, largest=False) |
| | topk_indices = topk.indices |
| | topk_acc = (topk_indices == 0).float().mean().item() |
| | rprint(f"[{batch_idx}] Datacomp Correct Pair Retrieval Acc: {topk_acc} ({self.datacomp_img_acc.compute()})") |
| | self.datacomp_img_acc.update(topk_acc) |
| | else: |
| | if img_first: |
| | |
| | x0_txt[:, self.config.model.img_length:] = x0[0, self.config.model.img_length:] |
| | |
| | |
| | x0_img[:, :self.config.model.img_length] = x0[0, :self.config.model.img_length] |
| | else: |
| | |
| | x0_txt[:, :self.config.model.txt_length] = x0[0, :self.config.model.txt_length] |
| | |
| | |
| | x0_img[:, self.config.model.txt_length:] = x0[0, self.config.model.txt_length:] |
| |
|
| | if self.parameterization == "ar": |
| | txt_class_sim = get_similarity_ar(x0_txt, batch, txt_cond=True) |
| | img_class_sim = get_similarity_ar(x0_img, batch, txt_cond=True) |
| | else: |
| | txt_class_sim = get_similarity(x0_txt, batch, txt_cond=True) |
| | img_class_sim = get_similarity(x0_img, batch, txt_cond=False) |
| | |
| | img_topk = img_class_sim.topk(k=1, dim=0, largest=False) |
| | txt_topk = txt_class_sim.topk(k=1, dim=0, largest=False) |
| | |
| | img_topk_indices = img_topk.indices |
| | txt_topk_indices = txt_topk.indices |
| | |
| | img_acc = (img_topk_indices == 0).float().mean().item() |
| | txt_acc = (txt_topk_indices == 0).float().mean().item() |
| | rprint(f"[{batch_idx}] Datacomp Text Retrieval Acc: {img_acc}, Datacomp Image Retrieval Accuracy: {txt_acc}") |
| | self.datacomp_img_acc.update(img_acc) |
| | self.datacomp_txt_acc.update(txt_acc) |
| | |
| | |
| | @torch.no_grad() |
| | def validation_step(self, batch, batch_idx): |
| | batch = self.update_batch(batch) |
| | continuous_mode = self.config.trainer.image_mode == "continuous" |
| |
|
| | if self.config.mode == "eval": |
| | logs = dict() |
| | logs["gpu_max_mem_reserved_gb"] = torch.cuda.max_memory_reserved() / (1024**3) |
| | logs["gpu_cur_mem_reserved_gb"] = torch.cuda.memory_reserved() / (1024**3) |
| | logs["gpu_max_mem_allocated_gb"] = torch.cuda.max_memory_allocated() / (1024**3) |
| | logs["gpu_cur_mem_allocated_gb"] = torch.cuda.memory_allocated() / (1024**3) |
| | log({**logs, **self.get_step_metrics()}) |
| |
|
| | if self.get_every_n_evals(getattr(self.config.eval, "log_every_n_evals", 10)) \ |
| | and self.image_model \ |
| | and (batch_idx == 0 or self.config.eval.visualize_data_only) \ |
| | and not continuous_mode: |
| | self.visualize_samples(batch, batch_idx) |
| | if self.config.eval.visualize_data_only: return |
| | |
| | if batch_idx < self.config.eval.num_sample_batches and self.config.eval.compute_generative_perplexity: |
| | if continuous_mode: |
| | |
| | gt_text_samples = wrapped_batch_decode(self.tokenizer, batch['text_tokens'][:, :self.config.model.txt_length], skip_special_tokens=True, clean_up_tokenization_spaces=True, disable_mask_after_eos=self.config.data.disable_mask_after_eos) |
| | else: |
| | input_ids = batch["input_ids"] |
| | pad_tokens = torch.full_like(input_ids, self.tokenizer.pad_token_id) |
| | text_tokens = torch.where(batch["modality"] == 0, input_ids, pad_tokens) |
| | gt_text_samples = wrapped_batch_decode(self.tokenizer, text_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True, disable_mask_after_eos=self.config.data.disable_mask_after_eos) |
| | if getattr(self.config.trainer, "disable_text_modality", False): |
| | gt_text_samples = [' '] |
| | self.compute_generative_perplexity(gt_text_samples, gt=True) |
| | |
| | if getattr(self.config.trainer, "log_flops", False) \ |
| | and batch_idx == 0 \ |
| | and self.current_run_global_step <= 1 \ |
| | and self.config.trainer.fsdp is False: |
| | self.log_flops(batch=batch, batch_idx=batch_idx) |
| | if self.fid_eval: |
| | if self.config.eval.fid_mode == "inline": |
| | self.update_inline_fid(batch, batch_idx) |
| | elif self.config.eval.fid_mode == "clean": |
| | self.update_clean_fid(batch, batch_idx) |
| | else: |
| | raise ValueError(f"Invalid FID mode: {self.config.eval.fid_mode}") |
| |
|
| | if getattr(self.config.eval, "get_top_k", False) and self.config.parameterization == "ar": |
| | self.get_top_k(batch, batch_idx) |
| |
|
| | try: |
| | if self.config.eval.compute_img_to_txt_mauve_clip and not self.config.eval.unconditional_fid: |
| | self.update_img_to_txt_mauve_clip(batch, batch_idx) |
| | except Exception as e: |
| | empty_device_cache() |
| | rprint(f"Error in update_img_to_txt_mauve_clip: {e}") |
| |
|
| | if (self.get_every_n_evals(getattr(self.config.eval, "log_every_n_evals", 10)) \ |
| | and continuous_mode \ |
| | and self.config.eval.generate_samples \ |
| | and not self.config.eval.test_eval_speed): |
| | |
| | data = self.sample_transfusion(batch_size_per_gpu=batch['input_ids'].shape[0]) |
| | |
| | rec_embs = [data.xt_img_embed[i, data.modality[i] == 1] for i in range(data.shape[0])] |
| | |
| | rec_embs = torch.stack(rec_embs) |
| | rec_txt = data.xt_ids[data.modality == 0][None] |
| | recon_image = decode_latents(self.config, self.get_vae(), rec_embs, batched=True) |
| | txt = wrapped_batch_decode(self.tokenizer, rec_txt, clean_up_tokenization_spaces=True, skip_special_tokens=True, disable_mask_after_eos=self.config.data.disable_mask_after_eos) |
| | rprint(f"Sampled {len(txt)} text samples:\n {txt[:1][:50]}") |
| | image_list = [wandb.Image(img) for img in recon_image] |
| | val_loss = self.compute_loss(batch, prefix="val") |
| | log({"val/gen_img": image_list, "val/loss": val_loss, **self.get_step_metrics()}) |
| |
|
| | if ( |
| | self.get_every_n_evals(getattr(self.config.eval, "log_every_n_evals", 10)) |
| | and (self.unified_model or self.cub_model or self.vggface_model) |
| | and batch_idx < getattr(self.config.eval, "num_masking_viz_batches", 1) |
| | and not continuous_mode |
| | ): |
| | self.sample_masking(batch=batch, batch_idx=batch_idx) |
| |
|
| | return self.compute_loss(batch, prefix="val", batch_idx=batch_idx) |
| |
|
| | @try_except(write_error_to_file=True) |
| | @torch.no_grad() |
| | def zero_shot_eval_epoch_end(self, example_batch=None): |
| | dataset_name = self.config.data.train |
| | dprint("zero_shot_eval_epoch_end") |
| | if dataset_name == "nlphuji/flickr30k": |
| | cider_score = self.cider_score.compute() |
| | rprint('Flickr30k CIDEr score: ', cider_score) |
| | |
| | log({ |
| | 'val/cider_score': cider_score |
| | }) |
| | elif dataset_name == "facebook/winoground": |
| | win_text_accuracy = self.win_text_accuracy.compute() |
| | win_image_accuracy = self.win_image_accuracy.compute() |
| | win_group_accuracy = self.win_group_accuracy.compute() |
| | rprint(f'Winoground Text Accuracy: {win_text_accuracy}') |
| | rprint(f'Winoground Image Accuracy: {win_image_accuracy}') |
| | rprint(f'Winoground Group Accuracy: {win_group_accuracy}') |
| | |
| | log({ |
| | 'val/win_text_accuracy': win_text_accuracy, |
| | 'val/win_image_accuracy': win_image_accuracy, |
| | 'val/win_group_accuracy': win_group_accuracy |
| | }) |
| | else: |
| | datacomp_img_acc = self.datacomp_img_acc.compute() |
| | datacomp_txt_acc = self.datacomp_txt_acc.compute() |
| | rprint(f'Datacomp Text Accuracy: {datacomp_img_acc}') |
| | rprint(f'Datacomp Image Accuracy: {datacomp_txt_acc}') |
| | |
| | log({ |
| | 'val/datacomp_text_retr_acc': datacomp_img_acc, |
| | 'val/datacomp_img_retr_acc': datacomp_txt_acc |
| | }) |
| | |
| | @try_except(write_error_to_file=True) |
| | @torch.no_grad() |
| | def get_img_text_saturation_batch(self, example_batch): |
| | max_sampling_steps = self.config.model.length |
| | batch_size_per_gpu = example_batch["input_ids"].shape[0] |
| | do_standalone = getattr(self.config.eval, "cham_standalone", False) |
| | pplx_per_step = [] |
| | |
| | |
| | |
| | steps = [1,2,4,8,16,32,64] |
| |
|
| | rprint(f"do_standalone: {do_standalone} with steps: {steps}") |
| | dec_txt_list = [] |
| | dec_img_list = [] |
| | for step in steps: |
| | rprint(f"Step: {step}") |
| | (txt_tokens, img_tokens), nfe_cnt = self._sample(text_only=False, batch_size_per_gpu=batch_size_per_gpu, sample_modality=example_batch["modality"], return_nfe=True, num_steps=step) |
| | decoded_img = Im(decode_latents(self.config, self.get_vae(), img_tokens)).pil |
| | decoded_txt = wrapped_batch_decode(self.tokenizer, txt_tokens, clean_up_tokenization_spaces=True, skip_special_tokens=True, disable_mask_after_eos=self.config.data.disable_mask_after_eos) |
| | if not isinstance(decoded_img, list): |
| | decoded_img = [decoded_img] |
| | if not isinstance(decoded_txt, list): |
| | decoded_txt = [decoded_txt] |
| | dec_txt_list.append(decoded_txt) |
| | dec_img_list.append(decoded_img) |
| | tot_ppl, tot_loss, img_loss, txt_loss = self.calculate_chameleon_perplexity(self.chameleon_model, self.chameleon_processor, prompts=decoded_txt, images=decoded_img, return_all=True)[0] |
| | rprint(f"Step {step} - Total PPL: {tot_ppl} | Total Loss: {tot_loss} | Img Loss: {img_loss} | Txt Loss: {txt_loss}") |
| | pplx_per_step.append((step, tot_ppl, tot_loss, img_loss, txt_loss)) |
| | empty_device_cache() |
| | return dec_txt_list, dec_img_list, pplx_per_step |
| |
|
| | @torch.no_grad() |
| | @try_except(write_error_to_file=True) |
| | @torch.no_grad() |
| | def on_validation_epoch_end(self, example_batch=None): |
| | dprint("on_validation_epoch_end") |
| |
|
| | if self.config.eval.compute_val_metrics_standalone: |
| | self.compute_val_metrics_standalone() |
| |
|
| | all_val_metrics = self.get_step_metrics() |
| | all_val_metrics.update(self.valid_metrics.compute()) |
| | if hasattr(self, "valid_txt_metrics"): |
| | valid_txt_metrics = self.valid_txt_metrics.compute() |
| | valid_img_metrics = self.valid_img_metrics.compute() |
| | all_val_metrics.update({ |
| | **{f"val/txt_{k.split('/')[-1]}": v for k, v in replace_nan_dict(valid_txt_metrics).items()}, |
| | **{f"val/img_{k.split('/')[-1]}": v for k, v in replace_nan_dict(valid_img_metrics).items()}, |
| | }) |
| |
|
| | log(all_val_metrics) |
| |
|
| | gprint("example_batch['input_ids'].ndim: ", example_batch['input_ids'].ndim) |
| | if example_batch['input_ids'].ndim == 3: |
| | combined_batches = example_batch |
| | example_batch = self.update_batch(example_batch[0]) |
| | else: |
| | example_batch = self.update_batch(example_batch) |
| |
|
| | if self.config.eval.auto_enhance: |
| | self.auto_enhance(combined_batches) |
| |
|
| | continuous_mode = self.config.trainer.image_mode == "continuous" |
| | compute_chameleon_perplexity = getattr(self.config.eval, "compute_chameleon_perplexity", False) |
| | all_images = [] |
| | with try_except(write_error_to_file=True, clear_cuda_cache=True): |
| | if self.fid_eval: |
| | if self.config.eval.fid_mode == "inline": |
| | self.compute_inline_fid_eval() |
| | elif self.config.eval.fid_mode == "clean": |
| | self.compute_clean_fid_eval() |
| | else: |
| | raise ValueError(f"Invalid FID mode: {self.config.eval.fid_mode}") |
| | |
| | if self.config.eval.calculate_clip_score: |
| | prefix = "unconditional" if self.config.eval.unconditional_fid else "fid" |
| | self.compute_clip_score(self.fid_gen_dir, f"{prefix}_gen") |
| | self.compute_clip_score(self.fid_gt_dir, f"{prefix}_gt") |
| | if self.config.trainer.ar_inpainting: |
| | import shutil |
| | target_dir = Path(self.fid_gt_dir).parent / "fid_inpainting" |
| | target_dir.mkdir(parents=True, exist_ok=True) |
| | |
| | for img_file in Path(self.fid_gt_dir).rglob("*.png"): |
| | shutil.copy2(img_file, target_dir / img_file.name) |
| |
|
| | for json_file in Path(self.fid_gen_dir).rglob("*.json"): |
| | shutil.copy2(json_file, target_dir / json_file.name) |
| |
|
| | self.compute_clip_score(target_dir, f"{prefix}_inpainting") |
| |
|
| | if self.config.eval.unconditional_fid and \ |
| | self.config.eval.compute_img_to_txt_mauve_during_unconditional_fid and self.config.eval.compute_img_to_txt_mauve_clip: |
| | rprint("Computing img to txt mauve during unconditional fid") |
| | |
| | gen_txt_tokens = self.gather_tokens(self.saved_tokens["unconditional_gen_txt_tokens"]) |
| | gt_txt_tokens = self.gather_tokens(self.saved_tokens["unconditional_gt_txt_tokens"]) |
| | if not getattr(self.config.eval, "global_disable_mauve", False): |
| | self.compute_mauve_entropy(self.fid_gen_dir, self.fid_gt_dir, gen_txt_tokens, gt_txt_tokens, "unconditional") |
| | elif self.config.eval.compute_img_to_txt_mauve_clip: |
| | gen_txt_tokens = self.gather_tokens(self.saved_tokens["img_to_txt_gen_txt_tokens"]) |
| | gt_txt_tokens = self.gather_tokens(self.saved_tokens["img_to_txt_gt_txt_tokens"]) |
| | if not getattr(self.config.eval, "global_disable_mauve", False): |
| | self.compute_mauve_entropy(self.img_to_txt_mauve_gen_dir, self.img_to_txt_mauve_gt_dir, gen_txt_tokens, gt_txt_tokens, "img_to_txt") |
| | if self.config.eval.calculate_clip_score: |
| | self.compute_clip_score(self.img_to_txt_mauve_gen_dir, "img_to_txt_mauve_gen") |
| | self.compute_clip_score(self.img_to_txt_mauve_gt_dir, "img_to_txt_mauve_gt") |
| | self.compute_mauve_entropy(self.img_to_txt_mauve_gen_dir, self.img_to_txt_mauve_gt_dir, gen_txt_tokens, gt_txt_tokens, "img_to_txt") |
| |
|
| | should_eval_speed = getattr(self.config.eval, "test_eval_speed", False) |
| | if self.config.eval.generate_samples: |
| | with try_except(write_error_to_file=True): |
| | empty_device_cache() |
| | if getattr(self.config.eval, 'set_random_gen_seed', False): |
| | new_seed = get_rank() * 10 + 32 |
| | torch.manual_seed(new_seed) |
| | torch.cuda.manual_seed(new_seed) |
| | random.seed(new_seed) |
| | np.random.seed(new_seed) |
| |
|
| | tot_time_per_sample = [] |
| | tot_token_time_per_token = [] |
| | tot_nfe_cnt = 0 |
| | batch_size_per_gpu = self.config.loader.eval_batch_size |
| | sampling_steps = self.config.sampling.steps |
| | num_batches = self.config.eval.num_sample_batches |
| | gen_ppl_max_batches = 1e8 |
| | compute_entropy = getattr(self.config.eval, "compute_entropy", False) |
| | compute_gen_ppl = self.config.eval.compute_generative_perplexity |
| | entropies = [] |
| | |
| | if self.config.eval.compute_standalone_mauve and not getattr(self.config.eval, "global_disable_mauve", False): |
| | mauve_N = self.config.eval.mauve_num_samples |
| | |
| | |
| | num_batches = math.ceil(mauve_N / (batch_size_per_gpu * get_num_gpus())) |
| | should_eval_speed = True |
| | gen_ppl_max_batches = getattr(self.config.eval, "gen_ppl_max_batches", 1e8) |
| | compute_entropy = True |
| | compute_gen_ppl = True |
| | rprint(f"[MAUVE] Generating {mauve_N} samples with batch size {batch_size_per_gpu}, sampling steps {sampling_steps}, total length {self.config.model.length}, num_batches: {num_batches}, max_gen_ppl_batches: {gen_ppl_max_batches}") |
| | |
| | rprint(f"Generating {num_batches} samples with batch size {batch_size_per_gpu}, sampling steps {sampling_steps}, total length {self.config.model.length}, compute_entropy: {compute_entropy}, compute_gen_ppl: {compute_gen_ppl}") |
| | all_samples = [] |
| | get_img_text_saturation = getattr(self.config.eval, "get_img_text_saturation", False) |
| | for i in tqdm(range(num_batches), desc="Generating samples"): |
| | if get_img_text_saturation: |
| | dec_txt_list, dec_img_list, all_vals = self.get_img_text_saturation_batch(example_batch) |
| | |
| | df = pd.DataFrame(all_vals, columns=["step", "tot_ppl", "tot_loss", "img_loss", "txt_loss"]) |
| | df.to_csv(Path(self.config.output_dir) / f"img_text_saturation_batch_{i}.csv", index=False) |
| | rprint(f"Saved img_text_saturation_batch_{i}.csv to {Path(self.config.output_dir) / f'img_text_saturation_batch_{i}.csv'}") |
| |
|
| | log_data = [] |
| | for (step, tot_ppl, tot_loss, img_loss, txt_loss), dec_txt, dec_img in zip(all_vals, dec_txt_list, dec_img_list): |
| | concatenated_text = ' | '.join(dec_txt) |
| | concatenated_image = dec_img[0] |
| | log_data.append([step, tot_ppl, tot_loss, img_loss, txt_loss, concatenated_text, wandb.Image(concatenated_image)]) |
| |
|
| | |
| | log_table = wandb.Table(columns=["Step", "Total PPL", "Total Loss", "Image Loss", "Text Loss", "Generated Text", "Generated Image"], data=log_data) |
| | wandb.log({"img_text_saturation": log_table, "trainer/global_step": self.global_step}) |
| | rprint("Logged img_text_saturation table to wandb") |
| | |
| | |
| | break |
| | if should_eval_speed: |
| | start_time = start_timing(sync=True, enable=True, message="Evaluating inference speed") |
| |
|
| | if self.parameterization == "ar" and continuous_mode: |
| | data = self.sample_transfusion(text_only=True, batch_size_per_gpu=batch_size_per_gpu) |
| | txt_tokens = data.xt_ids[:, self.static_txt_sl] |
| | else: |
| | (txt_tokens, img_tokens), nfe_cnt = self._sample( |
| | text_only=False, |
| | batch_size_per_gpu=batch_size_per_gpu, |
| | sample_modality=example_batch["modality"], |
| | return_nfe=True, |
| | ) |
| | tot_nfe_cnt += nfe_cnt |
| | if should_eval_speed: |
| | tot_time = end_timing(start_time, enable=True, sync=True) |
| | if continuous_mode: assert (data.modality == 0).all() |
| | tot_time_per_sample.append(tot_time) |
| | tot_token_time_per_token.append((tot_time) / self.config.model.length) |
| | |
| | if compute_entropy: |
| | entropies.append(self.compute_entropy(txt_tokens).item()) |
| | |
| | if compute_chameleon_perplexity: |
| | all_images.extend(Im(decode_latents(self.config, self.get_vae(), img_tokens)).pil) |
| | text_samples = wrapped_batch_decode(self.tokenizer, txt_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True) |
| | |
| | if self.config.eval.compute_standalone_mauve and not getattr(self.config.eval, "global_disable_mauve", False): |
| | self.mauve_predictions.extend(text_samples) |
| | if len(text_samples) > 0 and len(text_samples[0]) > 0 and self.config.eval.compute_generative_perplexity and i <= gen_ppl_max_batches: |
| | self.compute_generative_perplexity(text_samples) |
| |
|
| | rprint(f"Generated {len(text_samples)} samples - {[text_samples[i][:200] for i in range(min(len(text_samples), 5))]}") |
| | all_samples.extend(text_samples) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | avg_nfe_cnt = tot_nfe_cnt / num_batches |
| | if should_eval_speed: |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | data_dict = { |
| | f"samples": wandb.Table(columns=["Generated Samples", "Generated Images"], data=[[s, wandb.Image(img)] for s, img in zip(all_samples[:self.config.sampling.num_sample_log], all_images[:self.config.sampling.num_sample_log])]), |
| | "trainer/global_step": self.global_step, |
| | } |
| | assert len(tot_time_per_sample) == len(tot_token_time_per_token) |
| | if len(tot_time_per_sample) > 1: |
| | tot_time_per_sample = tot_time_per_sample[1:] |
| | tot_token_time_per_token = tot_token_time_per_token[1:] |
| | print(f'Have {len(tot_time_per_sample)} samples') |
| | print(f'tot_time_per_sample: {tot_time_per_sample}') |
| | print(f'tot_token_time_per_token: {tot_token_time_per_token}') |
| | avg_time_per_sample = sum(tot_time_per_sample) / len(tot_time_per_sample) |
| | avg_time_per_token = sum(tot_token_time_per_token) / len(tot_token_time_per_token) |
| | data_dict["val/avg_time_per_sample"] = avg_time_per_sample |
| | data_dict["val/avg_time_per_token"] = avg_time_per_token |
| | data_dict["val/avg_nfe_cnt"] = avg_nfe_cnt |
| | rprint(f"Time per sample: avg (excluding warmup): {avg_time_per_sample} - {tot_time_per_sample} ") |
| | rprint(f"Time per token: avg (excluding warmup): {avg_time_per_token} - {tot_token_time_per_token} ") |
| | with open(Path(self.config.output_dir) / "times.txt", "a") as f: |
| | f.write(f"{avg_time_per_sample}, {avg_time_per_token}\n") |
| | f.write(f"{tot_time_per_sample}\n") |
| | f.write(f"{tot_token_time_per_token}\n") |
| | rprint(f"Logged time per sample and time per token to {Path(self.config.output_dir) / 'times.txt'}") |
| | else: |
| | if len(text_samples) > 0 and isinstance(text_samples[0], list): |
| | text_samples = [[item] for sublist in text_samples for item in sublist] |
| | else: |
| | text_samples = [[item] for item in text_samples] |
| |
|
| | data_dict = { |
| | "samples": wandb.Table(columns=["Generated Samples"], data=text_samples), |
| | **self.get_step_metrics() |
| | } |
| |
|
| | if compute_gen_ppl: |
| | data_dict["val/gen_ppl"] = self.gen_ppl_metric.compute() |
| | data_dict["val/gt_gen_ppl"] = self.gt_gen_ppl_metric.compute() |
| | self.gen_ppl_metric.reset() |
| | self.gt_gen_ppl_metric.reset() |
| | |
| | if compute_entropy: |
| | data_dict["val/val_entropy"] = sum(entropies) / len(entropies) if len(entropies) > 0 else 0 |
| | |
| | if compute_chameleon_perplexity: |
| | if getattr(self.config.eval, "max_chameleon_samples", False): |
| | all_images = all_images[:self.config.eval.max_chameleon_samples] |
| | all_samples = all_samples[:self.config.eval.max_chameleon_samples] |
| | pplxs = self.calculate_chameleon_perplexity(self.chameleon_model, self.chameleon_processor, images=all_images, prompts=all_samples) |
| | |
| | |
| | avg_pplx = sum(pplxs) / len(pplxs) |
| | data_dict["val/chameleon_ppl"] = avg_pplx |
| | |
| | if self.config.eval.compute_standalone_mauve and not getattr(self.config.eval, "global_disable_mauve", False): |
| | all_mauve_preds = gather_object(self.mauve_predictions) |
| | all_mauve_refs = gather_object(self.mauve_references) |
| | data_dict["val/mauve_score"] = self.get_mauve_score(all_mauve_preds, all_mauve_refs, "standalone") |
| |
|
| | log(data_dict) |
| | |
| | |
| | |
| | if ( |
| | ((self.get_every_n_evals(getattr(self.config.eval, "log_every_n_evals", 10)) |
| | and (self.image_model or self.config.trainer.multimodal_batches) |
| | and not getattr(self.config.model, "img_cond", False) |
| | and not should_eval_speed) or getattr(self.config.eval, "force_eval_uncond", False)) and not getattr(self.config.eval, "force_disable_eval_uncond", False) |
| | ): |
| | dprint("Generating samples") |
| | with try_except(write_error_to_file=True): |
| | has_label = getattr(self.config.model, "cond_label", False) |
| | sample_kwargs = dict() |
| |
|
| | if has_label: |
| | label = torch.randint(0, self.config.model.label_vocab_size, (self.config.loader.eval_batch_size,)).to(device=self.device, dtype=torch.int64) |
| | sample_kwargs["label"] = label |
| | else: |
| | label = torch.randint(0, 1, (self.config.loader.eval_batch_size * 20,)) |
| |
|
| | text_samples_list = [] |
| | img_samples_list = [] |
| | for j in range(getattr(self.config.eval, "num_uncond_sample_batches", 1)): |
| | if continuous_mode: |
| | data = self.sample_transfusion(batch_size_per_gpu=self.config.loader.eval_batch_size) |
| | text_samples = data.xt_ids[:, self.static_txt_sl] |
| | img_samples = data.xt_img_embed[:, self.static_img_sl] |
| | img_samples = decode_latents(self.config, self.get_vae(), img_samples) |
| | else: |
| | if getattr(self.config.eval, "eval_large_batch", None) is not None: |
| | data = combined_batches[j] |
| | data = self.update_batch(data) |
| | rprint(f"Taken slice {j} of {getattr(self.config.eval, 'eval_large_batch', None)}") |
| | else: |
| | data = example_batch |
| |
|
| | _modality = data.get("modality", None) |
| | _bs = min(self.config.eval.perplexity_batch_size, self.config.loader.eval_batch_size) |
| | if _bs < _modality.shape[0]: |
| | _modality = _modality[:_bs] |
| |
|
| | text_samples, img_samples = self._sample( |
| | text_only=False, |
| | num_steps=self.config.sampling.max_sampling_steps, |
| | batch_size_per_gpu=_bs, |
| | example_batch=data, |
| | sample_batch_idx=j, |
| | modality=_modality, |
| | sample_ids=data.get("sample_ids", None), |
| | allow_interleaved_conditional=True, |
| | **sample_kwargs |
| | ) |
| | num_text_tokens = self.config.model.txt_length if self.config.model.txt_length > 0 else 128 |
| | if text_samples is None: |
| | text_samples = [torch.zeros((self.config.loader.eval_batch_size, num_text_tokens), dtype=torch.int64, device=self.device)] |
| | elif isinstance(text_samples, list): |
| | new_text_samples = [] |
| | for text_sample in text_samples: |
| | text_samples_padded = torch.nn.functional.pad(text_sample, (0, num_text_tokens - text_sample.shape[-1]), value=self.tokenizer.pad_token_id) if text_sample.shape[-1] < num_text_tokens else text_sample[..., :num_text_tokens] |
| | new_text_samples.append(text_samples_padded) |
| | text_samples = new_text_samples |
| | else: |
| | text_samples = [torch.nn.functional.pad(text_samples, (0, num_text_tokens - text_samples.shape[-1]), value=self.tokenizer.pad_token_id) if text_samples.shape[-1] < num_text_tokens else text_samples[..., :num_text_tokens]] |
| | |
| | text_samples_list.extend(text_samples) |
| | if img_samples is not None: |
| | if isinstance(img_samples, list): |
| | img_samples_list.extend(img_samples) |
| | else: |
| | img_samples_list.append(img_samples) |
| |
|
| | if len(text_samples_list) > 0 and any(text_samples is not None for text_samples in text_samples_list): |
| | text_samples = torch.cat(text_samples_list, dim=0) |
| | else: |
| | text_samples = None |
| | has_img = any(img_samples is not None for img_samples in img_samples_list) |
| | log_dict = {} |
| | try: |
| | if has_img: |
| | if isinstance(img_samples_list[0], Tensor): |
| | img_samples = torch.cat(img_samples_list, dim=0) |
| | if img_samples.ndim == 2: |
| | pred_img = decode_latents(self.config, self.get_vae(), img_samples) |
| | else: |
| | pred_img = img_samples |
| |
|
| | log_dict.update({"val/gen_images": wandb.Image(pred_img)}) |
| | else: |
| | pred_img = img_samples_list |
| | for i, img in enumerate(img_samples_list): |
| | log_dict[f"val/gen_images_{i}"] = wandb.Image(img) |
| | else: |
| | pred_img = img_samples_list |
| | except Exception as e: |
| | rprint(f"Error during gather: {e}") |
| | pred_img = [None] * len(img_samples_list) |
| | has_img = False |
| | with try_except(write_error_to_file=True): |
| | if text_samples is not None: |
| | text_samples = gather(text_samples) |
| | pred_txt = wrapped_batch_decode(self.tokenizer, text_samples, clean_up_tokenization_spaces=True, skip_special_tokens=True, disable_mask_after_eos=self.config.data.disable_mask_after_eos) |
| | prefix = "class_cond" if has_label else ("cond" if self.config.trainer.interleaved else "uncond") |
| |
|
| | if isinstance(pred_img, Tensor): |
| | pred_img = pred_img.float().cpu() |
| |
|
| | pred_img = gather_object(pred_img) |
| | gen_table = wandb.Table(columns=[*([f"{prefix}_sampled_image"] if has_img else []), f"{prefix}_sampled_caption", *(["Label"] if has_label else [])]) |
| | for img, caption, label in zip(pred_img, pred_txt, label): |
| | gen_table.add_data(*([wandb.Image(img)] if has_img else []), caption, *([label] if has_label else [])) |
| | log_dict[f"{prefix}_sample_table"] = gen_table |
| | log({**log_dict, **self.get_step_metrics()}) |
| |
|
| | if getattr(self.config.trainer, "print_llm_loss", False) and hasattr(self, 'histogram') and not should_eval_speed: |
| | avg_losses = {t: sum(l) / len(l) for t, l in self.histogram.items()} |
| | timesteps, avg_losses = zip(*sorted(avg_losses.items())) |
| |
|
| | from io import BytesIO |
| |
|
| | import matplotlib.pyplot as plt |
| |
|
| | plt.plot(timesteps, avg_losses) |
| | plt.xlabel('Timesteps') |
| | plt.ylabel('Average Loss') |
| | plt.title('Loss over Time') |
| | plt.show() |
| |
|
| | buf = BytesIO() |
| | plt.savefig(buf, format='png') |
| | plt.close() |
| | buf.seek(0) |
| | img = Image.open(buf) |
| | log({"loss_over_time": wandb.Image(img)}) |
| | rprint("Logged loss over time") |
| |
|
| | if hasattr(self, "valid_txt_metrics"): |
| | self.valid_metrics.reset() |
| | self.valid_txt_metrics.reset() |
| | self.valid_img_metrics.reset() |
| |
|
| | if (time.time() - getattr(self, "validation_start_time", time.time())) > 15: |
| | rprint(f"Validation took: {time.time() - self.validation_start_time} seconds") |
| |
|
| | dprint("on_validation_epoch_end finished") |
| |
|
| | def on_validation_epoch_cleanup(self): |
| | self.reset_validation_metrics() |
| | self.fid_eval = False |
| | self.saved_tokens = defaultdict(list) |
| | if hasattr(self, "inception_metrics"): del self.inception_metrics |
| |
|
| | if "tokens" in self.config.data.train and hasattr(self, "vae"): |
| | del self.vae |
| | self.vae = None |
| |
|
| | if is_main_process() and not getattr(self.config.eval, "disable_fid_cleanup", False): self.cleanup_fid_output() |
| | empty_device_cache() |
| |
|
| | if getattr(self.config.trainer, "attach_oom_observer_eval", False): |
| | if hasattr(self, "gpu_memory_reserved") and self.gpu_memory_reserved is not None: |
| | cur_gpu_memory_reserved = torch.cuda.memory_reserved() |
| | if getattr(self.config.trainer, "force_save_eval_memory_profile", False) or (cur_gpu_memory_reserved - self.gpu_memory_reserved > 4 * 1024**3): |
| | rprint(f"Warning: GPU memory usage increased by more than 4GB during validation. Initial: {self.gpu_memory_reserved / 1024**3:.2f}GB, Current: {cur_gpu_memory_reserved / 1024**3:.2f}GB") |
| | oom_dir = Path(self.config.output_dir) / "oom_profile" |
| | oom_dir.mkdir(parents=True, exist_ok=True) |
| | save_memory_profile(oom_dir) |
| | self.gpu_memory_reserved = None |
| | dprint("Disabled memory history") |
| | torch.cuda.memory._record_memory_history(enabled=None) |
| |
|
| | dprint("on_validation_epoch_cleanup finished") |
| |
|
| | def gather_tokens(self, tokens): |
| | tokens = torch.cat(tokens, dim=0).to(device=self.device, dtype=torch.int64) |
| | tokens = gather(tokens) |
| | return tokens |
| |
|
| | @try_except(write_error_to_file=True, clear_cuda_cache=True) |
| | def get_top_k(self, batch, batch_idx): |
| | if batch_idx == 0: |
| | all_top_k = {1: [], 2: [], 5: []} |
| | for i in range(16): |
| | mod_input_ids = batch['input_ids'].clone() |
| | mod_input_ids[:, self.static_txt_sl] = mod_input_ids[i, self.static_txt_sl] |
| | mod_attention_mask = batch['attention_mask'].clone() |
| | mod_attention_mask[:, self.static_txt_sl] = mod_attention_mask[i, self.static_txt_sl] |
| |
|
| | if getattr(self.config.eval, "cfg", None): |
| | cat_mod_input_ids = torch.cat([mod_input_ids, torch.where(batch['modality'] == 1, self.mask_index, mod_input_ids)], dim=0) |
| | cat_p_x0 = self.forward( |
| | cat_mod_input_ids, |
| | sigma=None, |
| | attention_mask=mod_attention_mask, |
| | batch=dict(modality=batch['modality']), modality=batch['modality'] |
| | ) |
| | logit_c, logit_u = cat_p_x0.chunk(2, dim=0) |
| | _w = getattr(self.config.eval, "cfg", None) |
| | model_output = (1 + _w) * logit_c - _w * logit_u |
| | else: |
| | model_output = self.forward(mod_input_ids, sigma=None, attention_mask=mod_attention_mask, batch=dict(modality=batch['modality']), modality=batch['modality']) |
| | |
| | log_p_theta = torch.gather(input=model_output, dim=-1, index=mod_input_ids[:, 1:, None]).squeeze(-1) |
| | mean_nll = (-log_p_theta * mod_attention_mask[:, 1:]).sum(dim=-1) / mod_attention_mask[:, 1:].sum(dim=-1) |
| | |
| | for k in [1, 2, 5]: |
| | topk_values, topk_indices = torch.topk(mean_nll, k, dim=0) |
| | all_top_k[k].append(0 in topk_indices.tolist()) |
| |
|
| | for k in [1, 2, 5]: |
| | retrieval_rate = sum(all_top_k[k]) / len(all_top_k[k]) |
| | rprint(f"{retrieval_rate:.2%} retrieved in top {k}") |
| | log({f"val/top_{k}": retrieval_rate}) |
| |
|
| | @try_except(write_error_to_file=True, clear_cuda_cache=True) |
| | def compute_clip_score(self, output_dir, prefix): |
| | from model_utils import calculate_clip_score |
| | caption_paths = [str(x.as_posix()) for x in Path(output_dir).glob('*.png') if x.is_file() and x.with_suffix('.json').exists()] |
| | captions_mapping = {str(x): json.load(Path(x).with_suffix('.json').open())['caption'] for x in caption_paths} |
| | clip_score = calculate_clip_score(caption_paths, captions_mapping=captions_mapping) |
| | clip_score *= 100 |
| | rprint(f"{prefix} CLIP score: {clip_score}") |
| | log({f"val/{prefix}_clip_score": clip_score, **self.get_step_metrics()}) |
| |
|
| | @try_except(write_error_to_file=True, clear_cuda_cache=True) |
| | def compute_inline_fid(self): |
| | rprint(f"FID Eval. We have {len(self.inception_metrics.fake_uncond_features)} batches.") |
| | try: |
| | if self.config.mode == "eval" and not self.config.trainer.image_mode == "continuous": |
| | output_dir = Path("eval_tokens").resolve() |
| | output_dir.mkdir(parents=True, exist_ok=True) |
| | dataset_size = sum(x[-1].shape[0] for x in self.computed_tokens) |
| | data = TensorDict( |
| | { |
| | "txt_input_ids": torch.cat([x[1] for x in self.computed_tokens]).to(device="cpu", dtype=torch.int32), |
| | "img_input_ids": torch.cat([x[2] for x in self.computed_tokens]).to(device="cpu", dtype=torch.int16), |
| | "gt_img_input_ids": torch.cat([x[3] for x in self.computed_tokens]).to(device="cpu", dtype=torch.int16), |
| | }, |
| | batch_size=[dataset_size], |
| | ) |
| | save_loc = str(output_dir / f"{get_rank()}") |
| | data.memmap(save_loc) |
| | gprint(f"Saved tokens to {save_loc}") |
| |
|
| | rank = get_rank() |
| | output_folder = Path("fid_metrics") |
| | output_folder.mkdir(parents=True, exist_ok=True) |
| | torch.save(self.inception_metrics.fake_uncond_features, output_folder / f"rank_{rank}_fake_uncond_features.pt") |
| | torch.save(self.inception_metrics.fake_uncond_logits, output_folder / f"rank_{rank}_fake_uncond_logits.pt") |
| | torch.save(self.inception_metrics.real_features, output_folder / f"rank_{rank}_real_features.pt") |
| | rprint(f"Saved rank_{rank} tensors.") |
| | except Exception as e: |
| | gprint(f"Error during all_gather_object or saving tensors: {e}") |
| |
|
| | with torch.autocast(device_type=self.device.type, enabled=False): |
| | metrics = self.inception_metrics.compute() |
| |
|
| | rprint(f"Computed metrics: {metrics}") |
| | metrics = {f"val/{k}": v for k, v in metrics.items()} |
| | log({**metrics, "trainer/global_step": self.global_step}) |
| | output_folder = Path("fid_metrics") |
| | output_folder.mkdir(parents=True, exist_ok=True) |
| | with open(output_folder / f'metrics_{get_rank()}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.txt', "w") as f: |
| | for k, v in metrics.items(): |
| | f.write(f"val/{k}: {v}\n") |
| |
|
| | self.fid_eval = False |
| | del self.inception_metrics |
| | rprint("Finished FID eval") |
| |
|
| | @try_except(write_error_to_file=True, clear_cuda_cache=True) |
| | def compute_clean_fid_eval(self): |
| | with try_except(write_error_to_file=True): |
| | images = [] |
| | for i, filename in enumerate(sorted(Path(self.fid_gen_dir).iterdir(), key=lambda x: random.random())): |
| | if i >= self.config.loader.eval_batch_size * get_world_size(): |
| | break |
| | if filename.is_file() and filename.suffix == ".png": |
| | for i in range(3): |
| | try: |
| | img = Image.open(filename) |
| | except Exception as e: |
| | time.sleep(0.1) |
| | rprint(f"Error opening image {filename}: {e}") |
| | images.append(np.array(img)) |
| | images = np.stack(images) |
| | log({"val/fid_gen_img_at_compute": wandb.Image(Im(images).torch)}) |
| |
|
| | from cleanfid import fid |
| | kwargs = dict() |
| | if self.config.eval.clean_fid_use_precomputed_stats: |
| | kwargs.update(dict( |
| | dataset_name=self.config.eval.clean_fid_precomputed_name, |
| | dataset_res=self.config.eval.clean_fid_precomputed_res, |
| | dataset_split=self.config.eval.clean_fid_precomputed_split, |
| | )) |
| | else: |
| | kwargs.update(dict(fdir2=str(self.fid_gt_dir))) |
| | |
| | score = fid.compute_fid( |
| | fdir1=str(self.fid_gen_dir), |
| | use_dataparallel=False, |
| | **kwargs |
| | ) |
| |
|
| | rprint(f"FID score: {score}") |
| | metrics = {"val/fid_unconditional": score, **self.get_step_metrics()} |
| | log(metrics) |
| |
|
| | metrics = {f"val/{k}": v for k, v in metrics.items()} |
| | output_folder = Path("fid_metrics") |
| | output_folder.mkdir(parents=True, exist_ok=True) |
| | with open(output_folder / f'metrics_{get_rank()}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.txt', "w") as f: |
| | for k, v in metrics.items(): |
| | f.write(f"{k}: {v}\n") |
| |
|
| | self.fid_eval = False |
| | |
| | def sample_for_fid(self, batch, batch_idx, return_gt_img=False, return_gt_txt=False, img_to_txt_gen=False): |
| | """This function is also used for img -> txt generation.""" |
| | continuous_mode = self.config.trainer.image_mode == "continuous" |
| | sample_kwargs = self.get_cond_dict(batch) |
| | orig_modality, orig_input_ids = None, None |
| | if img_to_txt_gen: |
| | if self.config.parameterization == "ar": |
| | txt_first_sl = slice(None, self.config.model.txt_length) |
| | img_first_sl = slice(None, self.config.model.img_length) |
| | if (batch["modality"][:, txt_first_sl] == 0).all(): |
| | assert (batch["modality"][:, :self.config.model.txt_length] == 0).all() and (batch["modality"][:, self.config.model.txt_length:] == 1).all() |
| | flipped_batch = dict() |
| | img_slice = slice(-self.config.model.img_length, None) |
| | txt_slice = slice(None, self.config.model.txt_length) |
| | for key in ["modality", "attention_mask", "input_ids"]: |
| | flipped_batch[key] = torch.cat([batch[key][:, img_slice], batch[key][:, txt_slice]], dim=1) |
| |
|
| | batch = flipped_batch |
| | else: |
| | assert (batch["modality"][:, img_first_sl] == 1).all() |
| |
|
| | assert (batch["modality"][:, :self.config.model.img_length] == 1).all(), "Img tokens should be 0" |
| | else: |
| | assert (batch["modality"][:, :self.config.model.txt_length] == 0).all() |
| |
|
| | sample_kwargs["sample_modality"] = batch["modality"] |
| | _x0_unmask = (batch["modality"] == 1) |
| | elif getattr(self.config.eval, "unconditional_fid", False): |
| | sample_kwargs["x0_unmask"] = None |
| | sample_kwargs["x0"] = None |
| | sample_kwargs["sample_modality"] = batch["modality"] |
| | elif self.config.trainer.ar_inpainting: |
| | assert getattr(self.config.eval, "txt_conditional_fid", False) |
| | min_val, max_val = getattr(self.config.eval, "ar_inpainting_min_val", 0.9), getattr(self.config.eval, "ar_inpainting_max_val", 1.0) |
| | n = batch["modality"].shape[0] |
| | _eps_t = torch.rand(n, device=self.device) |
| | t = (max_val - min_val) * _eps_t + min_val |
| | if getattr(self.config.eval, "ar_inpainting_force_val", None) is not None: |
| | t = torch.full_like(t, getattr(self.config.eval, "ar_inpainting_force_val"), dtype=t.dtype, device=t.device) |
| | if self.config.parameterization == "ar": |
| | orig_modality, orig_input_ids = batch["modality"].clone(), batch["input_ids"].clone() |
| | del batch["batch_contains_img"] |
| | batch.auto_batch_size_() |
| | batch = torch.cat([batch, batch], dim=1) |
| | x0 = batch["input_ids"] |
| | move_indices = torch.rand(*x0.shape, device=x0.device) < t[:, None] |
| | move_indices[:, x0.shape[1] // 2:] = False |
| | batch["input_ids"] = torch.where(move_indices, self.mask_index, x0) |
| | _x0_unmask = torch.zeros_like(batch["input_ids"], dtype=torch.bool) |
| | _x0_unmask[:, :batch["input_ids"].shape[1] // 2] = True |
| | else: |
| | _x0_unmask = torch.rand(*batch["modality"].shape, device=batch["modality"].device) > t[:, None] |
| | sample_kwargs["sample_modality"] = batch["modality"] |
| | sample_kwargs["x0_unmask"] = _x0_unmask |
| | sample_kwargs["x0"] = batch["input_ids"] |
| | elif getattr(self.config.eval, "class_conditional_fid", False) or getattr(self.config.eval, "txt_conditional_fid", False): |
| | sample_kwargs["x0"] = batch["input_ids"] |
| | if getattr(self.config.eval, "class_conditional_fid", False): |
| | sample_kwargs["sample_modality"] = torch.full_like(batch["modality"], 1) |
| | sample_kwargs["sample_modality"][:, 0] = 0 |
| | _x0_unmask = torch.zeros_like(batch["input_ids"], dtype=torch.bool) |
| | _x0_unmask[..., 0] = True |
| | elif getattr(self.config.eval, "txt_conditional_fid", False): |
| | assert ((batch["modality"] == 1).sum(dim=-1) > 0).all(), "No img samples provided" |
| | sample_kwargs["sample_modality"] = batch["modality"] |
| | _x0_unmask = (batch["modality"] == 0) |
| | sample_kwargs["x0_unmask"] = _x0_unmask |
| |
|
| | if continuous_mode: |
| | data = self.sample_transfusion(batch_size_per_gpu=self.config.loader.eval_batch_size) |
| | gen_txt_tokens = data.xt_ids[:, self.static_txt_sl] |
| | gen_img_tokens = data.xt_img_embed[:, self.static_img_sl] |
| | gen_img = decode_latents(self.config, self.get_vae(), gen_img_tokens) |
| | else: |
| | gen_txt_tokens, gen_img_tokens = self._sample(text_only=False, **sample_kwargs) |
| | gen_img = decode_latents(self.config, self.get_vae(), gen_img_tokens) |
| |
|
| | fid_rec_img, gt_img_tokens, gt_txt_tokens = None, None, None |
| | if return_gt_img: |
| | if "img" in batch: |
| | fid_rec_img = batch["img"] |
| | else: |
| | if orig_modality is None: |
| | orig_modality = batch.get("modality", None) |
| | if orig_input_ids is None: |
| | orig_input_ids = batch["input_ids"] |
| |
|
| | _, gt_img_tokens = self.decode_batch(orig_input_ids, text_only=False, sample_modality=orig_modality) |
| | if gt_img_tokens.shape[0] == 0: |
| | rprint(f"{gt_img_tokens.shape} {batch['input_ids'].shape}") |
| | fid_rec_img = decode_latents(self.config, self.get_vae(), gt_img_tokens) |
| |
|
| | if return_gt_txt: |
| | if orig_input_ids is None: |
| | orig_input_ids = batch["input_ids"] |
| | if orig_modality is None: |
| | orig_modality = batch.get("modality", None) |
| | gt_txt_tokens, _ = self.decode_batch(orig_input_ids, text_only=False, sample_modality=orig_modality) |
| |
|
| | _prefix = "img_to_txt" if img_to_txt_gen else ("unconditional" if getattr(self.config.eval, "unconditional_fid", False) else "txt_to_img") |
| | self.saved_tokens[_prefix + "_gen_img_tokens"].append(gen_img_tokens.detach().cpu().to(torch.int32)) |
| | self.saved_tokens[_prefix + "_gen_txt_tokens"].append(gen_txt_tokens.detach().cpu().to(torch.int32)) |
| | if gt_img_tokens is not None: self.saved_tokens[_prefix + "_gt_img_tokens"].append(gt_img_tokens.detach().cpu().to(torch.int32)) |
| | if gt_txt_tokens is not None: self.saved_tokens[_prefix + "_gt_txt_tokens"].append(gt_txt_tokens.detach().cpu().to(torch.int32)) |
| |
|
| | return gen_img, gen_txt_tokens, gt_img_tokens, gt_txt_tokens, gen_img_tokens, fid_rec_img |
| |
|
| |
|
| | def update_inline_fid(self, batch, batch_idx): |
| | gen_img, txt_tokens, gt_img_tokens, gt_txt_tokens, gen_img_tokens, fid_rec_img = self.sample_for_fid(batch, batch_idx, return_gt_img=True, return_gt_txt=True) |
| |
|
| | if self.config.mode == "eval": |
| | self.computed_tokens.append((txt_tokens, gen_img_tokens, gt_img_tokens)) |
| | with torch.autocast(device_type=self.device.type, enabled=False): |
| | self.inception_metrics.update(remap_image_torch(fid_rec_img).to(self.device), None, image_type="real") |
| | self.inception_metrics.update(remap_image_torch(gen_img).to(self.device), None, image_type="unconditional") |
| | |
| | if batch_idx == 0: |
| | log({"val/fid_gen": wandb.Image(gen_img), "val/fid_gt": wandb.Image(fid_rec_img), **self.get_step_metrics()}) |
| |
|
| | if batch_idx > 0 and batch_idx % 5 == 0 and self.config.mode == "eval": |
| | gprint(f"Saving rank_{get_rank()} tensors.") |
| | try: |
| | rank = get_rank() |
| | torch.save(self.inception_metrics.fake_uncond_features, f"{batch_idx}_rank_{rank}_fake_uncond_features.pt") |
| | torch.save(self.inception_metrics.fake_uncond_logits, f"{batch_idx}_rank_{rank}_fake_uncond_logits.pt") |
| | torch.save(self.inception_metrics.real_features, f"{batch_idx}_rank_{rank}_real_features.pt") |
| | gprint(f"Saved rank_{rank} tensors.") |
| | except Exception as e: |
| | gprint(f"Error during all_gather_object or saving tensors: {e}") |
| |
|
| | def update_clean_fid(self, batch, batch_idx): |
| | assert hasattr(self, "fid_gen_dir") |
| | save_gt_img = not self.config.eval.clean_fid_use_precomputed_stats |
| | gen_img, txt_tokens, gt_img_tokens, gt_txt_tokens, img_samples, fid_rec_img = self.sample_for_fid(batch, batch_idx, return_gt_img=save_gt_img, return_gt_txt=True) |
| |
|
| | if self.config.model.image_model_fid_eval: |
| | txt_samples = wrapped_batch_decode(self.tokenizer, txt_tokens, clean_up_tokenization_spaces=True, skip_special_tokens=True) |
| | gt_txt_samples = wrapped_batch_decode(self.tokenizer, gt_txt_tokens, clean_up_tokenization_spaces=True, skip_special_tokens=True) |
| |
|
| | save_loc = Path(self.fid_gen_dir) |
| | save_loc.mkdir(parents=True, exist_ok=True) |
| | quantized_img = remap_image_torch(gen_img).permute(0, 2, 3, 1).cpu().numpy() |
| |
|
| | if save_gt_img: |
| | gt_quantized_img = remap_image_torch(fid_rec_img).permute(0, 2, 3, 1).cpu().numpy() |
| | save_loc_gt = Path(self.fid_gt_dir) |
| | save_loc_gt.mkdir(parents=True, exist_ok=True) |
| | |
| | for i in range(gen_img.shape[0]): |
| | gen_img_pil = Image.fromarray(quantized_img[i]) |
| | suffix = ''.join(random.choices(string.ascii_lowercase + string.digits, k=4)) |
| | filename = f"{batch_idx}_{get_rank()}_{i}_{suffix}.png" |
| | out_file_path = save_loc / filename |
| | gen_img_pil.save(out_file_path) |
| |
|
| | if self.config.eval.txt_conditional_fid: |
| | with open(out_file_path.with_suffix(".json"), 'w') as json_file: |
| | json.dump({"caption": txt_samples[i]}, json_file) |
| |
|
| | if save_gt_img: |
| | gt_img_pil = Image.fromarray(gt_quantized_img[i]) |
| | gt_out_file_path = save_loc_gt / filename |
| | gt_img_pil.save(gt_out_file_path) |
| |
|
| | if self.config.eval.txt_conditional_fid: |
| | with open(gt_out_file_path.with_suffix(".json"), 'w') as json_file: |
| | json.dump({"caption": gt_txt_samples[i]}, json_file) |
| |
|
| | if batch_idx == 0: |
| | rprint(f"Logging at batch idx {batch_idx}") |
| | time.sleep(0.2) |
| | with try_except(write_error_to_file=True): |
| | images = [] |
| | for i, filename in enumerate(sorted(Path(self.fid_gen_dir).iterdir(), key=lambda x: random.random())): |
| | if i >= self.config.loader.eval_batch_size * get_world_size(): |
| | break |
| | if filename.is_file() and filename.suffix == ".png": |
| | img = Image.open(filename) |
| | images.append(np.array(img)) |
| | images = np.stack(images) |
| | log({"val/fid_gen_img": wandb.Image(Im(images).torch)}) |
| | rprint(f"FID Txt: {txt_samples[0]}") |
| |
|
| | def update_img_to_txt_mauve_clip(self, batch, batch_idx): |
| | assert hasattr(self, "img_to_txt_mauve_gen_dir") |
| | save_gt_img = True |
| | empty_device_cache() |
| | gen_img, gen_txt_tokens, gt_img_tokens, gt_txt_tokens, gen_img_tokens, fid_rec_img = self.sample_for_fid(batch, batch_idx, return_gt_img=save_gt_img, return_gt_txt=True, img_to_txt_gen=True) |
| |
|
| | gen_txt_samples = wrapped_batch_decode(self.tokenizer, gen_txt_tokens, clean_up_tokenization_spaces=True, skip_special_tokens=True) |
| | gt_txt_samples = wrapped_batch_decode(self.tokenizer, gt_txt_tokens, clean_up_tokenization_spaces=True, skip_special_tokens=True) |
| |
|
| | save_loc = Path(self.img_to_txt_mauve_gen_dir) |
| | save_loc.mkdir(parents=True, exist_ok=True) |
| | quantized_img = remap_image_torch(gen_img).permute(0, 2, 3, 1).cpu().numpy() |
| |
|
| | if save_gt_img: |
| | gt_quantized_img = remap_image_torch(fid_rec_img).permute(0, 2, 3, 1).cpu().numpy() |
| | save_loc_gt = Path(self.img_to_txt_mauve_gt_dir) |
| | save_loc_gt.mkdir(parents=True, exist_ok=True) |
| | |
| | for i in range(gen_img.shape[0]): |
| | gen_img_pil = Image.fromarray(quantized_img[i]) |
| | suffix = ''.join(random.choices(string.ascii_lowercase + string.digits, k=4)) |
| | filename = f"{batch_idx}_{get_rank()}_{i}_{suffix}.png" |
| | out_file_path = save_loc / filename |
| | gen_img_pil.save(out_file_path) |
| | with open(out_file_path.with_suffix(".json"), 'w') as json_file: |
| | json.dump({"caption": gen_txt_samples[i]}, json_file) |
| |
|
| | if save_gt_img: |
| | gt_img_pil = Image.fromarray(gt_quantized_img[i]) |
| | gt_out_file_path = save_loc_gt / filename |
| | gt_img_pil.save(gt_out_file_path) |
| | with open(gt_out_file_path.with_suffix(".json"), 'w') as json_file: |
| | json.dump({"caption": gt_txt_samples[i]}, json_file) |
| |
|
| | if batch_idx == 0: |
| | rprint(f"GT img -> txt mauve: {gt_txt_samples[0]}") |
| | rprint(f"Gen img -> txt mauve: {gen_txt_samples[0]}") |
| | |
| | def compute_mauve_entropy(self, img_to_txt_mauve_gen_dir, img_to_txt_mauve_gt_dir, gen_txt_tokens, gt_txt_tokens, prefix): |
| | gt_txt = [] |
| | gt_img = [] |
| | gt_dir = Path(img_to_txt_mauve_gt_dir) |
| | gen_dir = Path(img_to_txt_mauve_gen_dir) |
| | stems = [f.stem for f in gt_dir.iterdir() if f.suffix == '.json' and (gen_dir / f.name.replace("gt", "gen")).exists()] |
| | assert len(stems) > 0, f"No stems found in {gt_dir} and {gen_dir}" |
| | rprint(f"Found {len(stems)} unique stems") |
| |
|
| | gt_img = [] |
| | gt_txt = [] |
| | gen_txt = [] |
| | gen_img = [] |
| | data_dict = {} |
| | for stem in stems: |
| | gt_img_path = gt_dir / f"{stem}.png" |
| | gt_img.append(Image.open(gt_img_path)) |
| |
|
| | gen_img_path = gen_dir / f"{stem}.png" |
| | gen_img.append(Image.open(gen_img_path)) |
| |
|
| | with open(gt_dir / f"{stem}.json", 'r') as f: |
| | gt_txt.append(json.load(f)["caption"]) |
| |
|
| | with open(gen_dir / f"{stem}.json", 'r') as f: |
| | gen_txt.append(json.load(f)["caption"]) |
| |
|
| | table = wandb.Table(columns=["GT Image", "GT Text", "Generated Image", "Generated Text"]) |
| | num_samples_to_display = min(20, len(stems)) |
| | for i in range(num_samples_to_display): |
| | table.add_data( |
| | wandb.Image(gt_img[i]), |
| | gt_txt[i], |
| | wandb.Image(gen_img[i]), |
| | gen_txt[i] |
| | ) |
| | |
| | data_dict[f"val/{prefix}_mauve_samples"] = table |
| | if not getattr(self.config.eval, "global_disable_mauve", False): |
| | data_dict[f"val/{prefix}_mauve_score"] = self.get_mauve_score(gen_txt, gt_txt, prefix) |
| | data_dict[f"val/{prefix}_gt_entropy"] = self.compute_entropy(gt_txt_tokens) |
| | data_dict[f"val/{prefix}_gen_entropy"] = self.compute_entropy(gen_txt_tokens) |
| | data_dict[f"val/{prefix}_percent_valid_txt_tokens"] = self.count_valid_tokens(gen_txt_tokens).float().mean(dim=-1) / gen_txt_tokens.shape[-1] |
| | log({**data_dict, **self.get_step_metrics()}) |
| |
|
| | def count_valid_tokens(self, text_tokens): |
| | after_first_eos = torch.cumsum(text_tokens == self.tokenizer.eos_token_id, dim=1).bool() |
| | after_first_eos_mask = after_first_eos.cumsum(dim=1) > 1 |
| | return ~after_first_eos_mask |
| |
|
| | def get_valid_seq(self, text_tokens): |
| | if self.tokenizer.bos_token_id == self.tokenizer.eos_token_id: |
| | assert False, "BOS and EOS are the same." |
| | |
| | eos_positions = (text_tokens == self.tokenizer.eos_token_id).nonzero(as_tuple=True)[0] |
| | if len(eos_positions) > 0: |
| | return text_tokens[..., :eos_positions[0] + 1] |
| | else: |
| | return text_tokens |
| |
|
| | @try_except(write_error_to_file=True, clear_cuda_cache=True) |
| | def compute_entropy(self, text_tokens): |
| | """Compute the entropy of the generated text. |
| | Definition Pg 33 of https://arxiv.org/pdf/2409.02908 |
| | |
| | Args: |
| | text_tokens: Tensor of generated text tokens. (B, L) |
| | Returns: |
| | Entropy of the generated text. |
| | """ |
| | val_entropy = Entropy(sync_on_compute=False).to(self.device) |
| | B, L = text_tokens.shape |
| | K = self.tokenizer.vocab_size |
| | |
| | |
| | entropies = [] |
| | for seq in text_tokens: |
| | seq_length = seq.numel() |
| | token_frequencies = torch.bincount(self.get_valid_seq(seq), minlength=K) |
| | p_k = token_frequencies.float() / seq_length |
| | p_k = p_k.to(self.device) |
| | nll = -torch.sum(p_k * torch.log(p_k + 1e-10)) |
| | entropies.append(nll) |
| | |
| | |
| | avg_entropy = torch.mean(torch.tensor(entropies)) |
| | |
| | |
| | val_entropy.update(avg_entropy, weight=B) |
| | return val_entropy.compute() |
| |
|
| | @try_except(write_error_to_file=True, clear_cuda_cache=True) |
| | def get_mauve_score(self, pred, gt, prefix): |
| | from evaluate import load |
| | mauve = load('mauve') |
| | |
| | |
| | mauve_metric = MauveScore(sync_on_compute=False).to(self.device) |
| | rprint(f"Generated {len(pred)} MAUVE predictions") |
| | assert len(pred) >= self.config.eval.mauve_num_samples |
| | rprint(f'Before removing duplicates: {len(pred)}') |
| | pred_text = list(set(pred)) |
| | rprint(f'After removing duplicates: {len(pred_text)}') |
| | ref_text = list(set(gt)) |
| | store_path = os.path.join(self.config.output_dir, f"{prefix}_mauve_predictions.pkl") |
| | with open(store_path, "wb") as f: |
| | pickle.dump(pred_text, f) |
| | |
| | rprint(f"Stored {len(pred_text)} unique MAUVE predictions to {store_path}") |
| | |
| | min_len = min(len(pred_text), len(ref_text)) |
| | pred_text = pred_text[:min_len] |
| | ref_text = ref_text[:min_len] |
| | |
| | rprint(f"Computing img to txt MAUVE score for {len(pred_text)} unique predictions and {len(ref_text)} references") |
| | |
| | |
| | device_id = 0 |
| | mauve_divergence_curve_discretization_size = self.config.eval.mauve_divergence_curve_discretization_size |
| | mauve_scaling_factor = self.config.eval.mauve_scaling_factor |
| | avg_over_seed = self.config.eval.mauve_average_over_seeds |
| | |
| | |
| | random_seeds = [random.randint(0, 100000) for _ in range(avg_over_seed)] |
| | for seed in random_seeds: |
| | mauve_score = mauve.compute( |
| | references=ref_text, |
| | predictions=pred_text, |
| | device_id=device_id, |
| | divergence_curve_discretization_size=mauve_divergence_curve_discretization_size, |
| | mauve_scaling_factor=mauve_scaling_factor |
| | ) |
| | mauve_metric.update(mauve_score.mauve) |
| | rprint(f"MAUVE score for seed {seed}: {mauve_score.mauve}") |
| | store_path = os.path.join(self.config.output_dir, f"{prefix}_mauve_score_seed_{seed}.txt") |
| | with open(store_path, "w") as f: |
| | f.write(str(mauve_score)) |
| | |
| | rprint(f"Stored MAUVE score for seed {seed} to {store_path}") |
| | |
| | avg_mauve_score = mauve_metric.compute() |
| | return avg_mauve_score |
| |
|
| |
|
| | def _sample_prior(self, *batch_dims): |
| | return self.mask_index * torch.ones(*batch_dims, dtype=torch.int64) |
| |
|
| | def get_cfg_weight(self, t): |
| | _cfg = self.config.eval.cfg |
| | if not getattr(self.config.eval, "force_cfg_value", False): |
| | if _cfg == -1: |
| | _cfg = torch.linspace(0, 10, t.shape[0]).to(t.device) |
| |
|
| | if getattr(self.config.eval, "cfg_min_timestep", None) is not None and getattr(self.config.eval, "cfg_max_timestep", None) is not None: |
| | _w = (_cfg * ((t - getattr(self.config.eval, "cfg_max_timestep")) / (getattr(self.config.eval, "cfg_min_timestep") - getattr(self.config.eval, "cfg_max_timestep"))))[:, None] |
| | else: |
| | _w = (_cfg * (1 - t))[:, None] |
| | else: |
| | _w = _cfg |
| |
|
| | if getattr(self.config.eval, "cfg_min_timestep", None) is not None: |
| | _w = torch.where(t > getattr(self.config.eval, "cfg_min_timestep", None), _w, torch.tensor(0.0)) |
| | |
| | if getattr(self.config.eval, "cfg_max_timestep", None) is not None: |
| | _w = torch.where(t < getattr(self.config.eval, "cfg_max_timestep", None), _w, torch.tensor(0.0)) |
| |
|
| | if not isinstance(_w, torch.Tensor): |
| | _w = torch.tensor(_w) |
| |
|
| | return _w |
| |
|
| | def _ddpm_forward(self, x, t, sigma_t, x0=None, x0_unmask=None, force_cfg=None, **kwargs): |
| | _w = None |
| | if getattr(self.config.eval, "cfg", None) is not None and x0_unmask is not None and x0_unmask.sum() > 0: |
| | _w = self.get_cfg_weight(t) |
| | |
| | orig_modality, orig_sample_ids = None, None |
| | if _w is not None and (_w > 0).any(): |
| | x_uncond = x.clone() |
| | x_uncond[x0_unmask] = self.mask_index |
| | if getattr(self.config.eval, "split_cfg_batches", False): |
| | cat_p_x0 = torch.cat([ |
| | self.forward( |
| | x=x, |
| | sigma=sigma_t, |
| | return_logits=True, |
| | **kwargs |
| | ), |
| | self.forward( |
| | x=x_uncond, |
| | sigma=sigma_t, |
| | return_logits=True, |
| | **kwargs |
| | ) |
| | ], dim=0) |
| | else: |
| | orig_modality = kwargs.get("modality", None) |
| | if orig_modality is not None: |
| | orig_modality = orig_modality.clone() |
| | kwargs["modality"] = torch.cat([orig_modality, orig_modality], dim=0) |
| |
|
| | orig_sample_ids = kwargs.get("sample_ids", None) |
| | if orig_sample_ids is not None: |
| | orig_sample_ids = orig_sample_ids.clone() |
| | kwargs["sample_ids"] = torch.cat([orig_sample_ids, orig_sample_ids], dim=0) |
| |
|
| | if self.config.trainer.interleaved_training_flex_attention: |
| | assert 'sample_ids' in kwargs |
| | kwargs['block_mask'] = get_interleaved_block_mask(kwargs['sample_ids'], x.shape[0], x.shape[-1], self.device) |
| |
|
| | cat_p_x0 = self.forward( |
| | x=torch.cat([x, x_uncond], dim=0), |
| | sigma=torch.cat([sigma_t, sigma_t], dim=0) if sigma_t is not None else None, |
| | return_logits=True, |
| | **kwargs |
| | ) |
| | kwargs["modality"] = orig_modality |
| | kwargs["sample_ids"] = orig_sample_ids |
| |
|
| | logit_c, logit_u = cat_p_x0.chunk(2, dim=0) |
| | if isinstance(_w, torch.Tensor) and _w.ndim == 2 and logit_c.ndim == 3: |
| | _w = _w.unsqueeze(-1) |
| | output_logits = (1 + _w) * logit_c - _w * logit_u |
| | _modality = kwargs.get("modality", None) |
| | if self.config.trainer.ar_shift: |
| | _modality = _modality[:, 1:] |
| | |
| | p_x0 = self._subs_parameterization(output_logits, xt=None, batch=None, modality=_modality) |
| | p_x0 = p_x0.exp() |
| | del logit_c, logit_u, cat_p_x0, output_logits, orig_modality, orig_sample_ids, x, x_uncond |
| | else: |
| | p_x0 = self.forward(x=x, sigma=sigma_t, **kwargs) |
| | p_x0 = p_x0.exp() |
| |
|
| | if self.config.trainer.force_bf16_eval: |
| | p_x0 = p_x0.to(torch.bfloat16) |
| |
|
| | kwargs.pop("attention_caching", None) |
| | kwargs.pop("block_mask", None) |
| |
|
| | if getattr(self.config.eval, "force_empty_cache", False): |
| | empty_device_cache() |
| |
|
| | return p_x0 |
| |
|
| |
|
| | def sample_masking(self, batch, batch_idx): |
| | assert (self.config.loader.batch_size == self.config.loader.eval_batch_size) or self.config.mode == 'eval' |
| | if getattr(self.config.model, "img_cond", False): |
| | text_samples, img_samples = self._sample(text_only=False, **self.get_cond_dict(batch)) |
| | pred_img = decode_latents(self.config, self.get_vae(), img_samples) |
| | log({"val/gen_images_": wandb.Image(pred_img), "trainer/global_step": self.global_step}) |
| |
|
| | orig_bs = batch["input_ids"].shape[0] |
| | bs = min(10, max(1, int(orig_bs // 2))) |
| | bs = getattr(self.config.eval, "masking_batch_size", bs) |
| | bs = min(bs, orig_bs) |
| | |
| | if getattr(self.config.eval, "num_random_masking", None) is not None: |
| | num_random_masking = getattr(self.config.eval, "num_random_masking", 1) |
| | bs = max(bs, num_random_masking) |
| | else: |
| | num_random_masking = max((x0.shape[0] + 1) // 4, 1) |
| | |
| | _attention_mask = (batch["attention_mask"] if "attention_mask" in batch else None)[:bs] |
| | _input_ids = (batch["input_ids"])[:bs] |
| | _x_modality = (batch["modality"])[:bs] if "modality" in batch else None |
| |
|
| | if _x_modality.shape[0] != bs: |
| | _x_modality = _x_modality[[0]].repeat(bs, 1) |
| |
|
| | (input_tokens, output_tokens, _attention_mask) = self._maybe_sub_sample(_input_ids, _attention_mask) |
| | x0 = input_tokens |
| | forward_kwargs = self.get_cond_dict(batch) |
| | forward_kwargs['is_sample_masking'] = True |
| | |
| | if "x_cond" in forward_kwargs: |
| | forward_kwargs["x_cond"] = forward_kwargs["x_cond"][:bs] |
| |
|
| | assert output_tokens is None |
| | assert self.T == 0 and self.change_of_variables is False |
| |
|
| | random_masking_ratio = getattr(self.config.eval, "random_masking_ratio", 0.95) |
| | t = random_masking_ratio + (1 - random_masking_ratio) * torch.rand(num_random_masking, device=x0.device) |
| | sigma, dsigma = self.noise(t) |
| | unet_conditioning = sigma[:, None] |
| | move_chance = 1 - torch.exp(-sigma[:, None]) |
| |
|
| | unet_conditioning = torch.cat([unet_conditioning, unet_conditioning.new_full((bs - num_random_masking, 1), torch.nan)], dim=0) |
| | move_chance = torch.cat([move_chance, move_chance.new_full((bs - num_random_masking, move_chance.shape[1]), 1)], dim=0) |
| |
|
| | uniform_mask = torch.full(x0.shape, True, device=x0.device, dtype=torch.bool) |
| | text_only_mask = uniform_mask.clone() |
| | text_only_mask = torch.where(_x_modality == 1, False, text_only_mask) |
| |
|
| | image_only_mask = uniform_mask.clone() |
| | image_only_mask = torch.where(_x_modality == 0, False, image_only_mask) |
| | image_only_mask = torch.where(batch["batch_contains_img"][:bs, None], image_only_mask, True) |
| | mask_dict = dict(mask_all=uniform_mask, mask_text_only=text_only_mask, mask_image_only=image_only_mask) |
| |
|
| | if getattr(self.config.eval, "mask_img_only", False): |
| | uniform_mask = torch.full(x0.shape, True, device=x0.device, dtype=torch.bool) |
| | image_only_mask = torch.where(_x_modality == 0, False, uniform_mask) |
| | move_chance = torch.ones_like(move_chance) |
| | mask_dict = dict(mask_image_only=image_only_mask) |
| | elif getattr(self.config.eval, "mask_img_only_keep_partial", False): |
| | mask_dict = dict(mask_image_only=image_only_mask) |
| | elif getattr(self.config.eval, "mask_all_only", False): |
| | mask_dict = dict(mask_all=uniform_mask) |
| |
|
| | only_uniform_mask = getattr(self.config.eval, "only_uniform_mask", False) |
| |
|
| | table_dict = dict() |
| | for mask_name, allow_move_mask in mask_dict.items(): |
| | if mask_name == "mask_all" and not only_uniform_mask: |
| | _move_chance = 0.5 + (1 - 0.5) * torch.rand_like(move_chance) |
| | elif mask_name == "mask_text_only": |
| | _move_chance = torch.zeros_like(move_chance) |
| | else: |
| | _move_chance = move_chance |
| |
|
| | xt = self.q_xt( |
| | x0, |
| | _move_chance, |
| | allow_move_mask, |
| | mask_image_square=(mask_name != "mask_text_only") and not only_uniform_mask, |
| | mask_text_region=(mask_name != 'mask_image_only') and not only_uniform_mask |
| | ) |
| |
|
| | if getattr(self.config.eval, "single_step_denoising", False): |
| | forward_kwargs.pop("is_sample_masking", None) |
| | model_output = self.forward(xt, unet_conditioning, **forward_kwargs) |
| | if not self.is_compiled: |
| | utils.print_nans(model_output, "model_output") |
| | model_output = model_output.exp() |
| | pred_tokens = model_output.argmax(dim=-1) |
| | pred_tokens = torch.where(xt == self.mask_index, pred_tokens, xt) |
| | pred_text, pred_img = self.decode_batch(pred_tokens, text_only=False, sample_modality=_x_modality) |
| | pred_img = decode_latents(self.config, self.get_vae(), pred_img) |
| | pred_txt = wrapped_batch_decode(self.tokenizer, pred_text, clean_up_tokenization_spaces=True, skip_special_tokens=True, disable_mask_after_eos=self.config.data.disable_mask_after_eos) |
| | else: |
| | xt_unmasked = xt != self.mask_index |
| | pred_txt, pred_img = self.sample(x0=xt, x0_unmask=xt_unmasked, sample_modality=_x_modality, **forward_kwargs) |
| |
|
| | gen_table = wandb.Table(columns=["GT Img", "GT Caption", "Masked Img", "Masked Caption", "Pred Img", "Pred Caption", "Move chance"]) |
| | masked_txt, masked_img, mask_text_mask, mask_img_mask = self.decode_batch( |
| | xt, text_only=False, return_masks=True, allow_mask_index=True, sample_modality=_x_modality |
| | ) |
| |
|
| | downscale_ratio = self.config.model.downscale_ratio |
| | latent_dim = self.config.data.resolution // downscale_ratio |
| |
|
| | img_mask = einops.repeat( |
| | einops.rearrange(mask_img_mask[:, self.static_img_sl], "b (h w) -> b h w", h=latent_dim, w=latent_dim), |
| | "b h w -> b (h na) (w nb)", |
| | na=downscale_ratio, |
| | nb=downscale_ratio, |
| | ) |
| |
|
| | gt_txt, gt_img = self.decode_batch(_input_ids, text_only=False, sample_modality=_x_modality) |
| | gt_txt = wrapped_batch_decode(self.tokenizer, gt_txt, clean_up_tokenization_spaces=True, skip_special_tokens=True, disable_mask_after_eos=self.config.data.disable_mask_after_eos) |
| | gt_img = decode_latents(self.config, self.get_vae(), gt_img) |
| |
|
| | masked_txt = wrapped_batch_decode(self.tokenizer, masked_txt, clean_up_tokenization_spaces=True, skip_special_tokens=True, disable_mask_after_eos=self.config.data.disable_mask_after_eos) |
| | masked_img = gt_img.clone().permute(0, 2, 3, 1) |
| | masked_img[img_mask] = torch.tensor([0.5, 0.5, 0.5], dtype=masked_img.dtype, device=masked_img.device) |
| | masked_img = masked_img.permute(0, 3, 1, 2) |
| | for _gt_img, _gt_txt, _masked_img, _masked_txt, _pred_img, _pred_txt, _move_chance in zip( |
| | gt_img, gt_txt, masked_img, masked_txt, pred_img, pred_txt, move_chance |
| | ): |
| | gen_table.add_data( |
| | wandb.Image(_gt_img), _gt_txt, wandb.Image(_masked_img), _masked_txt, wandb.Image(_pred_img), _pred_txt, _move_chance |
| | ) |
| |
|
| | table_suffix = f"_{batch_idx}" |
| | table_dict[f"{mask_name}_sample_table{table_suffix}"] = gen_table |
| |
|
| | log({**table_dict, "trainer/global_step": self.global_step}) |
| |
|
| | def log_flops(self, batch, batch_idx): |
| | use_torch_tnt = False |
| | use_native_torch = True |
| | use_fvcore = False |
| | with torch.enable_grad(): |
| | with torch.autocast(self.device.type, dtype=self.dtype): |
| | new_batch_idxs = batch["input_ids"].new_ones((self.config.loader.batch_size, self.config.model.length)) |
| | if use_fvcore: |
| | |
| | from fvcore.nn import (ActivationCountAnalysis, |
| | FlopCountAnalysis, flop_count_str, |
| | flop_count_table) |
| | example_input = (new_batch_idxs, None) |
| | fca = FlopCountAnalysis(self.accelerator.unwrap_model(self.backbone), example_input) |
| | aca = ActivationCountAnalysis(self.accelerator.unwrap_model(self.backbone), example_input) |
| | print(flop_count_table(fca, max_depth=1)) |
| | print(flop_count_str(fca)) |
| | print(fca.total()) |
| |
|
| | if use_torch_tnt: |
| | from torchtnt.utils.module_summary import get_module_summary |
| | module_summary = get_module_summary(self.backbone, module_args=(new_batch_idxs, None), module_kwargs={}) |
| | rprint(module_summary) |
| | rprint(f"TorchTNT Forward FLOPs: {module_summary.flops_forward / 1e12:.2f} FLOPs") |
| | rprint(f"TorchTNT Backward FLOPs: {module_summary.flops_backward / 1e12:.2f} FLOPs") |
| | rprint(f"TorchTNT Total FLOPs: {(module_summary.flops_forward + module_summary.flops_backward) / 1e12:.2f} FLOPs") |
| |
|
| | if use_native_torch: |
| | from torch.utils.flop_counter import FlopCounterMode |
| | flop_counter = FlopCounterMode(self.backbone, display=True, depth=3) |
| | with flop_counter: |
| | fake_batch = {} |
| | fake_batch["input_ids"] = new_batch_idxs |
| | fake_batch['attention_mask'] = batch['attention_mask'].new_ones(new_batch_idxs.shape) |
| | if 'modality' in batch: |
| | fake_batch['modality'] = batch['modality'].new_ones(new_batch_idxs.shape) |
| | fake_batch['x0'] = fake_batch["input_ids"] |
| | t = self._sample_t(fake_batch['x0'].shape[0], fake_batch['x0'].device) |
| | sigma, dsigma = self.noise(t) |
| | move_chance = 1 - torch.exp(-sigma[:, None]) |
| | xt = self.q_xt(fake_batch['x0'], move_chance) |
| | fake_batch['xt'] = xt |
| | if self.config.trainer.image_mode == "continuous": |
| | B, T = fake_batch["input_ids"].shape |
| | indices = fake_batch["input_ids"].to(batch['text_tokens'].dtype) |
| | fake_sigma = torch.ones(B, T, device=self.device).long() |
| | fake_x_img_emb = torch.randn(B, T, 4 * (self.config.model.patching_downscale ** 2), device=self.device) |
| | fake_modality = torch.zeros(B, T, device=self.device, dtype=torch.long) |
| | fake_modality[:, self.config.model.txt_length:] = True |
| | logits = self.backbone(indices=indices, sigma=fake_sigma, continuous_mode=True, x_img_emb=fake_x_img_emb, modality=fake_modality) |
| | else: |
| | logits = self.backbone(fake_batch["input_ids"], sigma=None, modality=fake_batch.get("modality", None)) |
| | from transformers.modeling_outputs import \ |
| | CausalLMOutputWithPast |
| | if isinstance(logits, torch.Tensor): |
| | logits = logits |
| | elif isinstance(logits, tuple): |
| | logits = logits[0] |
| | elif isinstance(logits, CausalLMOutputWithPast): |
| | logits = logits.logits |
| |
|
| | loss = logits.mean().to(torch.float32) |
| | loss.backward() |
| |
|
| | total_flops = flop_counter.get_total_flops() |
| | rprint(f"Total FLOPs Per Sample Fwd+Bwd: {(total_flops / self.config.loader.batch_size) / 1e12:.2f} TFLOPs") |
| | rprint(f"Total FLOPs Per Fwd+Bwd: {total_flops / 1e12:.2f} TFLOPs") |
| | rprint(f"Total FLOPs Per Global Step: {(total_flops / 1e12) * self.world_size * self.gradient_accumulation_steps:.2f} TFLOPs") |
| |
|
| | rprint(f"GPU available FLOP/s: {get_available_flops(new_batch_idxs.device, self.dtype) / 1e12:.2f} TFLOP/s") |
| | rprint(f"Total available FLOP/s: {(get_available_flops(new_batch_idxs.device, self.dtype) / 1e12) * self.world_size * self.gradient_accumulation_steps:.2f} TFLOP/s") |
| | rprint(f"Used Batch Size: {self.config.loader.batch_size} for FLOP Calculations") |
| |
|
| | @torch.inference_mode() |
| | def _ddpm_update(self, x, t, dt, **kwargs): |
| | sigma_t, _ = self.noise(t) |
| | sigma_s, _ = self.noise(t - dt) |
| | if sigma_t.ndim > 1: |
| | sigma_t = sigma_t.squeeze(-1) |
| | if sigma_s.ndim > 1: |
| | sigma_s = sigma_s.squeeze(-1) |
| | assert sigma_t.ndim == 1, sigma_t.shape |
| | assert sigma_s.ndim == 1, sigma_s.shape |
| | move_chance_t = 1 - torch.exp(-sigma_t) |
| | move_chance_s = 1 - torch.exp(-sigma_s) |
| | move_chance_t = move_chance_t[:, None, None] |
| | move_chance_s = move_chance_s[:, None, None] |
| | nfe_cnt = 0 |
| | _sigma = None if getattr(self.config.trainer, "force_null_sigma", False) else sigma_t |
| | p_x0 = self._ddpm_forward(x, t, _sigma, **kwargs) |
| | nfe_cnt += 1 |
| | assert move_chance_t.ndim == p_x0.ndim |
| | |
| | |
| | |
| | q_xs = p_x0 * (move_chance_t - move_chance_s) |
| | q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0] |
| | _x = _sample_categorical(q_xs) |
| |
|
| | copy_flag = (x != self.mask_index).to(x.dtype) |
| | del p_x0, q_xs, move_chance_t, move_chance_s |
| | return copy_flag * x + (1 - copy_flag) * _x, nfe_cnt |
| |
|
| | @torch.inference_mode() |
| | def _ddpm_caching_update(self, x, t, dt, p_x0=None, x0=None, x0_unmask=None, modality=None,**kwargs): |
| | assert self.config.noise.type == "loglinear" |
| | sigma_t, _ = self.noise(t) |
| | if t.ndim > 1: |
| | t = t.squeeze(-1) |
| | |
| | nfe_cnt = 0 |
| | assert t.ndim == 1 |
| | move_chance_t = t[:, None, None] |
| | move_chance_s = (t - dt)[:, None, None] |
| | assert move_chance_t.ndim == 3, move_chance_t.shape |
| |
|
| | if p_x0 is None: |
| | _sigma = None if getattr(self.config.trainer, "force_null_sigma", False) else sigma_t |
| | p_x0 = self._ddpm_forward(x, t, _sigma, x0=x0, x0_unmask=x0_unmask, modality=modality, **kwargs) |
| | nfe_cnt += 1 |
| | assert move_chance_t.ndim == p_x0.ndim |
| | if self.config.trainer.force_bf16_eval: empty_device_cache() |
| | q_xs = p_x0 * (move_chance_t - move_chance_s) |
| | q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0] |
| | _x = _sample_categorical(q_xs) |
| | copy_flag = (x != self.mask_index).to(x.dtype) |
| | if self.config.trainer.force_bf16_eval: empty_device_cache() |
| |
|
| | if self.config.trainer.ar_shift: |
| | if x0 is not None: |
| | _x = torch.cat([x0[:, [0]], _x], dim=1) |
| | else: |
| | _x = torch.cat([torch.full_like(_x[..., :1], fill_value=self.tokenizer.pad_token_id), _x], dim=1) |
| |
|
| | del q_xs, move_chance_t, move_chance_s |
| | return p_x0, copy_flag * x + (1 - copy_flag) * _x, nfe_cnt |
| |
|
| |
|
| | @try_except(write_error_to_file=True, clear_cuda_cache=True) |
| | @torch.inference_mode() |
| | def _sample( |
| | self, |
| | num_steps=None, |
| | eps=1e-5, |
| | text_only=True, |
| | x0=None, |
| | x0_unmask=None, |
| | batch_size_per_gpu=None, |
| | example_batch=None, |
| | sample_batch_idx=None, |
| | sample_modality=None, |
| | sample_ids=None, |
| | return_raw_data=False, |
| | **kwargs, |
| | ): |
| | """Generate samples from the model.""" |
| | if not (x0 is None) == (x0_unmask is None): |
| | breakpoint() |
| | assert (x0 is None) == (x0_unmask is None), f"x0: {x0} x0_unmask: {x0_unmask}" |
| | batch_size_per_gpu = (x0.shape[0] if x0 is not None else self.config.loader.eval_batch_size) if batch_size_per_gpu is None else batch_size_per_gpu |
| | sample_modality = kwargs.get("modality", None) if sample_modality is None else sample_modality |
| | kwargs['modality'] = sample_modality |
| | kwargs['sample_ids'] = sample_ids |
| | return_nfe = kwargs.pop('return_nfe', False) |
| | is_sample_masking = kwargs.pop('is_sample_masking', False) |
| | allow_interleaved_conditional = kwargs.pop('allow_interleaved_conditional', False) |
| | nfe_cnt = 0 |
| | assert batch_size_per_gpu > 0 |
| | if num_steps is None: |
| | num_steps = self.config.sampling.steps |
| | if getattr(self.config.eval, "test_eval_speed", False) and getattr(self.config.eval, 'eval_at_ratio_length', False): |
| | num_steps = self.config.model.length |
| | if getattr(self.config.eval, "num_steps_ratio", None) is not None: |
| | num_steps = int(num_steps * self.config.eval.num_steps_ratio) |
| | |
| | decode_kwargs = dict(sample_modality=sample_modality, return_raw_data=return_raw_data, is_sample_masking=is_sample_masking) |
| |
|
| | if x0 is not None and x0_unmask is not None: |
| | x = self._sample_prior(batch_size_per_gpu, x0.shape[1]).to(self.device) |
| | decode_kwargs['x0_unmask'] = x0_unmask |
| | if getattr(self.config.eval, "visualize_sample", False): |
| | x_viz = x.clone() |
| | x_viz = torch.where(x0_unmask, x0, x) |
| | _mask_id = self.tokenizer("mask")['input_ids'] |
| | assert len(_mask_id) == 3 |
| | x_viz[x_viz == self.mask_index] = _mask_id[1] |
| | ret_txt, ret_img = self.decode_sampling(x_viz, text_only, **kwargs, **decode_kwargs, image_save_postfix="_masked_input") |
| | print(ret_txt) |
| |
|
| | elif (self.config.trainer.interleaved and not self.config.backbone == "chameleon") and allow_interleaved_conditional: |
| | assert self.config.trainer.interleaved_training_flex_attention |
| | x0 = example_batch['input_ids'].to(self.device) |
| | total_samples = getattr(self.config.eval, "num_uncond_sample_batches", 1) - 1 |
| | half_uncond = getattr(self.config.eval, "half_uncond", False) |
| | if not half_uncond or sample_batch_idx >= total_samples // 2: |
| | unmask_modality = getattr(self.config.eval, "unmask_modality", sample_batch_idx % 2) |
| | x0_unmask = sample_modality == unmask_modality |
| | if x0_unmask.sum() == x0.numel(): |
| | unmask_modality = 1 - unmask_modality |
| | x0_unmask = sample_modality == unmask_modality |
| |
|
| | if x0.shape != sample_modality.shape: |
| | breakpoint() |
| | |
| | if unmask_modality == 1: |
| | x0_unmask = torch.zeros_like(x0_unmask) |
| | for i in range(x0.shape[0]): |
| | eos_pos = (x0[i] == self.tokenizer.eos_token_id).nonzero(as_tuple=True)[0] |
| | if len(eos_pos) > 0: |
| | idx = random.randint(0, len(eos_pos) - 2) |
| | x0_unmask[i, :] = True |
| | if len(eos_pos) >= idx + 1: |
| | _sl = slice(eos_pos[idx], None) |
| | else: |
| | _sl = slice(eos_pos[idx] + 2, eos_pos[idx+1] - 1) |
| | |
| | x0_unmask[i, _sl] = (sample_modality[i, _sl] == 1) |
| |
|
| | |
| | for i in range(x0.shape[0]): |
| | eos_pos = (x0[i] == self.tokenizer.eos_token_id).nonzero(as_tuple=True)[0] |
| | if len(eos_pos) > 0: |
| | assert (eos_pos[0] < 48) or (sample_modality[i].sum() == 0), f"eos_pos: {eos_pos}" |
| | x0_unmask[i, :eos_pos[0]+1] = True |
| |
|
| | if unmask_modality == 1 and x0_unmask.sum() == 0: |
| | x0_unmask = torch.ones_like(x0_unmask) |
| | print(f"Found no umasked tokens, unmasking random sequences") |
| | for i in range(x0.shape[0]): |
| | seq_len = (x0[i] != self.tokenizer.pad_token_id).sum() |
| | if seq_len == 0: |
| | continue |
| | |
| | start_pos = random.randint(0, seq_len-1) |
| | max_len = min(seq_len - start_pos, 200) |
| | unmask_len = random.randint(1, max_len) |
| | x0_unmask[i, start_pos:start_pos+unmask_len] = False |
| |
|
| | gprint(f"Unmasking modality: {unmask_modality}, Unmasking {(x0_unmask.sum() / x0_unmask.numel()):.2%} of image tokens. Txt tokens: {(sample_modality == 0).sum()}, Img tokens: {(sample_modality == 1).sum()}") |
| |
|
| | x0_unmask[~example_batch['attention_mask']] = True |
| | x = self._sample_prior(batch_size_per_gpu, self.config.model.length).to(self.device) |
| | decode_kwargs['x0_unmask'] = x0_unmask |
| | x = torch.where(x0_unmask, x0, x) |
| |
|
| | if getattr(self.config.eval, "visualize_sample", False): |
| | x_viz = x.clone() |
| | _mask_id = self.tokenizer("mask")['input_ids'] |
| | assert len(_mask_id) == 3 |
| | _mask_id = _mask_id[1] |
| | x_viz[x == self.mask_index] = _mask_id |
| | self.decode_sampling(x_viz, text_only, **kwargs, **decode_kwargs, image_save_postfix="_x0_unmasked") |
| |
|
| | if self.parameterization == "ar" or getattr(self.config.eval, "eval_large_batch", None) is not None: |
| | rprint(f"Masking all tokens by default.") |
| | x0_unmask = torch.zeros(*x0.shape, device=x0.device).to(torch.bool) |
| | else: |
| | rprint(f"Hit chamelon sample") |
| | if sample_batch_idx == getattr(self.config.eval, "num_uncond_sample_batches", 1) - 1: |
| | x0_unmask = torch.zeros(*x0.shape, device=x0.device, dtype=torch.bool) |
| | x0_unmask[..., -20:] = True |
| | rprint(f"Unmasking first {x0_unmask.shape[-1] // 2} tokens") |
| | else: |
| | x0_unmask = torch.rand(*x0.shape, device=x0.device) < (sample_batch_idx / 60) |
| | rprint(f"Unmasking {(sample_batch_idx / 60)} of image_tokens, {x0_unmask.sum()}") |
| |
|
| | x = self._sample_prior(batch_size_per_gpu, x0.shape[1]).to(self.device) |
| | _img_indices = torch.isin(x0, torch.tensor(list(image_indices), device=self.device)) |
| | if getattr(self.config.eval, "unmask_chameleon_txt", False): |
| | rprint(f"Unmasking all text tokens") |
| | x0_unmask |= _img_indices |
| | x0_unmask[:, :4] = True |
| | rprint(f"All tokens: {x0_unmask.tolist()}") |
| | |
| | |
| | else: |
| | x0_unmask |= (~_img_indices) |
| |
|
| | kwargs['forward_attention_mask'] = attention_mask |
| | decode_kwargs['image_indices'] = image_indices |
| | decode_kwargs['x0_unmask'] = x0_unmask |
| | rprint(f"Unmasking: {torch.sum(x0_unmask)}") |
| | else: |
| | x = self._sample_prior(batch_size_per_gpu, self.config.model.length).to(self.device) |
| | decode_kwargs['x0_unmask'] = x0_unmask |
| |
|
| | if self.config.trainer.interleaved_training_flex_attention: |
| | assert 'sample_ids' in kwargs |
| | kwargs['block_mask'] = get_interleaved_block_mask(kwargs['sample_ids'], x.shape[0], x.shape[-1], self.device) |
| |
|
| | if num_steps > (~x0_unmask).sum(dim=-1).min(): |
| | rprint(f"num_steps {num_steps} > sequence length {(~x0_unmask).sum(dim=-1).min()}, setting num_steps to sequence length") |
| | num_steps = (~x0_unmask).sum(dim=-1).min() |
| |
|
| | if self.parameterization == "ar": |
| | with show_memory_usage(empty_cache=True): |
| | out, nfe_cnt = self._ar_sampler(batch_size_per_gpu, x0=x0, x0_unmask=x0_unmask, **kwargs) |
| | res = self.decode_sampling(out, text_only, **kwargs, **decode_kwargs) |
| | if return_nfe: |
| | return res, nfe_cnt |
| | return res |
| |
|
| | if x0 is not None and x0_unmask is not None: |
| | x = torch.where(x0_unmask, x0, x) |
| |
|
| | if self.sampler == "maskgit" or self.sampler == "first_hitting" or self.sampler == "maskgit_nucleus": |
| | sampling_schedule = 'arccos' if self.sampler in ['maskgit', 'maskgit_nucleus'] else 'linear' |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | schedule = adap_sche(x=x, step=num_steps, mask_index=self.mask_index, mode=sampling_schedule) |
| | print(f"schedule: {schedule}") |
| | |
| | timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device) |
| | dt = (1 - eps) / num_steps |
| | p_x0_cache = None |
| | |
| | is_x_sliced = False |
| | attention_caching = self.config.eval.attention_caching |
| | attention_caching_txt_to_img_ratio = getattr(self.config.eval, "attention_caching_txt_to_img_ratio", 10) |
| | if attention_caching: |
| | backbone = self.accelerator.unwrap_model(self.backbone) |
| | backbone.set_flex_attention_cache(x.shape[0], x.shape[1], self.device, self.dtype) |
| | full_data = dict() |
| | x_next = None |
| |
|
| | |
| | if getattr(self.config.eval, "visualize_denoising", False): |
| | denoising_steps = [x.clone()] |
| | |
| | for i in range(num_steps): |
| | t = timesteps[i] * torch.ones(x.shape[0], 1, device=self.device, dtype=self.dtype if self.config.trainer.force_bf16_eval else torch.float32) |
| | if attention_caching: |
| | if i % attention_caching_txt_to_img_ratio == 0: |
| | if is_x_sliced: |
| | def replace_new_data(_key, _new_data): |
| | if full_data[_key] is not None: |
| | full_data[_key][:,self.static_txt_sl] = _new_data |
| | return full_data[_key] |
| | |
| | x = replace_new_data("x", x) |
| | x0 = replace_new_data("x0", x0) |
| | x0_unmask = replace_new_data("x0_unmask", x0_unmask) |
| | p_x0_cache = replace_new_data("p_x0_cache", p_x0_cache) |
| | kwargs["modality"] = replace_new_data("modality", kwargs.get("modality", None)) |
| | del full_data |
| | full_data = dict() |
| | is_x_sliced = False |
| |
|
| | update_cache_slice = None |
| | block_mask = True |
| | elif (i - 1) % attention_caching_txt_to_img_ratio == 0: |
| | update_cache_slice = slice(0, x.shape[1]) |
| | block_mask = get_block_mask( |
| | txt_batch_attn_dropout=torch.zeros(x.shape[0], dtype=torch.bool, device=x.device), |
| | img_batch_attn_dropout=torch.ones(x.shape[0], dtype=torch.bool, device=x.device), |
| | txt_length=self.config.model.txt_length, |
| | batch_size=x.shape[0], |
| | seq_len=x.shape[1], |
| | device=x.device |
| | ) |
| | else: |
| | update_cache_slice = self.static_txt_sl |
| | block_mask = True |
| | if not is_x_sliced: |
| | is_x_sliced = True |
| | |
| | def clone_if_valid(_data): |
| | if _data is not None: |
| | return _data.clone() |
| | else: |
| | return None |
| | |
| | def sl_if_valid(_data): |
| | if _data is not None: |
| | return _data[:, self.static_txt_sl] |
| | else: |
| | return None |
| | |
| | full_data.update(x=clone_if_valid(x), x0=clone_if_valid(x0), x0_unmask=clone_if_valid(x0_unmask), modality=clone_if_valid(kwargs.get("modality", None)), p_x0_cache=clone_if_valid(p_x0_cache)) |
| | x = sl_if_valid(x) |
| | x0 = sl_if_valid(x0) |
| | x0_unmask = sl_if_valid(x0_unmask) |
| | x_next = sl_if_valid(x_next) |
| | p_x0_cache = sl_if_valid(p_x0_cache) |
| | kwargs["modality"] = sl_if_valid(kwargs.get("modality", None)) |
| |
|
| | kwargs["update_cache_slice"] = update_cache_slice |
| | kwargs["block_mask"] = block_mask |
| |
|
| | if self.sampler == "maskgit": |
| | x, nfe_step_cnt = self._maskgit_update(x, t, dt, x0=x0, x0_unmask=x0_unmask, schedule=schedule, step=i, **kwargs) |
| | elif self.sampler == "maskgit_nucleus": |
| | x, nfe_step_cnt = self._maskgit_nucleus_update(x, t, dt, x0=x0, x0_unmask=x0_unmask, schedule=schedule, step=i, **kwargs) |
| | elif self.sampler == "first_hitting": |
| | x, nfe_step_cnt = self._first_hitting_update(x, t, dt, x0=x0, x0_unmask=x0_unmask, schedule=schedule, step=i, **kwargs) |
| | elif self.sampler == "ddpm": |
| | x, nfe_step_cnt = self._ddpm_update(x, t, dt, x0=x0, x0_unmask=x0_unmask, **kwargs) |
| | elif self.sampler == "ddpm_tweedie": |
| | assert not return_nfe, "Tweedie sampler does not support return_nfe" |
| | x = self._ddpm_update_finetune_controlled_tweedie(x, t, dt, sampling_step=i, **kwargs) |
| | nfe_step_cnt = 0 |
| | elif self.sampler == "ddpm_cache": |
| | p_x0_cache, x_next, nfe_step_cnt = self._ddpm_caching_update(x, t, dt, p_x0=p_x0_cache, x0=x0, x0_unmask=x0_unmask, **kwargs) |
| | if not torch.allclose(x_next, x) or self.time_conditioning: |
| | p_x0_cache = None |
| | x = x_next |
| | else: |
| | x, nfe_step_cnt = self._analytic_update(x, t, dt) |
| |
|
| | nfe_cnt += nfe_step_cnt |
| | if self.tokenizer.eos_token_id in x and getattr(self.config.trainer, "force_after_eos_padding", False) and (self.tokenizer.eos_token_id != self.tokenizer.bos_token_id) and not attention_caching: |
| | after_first_eos = torch.cumsum(x == self.tokenizer.eos_token_id, dim=1).bool() |
| | after_first_eos_mask = after_first_eos.cumsum(dim=1) > 1 |
| | to_mask = ((after_first_eos_mask & (sample_modality == 0)) & (x != self.tokenizer.pad_token_id)) & (x != self.mask_index) |
| | x[to_mask] = self.tokenizer.pad_token_id |
| |
|
| | if to_mask.sum() > 0: |
| | rprint(f"Masked an avg of {torch.sum(to_mask, dim=1).float().mean()} tokens due to EOS.") |
| |
|
| | if x0 is not None and x0_unmask is not None: x = torch.where(x0_unmask, x0, x) |
| | |
| | |
| | if getattr(self.config.eval, "visualize_denoising", False) and i % getattr(self.config.eval, "visualize_step_interval", max(1, num_steps // 10)) == 0: |
| | denoising_steps.append(x.clone()) |
| | |
| | clear_gpu_memory_if_needed() |
| | |
| | if getattr(self.config.eval, "visualize_denoising", False) and denoising_steps: |
| | if denoising_steps[-1] is not x: |
| | denoising_steps.append(x.clone()) |
| | |
| | step_images = [] |
| | for step_x in denoising_steps: |
| | _, step_res = self.decode_sampling(step_x, text_only=False, bypass_return_interleaved_modalities_split=True, **kwargs, **decode_kwargs) |
| | if not isinstance(step_res, Image.Image): |
| | step_res = step_res[0] |
| | step_images.append(step_res) |
| | |
| | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| | date_folder = datetime.now().strftime("%Y-%m-%d") |
| | save_dir = Path("/dev/shm") / os.getenv("USER", 'user') / "denoise_vis" / date_folder / f"{timestamp}.png" |
| | save_dir.parent.mkdir(parents=True, exist_ok=True) |
| | Im.concat_horizontal(step_images).save(save_dir) |
| | rprint(f"Saved denoising visualization to {save_dir}") |
| |
|
| | if is_x_sliced: |
| | def replace_new_data(_key, _new_data): |
| | if full_data[_key] is not None: |
| | full_data[_key][:,self.static_txt_sl] = _new_data |
| | return full_data[_key] |
| | |
| | x = replace_new_data("x", x) |
| | x0 = replace_new_data("x0", x0) |
| | x0_unmask = replace_new_data("x0_unmask", x0_unmask) |
| | p_x0_cache = replace_new_data("p_x0_cache", p_x0_cache) |
| | kwargs["modality"] = replace_new_data("modality", kwargs.get("modality", None)) |
| | del full_data |
| | full_data = dict() |
| | is_x_sliced = False |
| |
|
| | if self.config.sampling.noise_removal: |
| | t = timesteps[-1] * torch.ones(x.shape[0], 1, device=self.device) |
| | if self.sampler == "analytic": |
| | x = self._denoiser_update(x, t) |
| | else: |
| | unet_conditioning = self.noise(t)[0] |
| | x = self.forward(x=x, sigma=unet_conditioning, **kwargs).argmax(dim=-1) |
| |
|
| | if x0 is not None and x0_unmask is not None: |
| | x = torch.where(x0_unmask, x0, x) |
| | res = self.decode_sampling(x, text_only, **kwargs, **decode_kwargs) |
| |
|
| | if return_nfe: |
| | return res, nfe_cnt |
| | return res |
| |
|
| | def decode_sampling(self, x, text_only, is_sample_masking=False, bypass_return_interleaved_modalities_split=False, **kwargs): |
| | if self.config.trainer.interleaved and getattr(self.config.eval, "return_interleaved_modalities_split", False) and not bypass_return_interleaved_modalities_split: |
| | decoded_data = self.decode_batch({"input_ids": x, **kwargs}, text_only=False) |
| | image_save_postfix = kwargs.get("image_save_postfix", None) |
| | assert len(decoded_data) == 1 |
| | all_imgs = [] |
| | all_txt = [] |
| | for i in range(min(len(decoded_data), 64)): |
| | sample_data, sample_modalities = decoded_data[i].to_list() |
| | ret = self.get_interleaved_image(sample_data, sample_modalities, image_save_postfix=image_save_postfix) |
| | all_txt_in_sample = [] |
| | all_img_in_sample = [] |
| | for j in range(len(sample_data)): |
| | if sample_modalities[j] == 0: |
| | text_samples = sample_data[j] |
| | pred_txt = wrapped_batch_decode( |
| | self.tokenizer, text_samples[None], clean_up_tokenization_spaces=False, skip_special_tokens=False, disable_mask_after_eos=True |
| | ) |
| | all_txt_in_sample.extend(pred_txt) |
| | else: |
| | img_samples = sample_data[j] |
| | pred_img = decode_latents(self.config, self.get_vae(), img_samples[None]) |
| | all_img_in_sample.extend([Im(x).pil for x in pred_img]) |
| |
|
| | |
| | if len(all_txt_in_sample) >= 2 and all_txt_in_sample[-1] == self.tokenizer.eos_token: |
| | all_txt_in_sample[-2] += all_txt_in_sample[-1] |
| | all_txt_in_sample.pop() |
| |
|
| | all_txt.extend(all_txt_in_sample) |
| | all_imgs.extend(all_img_in_sample) |
| |
|
| | print(f"Returning... all_txt: {all_txt}, all_imgs: {all_imgs}") |
| | for i in range(len(all_imgs)): |
| | filename = f"img_{get_rank()}_{str(time.time()).replace('.', '__')}.png" |
| | Im(all_imgs[i]).save(filename) |
| | return all_txt, all_imgs |
| | elif (self.config.trainer.interleaved and not is_sample_masking) or getattr(self.config.eval, "fake_interleaved", False): |
| | image_save_postfix = kwargs.get("image_save_postfix", None) |
| | decoded_data = self.decode_batch({"input_ids": x, **kwargs}, text_only=False) |
| | all_imgs = [] |
| | all_txt_ids = [] |
| | num_text_tokens = self.config.model.txt_length |
| | for i in range(min(len(decoded_data), 64)): |
| | sample_data, sample_modalities = decoded_data[i].to_list() |
| | all_imgs.append(self.get_interleaved_image(sample_data, sample_modalities, image_save_postfix=image_save_postfix)) |
| | all_txt_ids_in_sample = [] |
| | for j in range(len(sample_data)): |
| | if sample_modalities[j] == 0: |
| | text_samples = sample_data[j] |
| | if text_samples.shape[-1] < num_text_tokens: |
| | text_samples = torch.nn.functional.pad( |
| | text_samples, |
| | (0, num_text_tokens - text_samples.shape[-1]), |
| | value=self.tokenizer.pad_token_id |
| | ) |
| | else: |
| | text_samples = text_samples[..., :num_text_tokens] |
| | all_txt_ids_in_sample.append(text_samples) |
| |
|
| | if len(all_txt_ids_in_sample) == 0: |
| | all_txt_ids_in_sample.append(torch.zeros((num_text_tokens), dtype=torch.long, device=self.device)) |
| |
|
| | all_txt_ids.append(torch.cat(all_txt_ids_in_sample, dim=0)) |
| |
|
| | if kwargs.get("return_raw_data", False): |
| | return all_txt_ids, all_imgs, x |
| | |
| | return all_txt_ids, all_imgs |
| | else: |
| | ret = self.decode_batch(x, text_only=text_only, **kwargs) |
| | if getattr(self.config.eval, "visualize_sample", False): |
| | self.save_image_text_pair(ret[1], ret[0][:, self.static_txt_sl]) |
| | return ret |
| |
|
| |
|
| | @tensorclass |
| | class InputData: |
| | |
| | xt_ids: Integer[Tensor, "b h w c"] |
| |
|
| | |
| | xt_img_embed: Optional[Float[Tensor, "b h w 2"]] = None |
| | modality: Bool[Tensor, "b h w"] = False |
| | sigma: Optional[Float[Tensor, "b"]] = None |
| |
|
| | @torch.no_grad() |
| | def sample_transfusion( |
| | self, |
| | batch_size_per_gpu=None, |
| | text_only=False, |
| | ): |
| | """Generate samples from the model in autoregressive discrete mode for text and diffusion for image.""" |
| | |
| | |
| | B = batch_size_per_gpu if batch_size_per_gpu is not None else self.config.loader.eval_batch_size |
| | T = self.config.model.length |
| | C = self.config.model.downscale_ratio |
| | |
| | num_img_tokens = self.config.model.img_length |
| | num_img_diffusion_steps = self.config.sampling.steps |
| |
|
| | |
| | xt_ids = torch.full((B, T), fill_value=self.tokenizer.pad_token_id, dtype=torch.long, device=self.device) |
| | xt_ids[:, 0] = self.tokenizer.bos_token_id |
| | xt_img_embed = torch.zeros((B, T, C), device=self.device) |
| | modality = torch.zeros((B, T), dtype=torch.long, device=self.device) |
| | sigma = torch.zeros((B, T), dtype=self.dtype, device=self.device) |
| | data = InputData(xt_ids=xt_ids, xt_img_embed=xt_img_embed, modality=modality, sigma=sigma, batch_size=[B]) |
| |
|
| | noise = torch.distributions.Gumbel(0, 1).sample((data.shape[0], T, self.vocab_size)).to(self.device) |
| | img_start_token_id = self.tokenizer.eos_token_id |
| | i = 1 |
| | continuous_diffusion_mode = False |
| | while i < T: |
| | if continuous_diffusion_mode: |
| | |
| | img_sl = slice(i, i+num_img_tokens) |
| | data.modality[:, img_sl] = 1 |
| | data.xt_img_embed[:, img_sl] = self.sample_continuous_image(data, img_sl=img_sl, num_steps=num_img_diffusion_steps, return_embeddings=True) |
| | i += num_img_tokens |
| | continuous_diffusion_mode = False |
| | break |
| | else: |
| | |
| | ar_sl = slice(None, i) |
| | if self.use_kv_cache: |
| | start_pos = i - 1 |
| | kv_sl = slice(start_pos, i) |
| | else: |
| | kv_sl = ar_sl |
| | start_pos=None |
| | pred_logits, pred_noise = self.forward(x=data.xt_ids[:, kv_sl], sigma=data.sigma[:, ar_sl], modality=data.modality[:, ar_sl], x_img_emb=data.xt_img_embed[:, ar_sl], disable_ar_shift=True, continuous_mode=True, start_pos=start_pos) |
| | pred_logits = pred_logits[:, -1] |
| | y = (pred_logits + noise[:, i]).argmax(-1) |
| | |
| |
|
| | data.xt_ids[:, i] = y |
| | |
| | i += 1 |
| | if not text_only and (i == self.config.model.txt_length-1 or torch.all(y == img_start_token_id)): |
| | continuous_diffusion_mode = True |
| |
|
| | if self.config.model.use_kv_cache: |
| | backbone = self.accelerator.unwrap_model(self.backbone) |
| | backbone.reset_kv_cache(batch_size=self.config.model.inference_max_batch_size, seq_len=self.config.model.inference_max_seq_len, dtype=self.dtype, device=self.device) |
| |
|
| | return data |
| |
|
| | def sample_continuous_image(self, data: InputData, img_sl, num_steps=None, return_embeddings=False): |
| | if num_steps is None: |
| | num_steps = self.config.sampling.steps |
| | B = data.xt_img_embed.shape[0] |
| | noise_scheduler = self.vae.scheduler |
| | noise_scheduler.set_timesteps(num_steps, device=self.device) |
| | timesteps = noise_scheduler.timesteps |
| | data.xt_img_embed[:, img_sl] = torch.randn_like(data.xt_img_embed[:, img_sl]) |
| |
|
| | visible_sl = slice(None, img_sl.stop) |
| | for i in range(num_steps+1): |
| | data.sigma[:, img_sl] = (timesteps[i] * torch.ones(B, device=self.device)).unsqueeze(-1) |
| | pred_logits, pred_noise = self.forward( |
| | x=data.xt_ids[:, visible_sl], sigma=data.sigma[:, visible_sl], x_img_emb=data.xt_img_embed[:, visible_sl], modality=data.modality[:, visible_sl], disable_ar_shift=True, continuous_mode=True |
| | ) |
| | data.xt_img_embed[:, img_sl] = noise_scheduler.step(pred_noise[:, img_sl], timesteps[i], data.xt_img_embed[:, img_sl]).prev_sample |
| |
|
| | if return_embeddings: return data.xt_img_embed[:, img_sl] |
| |
|
| | |
| | |
| | text_tokens, img_tokens = self.decode_batch(data.xt_ids[:, img_sl], text_only=False) |
| | return text_tokens, img_tokens |
| |
|
| |
|
| | def cfg(config, t, cat_p_x0): |
| | logit_c, logit_u = cat_p_x0.chunk(2, dim=0) |
| | _cfg = config.eval.cfg |
| | if not getattr(config.eval, "force_cfg_value", False): |
| | if _cfg == -1: |
| | _cfg = torch.linspace(0, 10, t.shape[0]).to(t.device) |
| | _w = (_cfg * (1 - t))[:, None, None] |
| | else: |
| | _w = _cfg |
| |
|
| | return (1 + _w) * logit_c - _w * logit_u |
| |
|
| | def nucleus_sampling_batch(logits, top_p=0.9, temperature=1.0): |
| | """ |
| | Perform nucleus (top-p) sampling on batched and sequenced logits. |
| | |
| | Args: |
| | logits (torch.Tensor): A tensor of shape (B, N, C) where B is the batch size, |
| | N is the sequence length, and C is the number of classes. |
| | top_p (float): The cumulative probability threshold for nucleus sampling. |
| | temperature (float): Temperature value for scaling logits. |
| | |
| | Returns: |
| | torch.Tensor: Indices sampled from the filtered distribution for each position, |
| | with shape (B, N). |
| | """ |
| | B, N, C = logits.shape |
| |
|
| | |
| | |
| | probs = logits / temperature |
| |
|
| | |
| | sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1) |
| |
|
| | |
| | cumulative_probs = torch.cumsum(sorted_probs, dim=-1) |
| |
|
| | |
| | mask = cumulative_probs <= top_p |
| |
|
| | |
| | mask[:, :, 0] = True |
| |
|
| | |
| | filtered_probs = sorted_probs * mask.float() |
| |
|
| | |
| | filtered_probs /= filtered_probs.sum(dim=-1, keepdim=True) |
| |
|
| | |
| | sampled_indices = torch.multinomial(filtered_probs.view(-1, C), num_samples=1).squeeze(-1) |
| |
|
| | |
| | sampled_indices = sampled_indices.view(B, N) |
| |
|
| | |
| | final_indices = torch.gather(sorted_indices, -1, sampled_indices.unsqueeze(-1)).squeeze(-1) |
| |
|
| | return final_indices |
| |
|
| | def nucleus_sampling(logits, top_p=0.9, temperature=1.0): |
| | """ |
| | Perform nucleus (top-p) sampling on the given logits. |
| | |
| | Args: |
| | logits (torch.Tensor): A tensor of shape (B, C) where B is the batch size |
| | and C is the number of classes. |
| | top_p (float): The cumulative probability threshold for nucleus sampling. |
| | |
| | Returns: |
| | torch.Tensor: Indices sampled from the filtered distribution. |
| | """ |
| | |
| | probs = torch.nn.functional.softmax(logits / temperature, dim=-1) |
| |
|
| | |
| | sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1) |
| |
|
| | |
| | cumulative_probs = torch.cumsum(sorted_probs, dim=-1) |
| |
|
| | |
| | mask = cumulative_probs <= top_p |
| | |
| | |
| | mask[..., 0] = True |
| |
|
| | |
| | filtered_probs = sorted_probs * mask.float() |
| |
|
| | |
| | filtered_probs /= (filtered_probs.sum(dim=-1, keepdim=True)) |
| | |
| | sampled_indices = torch.multinomial(filtered_probs, num_samples=1)[:, 0] |
| | |
| | final_indices = sorted_indices.gather(dim=-1, index=sampled_indices.unsqueeze(-1)).squeeze(-1) |
| |
|
| | return final_indices |
| |
|
| | def clear_gpu_memory_if_needed(): |
| | if torch.cuda.is_available(): |
| | current_memory = torch.cuda.memory_reserved() / torch.cuda.get_device_properties(0).total_memory |
| | if current_memory >= 0.50: |
| | torch.cuda.empty_cache() |
| |
|
| | def _ar_sampler(self, B, x0=None, x0_unmask=None, modality=None, **kwargs): |
| | assert B > 0 |
| | assert (x0 is None) == (x0_unmask is None), f"x0: {x0} x0_unmask: {x0_unmask}" |
| | num_pred_tokens = self.config.model.length - 1 |
| | x = torch.zeros((B, num_pred_tokens + 1), dtype=torch.long, device=self.device) |
| | x[:, 0] = self.tokenizer.bos_token_id |
| | if x0 is not None: x = torch.where(x0_unmask, x0, x) |
| | split_cfg_batches = getattr(self.config.eval, "split_cfg_batches", False) and not self.config.model.use_kv_cache |
| | effective_bs = B * 2 if ((self.config.eval.cfg is not None and x0 is not None) and split_cfg_batches is False) else B |
| | top_p = getattr(self.config.eval, "top_p", None) |
| | temperature = getattr(self.config.eval, "temperature", 1.0) |
| | if self.config.model.use_kv_cache: |
| | assert getattr(self.config.model, "inference_max_batch_size", None) is None |
| | assert getattr(self.config.model, "inference_max_seq_len", None) is None |
| | self.accelerator.unwrap_model(self.backbone).reset_kv_cache( |
| | batch_size=effective_bs, |
| | seq_len=num_pred_tokens, |
| | dtype=self.dtype, |
| | device=self.device |
| | ) |
| |
|
| | _x, _modality = None, None |
| | if self.config.eval.cfg is not None and x0 is not None: |
| | if split_cfg_batches is False: |
| | _x = torch.cat([x, torch.where(x0_unmask, self.mask_index, x)], dim=0) |
| | _modality = torch.cat([modality, modality], dim=0) |
| |
|
| | nfe_cnt = 0 |
| | noise = torch.distributions.Gumbel(0, 1).sample((B, num_pred_tokens, self.vocab_size)).to(self.device) |
| | for i in range(num_pred_tokens): |
| | start_pos = i if self.use_kv_cache else None |
| | ar_sl = slice(start_pos, i+1) |
| |
|
| | if self.config.eval.cfg is not None and x0 is not None: |
| | if split_cfg_batches: |
| | logit_c = self.forward( |
| | x=x[:, ar_sl], sigma=None, modality=modality[:, ar_sl], start_pos=start_pos, disable_ar_shift=True |
| | )[:, -1] |
| | logit_u = self.forward( |
| | x=torch.where(x0_unmask, self.mask_index, x)[:, ar_sl], sigma=None, modality=modality[:, ar_sl], start_pos=start_pos, disable_ar_shift=True |
| | )[:, -1] |
| | else: |
| | _x[:B] = x |
| | _x[B:] = torch.where(x0_unmask, self.mask_index, x) |
| | next_logits = self.forward(x=_x[:, ar_sl], sigma=None, modality=_modality[:, ar_sl], start_pos=start_pos, disable_ar_shift=True)[:, -1] |
| | logit_c, logit_u = next_logits.chunk(2, dim=0) |
| |
|
| | _w = self.get_cfg_weight(1 - (i / num_pred_tokens)) |
| | next_logits = (1 + _w) * logit_c - _w * logit_u |
| | else: |
| | next_logits = self.forward(x=x[:, ar_sl], sigma=None, modality=modality[:, ar_sl], start_pos=start_pos, disable_ar_shift=True)[:, -1] |
| |
|
| | if getattr(self.config.model, "force_argmax_valid_indices", False): |
| | |
| | next_sl = slice(i + 1, i + 2) |
| | try: |
| | next_logits[..., self.text_vocab_size:] = torch.where((modality[:, next_sl] == 0), torch.finfo(next_logits.dtype).min, next_logits[..., self.text_vocab_size:]) |
| | next_logits[..., :self.text_vocab_size] = torch.where((modality[:, next_sl] == 1), torch.finfo(next_logits.dtype).min, next_logits[..., :self.text_vocab_size]) |
| | except: |
| | breakpoint() |
| | if top_p is not None: |
| | |
| | y = nucleus_sampling(next_logits, top_p=top_p, temperature=temperature) |
| | else: |
| | next_logits = next_logits + noise[:, i] |
| | nfe_cnt += 1 |
| | y = (next_logits).argmax(-1) |
| | x[:, i + 1] = y |
| | if x0 is not None: x = torch.where(x0_unmask, x0, x) |
| | if not self.config.model.use_kv_cache: |
| | empty_device_cache() |
| | |
| | if getattr(self.config.eval, "force_empty_cache", False): |
| | empty_device_cache() |
| |
|
| | if self.config.model.use_kv_cache: |
| | |
| | del noise, next_logits, _x, _modality |
| | self.accelerator.unwrap_model(self.backbone).reset_kv_cache( |
| | batch_size=effective_bs, |
| | seq_len=num_pred_tokens, |
| | dtype=self.dtype, |
| | device=self.device, |
| | set_to_none=True |
| | ) |
| |
|
| | return x, nfe_cnt |
| |
|
| | def handle_interleaved_decode(self, sample, allow_mask_index=False, new_mask_index=None, **kwargs): |
| | batch = sample |
| | sample_modality = sample.get("modality", None) |
| | sample = sample.get("input_ids", None) |
| |
|
| | text_tokens = torch.where(sample_modality == 0, sample, self.tokenizer.pad_token_id) |
| | img_tokens = torch.where((sample_modality == 1), sample, self.mask_index) |
| |
|
| | invalid_text_mask = (text_tokens >= self.text_vocab_size) & (sample_modality == 0) |
| | invalid_img_mask = (img_tokens < self.text_vocab_size) & (sample_modality == 1) |
| | mask_img_mask = (img_tokens == self.mask_index) & (sample_modality == 1) |
| |
|
| | if invalid_text_mask.sum() > 0: |
| | assert allow_mask_index or self.config.model.force_argmax_valid_indices is False or self.config.sampling.predictor == "ddpm_tweedie" or self.config.parameterization == "ar", f"invalid_text_mask.sum(): {invalid_text_mask.sum()}, {invalid_text_mask.nonzero()[:4]}" |
| | text_tokens[invalid_text_mask] = self.mask_index |
| |
|
| | if new_mask_index is not None: |
| | img_invalid_mask_v2 = ((img_tokens < self.text_vocab_size) & (img_tokens != self.mask_index)) |
| |
|
| | sample = torch.where(sample_modality == 1, img_tokens - self.text_vocab_size, text_tokens) |
| | if invalid_img_mask.sum() > 0 or mask_img_mask.sum() > 0: |
| | if new_mask_index is not None: |
| | assert img_invalid_mask_v2.sum().item() == 0 |
| | sample[mask_img_mask] = new_mask_index |
| | else: |
| | sample[mask_img_mask] = 0 |
| | sample[invalid_img_mask] = 0 |
| |
|
| | new_batch = {**batch, "input_ids": sample} |
| | new_batch = InterleavedBatch.custom_from_dict(new_batch) |
| | new_batch = new_batch.to_elements() |
| | return new_batch |
| |
|
| | def decode_batch(self, |
| | sample, |
| | text_only=True, |
| | return_masks: bool = False, |
| | allow_mask_index: bool = False, |
| | new_mask_index=None, |
| | sample_modality=None, |
| | **kwargs |
| | ): |
| |
|
| | if isinstance(sample, dict) or isinstance(sample, TensorDict): |
| | if self.config.trainer.interleaved or getattr(self.config.eval, "fake_interleaved", False): |
| | return handle_interleaved_decode(self, sample, allow_mask_index=allow_mask_index, new_mask_index=new_mask_index, **kwargs) |
| | else: |
| | sample_modality = sample.get("modality", None) |
| | sample = sample.get("input_ids", None) |
| |
|
| | img_tokens = None |
| | continuous_mode = self.config.trainer.image_mode == "continuous" |
| | if continuous_mode: |
| | text_tokens, img_tokens = sample[..., self.static_txt_sl], sample[..., self.static_img_sl] |
| | elif self.unified_model and self.config.trainer.multimodal_batches and sample_modality is not None: |
| | if (sample_modality == 0).all(dim=-1).sum() > 0: |
| | text_tokens = torch.where(sample_modality == 0, sample, self.tokenizer.pad_token_id) |
| | img_tokens = torch.where((sample_modality == 1)[:, self.static_img_sl], sample[:, self.static_img_sl], self.mask_index) |
| | else: |
| | text_tokens = torch.where(sample_modality == 0, sample, self.tokenizer.pad_token_id) |
| | img_tokens = torch.where((sample_modality == 1), sample, self.mask_index) |
| |
|
| | invalid_text_mask = text_tokens >= self.text_vocab_size |
| | if getattr(self.config.model, "add_labels", None) is not None: |
| | invalid_img_mask = (img_tokens < self.text_vocab_size) | (img_tokens >= (self.vocab_size - self.config.model.add_labels)) |
| | else: |
| | invalid_img_mask = (img_tokens < self.text_vocab_size) |
| | mask_text_mask = text_tokens == self.mask_index |
| | mask_img_mask = img_tokens == self.mask_index |
| | if invalid_text_mask.sum() > 0: |
| | assert allow_mask_index or self.config.model.force_argmax_valid_indices is False or self.config.sampling.predictor == "ddpm_tweedie" or self.config.parameterization == "ar", f"invalid_text_mask.sum(): {invalid_text_mask.sum()}, {invalid_text_mask.nonzero()[:4]}" |
| | text_tokens[invalid_text_mask] = self.mask_index |
| |
|
| | if new_mask_index is not None: |
| | img_invalid_mask_v2 = ((img_tokens < self.text_vocab_size) & (img_tokens != self.mask_index)) |
| |
|
| | img_tokens = img_tokens - self.text_vocab_size |
| | if invalid_img_mask.sum() > 0 or mask_img_mask.sum() > 0: |
| | if new_mask_index is not None: |
| | assert img_invalid_mask_v2.sum().item() == 0 |
| | img_tokens[mask_img_mask] = new_mask_index |
| | else: |
| | img_tokens[mask_img_mask] = 0 |
| | img_tokens[invalid_img_mask] = 0 |
| | |
| | if img_tokens.shape[-1] != self.config.model.img_length: |
| | if (sample_modality[:, -self.config.model.img_length:].sum(dim=-1) == self.config.model.img_length).all(): |
| | img_tokens = img_tokens[:, -self.config.model.img_length:] |
| | elif (sample_modality[:, :self.config.model.img_length].sum(dim=-1) == self.config.model.img_length).all(): |
| | img_tokens = img_tokens[:, :self.config.model.img_length] |
| |
|
| | elif self.unified_model: |
| | text_tokens, img_tokens = sample[..., self.static_txt_sl], sample[..., self.static_img_sl] |
| | invalid_text_mask = text_tokens >= self.text_vocab_size |
| | invalid_img_mask = img_tokens < self.text_vocab_size |
| | mask_text_mask = text_tokens == self.mask_index |
| | mask_img_mask = img_tokens == self.mask_index |
| |
|
| | if invalid_text_mask.sum() > 0: |
| | assert allow_mask_index or self.config.model.force_argmax_valid_indices is False or self.config.sampling.predictor == "ddpm_tweedie" or self.config.parameterization == "ar", f"invalid_text_mask.sum(): {invalid_text_mask.sum()}" |
| | text_tokens[invalid_text_mask] = self.mask_index |
| |
|
| | if new_mask_index is not None: |
| | img_invalid_mask_v2 = ((img_tokens < self.text_vocab_size) & (img_tokens != self.mask_index)) |
| |
|
| | img_tokens = img_tokens - self.text_vocab_size |
| | if invalid_img_mask.sum() > 0 or mask_img_mask.sum() > 0: |
| | assert allow_mask_index or self.config.model.force_argmax_valid_indices is False or self.config.sampling.predictor == "ddpm_tweedie" or self.config.parameterization == "ar", f"invalid_img_mask.sum(): {invalid_img_mask.sum()}" |
| | if new_mask_index is not None: |
| | assert img_invalid_mask_v2.sum().item() == 0 |
| | img_tokens[mask_img_mask] = new_mask_index |
| | else: |
| | img_tokens[mask_img_mask] = 0 |
| | img_tokens[invalid_img_mask] = 0 |
| |
|
| | try: |
| | assert img_tokens.shape[-1] == self.config.model.img_length, f"img_tokens.shape[-1]: {img_tokens.shape[-1]}, config.model.img_length: {self.config.model.img_length}, sample_modality: {sample_modality}" |
| | except: |
| | breakpoint() |
| | |
| | elif self.image_model: |
| | text_tokens, img_tokens = None, sample |
| | else: |
| | text_tokens, img_tokens = sample, None |
| | if text_only: |
| | return text_tokens |
| | else: |
| | if return_masks: |
| | return text_tokens, img_tokens, mask_text_mask, mask_img_mask |
| | else: |
| | return text_tokens, img_tokens |
| |
|
| | def optional_add_bos(self, _x, x0): |
| | if self.config.trainer.ar_shift: |
| | if x0 is not None: |
| | _x = torch.cat([x0[:, [0]], _x], dim=1) |
| | else: |
| | _x = torch.cat([torch.full_like(_x[..., :1], fill_value=self.tokenizer.pad_token_id), _x], dim=1) |
| | return _x |
| |
|
| | def adap_sche(x, step, mask_index, mode="arccos"): |
| | """ Create a 2D sampling scheduler |
| | :param |
| | x -> torch.Tensor: input tensor with shape (B, seq_len) |
| | step -> int: number of prediction steps during inference |
| | mode -> str: the rate of value to unmask |
| | leave -> bool: tqdm arg on either to keep the bar or not |
| | :return |
| | scheduler -> torch.LongTensor(): 2D tensor of shape (B, max_seq_len) with schedules for each sample |
| | """ |
| | num_masked = (x == mask_index).sum(dim=-1).to(x.device) |
| | |
| | r = torch.linspace(1, 0, step) |
| | |
| | if mode == "root": |
| | val_to_mask = 1 - (r ** .5) |
| | elif mode == "linear": |
| | val_to_mask = 1 - r |
| | elif mode == "square": |
| | val_to_mask = 1 - (r ** 2) |
| | elif mode == "cosine": |
| | val_to_mask = torch.cos(r * math.pi * 0.5) |
| | elif mode == "arccos": |
| | val_to_mask = torch.arccos(r) / (math.pi * 0.5) |
| | else: |
| | return None |
| | val_to_mask = val_to_mask.to(x.device) |
| | schedules = [] |
| | for seq_len in num_masked: |
| | print(f"seq_len: {seq_len}") |
| | sche = (val_to_mask / val_to_mask.sum()) * seq_len |
| | sche = sche.round() |
| | sche[sche == 0] = 1 |
| | sche[-1] += seq_len - sche.sum() |
| | sche[-1] = max(sche[-1], 0) |
| | schedules.append(sche.int()) |
| |
|
| | return torch.stack(schedules, dim=0) |
| |
|
| | |
| | @torch.no_grad() |
| | def _first_hitting_update(self, x, t, dt, schedule=None, step=None, **kwargs): |
| | sigma_t, _ = self.noise(t) |
| | sigma_s, _ = self.noise(t - dt) |
| | if sigma_t.ndim > 1: |
| | sigma_t = sigma_t.squeeze(-1) |
| | if sigma_s.ndim > 1: |
| | sigma_s = sigma_s.squeeze(-1) |
| | assert sigma_t.ndim == 1, sigma_t.shape |
| | assert sigma_s.ndim == 1, sigma_s.shape |
| | move_chance_t = 1 - torch.exp(-sigma_t) |
| | move_chance_s = 1 - torch.exp(-sigma_s) |
| | move_chance_t = move_chance_t[:, None, None] |
| | move_chance_s = move_chance_s[:, None, None] |
| |
|
| | _sigma = None if getattr(self.config.trainer, "force_null_sigma", False) else sigma_t |
| | nfe_cnt = 0 |
| | p_x0 = self._ddpm_forward(x, t, _sigma, **kwargs) |
| | nfe_cnt += 1 |
| |
|
| | copy_flag = (x != self.mask_index) |
| |
|
| | |
| | _x = _sample_categorical(p_x0) |
| |
|
| | num_unmask = schedule[:, step] |
| | num_unmask = torch.minimum(num_unmask, (~copy_flag).sum(dim=-1)) |
| | if torch.all(num_unmask <= 0): |
| | return x, nfe_cnt |
| |
|
| | random_values = torch.rand_like(copy_flag, dtype=torch.float32) |
| | random_values = torch.where(~copy_flag, random_values, -1) |
| | _, indices = torch.sort(random_values, dim=-1, descending=True) |
| | range_tensor = torch.arange(copy_flag.shape[-1], device=copy_flag.device).expand(copy_flag.shape) |
| | final_mask = range_tensor < num_unmask[:, None] |
| |
|
| | result = torch.zeros_like(copy_flag) |
| | result.scatter_(-1, indices, final_mask) |
| |
|
| | return torch.where(result, _x, x), nfe_cnt |
| |
|
| | @torch.no_grad() |
| | def _maskgit_update(self, x, t, dt, schedule=None, step=None, **kwargs): |
| | sigma_t, _ = self.noise(t) |
| | sigma_s, _ = self.noise(t - dt) |
| | if sigma_t.ndim > 1: |
| | sigma_t = sigma_t.squeeze(-1) |
| | if sigma_s.ndim > 1: |
| | sigma_s = sigma_s.squeeze(-1) |
| | assert sigma_t.ndim == 1, sigma_t.shape |
| | assert sigma_s.ndim == 1, sigma_s.shape |
| | move_chance_t = 1 - torch.exp(-sigma_t) |
| | move_chance_s = 1 - torch.exp(-sigma_s) |
| | move_chance_t = move_chance_t[:, None, None] |
| | move_chance_s = move_chance_s[:, None, None] |
| | nfe_cnt = 0 |
| | _sigma = None if getattr(self.config.trainer, "force_null_sigma", False) else sigma_t |
| |
|
| | copy_flag = (x != self.mask_index) |
| | r_temp = getattr(self.config.eval, 'maskgit_r_temp', 10) |
| | num_unmask = schedule[:, step] |
| | |
| | num_unmask = torch.minimum(num_unmask, (~copy_flag).sum(dim=-1)) |
| | if torch.all(num_unmask <= 0): |
| | return x, nfe_cnt |
| |
|
| | p_x0 = self._ddpm_forward(x, t, _sigma, **kwargs) |
| | nfe_cnt += 1 |
| | pred_code = torch.multinomial(p_x0.view(-1, p_x0.shape[-1]), 1)[:, 0].view(p_x0.shape[:-1]) |
| | conf = torch.gather(p_x0, -1, pred_code.unsqueeze(-1)).squeeze(-1) |
| | |
| | rand = r_temp * torch.from_numpy(np.random.gumbel(size=pred_code.shape)).to(self.device) * t |
| | conf = torch.log(conf.squeeze()) + rand |
| |
|
| | if self.config.trainer.ar_shift: |
| | copy_flag = copy_flag[:, 1:] |
| |
|
| | |
| | conf = torch.where(copy_flag, -torch.inf, conf) |
| |
|
| | |
| | |
| | max_num_unmask = num_unmask.max().item() |
| | |
| | |
| | tresh_conf, indice_mask = torch.topk(conf, k=max_num_unmask, dim=-1) |
| | |
| | |
| | |
| |
|
| | |
| | gather_indices = torch.clamp(num_unmask - 1, min=0)[:, None] |
| | tresh_conf = tresh_conf.gather(-1, gather_indices) |
| | tresh_conf = torch.where((num_unmask <= 0)[:, None], torch.inf, tresh_conf) |
| | |
| | |
| | conf = (conf >= tresh_conf.expand_as(conf)) |
| | if self.config.trainer.ar_shift: |
| | out = torch.where(conf, pred_code, x[:, 1:]) |
| | out = optional_add_bos(self, out, x0=kwargs.get("x0", None)) |
| | else: |
| | out = torch.where(conf, pred_code, x) |
| |
|
| | if getattr(self.config.eval, "allow_token_updates", False): |
| | out = torch.where(copy_flag, p_x0.argmax(dim=-1), out) |
| |
|
| | del conf, indice_mask, gather_indices, tresh_conf, pred_code, p_x0 |
| | if getattr(self.config.eval, "force_empty_cache", False): |
| | empty_device_cache() |
| |
|
| | return out, nfe_cnt |
| |
|
| |
|
| | @torch.no_grad() |
| | def _maskgit_nucleus_update(self, x, t, dt, schedule=None, step=None, **kwargs): |
| | nfe_cnt = 0 |
| | _sigma = None |
| |
|
| | copy_flag = (x != self.mask_index) |
| | if self.config.trainer.ar_shift: |
| | copy_flag = copy_flag[:, 1:] |
| |
|
| | assert getattr(self.config.eval, 'maskgit_r_temp', None) != None |
| | r_temp = getattr(self.config.eval, "maskgit_r_temp", 10) |
| | num_unmask = schedule[:, step] |
| | num_unmask = torch.minimum(num_unmask, (~copy_flag).sum(dim=-1)) |
| | if num_unmask <= 0: |
| | return x, nfe_cnt |
| | |
| | p_x0 = self._ddpm_forward(x, t, _sigma, **kwargs) |
| | nfe_cnt += 1 |
| | top_p = getattr(self.config.eval, "top_p", 0.95) |
| | temperature = getattr(self.config.eval, "temperature", 0.9) |
| | if top_p is not None: |
| | pred_code = nucleus_sampling_batch(p_x0, top_p=top_p, temperature=temperature) |
| | else: |
| | pred_code = torch.multinomial(p_x0.view(-1, p_x0.shape[-1]), 1)[:, 0].view(p_x0.shape[:-1]) |
| | conf = torch.gather(p_x0, -1, pred_code.unsqueeze(-1)).squeeze(-1) |
| |
|
| | rand = r_temp * torch.from_numpy(np.random.gumbel(size=pred_code.shape)).to(self.device) * t |
| | conf = torch.log(conf.squeeze()) + rand |
| |
|
| | |
| | conf = torch.where(copy_flag, -torch.inf, conf) |
| |
|
| | |
| | |
| | max_num_unmask = num_unmask.max().item() |
| | |
| | tresh_conf, indice_mask = torch.topk(conf, k=max_num_unmask, dim=-1) |
| | |
| | |
| | |
| | gather_indices = torch.clamp(num_unmask - 1, min=0)[:, None] |
| | tresh_conf = tresh_conf.gather(-1, gather_indices.long()) |
| | tresh_conf = torch.where((num_unmask <= 0)[:, None], torch.inf, tresh_conf) |
| |
|
| | |
| | conf = (conf >= tresh_conf) |
| | if self.config.trainer.ar_shift: |
| | out = torch.where(conf, pred_code, x[:, 1:]) |
| | out = optional_add_bos(self, out, x0=kwargs.get("x0", None)) |
| | else: |
| | out = torch.where(conf, pred_code, x) |
| | return out, nfe_cnt |
| |
|
| |
|
| |
|
| | @torch.no_grad() |
| | def _ddpm_update_finetune_controlled_tweedie(self, x, t, dt, reward_model=None, repeats=10, sampling_step=None, **kwargs): |
| | sigma_t, _ = self.noise(t) |
| | sigma_s, _ = self.noise(t - dt) |
| | if sigma_t.ndim > 1: |
| | sigma_t = sigma_t.squeeze(-1) |
| | if sigma_s.ndim > 1: |
| | sigma_s = sigma_s.squeeze(-1) |
| | assert sigma_t.ndim == 1, sigma_t.shape |
| | assert sigma_s.ndim == 1, sigma_s.shape |
| | move_chance_t = 1 - torch.exp(-sigma_t) |
| | move_chance_s = 1 - torch.exp(-sigma_s) |
| | move_chance_t = move_chance_t[:, None, None] |
| | move_chance_s = move_chance_s[:, None, None] |
| | _sigma = None if getattr(self.config.trainer, "force_null_sigma", False) else sigma_t |
| | p_x0 = self._ddpm_forward(x, t, _sigma, **kwargs) |
| | assert move_chance_t.ndim == p_x0.ndim |
| |
|
| | if self.config.trainer.force_bf16_eval: empty_device_cache() |
| | q_xs = p_x0 * (move_chance_t - move_chance_s) |
| | q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0] |
| | copy_flag = (x != self.mask_index).to(x.dtype) |
| |
|
| | del p_x0, move_chance_t, move_chance_s |
| | resample_interval = getattr(self.config.eval, "tweedie_resample_interval", None) |
| | return_single_sample = False |
| | _repeats = repeats |
| | if resample_interval is not None and sampling_step % resample_interval != 0: |
| | _repeats = 1 |
| | return_single_sample = True |
| |
|
| | |
| | samples = [copy_flag * x + (1 - copy_flag) * optional_add_bos(self, _sample_categorical(q_xs), x0=kwargs.get("x0", None)) for _ in range(_repeats)] |
| |
|
| | if return_single_sample: |
| | return samples[0] |
| |
|
| | if not hasattr(self, "reward_model"): |
| | from unidisc.tokenizers.laion_aesthetic_v2 import get_predictor_func |
| | self.reward_model = get_predictor_func(self.device) |
| | rprint("Using reward model. Should delete this after eval.") |
| |
|
| | |
| | |
| | scores = [] |
| | expected_x0_args = [] |
| | for i in range(repeats): |
| | |
| | expected_x0 = self._ddpm_forward(samples[i], t, sigma_s, **kwargs) |
| | if getattr(self.config.eval, "use_generic_tweedie_rewards", False): |
| | assert self.config.trainer.interleaved |
| | expected_x0_arg = torch.argmax(expected_x0, dim=-1) |
| | expected_x0_args.append(expected_x0_arg) |
| | assert samples[0].shape[0] == 1 |
| | else: |
| | expected_x0[..., :self.text_vocab_size] = 0 |
| | expected_x0[..., self.mask_index] = 0 |
| | expected_x0[..., self.text_vocab_size:] = expected_x0[..., self.text_vocab_size:] + 1e-6 |
| | expected_x0_arg = torch.argmax(expected_x0, dim=-1) |
| | expected_x0_arg = expected_x0_arg - self.text_vocab_size |
| | expected_x0_img_pred = decode_latents(self.config, self.get_vae(), expected_x0_arg[:, self.static_img_sl]) |
| | scorer = self.reward_model(expected_x0_img_pred) |
| |
|
| | scorer = scorer.squeeze() |
| | if scorer.ndim == 0: |
| | scorer = scorer[None] |
| | scores.append(torch.from_numpy(scorer)) |
| |
|
| | if getattr(self.config.eval, "use_generic_tweedie_rewards", False): |
| | orig_modality = kwargs.get("modality", None) |
| | if orig_modality is not None: |
| | orig_modality = orig_modality.clone() |
| | kwargs["modality"] = orig_modality.repeat(len(expected_x0_args), 1) |
| |
|
| | orig_sample_ids = kwargs.get("sample_ids", None) |
| | if orig_sample_ids is not None: |
| | orig_sample_ids = orig_sample_ids.clone() |
| | kwargs["sample_ids"] = orig_sample_ids.repeat(len(expected_x0_args), 1) |
| |
|
| | decoded_data = self.decode_batch({"input_ids": torch.cat(expected_x0_args, dim=0), **kwargs}, text_only=False) |
| | kwargs["modality"] = orig_modality |
| | kwargs["sample_ids"] = orig_sample_ids |
| |
|
| | all_imgs = [] |
| | all_txt_ids = [] |
| | for i in range(len(decoded_data)): |
| | sample_data, sample_modalities = decoded_data[i].to_list() |
| | assert len(sample_data) == 2 |
| | assert sample_modalities == [0, 1] |
| | sample_text = wrapped_batch_decode( |
| | self.tokenizer, |
| | sample_data[0][None], |
| | clean_up_tokenization_spaces=True, |
| | skip_special_tokens=False, |
| | disable_mask_after_eos=True |
| | ) |
| | assert len(sample_text) == 1 |
| | all_txt_ids.append(sample_text[0]) |
| | all_imgs.append(self.get_interleaved_image(sample_data, sample_modalities, single_image_only=True, disable_img_save=True)) |
| |
|
| | all_imgs = torch.cat(all_imgs, dim=0) |
| | reward_config = getattr(self.config.eval, "tweedie_reward_config") |
| | scores = self.get_rewards(reward_config, all_imgs, all_txt_ids).float().cpu() |
| | scores = torch.softmax(scores, dim=0)[None] |
| | else: |
| | scores = torch.stack(scores, dim=1) |
| | scores = torch.softmax(scores, dim=1) |
| |
|
| | |
| | |
| | final_sample_indices = torch.argmax(scores, dim=1) |
| | final_samples = [samples[final_sample_indices[j]][j,:] for j in range(x.size(0))] |
| | final_samples = torch.stack(final_samples, dim=0) |
| | return final_samples |
| |
|
| | @try_except(write_error_to_file=True, clear_cuda_cache=True) |
| | def visualize_samples(self, batch, batch_idx, split='val'): |
| | split = split.removesuffix("/") |
| | gt_txt = None |
| | step_metrics = self.get_step_metrics() |
| | step_metrics["trainer/global_step"] = (batch_idx if self.config.eval.visualize_data_only else self.global_step) |
| | rprint('[IMPORTANT] Visualizing ground truth samples, verify tokenization') |
| |
|
| | if getattr(self.config.eval, "disable_visualization", False): |
| | return |
| |
|
| | if self.config.trainer.interleaved: |
| | decoded_data = self.decode_batch(batch, text_only=False) |
| | all_imgs = [] |
| | max_num = 10000 if getattr(self.config.eval, "visualize_data_only", False) else 32 |
| | for i in range(min(len(decoded_data), max_num)): |
| | sample_data, sample_modalities = decoded_data[i].to_list() |
| | all_imgs.append(self.get_interleaved_image(sample_data, sample_modalities)) |
| |
|
| | if not getattr(self.config.eval, "visualize_data_only", False): |
| | log({f"{split}/rec_img": wandb.Image(Im.concat_horizontal(*all_imgs).pil), **step_metrics}) |
| | else: |
| | gt_txt, gt_img = self.decode_batch(batch["input_ids"], text_only=False, sample_modality=batch.get("modality", None)) |
| | if gt_img is not None: |
| | rec_img = decode_latents(self.config, self.get_vae(), gt_img) |
| | log({f"{split}/rec_img": wandb.Image(rec_img), **step_metrics}) |
| | |
| | gt_txt = gt_txt[:4] |
| | if self.config.trainer.multimodal_batches: |
| | txt_batch = batch["input_ids"][~self.img_txt_pair_batch_mask(batch)] |
| | if txt_batch.shape[0] > 0: |
| | rprint(f"Txt Only (GT): {wrapped_batch_decode(self.tokenizer, txt_batch[:4], clean_up_tokenization_spaces=True, skip_special_tokens=True)}") |
| | else: |
| | rprint(f"GT Captions: {wrapped_batch_decode(self.tokenizer, gt_txt, clean_up_tokenization_spaces=True, skip_special_tokens=True)}") |
| | else: |
| | if gt_txt is not None: |
| | rprint(f"GT Captions: {wrapped_batch_decode(self.tokenizer, gt_txt, clean_up_tokenization_spaces=True, skip_special_tokens=True)}") |
| |
|
| | if getattr(self.config.eval, "visualize_data_only", False): |
| | exit() |
| |
|
| | if split == "train": |
| | if hasattr(self, "vae"): |
| | del self.vae |
| | empty_device_cache() |
| |
|
| |
|
| | @try_except(write_error_to_file=True, clear_cuda_cache=True) |
| | def mauve_store_references(self, dataloader): |
| | total_batches = len(dataloader) |
| | sample_batch = next(iter(dataloader)) |
| | batch_size = sample_batch["input_ids"].shape[0] |
| | |
| | N = self.config.eval.mauve_num_samples |
| | if not is_main_process(): |
| | return |
| | if N is None or N <= 0 or batch_size * total_batches < N: |
| | rprint(f"[WARNING] Skipping Mauve reference storage. N: {N}, batch_size: {batch_size}, total_batches: {total_batches}") |
| | return |
| | |
| | |
| | num_batches = math.ceil(N / batch_size) |
| | |
| | for i, batch in tqdm(enumerate(dataloader), total=num_batches, desc="Mauve storing references"): |
| | if i >= num_batches: |
| | break |
| | reference_txt_tokens, _ = self.decode_batch(batch["input_ids"], text_only=False, sample_modality=batch.get("modality", None)) |
| | reference_txt = wrapped_batch_decode(self.tokenizer, reference_txt_tokens, clean_up_tokenization_spaces=True, skip_special_tokens=True) |
| | self.mauve_references.extend(reference_txt) |
| | |
| | assert len(self.mauve_references) >= N, f"len(self.mauve_references) ({len(self.mauve_references)}) < N ({N})" |
| | self.mauve_references = self.mauve_references[:N] |
| | save_path = os.path.join(self.config.output_dir, f'mauve_references_{N}.pkl') |
| | with open(save_path, 'wb') as f: |
| | pickle.dump(self.mauve_references, f) |
| | rprint(f"[MAUVE] Stored {N} references in {save_path}") |
| | |
| |
|
| | @try_except(write_error_to_file=True) |
| | def cleanup_fid_output(self): |
| | if getattr(self.config.eval, "force_fid_output_dir", None) is not None: |
| | return |
| | if hasattr(self, "fid_gen_dir"): |
| | fid_output_dir_path = Path(self.fid_gen_dir) |
| | if fid_output_dir_path.exists() and fid_output_dir_path.is_dir(): |
| | rprint(f"Removing fid output dir: {fid_output_dir_path}") |
| | shutil.rmtree(fid_output_dir_path) |
| |
|
| | if hasattr(self, "fid_gt_dir"): |
| | fid_gt_dir_path = Path(self.fid_gt_dir) |
| | if fid_gt_dir_path.exists() and fid_gt_dir_path.is_dir(): |
| | rprint(f"Removing fid gt dir: {fid_gt_dir_path}") |
| | shutil.rmtree(fid_gt_dir_path) |
| |
|
| | if hasattr(self, "img_to_txt_mauve_gen_dir"): |
| | img_to_txt_mauve_gen_dir_path = Path(self.img_to_txt_mauve_gen_dir) |
| | if img_to_txt_mauve_gen_dir_path.exists() and img_to_txt_mauve_gen_dir_path.is_dir(): |
| | rprint(f"Removing img to txt mauve gen dir: {img_to_txt_mauve_gen_dir_path}") |
| | shutil.rmtree(img_to_txt_mauve_gen_dir_path) |
| |
|
| | if hasattr(self, "img_to_txt_mauve_gt_dir"): |
| | img_to_txt_mauve_gt_dir_path = Path(self.img_to_txt_mauve_gt_dir) |
| | if img_to_txt_mauve_gt_dir_path.exists() and img_to_txt_mauve_gt_dir_path.is_dir(): |
| | rprint(f"Removing img to txt mauve gt dir: {img_to_txt_mauve_gt_dir_path}") |
| | shutil.rmtree(img_to_txt_mauve_gt_dir_path) |
| |
|
| | def compute_val_metrics_standalone(self): |
| | rprint("Computing validation metrics standalone") |
| | self.reset_validation_metrics() |
| | num_samples = 0 |
| | for i, batch in tqdm(enumerate(self.validation_dataloader), desc="Standalone validation steps", disable=not is_main_process(), leave=False): |
| | batch = self.update_batch(batch) |
| | num_samples += batch["input_ids"].shape[0] |
| | self.compute_loss(batch, prefix="val", batch_idx=i) |
| | if i >= self.config.eval.num_val_metrics_standalone_batches_per_device: |
| | break |
| |
|
| | log({**self.get_step_metrics(), "num_samples": num_samples * get_world_size()}) |
| | rprint(f"Finished computing validation metrics standalone.") |
| |
|
| |
|
| | def compute_val_metrics_constant_per_batch(self): |
| | rprint("Computing validation metrics standalone") |
| | self.reset_validation_metrics() |
| | if self.config.eval.num_val_metrics_standalone_batches_per_device is None or self.config.eval.num_val_metrics_standalone_batches_per_device <= 0: |
| | return |
| | num_samples = 0 |
| | for i, batch in tqdm(enumerate(self.validation_dataloader), desc="Standalone validation steps", disable=not is_main_process(), leave=False): |
| | batch = self.update_batch(batch) |
| | num_samples += batch["input_ids"].shape[0] |
| | self.compute_loss(batch, prefix="val", batch_idx=i) |
| | if i >= self.config.eval.num_val_metrics_standalone_batches_per_device: |
| | break |
| |
|
| | log({**self.get_step_metrics(), "num_samples": num_samples * get_world_size()}) |
| | rprint(f"Finished computing validation metrics standalone.") |
| |
|
| | def get_interleaved_image(self, sample_data, sample_modalities, single_image_only=False, disable_img_save=False, image_save_postfix=None): |
| | all_sample_imgs = [] |
| | single_image_only = self.config.eval.auto_enhance or single_image_only or getattr(self.config.eval, "fake_interleaved", False) |
| | if getattr(self.config.eval, "disable_shm_save", False): |
| | disable_img_save = True |
| |
|
| | if not disable_img_save: |
| | date_folder = datetime.now().strftime("%Y-%m-%d") |
| | save_dir = Path("/dev/shm") / os.getenv("USER", 'user') / "imgs" / date_folder |
| | save_dir.mkdir(exist_ok=True, parents=True) |
| |
|
| | for j in range(len(sample_data)): |
| | if sample_modalities[j] == 0 and not single_image_only: |
| | sample_text = wrapped_batch_decode( |
| | self.tokenizer, |
| | sample_data[j][None], |
| | clean_up_tokenization_spaces=True, |
| | skip_special_tokens=False, |
| | disable_mask_after_eos=True |
| | ) |
| | txt_image = create_text_image(text=sample_text[0], desired_width=self.config.data.resolution) |
| | all_sample_imgs.append(txt_image) |
| | elif sample_modalities[j] == 1: |
| | sample_img = decode_latents(self.config, self.get_vae(), sample_data[j][None]) |
| | all_sample_imgs.append(sample_img) |
| |
|
| | if not disable_img_save: |
| | image_save_postfix = image_save_postfix or "" |
| | filename = f"img_{get_rank()}_{str(time.time()).replace('.', '__')}"[:100] + f"{image_save_postfix}.png" |
| | save_path = save_dir / filename |
| | if single_image_only: |
| | if not disable_img_save: |
| | gprint(Im(all_sample_imgs[0]).save(save_path)) |
| | assert len(all_sample_imgs) == 1, "Expected single image only" |
| | return all_sample_imgs[0] |
| | else: |
| | img = Im.concat_vertical(*all_sample_imgs).pil |
| | if not disable_img_save: |
| | gprint(Im(img).save(save_path)) |
| | return img |
| | |
| |
|
| | def get_hpsv2_score( |
| | self, |
| | images, |
| | prompts |
| | ): |
| | from unidisc.tokenizers.hpsv2_img_score import score, initialize_model |
| | if not hasattr(self, "hpsv2_model_dict"): |
| | self.hpsv2_model_dict = initialize_model(self.device, "v2.1") |
| |
|
| | if isinstance(images, Tensor): |
| | images = [Im(x).pil for x in images] |
| |
|
| | with torch.inference_mode(mode=False), torch.no_grad(): |
| | scores = [] |
| | for img, prompt in zip(images, prompts): |
| | scores.append(score(self.hpsv2_model_dict, img, prompt)[0].item()) |
| | return torch.tensor(scores) |
| |
|
| | def get_dfn_score( |
| | self, |
| | images, |
| | prompts |
| | ): |
| | if isinstance(images, Tensor): |
| | images = [Im(x).pil for x in images] |
| |
|
| | from open_clip import create_model_from_pretrained, get_tokenizer |
| |
|
| | if not hasattr(self, "dfn_model"): |
| | self.dfn_model, self.dfn_preprocess = create_model_from_pretrained('hf-hub:apple/DFN5B-CLIP-ViT-H-14-384') |
| | self.dfn_tokenizer = get_tokenizer('ViT-H-14') |
| | self.dfn_model.to(str(self.device)) |
| |
|
| | assert len(images) == len(prompts), "Expected same number of images and prompts" |
| | images = torch.stack([self.dfn_preprocess(x) for x in images]) |
| | text = self.dfn_tokenizer(prompts, context_length=self.dfn_model.context_length) |
| | dfn_dtype = next(iter(self.dfn_model.parameters())).dtype |
| |
|
| | with torch.no_grad(), torch.cuda.amp.autocast(): |
| | image_features = self.dfn_model.encode_image(images.to(device=self.device, dtype=dfn_dtype)) |
| | text_features = self.dfn_model.encode_text(text.to(device=self.device)) |
| | image_features = F.normalize(image_features, dim=-1) |
| | text_features = F.normalize(text_features, dim=-1) |
| | sim = (image_features * text_features).sum(dim=-1) |
| |
|
| | return sim |
| |
|
| |
|
| | def get_clip_score( |
| | self, |
| | images, |
| | prompts |
| | ): |
| |
|
| | if isinstance(images, Tensor): |
| | images = [Im(x).pil for x in images] |
| |
|
| | from transformers import ( |
| | CLIPTokenizer, |
| | CLIPTextModelWithProjection, |
| | CLIPVisionModelWithProjection, |
| | CLIPImageProcessor, |
| | ) |
| |
|
| | if not hasattr(self, "clip_tokenizer"): |
| | clip_id = "openai/clip-vit-large-patch14" |
| | self.clip_tokenizer = CLIPTokenizer.from_pretrained(clip_id) |
| | self.clip_text_encoder = CLIPTextModelWithProjection.from_pretrained(clip_id).to(self.device) |
| | self.clip_image_processor = CLIPImageProcessor.from_pretrained(clip_id) |
| | self.clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(clip_id).to(self.device) |
| |
|
| | assert len(images) == len(prompts), "Expected same number of images and prompts" |
| |
|
| | with torch.no_grad(), torch.cuda.amp.autocast(): |
| | preprocessed_images = self.clip_image_processor(images, return_tensors="pt")["pixel_values"] |
| | image_features = self.clip_image_encoder(pixel_values=preprocessed_images.to(self.device)).image_embeds |
| | image_features = image_features / image_features.norm(dim=1, keepdim=True) |
| |
|
| | tokenized_text = self.clip_tokenizer( |
| | prompts, |
| | max_length=self.clip_tokenizer.model_max_length, |
| | padding="max_length", |
| | truncation=True, |
| | return_tensors="pt" |
| | ) |
| | text_features = self.clip_text_encoder(input_ids=tokenized_text.input_ids.to(self.device)).text_embeds |
| | text_features = text_features / text_features.norm(dim=1, keepdim=True) |
| |
|
| | sim = (image_features * text_features).sum(dim=-1) |
| |
|
| | return sim |
| |
|
| | def get_laion_aesthetic_score( |
| | self, |
| | images, |
| | prompts |
| | ): |
| | from unidisc.tokenizers.laion_aesthetic_v2 import get_predictor_func |
| | if not hasattr(self, "laion_aesthetic_model"): |
| | self.laion_aesthetic_model = get_predictor_func(self.device) |
| |
|
| | return torch.from_numpy(self.laion_aesthetic_model(images)).squeeze(-1) |
| |
|
| | def get_model_likelihood_score(self, batch, num_timesteps=100, return_unweighed=True): |
| | class_log_probs = [] |
| | unweighed_class_log_probs = [] |
| | effective_batch_size = batch['modality'].shape[0] |
| | empty_device_cache() |
| | times = torch.linspace(0, 1, steps=num_timesteps + 2)[1:-1].to(self.device).to(torch.float32) |
| | attention_mask = batch['attention_mask'] |
| | |
| | for i in range(num_timesteps): |
| | empty_device_cache() |
| | t = times[i] |
| | t = t.expand(effective_batch_size) |
| | sigma, dsigma = self.noise(t) |
| |
|
| | unet_conditioning = None |
| | move_chance = 1 - torch.exp(-sigma[:, None]) |
| |
|
| | x0 = batch['input_ids'] |
| | xt = self.q_xt(x0, move_chance) |
| |
|
| | model_output = self.forward( |
| | xt, unet_conditioning, return_additional_loss=True, batch=batch, modality=batch['modality'] |
| | ) |
| |
|
| | log_p_theta = torch.gather(input=model_output, dim=-1, index=x0[:, :, None]).squeeze(-1) |
| | log_p_theta = torch.where(attention_mask, log_p_theta, 0) |
| | std_weighting = (dsigma / torch.expm1(sigma))[:, None] |
| | unweighed_log_p_theta = -log_p_theta |
| | loss = -log_p_theta * std_weighting |
| | log_probs = loss.sum(dim=-1) / attention_mask.sum(dim=-1) |
| | unweighed_log_probs = unweighed_log_p_theta.sum(dim=-1) / attention_mask.sum(dim=-1) |
| | |
| | class_log_probs.append(log_probs) |
| | unweighed_class_log_probs.append(unweighed_log_probs) |
| |
|
| | overall_time_log_probs = torch.stack(class_log_probs) |
| | unweighed_overall_time_log_probs = torch.stack(unweighed_class_log_probs) |
| |
|
| | if return_unweighed: |
| | return unweighed_overall_time_log_probs.mean(dim=0) |
| | return overall_time_log_probs.mean(dim=0) |
| |
|
| | def get_chameleon_score(self, images, prompts): |
| | return torch.tensor(self.calculate_chameleon_perplexity(None, None, prompts, images)) |
| |
|
| | def get_text_likelihood_score(self, images, prompts): |
| | return self.compute_generative_perplexity(prompts, return_raw_score=True) |
| |
|
| | @torch.inference_mode() |
| | def get_text_reward_model_score( |
| | self, |
| | images, |
| | prompts |
| | ): |
| | if not hasattr(self, "text_reward_model"): |
| | from transformers import AutoModelForSequenceClassification, AutoTokenizer |
| | model_name = "Skywork/Skywork-Reward-Llama-3.1-8B" |
| | self.text_reward_model = AutoModelForSequenceClassification.from_pretrained( |
| | model_name, |
| | torch_dtype=torch.bfloat16, |
| | device_map=self.device, |
| | num_labels=1, |
| | ) |
| | self.text_reward_tokenizer = AutoTokenizer.from_pretrained(model_name) |
| |
|
| | prompt = "Please generate a realistic caption for a text-to-image generator. The caption should have proper grammar and describe a realistic scene that a user might ask for. The caption should not be non-sensical. The caption does not need to be elaborate, but should be descriptive and realistic. Penalize improper grammar and spelling." |
| |
|
| | batch_size = 4 |
| | formatted_conversations = [] |
| | for resp in prompts: |
| | conv = [{"role": "user", "content": prompt}, {"role": "assistant", "content": resp}] |
| | formatted = self.text_reward_tokenizer.apply_chat_template(conv, tokenize=False) |
| | formatted_conversations.append(formatted) |
| | |
| | all_scores = [] |
| | for i in range(0, len(formatted_conversations), batch_size): |
| | batch_texts = formatted_conversations[i : i + batch_size] |
| | batch_inputs = self.text_reward_tokenizer( |
| | batch_texts, return_tensors="pt", padding=True, truncation=True |
| | ).to(self.device) |
| | |
| | with torch.no_grad(): |
| | batch_logits = self.text_reward_model(**batch_inputs).logits.squeeze(-1) |
| |
|
| | all_scores.extend(batch_logits.cpu().tolist()) |
| |
|
| | return torch.tensor(all_scores).to(self.device) |
| |
|
| |
|
| | def get_rewards(self, reward_config, images, prompts, batch=None, return_raw_rewards=False): |
| | assert isinstance(images, Tensor) and isinstance(prompts, list), "Expected images to be a Tensor and prompts to be a list" |
| | assert images.ndim == 4 and 0 <= images.min() and images.max() <= 1, "Expected images to be in [0, 1]" |
| | assert len(prompts) == images.shape[0], "Expected same number of images and prompts" |
| | reward_name_to_fn = dict( |
| | dfn_score=self.get_dfn_score, |
| | clip_score=self.get_clip_score, |
| | hpsv2_score=self.get_hpsv2_score, |
| | laion_aesthetic_score=self.get_laion_aesthetic_score, |
| | model_likelihood_score=self.get_model_likelihood_score, |
| | chameleon_score=self.get_chameleon_score, |
| | text_likelihood_score=self.get_text_likelihood_score, |
| | text_reward_model_score=self.get_text_reward_model_score |
| | ) |
| |
|
| | rewards = [] |
| | raw_rewards = dict() |
| | for reward_name, reward_weight in reward_config.items(): |
| | start_time = time.time() |
| | assert reward_name in reward_name_to_fn, f"Invalid reward name: {reward_name}" |
| | reward_fn = reward_name_to_fn[reward_name] |
| | if reward_name == "model_likelihood_score" or reward_name == "chameleon_score" or reward_name == "text_likelihood_score": |
| | assert batch is not None, "Expected batch to be provided for model likelihood score" |
| | if reward_name == "chameleon_score" or reward_name == "text_likelihood_score": |
| | reward = reward_fn(images, prompts).cpu() |
| | else: |
| | reward = reward_fn(batch=batch).cpu() |
| | raw_rewards[reward_name] = reward |
| | rprint(f"Orig {reward_name}: {reward}") |
| | reward = -reward |
| | reward = (reward - reward.min()) / (reward.max() - reward.min()) |
| | rprint(f"Normalized {reward_name}: {reward}") |
| | else: |
| | reward = reward_fn(images, prompts).cpu() |
| | raw_rewards[reward_name] = reward |
| | |
| | reward = (reward - reward.min()) / (reward.max() - reward.min()) |
| |
|
| | reward = torch.nan_to_num(reward, nan=0.0) |
| | rewards.append(reward * reward_weight) |
| | print(f"Processed {reward_name} in {time.time() - start_time:.2f} seconds") |
| |
|
| | rewards = torch.stack(rewards, dim=-1).sum(dim=-1) |
| |
|
| | if return_raw_rewards: |
| | return rewards, raw_rewards |
| |
|
| | return rewards |
| |
|
| | def clear_reward_models(self): |
| | if hasattr(self, "laion_aesthetic_model"): |
| | del self.laion_aesthetic_model |
| | if hasattr(self, "dfn_model"): |
| | del self.dfn_model |
| | if hasattr(self, "dfn_tokenizer"): |
| | del self.dfn_tokenizer |
| | if hasattr(self, "clip_tokenizer"): |
| | del self.clip_tokenizer |
| | if hasattr(self, "clip_text_encoder"): |
| | del self.clip_text_encoder |
| | if hasattr(self, "clip_image_processor"): |
| | del self.clip_image_processor |
| | if hasattr(self, "clip_image_encoder"): |
| | del self.clip_image_encoder |
| | if hasattr(self, "text_reward_model"): |
| | del self.text_reward_model |
| | if hasattr(self, "text_reward_tokenizer"): |
| | del self.text_reward_tokenizer |
| | if hasattr(self, "hpsv2_model_dict"): |
| | del self.hpsv2_model_dict |
| |
|
| | def auto_enhance(self, batch): |
| | gprint(f"Auto enhancing") |
| | from dataloader import tokenize_text |
| | assert isinstance(batch, TensorDict), "Expected batch to be a TensorDict" |
| | batch = batch.squeeze(1) |
| | assert batch['input_ids'].ndim == 2, "Expected batch to be 2D" |
| |
|
| | |
| | |
| | |
| |
|
| | x0 = batch["input_ids"].clone() |
| | add_object = getattr(self.config.eval, "auto_enhance_add_object", False) |
| | if add_object: |
| | img_tokens = x0[:, self.static_img_sl] - self.text_vocab_size |
| | assert 0 <= img_tokens.min() and img_tokens.max() <= self.image_vocab_size, "Expected img tokens to be in [0, img_vocab_size]" |
| | orig_imgs = decode_latents(self.config, self.get_vae(), img_tokens) |
| | orig_imgs = [Im(img).pil for img in orig_imgs] |
| | aug_imgs = [augment_image_with_random_object_coco(img, str(UNIDISC_DIR / "archive" / "objects")) for img in orig_imgs] |
| | gprint(f"Augmented {len(aug_imgs)} images") |
| | aug_imgs = torch.stack([Im(img).torch for img in aug_imgs]).to(self.device) |
| | image_ids = get_image_batch(self.config, self.get_vae(), {"img": aug_imgs}, self.device) |
| | x0[:, self.static_img_sl] = image_ids + self.text_vocab_size |
| |
|
| | gen_batch = batch.clone() |
| | if 'interleaved_metadata' in gen_batch: |
| | del gen_batch['interleaved_metadata'] |
| | gen_batch.auto_batch_size_() |
| |
|
| | orig_caption = wrapped_batch_decode(self.tokenizer, batch['input_ids'][:, self.static_txt_sl], clean_up_tokenization_spaces=True, skip_special_tokens=True, disable_mask_after_eos=True) |
| |
|
| | max_num_augmentations = getattr(self.config.eval, "max_num_auto_enhance_augmentations", 10) |
| |
|
| | llm_func = get_llm(llm_model_type="") |
| | llm_augmented_captions = [llm_func(cap, fake_openai_failure=False)[0] for cap in orig_caption] |
| | _augmented_captions = [] |
| | for caps in llm_augmented_captions: |
| | _shuf = deepcopy(caps) |
| | random.shuffle(_shuf) |
| | assert len(_shuf) >= max_num_augmentations, "Expected at least max_num_augmentations augmentations" |
| | _augmented_captions.append(_shuf[:max_num_augmentations]) |
| | gprint(f"Augmented {len(_augmented_captions)} captions") |
| |
|
| | _orig_imgs = Im(decode_latents(self.config, self.get_vae(), x0[:, self.static_img_sl] - self.text_vocab_size)).pil |
| | if not isinstance(_orig_imgs, list): |
| | _orig_imgs = [_orig_imgs] |
| |
|
| | num_iter_per_sample = self.config.eval.num_auto_enhance_iter |
| | num_iter = num_iter_per_sample * max_num_augmentations |
| | bs = 1 |
| | n = num_iter * bs * len(_augmented_captions) |
| | _gen_batch = [] |
| | for i in range(len(_augmented_captions)): |
| | for j in range(num_iter): |
| | _gen_batch.append(gen_batch[[i]]) |
| | gen_batch = torch.cat(_gen_batch, dim=0) |
| |
|
| | txt_data = [tokenize_text(self.tokenizer, self.config.data.block_size, caps) for caps in _augmented_captions] |
| | txt_sl = slice(None, self.config.data.block_size) |
| | real_captions = [] |
| | augmented_captions = [] |
| | orig_images = [] |
| |
|
| | gprint(f"Generating {num_iter} samples, gen_batch shape: {gen_batch.shape}") |
| |
|
| | for j in range(len(_augmented_captions)): |
| | for k in range(max_num_augmentations): |
| | sl = slice(j * max_num_augmentations + k * num_iter_per_sample, j * max_num_augmentations + (k + 1) * num_iter_per_sample) |
| | gen_batch[sl]['input_ids'][:, txt_sl] = txt_data[j]['input_ids'][k] |
| | gen_batch[sl]['attention_mask'][:, txt_sl] = txt_data[j]['attention_mask'][k] |
| | augmented_captions.extend([_augmented_captions[j][k]] * num_iter_per_sample) |
| | real_captions.extend([orig_caption[j]] * num_iter_per_sample) |
| | orig_images.extend([_orig_imgs[j]] * num_iter_per_sample) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | if getattr(self.config.eval, "auto_enhance_use_low_masking", False): |
| | mean_txt, std_txt = 0.85, 0.2 / 0.8416 |
| | mean_img, std_img = 0.75, 0.04 / 1.645 |
| | else: |
| | mean_txt, std_txt = 0.85, 0.2 / 0.8416 |
| | mean_img, std_img = 0.95, 0.04 / 1.645 |
| |
|
| | def slice_len(_sl, _seq_len): |
| | |
| | assert _sl.step is None |
| | if _sl.start is not None and _sl.start < 0: |
| | assert _sl.stop is None |
| | return -_sl.start |
| | else: |
| | return (_sl.stop if _sl.stop is not None else _seq_len) - (_sl.start if _sl.start is not None else 0) |
| | |
| | seq_len = x0.shape[1] |
| |
|
| | t = torch.zeros((n,), device=self.device) |
| | t = t.to(torch.float32) |
| | |
| | t_txt = torch.normal(mean=mean_txt, std=std_txt, size=(n,), device=self.device) |
| | t_img = torch.normal(mean=mean_img, std=std_img, size=(n,), device=self.device) |
| | |
| | t_txt = torch.clamp(t_txt, max=1.0) |
| | t_img = torch.clamp(t_img, max=1.0) |
| | move_indices = torch.zeros(n, seq_len, device=self.device, dtype=torch.bool) |
| | |
| | move_indices[:, self.static_txt_sl] = torch.rand(move_indices.shape[0], slice_len(self.static_txt_sl, seq_len), device=self.device) < t_txt.unsqueeze(1) |
| | move_indices[:, self.static_img_sl] = torch.rand(move_indices.shape[0], slice_len(self.static_img_sl, seq_len), device=self.device) < t_img.unsqueeze(1) |
| | |
| | x0_unmask = ~move_indices |
| | rprint(f"Text masking ratio: {move_indices[:, self.static_txt_sl].sum() / move_indices[:, self.static_txt_sl].numel():.3f}") |
| | rprint(f"Image masking ratio: {move_indices[:, self.static_img_sl].sum() / move_indices[:, self.static_img_sl].numel():.3f}") |
| | rprint(f"Num unmasked: {x0_unmask.sum(dim=-1).float().mean():.1f}") |
| |
|
| | text_samples_list = [] |
| | img_samples_list = [] |
| |
|
| | x0 = x0.to(self.device) |
| | x0_unmask = x0_unmask.to(self.device) |
| | |
| | idx = 0 |
| | for i in range(len(_augmented_captions)): |
| | for j in range(num_iter_per_sample): |
| | _modality = gen_batch[[idx]].get("modality", None) |
| | _sample_ids = gen_batch[[idx]].get("sample_ids", None) |
| | if _modality is not None: |
| | _modality = _modality.to(self.device) |
| | if _sample_ids is not None: |
| | _sample_ids = _sample_ids.to(self.device) |
| | else: |
| | _sample_ids = torch.zeros_like(_modality) |
| | text_samples, img_samples, x = self._sample( |
| | text_only=False, |
| | num_steps=self.config.sampling.max_sampling_steps, |
| | batch_size_per_gpu=bs, |
| | modality=_modality, |
| | sample_ids=_sample_ids, |
| | x0=gen_batch["input_ids"][[idx]].to(self.device), |
| | x0_unmask=x0_unmask[[idx]].to(self.device), |
| | return_raw_data=True, |
| | allow_interleaved_conditional=True |
| | ) |
| | gen_batch[[idx]]['input_ids'] = x |
| | text_samples_list.extend(text_samples) |
| | img_samples_list.extend(img_samples) |
| | rprint(f"Sampled {j + 1} / {num_iter}") |
| | idx += 1 |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | text_samples_list = wrapped_batch_decode( |
| | self.tokenizer, |
| | torch.stack(text_samples_list, dim=0), |
| | clean_up_tokenization_spaces=True, |
| | skip_special_tokens=True, |
| | disable_mask_after_eos=True |
| | ) |
| |
|
| | |
| | |
| |
|
| | img_samples_list = torch.cat(img_samples_list, dim=0) |
| |
|
| | reward_config = self.config.eval.auto_enhance_reward_config |
| | rewards, raw_rewards = self.get_rewards(reward_config, img_samples_list, text_samples_list, batch=gen_batch, return_raw_rewards=True) |
| |
|
| | gprint(f"Avg Rewards: {rewards}") |
| |
|
| | sorted_indices = torch.argsort(rewards, descending=True).tolist() |
| | sorted_text_samples = [text_samples_list[i] for i in sorted_indices] |
| | sorted_augmented_captions = [augmented_captions[i] for i in sorted_indices] |
| | sorted_real_captions = [real_captions[i] for i in sorted_indices] |
| | sorted_img_samples = [img_samples_list[i] for i in sorted_indices] |
| | sorted_orig_images = [orig_images[i] for i in sorted_indices] |
| | sorted_avg_rewards = [rewards[i] for i in sorted_indices] |
| | sorted_raw_rewards = {k: [raw_rewards[k][i] for i in sorted_indices] for k in raw_rewards} |
| |
|
| | text_samples_list = sorted_text_samples |
| | real_captions = sorted_real_captions |
| | augmented_captions = sorted_augmented_captions |
| | img_samples_list = sorted_img_samples |
| | orig_images = sorted_orig_images |
| | raw_rewards = sorted_raw_rewards |
| |
|
| | |
| | self.clear_reward_models() |
| |
|
| | log_dict = {} |
| | with try_except(write_error_to_file=True): |
| | if text_samples_list is not None: |
| | gprint(f"Gathering {len(text_samples_list)} text samples") |
| | text_samples_list = gather_object(text_samples_list) |
| |
|
| | real_captions = gather_object(real_captions) |
| | augmented_captions = gather_object(augmented_captions) |
| | prefix = "auto_enhance" |
| |
|
| | if isinstance(img_samples_list, Tensor): img_samples_list = img_samples_list.float().cpu() |
| | img_samples_list = [Im(img).pil for img in img_samples_list] |
| | img_samples_list = gather_object(img_samples_list) |
| | orig_images = gather_object(orig_images) |
| |
|
| | dprint(f"Gathered {len(text_samples_list)} text samples") |
| |
|
| | new_sorted_avg_rewards = gather_object(sorted_avg_rewards) |
| | sorted_avg_rewards = new_sorted_avg_rewards |
| |
|
| | new_raw_rewards = {k: gather_object(v) for k, v in raw_rewards.items()} |
| | raw_rewards = new_raw_rewards |
| | rprint(f"Finished gathering, length: {len(orig_images)}") |
| |
|
| | gen_table = wandb.Table(columns=[f"real_caption", f"original_image", f"augmented_caption", f"sampled_caption", f"sampled_image", f"avg_reward", *reward_config.keys()]) |
| | assert len(img_samples_list) == len(text_samples_list) == len(augmented_captions) == len(real_captions) == len(sorted_avg_rewards) |
| | for real_caption, orig_img, augmented_caption, sampled_caption, sampled_img, avg_reward, *rewards in zip(real_captions, orig_images, augmented_captions, text_samples_list, img_samples_list, sorted_avg_rewards, *raw_rewards.values()): |
| | gen_table.add_data(real_caption, wandb.Image(Im(orig_img).pil), augmented_caption, sampled_caption, wandb.Image(Im(sampled_img).pil), avg_reward, *rewards) |
| |
|
| | log_dict[f"{prefix}_sample_table"] = gen_table |
| |
|
| | log({**log_dict, **self.get_step_metrics()}) |
| |
|
| | def save_image_text_pair(self, image_tensor, text_tensor, single_image_only=False, disable_img_save=False, image_save_postfix=None): |
| | """ |
| | Take separate image and text tensors and save them as paired visualizations. |
| | |
| | Args: |
| | image_tensor: Tensor [B, N] of image tokens |
| | text_tensor: Tensor [B, M] of text tokens |
| | single_image_only: If True, only return the image without text visualization |
| | disable_img_save: If True, don't save to disk |
| | image_save_postfix: Optional postfix for the saved image filename |
| | |
| | Returns: |
| | PIL Image or tensor of concatenated images and text visualizations |
| | """ |
| | batch_size = image_tensor.shape[0] |
| | assert batch_size == text_tensor.shape[0], "Batch sizes must match between image and text tensors" |
| | |
| | all_paired_imgs = [] |
| | |
| | |
| | if hasattr(self, 'config') and hasattr(self.config, 'eval'): |
| | single_image_only = self.config.eval.auto_enhance or single_image_only or getattr(self.config.eval, "fake_interleaved", False) |
| | |
| | if hasattr(self, 'config') and hasattr(self.config.eval, "disable_shm_save"): |
| | disable_img_save = disable_img_save or getattr(self.config.eval, "disable_shm_save", False) |
| |
|
| | |
| | if not disable_img_save: |
| | date_folder = datetime.now().strftime("%Y-%m-%d") |
| | save_dir = Path("/dev/shm") / os.getenv("USER", 'user') / "paired_imgs" / date_folder |
| | save_dir.mkdir(exist_ok=True, parents=True) |
| |
|
| | for i in range(batch_size): |
| | pair_imgs = [] |
| | |
| | |
| | if not single_image_only: |
| | sample_text = wrapped_batch_decode( |
| | self.tokenizer, |
| | text_tensor[i:i+1], |
| | clean_up_tokenization_spaces=True, |
| | skip_special_tokens=False, |
| | disable_mask_after_eos=True |
| | ) |
| | txt_image = create_text_image(text=sample_text[0], desired_width=self.config.data.resolution) |
| | pair_imgs.append(txt_image) |
| | |
| | |
| | img_tokens = image_tensor[i:i+1] |
| | sample_img = decode_latents(self.config, self.get_vae(), img_tokens) |
| | pair_imgs.append(sample_img) |
| | |
| | |
| | if single_image_only: |
| | all_paired_imgs.append(pair_imgs[0]) |
| | else: |
| | paired_img = Im.concat_vertical(*pair_imgs).pil |
| | all_paired_imgs.append(paired_img) |
| |
|
| | |
| | if not disable_img_save: |
| | image_save_postfix = image_save_postfix or "" |
| | for i, img in enumerate(all_paired_imgs): |
| | filename = f"pair_{get_rank()}_{i}_{str(time.time()).replace('.', '__')}"[:100] + f"{image_save_postfix}.png" |
| | save_path = save_dir / filename |
| | gprint(Im(img).save(save_path)) |
| | |
| | |
| | if batch_size == 1: |
| | return all_paired_imgs[0] |
| | else: |
| | return all_paired_imgs |