File size: 2,383 Bytes
05d640e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from dataclasses import dataclass, field
from typing import Dict, List, Optional


@dataclass(frozen=True)
class TextConfig:
    dim: int = 2048
    n_layers: int = 24
    vocab_size: int = 51200
    max_context: int = 2048
    n_heads: int = 32
    prefix_attn: int = 730


@dataclass(frozen=True)
class VisionConfig:
    enc_dim: int = 1152
    enc_patch_size: int = 14
    enc_n_layers: int = 27
    enc_ff_dim: int = 4304
    enc_n_heads: int = 16
    proj_out_dim: int = 2048
    crop_size: int = 378
    in_channels: int = 3
    max_crops: int = 12
    overlap_margin: int = 4
    proj_inner_dim: int = 8192


@dataclass(frozen=True)
class RegionConfig:
    dim: int = 2048
    coord_feat_dim: int = 256
    coord_out_dim: int = 1024
    size_feat_dim: int = 512
    size_out_dim: int = 2048
    inner_dim: int = 8192


@dataclass(frozen=True)
class TokenizerConfig:
    bos_id: int = 50256
    eos_id: int = 50256
    templates: Dict[str, Optional[Dict[str, List[int]]]] = field(
        default_factory=lambda: {
            "caption": {
                "short": [198, 198, 16438, 8305, 25],
                "normal": [198, 198, 24334, 1159, 25],
            },
            "query": {"prefix": [198, 198, 24361, 25], "suffix": [198, 198, 33706, 25]},
            "detect": {"prefix": [198, 198, 47504, 25], "suffix": [628]},
            "point": {"prefix": [198, 198, 12727, 25], "suffix": [628]},
        }
    )


@dataclass(frozen=True)
class MoondreamConfig:
    text: TextConfig = TextConfig()
    vision: VisionConfig = VisionConfig()
    region: RegionConfig = RegionConfig()
    tokenizer: TokenizerConfig = TokenizerConfig()

    @classmethod
    def from_dict(cls, config_dict: dict):
        text_config = TextConfig(**config_dict.get("text", {}))
        vision_config = VisionConfig(**config_dict.get("vision", {}))
        region_config = RegionConfig(**config_dict.get("region", {}))
        tokenizer_config = TokenizerConfig(**config_dict.get("tokenizer", {}))
        return cls(
            text=text_config,
            vision=vision_config,
            region=region_config,
            tokenizer=tokenizer_config,
        )

    def to_dict(self):
        return {
            "text": self.text.__dict__,
            "vision": self.vision.__dict__,
            "region": self.region.__dict__,
            "tokenizer": self.tokenizer.__dict__,
        }