kcz358 commited on
Commit
36071d9
·
verified ·
1 Parent(s): 3d429c6

Upload configuration_aero.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. configuration_aero.py +73 -0
configuration_aero.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ from transformers.configuration_utils import PretrainedConfig
18
+ from transformers.models.auto import CONFIG_MAPPING, AutoConfig
19
+ from transformers.utils import logging
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class AeroConfig(PretrainedConfig):
25
+ model_type = "aero"
26
+ sub_configs = {
27
+ "text_config": AutoConfig,
28
+ "audio_config": AutoConfig,
29
+ }
30
+
31
+ def __init__(
32
+ self,
33
+ text_config=None,
34
+ audio_config=None,
35
+ audio_token_index=151648,
36
+ tie_word_embeddings=False,
37
+ **kwargs,
38
+ ):
39
+ self.audio_token_index = audio_token_index
40
+
41
+ if isinstance(text_config, dict):
42
+ text_config["model_type"] = (
43
+ text_config["model_type"] if "model_type" in text_config else "qwen2"
44
+ )
45
+ text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
46
+ elif text_config is None:
47
+ text_config = AutoConfig.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
48
+
49
+ self.text_config = text_config
50
+
51
+ if isinstance(audio_config, dict):
52
+ audio_config["model_type"] = (
53
+ audio_config["model_type"]
54
+ if "model_type" in audio_config
55
+ else "qwen2_audio_encoder"
56
+ )
57
+ audio_config = CONFIG_MAPPING[audio_config["model_type"]](**audio_config)
58
+ elif audio_config is None:
59
+ audio_config = CONFIG_MAPPING["qwen2_audio_encoder"](
60
+ d_model=1280,
61
+ encoder_attention_heads=20,
62
+ encoder_ffn_dim=5120,
63
+ encoder_layerdrop=0.0,
64
+ encoder_layers=32,
65
+ num_mel_bins=128,
66
+ max_source_positions=1500,
67
+ scale_embedding=False,
68
+ activation_function="gelu",
69
+ )
70
+
71
+ self.audio_config = audio_config
72
+
73
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)