CNN Training Process
Welcome to this neural network programming series with PyTorch. In this episode, we will learn the steps needed to train a convolutional neural network.
So far in this series, we learned about Tensors, and we've learned all about PyTorch neural networks. We are now ready to begin the training process.
- Prepare the data
- Build the model
Train the model
- Calculate the loss, the gradient, and update the weights
- Analyze the model's results
Training: What we do after the forward pass
During training, we do a forward pass, but then what? We'll suppose we get a batch and pass it forward through the network. Once the output is obtained, we compare the predicted output to the actual labels, and once we know how close the predicted values are from the actual labels, we teak the weights inside the network in such a way that the values the network predicts move closer to the true values (labels).
All of this is for a single batch, and we repeat this process for every batch until we have covered every sample in our training set. After we've completed this process for all of the batches and passed over every sample in our training set, we say that an epoch is complete. We use the word epoch to represent a time period in which our entire training set has been covered.
During the entire training process, we do as many epochs as necessary to reach our desired level of accuracy. With this, we have the following steps:
- Get batch from the training set.
- Pass batch to network.
- Calculate the loss (difference between the predicted values and the true values).
- Calculate the gradient of the loss function w.r.t the network's weights.
- Update the weights using the gradients to reduce the loss.
- Repeat steps 1-5 until one epoch is completed.
- Repeat steps 1-6 for as many epochs required to reach the minimum loss.
We already know exactly how to do steps
2. If you've already covered the
deep learning fundamentals series, then you know that we use a loss function to perform step
3, and you know that we use
backpropagation and an optimization algorithm to perform step
7 are just standard Python loops (the training loop). Let's
see how this is done in code.
The Training Process
Since we disabled PyTorch's gradient tracking feature in a previous episode, we need to be sure to turn it back on (it is on by default).
> torch.set_grad_enabled(True) <torch.autograd.grad_mode.set_grad_enabled at 0x15b22d012b0>
Preparing for the Forward Pass
We already know how to get a batch and pass it forward through the network. Let's see what we do after the forward pass is complete.
We'll begin by:
Creating an instance of our
Creating a data loader that provides batches of size
100from our training set.
- Unpacking the images and labels from one of these batches.
> network = Network() > train_loader = torch.utils.data.DataLoader(train_set, batch_size=100) > batch = next(iter(train_loader)) # Getting a batch > images, labels = batch
Next, we are ready to pass our batch of images forward through the network and obtain the output predictions. Once we have the prediction tensor, we can use the predictions and the true labels to calculate the loss.
Calculating the loss
To do this we will use the
cross_entropy() loss function that is available in PyTorch's
nn.functional API. Once we have the loss, we can print it, and also check the number
of correct predictions using the function we created a
> preds = network(images) > loss = F.cross_entropy(preds, labels) # Calculating the loss > loss.item() 2.307542085647583 > get_num_correct(preds, labels) 9
cross_entropy() function returned a scalar valued tenor, and so we used the
item() method to print the loss as a Python number. We got
9 out of
100 correct, and since we have
10 prediction classes, this is what we'd expect by guessing at random.
Calculating the Gradients
Calculating the gradients is very easy using PyTorch. Since our network is a PyTorch
nn.Module, PyTorch has created a computation graph under the hood. As our tensor flowed forward through
our network, all of the computations where added to the graph. The computation graph is then used by PyTorch to calculate the gradients of the loss function with respect to the network's weights.
Before we calculate the gradients, let's verify that we currently have no gradients inside our
conv1 layer. The gradients are tensors that are accessible in the
for gradient) attribute of the weight tensor of each layer.
> network.conv1.weight.grad None
calculate the gradients, we call the
backward() method on the loss tensor, like so:
loss.backward() # Calculating the gradients
Now, the gradients of the loss function have been stored inside weight tensors.
> network.conv1.weight.grad.shape torch.Size([6, 1, 5, 5])
These gradients are used by the optimizer to update the respective weights. To create our optimizer, we use the
torch.optim package that has many optimization algorithm implementations that
we can use. We'll use
Adam for our example.
Updating the Weights
Adam class constructor, we pass the network parameters (this is how the optimizer is able to access the gradients), and we pass the
Finally, all we have to do to update the weights is to tell the optimizer to use the gradients to step in the direction of the loss function's minimum.
optimizer = optim.Adam(network.parameters(), lr=0.01) optimizer.step() # Updating the weights
step() function is called, the optimizer updates the weights using the gradients that are stored in the network's parameters. This means that we should expect our loss to be
reduced if we pass the same batch through the network again. Checking this, we can see that this is indeed the case:
> preds = network(images) > loss.item() > loss = F.cross_entropy(preds, labels) 2.262690782546997 > get_num_correct(preds, labels) 15
Train Using a Single Batch
We can summarize the code for training with a single batch in the following way:
network = Network() train_loader = torch.utils.data.DataLoader(train_set, batch_size=100) optimizer = optim.Adam(network.parameters(), lr=0.01) batch = next(iter(train_loader)) # Get Batch images, labels = batch preds = network(images) # Pass Batch loss = F.cross_entropy(preds, labels) # Calculate Loss loss.backward() # Calculate Gradients optimizer.step() # Update Weights print('loss1:', loss.item()) preds = network(images) loss = F.cross_entropy(preds, labels) print('loss2:', loss.item())
loss1: 2.3034827709198 loss2: 2.2825052738189697
Building the Training Loop is Next
We should now have a good understanding of the training process. In the next episode, we'll see how these ideas are extended by completing the process by constructing the training loop. See you in the next one!