root commited on
Commit
3fee777
·
1 Parent(s): 0604e3c
Files changed (1) hide show
  1. configuration_sa2va_chat.py +6 -10
configuration_sa2va_chat.py CHANGED
@@ -19,7 +19,6 @@ logger = logging.get_logger(__name__)
19
 
20
  class Sa2VAChatConfig(PretrainedConfig):
21
  model_type = 'sa2va_chat'
22
- is_composition = True
23
 
24
  def __init__(
25
  self,
@@ -40,25 +39,22 @@ class Sa2VAChatConfig(PretrainedConfig):
40
  **kwargs):
41
  super().__init__(**kwargs)
42
  if vision_config is None:
43
- vision_config = {}
44
  logger.info('vision_config is None. Initializing the InternVisionConfig with default values.')
45
 
46
  if llm_config is None:
47
- llm_config = {}
48
  logger.info('llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`).')
49
 
50
  self.vision_config = InternVisionConfig(**vision_config)
51
-
52
- if llm_config['architectures'][0] == 'LlamaForCausalLM':
53
  self.llm_config = LlamaConfig(**llm_config)
54
- elif llm_config['architectures'][0] == 'InternLM2ForCausalLM':
55
  self.llm_config = InternLM2Config(**llm_config)
56
- elif llm_config['architectures'][0] == 'Phi3ForCausalLM':
57
- self.llm_config = Phi3Config(**llm_config)
58
- elif llm_config['architectures'][0] == 'Qwen2ForCausalLM':
59
  self.llm_config = Qwen2Config(**llm_config)
60
  else:
61
- raise ValueError('Unsupported architecture: {}'.format(llm_config['architectures'][0]))
62
  self.use_backbone_lora = use_backbone_lora
63
  self.use_llm_lora = use_llm_lora
64
  self.pad2square = pad2square
 
19
 
20
  class Sa2VAChatConfig(PretrainedConfig):
21
  model_type = 'sa2va_chat'
 
22
 
23
  def __init__(
24
  self,
 
39
  **kwargs):
40
  super().__init__(**kwargs)
41
  if vision_config is None:
42
+ vision_config = {'architectures': ['InternVisionModel']}
43
  logger.info('vision_config is None. Initializing the InternVisionConfig with default values.')
44
 
45
  if llm_config is None:
46
+ llm_config = {'architectures': ['InternLM2ForCausalLM']}
47
  logger.info('llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`).')
48
 
49
  self.vision_config = InternVisionConfig(**vision_config)
50
+ if llm_config.get('architectures')[0] == 'LlamaForCausalLM':
 
51
  self.llm_config = LlamaConfig(**llm_config)
52
+ elif llm_config.get('architectures')[0] == 'InternLM2ForCausalLM':
53
  self.llm_config = InternLM2Config(**llm_config)
54
+ elif llm_config.get('architectures')[0] == 'Qwen2ForCausalLM':
 
 
55
  self.llm_config = Qwen2Config(**llm_config)
56
  else:
57
+ raise ValueError('Unsupported architecture: {}'.format(llm_config.get('architectures')[0]))
58
  self.use_backbone_lora = use_backbone_lora
59
  self.use_llm_lora = use_llm_lora
60
  self.pad2square = pad2square