akhauriyash commited on
Commit
3658e2c
·
1 Parent(s): ded7955
Files changed (2) hide show
  1. config.json +4 -4
  2. modeling_llama_butler.py +6 -6
config.json CHANGED
@@ -1,14 +1,14 @@
1
  {
2
  "architectures": [
3
- "modeling_llama_butler.LlamaButlerForCausalLM"
4
  ],
5
  "attention_bias": false,
6
  "attention_dropout": 0.0,
7
  "attn_reduce_factor": 8,
8
  "auto_map": {
9
- "AutoConfig": "modeling_llama_butler.LlamaButlerConfig",
10
- "AutoModel": "modeling_llama_butler.LlamaButlerForCausalLM",
11
- "AutoModelForCausalLM": "modeling_llama_butler.LlamaButlerForCausalLM"
12
  },
13
  "bos_token_id": 128000,
14
  "dDash": 16,
 
1
  {
2
  "architectures": [
3
+ "LlamaButlerForCausalLM"
4
  ],
5
  "attention_bias": false,
6
  "attention_dropout": 0.0,
7
  "attn_reduce_factor": 8,
8
  "auto_map": {
9
+ "AutoConfig": "modeling_llama_butler:LlamaButlerConfig",
10
+ "AutoModel": "modeling_llama_butler:LlamaButlerForCausalLM",
11
+ "AutoModelForCausalLM": "modeling_llama_butler:LlamaButlerForCausalLM"
12
  },
13
  "bos_token_id": 128000,
14
  "dDash": 16,
modeling_llama_butler.py CHANGED
@@ -1266,12 +1266,12 @@ class LlamaAttentionExperimental(nn.Module):
1266
  else:
1267
  self.head_importances = torch.cat([self.head_importances, head_importances], dim=1)
1268
 
1269
- if self.layer_idx == 31:
1270
- if q_len == 1:
1271
- self.dtok += 1
1272
- print(f"Primary Key-Value Shape: {past_key_value.predictor_primary_key[0].shape}, Importance: {past_key_value.predictor_importance_key[0].shape}, Tok-Decoded: {self.dtok}")
1273
- else:
1274
- self.dtok = 0
1275
 
1276
  if not output_attentions:
1277
  attn_weights = None
 
1266
  else:
1267
  self.head_importances = torch.cat([self.head_importances, head_importances], dim=1)
1268
 
1269
+ # if self.layer_idx == 31:
1270
+ # if q_len == 1:
1271
+ # self.dtok += 1
1272
+ # print(f"Primary Key-Value Shape: {past_key_value.predictor_primary_key[0].shape}, Importance: {past_key_value.predictor_importance_key[0].shape}, Tok-Decoded: {self.dtok}")
1273
+ # else:
1274
+ # self.dtok = 0
1275
 
1276
  if not output_attentions:
1277
  attn_weights = None