|
from typing import Optional, Union, Dict, Any |
|
|
|
from transformers import Gemma3TextConfig, SiglipVisionConfig, PretrainedConfig |
|
from transformers.utils import logging |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class Gemma3OmniConfig(PretrainedConfig): |
|
model_type = "gemma3omni" |
|
attribute_map = { |
|
"image_token_id": "image_token_index", |
|
"audio_token_id": "audio_token_index", |
|
"boi_token_id": "boi_token_index", |
|
"eoi_token_id": "eoi_token_index", |
|
} |
|
sub_configs = { |
|
"text_config": Gemma3TextConfig, |
|
"vision_config": SiglipVisionConfig, |
|
} |
|
|
|
def __init__( |
|
self, |
|
text_config: Optional[Union[Gemma3TextConfig, Dict[str, Any]]] = None, |
|
vision_config: Optional[Union[SiglipVisionConfig, Dict[str, Any]]] = None, |
|
mm_tokens_per_image: int = 256, |
|
boi_token_index: int = 255_999, |
|
eoi_token_index: int = 256_000, |
|
image_token_index: int = 262_144, |
|
audio_token_index: int = 262_151, |
|
initializer_range: float = 0.02, |
|
**kwargs, |
|
): |
|
if text_config is None: |
|
text_config = Gemma3TextConfig() |
|
logger.info("text_config is None, using default Gemma3TextConfig text config.") |
|
elif isinstance(text_config, dict): |
|
text_config = Gemma3TextConfig(**text_config) |
|
|
|
if isinstance(vision_config, dict): |
|
vision_config = SiglipVisionConfig(**vision_config) |
|
elif vision_config is None: |
|
vision_config = SiglipVisionConfig() |
|
logger.info("vision_config is None, using default SiglipVisionConfig vision config.") |
|
|
|
self.text_config = text_config |
|
self.vision_config = vision_config |
|
self.mm_tokens_per_image = mm_tokens_per_image |
|
self.boi_token_index = boi_token_index |
|
self.eoi_token_index = eoi_token_index |
|
self.image_token_index = image_token_index |
|
self.audio_token_index = audio_token_index |
|
self.initializer_range = initializer_range |
|
|
|
super().__init__(**kwargs) |
|
|