from transformers import PretrainedConfig | |
from typing import List | |
class CetaceanClassifierConfig(PretrainedConfig): | |
model_type = "cetaceanet" | |
def __init__( | |
self, | |
# block_type="bottleneck", | |
# layers: List[int] = [3, 4, 6, 3], | |
# num_classes: int = 1000, | |
# input_channels: int = 3, | |
# cardinality: int = 1, | |
# base_width: int = 64, | |
# stem_width: int = 64, | |
# stem_type: str = "", | |
# avg_down: bool = False, | |
**kwargs, | |
): | |
# if block_type not in ["basic", "bottleneck"]: | |
# raise ValueError(f"`block_type` must be 'basic' or bottleneck', got {block_type}.") | |
# if stem_type not in ["", "deep", "deep-tiered"]: | |
# raise ValueError(f"`stem_type` must be '', 'deep' or 'deep-tiered', got {stem_type}.") | |
# self.block_type = block_type | |
# self.layers = layers | |
# self.num_classes = num_classes | |
# self.input_channels = input_channels | |
# self.cardinality = cardinality | |
# self.base_width = base_width | |
# self.stem_width = stem_width | |
# self.stem_type = stem_type | |
# self.avg_down = avg_down | |
super().__init__(**kwargs) | |