Code for Deep Learning - ArgMax and Reduction Tensor Ops
text
Tensor Reduction Ops for Deep Learning
Welcome back to this series on neural network programming. In this post, we'll learn about reduction operations for tensors.
- Reshaping operations
- Element-wise operations
- Reduction operations
- Access operations
We'll focus in on the frequently used argmax()
function, and we'll see how to access the data inside our tensors. Without further ado, let's get started.
Tensor reduction operations
Let's kick things off by giving a definition for a reduction operation.
So far in this series, we've learned that tensors are the data structures of deep learning. Our first task is to load our data elements into a tensor.
For this reason, tensors are super important, but ultimately, what we are doing with the operations we've been learning about in this series is managing the data elements contained with our tensors.
Reshaping operations gave us the ability to position our elements along particular axes. Element-wise operations allow us to perform operations on elements between two tensors, and reduction operations allow us to perform operations on elements within a single tensor.
Let's look at an example in code.
Reduction operation example
Suppose we the following
3 x 3
rank-2 tensor:
> t = torch.tensor([
[0,1,0],
[2,0,2],
[0,3,0]
], dtype=torch.float32)
Let's look at our first reduction operation, a summation:
> t.sum()
tensor(8.)
Using the fact that
> t.numel()
9
> t.sum().numel()
1
We can see that
> t.sum().numel() < t.numel()
True
The sum of our tensor's scalar components is calculated using the sum()
tensor method. The result of this call is a scalar valued tensor.
Checking the number of elements in the original tensor against the result of the sum()
call, we can see that, indeed, the tensor returned by the call to sum()
contains fewer elements
than the original.
Since the number of elements have been reduced by the operation, we can conclude that the sum()
method is a reduction operation.
Common tensor reduction operations
As you may expect, here are some other common reduction functions:
> t.sum()
tensor(8.)
> t.prod()
tensor(0.)
> t.mean()
tensor(.8889)
> t.std()
tensor(1.1667)
All of these tensor methods reduce the tensor to a single element scalar valued tensor by operating on all the tensor's elements.
Reduction operations in general allow us to compute aggregate (total) values across data structures. In our case, our structures are tensors.
Here is a question though:
The answer is no!
In fact, we often reduce specific axes at a time. This process is important. It's just like we saw with reshaping when we aimed to flatten the image tensors within a batch while still maintaining the batch axis.
Reducing tensors by axes
To reduce a tensor with respect to a specific axis, we use the same methods, and we just pass a value for the dimension parameter. Let's see this in action.
Suppose we have the following tensor:
> t = torch.tensor([
[1,1,1,1],
[2,2,2,2],
[3,3,3,3]
], dtype=torch.float32)
This is a 3 x 4
rank-2 tensor. Having different lengths for the two axes will help us understand these reduce operations.
Let's consider the
sum()
method again. Only, this time, we will specify a dimension to reduce. We have two axes so we'll do both. Check it out.
> t.sum(dim=0)
tensor([6., 6., 6., 6.])
> t.sum(dim=1)
tensor([ 4., 8., 12.])
When I first saw this when I was learning how this works, I was confused. If you're confused like I was, I highly recommend you try to understand what's happening here before going forward.
Remember, we are reducing this tensor across the first axis, and elements running along the first axis are arrays, and the elements running along the second axis are numbers.
Let's go over what happened here.
Understanding reductions by axes
We'll tackle the first axis first. When take the summation of the first axis, we are summing the elements of the first axis.
It's like this:
β> t[0]
tensor([1., 1., 1., 1.])
> t[1]
tensor([2., 2., 2., 2.])
> t[2]
tensor([3., 3., 3., 3.])
> t[0] + t[1] + t[2]
tensor([6., 6., 6., 6.])
Surprise! Element-wise operations are in play here.
When we sum across the first axis, we are taking the summation of all the elements of the first axis. To do this, we must utilize element-wise addition. This is why we covered element-wise operations before reduction operations in the series.
The second axis in this tensor contains numbers that come in groups of four. Since we have three groups of four numbers, we get three sums.
> t[0].sum()
tensor(4.)
> t[1].sum()
tensor(8.)
> t[2].sum()
tensor(12.)
> t.sum(dim=1)
tensor([ 4., 8., 12.])
This may take a little bit of time to sink in. If it does, don't worry, you can do it.
Now, with this heavy lifting out of the way. Let's look now a very common reduction operation used in neural network programming called Argmax.
Argmax tensor reduction operation
Argmax is a mathematical function that tells us which argument, when supplied to a function as input, results in the function's max output value.
When we call the argmax()
method on a tensor, the tensor is reduced to a new tensor that contains an index value indicating where the max value is inside the tensor. Let's see this in
code.
Suppose we have the following tensor:
t = torch.tensor([
[1,0,0,2],
[0,3,3,0],
[4,0,0,5]
], dtype=torch.float32)
In this tensor, we can see that the max value is the 5 in the last position of the last array.
Suppose we are tensor walkers. To arrive at this element, we walk down the first axis until we reach the last array element, and then we walk down to the end of this array passing by the 4, and the two 0s.
Let's see some code.
> t.max()
tensor(5.)
> t.argmax()
tensor(11)
> t.flatten()
tensor([1., 0., 0., 2., 0., 3., 3., 0., 4., 0., 0., 5.])
The first piece of code confirms for us that the max is indeed 5
, but the call to the argmax()
method tells us that the 5
is sitting at index 11
. What's
happening here?
We'll have a look at the flattened output for this tensor. If we don't specific an axis to the argmax()
method, it returns the index location of the max value from the flattened
tensor, which in this case is indeed 11
.
Let's see how we can work with specific axes now.
> t.max(dim=0)
(tensor([4., 3., 3., 5.]), tensor([2, 1, 1, 2]))
> t.argmax(dim=0)
tensor([2, 1, 1, 2])
> t.max(dim=1)
(tensor([2., 3., 5.]), tensor([3, 1, 3]))
> t.argmax(dim=1)
tensor([3, 1, 3])
We're working with both axes of this tensor in this code. Notice how the call to the max()
method returns two tensors. The first tensor contains the max values and the second tensor contains
the index locations for the max values. This is what argmax gives us.
For the first axis, the max values are, 4
, 3
, 3
, and 5
. These values are determined by taking the element-wise maximum across each array running across
the first axis.
For each of these maximum values, the argmax()
method tells us which element along the first axis where the value lives.
-
The
4
lives at index two of the first axis. -
The first
3
lives at index one of the first axis. -
The second
3
lives at index one of the first axis. -
The
5
lives at index two of the first axis.
For the second axis, the max values are 2
, 3
, and 5
. These values are determined by taking the maximum inside each array of the first axis. We have three groups of
four, which gives us 3
maximum values.
The argmax values here, tell the index inside each respective array where the max value lives.
In practice, we often use the argmax()
function on a network's output prediction tensor, to determine which category has the highest prediction value.
Accessing elements inside tensors
The last type of common operation that we need for tensors is the ability to access data from within the tensor. Let's look at these for PyTorch.
Suppose we have the following tensor:
> t = torch.tensor([
[1,2,3],
[4,5,6],
[7,8,9]
], dtype=torch.float32)
> t.mean()
tensor(5.)
> t.mean().item()
5.0
Check out these operations on this one. When we call mean on this 3 x 3
tensor, the reduced output is a scalar valued tensor. If we want to actually get the value as a number, we use the
item()
tensor method. This works for scalar valued tensors.
Have a look at how we do it with multiple values:
> t.mean(dim=0).tolist()
[4.0, 5.0, 6.0]
> t.mean(dim=0).numpy()
array([4., 5., 6.], dtype=float32)
When we compute the mean across the first axis, multiple values are returned, and we can access the numeric values by transforming the output tensor into a Python list
or a NumPy
array
.
Advanced indexing and slicing
With NumPy ndarray
objects, we have a pretty robust set of operations for indexing and slicing, and PyTorch tensor
objects support most of these operations as well.
Use this a resource for advanced indexing and slicing.
Deep learning project
Congrats for making it this far in the series. All of these tensor topics are pretty raw and low level, but having a strong understanding of them make our lives much easier as we develop as neural network programmers.
We're ready now to start part two of the series where we'll be putting all of this knowledge to use. We'll be kicking things off by exploring the dataset we'll be training on, the Fashion-MNIST dataset.
This dataset contains a training set of sixty thousand examples from ten different classes of clothing. We will use PyTorch to build a convolutional neural network that can accurately predict the correct article of clothing given an input piece, so stay tuned!
I'll see you in the next one!
quiz
resources
updates
Committed by on