RSD46 dataset: Example of the aitlas
toolbox in for multi class image classification#
This notebook shows a sample implementation of a multi class image classification using the aitlas
toolbox and the RSD46 WHU dataset.
[ ]:
from aitlas.datasets import RSD46WHUDataset
from aitlas.models import ResNet50
from aitlas.transforms import ResizeCenterCropFlipHVToTensor, ResizeCenterCropToTensor
from aitlas.utils import image_loader
Load the dataset#
[ ]:
dataset_config = {
"data_dir": "./data/RSD46-WHU",
"csv_file": "./data/RSD46-WHU/train.csv"
}
dataset = RSD46WHUDataset(dataset_config)
Show images from the dataset#
[ ]:
fig1 = dataset.show_image(1000)
fig2 = dataset.show_image(80)
fig3 = dataset.show_batch(15)
Inspect the data#
[ ]:
dataset.show_samples()
[ ]:
dataset.data_distribution_table()
[ ]:
fig = dataset.data_distribution_barchart()
Load train and test splits#
[ ]:
train_dataset_config = {
"batch_size": 16,
"shuffle": True,
"num_workers": 4,
"data_dir": "./data/RSD46-WHU",
"csv_file": "./data/RSD46-WHU/train.csv"
}
train_dataset = RSD46WHUDataset(train_dataset_config)
train_dataset.transform = ResizeCenterCropFlipHVToTensor()
test_dataset_config = {
"batch_size": 4,
"shuffle": False,
"num_workers": 4,
"data_dir": "./data/RSD46-WHU",
"csv_file": "./data/RSD46-WHU/test.csv",
"transforms": ["aitlas.transforms.ResizeCenterCropToTensor"]
}
test_dataset = RSD46WHUDataset(test_dataset_config)
len(train_dataset), len(test_dataset)
Setup and create the model for training#
[ ]:
epochs = 10
model_directory = "./experiments/RSD46-WHU"
model_config = {
"num_classes": 46,
"learning_rate": 0.0001,
"pretrained": True,
"metrics": ["accuracy", "precision", "recall", "f1_score"]
}
model = ResNet50(model_config)
model.prepare()
Training and evaluation#
[ ]:
model.train_and_evaluate_model(
train_dataset=train_dataset,
epochs=epochs,
model_directory=model_directory,
val_dataset=test_dataset,
run_id='1',
)
Predictions#
[ ]:
model_path = "./experiments/AID/checkpoint.pth.tar"
#labels = RSD46WHUDataset.labels
labels = ["Airplane", "Airport", "Artificial dense forest land", "Artificial sparse forest land", "Bare land",
"Basketball court", "Blue structured factory building", "Building", "Construction site", "Cross river bridge",
"Crossroads", "Dense tall building", "Dock", "Fish pond", "Footbridge", "Graff", "Grassland",
"Low scattered building", "Lrregular farmland", "Medium density scattered building",
"Medium density structured building", "Natural dense forest land", "Natural sparse forest land", "Oiltank",
"Overpass", "Parking lot", "Plasticgreenhouse", "Playground", "Railway", "Red structured factory building",
"Refinery", "Regular farmland", "Scattered blue roof factory building", "Scattered red roof factory building",
"Sewage plant-type-one", "Sewage plant-type-two", "Ship", "Solar power station", "Sparse residential area",
"Square", "Steelsmelter", "Storage land", "Tennis court", "Thermal power plant", "Vegetable plot", "Water"]
transform = ResizeCenterCropToTensor()
model.load_model(model_path)
image = image_loader('./data/predict/image1.tif')
fig = model.predict_image(image, labels, transform)
image = image_loader('./data/predict/image2.tif')
fig = model.predict_image(image, labels, transform)
image = image_loader('./data/predict/image3.tif')
fig = model.predict_image(image, labels, transform)
image = image_loader('./data/predict/image4.tif')
fig = model.predict_image(image, labels, transform)