marriola commited on
Commit
d25f94c
·
verified ·
1 Parent(s): ebaee33

Upload BD3LM

Browse files
Files changed (1) hide show
  1. modeling_bd3lm.py +23 -18
modeling_bd3lm.py CHANGED
@@ -299,7 +299,7 @@ class DDiTBlock(nn.Module):
299
  return bias_dropout_add_scale_fused_inference
300
 
301
 
302
- def get_qkv(self, x, rotary_cos_sin, save_kv=False):
303
  # compute qkv (potentially use cache)
304
  if self.kv_cache is not None:
305
  block_len = x.shape[1] - self.kv_cache.shape[1]
@@ -308,8 +308,8 @@ class DDiTBlock(nn.Module):
308
  else:
309
  qkv = self.attn_qkv(x)
310
 
311
- # save kv cache in a sliding window (can't exceed context len)
312
- if save_kv:
313
  if self.kv_cache is not None:
314
  cache_len = min(x.shape[1], self.n - block_len)
315
  self.kv_cache = qkv[:, -cache_len:]
@@ -347,7 +347,8 @@ class DDiTBlock(nn.Module):
347
  x = einops.rearrange(x, 'b s h d -> b s (h d)')
348
  return x
349
 
350
- def forward(self, x, rotary_cos_sin, c, cross_attn_mask=None, save_kv=False):
 
351
  bias_dropout_scale_fn = self._get_bias_dropout_scale()
352
 
353
  (shift_msa, scale_msa, gate_msa, shift_mlp,
@@ -358,12 +359,12 @@ class DDiTBlock(nn.Module):
358
  x = modulate_fused(self.norm1(x), shift_msa, scale_msa)
359
 
360
  # get qkvs
361
- if cross_attn_mask is not None and not save_kv:
362
  qkv_x = self.get_qkv(x[:,:self.n], rotary_cos_sin)
363
  qkv_x0 = self.get_qkv(x[:,self.n:], rotary_cos_sin)
364
  qkv = torch.cat((qkv_x, qkv_x0), dim=1)
365
  else:
366
- qkv = self.get_qkv(x, rotary_cos_sin, save_kv=save_kv)
367
 
368
  if cross_attn_mask is None and self.attn_backend == 'flash_attn':
369
  x = regular_attention_multi_headed(qkv)
@@ -470,9 +471,8 @@ class DITBackbone(nn.Module):
470
  x0_attn_mask = torch.cat((torch.zeros_like(self_attn_mask), x0_attn_mask), dim=1)
471
  self.cross_attn_mask = torch.cat((cross_attn_mask, x0_attn_mask), dim=0)
472
 
473
- def forward(self, indices, sigma, disable_cross_attn=False,
474
- output_hidden_states=False, save_kv=False):
475
- cross_attn = self.cross_attn and not disable_cross_attn
476
  if not self.config.time_conditioning:
477
  sigma = torch.zeros_like(sigma)
478
  all_hidden_states = []
@@ -480,11 +480,13 @@ class DITBackbone(nn.Module):
480
  if output_hidden_states:
481
  all_hidden_states.append(x)
482
  c = F.silu(self.sigma_map(sigma))
483
- if cross_attn:
484
- cross_attn_mask = self.cross_attn_mask.to(x.device)
485
- if save_kv:
486
- cross_attn_mask = cross_attn_mask[:x.shape[1], :x.shape[1]]
487
  rotary_cos_sin = self.rotary_emb(x[:, :self.n])
 
 
 
 
 
488
  else:
489
  cross_attn_mask = None
490
  rotary_cos_sin = self.rotary_emb(x)
@@ -495,11 +497,12 @@ class DITBackbone(nn.Module):
495
  rotary_cos_sin,
496
  c,
497
  cross_attn_mask=cross_attn_mask,
498
- save_kv=save_kv)
 
499
  if output_hidden_states:
500
  all_hidden_states.append(x)
501
  logits = self.output_layer(x, c)
502
- if cross_attn and not save_kv:
503
  logits = logits[:, :self.n]
504
  all_hidden_states = [hidden_states[:, :self.n] for hidden_states in all_hidden_states]
505
  return logits, all_hidden_states
@@ -526,7 +529,8 @@ class BD3LM(transformers.PreTrainedModel):
526
  self,
527
  input_ids: torch.LongTensor = None,
528
  timesteps: torch.FloatTensor = None,
529
- disable_cross_attn: typing.Optional[bool] = None,
 
530
  output_hidden_states: typing.Optional[bool] = None,
531
  return_dict: typing.Optional[bool] = None,
532
  ) -> typing.Union[
@@ -545,8 +549,9 @@ class BD3LM(transformers.PreTrainedModel):
545
  logits, all_hidden_states = self.backbone(
546
  indices=input_ids,
547
  sigma=timesteps,
548
- disable_cross_attn=disable_cross_attn,
549
- output_hidden_states=output_hidden_states
 
550
  )
551
  if return_dict:
552
  return modeling_outputs.MaskedLMOutput(
 
299
  return bias_dropout_add_scale_fused_inference
300
 
301
 
302
+ def get_qkv(self, x, rotary_cos_sin, store_kv=False):
303
  # compute qkv (potentially use cache)
304
  if self.kv_cache is not None:
305
  block_len = x.shape[1] - self.kv_cache.shape[1]
 
308
  else:
309
  qkv = self.attn_qkv(x)
310
 
311
+ # store kv cache in a sliding window (can't exceed context len)
312
+ if store_kv:
313
  if self.kv_cache is not None:
314
  cache_len = min(x.shape[1], self.n - block_len)
315
  self.kv_cache = qkv[:, -cache_len:]
 
347
  x = einops.rearrange(x, 'b s h d -> b s (h d)')
348
  return x
349
 
350
+ def forward(self, x, rotary_cos_sin, c, cross_attn_mask=None,
351
+ sample_mode=False, store_kv=False):
352
  bias_dropout_scale_fn = self._get_bias_dropout_scale()
353
 
354
  (shift_msa, scale_msa, gate_msa, shift_mlp,
 
359
  x = modulate_fused(self.norm1(x), shift_msa, scale_msa)
360
 
361
  # get qkvs
362
+ if cross_attn_mask is not None and not sample_mode:
363
  qkv_x = self.get_qkv(x[:,:self.n], rotary_cos_sin)
364
  qkv_x0 = self.get_qkv(x[:,self.n:], rotary_cos_sin)
365
  qkv = torch.cat((qkv_x, qkv_x0), dim=1)
366
  else:
367
+ qkv = self.get_qkv(x, rotary_cos_sin, store_kv=store_kv)
368
 
369
  if cross_attn_mask is None and self.attn_backend == 'flash_attn':
370
  x = regular_attention_multi_headed(qkv)
 
471
  x0_attn_mask = torch.cat((torch.zeros_like(self_attn_mask), x0_attn_mask), dim=1)
472
  self.cross_attn_mask = torch.cat((cross_attn_mask, x0_attn_mask), dim=0)
473
 
474
+ def forward(self, indices, sigma, sample_mode=False,
475
+ store_kv=False, output_hidden_states=False):
 
476
  if not self.config.time_conditioning:
477
  sigma = torch.zeros_like(sigma)
478
  all_hidden_states = []
 
480
  if output_hidden_states:
481
  all_hidden_states.append(x)
482
  c = F.silu(self.sigma_map(sigma))
483
+ if self.cross_attn:
 
 
 
484
  rotary_cos_sin = self.rotary_emb(x[:, :self.n])
485
+ cross_attn_mask = self.cross_attn_mask.to(x.device)
486
+ # use block-causal mask only during sampling
487
+ if sample_mode:
488
+ cross_attn_mask = cross_attn_mask[
489
+ self.n:self.n+x.shape[1], self.n:self.n+x.shape[1]]
490
  else:
491
  cross_attn_mask = None
492
  rotary_cos_sin = self.rotary_emb(x)
 
497
  rotary_cos_sin,
498
  c,
499
  cross_attn_mask=cross_attn_mask,
500
+ sample_mode=sample_mode,
501
+ store_kv=store_kv)
502
  if output_hidden_states:
503
  all_hidden_states.append(x)
504
  logits = self.output_layer(x, c)
505
+ if self.cross_attn and not sample_mode:
506
  logits = logits[:, :self.n]
507
  all_hidden_states = [hidden_states[:, :self.n] for hidden_states in all_hidden_states]
508
  return logits, all_hidden_states
 
529
  self,
530
  input_ids: torch.LongTensor = None,
531
  timesteps: torch.FloatTensor = None,
532
+ sample_mode: typing.Optional[bool] = None,
533
+ store_kv: typing.Optional[bool] = None,
534
  output_hidden_states: typing.Optional[bool] = None,
535
  return_dict: typing.Optional[bool] = None,
536
  ) -> typing.Union[
 
549
  logits, all_hidden_states = self.backbone(
550
  indices=input_ids,
551
  sigma=timesteps,
552
+ sample_mode=sample_mode,
553
+ store_kv=store_kv,
554
+ output_hidden_states=output_hidden_states,
555
  )
556
  if return_dict:
557
  return modeling_outputs.MaskedLMOutput(