File size: 32,090 Bytes
131da64 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 |
import math
import os
import random
import typing
from contextlib import nullcontext
from dataclasses import dataclass
from pathlib import Path
from types import FrameType
from typing import Dict, List, Optional, Tuple, Union
import einops
import hydra
import hydra.utils
import numpy as np
import pandas as pd
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
import torchmetrics
import transformers
from image_utils import Im
from torch import Tensor, nn
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from tqdm.auto import tqdm
import models
import wandb
from decoupled_utils import (Profiler, barrier, dprint, get_rank,
get_slurm_job_id, get_world_size, gprint,
is_local_main_process, is_main_process,
is_torch_cuda_available, is_torch_xla_available,
module_hash, mprint, parameter_hash, print_memory,
rank_zero_fn, rprint, save_memory_profile,
show_memory_usage, try_except, use_dist)
is_xla_available = is_torch_xla_available()
if is_xla_available:
from unidisc.utils.standalone_metrics import MeanMetric, MetricCollection
else:
from torchmetrics import MetricCollection
from torchmetrics.aggregation import MeanMetric
LOG2 = math.log(2)
@try_except(write_error_to_file=True)
def log(*arg, **kwargs):
for key, value in arg[0].items():
if isinstance(value, torch.Tensor):
arg[0][key] = value.detach().cpu().float()
if is_main_process():
wandb.log(*arg, **kwargs)
def replace_nan_dict(x):
return {k: v.nan_to_num(0) for k, v in x.items()}
def ddprint(*args, **kwargs):
mprint(*args, **kwargs)
def empty_device_cache():
if is_torch_cuda_available():
torch.cuda.empty_cache()
else:
dprint("Not using cuda, skipping cache clear")
def update_logs(_logs, _extra_logs):
_logs.update(_extra_logs())
for k, v in _logs.items():
if isinstance(v, torch.Tensor):
_logs[k] = v.detach().cpu().item()
gprint(f"Converting {k} to item: {v}")
log(_logs)
def ema_update(model_dest: nn.Module, model_src: nn.Module, rate):
param_dict_src = dict(model_src.named_parameters())
for p_name, p_dest in model_dest.named_parameters():
if p_name not in param_dict_src:
print(f"Parameter {p_name} not found in src: {param_dict_src}")
p_src = param_dict_src[p_name]
assert p_src is not p_dest
p_dest.data.mul_(rate).add_((1 - rate) * p_src.data)
def identity(x):
return x
def remap_image_torch(image):
image_torch = image * 255
image_torch = torch.clip(image_torch, 0, 255).to(torch.uint8)
return image_torch
def _sample_categorical(categorical_probs):
gumbel_norm = 1e-10 - (torch.rand_like(categorical_probs) + 1e-10).log()
return (categorical_probs / gumbel_norm).argmax(dim=-1)
def wrapped_batch_decode(tokenizer, tokens, disable_mask_after_eos=False, **kwargs):
tokens = tokens.clone()
if (tokenizer.bos_token_id != tokenizer.eos_token_id) and not disable_mask_after_eos:
after_first_five = torch.cumsum(tokens == tokenizer.eos_token_id, dim=1).bool()
tokens[after_first_five.cumsum(dim=1) > 1] = tokenizer.pad_token_id
return tokenizer.batch_decode(tokens, **kwargs)
def _unsqueeze(x, reference):
return x.view(*x.shape, *((1,) * (len(reference.shape) - len(x.shape))))
@dataclass
class Loss:
loss: torch.FloatTensor
img_loss: torch.FloatTensor = None
txt_loss: torch.FloatTensor = None
nlls: torch.FloatTensor = None
token_mask: torch.FloatTensor = None
txt_nlls: torch.FloatTensor = None
img_nlls: torch.FloatTensor = None
extra_losses: dict = None
modality_mask: torch.FloatTensor = None
class NLL(MeanMetric):
pass
class BPD(NLL):
def compute(self) -> Tensor:
"""Computes the bits per dimension.
Returns:
bpd
"""
return self.mean_value / self.weight / LOG2
class Perplexity(NLL):
def compute(self) -> Tensor:
"""Computes the Perplexity.
Returns:
Perplexity
"""
return torch.exp(self.mean_value / self.weight)
class Entropy(NLL):
def compute(self) -> Tensor:
"""Computes the Entropy.
Returns:
Entropy
"""
return self.mean_value / self.weight
class MauveScore(NLL):
def compute(self) -> Tensor:
"""Computes the Mauve Score.
Returns:
Mauve Score
"""
return self.mean_value / self.weight
class CIDErScore(NLL):
def compute(self) -> Tensor:
"""Computes the CIDEr Score.
Returns:
CIDEr Score
"""
return self.mean_value / self.weight
class Accuracy(NLL):
def compute(self) -> Tensor:
"""Computes the Accuracy.
Returns:
Accuracy
"""
return self.mean_value / self.weight
def get_coord_plot(self):
from mup.coord_check import get_coord_data, plot_coord_data
def gen(w):
def f():
from copy import deepcopy
from omegaconf import read_write
import models as _models
_conf = deepcopy(self.config)
with read_write(_conf):
_conf.model.hidden_size = _conf.model.n_heads * w
_backbone = _models.dit.DIT(
_conf, vocab_size=self.vocab_size, mask_index=self.mask_index, text_vocab_size=self.text_vocab_size, dtype=self.dtype
)
self.get_base_shapes_for_mup(_backbone)
return _backbone
return f
optimizer = 'adamw'
widths = np.array([2**i for i in range(2, 6)])
models = {int(w) * self.config.model.n_heads: gen(int(w)) for w in widths}
fake_dataloader = []
self.validation_dataloader.num_workers = 0
nsteps = 30
for i, dataloader_batch in enumerate(self.validation_dataloader):
fake_batch = self.update_batch(dataloader_batch)
fake_batch['x0'] = fake_batch["input_ids"]
t = self._sample_t(fake_batch['x0'].shape[0], fake_batch['x0'].device)
sigma, dsigma = self.noise(t)
move_chance = 1 - torch.exp(-sigma[:, None])
xt = self.q_xt(fake_batch['x0'], move_chance)
fake_batch['xt'] = xt
fake_dataloader.append(fake_batch)
if i >= nsteps:
break
def loss_fn(_batch, _logits):
attention_mask = _batch['attention_mask']
model_output = self._subs_parameterization(logits=_logits, xt=_batch['xt'])
log_p_theta = torch.gather(input=model_output, dim=-1, index=_batch['x0'][:, :, None]).squeeze(-1)
std_weighting = (dsigma / torch.expm1(sigma))[:, None]
loss = -log_p_theta * std_weighting
loss = (loss * attention_mask).sum() / attention_mask.sum()
return loss
mup = True
lr = 1e-2
prm = 'μP' if mup else 'SP'
nseeds = 2
with torch.autocast(device_type=self.device.type, dtype=self.dtype):
df = get_coord_data(
models,
fake_dataloader,
lr=lr,
optimizer=optimizer,
nsteps=nsteps,
nseeds=nseeds,
dict_in_out=True,
lossfn=loss_fn,
mup=mup,
)
output_path = Path(__file__).parent / 'output' / f'{prm.lower()}_trsfmr_{optimizer}_coord.png'
output_path.parent.mkdir(parents=True, exist_ok=True)
plot_coord_data(
df,
legend='brief',
save_to=str(output_path.resolve()),
suptitle=f'{prm} Transformer {optimizer} lr={lr} nseeds={nseeds}',
face_color='xkcd:light grey' if not mup else None,
loglog=True
)
rprint(f"Saved coord plot to {output_path.resolve()}")
csv_path = output_path.with_suffix('.csv')
df.to_csv(csv_path, index=False)
rprint(f"DataFrame saved as CSV to {csv_path.resolve()}")
result = df[df['t'] == 1].nsmallest(100, 'l1').sort_values('l1', ascending=True)
with pd.option_context('display.max_rows', None, 'display.max_columns', None):
print(result[['module', 'width', 'l1']])
exit()
def _score_entropy(self, log_score, sigma, xt, x0):
"""Computes the SEDD loss.
Args:
log_score: float torch.Tensor with shape (batch_size,
diffusion_model_input_length, vocab_size),
log score, output of the denoising network.
xt: int torch.Tensor with shape (batch_size,
diffusion_model_input_length), input.
x0: int torch.Tensor with shape (batch_size,
diffusion_model_input_length), input.
sigma: float torch.Tensor with shape (batch_size, 1).
Returns:
loss with shape (batch_size, diffusion_model_input_length)
"""
masked_indices = xt == self.mask_index
expsig_minus_1 = torch.expm1(sigma).expand_as(xt)
q_ratio = 1 / expsig_minus_1[masked_indices]
words_that_were_masked = x0[masked_indices]
neg_term = q_ratio * torch.gather(log_score[masked_indices], -1, words_that_were_masked[..., None]).squeeze(-1)
score = log_score[masked_indices].exp()
if self.mask_index == self.vocab_size - 1:
pos_term = score[:, :-1].sum(dim=-1)
else:
pos_term = score[:, : self.mask_index].sum(dim=-1) + score[:, self.mask_index + 1 :].sum(dim=-1)
const = q_ratio * (q_ratio.log() - 1)
entropy = torch.zeros(*xt.shape, device=xt.device)
entropy[masked_indices] += pos_term - neg_term + const
return entropy
@torch.no_grad
def sample_subs_guidance(self, n_samples, stride_length, num_strides, dt=0.001):
ones = torch.ones(n_samples, dtype=self.dtype, device=self.device)
num_steps = int(1 / dt)
sampling_steps = 0
intermediate_tokens = []
target = None
for _ in range(num_strides + 1):
p_x0_cache = None
x = self._sample_prior(n_samples, self.config.model.length).to(self.device)
if target is not None:
x[:, :-stride_length] = target
for i in range(num_steps + 1):
p_x0_cache, x_next, nfe_cnt = self._ddpm_caching_update(x=x, t=(1 - i * dt) * ones, dt=dt, p_x0=p_x0_cache)
if not torch.allclose(x_next, x) or self.time_conditioning:
p_x0_cache = None
sampling_steps += 1
x = x_next
x = self.forward(x, 0 * ones).argmax(dim=-1)
intermediate_tokens.append(x[:, :stride_length].cpu().numpy())
target = x[:, stride_length:]
intermediate_tokens.append(target.cpu().numpy())
intermediate_text_samples = []
sequence_lengths = ((np.concatenate(intermediate_tokens, axis=1)[:, 1:] == self.tokenizer.eos_token_id).cumsum(-1) == 0).sum(-1)
for i in range(2, len(intermediate_tokens) + 1):
intermediate_text_samples.append(self.tokenizer.batch_decode(np.concatenate(intermediate_tokens[:i], axis=1)))
return (sampling_steps, intermediate_text_samples, sequence_lengths)
def restore_model_and_semi_ar_sample(self, stride_length, num_strides, dt=0.001):
"""Generate samples from the model."""
# Lightning auto-casting is not working in this method for some reason
if self.ema:
self.ema.store(self.get_params())
self.ema.copy_to(self.get_params())
self.backbone.eval()
(sampling_steps, samples, sequence_lengths) = self.sample_subs_guidance(
n_samples=self.config.loader.eval_batch_size, stride_length=stride_length, num_strides=num_strides, dt=dt
)
if self.ema:
self.ema.restore(self.get_params())
self.backbone.train()
self.noise.train()
return sampling_steps, samples, sequence_lengths
def _reconstruction_loss(self, x0):
t0 = torch.zeros(x0.shape[0], dtype=self.dtype, device=self.device)
assert self.config.noise.type == "loglinear"
# The above assert is for d3pm parameterization
unet_conditioning = self.noise(t0)[0][:, None]
model_output_t0 = self.forward(x0, unet_conditioning)
return -torch.gather(input=model_output_t0, dim=-1, index=x0[:, :, None]).squeeze(-1)
def restore_model_and_sample(self, num_steps, eps=1e-5):
"""Generate samples from the model."""
# Lightning auto-casting is not working in this method for some reason
if self.ema is not None:
self.ema.store(self.get_params())
self.ema.copy_to(self.get_params())
self.backbone.eval()
samples = self._sample(num_steps=num_steps, eps=eps)
if self.ema is not None:
self.ema.restore(self.get_params())
self.backbone.train()
return samples
def get_score(self, x, sigma, **kwargs):
model_output = self.forward(x, sigma, **kwargs)
if self.parameterization == "subs":
# score(x, t) = p_t(y) / p_t(x)
# => log score(x, t) = log p_t(y) - log p_t(x)
# case 1: x = masked
# (i) y = unmasked
# log score(x, t) = log p_\theta(x)|_y + log k
# where k = exp(- sigma) / (1 - exp(- sigma))
# (ii) y = masked
# log score(x, t) = 0
# case 2: x = unmasked
# (i) y != masked, y != x
# log score(x_i, t) = - inf
# (ii) y = x
# log score(x_i, t) = 0
# (iii) y = masked token
# log score(x_i, t) = - log k
# where k = exp(- sigma) / (1 - exp(- sigma))
log_k = -torch.log(torch.expm1(sigma)).squeeze(-1)
assert log_k.ndim == 1
masked_score = model_output + log_k[:, None, None]
masked_score[:, :, self.mask_index] = 0
unmasked_score = self.neg_infinity * torch.ones_like(model_output)
unmasked_score = torch.scatter(unmasked_score, -1, x[..., None], torch.zeros_like(unmasked_score[..., :1]))
unmasked_score[:, :, self.mask_index] = -(log_k[:, None] * torch.ones_like(x))
masked_indices = (x == self.mask_index).to(model_output.dtype)[:, :, None]
model_output = masked_score * masked_indices + unmasked_score * (1 - masked_indices)
return model_output.exp()
def _staggered_score(self, score, dsigma):
score = score.clone()
extra_const = (1 - dsigma.exp()) * score.sum(dim=-1)
score *= dsigma.exp()[:, None]
score[..., self.mask_index] += extra_const
return score
def _analytic_update(self, x, t, step_size):
curr_sigma, _ = self.noise(t)
next_sigma, _ = self.noise(t - step_size)
dsigma = curr_sigma - next_sigma
nfe_cnt = 0
score = self.get_score(x, curr_sigma)
nfe_cnt += 1
stag_score = self._staggered_score(score, dsigma)
probs = stag_score * self._transp_transition(x, dsigma)
return _sample_categorical(probs), nfe_cnt
def _denoiser_update(self, x, t):
sigma, _ = self.noise(t)
score = self.get_score(x, sigma)
stag_score = self._staggered_score(score, sigma)
probs = stag_score * self._transp_transition(x, sigma)
probs[..., self.mask_index] = 0
samples = _sample_categorical(probs)
return samples
def _transp_transition(self, i, sigma):
sigma = _unsqueeze(sigma, reference=i[..., None])
edge = torch.exp(-sigma) * F.one_hot(i, num_classes=self.vocab_size)
edge += torch.where(i == self.mask_index, 1 - torch.exp(-sigma).squeeze(-1), 0)[..., None]
return edge
@torch.no_grad()
def eval_retokenize(self, text_samples, max_length):
"""Retokenizes samples for the eval model.
Args:
text_samples: List of sentences generated by the model.
Returns:
samples: Samples re-tokenized for the eval model
attn_mask: Attention mask for the eval model
eval_context_size: Size of the context for the eval model
"""
if "llama2" in self.gen_ppl_eval_model_name_or_path:
tokenizer_kwargs = {
"text_samples": text_samples,
"return_tensors": "pt",
"return_token_type_ids": False,
"return_attention_mask": True,
"truncation": True,
"padding": True,
"max_length": max_length,
}
eval_context_size = 4096
else:
tokenizer_kwargs = {
"return_tensors": "pt",
"return_token_type_ids": False,
"return_attention_mask": True,
"truncation": True,
"padding": True,
"max_length": max_length,
}
eval_context_size = 1024
if getattr(self.config.eval, "force_eval_context_size_match_model", False):
eval_context_size = self.config.model.txt_length
samples = self.eval_model_tokenizer(text_samples, **tokenizer_kwargs)
attn_mask = samples["attention_mask"]
samples = samples["input_ids"]
if "llama2" not in self.gen_ppl_eval_model_name_or_path:
attn_mask = attn_mask.to(self.device)
samples = samples.to(self.device)
return samples, attn_mask, eval_context_size
@try_except(write_error_to_file=True)
@torch.no_grad()
def compute_cider(self, text_samples, gt_text_samples):
"""Compute the CIDEr score for the generated text.
Args:
text_samples: List of sentences generated by the model.
gt_text_samples: List of ground truth sentences.
Returns:
CIDEr score for the generated text.
"""
for text_sample, gt_text_sample in zip(text_samples, gt_text_samples):
self.cider_score_metric += (text_sample, gt_text_sample)
score = self.cider_score_metric.compute_cider() # list of np.float64
avg_score = sum(score) / len(score)
self.cider_score.update(avg_score.item()) # weight=len(text_samples))
def get_anole_data(model, processor, prompt, image, device):
inputs = processor(prompt, [image], padding=True, return_tensors="pt").to(device=device, dtype=dtype)
image_tokens = model.model.get_image_tokens(inputs["pixel_values"])
special_image_mask = inputs["input_ids"] == model.model.vocabulary_mapping.image_token_id
image_tokens = image_tokens.to(inputs["input_ids"].device, inputs["input_ids"].dtype)
inputs["input_ids"] = inputs["input_ids"].masked_scatter(special_image_mask, image_tokens)
inputs.pop("pixel_values")
inputs['input_ids'] = torch.load('save.pth').to(device)
return inputs
@try_except(write_error_to_file=True)
@torch.inference_mode()
def compute_generative_perplexity(self, text_samples: typing.List[str], retokenize: bool = True, max_length: typing.Optional[int] = None, gt: bool = False, return_raw_score: bool = False) -> None:
"""Compute the generative perplexity of the model.
Args:
text_samples: List of sentences generated by the model.
retokenize: Whether to retokenize using eval model's tokenizer
max_length: Maximum sequence length for tokenization
gt: Whether these are ground truth samples
return_raw_score: Whether to return raw NLL scores instead of updating metrics
Returns:
If return_raw_score is True, returns tensor of NLL scores.
Otherwise updates internal perplexity metrics.
"""
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if not getattr(self.config.eval, 'enable_gen_pplx_cleanup', True):
eval_model = self.gen_pplx_eval_model
elif getattr(self.config.eval, 'gen_ppl_use_chameleon', False):
from transformers import (ChameleonForConditionalGeneration,
ChameleonProcessor)
model = ChameleonForConditionalGeneration.from_pretrained("leloy/Anole-7b-v0.1-hf", torch_dtype=torch.bfloat16).to("cuda")
processor = ChameleonProcessor.from_pretrained("leloy/Anole-7b-v0.1-hf")
image = Im(Im("https://cdn.outsideonline.com/wp-content/uploads/2023/03/Funny_Dog_H.jpg").np[50:-150, 550:-900, :]).resize(256, 256).pil
prompt = "A picture of a cat.<image>"
device = "cuda:0"
inputs = get_anole_data(model, processor, prompt, image, self.dtype, device)
output = model(input_ids=inputs['input_ids'].to(device))
attention_mask = torch.ones_like(inputs["input_ids"])
logits = output.logits
logits = logits.transpose(-1, -2)
sample_chunk = inputs["input_ids"]
nlls = F.cross_entropy(logits[..., :-1].to(device), sample_chunk[..., 1:].to(device), reduction="none")
nlls = nlls * attention_mask[..., 1:].to(nlls.dtype)
nlls = nlls.sum(-1) / attention_mask[..., 1:].sum(-1)
print(torch.exp(nlls))
else:
eval_model = transformers.AutoModelForCausalLM.from_pretrained(self.gen_ppl_eval_model_name_or_path).eval()
if max_length is None:
max_length = self.config.model.txt_length
if "llama2" not in self.gen_ppl_eval_model_name_or_path:
eval_model = eval_model.to(self.device)
# Re-tokenize using eval model's tokenizer
if retokenize:
(samples, attn_mask, eval_context_size) = self.eval_retokenize(text_samples, max_length=max_length)
else:
samples = text_samples
attn_mask = torch.ones(samples.shape).to(self.device)
eval_context_size = samples.shape[-1]
batch_size = min(self.config.eval.perplexity_batch_size, samples.shape[0])
num_batches = (samples.shape[0] + batch_size - 1) // batch_size
all_nlls = []
all_valid_mask = []
for i in range(num_batches):
batch_samples = samples[i * batch_size : (i + 1) * batch_size]
batch_attn_mask = attn_mask[i * batch_size : (i + 1) * batch_size]
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.FLASH_ATTENTION, torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION]):
logits = eval_model(batch_samples, attention_mask=batch_attn_mask)[0]
logits = logits.transpose(-1, -2)
nlls = F.cross_entropy(logits[..., :-1], batch_samples[..., 1:], reduction="none")
# Only consider tokens up to first EOS or padding
first_eos = (batch_samples == self.eval_model_tokenizer.eos_token_id).cumsum(-1) <= 1
token_mask = batch_attn_mask[..., 1:] > 0
valid_mask = first_eos[..., 1:] * token_mask
if not return_raw_score:
if gt:
self.gt_gen_ppl_metric.update(nlls, valid_mask)
else:
self.gen_ppl_metric.update(nlls, valid_mask)
else:
all_nlls.append(nlls)
all_valid_mask.append(valid_mask)
if getattr(self.config.eval, 'enable_gen_pplx_cleanup', True):
eval_model.to(torch.device('cpu'))
del eval_model
if return_raw_score:
all_nlls = torch.cat(all_nlls)
all_valid_mask = torch.cat(all_valid_mask)
# Compute mean NLL per sequence, ignoring padding/post-EOS tokens
nll = (all_nlls * all_valid_mask).sum(-1) / all_valid_mask.sum(-1)
return nll
def _d3pm_loss(self, model_output, xt, x0, t):
dt = 1 / self.T
if torch.is_tensor(t):
t = t[:, None]
assert t.ndim == 2
t = t.clamp(0.0, 1.0 - 1e-4)
alpha_t = 1 - t + torch.zeros_like(xt)
alpha_s = 1 - (t - dt) + torch.zeros_like(xt)
log_x_theta_at_x0 = torch.gather(model_output, -1, x0[:, :, None]).squeeze(-1)
log_x_theta_at_m = model_output[:, :, self.mask_index]
x_theta_at_m = log_x_theta_at_m.exp()
term_1_coef = dt / t
term_1_log_nr = torch.log(alpha_t * x_theta_at_m / t + 1)
term_1_log_dr = log_x_theta_at_x0
term_2_coef = 1 - dt / t
term_2_log_nr = term_1_log_nr
term_2_log_dr = torch.log(alpha_s * x_theta_at_m / (t - dt) + 1)
L_vb_masked = term_1_coef * (term_1_log_nr - term_1_log_dr) + term_2_coef * (term_2_log_nr - term_2_log_dr)
L_vb = L_vb_masked * (xt == self.mask_index)
return self.T * L_vb
def _d3pm_parameterization(self, logits):
if self.subs_masking:
logits[:, :, self.mask_index] += self.neg_infinity
logits = logits - torch.logsumexp(logits, dim=-1, keepdim=True)
return logits
def _sedd_parameterization(self, logits, xt, sigma):
esigm1_log = torch.where(sigma < 0.5, torch.expm1(sigma), sigma.exp() - 1).log().to(logits.dtype)
# logits shape
# (batch_size, diffusion_model_input_length, vocab_size)
logits = logits - esigm1_log[:, None, None] - np.log(logits.shape[-1] - 1)
# The below scatter operation sets the log score
# for the input word to 0.
logits = torch.scatter(logits, -1, xt[..., None], torch.zeros_like(logits[..., :1]))
return logits
def get_base_shapes_for_mup(self, _model):
from copy import deepcopy
from mup import set_base_shapes
from omegaconf import read_write
base_config = deepcopy(self.config)
with read_write(base_config):
base_config.model.hidden_size = base_config.model.n_heads # We need at least n_heads dim
delta_config = deepcopy(base_config)
with read_write(delta_config):
delta_config.model.hidden_size = base_config.model.n_heads * 2
base_model = models.dit.DIT(
base_config, vocab_size=self.vocab_size, mask_index=self.mask_index, text_vocab_size=self.text_vocab_size, dtype=self.dtype
)
delta_model = models.dit.DIT(
delta_config, vocab_size=self.vocab_size, mask_index=self.mask_index, text_vocab_size=self.text_vocab_size, dtype=self.dtype
)
set_base_shapes(_model, base_model, delta=delta_model)
def update_histogram(histogram, timesteps: torch.Tensor, losses: torch.Tensor):
for t, l in zip(timesteps, losses):
if t.item() in histogram:
histogram[t.item()].append(l.item())
else:
histogram[t.item()] = [l.item()]
def _maybe_sub_sample(self, x0, attention_mask):
seqlen = x0.shape[1]
if seqlen > self.config.model.length:
if not getattr(self.config.eval, 'big_seq_len_eval', False):
assert seqlen == 2 * self.config.model.length
# cropping is needed for text8-crop dataset
# try the same starting point for now
start = np.random.choice(self.config.model.length)
end = start + self.config.model.length
input_tokens = x0[:, start:end]
output_tokens = x0[:, start + 1 : end + 1]
new_attention_mask = attention_mask[:, start:end]
# Helps with validation PPL, since the val
# examples will all start and end with BOS/EOS
input_tokens[:, 0] = self.tokenizer.bos_token_id
output_tokens[:, -1] = self.tokenizer.eos_token_id
else:
input_tokens = x0
output_tokens = None
new_attention_mask = attention_mask
return input_tokens, output_tokens, new_attention_mask
from unidisc.tokenizers.image_tokenizers import decode_latents
def viz_images_from_dataloader(self):
_iter = iter(self.train_dataloader)
random_elements = [next(_iter) for _ in range(10)]
# random_elements[0]['input_ids'] - self.text_vocab_size
out = decode_latents(self.config, self.get_vae(), torch.cat([torch.zeros_like(random_elements[0]['input_ids'][:, :1]), (random_elements[0]['input_ids'] - self.text_vocab_size)], dim=-1))
from image_utils import Im
print(Im(out[:16]).save())
breakpoint()
return random_elements
try:
from torch.nn.attention.flex_attention import create_block_mask
except:
pass
def _attn_mask(txt_batch_dropout, img_batch_dropout, txt_length):
def mask_mod(b, h, q_idx, kv_idx):
txt_sees_txt = (q_idx < txt_length) & (kv_idx < txt_length)
img_sees_img_and_txt = (q_idx >= txt_length)
txt_dropout_case = ~txt_batch_dropout[b] | (txt_sees_txt | img_sees_img_and_txt)
img_sees_img = ((q_idx >= txt_length) & (kv_idx >= txt_length))
txt_sees_txt_and_img = (q_idx < txt_length)
img_dropout_case = ~img_batch_dropout[b] | (img_sees_img | txt_sees_txt_and_img)
return txt_dropout_case & img_dropout_case
return mask_mod
def get_block_mask(txt_batch_attn_dropout, img_batch_attn_dropout, txt_length, batch_size, seq_len, device):
return create_block_mask(
_attn_mask(txt_batch_attn_dropout, img_batch_attn_dropout, txt_length),
B = batch_size, H = None, Q_LEN = seq_len, KV_LEN = seq_len, device = device
)
def _interleaved_attn_mask(interleaved_sample_ids):
def mask_mod(b, h, q_idx, kv_idx):
return (interleaved_sample_ids[b, q_idx] == interleaved_sample_ids[b, kv_idx]) & (interleaved_sample_ids[b, q_idx] != -1)
return mask_mod
def visualize_flex_attention(mask_mod, B, SEQ_LEN, H=16, HEAD_DIM=64, device="cuda"):
from models.archived.utils import visualize_attention_scores
def make_tensor():
return torch.ones(B, H, SEQ_LEN, HEAD_DIM, device=device)
query, key = make_tensor(), make_tensor()
visualize_attention_scores(
query,
key,
mask_mod=mask_mod,
device=device,
name="interleaved_attn_mask",
)
def get_interleaved_block_mask(interleaved_sample_ids, batch_size, seq_len, device, visualize=False):
# Uncomment this to visualize the mask
if visualize:
visualize_flex_attention(_interleaved_attn_mask(interleaved_sample_ids), batch_size, seq_len, device=device)
if (interleaved_sample_ids == -1).all(dim=-1).any():
gprint(f"WARNING: Found all -1s in interleaved_sample_ids, setting one to 0")
interleaved_sample_ids = interleaved_sample_ids.clone()
interleaved_sample_ids[(interleaved_sample_ids == -1).all(dim=-1), 0] = 0
return create_block_mask(
_interleaved_attn_mask(interleaved_sample_ids),
B = batch_size, H = None, Q_LEN = seq_len, KV_LEN = seq_len, device = device
)
def calculate_clip_score(
image_paths: List[str],
captions_mapping: Dict[str, str],
device: torch.device = "cuda",
seed: Optional[int] = 42,
batch_size: int = 128,
dataloader_workers: int = 16,
verbose: bool = True,
):
import clip
from T2IBenchmark.feature_extractors import (BaseFeatureExtractor,
InceptionV3FE)
from T2IBenchmark.loaders import CaptionImageDataset
from T2IBenchmark.model_wrapper import (ModelWrapperDataloader,
T2IModelWrapper)
from T2IBenchmark.utils import dprint, set_all_seeds
if seed:
set_all_seeds(seed)
model, preprocess = clip.load("ViT-B/32", device=device)
dataset = CaptionImageDataset(
images_paths=image_paths,
captions=list(map(lambda x: captions_mapping[x], image_paths)),
preprocess_fn=preprocess,
)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
drop_last=False,
num_workers=dataloader_workers,
)
score_acc = 0.0
num_samples = 0.0
for image, caption in tqdm(dataloader):
image_embedding = model.encode_image(image.to(device))
caption_embedding = model.encode_text(clip.tokenize(caption, truncate=True).to(device))
image_features = image_embedding / image_embedding.norm(dim=1, keepdim=True).to(
torch.float32
)
caption_features = caption_embedding / caption_embedding.norm(
dim=1, keepdim=True
).to(torch.float32)
score = (image_features * caption_features).sum()
score_acc += score
num_samples += image.shape[0]
clip_score = score_acc / num_samples
dprint(verbose, f"CLIP score is {clip_score}")
return clip_score
def get_chameleon_txt_indices(vae, include_special_tokens=True):
image_indices = set(vae.chameleon_ori_translation.bpe2img.keys())
if include_special_tokens:
h_grids, w_grids = 32, 32
image_start_token = vae.token2id(vae.image_start_token)
n_grids_token = vae.token2id(vae.get_n_grids_token(h_grids))
image_end_token = vae.token2id(vae.image_end_token)
image_indices.add(image_start_token)
image_indices.add(n_grids_token)
image_indices.add(image_end_token)
image_indices.add(-100)
image_indices.add(1)
image_indices.update(range(8192, 8820 + 1))
return image_indices |