|
from transformers import PretrainedConfig
|
|
|
|
|
|
class SatDINOConfig(PretrainedConfig):
|
|
model_type = "satdino"
|
|
|
|
def __init__(
|
|
self,
|
|
img_size=[224],
|
|
patch_size=16,
|
|
in_chans=3,
|
|
num_classes=0,
|
|
embed_dim=768,
|
|
depth=12,
|
|
num_heads=12,
|
|
mlp_ratio=4.,
|
|
qkv_bias=False,
|
|
qk_scale=None,
|
|
drop_rate=0.,
|
|
attn_drop_rate=0.,
|
|
drop_path_rate=0.,
|
|
norm_layer=1e-6,
|
|
use_xformers=False,
|
|
pos_encoding_method="learnable",
|
|
**kwargs
|
|
):
|
|
self.img_size = img_size
|
|
self.patch_size = patch_size
|
|
self.in_chans = in_chans
|
|
self.num_classes = num_classes
|
|
self.embed_dim = embed_dim
|
|
self.depth = depth
|
|
self.num_heads = num_heads
|
|
self.mlp_ratio = mlp_ratio
|
|
self.qkv_bias = qkv_bias
|
|
self.qk_scale = qk_scale
|
|
self.drop_rate = drop_rate
|
|
self.attn_drop_rate = attn_drop_rate
|
|
self.drop_path_rate = drop_path_rate
|
|
self.norm_layer = norm_layer
|
|
self.use_xformers = use_xformers
|
|
self.pos_encoding_method = pos_encoding_method
|
|
super().__init__(**kwargs)
|
|
|
|
|