PyTorch DataLoader Source Code - Debugging Session
text
PyTorch DataLoader Source Code - Debugging Session
Welcome to DeepLizard. My name's Chris. In this episode, we're going to pick up where we left off last time looking at data normalization. Only this time, instead of writing code, we're going be debugging the code, and specifically, we're going be debugging down into the PyTorch source code to see exactly what's going on when we normalize a data set.

Without further ado, let's get started.
Short Program to Debug PyTorch Source
Before we start debugging, we just want to give we a quick overview of the program that we have written that will allow us to step in and see the normalization of the dataset and see exactly how it's done underneath the hood and PyTorch.
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
from torch.utils.data import DataLoader
As we discussed in the last episode, we have the mean and standard deviation values. Now, instead of having to calculate these, we just pulled these out and hard coded them into the program here.
mean = 0.2860347330570221
std = 0.3530242443084717
This is the kind of thing that we would do if we were to snagging these values offline.
We don't want to have to go through the trouble of recalculating these values, so we are hard coating them here. We have the mean and standard deviation, and we know we need both of these values to be able to normalize every member or every pixel of our data set.
Next, we initialize our train set using the FashionMNIST
class constructor. The key point to notice here or to take note of is the transforms. We have a composition of transforms.
train_set = torchvision.datasets.FashionMNIST(
root='./data'
,train=True
,download=True
,transform=transforms.Compose([
transforms.ToTensor()
, transforms.Normalize(mean, std)
])
)
The first of the composition transforms the pill image into a tensor, and then the second is the normalize transform, which is going to normalize our data. Our goal is to verify in the source code how this particular transform is working.
Lastly, we create a DataLoader
and use it.
loader = DataLoader(train_set, batch_size=1)
image, label = next(iter(loader))
Debugging the PyTorch Source Code
All right, so now we're ready to actually debug. To debug, we are going to go ahead and just make sure that we have my python run configuration selected, and then we are going to click, start debugging.
Use this
link to access the current source code for the PyTorch DataLoader
class. This discussion assumes PyTorch version 1.5.0
.
The Sampler: To Shuffle or Not
A sampler is the object that gets the index values that will be used to get the actual values from the underlying dataset.
We can see, there's two particular samplers that are relevant, the random sampler and the sequential sampler.
- Random Sampler
- Sequential Sampler
If the shuffle
value is true
the sampler will be a random one, otherwise a sequential one.
How the Batch Size is Used
We found that the sampler is used to gather the index values in this code:
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
Here, we can see the batch_size
parameter in play as it is limiting the number of index values collected.
Note that the yield
keyword here makes this iterator into what is called a
generator.
After the index values are obtained, they are used to fetch the data in the following way:
def fetch(self, possibly_batched_index):
if self.auto_collation:
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)
This like is does the work of pulling each sample from the underlying dataset.
data = [self.dataset[idx] for idx in possibly_batched_index]
This syntax or notation is referred to as list comprehension.
This returns a list of data elements that are then extracted and put into a single batch tensor using the collate_fn()
method.
Normalizing the Dataset
Finally, we found that each element returned to the batch is normalized using the normalize()
function of the functional api.
def normalize(tensor, mean, std, inplace=False):
"""Normalize a tensor image with mean and standard deviation.
tensor.sub_(mean).div_(std)
return tensor
Note that the dataset class calls a transform that then calls the functional api. We also ran into some bad designed which required some hacking to keep things consistent.
Note that by using the term hacking here, we are referring to the fact that we saw the code was doing unnecessary transformations.
quiz
resources
updates
Committed by on