Models#

aitlas.models.alexnet module#

AlexNet model for multiclass and multilabel classification

class AlexNet(config)[source]#

Bases: BaseMulticlassClassifier

AlexNet model implementation

name = 'AlexNet'#
forward(x)[source]#
extract_features()[source]#

Remove final layers if we only need to extract features

freeze()[source]#
training: bool#
class AlexNetMultiLabel(config)[source]#

Bases: BaseMultilabelClassifier

name = 'AlexNet'#
forward(x)[source]#
extract_features()[source]#

Remove final layers if we only need to extract features

training: bool#
freeze()[source]#

aitlas.models.cnn_rnn module#

CNNRNN model

class EncoderCNN(embed_size)[source]#

Bases: Module

forward(images)[source]#
training: bool#
class DecoderRNN(embed_size, hidden_size, num_classes, num_layers)[source]#

Bases: Module

forward(features)[source]#
training: bool#
class CNNRNN(config)[source]#

Bases: BaseMultilabelClassifier

CNNRNN model implementation.

schema#

alias of CNNRNNModelSchema

forward(inputs)[source]#
training: bool#

aitlas.models.convnext module#

ConvNeXt tiny model

class ConvNeXtTiny(config)[source]#

Bases: BaseMulticlassClassifier

ConvNeXtTiny model implementation

name = 'ConvNeXt tiny'#
forward(x)[source]#
freeze()[source]#
extract_features()[source]#

Remove final layers if we only need to extract features

training: bool#
class ConvNeXtTinyMultiLabel(config)[source]#

Bases: BaseMultilabelClassifier

name = 'ConvNeXt tiny'#
forward(x)[source]#
extract_features()[source]#

Remove final layers if we only need to extract features

training: bool#
freeze()[source]#

aitlas.models.deeplabv3 module#

DeepLabV3 model

class DeepLabV3(config)[source]#

Bases: BaseSegmentationClassifier

DeepLabV3 model implementation

forward(x)[source]#
training: bool#

aitlas.models.deeplabv3plus module#

DeepLabV3Plus model

class DeepLabV3Plus(config)[source]#

Bases: BaseSegmentationClassifier

DeepLabV3Plus model implementation

forward(x)[source]#
training: bool#

aitlas.models.densenet module#

DenseNet161 model for multiclass classification

class DenseNet161(config)[source]#

Bases: BaseMulticlassClassifier

DenseNet161 model implementation

name = 'DenseNet161'#
forward(x)[source]#
extract_features()[source]#

Remove final layers if we only need to extract features

freeze()[source]#
training: bool#
class DenseNet161MultiLabel(config)[source]#

Bases: BaseMultilabelClassifier

name = 'DenseNet161'#
training: bool#
forward(x)[source]#
extract_features()[source]#

Remove final layers if we only need to extract features

freeze()[source]#

aitlas.models.efficientnet module#

EfficientNetB0 (V1) for image classification

class EfficientNetB0(config)[source]#

Bases: BaseMulticlassClassifier

EfficientNetB0 model implementation

name = 'EfficientNetB0'#
forward(x)[source]#
freeze()[source]#
extract_features()[source]#
training: bool#
class EfficientNetB0MultiLabel(config)[source]#

Bases: BaseMultilabelClassifier

name = 'EfficientNetB0'#
forward(x)[source]#
extract_features()[source]#

Remove final layers if we only need to extract features

freeze()[source]#
training: bool#
class EfficientNetB4(config)[source]#

Bases: BaseMulticlassClassifier

name = 'EfficientNetB4'#
forward(x)[source]#
freeze()[source]#
extract_features()[source]#
training: bool#
class EfficientNetB4MultiLabel(config)[source]#

Bases: BaseMultilabelClassifier

name = 'EfficientNetB4'#
forward(x)[source]#
extract_features()[source]#

Remove final layers if we only need to extract features

freeze()[source]#
training: bool#
class EfficientNetB7(config)[source]#

Bases: BaseMulticlassClassifier

name = 'EfficientNetB7'#
forward(x)[source]#
freeze()[source]#
extract_features()[source]#
training: bool#
class EfficientNetB7MultiLabel(config)[source]#

Bases: BaseMultilabelClassifier

name = 'EfficientNetB7'#
training: bool#
forward(x)[source]#
extract_features()[source]#

Remove final layers if we only need to extract features

freeze()[source]#

aitlas.models.efficientnet_v2 module#

EfficientNetV2 model

class EfficientNetV2(config)[source]#

Bases: BaseMulticlassClassifier

EfficientNetV2 model implementation

name = 'EfficientNetV2'#
forward(x)[source]#
training: bool#
class EfficientNetV2MultiLabel(config)[source]#

Bases: BaseMultilabelClassifier

name = 'EfficientNetV2'#
training: bool#
forward(x)[source]#

aitlas.models.fasterrcnn module#

FasterRCNN model for object detection

class FasterRCNN(config)[source]#

Bases: BaseObjectDetection

FasterRCNN model implementation

forward(inputs, targets=None)[source]#
training: bool#

aitlas.models.fcn module#

FCN model for segmentation

class FCN(config)[source]#

Bases: BaseSegmentationClassifier

FCN model implementation

forward(x)[source]#
training: bool#

aitlas.models.hrnet module#

HRNet model for segmentation

class HRNetModule(head, pretrained=True, higher_res=False)[source]#

Bases: Module

HRNet model implementation

Pretrained backbone for HRNet. :param head: Output head :type head: nn.Module :param pretrained: If True, uses imagenet pretrained weights :type pretrained: bool :param higher_res: If True, retains higher resolution features :type higher_res: bool

Parameters:
forward(x)[source]#
training: bool#
class HRNetSegHead(nclasses=3, higher_res=False)[source]#

Bases: Module

Segmentation head for HRNet. Does not have pretrained weights.

Parameters:
  • nclasses (int) – Number of output classes

  • higher_res (bool) – If True, retains higher resolution features

forward(x, yl)[source]#
training: bool#
class HRNet(config, higher_res=False)[source]#

Bases: BaseSegmentationClassifier

forward(x)[source]#
training: bool#

aitlas.models.inceptiontime module#

InceptionTime model

Note

Original implementation of InceptionTime model dl4sits/BreizhCrops

class InceptionTime(config)[source]#

Bases: BaseMulticlassClassifier

InceptionTime model implementation

Note

Based dl4sits/BreizhCrops

schema#

alias of InceptionTimeSchema

forward(x)[source]#
load_optimizer()[source]#

Load the optimizer

training: bool#
class InceptionModule(kernel_size=32, num_filters=128, residual=True, use_bias=False, device=device(type='cpu'))[source]#

Bases: Module

forward(input_tensor)[source]#
training: bool#

aitlas.models.lstm module#

LSTM model

Note

Original implementation of LSTM model: dl4sits/BreizhCrops

class LSTM(config)[source]#

Bases: BaseMulticlassClassifier

LSTM model implementation

Note

Based on <dl4sits/BreizhCrops>

schema#

alias of LSTMSchema

logits(x)[source]#
forward(x)[source]#
load_optimizer()[source]#

Load the optimizer

training: bool#

aitlas.models.mlp_mixer module#

MLP-Mixer architecture for image classification.

class MLPMixer(config)[source]#

Bases: BaseMulticlassClassifier

MLP mixer multi-class b16_224 model implementation

name = 'MLP mixer b16_224'#
forward(x)[source]#
training: bool#
class MLPMixerMultilabel(config)[source]#

Bases: BaseMultilabelClassifier

MLP mixer multi-label b16_224 model implementation

name = 'MLP mixer b16_224'#
forward(x)[source]#
training: bool#

aitlas.models.msresnet module#

MRSResNet model

Note

Adapted from dl4sits/BreizhCrops

Original implementation of MSResNet model: geekfeiw/Multi-Scale-1D-ResNet and dl4sits/BreizhCrops

conv3x3(in_planes, out_planes, stride=1)[source]#

3x3 convolution with padding

conv5x5(in_planes, out_planes, stride=1)[source]#
conv7x7(in_planes, out_planes, stride=1)[source]#
class BasicBlock3x3(inplanes3, planes, stride=1, downsample=None)[source]#

Bases: Module

expansion = 1#
forward(x)[source]#
training: bool#
class BasicBlock5x5(inplanes5, planes, stride=1, downsample=None)[source]#

Bases: Module

expansion = 1#
forward(x)[source]#
training: bool#
class BasicBlock7x7(inplanes7, planes, stride=1, downsample=None)[source]#

Bases: Module

expansion = 1#
forward(x)[source]#
training: bool#
class MSResNet(config)[source]#

Bases: BaseMulticlassClassifier

MSResNet model implementation

Note

Based on <dl4sits/BreizhCrops>

schema#

alias of MSResNetSchema

forward(x0)[source]#
load_optimizer()[source]#

Load the optimizer

training: bool#

aitlas.models.omniscalecnn module#

OmniScaleCNN model implementation

Note

Adapted from dl4sits/BreizhCrops

Original implementation of OmniScaleCNN model: dl4sits/BreizhCrops

class SampaddingConv1D_BN(in_channels, out_channels, kernel_size)[source]#

Bases: Module

forward(X)[source]#
training: bool#
class build_layer_with_layer_parameter(layer_parameters)[source]#

Bases: Module

formerly build_layer_with_layer_parameter

Note

layer_parameters format : [in_channels, out_channels, kernel_size, in_channels, out_channels, kernel_size, …, nlayers ]

forward(X)[source]#
training: bool#
class OmniScaleCNN(config)[source]#

Bases: BaseMulticlassClassifier

OmniScaleCNN model implementation

Note

Based on <dl4sits/BreizhCrops>

schema#

alias of OmniScaleCNNSchema

forward(X)[source]#
load_optimizer()[source]#

Load the optimizer

training: bool#
get_Prime_number_in_a_range(start, end)[source]#
get_out_channel_number(paramenter_layer, in_channel, prime_list)[source]#
generate_layer_parameter_list(start, end, paramenter_number_of_layer_list, in_channel=1)[source]#

aitlas.models.resnet module#

ResNet50 and ResNet152 models for multi-class and multi-label classification

class ResNet50(config)[source]#

Bases: BaseMulticlassClassifier

ResNet50 multi-class model implementation.

name = 'ResNet50'#
forward(x)[source]#
freeze()[source]#
extract_features()[source]#

Remove final layers if we only need to extract features

training: bool#
class ResNet152(config)[source]#

Bases: BaseMulticlassClassifier

ResNet50 multi-label model implementation

name = 'ResNet152'#
forward(x)[source]#
extract_features()[source]#

Remove final layers if we only need to extract features

freeze()[source]#
training: bool#
class ResNet50MultiLabel(config)[source]#

Bases: BaseMultilabelClassifier

name = 'ResNet50'#
forward(x)[source]#
extract_features()[source]#

Remove final layers if we only need to extract features

freeze()[source]#
training: bool#
class ResNet152MultiLabel(config)[source]#

Bases: BaseMultilabelClassifier

name = 'ResNet152'#
training: bool#
forward(x)[source]#
extract_features()[source]#

Remove final layers if we only need to extract features

freeze()[source]#

aitlas.models.schemas module#

class TransformerModelSchema(*, only=None, exclude=(), many=False, context=None, load_only=(), dump_only=(), partial=False, unknown=None)[source]#

Bases: BaseClassifierSchema

Schema for configuring a transformer model.

Parameters:
  • only (types.StrSequenceOrSet | None) –

  • exclude (types.StrSequenceOrSet) –

  • many (bool) –

  • context (dict | None) –

  • load_only (types.StrSequenceOrSet) –

  • dump_only (types.StrSequenceOrSet) –

  • partial (bool | types.StrSequenceOrSet) –

  • unknown (str | None) –

opts: SchemaOpts = <marshmallow.schema.SchemaOpts object>#
fields: Dict[str, ma_fields.Field]#

Dictionary mapping field_names -> Field objects

load_fields: Dict[str, ma_fields.Field]#
dump_fields: Dict[str, ma_fields.Field]#
class InceptionTimeSchema(*, only=None, exclude=(), many=False, context=None, load_only=(), dump_only=(), partial=False, unknown=None)[source]#

Bases: BaseClassifierSchema

Schema for configuring a InceptionTime model.

Parameters:
  • only (types.StrSequenceOrSet | None) –

  • exclude (types.StrSequenceOrSet) –

  • many (bool) –

  • context (dict | None) –

  • load_only (types.StrSequenceOrSet) –

  • dump_only (types.StrSequenceOrSet) –

  • partial (bool | types.StrSequenceOrSet) –

  • unknown (str | None) –

opts: SchemaOpts = <marshmallow.schema.SchemaOpts object>#
fields: Dict[str, ma_fields.Field]#

Dictionary mapping field_names -> Field objects

load_fields: Dict[str, ma_fields.Field]#
dump_fields: Dict[str, ma_fields.Field]#
class LSTMSchema(*, only=None, exclude=(), many=False, context=None, load_only=(), dump_only=(), partial=False, unknown=None)[source]#

Bases: BaseClassifierSchema

Schema for configuring a LSTM model.

Parameters:
  • only (types.StrSequenceOrSet | None) –

  • exclude (types.StrSequenceOrSet) –

  • many (bool) –

  • context (dict | None) –

  • load_only (types.StrSequenceOrSet) –

  • dump_only (types.StrSequenceOrSet) –

  • partial (bool | types.StrSequenceOrSet) –

  • unknown (str | None) –

opts: SchemaOpts = <marshmallow.schema.SchemaOpts object>#
fields: Dict[str, ma_fields.Field]#

Dictionary mapping field_names -> Field objects

load_fields: Dict[str, ma_fields.Field]#
dump_fields: Dict[str, ma_fields.Field]#
class MSResNetSchema(*, only=None, exclude=(), many=False, context=None, load_only=(), dump_only=(), partial=False, unknown=None)[source]#

Bases: BaseClassifierSchema

Schema for configuring a MSResNet model.

Parameters:
  • only (types.StrSequenceOrSet | None) –

  • exclude (types.StrSequenceOrSet) –

  • many (bool) –

  • context (dict | None) –

  • load_only (types.StrSequenceOrSet) –

  • dump_only (types.StrSequenceOrSet) –

  • partial (bool | types.StrSequenceOrSet) –

  • unknown (str | None) –

opts: SchemaOpts = <marshmallow.schema.SchemaOpts object>#
fields: Dict[str, ma_fields.Field]#

Dictionary mapping field_names -> Field objects

load_fields: Dict[str, ma_fields.Field]#
dump_fields: Dict[str, ma_fields.Field]#
class TempCNNSchema(*, only=None, exclude=(), many=False, context=None, load_only=(), dump_only=(), partial=False, unknown=None)[source]#

Bases: BaseClassifierSchema

Schema for configuring a TempCNN model.

Parameters:
  • only (types.StrSequenceOrSet | None) –

  • exclude (types.StrSequenceOrSet) –

  • many (bool) –

  • context (dict | None) –

  • load_only (types.StrSequenceOrSet) –

  • dump_only (types.StrSequenceOrSet) –

  • partial (bool | types.StrSequenceOrSet) –

  • unknown (str | None) –

opts: SchemaOpts = <marshmallow.schema.SchemaOpts object>#
fields: Dict[str, ma_fields.Field]#

Dictionary mapping field_names -> Field objects

load_fields: Dict[str, ma_fields.Field]#
dump_fields: Dict[str, ma_fields.Field]#
class StarRNNSchema(*, only=None, exclude=(), many=False, context=None, load_only=(), dump_only=(), partial=False, unknown=None)[source]#

Bases: BaseClassifierSchema

Schema for configuring a StarRNN model.

Parameters:
  • only (types.StrSequenceOrSet | None) –

  • exclude (types.StrSequenceOrSet) –

  • many (bool) –

  • context (dict | None) –

  • load_only (types.StrSequenceOrSet) –

  • dump_only (types.StrSequenceOrSet) –

  • partial (bool | types.StrSequenceOrSet) –

  • unknown (str | None) –

opts: SchemaOpts = <marshmallow.schema.SchemaOpts object>#
fields: Dict[str, ma_fields.Field]#

Dictionary mapping field_names -> Field objects

load_fields: Dict[str, ma_fields.Field]#
dump_fields: Dict[str, ma_fields.Field]#
class OmniScaleCNNSchema(*, only=None, exclude=(), many=False, context=None, load_only=(), dump_only=(), partial=False, unknown=None)[source]#

Bases: BaseClassifierSchema

Schema for configuring a OmniScaleCNN model.

Parameters:
  • only (types.StrSequenceOrSet | None) –

  • exclude (types.StrSequenceOrSet) –

  • many (bool) –

  • context (dict | None) –

  • load_only (types.StrSequenceOrSet) –

  • dump_only (types.StrSequenceOrSet) –

  • partial (bool | types.StrSequenceOrSet) –

  • unknown (str | None) –

opts: SchemaOpts = <marshmallow.schema.SchemaOpts object>#
fields: Dict[str, ma_fields.Field]#

Dictionary mapping field_names -> Field objects

load_fields: Dict[str, ma_fields.Field]#
dump_fields: Dict[str, ma_fields.Field]#
class UnsupervisedDeepMulticlassClassifierSchema(*, only=None, exclude=(), many=False, context=None, load_only=(), dump_only=(), partial=False, unknown=None)[source]#

Bases: BaseModelSchema

Parameters:
  • only (types.StrSequenceOrSet | None) –

  • exclude (types.StrSequenceOrSet) –

  • many (bool) –

  • context (dict | None) –

  • load_only (types.StrSequenceOrSet) –

  • dump_only (types.StrSequenceOrSet) –

  • partial (bool | types.StrSequenceOrSet) –

  • unknown (str | None) –

opts: SchemaOpts = <marshmallow.schema.SchemaOpts object>#
fields: Dict[str, ma_fields.Field]#

Dictionary mapping field_names -> Field objects

load_fields: Dict[str, ma_fields.Field]#
dump_fields: Dict[str, ma_fields.Field]#
class UNetEfficientNetModelSchema(*, only=None, exclude=(), many=False, context=None, load_only=(), dump_only=(), partial=False, unknown=None)[source]#

Bases: BaseSegmentationClassifierSchema

Parameters:
  • only (types.StrSequenceOrSet | None) –

  • exclude (types.StrSequenceOrSet) –

  • many (bool) –

  • context (dict | None) –

  • load_only (types.StrSequenceOrSet) –

  • dump_only (types.StrSequenceOrSet) –

  • partial (bool | types.StrSequenceOrSet) –

  • unknown (str | None) –

opts: SchemaOpts = <marshmallow.schema.SchemaOpts object>#
fields: Dict[str, ma_fields.Field]#

Dictionary mapping field_names -> Field objects

load_fields: Dict[str, ma_fields.Field]#
dump_fields: Dict[str, ma_fields.Field]#
class CNNRNNModelSchema(*, only=None, exclude=(), many=False, context=None, load_only=(), dump_only=(), partial=False, unknown=None)[source]#

Bases: BaseModelSchema

Parameters:
  • only (types.StrSequenceOrSet | None) –

  • exclude (types.StrSequenceOrSet) –

  • many (bool) –

  • context (dict | None) –

  • load_only (types.StrSequenceOrSet) –

  • dump_only (types.StrSequenceOrSet) –

  • partial (bool | types.StrSequenceOrSet) –

  • unknown (str | None) –

opts: SchemaOpts = <marshmallow.schema.SchemaOpts object>#
fields: Dict[str, ma_fields.Field]#

Dictionary mapping field_names -> Field objects

load_fields: Dict[str, ma_fields.Field]#
dump_fields: Dict[str, ma_fields.Field]#

aitlas.models.shallow module#

class ShallowCNNNet(config)[source]#

Bases: BaseMulticlassClassifier

Simlpe shallow multi-class CNN network for testing purposes

forward(x)[source]#
training: bool#
class ShallowCNNNetMultilabel(config)[source]#

Bases: BaseMultilabelClassifier

Simlpe shallow multi-label CNN network for testing purposes

forward(x)[source]#
training: bool#

aitlas.models.starrnn module#

StarRNN model for multiclass classification

Note

Adapted from dl4sits/BreizhCrops

Original implementation of StarRNN model: dl4sits/BreizhCrops

Author: Türkoglu Mehmet Özgür <ozgur.turkoglu@geod.baug.ethz.ch>

class StarRNN(config)[source]#

Bases: BaseMulticlassClassifier

StarRNN model implementation

schema#

alias of StarRNNSchema

forward(x)[source]#
load_optimizer()[source]#

Load the optimizer

training: bool#
class StarCell(input_size, hidden_size, bias=True)[source]#

Bases: Module

reset_parameters()[source]#
forward(x, hidden)[source]#
training: bool#
class StarLayer(input_dim, hidden_dim, bias=True, droput_factor=0.2, batch_norm=True, layer_norm=False, device=device(type='cpu'))[source]#

Bases: Module

forward(x)[source]#
training: bool#

aitlas.models.swin_transformer module#

Swin Transformer V2 model for multi-class and multi-label classification tasks.

class SwinTransformer(config)[source]#

Bases: BaseMulticlassClassifier

A Swin Transformer V2 implementation for multi-class classification tasks.

Initialize a SwinTransformer object with the given configuration.

Parameters:

config (Config schema object) – A configuration containing model-related settings.

name = 'SwinTransformerV2'#
freeze()[source]#

Freeze all the layers in the model except for the head. This prevents the gradient computation for the frozen layers during backpropagation.

forward(x)[source]#

Perform a forward pass through the model.

Parameters:

x (torch.Tensor) – Input tensor with shape (batch_size, channels, height, width).

Returns:

Output tensor with shape (batch_size, num_classes).

Return type:

torch.Tensor

training: bool#
class SwinTransformerMultilabel(config)[source]#

Bases: BaseMultilabelClassifier

A Swin Transformer V2 implementation for multi-label classification tasks.

Initialize a SwinTransformerMultilabel object with the given configuration.

Parameters:

config (Config schema object) – A configuration object containing model-related settings.

name = 'SwinTransformerV2'#
freeze()[source]#

Freeze all the layers in the model except for the head. This prevents the gradient computation for the frozen layers during backpropagation.

forward(x)[source]#

Perform a forward pass through the model.

Parameters:

x (torch.Tensor) – Input tensor with shape (batch_size, channels, height, width).

Returns:

Output tensor with shape (batch_size, num_classes).

Return type:

torch.Tensor

training: bool#

aitlas.models.tempcnn module#

Temporal Convolutional Neural Network (TempCNN) model

Note

Adapted from: dl4sits/BreizhCrops

Original implementation(s) of TempCNN model: dl4sits/BreizhCrops and charlotte-pel/temporalCNN

class TempCNN(config)[source]#

Bases: BaseMulticlassClassifier

TempCNN model implementation

Note

Based on <dl4sits/BreizhCrops>

schema#

alias of TempCNNSchema

forward(x)[source]#
load_optimizer()[source]#

Load the optimizer

training: bool#
class Conv1D_BatchNorm_Relu_Dropout(input_dim, hidden_dims, kernel_size=5, drop_probability=0.5)[source]#

Bases: Module

forward(X)[source]#
training: bool#
class FC_BatchNorm_Relu_Dropout(input_dim, hidden_dims, drop_probability=0.5)[source]#

Bases: Module

forward(X)[source]#
training: bool#
class Flatten(*args, **kwargs)[source]#

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(input)[source]#
training: bool#

aitlas.models.transformer module#

Transformer model

Note

Adapted from: dl4sits/BreizhCrops

Original implementation of Transformer model: dl4sits/BreizhCrops

class TransformerModel(config)[source]#

Bases: BaseMulticlassClassifier

Transformer model for multi-class classification model implementation

schema#

alias of TransformerModelSchema

forward(x)[source]#
load_optimizer()[source]#

Load the optimizer

training: bool#
class Flatten(*args, **kwargs)[source]#

Bases: Module

Flatten module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(input)[source]#
training: bool#

aitlas.models.unet module#

UNet model for segmentation

class Unet(config)[source]#

Bases: BaseSegmentationClassifier

UNet segmentation model implementation.

forward(x)[source]#
training: bool#

aitlas.models.unet_efficientnet module#

post_process(prediction_directory, prediction_csv)[source]#
post_process_single(sourcefile, watershed_line=True, conn=2, polygon_buffer=0.5, tolerance=0.5, seed_msk_th=0.75, area_th_for_seed=110, prediction_threshold=0.5, area_th=80, contact_weight=1.0, edge_weight=0.0, seed_contact_weight=1.0, seed_edge_weight=1.0)[source]#
evaluation(prediction_csv, gt_csv)[source]#
class FocalLoss2d(gamma=3, ignore_index=255, eps=1e-06)[source]#

Bases: Module

forward(outputs, targets, weights=1.0)[source]#
training: bool#
class DiceLoss(weight=None, per_image=False, eps=1e-06)[source]#

Bases: Module

forward(outputs, targets)[source]#
training: bool#
class GenEfficientNet(block_args, num_classes=1000, in_channels=3, num_features=1280, stem_size=32, fix_stem=False, channel_multiplier=1.0, channel_divisor=8, channel_min=None, pad_type='', act_layer=<class 'torch.nn.modules.activation.ReLU'>, drop_connect_rate=0.0, se_kwargs=None, norm_layer=<class 'torch.nn.modules.batchnorm.BatchNorm2d'>, norm_kwargs=None, weight_init='goog')[source]#

Bases: Module

training: bool#
class UNetEfficientNet(config)[source]#

Bases: BaseSegmentationClassifier

Unet EfficientNet model implementation. .. note:: Based on <SpaceNetChallenge/SpaceNet_SAR_Buildings_Solutions>

:param config : the configuration for this model :type config : UNetEfficientNetModelSchema

schema#

alias of UNetEfficientNetModelSchema

forward(x, strip, direction, coord)[source]#
load_optimizer()[source]#
load_lr_scheduler()[source]#
train_and_evaluate_model(train_dataset, epochs=100, model_directory=None, save_epochs=10, iterations_log=100, resume_model=None, val_dataset=None, run_id=None, **kwargs)[source]#

Overridden method for training on the SpaceNet6 data set.

Parameters:
evaluate(dataset=None, model_path=None)[source]#
Parameters:
load_model(file_path, optimizer=None)[source]#
training: bool#

aitlas.models.unsupervised module#

DeepCluster model

class UnsupervisedDeepMulticlassClassifier(config)[source]#

Bases: BaseMulticlassClassifier

Unsupervised Deep Learning model implementation

Note

Based on Deep Clustering: <facebookresearch/deepcluster>

schema#

alias of UnsupervisedDeepMulticlassClassifierSchema

train_epoch(epoch, dataloader, optimizer, criterion, iterations_log)[source]#

Overriding train epoch to implement the custom logic for the unsupervised classifier

forward(x)[source]#
training: bool#
compute_features(dataloader, model, N, batch, device)[source]#

Compute features for images

class VGG(features, num_classes, sobel)[source]#

Bases: Module

forward(x)[source]#
training: bool#
make_layers(input_dim, batch_norm)[source]#
vgg16(sobel=False, bn=True, out=1000)[source]#
class UnifLabelSampler(N, images_lists)[source]#

Bases: Sampler

Samples elements uniformely accross pseudolabels.

Parameters:
  • N (int) – size of returned iterator.

  • images_lists – lists of images for each pseudolabel.

generate_indexes_epoch()[source]#

aitlas.models.vgg module#

VGG16 model

class VGG16(config)[source]#

Bases: BaseMulticlassClassifier

VGG16 model implementation

name = 'VGG16'#
forward(x)[source]#
freeze()[source]#
extract_features()[source]#

Remove final layers if we only need to extract features

training: bool#
class VGG19(config)[source]#

Bases: BaseMulticlassClassifier

name = 'VGG19'#
forward(x)[source]#
extract_features()[source]#

Remove final layers if we only need to extract features

freeze()[source]#
training: bool#
class VGG16MultiLabel(config)[source]#

Bases: BaseMultilabelClassifier

name = 'VGG16'#
forward(x)[source]#
extract_features()[source]#

Remove final layers if we only need to extract features

freeze()[source]#
training: bool#
class VGG19MultiLabel(config)[source]#

Bases: BaseMultilabelClassifier

name = 'VGG19'#
training: bool#
forward(x)[source]#
freeze()[source]#
extract_features()[source]#

Remove final layers if we only need to extract features

aitlas.models.vision_transformer module#

VisionTransformer model (base_patch16_224)

class VisionTransformer(config)[source]#

Bases: BaseMulticlassClassifier

VisionTransformer model implementation

name = 'ViT base_patch16_224'#
freeze()[source]#
forward(x)[source]#
training: bool#
class VisionTransformerMultilabel(config)[source]#

Bases: BaseMultilabelClassifier

name = 'ViT base_patch16_224'#
freeze()[source]#
training: bool#
forward(x)[source]#