root
		
	commited on
		
		
					Commit 
							
							·
						
						3fee777
	
1
								Parent(s):
							
							0604e3c
								
fix
Browse files- configuration_sa2va_chat.py +6 -10
    	
        configuration_sa2va_chat.py
    CHANGED
    
    | @@ -19,7 +19,6 @@ logger = logging.get_logger(__name__) | |
| 19 |  | 
| 20 | 
             
            class Sa2VAChatConfig(PretrainedConfig):
         | 
| 21 | 
             
                model_type = 'sa2va_chat'
         | 
| 22 | 
            -
                is_composition = True
         | 
| 23 |  | 
| 24 | 
             
                def __init__(
         | 
| 25 | 
             
                        self,
         | 
| @@ -40,25 +39,22 @@ class Sa2VAChatConfig(PretrainedConfig): | |
| 40 | 
             
                        **kwargs):
         | 
| 41 | 
             
                    super().__init__(**kwargs)
         | 
| 42 | 
             
                    if vision_config is None:
         | 
| 43 | 
            -
                        vision_config = {}
         | 
| 44 | 
             
                        logger.info('vision_config is None. Initializing the InternVisionConfig with default values.')
         | 
| 45 |  | 
| 46 | 
             
                    if llm_config is None:
         | 
| 47 | 
            -
                        llm_config = {}
         | 
| 48 | 
             
                        logger.info('llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`).')
         | 
| 49 |  | 
| 50 | 
             
                    self.vision_config = InternVisionConfig(**vision_config)
         | 
| 51 | 
            -
             | 
| 52 | 
            -
                    if llm_config['architectures'][0] == 'LlamaForCausalLM':
         | 
| 53 | 
             
                        self.llm_config = LlamaConfig(**llm_config)
         | 
| 54 | 
            -
                    elif llm_config | 
| 55 | 
             
                        self.llm_config = InternLM2Config(**llm_config)
         | 
| 56 | 
            -
                    elif llm_config | 
| 57 | 
            -
                        self.llm_config = Phi3Config(**llm_config)
         | 
| 58 | 
            -
                    elif llm_config['architectures'][0] == 'Qwen2ForCausalLM':
         | 
| 59 | 
             
                        self.llm_config = Qwen2Config(**llm_config)
         | 
| 60 | 
             
                    else:
         | 
| 61 | 
            -
                        raise ValueError('Unsupported architecture: {}'.format(llm_config | 
| 62 | 
             
                    self.use_backbone_lora = use_backbone_lora
         | 
| 63 | 
             
                    self.use_llm_lora = use_llm_lora
         | 
| 64 | 
             
                    self.pad2square = pad2square
         | 
|  | |
| 19 |  | 
| 20 | 
             
            class Sa2VAChatConfig(PretrainedConfig):
         | 
| 21 | 
             
                model_type = 'sa2va_chat'
         | 
|  | |
| 22 |  | 
| 23 | 
             
                def __init__(
         | 
| 24 | 
             
                        self,
         | 
|  | |
| 39 | 
             
                        **kwargs):
         | 
| 40 | 
             
                    super().__init__(**kwargs)
         | 
| 41 | 
             
                    if vision_config is None:
         | 
| 42 | 
            +
                        vision_config = {'architectures': ['InternVisionModel']}
         | 
| 43 | 
             
                        logger.info('vision_config is None. Initializing the InternVisionConfig with default values.')
         | 
| 44 |  | 
| 45 | 
             
                    if llm_config is None:
         | 
| 46 | 
            +
                        llm_config = {'architectures': ['InternLM2ForCausalLM']}
         | 
| 47 | 
             
                        logger.info('llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`).')
         | 
| 48 |  | 
| 49 | 
             
                    self.vision_config = InternVisionConfig(**vision_config)
         | 
| 50 | 
            +
                    if llm_config.get('architectures')[0] == 'LlamaForCausalLM':
         | 
|  | |
| 51 | 
             
                        self.llm_config = LlamaConfig(**llm_config)
         | 
| 52 | 
            +
                    elif llm_config.get('architectures')[0] == 'InternLM2ForCausalLM':
         | 
| 53 | 
             
                        self.llm_config = InternLM2Config(**llm_config)
         | 
| 54 | 
            +
                    elif llm_config.get('architectures')[0] == 'Qwen2ForCausalLM':
         | 
|  | |
|  | |
| 55 | 
             
                        self.llm_config = Qwen2Config(**llm_config)
         | 
| 56 | 
             
                    else:
         | 
| 57 | 
            +
                        raise ValueError('Unsupported architecture: {}'.format(llm_config.get('architectures')[0]))
         | 
| 58 | 
             
                    self.use_backbone_lora = use_backbone_lora
         | 
| 59 | 
             
                    self.use_llm_lora = use_llm_lora
         | 
| 60 | 
             
                    self.pad2square = pad2square
         | 
