Train, Test, & Validation Sets explained
text
Datasets for deep learning
In this post, we'll discuss the different datasets used during training and testing a neural network.
For training and testing purposes for our model, we should have our data broken down into three distinct datasets. These datasets will consist of the following:
- Training set
- Validation set
- Test set
Let's start by discussing the training set.
Training set
The training set is what it sounds like. It's the set of data used to train the model. During each epoch, our model will be trained over and over again on this same data in our training set, and it will continue to learn about the features of this data.
The hope with this is that later we can deploy our model and have it accurately predict on new data that it's never seen before. It will be making these predictions based on what it's learned about the training data. Ok, now let's discuss the validation set.
Validation set
The validation set is a set of data, separate from the training set, that is used to validate our model during training. This validation process helps give information that may assist us with adjusting our hyperparameters.
Recall how we just mentioned that with each epoch during training, the model will be trained on the data in the training set. Well, it will also simultaneously be validated on the data in the validation set.
We know from our previous posts on training, that during the training process, the model will be classifying the output for each input in the training set. After this classification occurs, the loss will then be calculated, and the weights in the model will be adjusted. Then, during the next epoch, it will classify the same input again.
Now, also during training, the model will be classifying each input from the validation set as well. It will be doing this classification based only on what it's learned about the data it's being trained on in the training set. The weights will not be updated in the model based on the loss calculated from our validation data.
Remember, the data in the validation set is separate from the data in the training set. So when the model is validating on this data, this data does not consist of samples that the model already is familiar with from training.
One of the major reasons we need a validation set is to ensure that our model is not overfitting to the data in the training set. We'll discuss overfitting and underfitting in detail at a later time. But the idea of overfitting is that our model becomes really good at being able to classify data in the training set, but it's unable to generalize and make accurate classifications on data that it wasn't trained on.
During training, if we're also validating the model on the validation set and see that the results it's giving for the validation data are just as good as the results it's giving for the training data, then we can be more confident that our model is not overfitting.
On the other hand, if the results on the training data are really good, but the results on the validation data are lagging behind, then our model is overfitting. Now let's move on to the test set.
Test set
The test set is a set of data that is used to test the model after the model has already been trained. The test set is separate from both the training set and validation set.
After our model has been trained and validated using our training and validation sets, we will then use our model to predict the output of the unlabeled data in the test set.
One major difference between the test set and the two other sets is that the test set should not be labeled. The training set and validation set have to be labeled so that we can see the metrics given during training, like the loss and the accuracy from each epoch.
When the model is predicting on unlabeled data in our test set, this would be the same type of process that would be used if we were to deploy our model out into the field.
For example, if we're using a model to classify data without knowing what the labels of the data are beforehand, or with never have being shown the exact data it's going to be classifying, then of course we wouldn't be giving our model labeled data to do this.
The entire goal of having a model be able to classify is to do it without knowing what the data is beforehand.
Deep learning datasets in summary
The table below summarizes deep learning datasets:
Dataset | Updates Weights | Description |
---|---|---|
Training set | Yes | Used to train the model. The goal of training is to fit the model to the training set while still generalizing to unseen data. |
Validation set | No | Used during training to check how well the model is generalizing. |
Test set | No | Used to test the model's final ability to generalize before deploying to production. |
Now hopefully we have an idea about how our data should be organized in terms of datasets and how each of these datasets are used.
The main reason for having three separate datasets is to ensure that the model is able to generalize by predicting accurately on unseen data. When the model is failing to generalize, we are usually in a situation of overfitting or underfitting. We'll look at these in the next one. I'll see ya there!
quiz
resources
updates
Committed by on