Upload BD3LM
Browse files- modeling_bd3lm.py +23 -18
modeling_bd3lm.py
CHANGED
@@ -299,7 +299,7 @@ class DDiTBlock(nn.Module):
|
|
299 |
return bias_dropout_add_scale_fused_inference
|
300 |
|
301 |
|
302 |
-
def get_qkv(self, x, rotary_cos_sin,
|
303 |
# compute qkv (potentially use cache)
|
304 |
if self.kv_cache is not None:
|
305 |
block_len = x.shape[1] - self.kv_cache.shape[1]
|
@@ -308,8 +308,8 @@ class DDiTBlock(nn.Module):
|
|
308 |
else:
|
309 |
qkv = self.attn_qkv(x)
|
310 |
|
311 |
-
#
|
312 |
-
if
|
313 |
if self.kv_cache is not None:
|
314 |
cache_len = min(x.shape[1], self.n - block_len)
|
315 |
self.kv_cache = qkv[:, -cache_len:]
|
@@ -347,7 +347,8 @@ class DDiTBlock(nn.Module):
|
|
347 |
x = einops.rearrange(x, 'b s h d -> b s (h d)')
|
348 |
return x
|
349 |
|
350 |
-
def forward(self, x, rotary_cos_sin, c, cross_attn_mask=None,
|
|
|
351 |
bias_dropout_scale_fn = self._get_bias_dropout_scale()
|
352 |
|
353 |
(shift_msa, scale_msa, gate_msa, shift_mlp,
|
@@ -358,12 +359,12 @@ class DDiTBlock(nn.Module):
|
|
358 |
x = modulate_fused(self.norm1(x), shift_msa, scale_msa)
|
359 |
|
360 |
# get qkvs
|
361 |
-
if cross_attn_mask is not None and not
|
362 |
qkv_x = self.get_qkv(x[:,:self.n], rotary_cos_sin)
|
363 |
qkv_x0 = self.get_qkv(x[:,self.n:], rotary_cos_sin)
|
364 |
qkv = torch.cat((qkv_x, qkv_x0), dim=1)
|
365 |
else:
|
366 |
-
qkv = self.get_qkv(x, rotary_cos_sin,
|
367 |
|
368 |
if cross_attn_mask is None and self.attn_backend == 'flash_attn':
|
369 |
x = regular_attention_multi_headed(qkv)
|
@@ -470,9 +471,8 @@ class DITBackbone(nn.Module):
|
|
470 |
x0_attn_mask = torch.cat((torch.zeros_like(self_attn_mask), x0_attn_mask), dim=1)
|
471 |
self.cross_attn_mask = torch.cat((cross_attn_mask, x0_attn_mask), dim=0)
|
472 |
|
473 |
-
def forward(self, indices, sigma,
|
474 |
-
|
475 |
-
cross_attn = self.cross_attn and not disable_cross_attn
|
476 |
if not self.config.time_conditioning:
|
477 |
sigma = torch.zeros_like(sigma)
|
478 |
all_hidden_states = []
|
@@ -480,11 +480,13 @@ class DITBackbone(nn.Module):
|
|
480 |
if output_hidden_states:
|
481 |
all_hidden_states.append(x)
|
482 |
c = F.silu(self.sigma_map(sigma))
|
483 |
-
if cross_attn:
|
484 |
-
cross_attn_mask = self.cross_attn_mask.to(x.device)
|
485 |
-
if save_kv:
|
486 |
-
cross_attn_mask = cross_attn_mask[:x.shape[1], :x.shape[1]]
|
487 |
rotary_cos_sin = self.rotary_emb(x[:, :self.n])
|
|
|
|
|
|
|
|
|
|
|
488 |
else:
|
489 |
cross_attn_mask = None
|
490 |
rotary_cos_sin = self.rotary_emb(x)
|
@@ -495,11 +497,12 @@ class DITBackbone(nn.Module):
|
|
495 |
rotary_cos_sin,
|
496 |
c,
|
497 |
cross_attn_mask=cross_attn_mask,
|
498 |
-
|
|
|
499 |
if output_hidden_states:
|
500 |
all_hidden_states.append(x)
|
501 |
logits = self.output_layer(x, c)
|
502 |
-
if cross_attn and not
|
503 |
logits = logits[:, :self.n]
|
504 |
all_hidden_states = [hidden_states[:, :self.n] for hidden_states in all_hidden_states]
|
505 |
return logits, all_hidden_states
|
@@ -526,7 +529,8 @@ class BD3LM(transformers.PreTrainedModel):
|
|
526 |
self,
|
527 |
input_ids: torch.LongTensor = None,
|
528 |
timesteps: torch.FloatTensor = None,
|
529 |
-
|
|
|
530 |
output_hidden_states: typing.Optional[bool] = None,
|
531 |
return_dict: typing.Optional[bool] = None,
|
532 |
) -> typing.Union[
|
@@ -545,8 +549,9 @@ class BD3LM(transformers.PreTrainedModel):
|
|
545 |
logits, all_hidden_states = self.backbone(
|
546 |
indices=input_ids,
|
547 |
sigma=timesteps,
|
548 |
-
|
549 |
-
|
|
|
550 |
)
|
551 |
if return_dict:
|
552 |
return modeling_outputs.MaskedLMOutput(
|
|
|
299 |
return bias_dropout_add_scale_fused_inference
|
300 |
|
301 |
|
302 |
+
def get_qkv(self, x, rotary_cos_sin, store_kv=False):
|
303 |
# compute qkv (potentially use cache)
|
304 |
if self.kv_cache is not None:
|
305 |
block_len = x.shape[1] - self.kv_cache.shape[1]
|
|
|
308 |
else:
|
309 |
qkv = self.attn_qkv(x)
|
310 |
|
311 |
+
# store kv cache in a sliding window (can't exceed context len)
|
312 |
+
if store_kv:
|
313 |
if self.kv_cache is not None:
|
314 |
cache_len = min(x.shape[1], self.n - block_len)
|
315 |
self.kv_cache = qkv[:, -cache_len:]
|
|
|
347 |
x = einops.rearrange(x, 'b s h d -> b s (h d)')
|
348 |
return x
|
349 |
|
350 |
+
def forward(self, x, rotary_cos_sin, c, cross_attn_mask=None,
|
351 |
+
sample_mode=False, store_kv=False):
|
352 |
bias_dropout_scale_fn = self._get_bias_dropout_scale()
|
353 |
|
354 |
(shift_msa, scale_msa, gate_msa, shift_mlp,
|
|
|
359 |
x = modulate_fused(self.norm1(x), shift_msa, scale_msa)
|
360 |
|
361 |
# get qkvs
|
362 |
+
if cross_attn_mask is not None and not sample_mode:
|
363 |
qkv_x = self.get_qkv(x[:,:self.n], rotary_cos_sin)
|
364 |
qkv_x0 = self.get_qkv(x[:,self.n:], rotary_cos_sin)
|
365 |
qkv = torch.cat((qkv_x, qkv_x0), dim=1)
|
366 |
else:
|
367 |
+
qkv = self.get_qkv(x, rotary_cos_sin, store_kv=store_kv)
|
368 |
|
369 |
if cross_attn_mask is None and self.attn_backend == 'flash_attn':
|
370 |
x = regular_attention_multi_headed(qkv)
|
|
|
471 |
x0_attn_mask = torch.cat((torch.zeros_like(self_attn_mask), x0_attn_mask), dim=1)
|
472 |
self.cross_attn_mask = torch.cat((cross_attn_mask, x0_attn_mask), dim=0)
|
473 |
|
474 |
+
def forward(self, indices, sigma, sample_mode=False,
|
475 |
+
store_kv=False, output_hidden_states=False):
|
|
|
476 |
if not self.config.time_conditioning:
|
477 |
sigma = torch.zeros_like(sigma)
|
478 |
all_hidden_states = []
|
|
|
480 |
if output_hidden_states:
|
481 |
all_hidden_states.append(x)
|
482 |
c = F.silu(self.sigma_map(sigma))
|
483 |
+
if self.cross_attn:
|
|
|
|
|
|
|
484 |
rotary_cos_sin = self.rotary_emb(x[:, :self.n])
|
485 |
+
cross_attn_mask = self.cross_attn_mask.to(x.device)
|
486 |
+
# use block-causal mask only during sampling
|
487 |
+
if sample_mode:
|
488 |
+
cross_attn_mask = cross_attn_mask[
|
489 |
+
self.n:self.n+x.shape[1], self.n:self.n+x.shape[1]]
|
490 |
else:
|
491 |
cross_attn_mask = None
|
492 |
rotary_cos_sin = self.rotary_emb(x)
|
|
|
497 |
rotary_cos_sin,
|
498 |
c,
|
499 |
cross_attn_mask=cross_attn_mask,
|
500 |
+
sample_mode=sample_mode,
|
501 |
+
store_kv=store_kv)
|
502 |
if output_hidden_states:
|
503 |
all_hidden_states.append(x)
|
504 |
logits = self.output_layer(x, c)
|
505 |
+
if self.cross_attn and not sample_mode:
|
506 |
logits = logits[:, :self.n]
|
507 |
all_hidden_states = [hidden_states[:, :self.n] for hidden_states in all_hidden_states]
|
508 |
return logits, all_hidden_states
|
|
|
529 |
self,
|
530 |
input_ids: torch.LongTensor = None,
|
531 |
timesteps: torch.FloatTensor = None,
|
532 |
+
sample_mode: typing.Optional[bool] = None,
|
533 |
+
store_kv: typing.Optional[bool] = None,
|
534 |
output_hidden_states: typing.Optional[bool] = None,
|
535 |
return_dict: typing.Optional[bool] = None,
|
536 |
) -> typing.Union[
|
|
|
549 |
logits, all_hidden_states = self.backbone(
|
550 |
indices=input_ids,
|
551 |
sigma=timesteps,
|
552 |
+
sample_mode=sample_mode,
|
553 |
+
store_kv=store_kv,
|
554 |
+
output_hidden_states=output_hidden_states,
|
555 |
)
|
556 |
if return_dict:
|
557 |
return modeling_outputs.MaskedLMOutput(
|