voitl commited on
Commit
823e567
·
1 Parent(s): 5c15c1d

Upload HFUnetPlusPlus

Browse files
Files changed (4) hide show
  1. config.json +21 -0
  2. hf_config.py +17 -0
  3. hf_model.py +21 -0
  4. pytorch_model.bin +3 -0
config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "HFUnetPlusPlus"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "hf_config.UnetConfig",
7
+ "AutoModelForImageSegmentation": "hf_model.HFUnetPlusPlus"
8
+ },
9
+ "decoder_channels": [
10
+ 1024,
11
+ 512,
12
+ 256,
13
+ 128,
14
+ 64
15
+ ],
16
+ "encoder_name": "resnet18",
17
+ "input_channels": 1,
18
+ "num_classes": 16,
19
+ "torch_dtype": "float32",
20
+ "transformers_version": "4.24.0"
21
+ }
hf_config.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class UnetConfig(PretrainedConfig):
4
+
5
+ def __init__(
6
+ self,
7
+ encoder_name: str = "resnet18",
8
+ num_classes: int = 16,
9
+ input_channels: int = 1,
10
+ decoder_channels: tuple = (1024, 512, 256, 128, 64),
11
+ **kwargs
12
+ ):
13
+ self.encoder_name = encoder_name
14
+ self.num_classes = num_classes
15
+ self.input_channels = input_channels
16
+ self.decoder_channels = decoder_channels
17
+ super().__init__(**kwargs)
hf_model.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import segmentation_models_pytorch as smp
2
+ from hf_config import UnetConfig
3
+ from transformers import PreTrainedModel
4
+
5
+
6
+ class HFUnetPlusPlus(PreTrainedModel):
7
+ config_class = UnetConfig
8
+
9
+ def __init__(self, config):
10
+ super().__init__(config)
11
+
12
+ self.model = smp.UnetPlusPlus(
13
+ encoder_name=config.encoder_name,
14
+ encoder_weights="imagenet",
15
+ decoder_channels=config.decoder_channels,
16
+ in_channels=config.input_channels,
17
+ classes=config.num_classes,
18
+ decoder_attention_type="scse")
19
+
20
+ def forward(self, tensor):
21
+ return self.model(tensor)
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:11241ff685300ecf0d1314eabc7bf74223d2071bae43b7b3ce7859a01f599efd
3
+ size 164204733