Overfitting in a neural network
In this post, we’ll discuss what it means when a model is said to be overfitting. We’ll also cover some techniques we can use to try to reduce overfitting when it happens.
We briefly mentioned the concept of overfitting in a previous post where we discussed the purpose of a validation set. Let’s build more on this concept now.
Overfitting occurs when our model becomes really good at being able to classify or predict on data that was included in the training set, but is not as good at classifying data that it wasn’t trained on. So essentially, the model has overfit the data in the training set.
How to spot overfitting
We can tell if the model is overfitting based on the metrics that are given for our training data and validation data during training. We previously saw that when we specify a validation set during training, we get metrics for the validation accuracy and loss, as well as the training accuracy and loss.
If the validation metrics are considerably worse than the training metrics, then that is indication that our model is overfitting.
We can also get an idea that our model is overfitting if during training, the model’s metrics were good, but when we use the model to predict on test data, it doesn't accurately classify the data in the test set.
The concept of overfitting boils down to the fact that the model is unable to generalize well. It has learned the features of the training set extremely well, but if we give the model any data that slightly deviates from the exact data used during training, it’s unable to generalize and accurately predict the output.
Overfitting is an incredibly common issue. How can we reduce it? Let's look at some techniques.
Adding more data to the training set
The easiest thing we can do, as long as we have access to it, is to add more data. The more data we can train our model on, the more it will be able to learn from the training set. Also, with more data, we’re hoping to be adding more diversity to the training set as well.
For example, if we train a model to classify whether an image is an image of a dog or cat, and the model has only seen images of larger dogs, like Labs, Golden Retrievers, and Boxers, then in practice if it sees a Pomeranian, it may not do so well at recognizing that a Pomeranian is a dog.
If we add more data to this model to encompass more breeds, then our training data will become more diverse, and the model will be less likely to overfit.
Another technique we can deploy to reduce overfitting is to use data augmentation. This is the process of creating additional augmented data by reasonably modifying the data in our training set. For image data, for example, we can do these modifications by:
The general idea of data augmentation allows us to add more data to our training set that is similar to the data that we already have, but is just reasonably modified to some degree so that it’s not the exact same.
For example, if most of our dog images were dogs facing to the left, then it would be a reasonable modification to add augmented flipped images so that our training set would also have dogs that faced to the right.
Reduce the complexity of the model
Something else we can do to reduce overfitting is to reduce the complexity of our model. We could reduce complexity by making simple changes, like removing some layers from the model, or reducing the number of neurons in the layers. This may help our model generalize better to data it hasn’t seen before.
The last tip we'll cover for reducing overfitting is to use something called dropout. The general idea behind dropout is that, if you add it to a model, it will randomly ignore some subset of nodes in a given layer during training, i.e., it drops out the nodes from the layer. Hence, the name dropout. This will prevent these dropped out nodes from participating in producing a prediction on the data.
This technique may also help our model to generalize better to data it hasn’t seen before. We’ll cover the full concept of dropout as a regularization technique in another post, and there we’ll understand why this makes sense.
Underfitting is next
Hopefully now we understand the concept of overfitting, why it occurs, and how we can reduce it if we see it happening in our models. In the next post, we’ll explore the concept of underfitting. I’ll see ya there!