File size: 2,383 Bytes
05d640e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
from dataclasses import dataclass, field
from typing import Dict, List, Optional
@dataclass(frozen=True)
class TextConfig:
dim: int = 2048
n_layers: int = 24
vocab_size: int = 51200
max_context: int = 2048
n_heads: int = 32
prefix_attn: int = 730
@dataclass(frozen=True)
class VisionConfig:
enc_dim: int = 1152
enc_patch_size: int = 14
enc_n_layers: int = 27
enc_ff_dim: int = 4304
enc_n_heads: int = 16
proj_out_dim: int = 2048
crop_size: int = 378
in_channels: int = 3
max_crops: int = 12
overlap_margin: int = 4
proj_inner_dim: int = 8192
@dataclass(frozen=True)
class RegionConfig:
dim: int = 2048
coord_feat_dim: int = 256
coord_out_dim: int = 1024
size_feat_dim: int = 512
size_out_dim: int = 2048
inner_dim: int = 8192
@dataclass(frozen=True)
class TokenizerConfig:
bos_id: int = 50256
eos_id: int = 50256
templates: Dict[str, Optional[Dict[str, List[int]]]] = field(
default_factory=lambda: {
"caption": {
"short": [198, 198, 16438, 8305, 25],
"normal": [198, 198, 24334, 1159, 25],
},
"query": {"prefix": [198, 198, 24361, 25], "suffix": [198, 198, 33706, 25]},
"detect": {"prefix": [198, 198, 47504, 25], "suffix": [628]},
"point": {"prefix": [198, 198, 12727, 25], "suffix": [628]},
}
)
@dataclass(frozen=True)
class MoondreamConfig:
text: TextConfig = TextConfig()
vision: VisionConfig = VisionConfig()
region: RegionConfig = RegionConfig()
tokenizer: TokenizerConfig = TokenizerConfig()
@classmethod
def from_dict(cls, config_dict: dict):
text_config = TextConfig(**config_dict.get("text", {}))
vision_config = VisionConfig(**config_dict.get("vision", {}))
region_config = RegionConfig(**config_dict.get("region", {}))
tokenizer_config = TokenizerConfig(**config_dict.get("tokenizer", {}))
return cls(
text=text_config,
vision=vision_config,
region=region_config,
tokenizer=tokenizer_config,
)
def to_dict(self):
return {
"text": self.text.__dict__,
"vision": self.vision.__dict__,
"region": self.region.__dict__,
"tokenizer": self.tokenizer.__dict__,
}
|