aria-dev commited on
Commit
6a3f5b6
1 Parent(s): fb5de81

sync changes from github

Browse files
Files changed (2) hide show
  1. modeling_aria.py +2 -2
  2. moe_lm.py +2 -2
modeling_aria.py CHANGED
@@ -23,7 +23,7 @@ from typing import List, Optional, Tuple, Union
23
  import torch
24
  import torch.nn as nn
25
  from torch import nn
26
- from transformers import PreTrainedModel
27
  from transformers.modeling_outputs import ModelOutput
28
  from transformers.utils import logging
29
 
@@ -122,7 +122,7 @@ def build_mm_projector(config: AriaConfig):
122
 
123
 
124
  # adapted from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration
125
- class AriaForConditionalGeneration(AriaPretrainedModel):
126
  """
127
  Aria model for conditional generation tasks.
128
 
 
23
  import torch
24
  import torch.nn as nn
25
  from torch import nn
26
+ from transformers import GenerationMixin, PreTrainedModel
27
  from transformers.modeling_outputs import ModelOutput
28
  from transformers.utils import logging
29
 
 
122
 
123
 
124
  # adapted from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration
125
+ class AriaForConditionalGeneration(AriaPretrainedModel, GenerationMixin):
126
  """
127
  Aria model for conditional generation tasks.
128
 
moe_lm.py CHANGED
@@ -25,7 +25,7 @@ import torch
25
  import torch.nn as nn
26
  import torch.nn.functional as F
27
  from torch import nn
28
- from transformers import LlamaConfig
29
  from transformers.models.llama.modeling_llama import (
30
  ACT2FN,
31
  LLAMA_ATTENTION_CLASSES,
@@ -634,7 +634,7 @@ class AriaMoELMModel(LlamaModel):
634
  self.post_init()
635
 
636
 
637
- class AriaMoELMForCausalLM(LlamaForCausalLM):
638
  """
639
  AriaMoE model for causal language modeling tasks.
640
 
 
25
  import torch.nn as nn
26
  import torch.nn.functional as F
27
  from torch import nn
28
+ from transformers import GenerationMixin, LlamaConfig
29
  from transformers.models.llama.modeling_llama import (
30
  ACT2FN,
31
  LLAMA_ATTENTION_CLASSES,
 
634
  self.post_init()
635
 
636
 
637
+ class AriaMoELMForCausalLM(LlamaForCausalLM, GenerationMixin):
638
  """
639
  AriaMoE model for causal language modeling tasks.
640