PyTorch - Python Deep Learning Neural Network API

Deep Learning Course 4 of 6 - Level: Intermediate

CNN Image Prediction with PyTorch - Forward Propagation Explained


expand_more chevron_left


expand_more chevron_left

Forward Propagation Explained - Using a PyTorch Neural Network

Welcome to this series on neural network programming with PyTorch. In this episode, we will see how we can use our convolutional neural network to generate an output prediction tensor from a sample image of our dataset. Without further ado, let's get started.

At this point in the series, we've finished building our model, and technically, we could jump right into the training process from here. However, let's work to better understand how our network is working right out of the box, and then, once we understand our network a little more deeply, we'll be better prepared to understand the training process.

The first step is to understand forward propagation.

What is forward propagation?

Forward propagation is the process of transforming an input tensor to an output tensor. At its core, a neural network is a function that maps an input tensor to an output tensor, and forward propagation is just a special name for the process of passing an input to the network and receiving the output from the network.

As we have seen, neural networks operate on data in the form of tensors. The concept forward propagate is used to indicate that the input tensor data is transmitted through the network in the forward direction.

For our network, what this means is simply passing our input tensor to the network and receiving the output tensor. To do this, we pass our sample data to the network's forward() method.

This is why, the forward() method has the name forward, the execution of the forward() is the process of forward propagation.

If you're following the series, we know by now, that we don't call the forward() method directly, rather, we call the network instance. See this episode for more details.

The word forward, is pretty straight forward. ;)

However, the word propagate means to move or transmit through some medium. In the case of neural networks, data propagates through the layers of the network.

There is a notion of backward propagation (backpropagation) as well which makes the term forward propagation suitable as a first step. During the training process, backpropagation occurs after forward propagation.

In our case and from a practical standpoint, forward propagation is the process of passing an input image tensor to the forward() method that we implemented in the last episode. This output is the network's prediction.

In the episode on datasets and data loaders, we saw how to access a single sample image tensor from our training set and more importantly, how to access a batch of image tensors from our data loader. Now that we have our network defined and our forward() method implemented, pass an image to our network to get a prediction.

Predicting with the network: Forward pass

Before we being, we are going to turn off PyTorch's gradient calculation feature. This will stop PyTorch from automatically building a computation graph as our tensor flows through the network.

neural network diagram

The computation graph keeps track of the network's mapping by tracking each computation that happens. The graph is used during the training process to calculate the derivative (gradient) of the loss function with respect to the network's weights.

Since we are not training the network yet, we aren't planning on updating the weights, and so we don't require gradient calculations. We will turn this back on when training begins.

This process of tracking calculations happens in real-time, as the calculations occur. Remember back at the beginning of the series, we said that PyTorch uses a dynamic computational graph. We'll now we're turning it off.

Turning it off isn't strictly necessary but having the feature turned off does reduce memory consumption since the graph isn't stored in memory. This code will turn the feature off.

> torch.set_grad_enabled(False) 

Passing a single image to the network

Let's continue by creating an instance of our Network class:

> network = Network()

Next, we'll procure a single sample from our training set, unpack the image and the label, and verify the image's shape:

> sample = next(iter(train_set)) 
> image, label = sample 
> image.shape 
torch.Size([1, 28, 28])

The image tensor's shape indicates that we have a single channel image that is 28 in height and 28 in width. Cool, this is what we expect.

fashion mnist ankle boot

Now, there's a second step we must preform before simply passing this tensor to our network. When we pass a tensor to our network, the network is expecting a batch, so even if we want to pass a single image, we still need a batch.

This is no problem. We can create a batch that contains a single image. All of this will be packaged into a single four dimensional tensor that reflects the following dimensions.

(batch_size, in_channels, height, width)

This requirement of the network arises from the fact that the forward() method's in the nn.Conv2d convolutional layer classes expect their tenors to have 4 dimensions. This is pretty standard as most neural network implementations deal with batches of input samples rather than single samples.

To put our single sample image tensor into a batch with a size of 1, we just need to unsqueeze() the tensor to add an additional dimension. We saw how to do this in previous episodes.

# Inserts an additional dimension that represents a batch of size 1
torch.Size([1, 1, 28, 28])

Using this, we can now pass the unsqueezed image to our network and get the network's prediction.

> pred = network(image.unsqueeze(0)) # image shape needs to be (batch_size Γ— in_channels Γ— H Γ— W)

> pred
tensor([[0.0991, 0.0916, 0.0907, 0.0949, 0.1013, 0.0922, 0.0990, 0.1130, 0.1107, 0.1074]])

> pred.shape
torch.Size([1, 10])

> label

> pred.argmax(dim=1)

And we did it! We've used our forward method to get a prediction from the network. The network has returned a prediction tensor that contains a prediction value for each of the ten categories of clothing.

The shape of the prediction tensor is 1 x 10. This tells us that the first axis has a length of one while the second axis has a length of ten. The interpretation of this is that we have one image in our batch and ten prediction classes.

(batch size, number of prediction classes)

For each input in the batch, and for each prediction class, we have a prediction value. If we wanted these values to be probabilities, we could just the softmax() function from the nn.functional package.

> F.softmax(pred, dim=1)
tensor([[0.1096, 0.1018, 0.0867, 0.0936, 0.1102, 0.0929, 0.1083, 0.0998, 0.0943, 0.1030]])

> F.softmax(pred, dim=1).sum()

The label for the first image in our training set is 9, and using the argmax() function we can see that the highest value in our prediction tensor occurred at the class represented by index 7.

  • Prediction: Sneaker (7)
  • Actual: Ankle boot (9)

Remember, each prediction class is represented by a corresponding index.

Index Label
0 T-shirt/top
1 Trouser
2 Pullover
3 Dress
4 Coat
5 Sandal
6 Shirt
7 Sneaker
8 Bag
9 Ankle boot

The prediction in this case is incorrect, which is what we expect because the weights in the network were generated randomly.

Network weights are randomly generated

There are a couple of important things we need to point out about these results. Most of the probabilities came in close to 10%, and this makes sense because our network is guessing and we have ten prediction classes coming from a balanced dataset.

Another implication of the randomly generated weights is that each time we create a new instance of our network, the weights within the network will be different. This means that the predictions we get will be different if we create different networks. Keep this in mind. Your predictions will be different from what we see here.

> net1 = Network()
> net2 = Network()

> net1(image.unsqueeze(0))
tensor([[ 0.0855,  0.1123, -0.0290, -0.1411, -0.1293, -0.0688,  0.0149,  0.1410, -0.0936, -0.1157]])

> net2(image.unsqueeze(0))
tensor([[-0.0408, -0.0696, -0.1022, -0.0316, -0.0986, -0.0123,  0.0463, 0.0248,  0.0157, -0.1251]])

Using the data loader to pass a batch is next

We now ready to pass a batches of data to our network and interpret the results.

We should now have a good understanding of what forward propagation is and how we can pass a single image tensor to a convolutional neural network in PyTorch. In the next post, we will see how to use the data loader to pass a batch to our network. I'll see you there!


expand_more chevron_left
deeplizard logo DEEPLIZARD Message notifications

Quiz Results


expand_more chevron_left
In this episode, we will see how we can use our convolutional neural network (CNN) to generate an output prediction tensor from a sample image of our dataset. πŸ•’πŸ¦Ž VIDEO SECTIONS πŸ¦ŽπŸ•’ 00:00 Welcome to DEEPLIZARD - Go to for learning resources 00:30 Help deeplizard add video timestamps - See example in the description 11:29 Collective Intelligence and the DEEPLIZARD HIVEMIND πŸ’₯🦎 DEEPLIZARD COMMUNITY RESOURCES 🦎πŸ’₯ πŸ‘‹ Hey, we're Chris and Mandy, the creators of deeplizard! πŸ‘€ CHECK OUT OUR VLOG: πŸ”— πŸ’» DOWNLOAD ACCESS TO CODE FILES πŸ€– Available for members of the deeplizard hivemind: πŸ”— ❀️🦎 Special thanks to the following polymaths of the deeplizard hivemind: Tammy BufferUnderrun Mano Prime πŸ‘€ Follow deeplizard: Our vlog: Facebook: Instagram: Twitter: Patreon: YouTube: πŸŽ“ Deep Learning with deeplizard: Deep Learning Dictionary - Deep Learning Fundamentals - Learn TensorFlow - Learn PyTorch - Reinforcement Learning - Generative Adversarial Networks - πŸŽ“ Other Courses: Data Science - Trading - πŸ›’ Check out products deeplizard recommends on Amazon: πŸ”— πŸ“• Get a FREE 30-day Audible trial and 2 FREE audio books using deeplizard's link: πŸ”— 🎡 deeplizard uses music by Kevin MacLeod πŸ”— πŸ”— ❀️ Please use the knowledge gained from deeplizard content for good, not evil.


expand_more chevron_left
deeplizard logo DEEPLIZARD Message notifications

Update history for this page

Did you know you that deeplizard content is regularly updated and maintained?

  • Updated
  • Maintained

Spot something that needs to be updated? Don't hesitate to let us know. We'll fix it!

All relevant updates for the content on this page are listed below.