CLRS Dataset: Example of the aitlas toolbox for multi class image classification#

This notebook shows a sample implementation of a multi class image classification using the aitlas toolbox and the CLRS dataset.

[1]:
from aitlas.datasets import CLRSDataset
from aitlas.models import VisionTransformer
from aitlas.transforms import ResizeCenterCropFlipHVToTensor, ResizeCenterCropToTensor, ResizeRandomCropFlipHVToTensor
from aitlas.utils import image_loader

Load the dataset#

[2]:
dataset_config = {
    "data_dir": "/media/hdd/multi-class/CLRS",
    "csv_file": "/media/hdd/multi-class/CLRS/train.csv",
    "batch_size": 128,
    "shuffle": True,
    "num_workers": 4,
}
dataset = CLRSDataset(dataset_config)

Show images from the dataset#

[3]:
fig1 = dataset.show_image(1000)
fig2 = dataset.show_image(80)
fig3 = dataset.show_batch(15)
../_images/examples_multiclass_classification_example_clrs_5_0.png
../_images/examples_multiclass_classification_example_clrs_5_1.png
../_images/examples_multiclass_classification_example_clrs_5_2.png

Inspect the data#

[4]:
dataset.show_samples()
[4]:
File name Label
0 residential/residential_543_Level2_0.84m.tif residential
1 beach/beach_172_Level1_0.48m.tif beach
2 residential/residential_530_Level3_1.86m.tif residential
3 highway/highway_62_Level3_0.93m.tif highway
4 parking/parking_309_Level1_0.46m.tif parking
5 port/port_275_Level3_2.44m.tif port
6 port/port_357_Level3_1.15m.tif port
7 railway/railway_432_Level2_0.73m.tif railway
8 meadow/meadow_596_Level1_0.39m.tif meadow
9 storage-tank/storage-tank_469_Level1_0.60m.tif storage-tank
10 stadium/stadium_145_Level3_1.95m.tif stadium
11 stadium/stadium_32_Level3_3.66m.tif stadium
12 forest/forest_551_Level2_0.73m.tif forest
13 storage-tank/storage-tank_258_Level3_1.62m.tif storage-tank
14 airport/airport_444_Level2_0.71m.tif airport
15 stadium/stadium_540_Level2_1.41m.tif stadium
16 residential/residential_499_Level3_1.86m.tif residential
17 railway-station/railway-station_432_Level2_0.7... railway-station
18 playground/playground_170_Level2_0.93m.tif playground
19 bridge/bridge_84_Level1_0.73m.tif bridge
[5]:
dataset.data_distribution_table()
[5]:
Label Count
0 airport 360
1 bare-land 360
2 beach 360
3 bridge 360
4 commercial 360
5 desert 360
6 farmland 360
7 forest 360
8 golf-course 360
9 highway 360
10 industrial 360
11 meadow 360
12 mountain 360
13 overpass 360
14 park 360
15 parking 360
16 playground 360
17 port 360
18 railway 360
19 railway-station 360
20 residential 360
21 river 360
22 runway 360
23 stadium 360
24 storage-tank 360
[6]:
fig = dataset.data_distribution_barchart()
../_images/examples_multiclass_classification_example_clrs_9_0.png

Load train and validation splits#

[7]:
train_dataset_config = {
    "batch_size": 32,
    "shuffle": True,
    "num_workers": 4,
    "data_dir": "/media/hdd/multi-class/CLRS",
    "csv_file": "/media/hdd/multi-class/CLRS/train.csv",
}

train_dataset = CLRSDataset(train_dataset_config)
train_dataset.transform = ResizeRandomCropFlipHVToTensor()

validation_dataset_config = {
    "batch_size": 32,
    "shuffle": False,
    "num_workers": 4,
    "data_dir": "/media/hdd/multi-class/CLRS",
    "csv_file": "/media/hdd/multi-class/CLRS/val.csv",
    "transforms": ["aitlas.transforms.ResizeCenterCropToTensor"]
}

validation_dataset = CLRSDataset(validation_dataset_config)
len(train_dataset), len(validation_dataset)
[7]:
(9000, 3000)

Setup and create the model for training#

[8]:
epochs = 100
model_directory = "./experiments/CLRS"
model_config = {
    "num_classes": 25,
    "learning_rate": 0.0001,
    "pretrained": True,
    "metrics": ["accuracy", "precision", "recall", "f1_score"]
}
model = VisionTransformer(model_config)
model.prepare()

Training and validation#

[ ]:
model.train_and_evaluate_model(
    train_dataset=train_dataset,
    epochs=epochs,
    model_directory=model_directory,
    val_dataset=validation_dataset,
    run_id='1',
)

Test the model#

[ ]:
test_dataset_config = {
    "batch_size": 16,
    "shuffle": False,
    "num_workers": 4,
    "data_dir": "/media/hdd/multi-class/CLRS",
    "csv_file": "/media/hdd/multi-class/CLRS/test.csv",
    "transforms": ["aitlas.transforms.ResizeCenterCropToTensor"]
}

test_dataset = CLRSDataset(test_dataset_config)
model_path = "best_checkpoint_1654596678_20.pth.tar"
model.metrics = ["accuracy", "precision", "recall", "f1_score"]
model.running_metrics.reset()
model.evaluate(dataset=test_dataset, model_path=model_path)
model.running_metrics.get_scores(model.metrics)

Predictions#

[24]:
#labels = CLRSDataset.labels
labels = ["airport", "bare-land", "beach", "bridge", "commercial", "desert", "farmland", "forest", "golf-course",
          "highway", "industrial", "meadow", "mountain", "overpass", "park", "parking", "playground", "port", "railway",
          "railway-station", "residential", "river", "runway", "stadium", "storage-tank"]
transform = ResizeCenterCropToTensor()

model_path = "best_checkpoint_1654596678_20.pth.tar"
model.load_model(model_path)

image = image_loader('/media/hdd/multi-class/CLRS/stadium/stadium_353_Level1_0.95m.tif')
fig = model.predict_image(image, labels, transform)
2022-10-19 11:55:19,750 INFO Loading checkpoint best_checkpoint_1654596678_20.pth.tar
2022-10-19 11:55:20,083 INFO Loaded checkpoint best_checkpoint_1654596678_20.pth.tar at epoch 21
../_images/examples_multiclass_classification_example_clrs_19_1.png