TensorFlow - Python Deep Learning Neural Network API

Deep Learning Course - Level: Beginner

Create a Confusion Matrix for Neural Network Predictions

video

expand_more chevron_left

text

expand_more chevron_left

Create a confusion matrix for neural network predictions

In this episode, we'll demonstrate how to create a confusion matrix, which will aid us in being able to visually observe how well a neural network is predicting during inference.

We'll continue working with the predictions we obtained from the tf.keras.Sequential model in the last episode.

In the last episode, we showed how to use a trained model for inference on new data in a test set it hasn't seen before. As mentioned in that episode, we had the labels for the test set, but we didn't provide these labels to the network.

Additionally, we were able to see the values that the model was predicting for each of the samples in the test set by just observing the predictions themselves.

Below are the probabilities that the model assigned to whether patients from the test set were more or less likely to experience side effects from an experimental drug.

for i in predictions:
    print(i)
[ 0.74106830  0.25893170]
[ 0.14958295  0.85041702]
[ 0.96918124  0.03081879]
[ 0.12985019  0.87014979]
[ 0.88596725  0.11403273]
...

We then looked only at the most probable prediction for each sample.

for i in rounded_predictions:
print(i)
0
1
0
1
0
...

Although we were able to read the predictions from the model easily, we weren't easily able to compare the predictions to the true labels for the test data.

With a confusion matrix, we'll be able to visually observe how well the model predicts on test data.

Let's jump into the code for how this is done.

Plotting a confusion matrix

First, we import all the required libraries we'll be working with.

%matplotlib inline
from sklearn.metrics import confusion_matrix
import itertools
import matplotlib.pyplot as plt

The confusion matrix we'll be plotting comes from scikit-learn.

We then create the confusion matrix and assign it to the variable cm. T

cm = confusion_matrix(y_true=test_labels, y_pred=rounded_predictions)

To the confusion matrix, we pass in the true labels test_labels as well as the network's predicted labels rounded_predictions for the test set.

Below, we have a function called plot_confusion_matrix() that came directly from scikit-learn's website. This is code that they provide in order to plot the confusion matrix.

def plot_confusion_matrix(cm, classes,
                        normalize=False,
                        title='Confusion matrix',
                        cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, cm[i, j],
            horizontalalignment="center",
            color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

Next, we define the labels for the confusion matrix. In our case, the labels are titled “no side effects” and “had side effects.”

cm_plot_labels = ['no_side_effects','had_side_effects']

Lastly, we plot the confusion matrix by using the plot_confusion_matrix() function we just discussed. To this function, we pass in the confusion matrix cm and the labels cm_plot_labels, as well as a title for the confusion matrix.

plot_confusion_matrix(cm=cm, classes=cm_plot_labels, title='Confusion Matrix')

Reading a Confusion Matrix

Looking at the plot of the confusion matrix, we have the predicted labels on the x-axis and the true labels on the y-axis. The blue cells running from the top left to bottom right contain the number of samples that the model accurately predicted. The white cells contain the number of samples that were incorrectly predicted.

There are 420 total samples in the test set. Looking at the confusion matrix, we can see that the model accurately predicted 391 out of 420 total samples. The model incorrectly predicted 29 out of the 420.

For the samples the model got correct, we can see that it accurately predicted that the patients would experience no side effects 191 times. It incorrectly predicted that the patient would have no side effects 10 times when the patient did actually experience side effects.

On the other side, the model accurately predicted that the patient would experience side effects 200 times that the patient did indeed experience side effects. It incorrectly predicted that the patient would have side effects 19 times when the patient actually did not experience side effects.

As you can see, this is a good way we can visually interpret how well the model is doing at its predictions and understand where it may need some work.

quiz

expand_more chevron_left
deeplizard logo DEEPLIZARD Message notifications

Quiz Results

resources

expand_more chevron_left
In this episode, we'll demonstrate how to create a confusion matrix, which will aid us in being able to visually observe how well a neural network is predicting during inference. We'll be working with predictions from a Sequential model from TensorFlow's Keras API. 🕒🦎 VIDEO SECTIONS 🦎🕒 00:00 Welcome to DEEPLIZARD - Go to deeplizard.com for learning resources 00:34 Plotting a Confusion Matrix 02:48 Reading a Confusion Matrix 04:46 Collective Intelligence and the DEEPLIZARD HIVEMIND 💥🦎 DEEPLIZARD COMMUNITY RESOURCES 🦎💥 👋 Hey, we're Chris and Mandy, the creators of deeplizard! 👀 CHECK OUT OUR VLOG: 🔗 https://youtube.com/deeplizardvlog 💪 CHECK OUT OUR FITNESS CHANNEL: 🔗 https://www.youtube.com/channel/UCdCxHNCexDrAx78VfAuyKiA 🧠 Use code DEEPLIZARD at checkout to receive 15% off your first Neurohacker order: 🔗 https://neurohacker.com/shop?rfsn=6488344.d171c6 ❤️🦎 Special thanks to the following polymaths of the deeplizard hivemind: Mano Prime 👀 Follow deeplizard: Our vlog: https://youtube.com/deeplizardvlog Fitness: https://www.youtube.com/channel/UCdCxHNCexDrAx78VfAuyKiA Facebook: https://facebook.com/deeplizard Instagram: https://instagram.com/deeplizard Twitter: https://twitter.com/deeplizard Patreon: https://patreon.com/deeplizard YouTube: https://youtube.com/deeplizard 🎓 Deep Learning with deeplizard: AI Art for Beginners - https://deeplizard.com/course/sdcpailzrd Deep Learning Dictionary - https://deeplizard.com/course/ddcpailzrd Deep Learning Fundamentals - https://deeplizard.com/course/dlcpailzrd Learn TensorFlow - https://deeplizard.com/course/tfcpailzrd Learn PyTorch - https://deeplizard.com/course/ptcpailzrd Natural Language Processing - https://deeplizard.com/course/txtcpailzrd Reinforcement Learning - https://deeplizard.com/course/rlcpailzrd Generative Adversarial Networks - https://deeplizard.com/course/gacpailzrd Stable Diffusion Masterclass - https://deeplizard.com/course/dicpailzrd 🎓 Other Courses: DL Fundamentals Classic - https://deeplizard.com/learn/video/gZmobeGL0Yg Deep Learning Deployment - https://deeplizard.com/learn/video/SI1hVGvbbZ4 Data Science - https://deeplizard.com/learn/video/d11chG7Z-xk Trading - https://deeplizard.com/learn/video/ZpfCK_uHL9Y 🛒 Check out products deeplizard recommends on Amazon: 🔗 https://amazon.com/shop/deeplizard 📕 Get a FREE 30-day Audible trial and 2 FREE audio books using deeplizard's link: 🔗 https://amzn.to/2yoqWRn 🎵 deeplizard uses music by Kevin MacLeod 🔗 https://youtube.com/channel/UCSZXFhRIx6b0dFX3xS8L1yQ ❤️ Please use the knowledge gained from deeplizard content for good, not evil.

updates

expand_more chevron_left
deeplizard logo DEEPLIZARD Message notifications

Update history for this page

Did you know you that deeplizard content is regularly updated and maintained?

  • Updated
  • Maintained

Spot something that needs to be updated? Don't hesitate to let us know. We'll fix it!


All relevant updates for the content on this page are listed below.