Code Update for CNN Training with TensorFlow's Keras API
text
Code Update for CNN Training with TensorFlow's Keras API
In this episode, we'll discuss an update that we'll need to know regarding the upcoming code before building and training our first convolutional neural network (CNN).
In the upcoming episodes, we'll demonstrate how to train a CNN using the image data we organized and processed previously.
Recall that we stored the image data in a Keras Sequence
, specifically a DirectoryIterator
, using the ImageDataGenerator.flow_from_directory()
function. This function generates batches of image data from the specified location on disk.
As you've seen in a previous episode, when we train a model, we call the fit()
function on the model and pass in the training data. We've seen how this was done when our training data was stored in a simple numpy array, but in the upcoming CNN episodes, we'll see how this is done specifically for our training data that we stored in a DirectoryIterator
.
Recently, there was a change introduced by TensorFlow that now requires us to pass in another parameter to the fit()
function when our data is stored in an infinitely repeating data set, like a DirectoryIterator
.
Note that a DirectoryIterator
is indeed an infinitely repeating data set, as the batches of data generated by the iterator will continue to come infinitely as long as we want them to. You can see this by passing the iterator to the built-in Python next()
function and running it over and over again to infinitely generate new batches of data.
Required steps_per_epoch
Parameter
Now, back to the parameter that is required to be passed to the fit()
function for this type of data. This parameter is called steps_per_epoch
and should be set to the number of steps (batches of samples) to yield from the training set before declaring one epoch finished and starting the next epoch.
This is typically set to be equal to the number of samples in our training set divided by the batch size. For example, if we have 100
training images, and our batch size is 5
, then we would set steps_per_epoch=20
.
This parameter actually isn't new, however, in previous TensorFlow versions, it was not required to be specified when our data was stored in a Keras Sequence
, like the DirectoryIterator
we've stored our data in. Instead TensorFlow would default to using the size of the data set divided by the batch size as the number of steps_per_epoch
.
Depending on which version of TensorFlow you're running, if you don't specify this parameter, then model.fit()
will run infinitely on the first epoch and never complete.
Additional Required Parameters
Note that in addition to steps_per_epoch
that we specify regarding the training data when we call model.fit()
, we also need to specify a parameter called validation_steps
if we are also passing in validation data to the model. This parameter acts in the exact same way as steps_per_epoch
, except for on our validation set.
Lastly, when we use the model for inference by calling predict()
on the model and passing in the test set, we need to also specify the parameter called steps
here. In this case, this is the number of steps (batches of samples) to yield from the test set before declaring the prediction round finished.
Tracking the Issue
It's unclear whether this parameter will continue to be required or not, as it was initially required for this type of infinitely repeating data ste, and then not required, and now required again. You can track this issue on TensorFlow's Github if you're interested.
In the upcoming episodes when we call fit()
or predict()
, you will see that in the video portion of the episodes, steps_per_epoch
, validation_steps
, and steps
have not been specified. However, the corresponding blogs will have updated code for whatever TensorFlow requires at the time. So, if the parameters are required, they will be set in the blog. If the parameters are no longer required, they will not be set in the blog.
Now we're ready to begin building and training our first convolutional neural network in the next episode.
quiz
resources
updates
Committed by on