from transformers import PreTrainedModel from .unet import UNet from .unet_config import UNetConfig class UNetModel(PreTrainedModel): config_class = UNetConfig def __init__(self, config: UNetConfig): super().__init__(config) self.model = UNet( in_channels=config.in_channels, out_channels=config.out_channels, pad=config.pad, bilinear=config.bilinear, normalization=config.normalization, ) def forward(self, x): return self.model(x)