Source code for aitlas.models.unet
"""UNet model for segmentation"""
import segmentation_models_pytorch as smp
from ..base import BaseSegmentationClassifier
[docs]class Unet(BaseSegmentationClassifier):
"""UNet segmentation model implementation.
.. note:: Based on <https://github.com/qubvel/segmentation_models.pytorch>"""
def __init__(self, config):
super().__init__(config)
self.model = smp.Unet(
encoder_name="resnet50",
encoder_weights="imagenet"
if self.config.pretrained
else None, # set pretrained weights for encoder
classes=self.config.num_classes,
)
[docs] def forward(self, x):
return self.model(x)