Source code for aitlas.models.fasterrcnn
"""FasterRCNN model for object detection"""
from torchvision.models.detection import (
FasterRCNN_ResNet50_FPN_V2_Weights,
fasterrcnn_resnet50_fpn_v2,
)
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from ..base import BaseObjectDetection
[docs]class FasterRCNN(BaseObjectDetection):
"""FasterRCNN model implementation
.. note:: Based on https://pytorch.org/vision/stable/models/generated/torchvision.models.detection.fasterrcnn_resnet50_fpn_v2.html#torchvision.models.detection.fasterrcnn_resnet50_fpn_v2
"""
def __init__(self, config):
super().__init__(config)
# load an object detection model pre-trained on COCO
self.model = fasterrcnn_resnet50_fpn_v2(
weights=FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1
if self.config.pretrained
else None
)
# get number of input features for the classifier
in_features = self.model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
self.model.roi_heads.box_predictor = FastRCNNPredictor(
in_features, self.config.num_classes
)
[docs] def forward(self, inputs, targets=None):
return self.model.forward(inputs, targets)