sync changes from github
Browse files- modeling_aria.py +2 -2
- moe_lm.py +2 -2
modeling_aria.py
CHANGED
|
@@ -23,7 +23,7 @@ from typing import List, Optional, Tuple, Union
|
|
| 23 |
import torch
|
| 24 |
import torch.nn as nn
|
| 25 |
from torch import nn
|
| 26 |
-
from transformers import PreTrainedModel
|
| 27 |
from transformers.modeling_outputs import ModelOutput
|
| 28 |
from transformers.utils import logging
|
| 29 |
|
|
@@ -122,7 +122,7 @@ def build_mm_projector(config: AriaConfig):
|
|
| 122 |
|
| 123 |
|
| 124 |
# adapted from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration
|
| 125 |
-
class AriaForConditionalGeneration(AriaPretrainedModel):
|
| 126 |
"""
|
| 127 |
Aria model for conditional generation tasks.
|
| 128 |
|
|
|
|
| 23 |
import torch
|
| 24 |
import torch.nn as nn
|
| 25 |
from torch import nn
|
| 26 |
+
from transformers import GenerationMixin, PreTrainedModel
|
| 27 |
from transformers.modeling_outputs import ModelOutput
|
| 28 |
from transformers.utils import logging
|
| 29 |
|
|
|
|
| 122 |
|
| 123 |
|
| 124 |
# adapted from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration
|
| 125 |
+
class AriaForConditionalGeneration(AriaPretrainedModel, GenerationMixin):
|
| 126 |
"""
|
| 127 |
Aria model for conditional generation tasks.
|
| 128 |
|
moe_lm.py
CHANGED
|
@@ -25,7 +25,7 @@ import torch
|
|
| 25 |
import torch.nn as nn
|
| 26 |
import torch.nn.functional as F
|
| 27 |
from torch import nn
|
| 28 |
-
from transformers import LlamaConfig
|
| 29 |
from transformers.models.llama.modeling_llama import (
|
| 30 |
ACT2FN,
|
| 31 |
LLAMA_ATTENTION_CLASSES,
|
|
@@ -634,7 +634,7 @@ class AriaMoELMModel(LlamaModel):
|
|
| 634 |
self.post_init()
|
| 635 |
|
| 636 |
|
| 637 |
-
class AriaMoELMForCausalLM(LlamaForCausalLM):
|
| 638 |
"""
|
| 639 |
AriaMoE model for causal language modeling tasks.
|
| 640 |
|
|
|
|
| 25 |
import torch.nn as nn
|
| 26 |
import torch.nn.functional as F
|
| 27 |
from torch import nn
|
| 28 |
+
from transformers import GenerationMixin, LlamaConfig
|
| 29 |
from transformers.models.llama.modeling_llama import (
|
| 30 |
ACT2FN,
|
| 31 |
LLAMA_ATTENTION_CLASSES,
|
|
|
|
| 634 |
self.post_init()
|
| 635 |
|
| 636 |
|
| 637 |
+
class AriaMoELMForCausalLM(LlamaForCausalLM, GenerationMixin):
|
| 638 |
"""
|
| 639 |
AriaMoE model for causal language modeling tasks.
|
| 640 |
|