AWS SageMaker Object Detection Training Gotchas

As part of updates to arachn.io, I’ve started tinkering with object detection machine learning models. During my experiments on AWS SageMaker, I found that AutoPilot does not support object detection models, so I had train using notebooks. As a result, I hit some “gotchas” fine-tuning TensorFlow Object Detection models. While this notebook works a treat on its own training data (at least when run through SageMaker studio), this discussion will focus on things I learned while trying to run it on my own data on August 31, 2024.

Hopefully others trying to do the same work will find this information useful!

Gotcha 1: Training Image Filenames Must Not Have Extra Periods

Part of fine-tuning object detection models in SageMaker is organizing training data into the following structure in S3:

input_directory
    |--images
        |--abc.png
        |--def.png
    |--annotations.json

Obviously, one must conform to the documented structure, but there is an additional undocumented requirement: the image filenames (here shown as abc.png, def.png) must not contain extra periods!

If the image filenames contain extra periods (e.g., abc.1.png, abc.2.png), then the inputs simply won’t be recognized for training will be silently ignored. The training succeeds otherwise. This was what I saw in my training output that suggested there was a problem:

[Epoch 0], Speed: 0.000 samples/sec, loss=0.
[Epoch 0], Val_localization=0.000, Val_classification=0.000
[Epoch 1], Speed: 0.000 samples/sec, loss=0.
[Epoch 1], Val_localization=0.000, Val_classification=0.000
[Epoch 2], Speed: 0.000 samples/sec, loss=0. 
[Epoch 2], Val_localization=0.000, Val_classification=0.000
[Epoch 3], Speed: 0.000 samples/sec, loss=0. 
[Epoch 3], Val_localization=0.000, Val_classification=0.000
[Epoch 4], Speed: 0.000 samples/sec, loss=0.
[Epoch 4], Val_localization=0.000, Val_classification=0.000

Specifically, it looks like no learning is taking place, mostly because no learning is taking place, since all input is silently ignored!

Gotcha 2: PNG Files Must Have 3 Channels

Another undocumented requirement is that if the image files are PNG files, then they must have only 3 channels.

For example, images that look like this using the file command (at least on my Mac):

$ file rgba.png
rgba.png: PNG image data, 1920 x 1080, 8-bit/color RGBA, non-interlaced

…have 4 channels (R, G, B, and A). If files like these are used, then the training will fail with an error the following message:

AlgorithmError: ExecuteUserScriptError: ExitCode 1 ErrorMessage "ValueError: cannot reshape array of size 8294400 into shape (1080,1920,3)" Command "/usr/local/bin/python3.9 transfer_learning.py --batch_size 16 --beta_1 0.9 --beta_2 0.999 --early_stopping True --early_stopping_min_delta 0.0 --early_stopping_patience 5 --epochs 10 --epsilon 1e-07 --initial_accumulator_value 0.1 --learning_rate 0.002 --momentum 0.9 --optimizer adam --reinitialize_top_layer Auto --rho 0.95 --train_only_top_layer True", exit code: 1

The training script loads the image as one long list (length 8294400) of R, G, B, A values, and then tries to reorganize it into a list of shape (height, width, 3). This doesn’t work, since there are 4 channels. Readers should note that (for width=1920 and height=1080), width * height * 4 = 8294400 whereas width * height * 3 = 6220800.

The following command converts a 4-channel PNG file to a 3-channel PNG file using ImageMagick:

$ convert rgba.png -alpha off rgb.png
$ file rgb.png
PNG image data, 1920 x 1080, 8-bit/color RGB, non-interlaced

Training with the converted file worked nicely.

Gotcha 3: Deploying Endpoints and Models using the UI is not the Same as Creating using Scripts

Perhaps this is obvious, but I found out the hard way that creating image detection models using the UI as opposed to using, for example, the SageMaker SDK, changes the input and output formats.

For example, when I mounted a pre-rolled Object Detection model to an endpoint, I found that the expected input is the standard OpenCV-style image representation with shape [ 1, 1, 3 ] ([ height, width, 3 ]).

However, deploying using the code from the notebook causes the model to use the documented input and output format:

from sagemaker import image_uris, model_uris, script_uris
from sagemaker.model import Model
from sagemaker.predictor import Predictor
from sagemaker.utils import name_from_base

# model_version="*" fetches the latest version of the model.
infer_model_id, infer_model_version = dropdown.value, "*"

endpoint_name = name_from_base(f"jumpstart-example-{infer_model_id}")

inference_instance_type = "ml.p2.xlarge"

# Retrieve the inference docker container uri.
deploy_image_uri = image_uris.retrieve(
    region=None,
    framework=None,
    image_scope="inference",
    model_id=infer_model_id,
    model_version=infer_model_version,
    instance_type=inference_instance_type,
)
# Retrieve the inference script uri.
deploy_source_uri = script_uris.retrieve(
    model_id=infer_model_id, model_version=infer_model_version, script_scope="inference"
)
# Retrieve the base model uri.
base_model_uri = model_uris.retrieve(
    model_id=infer_model_id, model_version=infer_model_version, model_scope="inference"
)
# Create the SageMaker model instance. Note that we need to pass Predictor class when we deploy model through Model class,
# for being able to run inference through the sagemaker API.
model = Model(
    image_uri=deploy_image_uri,
    source_dir=deploy_source_uri,
    model_data=base_model_uri,
    entry_point="inference.py",
    role=aws_role,
    predictor_cls=Predictor,
    name=endpoint_name,
)
# deploy the Model.
base_model_predictor = model.deploy(
    initial_instance_count=1,
    instance_type=inference_instance_type,
    endpoint_name=endpoint_name,
)

Note that the code sets the entry_point script above (in bold, emphasis mine) when deploying. AWS provides good documentation for how this affects the inference container and why it changes the input and output formats.

Leave a Reply

Your email address will not be published. Required fields are marked *