Refactor Logits Naming (#15)
Browse files- Refactor Logits Naming (5aed2aa1a7dabed45b57dbde209de31cef94b39f)
Co-authored-by: Xu <[email protected]>
modeling_moonshot_kimia.py
CHANGED
@@ -901,15 +901,15 @@ class MoonshotKimiaForCausalLM(Qwen2PreTrainedModel):
|
|
901 |
else:
|
902 |
hidden_states, mimo_hidden_states = outputs[0], outputs[1]
|
903 |
|
904 |
-
|
905 |
-
|
906 |
|
907 |
if not return_dict:
|
908 |
-
output = (
|
909 |
return output
|
910 |
return CausalLMOutputWithPast(
|
911 |
loss=None,
|
912 |
-
logits=(
|
913 |
past_key_values=outputs.past_key_values,
|
914 |
hidden_states=outputs.hidden_states,
|
915 |
attentions=outputs.attentions,
|
|
|
901 |
else:
|
902 |
hidden_states, mimo_hidden_states = outputs[0], outputs[1]
|
903 |
|
904 |
+
text_logits = self.lm_head(hidden_states)
|
905 |
+
audio_logits = self.mimo_output(mimo_hidden_states)
|
906 |
|
907 |
if not return_dict:
|
908 |
+
output = (audio_logits, text_logits) + outputs[2:]
|
909 |
return output
|
910 |
return CausalLMOutputWithPast(
|
911 |
loss=None,
|
912 |
+
logits=(audio_logits, text_logits),
|
913 |
past_key_values=outputs.past_key_values,
|
914 |
hidden_states=outputs.hidden_states,
|
915 |
attentions=outputs.attentions,
|