CNN Image Prediction with PyTorch - Forward Propagation Explained
Forward Propagation Explained - Using a PyTorch Neural Network
Welcome to this series on neural network programming with PyTorch. In this episode, we will see how we can use our convolutional neural network to generate an output prediction tensor from a sample image of our dataset. Without further ado, let's get started.
At this point in the series, we've finished building our model, and technically, we could jump right into the training process from here. However, let's work to better understand how our network is working right out of the box, and then, once we understand our network a little more deeply, we'll be better prepared to understand the training process.
The first step is to understand forward propagation.
What is forward propagation?
Forward propagation is the process of transforming an input tensor to an output tensor. At its core, a neural network is a function that maps an input tensor to an output tensor, and forward propagation is just a special name for the process of passing an input to the network and receiving the output from the network.
As we have seen, neural networks operate on data in the form of tensors. The concept forward propagate is used to indicate that the input tensor data is transmitted through the network in the forward direction.
For our network, what this means is simply passing our input tensor to the network and receiving the output tensor. To do this, we pass our sample data to the network's
This is why, the
forward() method has the name
forward, the execution of the
forward() is the process of forward propagation.
If you're following the series, we know by now, that we don't call the forward() method directly, rather, we call the network instance. See this episode for more details.
The word forward, is pretty straight forward. ;)
However, the word propagate means to move or transmit through some medium. In the case of neural networks, data propagates through the layers of the network.
There is a notion of backward propagation (backpropagation) as well which makes the term forward propagation suitable as a first step. During the training process, backpropagation occurs after forward propagation.
In our case and from a practical standpoint, forward propagation is the process of passing an input image tensor to the
forward() method that we implemented in the
last episode. This output is the network's prediction.
In the episode on
datasets and data loaders, we saw how to access a single sample image tensor from our training set and more importantly, how to access a batch of image tensors from our data loader. Now that we have
our network defined and our
forward() method implemented, pass an image to our network to get a prediction.
Predicting with the network: Forward pass
Before we being, we are going to turn off PyTorch's gradient calculation feature. This will stop PyTorch from automatically building a computation graph as our tensor flows through the network.
The computation graph keeps track of the network's mapping by tracking each computation that happens. The graph is used during the training process to calculate the derivative (gradient) of the loss function with respect to the network's weights.
Since we are not training the network yet, we aren't planning on updating the weights, and so we don't require gradient calculations. We will turn this back on when training begins.
This process of tracking calculations happens in real-time, as the calculations occur. Remember back at the beginning of the series, we said that PyTorch uses a dynamic computational graph. We'll now we're turning it off.
Turning it off isn't strictly necessary but having the feature turned off does reduce memory consumption since the graph isn't stored in memory. This code will turn the feature off.
Passing a single image to the network
Let's continue by creating an instance of our Network class:
> network = Network()
Next, we'll procure a single sample from our training set, unpack the image and the label, and verify the image's shape:
> sample = next(iter(train_set)) > image, label = sample > image.shape torch.Size([1, 28, 28])
The image tensor's shape indicates that we have a single channel image that is
28 in height and
28 in width. Cool, this is what we expect.
Now, there's a second step we must preform before simply passing this tensor to our network. When we pass a tensor to our network, the network is expecting a batch, so even if we want to pass a single image, we still need a batch.
This is no problem. We can create a batch that contains a single image. All of this will be packaged into a single four dimensional tensor that reflects the following dimensions.
This requirement of the network arises from the fact that the
forward() method's in the
nn.Conv2d convolutional layer classes expect their tenors to have
4 dimensions. This is pretty standard as most neural network implementations deal with batches of input samples rather than single samples.
To put our single sample image tensor into a batch with a size of
1, we just need to
unsqueeze() the tensor to add an additional dimension. We saw how to do this in
# Inserts an additional dimension that represents a batch of size 1 image.unsqueeze(0).shape torch.Size([1, 1, 28, 28])
Using this, we can now pass the unsqueezed image to our network and get the network's prediction.
> pred = network(image.unsqueeze(0)) # image shape needs to be (batch_size × in_channels × H × W) > pred tensor([[0.0991, 0.0916, 0.0907, 0.0949, 0.1013, 0.0922, 0.0990, 0.1130, 0.1107, 0.1074]]) > pred.shape torch.Size([1, 10]) > label 9 > pred.argmax(dim=1) tensor()
And we did it! We've used our forward method to get a prediction from the network. The network has returned a prediction tensor that contains a prediction value for each of the ten categories of clothing.
The shape of the prediction tensor is
1 x 10. This tells us that the first axis has a length of one while the second axis has a length of ten. The interpretation of this is that we have one
image in our batch and ten prediction classes.
For each input in the batch, and for each prediction class, we have a prediction value. If we wanted these values to be probabilities, we could just the
softmax() function from the
> F.softmax(pred, dim=1) tensor([[0.1096, 0.1018, 0.0867, 0.0936, 0.1102, 0.0929, 0.1083, 0.0998, 0.0943, 0.1030]]) > F.softmax(pred, dim=1).sum() tensor(1.)
The label for the first image in our training set is
9, and using the
argmax() function we can see that the highest value in our prediction tensor occurred at the class represented by index
- Prediction: Sneaker (7)
- Actual: Ankle boot (9)
Remember, each prediction class is represented by a corresponding index.
The prediction in this case is incorrect, which is what we expect because the weights in the network were generated randomly.
Network weights are randomly generated
There are a couple of important things we need to point out about these results. Most of the probabilities came in close to
10%, and this makes sense because our network is guessing and we have
ten prediction classes coming from a
Another implication of the randomly generated weights is that each time we create a new instance of our network, the weights within the network will be different. This means that the predictions we get will be different if we create different networks. Keep this in mind. Your predictions will be different from what we see here.
> net1 = Network() > net2 = Network() > net1(image.unsqueeze(0)) tensor([[ 0.0855, 0.1123, -0.0290, -0.1411, -0.1293, -0.0688, 0.0149, 0.1410, -0.0936, -0.1157]]) > net2(image.unsqueeze(0)) tensor([[-0.0408, -0.0696, -0.1022, -0.0316, -0.0986, -0.0123, 0.0463, 0.0248, 0.0157, -0.1251]])
Using the data loader to pass a batch is next
We now ready to pass a batches of data to our network and interpret the results.
We should now have a good understanding of what forward propagation is and how we can pass a single image tensor to a convolutional neural network in PyTorch. In the next post, we will see how to use the data loader to pass a batch to our network. I'll see you there!
Committed by on