|
import math |
|
import os |
|
import random |
|
import typing |
|
from contextlib import nullcontext |
|
from dataclasses import dataclass |
|
from pathlib import Path |
|
from types import FrameType |
|
from typing import Dict, List, Optional, Tuple, Union |
|
|
|
import einops |
|
import hydra |
|
import hydra.utils |
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
import torch.distributed as dist |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.utils.checkpoint |
|
import torchmetrics |
|
import transformers |
|
from image_utils import Im |
|
from torch import Tensor, nn |
|
from torch.utils.data import DataLoader, Dataset |
|
from tqdm import tqdm |
|
from tqdm.auto import tqdm |
|
|
|
import models |
|
import wandb |
|
from decoupled_utils import (Profiler, barrier, dprint, get_rank, |
|
get_slurm_job_id, get_world_size, gprint, |
|
is_local_main_process, is_main_process, |
|
is_torch_cuda_available, is_torch_xla_available, |
|
module_hash, mprint, parameter_hash, print_memory, |
|
rank_zero_fn, rprint, save_memory_profile, |
|
show_memory_usage, try_except, use_dist) |
|
|
|
is_xla_available = is_torch_xla_available() |
|
if is_xla_available: |
|
from unidisc.utils.standalone_metrics import MeanMetric, MetricCollection |
|
else: |
|
from torchmetrics import MetricCollection |
|
from torchmetrics.aggregation import MeanMetric |
|
|
|
LOG2 = math.log(2) |
|
|
|
@try_except(write_error_to_file=True) |
|
def log(*arg, **kwargs): |
|
for key, value in arg[0].items(): |
|
if isinstance(value, torch.Tensor): |
|
arg[0][key] = value.detach().cpu().float() |
|
|
|
if is_main_process(): |
|
wandb.log(*arg, **kwargs) |
|
|
|
def replace_nan_dict(x): |
|
return {k: v.nan_to_num(0) for k, v in x.items()} |
|
|
|
def ddprint(*args, **kwargs): |
|
mprint(*args, **kwargs) |
|
|
|
def empty_device_cache(): |
|
if is_torch_cuda_available(): |
|
torch.cuda.empty_cache() |
|
else: |
|
dprint("Not using cuda, skipping cache clear") |
|
|
|
def update_logs(_logs, _extra_logs): |
|
_logs.update(_extra_logs()) |
|
for k, v in _logs.items(): |
|
if isinstance(v, torch.Tensor): |
|
_logs[k] = v.detach().cpu().item() |
|
gprint(f"Converting {k} to item: {v}") |
|
|
|
log(_logs) |
|
|
|
def ema_update(model_dest: nn.Module, model_src: nn.Module, rate): |
|
param_dict_src = dict(model_src.named_parameters()) |
|
for p_name, p_dest in model_dest.named_parameters(): |
|
if p_name not in param_dict_src: |
|
print(f"Parameter {p_name} not found in src: {param_dict_src}") |
|
p_src = param_dict_src[p_name] |
|
assert p_src is not p_dest |
|
p_dest.data.mul_(rate).add_((1 - rate) * p_src.data) |
|
|
|
def identity(x): |
|
return x |
|
|
|
def remap_image_torch(image): |
|
image_torch = image * 255 |
|
image_torch = torch.clip(image_torch, 0, 255).to(torch.uint8) |
|
return image_torch |
|
|
|
def _sample_categorical(categorical_probs): |
|
gumbel_norm = 1e-10 - (torch.rand_like(categorical_probs) + 1e-10).log() |
|
return (categorical_probs / gumbel_norm).argmax(dim=-1) |
|
|
|
def wrapped_batch_decode(tokenizer, tokens, disable_mask_after_eos=False, **kwargs): |
|
tokens = tokens.clone() |
|
if (tokenizer.bos_token_id != tokenizer.eos_token_id) and not disable_mask_after_eos: |
|
after_first_five = torch.cumsum(tokens == tokenizer.eos_token_id, dim=1).bool() |
|
tokens[after_first_five.cumsum(dim=1) > 1] = tokenizer.pad_token_id |
|
return tokenizer.batch_decode(tokens, **kwargs) |
|
|
|
def _unsqueeze(x, reference): |
|
return x.view(*x.shape, *((1,) * (len(reference.shape) - len(x.shape)))) |
|
|
|
|
|
@dataclass |
|
class Loss: |
|
loss: torch.FloatTensor |
|
img_loss: torch.FloatTensor = None |
|
txt_loss: torch.FloatTensor = None |
|
nlls: torch.FloatTensor = None |
|
token_mask: torch.FloatTensor = None |
|
txt_nlls: torch.FloatTensor = None |
|
img_nlls: torch.FloatTensor = None |
|
extra_losses: dict = None |
|
modality_mask: torch.FloatTensor = None |
|
|
|
|
|
class NLL(MeanMetric): |
|
pass |
|
|
|
|
|
class BPD(NLL): |
|
def compute(self) -> Tensor: |
|
"""Computes the bits per dimension. |
|
|
|
Returns: |
|
bpd |
|
""" |
|
return self.mean_value / self.weight / LOG2 |
|
|
|
|
|
class Perplexity(NLL): |
|
def compute(self) -> Tensor: |
|
"""Computes the Perplexity. |
|
|
|
Returns: |
|
Perplexity |
|
""" |
|
return torch.exp(self.mean_value / self.weight) |
|
|
|
class Entropy(NLL): |
|
def compute(self) -> Tensor: |
|
"""Computes the Entropy. |
|
|
|
Returns: |
|
Entropy |
|
""" |
|
return self.mean_value / self.weight |
|
|
|
class MauveScore(NLL): |
|
def compute(self) -> Tensor: |
|
"""Computes the Mauve Score. |
|
|
|
Returns: |
|
Mauve Score |
|
""" |
|
return self.mean_value / self.weight |
|
|
|
class CIDErScore(NLL): |
|
def compute(self) -> Tensor: |
|
"""Computes the CIDEr Score. |
|
|
|
Returns: |
|
CIDEr Score |
|
""" |
|
return self.mean_value / self.weight |
|
|
|
class Accuracy(NLL): |
|
def compute(self) -> Tensor: |
|
"""Computes the Accuracy. |
|
|
|
Returns: |
|
Accuracy |
|
""" |
|
return self.mean_value / self.weight |
|
|
|
def get_coord_plot(self): |
|
from mup.coord_check import get_coord_data, plot_coord_data |
|
def gen(w): |
|
def f(): |
|
from copy import deepcopy |
|
|
|
from omegaconf import read_write |
|
|
|
import models as _models |
|
_conf = deepcopy(self.config) |
|
with read_write(_conf): |
|
_conf.model.hidden_size = _conf.model.n_heads * w |
|
|
|
_backbone = _models.dit.DIT( |
|
_conf, vocab_size=self.vocab_size, mask_index=self.mask_index, text_vocab_size=self.text_vocab_size, dtype=self.dtype |
|
) |
|
self.get_base_shapes_for_mup(_backbone) |
|
return _backbone |
|
return f |
|
|
|
optimizer = 'adamw' |
|
widths = np.array([2**i for i in range(2, 6)]) |
|
models = {int(w) * self.config.model.n_heads: gen(int(w)) for w in widths} |
|
|
|
fake_dataloader = [] |
|
self.validation_dataloader.num_workers = 0 |
|
nsteps = 30 |
|
for i, dataloader_batch in enumerate(self.validation_dataloader): |
|
fake_batch = self.update_batch(dataloader_batch) |
|
fake_batch['x0'] = fake_batch["input_ids"] |
|
t = self._sample_t(fake_batch['x0'].shape[0], fake_batch['x0'].device) |
|
sigma, dsigma = self.noise(t) |
|
move_chance = 1 - torch.exp(-sigma[:, None]) |
|
xt = self.q_xt(fake_batch['x0'], move_chance) |
|
fake_batch['xt'] = xt |
|
|
|
fake_dataloader.append(fake_batch) |
|
if i >= nsteps: |
|
break |
|
|
|
def loss_fn(_batch, _logits): |
|
attention_mask = _batch['attention_mask'] |
|
model_output = self._subs_parameterization(logits=_logits, xt=_batch['xt']) |
|
log_p_theta = torch.gather(input=model_output, dim=-1, index=_batch['x0'][:, :, None]).squeeze(-1) |
|
std_weighting = (dsigma / torch.expm1(sigma))[:, None] |
|
loss = -log_p_theta * std_weighting |
|
loss = (loss * attention_mask).sum() / attention_mask.sum() |
|
return loss |
|
|
|
mup = True |
|
lr = 1e-2 |
|
prm = 'μP' if mup else 'SP' |
|
nseeds = 2 |
|
with torch.autocast(device_type=self.device.type, dtype=self.dtype): |
|
df = get_coord_data( |
|
models, |
|
fake_dataloader, |
|
lr=lr, |
|
optimizer=optimizer, |
|
nsteps=nsteps, |
|
nseeds=nseeds, |
|
dict_in_out=True, |
|
lossfn=loss_fn, |
|
mup=mup, |
|
) |
|
|
|
output_path = Path(__file__).parent / 'output' / f'{prm.lower()}_trsfmr_{optimizer}_coord.png' |
|
output_path.parent.mkdir(parents=True, exist_ok=True) |
|
plot_coord_data( |
|
df, |
|
legend='brief', |
|
save_to=str(output_path.resolve()), |
|
suptitle=f'{prm} Transformer {optimizer} lr={lr} nseeds={nseeds}', |
|
face_color='xkcd:light grey' if not mup else None, |
|
loglog=True |
|
) |
|
rprint(f"Saved coord plot to {output_path.resolve()}") |
|
|
|
csv_path = output_path.with_suffix('.csv') |
|
df.to_csv(csv_path, index=False) |
|
rprint(f"DataFrame saved as CSV to {csv_path.resolve()}") |
|
|
|
result = df[df['t'] == 1].nsmallest(100, 'l1').sort_values('l1', ascending=True) |
|
with pd.option_context('display.max_rows', None, 'display.max_columns', None): |
|
print(result[['module', 'width', 'l1']]) |
|
exit() |
|
|
|
def _score_entropy(self, log_score, sigma, xt, x0): |
|
"""Computes the SEDD loss. |
|
|
|
Args: |
|
log_score: float torch.Tensor with shape (batch_size, |
|
diffusion_model_input_length, vocab_size), |
|
log score, output of the denoising network. |
|
xt: int torch.Tensor with shape (batch_size, |
|
diffusion_model_input_length), input. |
|
x0: int torch.Tensor with shape (batch_size, |
|
diffusion_model_input_length), input. |
|
sigma: float torch.Tensor with shape (batch_size, 1). |
|
|
|
Returns: |
|
loss with shape (batch_size, diffusion_model_input_length) |
|
""" |
|
masked_indices = xt == self.mask_index |
|
|
|
expsig_minus_1 = torch.expm1(sigma).expand_as(xt) |
|
q_ratio = 1 / expsig_minus_1[masked_indices] |
|
|
|
words_that_were_masked = x0[masked_indices] |
|
|
|
neg_term = q_ratio * torch.gather(log_score[masked_indices], -1, words_that_were_masked[..., None]).squeeze(-1) |
|
score = log_score[masked_indices].exp() |
|
if self.mask_index == self.vocab_size - 1: |
|
pos_term = score[:, :-1].sum(dim=-1) |
|
else: |
|
pos_term = score[:, : self.mask_index].sum(dim=-1) + score[:, self.mask_index + 1 :].sum(dim=-1) |
|
const = q_ratio * (q_ratio.log() - 1) |
|
|
|
entropy = torch.zeros(*xt.shape, device=xt.device) |
|
entropy[masked_indices] += pos_term - neg_term + const |
|
return entropy |
|
|
|
@torch.no_grad |
|
def sample_subs_guidance(self, n_samples, stride_length, num_strides, dt=0.001): |
|
ones = torch.ones(n_samples, dtype=self.dtype, device=self.device) |
|
|
|
num_steps = int(1 / dt) |
|
sampling_steps = 0 |
|
intermediate_tokens = [] |
|
target = None |
|
for _ in range(num_strides + 1): |
|
p_x0_cache = None |
|
x = self._sample_prior(n_samples, self.config.model.length).to(self.device) |
|
if target is not None: |
|
x[:, :-stride_length] = target |
|
for i in range(num_steps + 1): |
|
p_x0_cache, x_next, nfe_cnt = self._ddpm_caching_update(x=x, t=(1 - i * dt) * ones, dt=dt, p_x0=p_x0_cache) |
|
if not torch.allclose(x_next, x) or self.time_conditioning: |
|
p_x0_cache = None |
|
sampling_steps += 1 |
|
x = x_next |
|
x = self.forward(x, 0 * ones).argmax(dim=-1) |
|
intermediate_tokens.append(x[:, :stride_length].cpu().numpy()) |
|
target = x[:, stride_length:] |
|
|
|
intermediate_tokens.append(target.cpu().numpy()) |
|
intermediate_text_samples = [] |
|
sequence_lengths = ((np.concatenate(intermediate_tokens, axis=1)[:, 1:] == self.tokenizer.eos_token_id).cumsum(-1) == 0).sum(-1) |
|
for i in range(2, len(intermediate_tokens) + 1): |
|
intermediate_text_samples.append(self.tokenizer.batch_decode(np.concatenate(intermediate_tokens[:i], axis=1))) |
|
return (sampling_steps, intermediate_text_samples, sequence_lengths) |
|
|
|
def restore_model_and_semi_ar_sample(self, stride_length, num_strides, dt=0.001): |
|
"""Generate samples from the model.""" |
|
|
|
if self.ema: |
|
self.ema.store(self.get_params()) |
|
self.ema.copy_to(self.get_params()) |
|
self.backbone.eval() |
|
(sampling_steps, samples, sequence_lengths) = self.sample_subs_guidance( |
|
n_samples=self.config.loader.eval_batch_size, stride_length=stride_length, num_strides=num_strides, dt=dt |
|
) |
|
if self.ema: |
|
self.ema.restore(self.get_params()) |
|
self.backbone.train() |
|
self.noise.train() |
|
return sampling_steps, samples, sequence_lengths |
|
|
|
def _reconstruction_loss(self, x0): |
|
t0 = torch.zeros(x0.shape[0], dtype=self.dtype, device=self.device) |
|
assert self.config.noise.type == "loglinear" |
|
|
|
unet_conditioning = self.noise(t0)[0][:, None] |
|
model_output_t0 = self.forward(x0, unet_conditioning) |
|
return -torch.gather(input=model_output_t0, dim=-1, index=x0[:, :, None]).squeeze(-1) |
|
|
|
def restore_model_and_sample(self, num_steps, eps=1e-5): |
|
"""Generate samples from the model.""" |
|
|
|
if self.ema is not None: |
|
self.ema.store(self.get_params()) |
|
self.ema.copy_to(self.get_params()) |
|
self.backbone.eval() |
|
samples = self._sample(num_steps=num_steps, eps=eps) |
|
if self.ema is not None: |
|
self.ema.restore(self.get_params()) |
|
self.backbone.train() |
|
return samples |
|
|
|
def get_score(self, x, sigma, **kwargs): |
|
model_output = self.forward(x, sigma, **kwargs) |
|
if self.parameterization == "subs": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log_k = -torch.log(torch.expm1(sigma)).squeeze(-1) |
|
assert log_k.ndim == 1 |
|
|
|
masked_score = model_output + log_k[:, None, None] |
|
masked_score[:, :, self.mask_index] = 0 |
|
|
|
unmasked_score = self.neg_infinity * torch.ones_like(model_output) |
|
unmasked_score = torch.scatter(unmasked_score, -1, x[..., None], torch.zeros_like(unmasked_score[..., :1])) |
|
unmasked_score[:, :, self.mask_index] = -(log_k[:, None] * torch.ones_like(x)) |
|
|
|
masked_indices = (x == self.mask_index).to(model_output.dtype)[:, :, None] |
|
model_output = masked_score * masked_indices + unmasked_score * (1 - masked_indices) |
|
return model_output.exp() |
|
|
|
def _staggered_score(self, score, dsigma): |
|
score = score.clone() |
|
extra_const = (1 - dsigma.exp()) * score.sum(dim=-1) |
|
score *= dsigma.exp()[:, None] |
|
score[..., self.mask_index] += extra_const |
|
return score |
|
|
|
def _analytic_update(self, x, t, step_size): |
|
curr_sigma, _ = self.noise(t) |
|
next_sigma, _ = self.noise(t - step_size) |
|
dsigma = curr_sigma - next_sigma |
|
nfe_cnt = 0 |
|
score = self.get_score(x, curr_sigma) |
|
nfe_cnt += 1 |
|
stag_score = self._staggered_score(score, dsigma) |
|
probs = stag_score * self._transp_transition(x, dsigma) |
|
return _sample_categorical(probs), nfe_cnt |
|
|
|
def _denoiser_update(self, x, t): |
|
sigma, _ = self.noise(t) |
|
score = self.get_score(x, sigma) |
|
stag_score = self._staggered_score(score, sigma) |
|
probs = stag_score * self._transp_transition(x, sigma) |
|
probs[..., self.mask_index] = 0 |
|
samples = _sample_categorical(probs) |
|
return samples |
|
|
|
def _transp_transition(self, i, sigma): |
|
sigma = _unsqueeze(sigma, reference=i[..., None]) |
|
edge = torch.exp(-sigma) * F.one_hot(i, num_classes=self.vocab_size) |
|
edge += torch.where(i == self.mask_index, 1 - torch.exp(-sigma).squeeze(-1), 0)[..., None] |
|
return edge |
|
|
|
@torch.no_grad() |
|
def eval_retokenize(self, text_samples, max_length): |
|
"""Retokenizes samples for the eval model. |
|
|
|
Args: |
|
text_samples: List of sentences generated by the model. |
|
Returns: |
|
samples: Samples re-tokenized for the eval model |
|
attn_mask: Attention mask for the eval model |
|
eval_context_size: Size of the context for the eval model |
|
""" |
|
if "llama2" in self.gen_ppl_eval_model_name_or_path: |
|
tokenizer_kwargs = { |
|
"text_samples": text_samples, |
|
"return_tensors": "pt", |
|
"return_token_type_ids": False, |
|
"return_attention_mask": True, |
|
"truncation": True, |
|
"padding": True, |
|
"max_length": max_length, |
|
} |
|
eval_context_size = 4096 |
|
else: |
|
tokenizer_kwargs = { |
|
"return_tensors": "pt", |
|
"return_token_type_ids": False, |
|
"return_attention_mask": True, |
|
"truncation": True, |
|
"padding": True, |
|
"max_length": max_length, |
|
} |
|
eval_context_size = 1024 |
|
|
|
if getattr(self.config.eval, "force_eval_context_size_match_model", False): |
|
eval_context_size = self.config.model.txt_length |
|
|
|
samples = self.eval_model_tokenizer(text_samples, **tokenizer_kwargs) |
|
attn_mask = samples["attention_mask"] |
|
samples = samples["input_ids"] |
|
if "llama2" not in self.gen_ppl_eval_model_name_or_path: |
|
attn_mask = attn_mask.to(self.device) |
|
samples = samples.to(self.device) |
|
return samples, attn_mask, eval_context_size |
|
|
|
|
|
@try_except(write_error_to_file=True) |
|
@torch.no_grad() |
|
def compute_cider(self, text_samples, gt_text_samples): |
|
"""Compute the CIDEr score for the generated text. |
|
Args: |
|
text_samples: List of sentences generated by the model. |
|
gt_text_samples: List of ground truth sentences. |
|
Returns: |
|
CIDEr score for the generated text. |
|
""" |
|
for text_sample, gt_text_sample in zip(text_samples, gt_text_samples): |
|
self.cider_score_metric += (text_sample, gt_text_sample) |
|
score = self.cider_score_metric.compute_cider() |
|
avg_score = sum(score) / len(score) |
|
self.cider_score.update(avg_score.item()) |
|
|
|
|
|
def get_anole_data(model, processor, prompt, image, device): |
|
inputs = processor(prompt, [image], padding=True, return_tensors="pt").to(device=device, dtype=dtype) |
|
image_tokens = model.model.get_image_tokens(inputs["pixel_values"]) |
|
special_image_mask = inputs["input_ids"] == model.model.vocabulary_mapping.image_token_id |
|
image_tokens = image_tokens.to(inputs["input_ids"].device, inputs["input_ids"].dtype) |
|
inputs["input_ids"] = inputs["input_ids"].masked_scatter(special_image_mask, image_tokens) |
|
inputs.pop("pixel_values") |
|
inputs['input_ids'] = torch.load('save.pth').to(device) |
|
return inputs |
|
|
|
@try_except(write_error_to_file=True) |
|
@torch.inference_mode() |
|
def compute_generative_perplexity(self, text_samples: typing.List[str], retokenize: bool = True, max_length: typing.Optional[int] = None, gt: bool = False, return_raw_score: bool = False) -> None: |
|
"""Compute the generative perplexity of the model. |
|
|
|
Args: |
|
text_samples: List of sentences generated by the model. |
|
retokenize: Whether to retokenize using eval model's tokenizer |
|
max_length: Maximum sequence length for tokenization |
|
gt: Whether these are ground truth samples |
|
return_raw_score: Whether to return raw NLL scores instead of updating metrics |
|
|
|
Returns: |
|
If return_raw_score is True, returns tensor of NLL scores. |
|
Otherwise updates internal perplexity metrics. |
|
""" |
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
if not getattr(self.config.eval, 'enable_gen_pplx_cleanup', True): |
|
eval_model = self.gen_pplx_eval_model |
|
elif getattr(self.config.eval, 'gen_ppl_use_chameleon', False): |
|
from transformers import (ChameleonForConditionalGeneration, |
|
ChameleonProcessor) |
|
model = ChameleonForConditionalGeneration.from_pretrained("leloy/Anole-7b-v0.1-hf", torch_dtype=torch.bfloat16).to("cuda") |
|
processor = ChameleonProcessor.from_pretrained("leloy/Anole-7b-v0.1-hf") |
|
image = Im(Im("https://cdn.outsideonline.com/wp-content/uploads/2023/03/Funny_Dog_H.jpg").np[50:-150, 550:-900, :]).resize(256, 256).pil |
|
prompt = "A picture of a cat.<image>" |
|
device = "cuda:0" |
|
inputs = get_anole_data(model, processor, prompt, image, self.dtype, device) |
|
output = model(input_ids=inputs['input_ids'].to(device)) |
|
attention_mask = torch.ones_like(inputs["input_ids"]) |
|
logits = output.logits |
|
logits = logits.transpose(-1, -2) |
|
sample_chunk = inputs["input_ids"] |
|
nlls = F.cross_entropy(logits[..., :-1].to(device), sample_chunk[..., 1:].to(device), reduction="none") |
|
nlls = nlls * attention_mask[..., 1:].to(nlls.dtype) |
|
nlls = nlls.sum(-1) / attention_mask[..., 1:].sum(-1) |
|
print(torch.exp(nlls)) |
|
else: |
|
eval_model = transformers.AutoModelForCausalLM.from_pretrained(self.gen_ppl_eval_model_name_or_path).eval() |
|
if max_length is None: |
|
max_length = self.config.model.txt_length |
|
|
|
if "llama2" not in self.gen_ppl_eval_model_name_or_path: |
|
eval_model = eval_model.to(self.device) |
|
|
|
|
|
if retokenize: |
|
(samples, attn_mask, eval_context_size) = self.eval_retokenize(text_samples, max_length=max_length) |
|
else: |
|
samples = text_samples |
|
attn_mask = torch.ones(samples.shape).to(self.device) |
|
eval_context_size = samples.shape[-1] |
|
|
|
batch_size = min(self.config.eval.perplexity_batch_size, samples.shape[0]) |
|
num_batches = (samples.shape[0] + batch_size - 1) // batch_size |
|
all_nlls = [] |
|
all_valid_mask = [] |
|
for i in range(num_batches): |
|
batch_samples = samples[i * batch_size : (i + 1) * batch_size] |
|
batch_attn_mask = attn_mask[i * batch_size : (i + 1) * batch_size] |
|
|
|
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.FLASH_ATTENTION, torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION]): |
|
logits = eval_model(batch_samples, attention_mask=batch_attn_mask)[0] |
|
|
|
logits = logits.transpose(-1, -2) |
|
nlls = F.cross_entropy(logits[..., :-1], batch_samples[..., 1:], reduction="none") |
|
|
|
|
|
first_eos = (batch_samples == self.eval_model_tokenizer.eos_token_id).cumsum(-1) <= 1 |
|
token_mask = batch_attn_mask[..., 1:] > 0 |
|
valid_mask = first_eos[..., 1:] * token_mask |
|
|
|
if not return_raw_score: |
|
if gt: |
|
self.gt_gen_ppl_metric.update(nlls, valid_mask) |
|
else: |
|
self.gen_ppl_metric.update(nlls, valid_mask) |
|
else: |
|
all_nlls.append(nlls) |
|
all_valid_mask.append(valid_mask) |
|
|
|
if getattr(self.config.eval, 'enable_gen_pplx_cleanup', True): |
|
eval_model.to(torch.device('cpu')) |
|
del eval_model |
|
|
|
if return_raw_score: |
|
all_nlls = torch.cat(all_nlls) |
|
all_valid_mask = torch.cat(all_valid_mask) |
|
|
|
nll = (all_nlls * all_valid_mask).sum(-1) / all_valid_mask.sum(-1) |
|
return nll |
|
|
|
def _d3pm_loss(self, model_output, xt, x0, t): |
|
dt = 1 / self.T |
|
|
|
if torch.is_tensor(t): |
|
t = t[:, None] |
|
assert t.ndim == 2 |
|
t = t.clamp(0.0, 1.0 - 1e-4) |
|
alpha_t = 1 - t + torch.zeros_like(xt) |
|
alpha_s = 1 - (t - dt) + torch.zeros_like(xt) |
|
|
|
log_x_theta_at_x0 = torch.gather(model_output, -1, x0[:, :, None]).squeeze(-1) |
|
log_x_theta_at_m = model_output[:, :, self.mask_index] |
|
x_theta_at_m = log_x_theta_at_m.exp() |
|
|
|
term_1_coef = dt / t |
|
term_1_log_nr = torch.log(alpha_t * x_theta_at_m / t + 1) |
|
term_1_log_dr = log_x_theta_at_x0 |
|
|
|
term_2_coef = 1 - dt / t |
|
term_2_log_nr = term_1_log_nr |
|
term_2_log_dr = torch.log(alpha_s * x_theta_at_m / (t - dt) + 1) |
|
|
|
L_vb_masked = term_1_coef * (term_1_log_nr - term_1_log_dr) + term_2_coef * (term_2_log_nr - term_2_log_dr) |
|
|
|
L_vb = L_vb_masked * (xt == self.mask_index) |
|
|
|
return self.T * L_vb |
|
|
|
def _d3pm_parameterization(self, logits): |
|
if self.subs_masking: |
|
logits[:, :, self.mask_index] += self.neg_infinity |
|
logits = logits - torch.logsumexp(logits, dim=-1, keepdim=True) |
|
return logits |
|
|
|
def _sedd_parameterization(self, logits, xt, sigma): |
|
esigm1_log = torch.where(sigma < 0.5, torch.expm1(sigma), sigma.exp() - 1).log().to(logits.dtype) |
|
|
|
|
|
logits = logits - esigm1_log[:, None, None] - np.log(logits.shape[-1] - 1) |
|
|
|
|
|
logits = torch.scatter(logits, -1, xt[..., None], torch.zeros_like(logits[..., :1])) |
|
return logits |
|
|
|
def get_base_shapes_for_mup(self, _model): |
|
from copy import deepcopy |
|
|
|
from mup import set_base_shapes |
|
from omegaconf import read_write |
|
|
|
base_config = deepcopy(self.config) |
|
with read_write(base_config): |
|
base_config.model.hidden_size = base_config.model.n_heads |
|
|
|
delta_config = deepcopy(base_config) |
|
with read_write(delta_config): |
|
delta_config.model.hidden_size = base_config.model.n_heads * 2 |
|
|
|
base_model = models.dit.DIT( |
|
base_config, vocab_size=self.vocab_size, mask_index=self.mask_index, text_vocab_size=self.text_vocab_size, dtype=self.dtype |
|
) |
|
|
|
delta_model = models.dit.DIT( |
|
delta_config, vocab_size=self.vocab_size, mask_index=self.mask_index, text_vocab_size=self.text_vocab_size, dtype=self.dtype |
|
) |
|
|
|
set_base_shapes(_model, base_model, delta=delta_model) |
|
|
|
|
|
def update_histogram(histogram, timesteps: torch.Tensor, losses: torch.Tensor): |
|
for t, l in zip(timesteps, losses): |
|
if t.item() in histogram: |
|
histogram[t.item()].append(l.item()) |
|
else: |
|
histogram[t.item()] = [l.item()] |
|
|
|
def _maybe_sub_sample(self, x0, attention_mask): |
|
seqlen = x0.shape[1] |
|
if seqlen > self.config.model.length: |
|
if not getattr(self.config.eval, 'big_seq_len_eval', False): |
|
assert seqlen == 2 * self.config.model.length |
|
|
|
|
|
start = np.random.choice(self.config.model.length) |
|
end = start + self.config.model.length |
|
input_tokens = x0[:, start:end] |
|
output_tokens = x0[:, start + 1 : end + 1] |
|
new_attention_mask = attention_mask[:, start:end] |
|
|
|
|
|
|
|
input_tokens[:, 0] = self.tokenizer.bos_token_id |
|
output_tokens[:, -1] = self.tokenizer.eos_token_id |
|
else: |
|
input_tokens = x0 |
|
output_tokens = None |
|
new_attention_mask = attention_mask |
|
return input_tokens, output_tokens, new_attention_mask |
|
|
|
from unidisc.tokenizers.image_tokenizers import decode_latents |
|
|
|
|
|
def viz_images_from_dataloader(self): |
|
_iter = iter(self.train_dataloader) |
|
random_elements = [next(_iter) for _ in range(10)] |
|
|
|
out = decode_latents(self.config, self.get_vae(), torch.cat([torch.zeros_like(random_elements[0]['input_ids'][:, :1]), (random_elements[0]['input_ids'] - self.text_vocab_size)], dim=-1)) |
|
from image_utils import Im |
|
print(Im(out[:16]).save()) |
|
breakpoint() |
|
return random_elements |
|
|
|
try: |
|
from torch.nn.attention.flex_attention import create_block_mask |
|
except: |
|
pass |
|
|
|
def _attn_mask(txt_batch_dropout, img_batch_dropout, txt_length): |
|
def mask_mod(b, h, q_idx, kv_idx): |
|
txt_sees_txt = (q_idx < txt_length) & (kv_idx < txt_length) |
|
img_sees_img_and_txt = (q_idx >= txt_length) |
|
txt_dropout_case = ~txt_batch_dropout[b] | (txt_sees_txt | img_sees_img_and_txt) |
|
|
|
img_sees_img = ((q_idx >= txt_length) & (kv_idx >= txt_length)) |
|
txt_sees_txt_and_img = (q_idx < txt_length) |
|
img_dropout_case = ~img_batch_dropout[b] | (img_sees_img | txt_sees_txt_and_img) |
|
return txt_dropout_case & img_dropout_case |
|
return mask_mod |
|
|
|
|
|
def get_block_mask(txt_batch_attn_dropout, img_batch_attn_dropout, txt_length, batch_size, seq_len, device): |
|
return create_block_mask( |
|
_attn_mask(txt_batch_attn_dropout, img_batch_attn_dropout, txt_length), |
|
B = batch_size, H = None, Q_LEN = seq_len, KV_LEN = seq_len, device = device |
|
) |
|
|
|
def _interleaved_attn_mask(interleaved_sample_ids): |
|
def mask_mod(b, h, q_idx, kv_idx): |
|
return (interleaved_sample_ids[b, q_idx] == interleaved_sample_ids[b, kv_idx]) & (interleaved_sample_ids[b, q_idx] != -1) |
|
return mask_mod |
|
|
|
def visualize_flex_attention(mask_mod, B, SEQ_LEN, H=16, HEAD_DIM=64, device="cuda"): |
|
from models.archived.utils import visualize_attention_scores |
|
def make_tensor(): |
|
return torch.ones(B, H, SEQ_LEN, HEAD_DIM, device=device) |
|
|
|
query, key = make_tensor(), make_tensor() |
|
visualize_attention_scores( |
|
query, |
|
key, |
|
mask_mod=mask_mod, |
|
device=device, |
|
name="interleaved_attn_mask", |
|
) |
|
|
|
def get_interleaved_block_mask(interleaved_sample_ids, batch_size, seq_len, device, visualize=False): |
|
|
|
if visualize: |
|
visualize_flex_attention(_interleaved_attn_mask(interleaved_sample_ids), batch_size, seq_len, device=device) |
|
if (interleaved_sample_ids == -1).all(dim=-1).any(): |
|
gprint(f"WARNING: Found all -1s in interleaved_sample_ids, setting one to 0") |
|
interleaved_sample_ids = interleaved_sample_ids.clone() |
|
interleaved_sample_ids[(interleaved_sample_ids == -1).all(dim=-1), 0] = 0 |
|
|
|
return create_block_mask( |
|
_interleaved_attn_mask(interleaved_sample_ids), |
|
B = batch_size, H = None, Q_LEN = seq_len, KV_LEN = seq_len, device = device |
|
) |
|
|
|
def calculate_clip_score( |
|
image_paths: List[str], |
|
captions_mapping: Dict[str, str], |
|
device: torch.device = "cuda", |
|
seed: Optional[int] = 42, |
|
batch_size: int = 128, |
|
dataloader_workers: int = 16, |
|
verbose: bool = True, |
|
): |
|
import clip |
|
from T2IBenchmark.feature_extractors import (BaseFeatureExtractor, |
|
InceptionV3FE) |
|
from T2IBenchmark.loaders import CaptionImageDataset |
|
from T2IBenchmark.model_wrapper import (ModelWrapperDataloader, |
|
T2IModelWrapper) |
|
from T2IBenchmark.utils import dprint, set_all_seeds |
|
|
|
if seed: |
|
set_all_seeds(seed) |
|
|
|
model, preprocess = clip.load("ViT-B/32", device=device) |
|
dataset = CaptionImageDataset( |
|
images_paths=image_paths, |
|
captions=list(map(lambda x: captions_mapping[x], image_paths)), |
|
preprocess_fn=preprocess, |
|
) |
|
dataloader = DataLoader( |
|
dataset, |
|
batch_size=batch_size, |
|
shuffle=False, |
|
drop_last=False, |
|
num_workers=dataloader_workers, |
|
) |
|
|
|
score_acc = 0.0 |
|
num_samples = 0.0 |
|
|
|
for image, caption in tqdm(dataloader): |
|
image_embedding = model.encode_image(image.to(device)) |
|
caption_embedding = model.encode_text(clip.tokenize(caption, truncate=True).to(device)) |
|
|
|
image_features = image_embedding / image_embedding.norm(dim=1, keepdim=True).to( |
|
torch.float32 |
|
) |
|
caption_features = caption_embedding / caption_embedding.norm( |
|
dim=1, keepdim=True |
|
).to(torch.float32) |
|
|
|
score = (image_features * caption_features).sum() |
|
score_acc += score |
|
num_samples += image.shape[0] |
|
|
|
clip_score = score_acc / num_samples |
|
dprint(verbose, f"CLIP score is {clip_score}") |
|
|
|
return clip_score |
|
|
|
def get_chameleon_txt_indices(vae, include_special_tokens=True): |
|
image_indices = set(vae.chameleon_ori_translation.bpe2img.keys()) |
|
if include_special_tokens: |
|
h_grids, w_grids = 32, 32 |
|
image_start_token = vae.token2id(vae.image_start_token) |
|
n_grids_token = vae.token2id(vae.get_n_grids_token(h_grids)) |
|
image_end_token = vae.token2id(vae.image_end_token) |
|
image_indices.add(image_start_token) |
|
image_indices.add(n_grids_token) |
|
image_indices.add(image_end_token) |
|
image_indices.add(-100) |
|
image_indices.add(1) |
|
image_indices.update(range(8192, 8820 + 1)) |
|
|
|
return image_indices |