CNN Confusion Matrix with PyTorch - Neural Network Programming
text
Create a Confusion Matrix with PyTorch
Welcome to this neural network programming series. In this episode, we're going to build some functions that will allow us to get a prediction tensor for every sample in our training set.

Then, we'll see how we can take this prediction tensor, along with the labels for each sample, to create a confusion matrix. This confusion matrix will allow us to see which categories our network is confusing with one another. Without further ado, let's get started.
Where we are now in the course.
- Prepare the data
- Build the model
- Train the model
-
Analyze the model's results
- Building, plotting, and interpreting a confusion matrix
Be sure to see the previous episode in this course for all the code setup details.
Confusion Matrix Requirements
To create a confusion matrix for our entire dataset, we need to have a prediction tensor with a single dimension that has the same length as our training set.
> len(train_set) 60000
This prediction tensor will contain ten predictions for each sample from our training set (one for each category of clothing). After we have obtained this tensor, we can use the labels tensor to generate a confusion matrix.
> len(train_set.targets) 60000
A confusion matrix will show us where the model is getting confused. To be more specific, the confusion matrix will show us which categories the model is predicting correctly and which categories the model is predicting incorrectly. For the incorrect predictions, we will be able to see which category the model predicted, and this will show us which categories are confusing the model.
Get Predictions for the Entire Training Set
To get the predictions for all the training set samples, we need to pass all of the samples forward through the network. To do this, it is possible to create a DataLoader
that has
batch_size=1
. This will pass a single batch to the network at once and will give us the desired prediction tensor for all the training set samples.
However, depending on the computing resources and the size of the training set if we were training on a different data set, we need a way to prediction on smaller batches and collect the results. To collect the results, we'll use the torch.cat()
function to concatenate the output tensors together to obtain our single prediction tensor. Let's build a function to do this.
Building a Function to get Predictions for ALL Samples
We'll create a function called get_all_preds()
, and we'll pass a model and a data loader. The model will be used to obtain the predictions, and the data loader will be used to provide
the batches from the training set.
All the function needs to do is iterate over the data loader passing the batches to the model and concatenating the results of each batch to a prediction tensor that will returned to the caller.
@torch.no_grad() def get_all_preds(model, loader): all_preds = torch.tensor([]) for batch in loader: images, labels = batch preds = model(images) all_preds = torch.cat( (all_preds, preds) ,dim=0 ) return all_preds
The implantation of this function creates an empty tensor, all_preds
to hold the output predictions. Then, it iterates over the batches coming from the data loader, and concatenates the output
predictions with the
all_preds
tensor. Finally, all the predictions, all_preds
, is returned to the caller.
Note at the top, we have annotated the function using the @torch.no_grad()
PyTorch decoration. This is because we want this functions execution to omit gradient tracking.
This is because gradient tracking uses memory, and during inference (getting predictions while not training) there is no need to keep track of the computational graph. The decoration is one way of locally turning off the gradient tracking feature while executing specific functions.
Locally Disabling PyTorch Gradient Tracking
We are ready now to make the call to obtain the predictions for the training set. All we need to do is create a data loader with a reasonable batch size, and pass the model and data loader to the
get_all_preds()
function.
In a previous episode, we saw how use turned off PyTorch's gradient tracking feature when it was not needed, and we turned it back on when we started the training process.
We specifically need the gradient calculation feature anytime we are going to calculate gradients using the backward()
function. Otherwise, it is a good idea to turn it off because having it
off will reduce memory consumption for computations, e.g. when we are using networks for predicting (inference).
with
context manger
keyword to specify that a specify block of code should exclude gradient computations.
with torch.no_grad(): prediction_loader = torch.utils.data.DataLoader(train_set, batch_size=10000) train_preds = get_all_preds(network, prediction_loader)
Both of these options are valid. Let's keep both of these and get our predictions.
Using the Predictions Tensor
Now, that we have the prediction tensor, we can pass it to the get_num_correct()
function that we created in a previous episode, along with the training set labels, to get the total number of
correct predictions.
> preds_correct = get_num_correct(train_preds, train_set.targets) > print('total correct:', preds_correct) > print('accuracy:', preds_correct / len(train_set)) total correct: 53578 accuracy: 0.8929666666666667
We can see the total number of correct predictions and print the accuracy by dividing by the number of samples in the training set.
Building the Confusion Matrix
Our task in building the confusion matrix is to count the number of predicted values against the true values (targets).
This will create a matrix that acts as a heat map telling us where the predicted values fall relative to the true values.
To do this, we need to have the targets
tensor and the predicted label from the train_preds
tensor.
> train_set.targets tensor([9, 0, 0, ..., 3, 0, 5]) > train_preds.argmax(dim=1) tensor([9, 0, 0, ..., 3, 0, 5])
Now, if we compare the two tensors element-wise, we can see if the predicted label matches the target. Additionally, if we are counting the number of predicted labels vs the target labels, the values inside the two tensors act as coordinates for our matrix.
Let's stack these two tensors along the second dimension so we can have 60,000
ordered pairs.
> stacked = torch.stack( ( train_set.targets ,train_preds.argmax(dim=1) ) ,dim=1 ) > stacked.shape torch.Size([60000, 2]) > stacked tensor([ [9, 9], [0, 0], [0, 0], ..., [3, 3], [0, 0], [5, 5] ]) > stacked[0].tolist() [9, 9]
Now, we can iterate over these pairs and count the number of occurrences at each position in the matrix. Let's create the matrix. Since we have ten prediction categories, we'll have a ten by ten matrix. Check
here to learn about the stack()
function.
> cmt = torch.zeros(10,10, dtype=torch.int64) > cmt tensor([ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] ])
Now, we'll iterate over the prediction-target pairs and add one to the value inside the matrix each time the particular position occurs.
for p in stacked: tl, pl = p.tolist() cmt[tl, pl] = cmt[tl, pl] + 1
This gives us the following confusion matrix tensor.
> cmt tensor([ [5637, 3, 96, 75, 20, 10, 86, 0, 73, 0], [ 40, 5843, 3, 75, 16, 8, 5, 0, 10, 0], [ 87, 4, 4500, 70, 1069, 8, 156, 0, 106, 0], [ 339, 61, 19, 5269, 203, 10, 72, 2, 25, 0], [ 23, 9, 263, 209, 5217, 2, 238, 0, 39, 0], [ 0, 0, 0, 1, 0, 5604, 0, 333, 13, 49], [1827, 7, 716, 104, 792, 3, 2370, 0, 181, 0], [ 0, 0, 0, 0, 0, 22, 0, 5867, 4, 107], [ 32, 1, 13, 15, 19, 5, 17, 11, 5887, 0], [ 0, 0, 0, 0, 0, 28, 0, 234, 6, 5732] ])
Note that the example below will have different values because these two examples were created at different times.
Plotting the Confusion Matrix
To generate the actual confusion matrix as a numpy.ndarray
, we use the confusion_matrix()
function from the sklearn.metrics
library. Let's get this imported along
with our other needed imports.
import matplotlib.pyplot as plt from sklearn.metrics import confusion_matrix from resources.plotcm import plot_confusion_matrix
For the last import, note that plotcm
is a file, plotcm.py
that lives in a folder called resources in the current directory. Inside the plotcm.py
file, there
is a function called plot_confusion_matrix()
that we will call. You'll need to implement this on your system. We'll look at how to do this in a minute. First, let's generate the
confusion matrix.
We can generate the confusion matrix like so:
> cm = confusion_matrix(train_set.targets, train_preds.argmax(dim=1)) > print(type(cm)) > cm <class 'numpy.ndarray'> Out[74]: array([[5431, 14, 88, 145, 26, 7, 241, 0, 48, 0], [ 4, 5896, 6, 75, 8, 0, 8, 0, 3, 0], [ 92, 6, 5002, 76, 565, 1, 232, 1, 25, 0], [ 191, 49, 23, 5504, 162, 1, 61, 0, 7, 2], [ 15, 12, 267, 213, 5305, 1, 168, 0, 19, 0], [ 0, 0, 0, 0, 0, 5847, 0, 112, 3, 38], [1159, 16, 523, 189, 676, 0, 3396, 0, 41, 0], [ 0, 0, 0, 0, 0, 99, 0, 5540, 0, 361], [ 28, 6, 29, 15, 32, 23, 26, 14, 5827, 0], [ 0, 0, 0, 0, 1, 61, 0, 107, 1, 5830]], dtype=int64)
PyTorch tensors are
array-like Python objects, so we can pass them directly to the confusion_matrix()
function. We pass the training set labels tensor (targets) and the argmax with respect to the first dimension
of the train_preds
tensor, and this gives us the confusion matrix data structure.
To actually plot the confusion matrix, we need some custom code that I've put in a local file called plotcm
. The function is called plot_confusion_matrix()
. The
plotcm.py
file need to contain the following contents and live inside the resources folder of the current directory.
Note that you can also just copy this code into your notebook or whatever to avoid the import.
plotcm.py
:
import itertools import numpy as np import matplotlib.pyplot as plt def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues): if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] print("Normalized confusion matrix") else: print('Confusion matrix, without normalization') print(cm) plt.imshow(cm, interpolation='nearest', cmap=cmap) plt.title(title) plt.colorbar() tick_marks = np.arange(len(classes)) plt.xticks(tick_marks, classes, rotation=45) plt.yticks(tick_marks, classes) fmt = '.2f' if normalize else 'd' thresh = cm.max() / 2. for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): plt.text(j, i, format(cm[i, j], fmt), horizontalalignment="center", color="white" if cm[i, j] > thresh else "black") plt.tight_layout() plt.ylabel('True label') plt.xlabel('Predicted label')
Source - scikit-learn.org
For importing, we do it like this:
from plotcm import plot_confusion_matrix
We are ready to plot the confusion matrix, but first we need to create a list
of prediction class names to pass to the plot_confusion_matrix()
function. Our prediction classes and
their corresponding indexes are given by the table below:
Index | Label |
---|---|
0 | T-shirt/top |
1 | Trouser |
2 | Pullover |
3 | Dress |
4 | Coat |
5 | Sandal |
6 | Shirt |
7 | Sneaker |
8 | Bag |
9 | Ankle boot |

This allows us to make the call to plot the matrix:
> plt.figure(figsize=(10,10)) > plot_confusion_matrix(cm, train_set.classes) Confusion matrix, without normalization [[5431 14 88 145 26 7 241 0 48 0] [ 4 5896 6 75 8 0 8 0 3 0] [ 92 6 5002 76 565 1 232 1 25 0] [ 191 49 23 5504 162 1 61 0 7 2] [ 15 12 267 213 5305 1 168 0 19 0] [ 0 0 0 0 0 5847 0 112 3 38] [1159 16 523 189 676 0 3396 0 41 0] [ 0 0 0 0 0 99 0 5540 0 361] [ 28 6 29 15 32 23 26 14 5827 0] [ 0 0 0 0 1 61 0 107 1 5830]]

Interpreting the confusion matrix
The confusion matrix has three axes:
- Prediction label (class)
- True label
- Heat map value (color)
The prediction label and true labels show us which prediction class we are dealing with. The matrix diagonal represents locations in the matrix where the prediction and the truth are the same, so this is where we want the heat map to be darker.
Any values that are not on the diagonal are incorrect predictions because the prediction and the true label don't match. To read the plot, we can use these steps:
- Choose a prediction label on the horizontal axis.
- Check the diagonal location for this label to see the total number correct.
- Check the other non-diagonal locations to see where the network is confused.
For example, the network is confusing a T-shirt/top with a shirt, but is not confusing the T-shirt/top with things like:
- Ankle boot
- Sneaker
- Sandal
If we think about it, this makes pretty good sense. As our model learns, we will see the numbers that lie outside the diagonal become smaller and smaller.
Conclusion
At this point in the series, we have completed quite a lot of work on building and training a CNN in PyTorch. Congratulations for making it this far!
quiz
resources
updates
Committed by on