| from transformers import PreTrainedModel | |
| from .unet import UNet | |
| from .unet_config import UNetConfig | |
| class UNetModel(PreTrainedModel): | |
| config_class = UNetConfig | |
| def __init__(self, config: UNetConfig): | |
| super().__init__(config) | |
| self.model = UNet( | |
| in_channels=config.in_channels, | |
| out_channels=config.out_channels, | |
| pad=config.pad, | |
| bilinear=config.bilinear, | |
| normalization=config.normalization, | |
| ) | |
| def forward(self, x): | |
| return self.model(x) | |