TensorFlow - Python Deep Learning Neural Network API

Deep Learning Course - Level: Beginner

Build a Fine-Tuned Neural Network with TensorFlow's Keras API

video

expand_more chevron_left

text

expand_more chevron_left

Build a fine-tuned neural network with TensorFlow's Keras API

In this episode, we'll demonstrate how to fine-tune a pre-trained model to classify images as cats and dogs.

VGG16 and ImageNet

The pre-trained model we'll be working with to classify images of cats and dogs is called VGG16, which is the model that won the 2014 ImageNet competition.

In the ImageNet competition, multiple teams compete to build a model that best classifies images from the ImageNet library. The ImageNet library houses thousands of images belonging to 1000 different categories.

We'll import this VGG16 model and then fine-tune it using Keras. The fine-tuned model will not classify images as one of the 1000 categories for which it was trained on, but instead it will only work to classify images as either cats or dogs.

Note that dogs and cats were included in the ImageNet library from which VGG16 was originally trained. Therefore, the model has already learned the features of cats and dogs. Given this, the fine-tuning we'll do on this model will be very minimal. In later episodes, we'll do more involved fine-tuning and utilize transfer learning to classify completely new data than what was included in the training set.

To understand fine-tuning and transfer learning on a fundamental level, check out the corresponding episode in the Deep Learning Fundamentals course.

VGG16 Preprocessing

Let's first check out a batch of training data using the plotting function we brought in previously.

imgs, labels = next(train_batches)
plotImages(imgs)
print(labels)
[[1. 0.]
 [0. 1.]
 [1. 0.]
 [0. 1.]
 [0. 1.]
 [1. 0.]
 [0. 1.]
 [0. 1.]
 [0. 1.]
 [1. 0.]]

When we previously inspected these images, we briefly discussed that the color data was skewed as a result of preprocessing the images using the tf.keras.applications.vgg16.preprocess_input function.

To understand what preprocessing is needed for images that will be passed to a VGG16 model, we can look at the VGG16 paper.

Under the 2.1 Architecture section, we can see that the authors stated that, "The only preprocessing we do is subtracting the mean RGB value, computed on the training set, from each pixel."

This is the preprocessing that was used on the original training data, and therefore, this is the way we need to process images before passing them to VGG16 or a fine-tuned VGG16 model.

This processing is what is causing the underlying color data to look distorted.

Building a fine-tuned model

Now, let's begin building our model. First, be sure that you still have all the imports that we brought in a couple episodes back when we began our work on CNNs.

Next, we'll import the VGG16 model from Keras. Note, an internet connection is needed to download this model.

vgg16_model = tf.keras.applications.vgg16.VGG16()

The original trained VGG16 model, along with its saved weights and other parameters, is now downloaded onto our machine.

We can check out a summary of the model just to see what the architecture looks like.

vgg16_model.summary()
Model: "vgg16"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 224, 224, 3)]     0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 224, 224, 64)      1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 224, 224, 64)      36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 112, 112, 64)      0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 112, 112, 128)     73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 112, 112, 128)     147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 56, 56, 128)       0         
_________________________________________________________________
block3_conv1 (Conv2D)        (None, 56, 56, 256)       295168    
_________________________________________________________________
block3_conv2 (Conv2D)        (None, 56, 56, 256)       590080    
_________________________________________________________________
block3_conv3 (Conv2D)        (None, 56, 56, 256)       590080    
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, 28, 28, 256)       0         
_________________________________________________________________
block4_conv1 (Conv2D)        (None, 28, 28, 512)       1180160   
_________________________________________________________________
block4_conv2 (Conv2D)        (None, 28, 28, 512)       2359808   
_________________________________________________________________
block4_conv3 (Conv2D)        (None, 28, 28, 512)       2359808   
_________________________________________________________________
block4_pool (MaxPooling2D)   (None, 14, 14, 512)       0         
_________________________________________________________________
block5_conv1 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_conv2 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_conv3 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, 7, 7, 512)         0         
_________________________________________________________________
flatten (Flatten)            (None, 25088)             0         
_________________________________________________________________
fc1 (Dense)                  (None, 4096)              102764544 
_________________________________________________________________
fc2 (Dense)                  (None, 4096)              16781312  
_________________________________________________________________
predictions (Dense)          (None, 1000)              4097000   
=================================================================
Total params: 138,357,544
Trainable params: 138,357,544
Non-trainable params: 0
_________________________________________________________________

In contrast, recall how much simpler the CNN was that we worked with in the last episode. VGG16 is much more complex and sophisticated and has many more layers than our previous model.

Notice that the last Dense layer of VGG16 has 1000 outputs. These outputs correspond to the 1000 categories in the ImageNet library.

Since we're only going to be classifying two categories, cats and dogs, we need to modify this model in order for it to do what we want it to do, which is to only classify cats and dogs.

Before we do that, note that the type of Keras models we've been working with so far in this series have been of type Sequential.

If we check out the type of model vgg16_model is, we see that it is of type Model, which is from the Keras' Functional API.

type(vgg16_model)
tensorflow.python.keras.engine.training.Model

We've not yet worked with the more sophisticated Functional API, although we will work with it in later episodes using the MobileNet model.

For now, we're going to go through a process to convert the Functional model to a Sequential model, so that it will be easier for us to work with given our current knowledge.

We first create a new model of type Sequential. We then iterate over each of the layers in vgg16_model, except for the last layer, and add each layer to the new Sequential model.

model = Sequential()
for layer in vgg16_model.layers[:-1]:
    model.add(layer)

Now, we have replicated the entire vgg16_model (excluding the output layer) to a new Sequential model, which we've just given the name model.

Next, we'll iterate over each of the layers in our new Sequential model and set them to be non-trainable. This freezes the weights and other trainable parameters in each layer so that they will not be trained or updated when we later pass in our images of cats and dogs.

for layer in model.layers:
    layer.trainable = False

The reason we don't want to retrain these layers is because, as mentioned earlier, cats and dogs were already included in the original ImageNet library. So, VGG16 already does a nice job at classifying these categories. We only want to modify the model such that the output layer understands only how to classify cats and dogs and nothing else. Therefore, we don't want any re-training to occur on the earlier layers.

Next, we add our new output layer, consisting of only 2 nodes that correspond to cat and dog. This output layer will be the only trainable layer in the model.

model.add(Dense(units=2, activation='softmax'))

We can now check out a summary of our model and see that everything is exactly the same as the original vgg16_model, except for now, the output layer has only 2 nodes, rather than 1000, and the number of trainable parameters has drastically decreased since we froze all the parameters in the earlier layers.

model.summary()
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
block1_conv1 (Conv2D)        (None, 224, 224, 64)      1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 224, 224, 64)      36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 112, 112, 64)      0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 112, 112, 128)     73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 112, 112, 128)     147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 56, 56, 128)       0         
_________________________________________________________________
block3_conv1 (Conv2D)        (None, 56, 56, 256)       295168    
_________________________________________________________________
block3_conv2 (Conv2D)        (None, 56, 56, 256)       590080    
_________________________________________________________________
block3_conv3 (Conv2D)        (None, 56, 56, 256)       590080    
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, 28, 28, 256)       0         
_________________________________________________________________
block4_conv1 (Conv2D)        (None, 28, 28, 512)       1180160   
_________________________________________________________________
block4_conv2 (Conv2D)        (None, 28, 28, 512)       2359808   
_________________________________________________________________
block4_conv3 (Conv2D)        (None, 28, 28, 512)       2359808   
_________________________________________________________________
block4_pool (MaxPooling2D)   (None, 14, 14, 512)       0         
_________________________________________________________________
block5_conv1 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_conv2 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_conv3 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, 7, 7, 512)         0         
_________________________________________________________________
flatten (Flatten)            (None, 25088)             0         
_________________________________________________________________
fc1 (Dense)                  (None, 4096)              102764544 
_________________________________________________________________
fc2 (Dense)                  (None, 4096)              16781312  
_________________________________________________________________
dense_1 (Dense)              (None, 2)                 8194      
=================================================================
Total params: 134,268,738
Trainable params: 8,194
Non-trainable params: 134,260,544

In the next episode, we'll see how we can train this modified model on our images of cats and dogs.

quiz

expand_more chevron_left
deeplizard logo DEEPLIZARD Message notifications

Quiz Results

resources

expand_more chevron_left
In this episode, we'll demonstrate how to fine-tune a pre-trained model called VGG16 to classify images as cats and dogs. πŸ•’πŸ¦Ž VIDEO SECTIONS πŸ¦ŽπŸ•’ 00:00 Welcome to DEEPLIZARD - Go to deeplizard.com for learning resources 00:16 VGG16 and ImageNet 04:17 Building a Fine-tuned Model 11:19 Collective Intelligence and the DEEPLIZARD HIVEMIND πŸ’₯🦎 DEEPLIZARD COMMUNITY RESOURCES 🦎πŸ’₯ πŸ‘‹ Hey, we're Chris and Mandy, the creators of deeplizard! πŸ‘€ CHECK OUT OUR VLOG: πŸ”— https://youtube.com/deeplizardvlog πŸ’ͺ CHECK OUT OUR FITNESS CHANNEL: πŸ”— https://www.youtube.com/channel/UCdCxHNCexDrAx78VfAuyKiA 🧠 Use code DEEPLIZARD at checkout to receive 15% off your first Neurohacker order: πŸ”— https://neurohacker.com/shop?rfsn=6488344.d171c6 ❀️🦎 Special thanks to the following polymaths of the deeplizard hivemind: Mano Prime πŸ‘€ Follow deeplizard: Our vlog: https://youtube.com/deeplizardvlog Fitness: https://www.youtube.com/channel/UCdCxHNCexDrAx78VfAuyKiA Facebook: https://facebook.com/deeplizard Instagram: https://instagram.com/deeplizard Twitter: https://twitter.com/deeplizard Patreon: https://patreon.com/deeplizard YouTube: https://youtube.com/deeplizard πŸŽ“ Deep Learning with deeplizard: AI Art for Beginners - https://deeplizard.com/course/sdcpailzrd Deep Learning Dictionary - https://deeplizard.com/course/ddcpailzrd Deep Learning Fundamentals - https://deeplizard.com/course/dlcpailzrd Learn TensorFlow - https://deeplizard.com/course/tfcpailzrd Learn PyTorch - https://deeplizard.com/course/ptcpailzrd Natural Language Processing - https://deeplizard.com/course/txtcpailzrd Reinforcement Learning - https://deeplizard.com/course/rlcpailzrd Generative Adversarial Networks - https://deeplizard.com/course/gacpailzrd Stable Diffusion Masterclass - https://deeplizard.com/course/dicpailzrd πŸŽ“ Other Courses: DL Fundamentals Classic - https://deeplizard.com/learn/video/gZmobeGL0Yg Deep Learning Deployment - https://deeplizard.com/learn/video/SI1hVGvbbZ4 Data Science - https://deeplizard.com/learn/video/d11chG7Z-xk Trading - https://deeplizard.com/learn/video/ZpfCK_uHL9Y πŸ›’ Check out products deeplizard recommends on Amazon: πŸ”— https://amazon.com/shop/deeplizard πŸ“• Get a FREE 30-day Audible trial and 2 FREE audio books using deeplizard's link: πŸ”— https://amzn.to/2yoqWRn 🎡 deeplizard uses music by Kevin MacLeod πŸ”— https://youtube.com/channel/UCSZXFhRIx6b0dFX3xS8L1yQ ❀️ Please use the knowledge gained from deeplizard content for good, not evil.

updates

expand_more chevron_left
deeplizard logo DEEPLIZARD Message notifications

Update history for this page

Did you know you that deeplizard content is regularly updated and maintained?

  • Updated
  • Maintained

Spot something that needs to be updated? Don't hesitate to let us know. We'll fix it!


All relevant updates for the content on this page are listed below.