bigmoyan codecho commited on
Commit
f8ba0d0
·
verified ·
1 Parent(s): 4b4b7bf

Refactor Logits Naming (#15)

Browse files

- Refactor Logits Naming (5aed2aa1a7dabed45b57dbde209de31cef94b39f)


Co-authored-by: Xu <[email protected]>

Files changed (1) hide show
  1. modeling_moonshot_kimia.py +4 -4
modeling_moonshot_kimia.py CHANGED
@@ -901,15 +901,15 @@ class MoonshotKimiaForCausalLM(Qwen2PreTrainedModel):
901
  else:
902
  hidden_states, mimo_hidden_states = outputs[0], outputs[1]
903
 
904
- audio_logits = self.lm_head(hidden_states)
905
- text_logits = self.mimo_output(mimo_hidden_states)
906
 
907
  if not return_dict:
908
- output = (text_logits, audio_logits) + outputs[2:]
909
  return output
910
  return CausalLMOutputWithPast(
911
  loss=None,
912
- logits=(text_logits, audio_logits),
913
  past_key_values=outputs.past_key_values,
914
  hidden_states=outputs.hidden_states,
915
  attentions=outputs.attentions,
 
901
  else:
902
  hidden_states, mimo_hidden_states = outputs[0], outputs[1]
903
 
904
+ text_logits = self.lm_head(hidden_states)
905
+ audio_logits = self.mimo_output(mimo_hidden_states)
906
 
907
  if not return_dict:
908
+ output = (audio_logits, text_logits) + outputs[2:]
909
  return output
910
  return CausalLMOutputWithPast(
911
  loss=None,
912
+ logits=(audio_logits, text_logits),
913
  past_key_values=outputs.past_key_values,
914
  hidden_states=outputs.hidden_states,
915
  attentions=outputs.attentions,