Upload BD3LM
Browse files- modeling_bd3lm.py +2 -4
    	
        modeling_bd3lm.py
    CHANGED
    
    | @@ -299,7 +299,6 @@ class DDiTBlock(nn.Module): | |
| 299 | 
             
                else:
         | 
| 300 | 
             
                  return bias_dropout_add_scale_fused_inference
         | 
| 301 |  | 
| 302 | 
            -
             | 
| 303 | 
             
              def get_qkv(self, x, rotary_cos_sin, store_kv=False):
         | 
| 304 | 
             
                # compute qkv (potentially use cache)
         | 
| 305 | 
             
                if self.kv_cache is not None:
         | 
| @@ -307,11 +306,10 @@ class DDiTBlock(nn.Module): | |
| 307 | 
             
                  qkv = torch.cat((self.kv_cache, new_qkv), dim=1)
         | 
| 308 | 
             
                else:
         | 
| 309 | 
             
                  qkv = self.attn_qkv(x)
         | 
| 310 | 
            -
                
         | 
| 311 | 
             
                # store kv cache in a sliding window (can't exceed context len)
         | 
| 312 | 
             
                if store_kv:
         | 
| 313 | 
            -
                  self.kv_cache = qkv
         | 
| 314 | 
            -
             | 
| 315 | 
             
                qkv = einops.rearrange(
         | 
| 316 | 
             
                  qkv,
         | 
| 317 | 
             
                  'b s (three h d) -> b s three h d',
         | 
|  | |
| 299 | 
             
                else:
         | 
| 300 | 
             
                  return bias_dropout_add_scale_fused_inference
         | 
| 301 |  | 
|  | |
| 302 | 
             
              def get_qkv(self, x, rotary_cos_sin, store_kv=False):
         | 
| 303 | 
             
                # compute qkv (potentially use cache)
         | 
| 304 | 
             
                if self.kv_cache is not None:
         | 
|  | |
| 306 | 
             
                  qkv = torch.cat((self.kv_cache, new_qkv), dim=1)
         | 
| 307 | 
             
                else:
         | 
| 308 | 
             
                  qkv = self.attn_qkv(x)
         | 
|  | |
| 309 | 
             
                # store kv cache in a sliding window (can't exceed context len)
         | 
| 310 | 
             
                if store_kv:
         | 
| 311 | 
            +
                  self.kv_cache = qkv[:, -(self.n-self.block_size):]
         | 
| 312 | 
            +
             | 
| 313 | 
             
                qkv = einops.rearrange(
         | 
| 314 | 
             
                  qkv,
         | 
| 315 | 
             
                  'b s (three h d) -> b s three h d',
         | 
