BigEarth Net 19 dataset: Example of the aitlas toolbox for multi label image classification#

This notebook shows a sample implementation of a multi label image classification using the aitlas toolbox using the Big Earth Net multi label dataset with 19 labels.

[ ]:
from aitlas.datasets import BigEarthNetDataset
from aitlas.models import ResNet50MultiLabel
from aitlas.transforms import ResizeCenterCropFlipHVToTensor, ResizeCenterCropToTensor
from aitlas.utils import image_loader

Load the dataset#

[ ]:
dataset_config = {
    "lmdb_path": "./data/BigEarthNet/lmdb",
    "import_to_lmdb": false,
    "csv_file": "./data/BigEarthNet/splits/train.csv",
    "data_dir": "./data/BigEarthNet/BigEarthNet-v1.0",
    "selection": "rgb",
    "version": "19 labels"
}
dataset = BigEarthNetDataset(dataset_config)

Show images from the dataset#

[ ]:
fig1 = dataset.show_image(1000)
fig2 = dataset.show_image(30)
fig3 = dataset.show_batch(15)

Inspect the data#

[ ]:
dataset.show_stats()

Load train and test splits#

[ ]:
train_dataset_config = {
    "batch_size": 16,
    "shuffle": True,
    "num_workers": 4,
    "lmdb_path": "./data/BigEarthNet/lmdb",
    "import_to_lmdb": false,
    "csv_file": "./data/BigEarthNet/splits/train.csv",
    "data_dir": "./data/BigEarthNet/BigEarthNet-v1.0",
    "transforms": ["aitlas.transforms.ToTensorRGB", "aitlas.transforms.NormalizeRGB"],
    "bands10_mean": [429.9430203,614.21682446,590.23569706],
    "bands10_std": [572.41639287,582.87945694,675.88746967],
    "selection": "rgb",
    "version": "19 labels"
}

train_dataset = BigEarthNetDataset(train_dataset_config)
train_dataset.transform = ResizeCenterCropFlipHVToTensor()

test_dataset_config = {
    "batch_size": 4,
    "shuffle": False,
    "num_workers": 4,
    "lmdb_path": "./data/BigEarthNet/lmdb",
    "import_to_lmdb": false,
    "csv_file": "./data/BigEarthNet/splits/train.csv",
    "data_dir": "./data/BigEarthNet/BigEarthNet-v1.0",
    "transforms": ["aitlas.transforms.ToTensorRGB", "aitlas.transforms.NormalizeRGB"],
    "bands10_mean": [429.9430203,614.21682446,590.23569706],
    "bands10_std": [572.41639287,582.87945694,675.88746967],
    "selection": "rgb",
    "version": "19 labels"
}

test_dataset = BigEarthNetDataset(test_dataset_config)
len(train_dataset), len(test_dataset)

Setup and create the model for training#

[ ]:
epochs = 10
model_directory = "./data/BigEarthNet/experiments"
model_config = {
    "num_classes": 17,
    "learning_rate": 0.0001,
    "pretrained": False,
    "threshold": 0.5,
    "metrics": ["accuracy", "precision", "recall", "f1_score"]
}
model = ResNet50MultiLabel(model_config)
model.prepare()

Training and evaluation#

[ ]:
model.train_and_evaluate_model(
    train_dataset=train_dataset,
    epochs=epochs,
    model_directory=model_directory,
    val_dataset=test_dataset,
    run_id='1',
)

Predictions#

[ ]:
model_path = "./data/BigEarthNet/checkpoint.pth.tar"
labels = BigEarthNetDataset.labels

model.load_model(model_path)

image = image_loader('./data/predict/image1.tif')
fig = model.predict_image(image, labels)

image = image_loader('./data/predict/image2.tif')
fig = model.predict_image(image, labels)

image = image_loader('./data/predict/image3.tif')
fig = model.predict_image(image, labels)

image = image_loader('./data/predict/image4.tif')
fig = model.predict_image(image, labels)