Neural Network Programming - Deep Learning with PyTorch

with deeplizard.

CNN Confusion Matrix with PyTorch - Neural Network Programming

July 4, 2019 by

Blog

Create a Confusion Matrix with PyTorch

Welcome to this neural network programming series. In this episode, we're going to build some functions that will allow us to get a prediction tensor for every sample in our training set.

Then, we'll see how we can take this prediction tensor, along with the labels for each sample, to create a confusion matrix. This confusion matrix will allow us to see which categories our network is confusing with one another. Without further ado, let's get started.

Where we are now in the course.

  • Prepare the data
  • Build the model
  • Train the model
  • Analyze the model's results
    • Building, plotting, and interpreting a confusion matrix

Be sure to see the previous episode in this course for all the code setup details.

Confusion Matrix Requirements

To create a confusion matrix for our entire dataset, we need to have a prediction tensor with a single dimension that has the same length as our training set.

> len(train_set)
60000

This prediction tensor will contain ten predictions for each sample from our training set (one for each category of clothing). After we have obtained this tensor, we can use the labels tensor to generate a confusion matrix.

> len(train_set.targets)
60000

A confusion matrix will show us where the model is getting confused. To be more specific, the confusion matrix will show us which categories the model is predicting correctly and which categories the model is predicting incorrectly. For the incorrect predictions, we will be able to see which category the model predicted, and this will show us which categories are confusing the model.

Get Predictions for the Entire Training Set

To get the predictions for all the training set samples, we need to pass all of the samples forward through the network. To do this, it is possible to create a DataLoader that has batch_size=1. This will pass a single batch to the network at once and will give us the desired prediction tensor for all the training set samples.

However, depending on the computing resources and the size of the training set if we were training on a different data set, we need a way to prediction on smaller batches and collect the results. To collect the results, we'll use the torch.cat() function to concatenate the output tensors together to obtain our single prediction tensor. Let's build a function to do this.

Building a Function to get Predictions for ALL Samples

We'll create a function called get_all_preds(), and we'll pass a model and a data loader. The model will be used to obtain the predictions, and the data loader will be used to provide the batches from the training set.

All the function needs to do is iterate over the data loader passing the batches to the model and concatenating the results of each batch to a prediction tensor that will returned to the caller.

@torch.no_grad()
def get_all_preds(model, loader):
    all_preds = torch.tensor([])
    for batch in loader:
        images, labels = batch

        preds = model(images)
        all_preds = torch.cat(
            (all_preds, preds)
            ,dim=0
        )
    return all_preds

The implantation of this function creates an empty tensor, all_preds to hold the output predictions. Then, it iterates over the batches coming from the data loader, and concatenates the output predictions with the all_preds tensor. Finally, all the predictions, all_preds, is returned to the caller.

Note at the top, we have annotated the function using the @torch.no_grad() PyTorch decoration. This is because we want this functions execution to omit gradient tracking.

This is because gradient tracking uses memory, and during inference (getting predictions while not training) there is no need to keep track of the computational graph. The decoration is one way of locally turning off the gradient tracking feature while executing specific functions.

Locally Disabling PyTorch Gradient Tracking

We are ready now to make the call to obtain the predictions for the training set. All we need to do is create a data loader with a reasonable batch size, and pass the model and data loader to the get_all_preds() function.

In a previous episode, we saw how use turned off PyTorch's gradient tracking feature when it was not needed, and we turned it back on when we started the training process.

We specifically need the gradient calculation feature anytime we are going to calculate gradients using the backward() function. Otherwise, it is a good idea to turn it off because having it off will reduce memory consumption for computations, e.g. when we are using networks for predicting (inference).

We can disable gradient computations for specific or local spots in our code, e.g. like what we just saw with the annotated function. As another example, we can use Python's with context manger keyword to specify that a specify block of code should exclude gradient computations.

with torch.no_grad():
    prediction_loader = torch.utils.data.DataLoader(train_set, batch_size=10000)
    train_preds = get_all_preds(network, prediction_loader)

Both of these options are valid. Let's keep both of these and get our predictions.

Using the Predictions Tensor

Now, that we have the prediction tensor, we can pass it to the get_num_correct() function that we created in a previous episode, along with the training set labels, to get the total number of correct predictions.

> preds_correct = get_num_correct(train_preds, train_set.targets)

> print('total correct:', preds_correct)
> print('accuracy:', preds_correct / len(train_set))
total correct: 53578
accuracy: 0.8929666666666667

We can see the total number of correct predictions and print the accuracy by dividing by the number of samples in the training set.

Building the Confusion Matrix

Our task in building the confusion matrix is to count the number of predicted values against the true values (targets).

This will create a matrix that acts as a heat map telling us where the predicted values fall relative to the true values.

To do this, we need to have the targets tensor and the predicted label from the train_preds tensor.

> train_set.targets
tensor([9, 0, 0,  ..., 3, 0, 5])

> train_preds.argmax(dim=1)
tensor([9, 0, 0,  ..., 3, 0, 5])

Now, if we compare the two tensors element-wise, we can see if the predicted label matches the target. Additionally, if we are counting the number of predicted labels vs the target labels, the values inside the two tensors act as coordinates for our matrix. Let's stack these two tensors along the second dimension so we can have 60,000 ordered pairs.

> stacked = torch.stack(
    (
        train_set.targets
        ,train_preds.argmax(dim=1)
    )
    ,dim=1
)


> stacked.shape
torch.Size([60000, 2])


> stacked
tensor([
    [9, 9],
    [0, 0],
    [0, 0],
    ...,
    [3, 3],
    [0, 0],
    [5, 5]
])

> stacked[0].tolist()
[9, 9]

Now, we can iterate over these pairs and count the number of occurrences at each position in the matrix. Let's create the matrix. Since we have ten prediction categories, we'll have a ten by ten matrix. Check here to learn about the stack() function.

> cmt = torch.zeros(10,10, dtype=torch.int64)
> cmt
tensor([
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
])

Now, we'll iterate over the prediction-target pairs and add one to the value inside the matrix each time the particular position occurs.

for p in stacked:
    tl, pl = p.tolist()
    cmt[tl, pl] = cmt[tl, pl] + 1

This gives us the following confusion matrix tensor.

> cmt
tensor([
    [5637,    3,   96,   75,   20,   10,   86,    0,   73,    0],
    [  40, 5843,    3,   75,   16,    8,    5,    0,   10,    0],
    [  87,    4, 4500,   70, 1069,    8,  156,    0,  106,    0],
    [ 339,   61,   19, 5269,  203,   10,   72,    2,   25,    0],
    [  23,    9,  263,  209, 5217,    2,  238,    0,   39,    0],
    [   0,    0,    0,    1,    0, 5604,    0,  333,   13,   49],
    [1827,    7,  716,  104,  792,    3, 2370,    0,  181,    0],
    [   0,    0,    0,    0,    0,   22,    0, 5867,    4,  107],
    [  32,    1,   13,   15,   19,    5,   17,   11, 5887,    0],
    [   0,    0,    0,    0,    0,   28,    0,  234,    6, 5732]
])

Note that the example below will have different values because these two examples were created at different times.

Plotting the Confusion Matrix

To generate the actual confusion matrix as a numpy.ndarray, we use the confusion_matrix() function from the sklearn.metrics library. Let's get this imported along with our other needed imports.

import matplotlib.pyplot as plt

from sklearn.metrics import confusion_matrix
from resources.plotcm import plot_confusion_matrix

For the last import, note that plotcm is a file, plotcm.py that lives in a folder called resources in the current directory. Inside the plotcm.py file, there is a function called plot_confusion_matrix() that we will call. You'll need to implement this on your system. We'll look at how to do this in a minute. First, let's generate the confusion matrix.

We can generate the confusion matrix like so:

> cm = confusion_matrix(train_set.targets, train_preds.argmax(dim=1))
> print(type(cm))
> cm

<class 'numpy.ndarray'>
Out[74]:
array([[5431,   14,   88,  145,   26,    7,  241,    0,   48,    0],
        [   4, 5896,    6,   75,    8,    0,    8,    0,    3,    0],
        [  92,    6, 5002,   76,  565,    1,  232,    1,   25,    0],
        [ 191,   49,   23, 5504,  162,    1,   61,    0,    7,    2],
        [  15,   12,  267,  213, 5305,    1,  168,    0,   19,    0],
        [   0,    0,    0,    0,    0, 5847,    0,  112,    3,   38],
        [1159,   16,  523,  189,  676,    0, 3396,    0,   41,    0],
        [   0,    0,    0,    0,    0,   99,    0, 5540,    0,  361],
        [  28,    6,   29,   15,   32,   23,   26,   14, 5827,    0],
        [   0,    0,    0,    0,    1,   61,    0,  107,    1, 5830]],
        dtype=int64)

PyTorch tensors are array-like Python objects, so we can pass them directly to the confusion_matrix() function. We pass the training set labels tensor (targets) and the argmax with respect to the first dimension of the train_preds tensor, and this gives us the confusion matrix data structure.

To actually plot the confusion matrix, we need some custom code that I've put in a local file called plotcm. The function is called plot_confusion_matrix(). The plotcm.py file need to contain the following contents and live inside the resources folder of the current directory.

Note that you can also just copy this code into your notebook or whatever to avoid the import.

plotcm.py:

import itertools
import numpy as np
import matplotlib.pyplot as plt

def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt), horizontalalignment="center", color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

Source: scikit-learn.org

For importing, we do it like this:

from plotcm import plot_confusion_matrix

We are ready to plot the confusion matrix, but first we need to create a list of prediction class names to pass to the plot_confusion_matrix() function. Our prediction classes and their corresponding indexes are given by the table below:

Index Label
0 T-shirt/top
1 Trouser
2 Pullover
3 Dress
4 Coat
5 Sandal
6 Shirt
7 Sneaker
8 Bag
9 Ankle boot

This allows us to make the call to plot the matrix:

> names = (
    'T-shirt/top'
    ,'Trouser'
    ,'Pullover'
    ,'Dress'
    ,'Coat'
    ,'Sandal'
    ,'Shirt'
    ,'Sneaker'
    ,'Bag'
    ,'Ankle boot'
)
> plt.figure(figsize=(10,10))
> plot_confusion_matrix(cm, names)

Confusion matrix, without normalization
[[5431   14   88  145   26    7  241    0   48    0]
[   4 5896    6   75    8    0    8    0    3    0]
[  92    6 5002   76  565    1  232    1   25    0]
[ 191   49   23 5504  162    1   61    0    7    2]
[  15   12  267  213 5305    1  168    0   19    0]
[   0    0    0    0    0 5847    0  112    3   38]
[1159   16  523  189  676    0 3396    0   41    0]
[   0    0    0    0    0   99    0 5540    0  361]
[  28    6   29   15   32   23   26   14 5827    0]
[   0    0    0    0    1   61    0  107    1 5830]]


Interpreting the confusion matrix

The confusion matrix has three axes:

  1. Prediction label (class)
  2. True label
  3. Heat map value (color)

The prediction label and true labels show us which prediction class we are dealing with. The matrix diagonal represents locations in the matrix where the prediction and the truth are the same, so this is where we want the heat map to be darker.

Any values that are not on the diagonal are incorrect predictions because the prediction and the true label don't match. To read the plot, we can use these steps:

  1. Choose a prediction label on the horizontal axis.
  2. Check the diagonal location for this label to see the total number correct.
  3. Check the other non-diagonal locations to see where the network is confused.

For example, the network is confusing a T-shirt/top with a shirt, but is not confusing the T-shirt/top with things like:

  • Ankle boot
  • Sneaker
  • Sandal

If we think about it, this makes pretty good sense. As our model learns, we will see the numbers that lie outside the diagonal become smaller and smaller.

Conclusion

At this point in the series, we have completed quite a lot of work on building and training a CNN in PyTorch. Congratulations for making it this far!

Description

In this episode, we learn how to build, plot, and interpret a confusion matrix using PyTorch. We also talk about locally disabling PyTorch gradient tracking or computational graph generation. This is due to the fact that we are using our network to obtain predictions for every sample in our training set. FashionMNIST Explained - https://deeplizard.com/learn/video/EqpzfvxBx30 Training Loop Explained - https://deeplizard.com/learn/video/XfYmia3q2Ow 💥🦎 DEEPLIZARD COMMUNITY RESOURCES 🦎💥 👀 OUR VLOG: 🔗 https://www.youtube.com/channel/UC9cBIteC3u7Ee6bzeOcl_Og 👉 Check out the blog post and other resources for this video: 🔗 https://deeplizard.com/learn/video/0LhiS6yu2qQ 💻 DOWNLOAD ACCESS TO CODE FILES 🤖 Available for members of the deeplizard hivemind: 🔗 https://www.patreon.com/posts/27743395 🧠 Support collective intelligence, join the deeplizard hivemind: 🔗 https://deeplizard.com/hivemind 🤜 Support collective intelligence, create a quiz question for this video: 🔗 https://deeplizard.com/create-quiz-question 🚀 Boost collective intelligence by sharing this video on social media! ❤️🦎 Special thanks to the following polymaths of the deeplizard hivemind: yasser Prash 👀 Follow deeplizard: Our vlog: https://www.youtube.com/channel/UC9cBIteC3u7Ee6bzeOcl_Og Twitter: https://twitter.com/deeplizard Facebook: https://www.facebook.com/Deeplizard-145413762948316 Patreon: https://www.patreon.com/deeplizard YouTube: https://www.youtube.com/deeplizard Instagram: https://www.instagram.com/deeplizard/ 🎓 Deep Learning with deeplizard: Fundamental Concepts - https://deeplizard.com/learn/video/gZmobeGL0Yg Beginner Code - https://deeplizard.com/learn/video/RznKVRTFkBY Advanced Code - https://deeplizard.com/learn/video/v5cngxo4mIg Advanced Deep RL - https://deeplizard.com/learn/video/nyjbcRQ-uQ8 🎓 Other Courses: Data Science - https://deeplizard.com/learn/video/d11chG7Z-xk Trading - https://deeplizard.com/learn/video/ZpfCK_uHL9Y 🛒 Check out products deeplizard recommends on Amazon: 🔗 https://www.amazon.com/shop/deeplizard 📕 Get a FREE 30-day Audible trial and 2 FREE audio books using deeplizard’s link: 🔗 https://amzn.to/2yoqWRn 🎵 deeplizard uses music by Kevin MacLeod 🔗 https://www.youtube.com/channel/UCSZXFhRIx6b0dFX3xS8L1yQ 🔗 http://incompetech.com/ ❤️ Please use the knowledge gained from deeplizard content for good, not evil.