Extract, Transform, and Load (ETL) with PyTorch
Welcome back to this series on neural network programming with PyTorch. In this post, we will write our first code of part two of the series.
We’ll demonstrate a very simple extract, transform and load pipeline using torchvision, PyTorch’s computer vision package for machine learning. Without further ado, let’s get started.
The project (Bird's-eye view)
There are four general steps that we’ll be following as we move through this project:
- Prepare the data
- Build the model
- Train the model
- Analyze the model’s results
The ETL process
In this post, we’ll kick things off by preparing the data. To prepare our data, we'll be following what is loosely known as an ETL process.
- Extract data from a data source.
- Transform data into a desirable format.
- Load data into a suitable structure.
The ETL process can be thought of as a fractal process because it can be applied on various scales. The process can be applied on a small scale, like a single program, or on a large scale, all the way up to the enterprise level where there are huge systems handling each of the individual parts.
If you want to know more about the general data science pipeline, check out the data science post, where we cover this in greater detail.
Once we have completed the ETL process, we are ready to begin building and training our deep learning model. PyTorch has some built-in packages and classes that make the ETL process pretty easy.
We begin by importing all of the necessary PyTorch libraries.
import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F import torchvision import torchvision.transforms as transforms
This table describes the of each of these packages:
|torch||The top-level PyTorch package and tensor library.|
|torch.nn||A subpackage that contains modules and extensible classes for building neural networks.|
|torch.optim||A subpackage that contains standard optimization operations like SGD and Adam.|
|torch.nn.functional||A functional interface that contains typical operations used for building neural networks like loss functions and convolutions.|
|torchvision||A package that provides access to popular datasets, model architectures, and image transformations for computer vision.|
|torchvision.transforms||An interface that contains common transforms for image processing.|
The next imports are standard packages used for data science in Python:
import numpy as np import pandas as pd import matplotlib.pyplot as plt from sklearn.metrics import confusion_matrix #from plotcm import plot_confusion_matrix import pdb torch.set_printoptions(linewidth=120)
pdb is the Python debugger and the commented
import is a local file that we’ll introduce in
future posts for plotting the confusion matrix, and the last line sets the print options for PyTorch print statements.
We are ready now to prepare our data.
Preparing our data using PyTorch
Our ultimate goal when preparing our data is to do the following (ETL):
- Extract – Get the Fashion-MNIST image data from the source.
- Transform – Put our data into tensor form.
- Load – Put our data into an object to make it easily accessible.
For these purposes, PyTorch provides us with two classes:
|torch.utils.data.Dataset||An abstract class for representing a dataset.|
|torch.utils.data.DataLoader||Wraps a dataset and provides access to the underlying data.|
An abstract class is a Python class that has methods we must implement, so we can create a custom dataset by creating a subclass that extends the functionality of the Dataset class.
To create a
custom dataset using PyTorch, we extend the
Dataset class by creating a subclass that implements these required methods. Upon doing this, our new subclass can then be passed to the a
We will be using the fashion-MNIST dataset that comes built-in with the
torchvision package, so we won’t have to do this for our project. Just know that the Fashion-MNIST built-in
dataset class is doing this behind the scenes.
__len__, that provides the size of the dataset, and
__getitem__, supporting integer indexing in range from
Specifically, there are two methods that are required to be implemented. The
__len__ method which returns the length of the dataset, and the
__getitem__ method that gets an
element from the dataset at a specific index location within the dataset.
PyTorch torchvision package
torchvision package, gives us access to the following resources:
- Datasets (like MNIST and Fashion-MNIST)
- Models (like VGG16)
All of these resources are related to deep learning computer vision tasks.
When we learned about the Fashion-MNIST dataset in our previous post, the arXiv paper that introduced the fashion dataset indicated that the authors wanted it to be a drop-in for the original MNIST dataset.
The idea was to make is so that frameworks like PyTorch could add Fashion-MNIST by just changing the URL for retrieving the data.
This is the case for PyTorch. The PyTorch
FashionMNIST dataset simply extends the
MNIST dataset and overrides the urls.
Here is the class definition from PyTorch's
torchvision source code:
class FashionMNIST(MNIST): """`Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_ Dataset. Args: root (string): Root directory of dataset where ``processed/training.pt`` and ``processed/test.pt`` exist. train (bool, optional): If True, creates dataset from ``training.pt``, otherwise from ``test.pt``. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. """ urls = [ 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz', 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz', 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz', 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz', ]
Let’s see now how we can take advantage of
PyTorch Dataset class
To get an instance of the FashionMNIST dataset using
torchvision, we just create one like so:
train_set = torchvision.datasets.FashionMNIST( root='./data/FashionMNIST' ,train=True ,download=True ,transform=transforms.Compose([ transforms.ToTensor() ]) )
We specify the following arguments:
|root||The location on disk where the data is located.|
|train||If the dataset is the training set|
|download||If the data should be downloaded.|
|transform||A composition of transformations that should be performed on the dataset elements.|
Since we want our images to be transformed into tensors, we use the built-in
transforms.ToTensor() transformation, and since this dataset is going to be used for training, we’ll name the instance
When we run this code for the first time, the Fashion-MNIST dataset will be downloaded locally. Subsequent calls check for the data before downloading it. Thus, we don't have to worry about double downloads or repeated network calls.
PyTorch DataLoader class
To create a DataLoader wrapper for our training set, we do it like this:
train_loader = torch.utils.data.DataLoader(train_set ,batch_size=1000 ,shuffle=True )
We just pass train_set as an argument. Now, we can leverage the loader for tasks that would otherwise be pretty complicated to implement by hand:
batch_size(1000 in our case)
shuffle(True in our case)
num_workers(Default is 0 which means the main process will be used)
From an ETL perspective, we have achieved the extract, and the transform using
torchvision when we created the dataset:
- Extract – The raw data was extracted from the web.
- Transform – The raw image data was transformed into a tensor.
Load – The
train_setwrapped by (loaded into) the data loader giving us access to the underlying data.
Now, we should have a good understanding of the
torchvision module that is provided by PyTorch, and how we can use Datasets and DataLoaders in the PyTorch
torch.utils.data package to streamline ETL tasks.
In the next post, we’ll see how we can work with datasets and data loaders to access and view individual samples as well as batches of samples.
I’ll see you in the next one!