earica-audio-1b / configuration_gemma3_omni.py
voidful's picture
Upload configuration_gemma3_omni.py
dca6d25 verified
from typing import Optional, Union, Dict, Any
from transformers import Gemma3TextConfig, SiglipVisionConfig, PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class Gemma3OmniConfig(PretrainedConfig):
model_type = "gemma3omni"
attribute_map = {
"image_token_id": "image_token_index",
"audio_token_id": "audio_token_index",
"boi_token_id": "boi_token_index",
"eoi_token_id": "eoi_token_index",
}
sub_configs = {
"text_config": Gemma3TextConfig,
"vision_config": SiglipVisionConfig,
}
def __init__(
self,
text_config: Optional[Union[Gemma3TextConfig, Dict[str, Any]]] = None,
vision_config: Optional[Union[SiglipVisionConfig, Dict[str, Any]]] = None,
mm_tokens_per_image: int = 256,
boi_token_index: int = 255_999,
eoi_token_index: int = 256_000,
image_token_index: int = 262_144,
audio_token_index: int = 262_151,
initializer_range: float = 0.02,
**kwargs,
):
if text_config is None:
text_config = Gemma3TextConfig()
logger.info("text_config is None, using default Gemma3TextConfig text config.")
elif isinstance(text_config, dict):
text_config = Gemma3TextConfig(**text_config)
if isinstance(vision_config, dict):
vision_config = SiglipVisionConfig(**vision_config)
elif vision_config is None:
vision_config = SiglipVisionConfig()
logger.info("vision_config is None, using default SiglipVisionConfig vision config.")
self.text_config = text_config
self.vision_config = vision_config
self.mm_tokens_per_image = mm_tokens_per_image
self.boi_token_index = boi_token_index
self.eoi_token_index = eoi_token_index
self.image_token_index = image_token_index
self.audio_token_index = audio_token_index
self.initializer_range = initializer_range
super().__init__(**kwargs)