JonasGeiping commited on
Commit
ddc2bd9
·
verified ·
1 Parent(s): e28a94f

Update raven_modeling_minimal.py

Browse files
Files changed (1) hide show
  1. raven_modeling_minimal.py +116 -37
raven_modeling_minimal.py CHANGED
@@ -11,7 +11,7 @@ from .raven_config_minimal import RavenConfig
11
  from transformers.cache_utils import Cache, DynamicCache
12
 
13
  ###################### Huggingface Glue code I ##################################################################
14
- from transformers import PreTrainedModel
15
  from transformers.utils import ModelOutput
16
  from transformers.generation.utils import GenerateDecoderOnlyOutput
17
 
@@ -32,7 +32,8 @@ class RavenPreTrainedModel(PreTrainedModel):
32
  _supports_static_cache = False
33
 
34
  def _init_weights(self, module):
35
- print("Random Initialization not implemented.")
 
36
 
37
 
38
  @dataclass
@@ -70,7 +71,7 @@ class RMSNorm(torch.nn.Module):
70
 
71
 
72
  class HuginnDynamicCache(DynamicCache):
73
- def __init__(self, lookup_strategy: str = "latest") -> None:
74
  super().__init__()
75
  self._seen_tokens = 0
76
  self.key_cache: dict[int, dict[int, torch.Tensor]] = {}
@@ -89,6 +90,14 @@ class HuginnDynamicCache(DynamicCache):
89
  lookup_strategy: Optional[str] = None,
90
  ) -> tuple[torch.Tensor, torch.Tensor]:
91
  lookup_strategy = self.lookup_strategy if lookup_strategy is None else lookup_strategy
 
 
 
 
 
 
 
 
92
  # Init
93
  if step_idx not in self.key_cache:
94
  self.key_cache[step_idx] = {}
@@ -98,32 +107,49 @@ class HuginnDynamicCache(DynamicCache):
98
  self._seen_tokens += key_states.shape[-2]
99
  # Add entries to cache
100
  for idx, entry in enumerate(key_states.unbind(dim=-2)):
101
- assert step_idx < 0 or self._seen_tokens - key_states.shape[-2] + idx not in self.key_cache[step_idx]
 
102
  # print(f"Overwrote cache entry for step_idx {step_idx}") # likely the head
103
  self.key_cache[step_idx][self._seen_tokens - key_states.shape[-2] + idx] = entry
104
  for idx, entry in enumerate(value_states.unbind(dim=-2)):
105
  self.value_cache[step_idx][self._seen_tokens - value_states.shape[-2] + idx] = entry
106
 
107
  # Materialize past state based on lookup strategy:
108
- if len(self.key_cache[step_idx]) == self._seen_tokens:
109
  # All entries are present, materialize cache as normal
110
  return (
111
  torch.stack(list(self.key_cache[step_idx].values()), dim=-2),
112
  torch.stack(list(self.value_cache[step_idx].values()), dim=-2),
113
  )
114
  else: # some entries where not previously computed
115
- if lookup_strategy == "latest":
 
 
 
 
 
 
 
 
 
 
 
116
  latest_keys = []
117
  latest_values = []
118
  for token_pos in range(self._seen_tokens):
119
- # Find the latest step that has this token position
120
- max_step = max((s for s in range(step_idx + 1) if token_pos in self.key_cache[s]), default=None)
 
 
 
 
 
121
  if max_step is None:
122
  raise ValueError(f"No cache entry found for token position {token_pos}")
123
  latest_keys.append(self.key_cache[max_step][token_pos])
124
  latest_values.append(self.value_cache[max_step][token_pos])
125
  return torch.stack(latest_keys, dim=-2), torch.stack(latest_values, dim=-2)
126
- elif lookup_strategy == "skip":
127
  existing_keys = []
128
  existing_values = []
129
  for token_pos in range(self._seen_tokens):
@@ -131,15 +157,22 @@ class HuginnDynamicCache(DynamicCache):
131
  existing_keys.append(self.key_cache[step_idx][token_pos])
132
  existing_values.append(self.value_cache[step_idx][token_pos])
133
  return torch.stack(existing_keys, dim=-2), torch.stack(existing_values, dim=-2)
134
- elif lookup_strategy == "randomized": # sanity check
135
  rand_keys = []
136
  rand_values = []
137
  for token_pos in range(self._seen_tokens):
138
- # Find steps that have this token position
139
- steps = [s for s in range(step_idx + 1) if token_pos in self.key_cache[s]]
140
- rand_step = steps[torch.randint(len(steps), (1,))]
141
- rand_keys.append(self.key_cache[rand_step][token_pos])
142
- rand_values.append(self.value_cache[rand_step][token_pos])
 
 
 
 
 
 
 
143
  return torch.stack(rand_keys, dim=-2), torch.stack(rand_values, dim=-2)
144
  else:
145
  raise ValueError(f"Unknown lookup strategy: {lookup_strategy}")
@@ -153,6 +186,18 @@ class HuginnDynamicCache(DynamicCache):
153
  def get_seq_length(self, step_idx: int = 0) -> int:
154
  return self._seen_tokens
155
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  class CausalSelfAttention(torch.nn.Module):
158
  def __init__(self, config: RavenConfig) -> None:
@@ -265,7 +310,7 @@ class SandwichBlock(torch.nn.Module):
265
  return x, attn_map
266
 
267
 
268
- class RavenForCausalLM(RavenPreTrainedModel):
269
  def __init__(
270
  self,
271
  config: RavenConfig,
@@ -323,7 +368,7 @@ class RavenForCausalLM(RavenPreTrainedModel):
323
  "return_latents": True,
324
  "return_attention": False,
325
  "return_head": False,
326
- "return_stats": True,
327
  },
328
  use_cache: bool = False,
329
  cache_position: Optional[torch.Tensor] = None,
@@ -351,7 +396,7 @@ class RavenForCausalLM(RavenPreTrainedModel):
351
  # Non-recurrent prelude
352
  for block_idx, block in enumerate(self.transformer.prelude):
353
  input_embeds, attn_map = block(
354
- input_embeds, freqs_cis, block_idx, attention_mask, past_key_values, return_attn
355
  )
356
  attn_maps[block_idx] = attn_map
357
 
@@ -365,12 +410,13 @@ class RavenForCausalLM(RavenPreTrainedModel):
365
  past_key_values,
366
  num_steps,
367
  attn_maps,
 
368
  )
369
  latent_states = x.clone().detach()
370
 
371
  # Coda layers
372
  for block_idx, block in enumerate(self.transformer.coda, start=1):
373
- x, attn_map = block(x, freqs_cis, -block_idx, attention_mask, past_key_values, return_attn)
374
  attn_maps[-block_idx] = attn_map
375
  x = self.transformer.ln_f(x)
376
 
@@ -407,6 +453,7 @@ class RavenForCausalLM(RavenPreTrainedModel):
407
  past_key_values: Optional[Cache] = None,
408
  num_steps: Optional[torch.Tensor] = None,
409
  attn_maps: dict = {},
 
410
  ):
411
  x = xk = self.initialize_state(input_embeds) if input_states is None else input_states.clone()
412
  if num_steps is None:
@@ -424,13 +471,13 @@ class RavenForCausalLM(RavenPreTrainedModel):
424
  for step in range(num_steps_no_grad):
425
  xk = x
426
  x, block_idx, attn_maps = self.core_block_forward(
427
- xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, attn_maps
428
  )
429
 
430
  for step in range(num_steps_with_grad):
431
  xk = x
432
  x, block_idx, attn_maps = self.core_block_forward(
433
- xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, attn_maps
434
  )
435
  return self.transformer.ln_f(x), num_steps_no_grad, num_steps_with_grad, xk.detach(), block_idx, attn_maps
436
 
@@ -443,10 +490,11 @@ class RavenForCausalLM(RavenPreTrainedModel):
443
  past_key_values,
444
  block_idx: Union[torch.Tensor, int],
445
  attn_maps: dict = {},
 
446
  ):
447
  x = self.transformer.adapter(torch.cat([x, input_embeds], dim=-1))
448
  for idx, block in enumerate(self.transformer.core_block, start=1):
449
- x, attn_map = block(x, freqs_cis, block_idx + idx, mask, past_key_values, return_attn=len(attn_maps) > 0)
450
  attn_maps[block_idx + idx] = attn_map
451
  return x, block_idx + idx, attn_maps
452
 
@@ -579,7 +627,7 @@ class RavenForCausalLM(RavenPreTrainedModel):
579
  model_inputs["cache_position"] = cache_position
580
  current_input_length = input_ids.shape[1]
581
  if past_key_values is not None:
582
- if type(past_key_values) == DynamicCache:
583
  # Need to use custom cache, detect and replace HF dynamic cache if generate injects it
584
  assert past_key_values.get_seq_length() == 0
585
  past_key_values = HuginnDynamicCache()
@@ -599,6 +647,18 @@ class RavenForCausalLM(RavenPreTrainedModel):
599
  model_inputs[key] = value
600
  return model_inputs
601
 
 
 
 
 
 
 
 
 
 
 
 
 
602
  @torch.no_grad()
603
  def generate_minimal(
604
  self,
@@ -693,6 +753,7 @@ class RavenForCausalLM(RavenPreTrainedModel):
693
  continuous_compute=False, # warm-start state / continuous CoT
694
  latent_dampening=False,
695
  criterion="entropy-diff",
 
696
  cache_kwargs: dict = {},
697
  **model_kwargs,
698
  ) -> Union[torch.Tensor, GenerateDecoderOnlyOutput]:
@@ -725,46 +786,64 @@ class RavenForCausalLM(RavenPreTrainedModel):
725
  # Prep criterions:
726
  if criterion == "entropy-diff":
727
  entropy = torch.tensor(100.0, device=input_ids.device)
 
728
  elif criterion in ["latent-diff", "none"]:
729
- pass
730
- elif criterion == "kl":
731
  V = self.config.padded_vocab_size
732
  log_probs = (1 / V * torch.ones(V, device=input_ids.device)).log()
 
 
 
 
733
  elif criterion == "argmax-stability":
734
  stable_for_n_steps = 0
735
  current_argmax = torch.tensor(-1, dtype=torch.long, device=input_ids.device)
 
736
  else:
737
  raise ValueError("Invalid adaptive compute strategy.")
738
 
739
  all_latents = []
740
- for compute_step in range(1, model_inputs["num_steps"]):
 
741
  prev_latents = current_latents.clone()
742
  current_latents, block_idx, _ = self.iterate_one_step(
743
  embedded_inputs, current_latents, block_idx=block_idx, **aux_inputs
744
  )
745
  all_latents.append(current_latents if latent_dampening else None)
746
- if compute_step > 1 and step > 0: # do not exit in prefill:
747
  if criterion == "entropy-diff":
748
  prev_entropy = entropy.clone()
749
  outputs = self.predict_from_latents(current_latents, **aux_inputs)
750
  probs = F.softmax(outputs.logits[:, -1, :], dim=-1) # type: ignore
751
  entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1).mean()
752
  entropy_diff = (entropy - prev_entropy).abs()
753
- if entropy_diff < 1e-3:
754
- compute_steps.append([compute_step, entropy_diff.item()])
755
  break
756
  elif criterion == "latent-diff":
757
- norm_diff = (prev_latents - current_latents).norm()
758
- if norm_diff < 1:
759
- compute_steps.append([compute_step, norm_diff.item()])
760
  break
761
  elif criterion == "kl":
762
  prev_log_probs = log_probs.clone()
763
  outputs = self.predict_from_latents(current_latents, **aux_inputs)
764
  log_probs = F.log_softmax(outputs.logits[:, -1, :], dim=-1) # type: ignore
765
  kl = F.kl_div(log_probs, prev_log_probs, reduction="none", log_target=True).sum(dim=-1)
766
- if kl < 2e-4:
767
- compute_steps.append([compute_step, kl.item()])
 
 
 
 
 
 
 
 
 
 
 
768
  break
769
  elif criterion == "argmax-stability":
770
  prev_argmax = current_argmax.clone()
@@ -774,19 +853,19 @@ class RavenForCausalLM(RavenPreTrainedModel):
774
  stable_for_n_steps += 1
775
  else:
776
  stable_for_n_steps = 0
777
- if stable_for_n_steps >= 10:
778
- compute_steps.append([compute_step, stable_for_n_steps])
779
  break
780
  elif criterion == "none":
781
  pass
782
 
783
  else:
784
- compute_steps.append([compute_step, float("NaN")])
785
  if not latent_dampening:
786
  outputs = self.predict_from_latents(current_latents, **aux_inputs)
787
  else:
788
  dampened_latents = torch.sum(torch.cat(all_latents, dim=0), dim=0, keepdim=True)
789
  outputs = self.predict_from_latents(dampened_latents, **aux_inputs)
 
790
 
791
  next_token_logits = outputs.logits[0, -1, :] # type: ignore
792
  if continuous_compute: # Save last latent
 
11
  from transformers.cache_utils import Cache, DynamicCache
12
 
13
  ###################### Huggingface Glue code I ##################################################################
14
+ from transformers import PreTrainedModel, GenerationMixin
15
  from transformers.utils import ModelOutput
16
  from transformers.generation.utils import GenerateDecoderOnlyOutput
17
 
 
32
  _supports_static_cache = False
33
 
34
  def _init_weights(self, module):
35
+ if not torch.rand((1,)).is_meta:
36
+ print("Random Initialization not implemented.")
37
 
38
 
39
  @dataclass
 
71
 
72
 
73
  class HuginnDynamicCache(DynamicCache):
74
+ def __init__(self, lookup_strategy: str = "full") -> None:
75
  super().__init__()
76
  self._seen_tokens = 0
77
  self.key_cache: dict[int, dict[int, torch.Tensor]] = {}
 
90
  lookup_strategy: Optional[str] = None,
91
  ) -> tuple[torch.Tensor, torch.Tensor]:
92
  lookup_strategy = self.lookup_strategy if lookup_strategy is None else lookup_strategy
93
+ if "compress-" in self.lookup_strategy and step_idx > 1: # hardcode for current model!
94
+ compression_stage = int(self.lookup_strategy.split("compress-")[1][1:])
95
+ if "compress-s" in self.lookup_strategy:
96
+ new_step_idx = (step_idx - 2) % compression_stage + 2
97
+ else:
98
+ new_step_idx = (step_idx - 2) // compression_stage + 2
99
+ # @ print(step_idx, new_step_idx, compression_stage)
100
+ step_idx = new_step_idx
101
  # Init
102
  if step_idx not in self.key_cache:
103
  self.key_cache[step_idx] = {}
 
107
  self._seen_tokens += key_states.shape[-2]
108
  # Add entries to cache
109
  for idx, entry in enumerate(key_states.unbind(dim=-2)):
110
+ if "compress-" not in self.lookup_strategy:
111
+ assert step_idx < 0 or self._seen_tokens - key_states.shape[-2] + idx not in self.key_cache[step_idx]
112
  # print(f"Overwrote cache entry for step_idx {step_idx}") # likely the head
113
  self.key_cache[step_idx][self._seen_tokens - key_states.shape[-2] + idx] = entry
114
  for idx, entry in enumerate(value_states.unbind(dim=-2)):
115
  self.value_cache[step_idx][self._seen_tokens - value_states.shape[-2] + idx] = entry
116
 
117
  # Materialize past state based on lookup strategy:
118
+ if len(self.key_cache[step_idx]) == self._seen_tokens or self.lookup_strategy == "full":
119
  # All entries are present, materialize cache as normal
120
  return (
121
  torch.stack(list(self.key_cache[step_idx].values()), dim=-2),
122
  torch.stack(list(self.value_cache[step_idx].values()), dim=-2),
123
  )
124
  else: # some entries where not previously computed
125
+ # if lookup_strategy.startswith("latest"):
126
+ # latest_keys = []
127
+ # latest_values = []
128
+ # for token_pos in range(self._seen_tokens):
129
+ # # Find the latest step that has this token position
130
+ # max_step = max((s for s in range(step_idx + 1) if token_pos in self.key_cache[s]), default=None)
131
+ # if max_step is None:
132
+ # raise ValueError(f"No cache entry found for token position {token_pos}")
133
+ # latest_keys.append(self.key_cache[max_step][token_pos])
134
+ # latest_values.append(self.value_cache[max_step][token_pos])
135
+ # return torch.stack(latest_keys, dim=-2), torch.stack(latest_values, dim=-2)
136
+ if lookup_strategy.startswith("latest-m4"):
137
  latest_keys = []
138
  latest_values = []
139
  for token_pos in range(self._seen_tokens):
140
+ # For steps >= 2, use modulo 4
141
+ if step_idx >= 2:
142
+ # Find valid steps for this token position
143
+ valid_steps = [s for s in range(step_idx + 1) if token_pos in self.key_cache[s]]
144
+ max_step = max([s for s in valid_steps if s >= 2 and s % 4 == step_idx % 4])
145
+ else:
146
+ max_step = step_idx if token_pos in self.key_cache[step_idx] else 0
147
  if max_step is None:
148
  raise ValueError(f"No cache entry found for token position {token_pos}")
149
  latest_keys.append(self.key_cache[max_step][token_pos])
150
  latest_values.append(self.value_cache[max_step][token_pos])
151
  return torch.stack(latest_keys, dim=-2), torch.stack(latest_values, dim=-2)
152
+ elif lookup_strategy.startswith("skip"):
153
  existing_keys = []
154
  existing_values = []
155
  for token_pos in range(self._seen_tokens):
 
157
  existing_keys.append(self.key_cache[step_idx][token_pos])
158
  existing_values.append(self.value_cache[step_idx][token_pos])
159
  return torch.stack(existing_keys, dim=-2), torch.stack(existing_values, dim=-2)
160
+ elif lookup_strategy.startswith("randomized"): # sanity check
161
  rand_keys = []
162
  rand_values = []
163
  for token_pos in range(self._seen_tokens):
164
+ if step_idx < 2: # For prelude steps
165
+ max_step = step_idx if token_pos in self.key_cache[step_idx] else 0
166
+ else: # Get all steps from same block position
167
+ curr_modulo = (step_idx - 2) % 4 + 2
168
+ valid_steps = [
169
+ s
170
+ for s in range(2, step_idx + 1)
171
+ if (s - 2) % 4 + 2 == curr_modulo and token_pos in self.key_cache[s]
172
+ ]
173
+ max_step = valid_steps[torch.randint(len(valid_steps), (1,))]
174
+ rand_keys.append(self.key_cache[max_step][token_pos])
175
+ rand_values.append(self.value_cache[max_step][token_pos])
176
  return torch.stack(rand_keys, dim=-2), torch.stack(rand_values, dim=-2)
177
  else:
178
  raise ValueError(f"Unknown lookup strategy: {lookup_strategy}")
 
186
  def get_seq_length(self, step_idx: int = 0) -> int:
187
  return self._seen_tokens
188
 
189
+ def get_memory_usage(self) -> float:
190
+ total_bytes = 0
191
+ # For each recurrent step/layer index
192
+ for step_idx in self.key_cache:
193
+ # Get the sequence cache for this step
194
+ key_seq_cache = self.key_cache[step_idx]
195
+ for seq_idx in key_seq_cache:
196
+ key_tensor = key_seq_cache[seq_idx]
197
+ # Add memory for of key tensors, assuming value is the same
198
+ total_bytes += key_tensor.nelement() * key_tensor.element_size()
199
+ return total_bytes * 2 / (1024 * 1024)
200
+
201
 
202
  class CausalSelfAttention(torch.nn.Module):
203
  def __init__(self, config: RavenConfig) -> None:
 
310
  return x, attn_map
311
 
312
 
313
+ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
314
  def __init__(
315
  self,
316
  config: RavenConfig,
 
368
  "return_latents": True,
369
  "return_attention": False,
370
  "return_head": False,
371
+ "return_stats": False,
372
  },
373
  use_cache: bool = False,
374
  cache_position: Optional[torch.Tensor] = None,
 
396
  # Non-recurrent prelude
397
  for block_idx, block in enumerate(self.transformer.prelude):
398
  input_embeds, attn_map = block(
399
+ input_embeds, freqs_cis, block_idx, attention_mask, past_key_values, return_attn=return_attn
400
  )
401
  attn_maps[block_idx] = attn_map
402
 
 
410
  past_key_values,
411
  num_steps,
412
  attn_maps,
413
+ return_attn=return_attn,
414
  )
415
  latent_states = x.clone().detach()
416
 
417
  # Coda layers
418
  for block_idx, block in enumerate(self.transformer.coda, start=1):
419
+ x, attn_map = block(x, freqs_cis, -block_idx, attention_mask, past_key_values, return_attn=return_attn)
420
  attn_maps[-block_idx] = attn_map
421
  x = self.transformer.ln_f(x)
422
 
 
453
  past_key_values: Optional[Cache] = None,
454
  num_steps: Optional[torch.Tensor] = None,
455
  attn_maps: dict = {},
456
+ return_attn: bool = False,
457
  ):
458
  x = xk = self.initialize_state(input_embeds) if input_states is None else input_states.clone()
459
  if num_steps is None:
 
471
  for step in range(num_steps_no_grad):
472
  xk = x
473
  x, block_idx, attn_maps = self.core_block_forward(
474
+ xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, attn_maps, return_attn
475
  )
476
 
477
  for step in range(num_steps_with_grad):
478
  xk = x
479
  x, block_idx, attn_maps = self.core_block_forward(
480
+ xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, attn_maps, return_attn
481
  )
482
  return self.transformer.ln_f(x), num_steps_no_grad, num_steps_with_grad, xk.detach(), block_idx, attn_maps
483
 
 
490
  past_key_values,
491
  block_idx: Union[torch.Tensor, int],
492
  attn_maps: dict = {},
493
+ return_attn: bool = False,
494
  ):
495
  x = self.transformer.adapter(torch.cat([x, input_embeds], dim=-1))
496
  for idx, block in enumerate(self.transformer.core_block, start=1):
497
+ x, attn_map = block(x, freqs_cis, block_idx + idx, mask, past_key_values, return_attn=return_attn)
498
  attn_maps[block_idx + idx] = attn_map
499
  return x, block_idx + idx, attn_maps
500
 
 
627
  model_inputs["cache_position"] = cache_position
628
  current_input_length = input_ids.shape[1]
629
  if past_key_values is not None:
630
+ if type(past_key_values) != HuginnDynamicCache:
631
  # Need to use custom cache, detect and replace HF dynamic cache if generate injects it
632
  assert past_key_values.get_seq_length() == 0
633
  past_key_values = HuginnDynamicCache()
 
647
  model_inputs[key] = value
648
  return model_inputs
649
 
650
+ @torch.no_grad()
651
+ def generate(self, *args, **kwargs):
652
+ """Dispatcher - use HF generate in all normal cases."""
653
+ if any(
654
+ k in kwargs
655
+ for k in ("continuous_compute", "latent_dampening", "criterion", "exit_threshold", "cache_kwargs")
656
+ ):
657
+ print("Dispatching to custom generate function call")
658
+ return self.generate_with_adaptive_compute(*args, **kwargs)
659
+ else:
660
+ return super().generate(*args, **kwargs)
661
+
662
  @torch.no_grad()
663
  def generate_minimal(
664
  self,
 
753
  continuous_compute=False, # warm-start state / continuous CoT
754
  latent_dampening=False,
755
  criterion="entropy-diff",
756
+ exit_threshold: Union[str, float, int] = "auto",
757
  cache_kwargs: dict = {},
758
  **model_kwargs,
759
  ) -> Union[torch.Tensor, GenerateDecoderOnlyOutput]:
 
786
  # Prep criterions:
787
  if criterion == "entropy-diff":
788
  entropy = torch.tensor(100.0, device=input_ids.device)
789
+ exit_threshold = 1e-3 if exit_threshold == "auto" else float(exit_threshold)
790
  elif criterion in ["latent-diff", "none"]:
791
+ exit_threshold = 0.03 if exit_threshold == "auto" else float(exit_threshold)
792
+ elif "kl" in criterion:
793
  V = self.config.padded_vocab_size
794
  log_probs = (1 / V * torch.ones(V, device=input_ids.device)).log()
795
+ if criterion == "minp-kl":
796
+ exit_threshold = 1e-6 if exit_threshold == "auto" else float(exit_threshold)
797
+ else:
798
+ exit_threshold = 5e-4 if exit_threshold == "auto" else float(exit_threshold)
799
  elif criterion == "argmax-stability":
800
  stable_for_n_steps = 0
801
  current_argmax = torch.tensor(-1, dtype=torch.long, device=input_ids.device)
802
+ exit_threshold = 5 if exit_threshold == "auto" else int(exit_threshold)
803
  else:
804
  raise ValueError("Invalid adaptive compute strategy.")
805
 
806
  all_latents = []
807
+ exit_values = []
808
+ for compute_step in range(model_inputs["num_steps"]):
809
  prev_latents = current_latents.clone()
810
  current_latents, block_idx, _ = self.iterate_one_step(
811
  embedded_inputs, current_latents, block_idx=block_idx, **aux_inputs
812
  )
813
  all_latents.append(current_latents if latent_dampening else None)
814
+ if step > 0: # do not exit in prefill:
815
  if criterion == "entropy-diff":
816
  prev_entropy = entropy.clone()
817
  outputs = self.predict_from_latents(current_latents, **aux_inputs)
818
  probs = F.softmax(outputs.logits[:, -1, :], dim=-1) # type: ignore
819
  entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1).mean()
820
  entropy_diff = (entropy - prev_entropy).abs()
821
+ exit_values.append(entropy_diff.item())
822
+ if entropy_diff < exit_threshold:
823
  break
824
  elif criterion == "latent-diff":
825
+ norm_diff = (prev_latents - current_latents).norm() / current_latents.norm()
826
+ exit_values.append(norm_diff.item())
827
+ if norm_diff < exit_threshold:
828
  break
829
  elif criterion == "kl":
830
  prev_log_probs = log_probs.clone()
831
  outputs = self.predict_from_latents(current_latents, **aux_inputs)
832
  log_probs = F.log_softmax(outputs.logits[:, -1, :], dim=-1) # type: ignore
833
  kl = F.kl_div(log_probs, prev_log_probs, reduction="none", log_target=True).sum(dim=-1)
834
+ exit_values.append(kl.item())
835
+ if kl < exit_threshold:
836
+ break
837
+ elif criterion == "minp-kl":
838
+ prev_log_probs = log_probs.clone()
839
+ outputs = self.predict_from_latents(current_latents, **aux_inputs)
840
+ probs = F.softmax(outputs.logits[:, -1, :], dim=-1) # type: ignore
841
+ probs[probs < 0.1 * probs.max()] = 1 / V
842
+ probs = probs / probs.sum()
843
+ log_probs = probs.log()
844
+ kl = F.kl_div(log_probs, prev_log_probs, reduction="none", log_target=True).sum(dim=-1)
845
+ exit_values.append(kl.item())
846
+ if kl < exit_threshold:
847
  break
848
  elif criterion == "argmax-stability":
849
  prev_argmax = current_argmax.clone()
 
853
  stable_for_n_steps += 1
854
  else:
855
  stable_for_n_steps = 0
856
+ exit_values.append(stable_for_n_steps)
857
+ if stable_for_n_steps >= exit_threshold:
858
  break
859
  elif criterion == "none":
860
  pass
861
 
862
  else:
 
863
  if not latent_dampening:
864
  outputs = self.predict_from_latents(current_latents, **aux_inputs)
865
  else:
866
  dampened_latents = torch.sum(torch.cat(all_latents, dim=0), dim=0, keepdim=True)
867
  outputs = self.predict_from_latents(dampened_latents, **aux_inputs)
868
+ compute_steps.append([compute_step + 1, exit_values])
869
 
870
  next_token_logits = outputs.logits[0, -1, :] # type: ignore
871
  if continuous_compute: # Save last latent