sync changes from github
Browse files- modeling_aria.py +2 -2
- moe_lm.py +2 -2
modeling_aria.py
CHANGED
@@ -23,7 +23,7 @@ from typing import List, Optional, Tuple, Union
|
|
23 |
import torch
|
24 |
import torch.nn as nn
|
25 |
from torch import nn
|
26 |
-
from transformers import PreTrainedModel
|
27 |
from transformers.modeling_outputs import ModelOutput
|
28 |
from transformers.utils import logging
|
29 |
|
@@ -122,7 +122,7 @@ def build_mm_projector(config: AriaConfig):
|
|
122 |
|
123 |
|
124 |
# adapted from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration
|
125 |
-
class AriaForConditionalGeneration(AriaPretrainedModel):
|
126 |
"""
|
127 |
Aria model for conditional generation tasks.
|
128 |
|
|
|
23 |
import torch
|
24 |
import torch.nn as nn
|
25 |
from torch import nn
|
26 |
+
from transformers import GenerationMixin, PreTrainedModel
|
27 |
from transformers.modeling_outputs import ModelOutput
|
28 |
from transformers.utils import logging
|
29 |
|
|
|
122 |
|
123 |
|
124 |
# adapted from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration
|
125 |
+
class AriaForConditionalGeneration(AriaPretrainedModel, GenerationMixin):
|
126 |
"""
|
127 |
Aria model for conditional generation tasks.
|
128 |
|
moe_lm.py
CHANGED
@@ -25,7 +25,7 @@ import torch
|
|
25 |
import torch.nn as nn
|
26 |
import torch.nn.functional as F
|
27 |
from torch import nn
|
28 |
-
from transformers import LlamaConfig
|
29 |
from transformers.models.llama.modeling_llama import (
|
30 |
ACT2FN,
|
31 |
LLAMA_ATTENTION_CLASSES,
|
@@ -634,7 +634,7 @@ class AriaMoELMModel(LlamaModel):
|
|
634 |
self.post_init()
|
635 |
|
636 |
|
637 |
-
class AriaMoELMForCausalLM(LlamaForCausalLM):
|
638 |
"""
|
639 |
AriaMoE model for causal language modeling tasks.
|
640 |
|
|
|
25 |
import torch.nn as nn
|
26 |
import torch.nn.functional as F
|
27 |
from torch import nn
|
28 |
+
from transformers import GenerationMixin, LlamaConfig
|
29 |
from transformers.models.llama.modeling_llama import (
|
30 |
ACT2FN,
|
31 |
LLAMA_ATTENTION_CLASSES,
|
|
|
634 |
self.post_init()
|
635 |
|
636 |
|
637 |
+
class AriaMoELMForCausalLM(LlamaForCausalLM, GenerationMixin):
|
638 |
"""
|
639 |
AriaMoE model for causal language modeling tasks.
|
640 |
|