Resisc45 dataset: Example of the aitlas
toolbox in for multi class image classification#
This notebook shows a sample implementation of a multi class image classification using the aitlas
toolbox and the Resisc45 dataset.
[ ]:
from aitlas.datasets import Resisc45Dataset
from aitlas.models import ResNet50
from aitlas.transforms import ResizeCenterCropFlipHVToTensor, ResizeCenterCropToTensor
from aitlas.utils import image_loader
Load the dataset#
[ ]:
dataset_config = {
"data_dir": "./data/RESISC45",
"csv_file": "./data/RESISC45/train.csv"
}
dataset = Resisc45Dataset(dataset_config)
Show images from the dataset#
[ ]:
fig1 = dataset.show_image(1000)
fig2 = dataset.show_image(80)
fig3 = dataset.show_batch(15)
Inspect the data#
[ ]:
dataset.show_samples()
[ ]:
dataset.data_distribution_table()
[ ]:
fig = dataset.data_distribution_barchart()
Load train and test splits#
[ ]:
train_dataset_config = {
"batch_size": 16,
"shuffle": True,
"num_workers": 4,
"data_dir": "./data/RESISC45",
"csv_file": "./data/RESISC45/train.csv"
}
train_dataset = Resisc45Dataset(train_dataset_config)
train_dataset.transform = ResizeCenterCropFlipHVToTensor()
test_dataset_config = {
"batch_size": 4,
"shuffle": False,
"num_workers": 4,
"data_dir": "./data/RESISC45",
"csv_file": "./data/RESISC45/test.csv",
"transforms": ["aitlas.transforms.ResizeCenterCropToTensor"]
}
test_dataset = Resisc45Dataset(test_dataset_config)
len(train_dataset), len(test_dataset)
Setup and create the model for training#
[ ]:
epochs = 10
model_directory = "./experiments/RESISC45"
model_config = {
"num_classes": 45,
"learning_rate": 0.0001,
"pretrained": True,
"metrics": ["accuracy", "precision", "recall", "f1_score"]
}
model = ResNet50(model_config)
model.prepare()
Training and evaluation#
[ ]:
model.train_and_evaluate_model(
train_dataset=train_dataset,
epochs=epochs,
model_directory=model_directory,
val_dataset=test_dataset,
run_id='1',
)
Predictions#
[ ]:
model_path = "./experiments/RESISC45/checkpoint.pth.tar"
#labels = Resisc45Dataset.labels
labels = ["airplane", "airport", "baseball_diamond", "basketball_court", "beach", "bridge", "chaparral", "church",
"circular_farmland", "cloud", "commercial_area", "dense_residential", "desert", "forest", "freeway",
"golf_course", "ground_track_field", "harbor", "industrial_area", "intersection", "island", "lake",
"meadow", "medium_residential", "mobile_home_park", "mountain", "overpass", "palace", "parking_lot",
"railway", "railway_station", "rectangular_farmland", "river", "roundabout", "runway", "sea_ice",
"ship", "snowberg", "sparse_residential", "stadium", "storage_tank", "tennis_court", "terrace",
"thermal_power_station", "wetland"]
transform = ResizeCenterCropToTensor()
model.load_model(model_path)
image = image_loader('./data/predict/image1.tif')
fig = model.predict_image(image, labels, transform)
image = image_loader('./data/predict/image2.tif')
fig = model.predict_image(image, labels, transform)
image = image_loader('./data/predict/image3.tif')
fig = model.predict_image(image, labels, transform)
image = image_loader('./data/predict/image4.tif')
fig = model.predict_image(image, labels, transform)