Source code for aitlas.models.unet_efficientnet

import copy
import glob
import os
import shutil
from functools import partial
from math import ceil
from multiprocessing import Pool, Queue

import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from geffnet.conv2d_layers import select_conv2d
from geffnet.efficientnet_builder import (
    BN_EPS_TF_DEFAULT,
    EfficientNetBuilder,
    decode_arch_def,
    initialize_weight_default,
    initialize_weight_goog,
    resolve_bn_args,
    round_channels,
)

try:
    import gdal
except ModuleNotFoundError as err:
    from osgeo import gdal

from rasterio import features
from shapely.geometry import shape
from shapely.wkt import dumps
from skimage import io, measure
from skimage.segmentation import watershed
from solaris.eval.base import Evaluator
from torch.hub import load_state_dict_from_url
from tqdm import tqdm

from ..base import BaseSegmentationClassifier
from ..datasets import SpaceNet6Dataset
from ..models.schemas import UNetEfficientNetModelSchema


[docs]def post_process(prediction_directory, prediction_csv): np.seterr(over="ignore") source_files = sorted(glob.glob(os.path.join(prediction_directory, "*"))) with Pool() as pool: proposals = [ p for p in tqdm( pool.imap_unordered(partial(post_process_single), source_files), total=len(source_files), ) ] if len(proposals) > 0: pd.concat(proposals).to_csv(prediction_csv, index=False)
[docs]def post_process_single( sourcefile, watershed_line=True, conn=2, polygon_buffer=0.5, tolerance=0.5, seed_msk_th=0.75, area_th_for_seed=110, prediction_threshold=0.5, area_th=80, contact_weight=1.0, edge_weight=0.0, seed_contact_weight=1.0, seed_edge_weight=1.0, ): mask = gdal.Open(sourcefile).ReadAsArray() mask = 1.0 / (1 + np.exp(-mask)) mask[0] = mask[0] * (1 - contact_weight * mask[2]) * (1 - edge_weight * mask[1]) seed_msk = ( mask[0] * (1 - seed_contact_weight * mask[2]) * (1 - seed_edge_weight * mask[1]) ) seed_msk = measure.label((seed_msk > seed_msk_th), connectivity=conn, background=0) props = measure.regionprops(seed_msk) for i in range(len(props)): if props[i].area < area_th_for_seed: seed_msk[seed_msk == i + 1] = 0 seed_msk = measure.label(seed_msk, connectivity=conn, background=0) mask = watershed( -mask[0], seed_msk, mask=(mask[0] > prediction_threshold), watershed_line=watershed_line, ) mask = measure.label(mask, connectivity=conn, background=0).astype("uint8") polygon_generator = features.shapes(mask, mask) polygons = [] for polygon, value in polygon_generator: p = shape(polygon).buffer(polygon_buffer) if p.area >= area_th: p = dumps(p.simplify(tolerance=tolerance), rounding_precision=0) polygons.append(p) tile_name = "_".join( os.path.splitext(os.path.basename(sourcefile))[0].split("_")[-4:] ) csv_addition = pd.DataFrame( { "ImageId": tile_name, "BuildingId": range(len(polygons)), "PolygonWKT_Pix": polygons, "Confidence": 1, } ) return csv_addition
[docs]def evaluation(prediction_csv, gt_csv): evaluator = Evaluator(gt_csv) evaluator.load_proposal(prediction_csv, proposalCSV=True, conf_field_list=[]) report = evaluator.eval_iou_spacenet_csv(miniou=0.5) tp = 0 fp = 0 fn = 0 for entry in report: tp += entry["TruePos"] fp += entry["FalsePos"] fn += entry["FalseNeg"] f1score = (2 * tp) / (2 * tp + fp + fn) # print("Validation F1 {} tp {} fp {} fn {}".format(f1score, tp, fp, fn)) return f1score
[docs]class FocalLoss2d(nn.Module): def __init__(self, gamma=3, ignore_index=255, eps=1e-6): super().__init__() self.gamma = gamma self.ignore_index = ignore_index self.eps = eps
[docs] def forward(self, outputs, targets, weights=1.0): outputs = torch.sigmoid(outputs) outputs = outputs.contiguous() targets = targets.contiguous() weights = weights.contiguous() non_ignored = targets.view(-1) != self.ignore_index targets = targets.view(-1)[non_ignored].float() outputs = outputs.contiguous().view(-1)[non_ignored] weights = weights.contiguous().view(-1)[non_ignored] outputs = torch.clamp(outputs, self.eps, 1.0 - self.eps) targets = torch.clamp(targets, self.eps, 1.0 - self.eps) pt = (1 - targets) * (1 - outputs) + targets * outputs return ((-((1.0 - pt) ** self.gamma) * torch.log(pt)) * weights).mean()
[docs]class DiceLoss(nn.Module): def __init__(self, weight=None, per_image=False, eps=1e-6): super().__init__() self.register_buffer("weight", weight) self.per_image = per_image self.eps = eps
[docs] def forward(self, outputs, targets): outputs = torch.sigmoid(outputs) batch_size = outputs.size()[0] if not self.per_image: batch_size = 1 dice_target = targets.contiguous().view(batch_size, -1).float() dice_output = outputs.contiguous().view(batch_size, -1) intersection = torch.sum(dice_output * dice_target, dim=1) union = torch.sum(dice_output, dim=1) + torch.sum(dice_target, dim=1) + self.eps loss = (1 - (2 * intersection + self.eps) / union).mean() return loss
[docs]class GenEfficientNet(nn.Module): def __init__( self, block_args, num_classes=1000, in_channels=3, num_features=1280, stem_size=32, fix_stem=False, channel_multiplier=1.0, channel_divisor=8, channel_min=None, pad_type="", act_layer=nn.ReLU, drop_connect_rate=0.0, se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, weight_init="goog", ): super(GenEfficientNet, self).__init__() stem_size = round_channels( stem_size, channel_multiplier, channel_divisor, channel_min ) self.conv_stem = select_conv2d( in_channels, stem_size, 3, stride=2, padding=pad_type ) self.bn1 = norm_layer(stem_size, **norm_kwargs) self.act1 = act_layer(inplace=True) builder = EfficientNetBuilder( channel_multiplier, channel_divisor, channel_min, pad_type, act_layer, se_kwargs, norm_layer, norm_kwargs, drop_connect_rate, ) self.blocks = nn.Sequential(*builder(stem_size, block_args)) self.conv_head = select_conv2d( builder.in_chs, num_features, 1, padding=pad_type ) self.bn2 = norm_layer(num_features, **norm_kwargs) self.act2 = act_layer(inplace=True) self.global_pool = nn.AdaptiveAvgPool2d(1) self.classifier = nn.Linear(num_features, num_classes) for module in self.modules(): if weight_init == "goog": initialize_weight_goog(module) else: initialize_weight_default(module)
[docs]class UNetEfficientNet(BaseSegmentationClassifier): """Unet EfficientNet model implementation. .. note:: Based on <https://github.com/SpaceNetChallenge/SpaceNet_SAR_Buildings_Solutions/blob/master/1-zbigniewwojna/main.py#L178> """ schema = UNetEfficientNetModelSchema def __init__(self, config): """ :param config : the configuration for this model :type config : UNetEfficientNetModelSchema """ super().__init__(config) dec_ch = [32, 64, 128, 256, 1024] # Placeholder variables to avoid "might be referenced before assignment" warnings enc_ch = list() depth_multiplier, channel_multiplier = 1, 1 url = "" if self.config.net == "b4": channel_multiplier = 1.4 depth_multiplier = 1.8 url = "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ns-d6313a46.pth" enc_ch = [24, 32, 56, 160, 1792] if self.config.net == "b5": channel_multiplier = 1.6 depth_multiplier = 2.2 url = "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ns-6f26d0cf.pth" enc_ch = [24, 40, 64, 176, 2048] if self.config.net == "b6": channel_multiplier = 1.8 depth_multiplier = 2.6 url = "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ns-51548356.pth" enc_ch = [32, 40, 72, 200, 2304] if self.config.net == "b7": channel_multiplier = 2.0 depth_multiplier = 3.1 url = "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth" enc_ch = [32, 48, 80, 224, 2560] if self.config.net == "l2": channel_multiplier = 4.3 depth_multiplier = 5.3 url = "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth" enc_ch = [72, 104, 176, 480, 5504] if self.config.stride == 16: dec_ch[4] = enc_ch[4] elif self.config.stride == 8: dec_ch[3] = enc_ch[4] def mod(cin, cout, k=3): return nn.Sequential( nn.Conv2d(cin, cout, k, padding=k // 2), nn.ReLU(inplace=True) ) self.model.bot0extra = mod(206, enc_ch[4]) self.model.bot1extra = mod(206, dec_ch[4]) self.model.bot2extra = mod(206, dec_ch[3]) self.model.bot3extra = mod(206, dec_ch[2]) self.model.bot4extra = mod(206, dec_ch[1]) self.model.bot5extra = mod(206, 6) self.model.dec0 = mod(enc_ch[4], dec_ch[4]) self.model.dec1 = mod(dec_ch[4], dec_ch[3]) self.model.dec2 = mod(dec_ch[3], dec_ch[2]) self.model.dec3 = mod(dec_ch[2], dec_ch[1]) self.model.dec4 = mod(dec_ch[1], dec_ch[0]) self.model.bot0 = mod(enc_ch[3] + dec_ch[4], dec_ch[4]) self.model.bot1 = mod(enc_ch[2] + dec_ch[3], dec_ch[3]) self.model.bot2 = mod(enc_ch[1] + dec_ch[2], dec_ch[2]) self.model.bot3 = mod(enc_ch[0] + dec_ch[1], dec_ch[1]) self.model.up = nn.Upsample(scale_factor=2) self.model.upps = nn.PixelShuffle(upscale_factor=2) self.model.final = nn.Conv2d(dec_ch[0], 6, 1) self._initialize_weights() arch_def = [ ["ds_r1_k3_s1_e1_c16_se0.25"], ["ir_r2_k3_s2_e6_c24_se0.25"], ["ir_r2_k5_s2_e6_c40_se0.25"], ["ir_r3_k3_s2_e6_c80_se0.25"], ["ir_r3_k5_s1_e6_c112_se0.25"], ["ir_r4_k5_s2_e6_c192_se0.25"], ["ir_r1_k3_s1_e6_c320_se0.25"], ] enc = GenEfficientNet( in_channels=3, block_args=decode_arch_def(arch_def, depth_multiplier), num_features=round_channels(1280, channel_multiplier, 8, None), stem_size=32, channel_multiplier=channel_multiplier, norm_kwargs=resolve_bn_args({"bn_eps": BN_EPS_TF_DEFAULT}), pad_type="same", ) state_dict = load_state_dict_from_url(url) enc.load_state_dict(state_dict, strict=True) stem_size = round_channels(32, channel_multiplier, 8, None) conv_stem = select_conv2d(4, stem_size, 3, stride=2, padding="same") _w = enc.conv_stem.state_dict() _w["weight"] = torch.cat([_w["weight"], _w["weight"][:, 1:2]], 1) conv_stem.load_state_dict(_w) self.model.enc0 = nn.Sequential(conv_stem, enc.bn1, enc.act1, enc.blocks[0]) self.model.enc1 = nn.Sequential(enc.blocks[1]) self.model.enc2 = nn.Sequential(enc.blocks[2]) self.model.enc3 = nn.Sequential(enc.blocks[3], enc.blocks[4]) self.model.enc4 = nn.Sequential( enc.blocks[5], enc.blocks[6], enc.conv_head, enc.bn2, enc.act2 )
[docs] def forward(self, x, strip, direction, coord): enc0 = self.model.enc0(x) enc1 = self.model.enc1(enc0) enc2 = self.model.enc2(enc1) enc3 = self.model.enc3(enc2) enc4 = self.model.enc4(enc3) ex = torch.cat([strip, direction, coord], 1) x = enc4 if self.config.stride == 32: x = self.model.dec0(self.model.up(x + self.model.bot0extra(ex))) x = torch.cat([x, enc3], dim=1) x = self.model.bot0(x) if self.config.stride == 32 or self.config.stride == 16: x = self.model.dec1(self.model.up(x + self.model.bot1extra(ex))) x = torch.cat([x, enc2], dim=1) x = self.model.bot1(x) x = self.model.dec2(self.model.up(x)) x = torch.cat([x, enc1], dim=1) x = self.model.bot2(x) x = self.model.dec3(self.model.up(x)) x = torch.cat([x, enc0], dim=1) x = self.model.bot3(x) x = self.model.dec4(self.model.up(x)) x = self.model.final(x) return x
def _initialize_weights(self): for module in self.modules(): if isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d): module.weight.data = nn.init.kaiming_normal_(module.weight.data) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, torch.nn.BatchNorm2d): module.weight.data.fill_(1) module.bias.data.zero_()
[docs] def load_optimizer(self): return torch.optim.AdamW(self.model.parameters(), lr=2e-4, weight_decay=1e-2)
[docs] def load_lr_scheduler(self): return torch.optim.lr_scheduler.MultiStepLR( self.load_optimizer(), milestones=[80, 100, 120], gamma=0.5 )
[docs] def train_and_evaluate_model( self, train_dataset: SpaceNet6Dataset, epochs: int = 100, model_directory: str = None, save_epochs: int = 10, iterations_log: int = 100, resume_model: str = None, val_dataset: SpaceNet6Dataset = None, run_id: str = None, **kwargs, ): """Overridden method for training on the SpaceNet6 data set.""" contact_weight = train_dataset.config.contact_weight edge_weight = train_dataset.config.edge_weight fold = 0 pred_folder = train_dataset.config.pred_folder.format(fold) # Initialize loss functions dice_loss = DiceLoss().to(self.device) focal_loss = FocalLoss2d().to(self.device) # Load training data set train_dataset.load_other_folds(fold) train_data_loader = train_dataset.dataloader() # Load validation data set val_dataset.load_fold(fold) val_data_loader = val_dataset.dataloader() # Initialize optimizer and lr scheduler optimizer = self.load_optimizer() scheduler = self.load_lr_scheduler() best_f1_score = -1 # Kick off training self.model.to(self.device) for epoch in range(epochs): iterator = tqdm(train_data_loader) self.model.train() # For each batch (i.e. sample) for sample in iterator: images = sample["image"].to(self.device) strip = sample["strip"].to(self.device) direction = sample["direction"].to(self.device) coord = sample["coordinate"].to(self.device) target = sample["mask"].to(self.device) building_count = sample["b_count"].to(self.device) / 8 building_weight = building_count * 0.5 + 0.5 weights = torch.ones(size=target.shape).to(self.device) weights[target > 0.0] *= 0.5 for i in range(weights.shape[0]): weights[i] = weights[i] * building_weight[i] output = self.forward(images, strip, direction, coord) if isinstance(output, tuple): output = output[0] l0 = focal_loss(output[:, 0], target[:, 0], weights[:, 0]) + dice_loss( output[:, 0], target[:, 0] ) l1 = edge_weight * ( focal_loss(output[:, 1], target[:, 1], weights[:, 1]) + dice_loss(output[:, 1], target[:, 1]) ) l2 = contact_weight * ( focal_loss(output[:, 2], target[:, 2], weights[:, 2]) + dice_loss(output[:, 2], target[:, 2]) ) loss = l0 + l1 + l2 optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.2) optimizer.step() iterator.set_description( "epoch: {}; lr {:.5f}; loss {:.4f}".format( epoch, scheduler.get_lr()[-1], loss ) ) scheduler.step() torch.save( {"epoch": epoch, "state_dict": self.model.state_dict()}, os.path.join(model_directory, "last_model"), ) if epoch >= val_dataset.config.start_val_epoch: shutil.rmtree(pred_folder, ignore_errors=True) os.makedirs(pred_folder, exist_ok=True) self.model.eval() with torch.no_grad(): for sample in tqdm(val_data_loader): images = sample["image"].to(self.device) ymin, xmin = sample["ymin"].item(), sample["xmin"].item() strip = sample["strip"].to(self.device) direction = sample["direction"].to(self.device) coord = sample["coordinate"].to(self.device) _, _, h, w = images.shape scales = [0.8, 1.0, 1.5] flips = [lambda x: x, lambda x: torch.flip(x, (3,))] rots = [ (lambda x: torch.rot90(x, i, (2, 3))) for i in range(0, 1) ] rots2 = [ (lambda x: torch.rot90(x, 4 - i, (2, 3))) for i in range(0, 1) ] oos = torch.zeros( (images.shape[0], 6, images.shape[2], images.shape[3]) ).to(self.device) for sc in scales: im = F.interpolate( images, size=(ceil(h * sc / 32) * 32, ceil(w * sc / 32) * 32), mode="bilinear", align_corners=True, ) for fl in flips: for i, rot in enumerate(rots): o = self.forward( rot(fl(im)), strip, direction, coord ) if isinstance(o, tuple): o = o[0] oos += F.interpolate( fl(rots2[i](o)), size=(h, w), mode="bilinear", align_corners=True, ) o = oos / (len(scales) * len(flips) * len(rots)) o = np.moveaxis(o.cpu().data.numpy(), 1, 3) for i in range(len(o)): img = o[i][:, :, :3] if direction[i].item(): img = np.fliplr(np.flipud(img)) img = cv2.copyMakeBorder( img, ymin, 900 - h - ymin, xmin, 900 - w - xmin, cv2.BORDER_CONSTANT, 0.0, ) io.imsave( os.path.join( pred_folder, os.path.split(sample["image_path"][i])[1], ), img, ) to_save = { k: copy.deepcopy(v.cpu()) for k, v in self.model.state_dict().items() } pred_csv = val_dataset.config.pred_csv.format(fold) gt_csv = val_dataset.config.gt_csv.format(fold) post_process(pred_folder, pred_csv) val_f1 = evaluation(pred_csv, gt_csv) if best_f1_score < val_f1: torch.save( {"epoch": epoch, "state_dict": to_save, "best_score": val_f1}, os.path.join(model_directory, "best_model"), ) best_f1_score = max(best_f1_score, val_f1)
[docs] def evaluate(self, dataset: SpaceNet6Dataset = None, model_path: str = None): # Load the model self.load_model(model_path) # evaluate model on data fold = 3 # [0, 6, 9, 1, 2, 7, 8] # load data into the data set dataset.load_fold(fold) # get test data loader data_loader = dataset.dataloader() pred_folder = dataset.config.pred_folder.format(fold) # Enforce a clean do-over everytime by re-creating the destination prediction directory shutil.rmtree(pred_folder, ignore_errors=True) os.makedirs(pred_folder, exist_ok=True) # Set model to eval mode self.model.eval() with torch.no_grad(): for sample in tqdm(data_loader): images = sample["image"].to(self.device) ymin, xmin = sample["ymin"].item(), sample["xmin"].item() strip = sample["strip"].to(self.device) direction = sample["direction"].to(self.device) coord = sample["coordinate"].to(self.device) _, _, h, w = images.shape scales = [0.8, 1.0, 1.5] flips = [lambda x: x, lambda x: torch.flip(x, (3,))] rots = [(lambda x: torch.rot90(x, i, (2, 3))) for i in range(0, 1)] rots2 = [(lambda x: torch.rot90(x, 4 - i, (2, 3))) for i in range(0, 1)] oos = torch.zeros( (images.shape[0], 6, images.shape[2], images.shape[3]) ).to(self.device) for sc in scales: im = F.interpolate( images, size=(ceil(h * sc / 32) * 32, ceil(w * sc / 32) * 32), mode="bilinear", align_corners=True, ) for fl in flips: for i, rot in enumerate(rots): o = self.forward(rot(fl(im)), strip, direction, coord) if isinstance(o, tuple): o = o[0] oos += F.interpolate( fl(rots2[i](o)), size=(h, w), mode="bilinear", align_corners=True, ) o = oos / (len(scales) * len(flips) * len(rots)) o = np.moveaxis(o.cpu().data.numpy(), 1, 3) for i in range(len(o)): img = o[i][:, :, :3] if direction[i].item(): img = np.fliplr(np.flipud(img)) img = cv2.copyMakeBorder( img, ymin, 900 - h - ymin, xmin, 900 - w - xmin, cv2.BORDER_CONSTANT, 0.0, ) io.imsave( os.path.join( pred_folder, os.path.split(sample["image_path"][i])[1] ), img, )
################################################################################################ # # Merge everything # shutil.rmtree(dataset.config.merged_pred_dir, ignore_errors=True) # os.makedirs(dataset.config.merged_pred_dir, exist_ok=True) # merge_folds = [0, 1, 2, 3, 6, 7, 8, 9] # predictions_folders = [dataset.config.pred_folder.format(i) for i in merge_folds] # for filename in tqdm(os.listdir(predictions_folders[0])): # used_masks = list() # for ff in predictions_folders: # if os.path.exists(os.path.join(ff, filename)): # used_masks.append(io.imread(os.path.join(ff, filename))) # mask = np.zeros_like(used_masks[0], dtype="float") # for used_mask in used_masks: # mask += used_mask.astype("float") / len(used_masks) # io.imsave(os.path.join(dataset.config.merged_pred_dir, filename), mask) # post_process(dataset.config.merged_pred_dir, dataset.config.solution_file)
[docs] def load_model(self, file_path, optimizer=None): loaded = torch.load(file_path) missing_keys = [] unexpected_keys = [] metadata = getattr(loaded["state_dict"], "_metadata", None) state_dict = loaded["state_dict"].copy() if metadata is not None: state_dict._metadata = metadata def load(module, prefix=""): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) module._load_from_state_dict( state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, [], ) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + ".") load(self.model)