Ole-Christian Galbo Engstrøm commited on
Commit
becf17e
·
1 Parent(s): 9e3752d

Cleanup and update README.

Browse files
Files changed (5) hide show
  1. README.md +9 -1
  2. config.json +0 -9
  3. requirements.txt +0 -1
  4. unet_config.py +0 -21
  5. unet_hf.py +0 -20
README.md CHANGED
@@ -16,7 +16,7 @@ pipeline_tag: image-segmentation
16
  This repository contains an implementation of U-Net [[1]](#references). [unet.py](./unet.py) implements the class UNet. The implementation has been tested with PyTorch 2.7.1 and CUDA 12.6.
17
  ![](./images/unet_diagram.png)
18
 
19
- You can also load the U-Net from PyTorch Hub.
20
  ```python
21
  import torch
22
 
@@ -32,6 +32,14 @@ model = torch.hub.load('sm00thix/unet', 'unet', pretrained=False, in_channels=3,
32
  # model = torch.hub.load('sm00thix/unet', 'unet_transconv', **kwargs) # Convenience function equivalent to torch.hub.load('sm00thix/unet', 'unet', bilinear=False, **kwargs)
33
  ```
34
 
 
 
 
 
 
 
 
 
35
  ## Options
36
  The UNet class provides the following options for customization.
37
 
 
16
  This repository contains an implementation of U-Net [[1]](#references). [unet.py](./unet.py) implements the class UNet. The implementation has been tested with PyTorch 2.7.1 and CUDA 12.6.
17
  ![](./images/unet_diagram.png)
18
 
19
+ You can load the U-Net from PyTorch Hub.
20
  ```python
21
  import torch
22
 
 
32
  # model = torch.hub.load('sm00thix/unet', 'unet_transconv', **kwargs) # Convenience function equivalent to torch.hub.load('sm00thix/unet', 'unet', bilinear=False, **kwargs)
33
  ```
34
 
35
+ You can also clone this repository to access the U-Net directly.
36
+ ```python
37
+ import torch
38
+ from unet import UNet
39
+
40
+ model = UNet(in_channels=3, out_channels=1, pad=True, bilinear=True, normalization=None)
41
+ ```
42
+
43
  ## Options
44
  The UNet class provides the following options for customization.
45
 
config.json DELETED
@@ -1,9 +0,0 @@
1
- {
2
- "model_type": "unet",
3
- "architectures": ["UNet"],
4
- "in_channels": 3,
5
- "out_channels": 1,
6
- "pad": true,
7
- "bilinear": true,
8
- "normalization": null
9
- }
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,2 +1 @@
1
  torch >= 2.7.1
2
- transformers >= 4.55.2
 
1
  torch >= 2.7.1
 
unet_config.py DELETED
@@ -1,21 +0,0 @@
1
- from transformers import PretrainedConfig
2
-
3
-
4
- class UNetConfig(PretrainedConfig):
5
- model_type = "unet"
6
-
7
- def __init__(
8
- self,
9
- in_channels=3,
10
- out_channels=1,
11
- pad=True,
12
- bilinear=True,
13
- normalization=None,
14
- **kwargs,
15
- ):
16
- super().__init__(**kwargs)
17
- self.in_channels = in_channels
18
- self.out_channels = out_channels
19
- self.pad = pad
20
- self.bilinear = bilinear
21
- self.normalization = normalization
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unet_hf.py DELETED
@@ -1,20 +0,0 @@
1
- from transformers import PreTrainedModel
2
- from .unet import UNet
3
- from .unet_config import UNetConfig
4
-
5
-
6
- class UNetModel(PreTrainedModel):
7
- config_class = UNetConfig
8
-
9
- def __init__(self, config: UNetConfig):
10
- super().__init__(config)
11
- self.model = UNet(
12
- in_channels=config.in_channels,
13
- out_channels=config.out_channels,
14
- pad=config.pad,
15
- bilinear=config.bilinear,
16
- normalization=config.normalization,
17
- )
18
-
19
- def forward(self, x):
20
- return self.model(x)