Computational Graphs for Neural Networks Code Demo
text
Computational Graphs for Neural Networks Code Demo
In this episode, we'll examine the computational graphs used to compute gradients via backpropagation during neural network training.
We'll use PyTorch for demo purposes, but the concepts we cover here are fundamental concepts of neural networks. While we'll be gaining an understanding of a fundamental piece of training neural networks, there is an important reason to examine this in the GAN course.
When we trained DCGAN with PyTorch, we saw that during the discriminator training portion, we called detach()
on the fake images generated by the generator when we passed them to the discriminator for classification. When we passed this same batch of images to the discriminator a second time during generator training, however, we did not call detach()
.
There, we explained by doing this detaching, we would be returned a new tensor detached from the current graph. To understand what this means, we need to first understand computational graphs in general.
Computational Graph Code Demo
In our demo, we'll use the torchviz
and graphviz
visualization libraries to plot a visual representation of neural network computational graphs. We need to first install torchviz
and graphviz
with the commands below in order to import them in the next step.
conda install graphviz
conda install python-graphviz
pip install torchviz==0.0.2
Now, we'll import PyTorch and the nn
and functional
modules we'll make use of, as well as torchviz
and the Diagraph
module from graphviz
. We'll also set PyTorch's random seed so that we can generate the same random values each time we run this code.
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(0)
torch.set_printoptions(linewidth=120)
from graphviz import Digraph
import torchviz
Now, we'll create two tensors a
and b
, which each contain just one float
value.
a = torch.tensor(1.0, requires_grad=True)
b = torch.tensor(2.0, requires_grad=True)
By setting requires_grad=True
, we are specifying that we want to track operations that occur on the tensor. This tracking will allow us to calculate gradients for the tensor that requires a gradient.
Generally, we set the requires_grad
parameter of a given tensor to True
if gradients will later be computed with respect to the tensor. When this parameter is set to True
, we also say that the tensor has gradient tracking turned on or enabled.
The recorded operations that occur on a tensor is what is referred to as its computational graph, or just graph for short. We can think of a graph as being attached to a tensor.
In a neural network, the learnable parameters like weights and biases have requires_grad
set to True
since we'll later want to compute the gradient of the network's loss with respect to these parameters.
For a given graph, tensors that are created by the user with requires_grad
set to True
are referred to as leaf tensors or leaf nodes, which means that they are not the result of another operation in the graph.
Leaf nodes in a graph are nodes where the graph starts or stops. They are sometimes called terminal nodes of a graph.
Consider the weights in a network. They are not the result of any operation in the graph and therefore are leaf tensors. We'll understand why these tensors are referred to as leaves when we plot a visualization of a tensor's graph.
Now we'll create a new tensor c
as the sum of a
and b
.
c = a + b
By printing c
, we get the value of the tensor, as well as its gradient function grad_fn
.
tensor(3., grad_fn=AddBackward0)
When a tensor is created as a result of an operation performed on one or more tensors that have gradient tracking turned on, it will have a gradient function grad_fn
, which specifies how the tensor was created.
In the case of c
, its grad_fn
is AddBackward
. This indicates that c
is the result of a sum. The referenced function AddBackward
is the function that will be called when backpropagation occurs on this tensor.
Since c
has a grad_fn
, this means that we can see how c
will be affected if we change either of its inputs, a
or b
. In other words, we can calculate the derivatives or gradients of c
with respect to a
and b
.
This is analogous to the loss of a network being a function of all the network's weights, and then calculating the gradients of the loss with respect to these weights. The weights will have gradient tracking enabled, and therefore, the loss will have a grad_fn
that discloses how to compute its gradients.
Now let's create one more tensor.
d = c * a
Here, d
is the product of the leaf tensor a
with the tensor c
, which itself is the sum of two leaf tensors a
and b
. Let's print d
.
tensor(3., grad_fn=MulBackward0)
We can see the value of d
is 3
as a result of the multiplication of a
and c
, and we also see it's grad_fn
is MulBackward
.
This tells us that d
was created as a result of the multiplication operation, and also that somewhere in the history of computations that resulted in d
's creation, we have at least one tensor that has gradient tracking enabled.
Now, we know d
is a result of the multiplication operation with one of the inputs being c
, which itself was the result of the addition operation of the two leaf tensors a
and b
.
If we want to calculate the gradients of d
with respect to a
and b
, we have to first call backward()
on d
. The backward()
method computes the gradient of the given tensor with respect to any leaf tensors. In our case, this will be with respect to a
and b
. The resulting gradients will be accessible by calling grad
on the leaf tensors.
For example, let's look at the gradient of d
with respect to a
and b
by calling grad
on these leaf tensors.
print(a.grad, b.grad)
We get the following output.
None None
This is because backward()
has not been called yet on a tensor that was derived from a
and b
. Without this call, no gradients will have been computed with respect to a
and b
.
So now let's call backward()
on d
to compute d
's gradients with respect to a
and b
, and then inspect the corresponding gradients again.
d.backward()
print(a.grad, b.grad)
Now we have the following results.
tensor(4.) tensor(1.)
Now we can see that the gradient of d
with respect to a
is 4
, and the gradient of d
with respect to b
is 1
.
Visualizing a Computational Graph
Now that we're well acquainted with the relationships of these tensors and how gradients are tracked via the use of a computational graph, let's visualize these results using torchviz
.
torchviz.make_dot(d, params = {
"a": a, "b": b, "c": c, "d": d
})
Passing d
to torchviz.make_dot()
plots a visual of d
's graph.
This is a graph representation of the operations we did that resulted in d
. This gives us a visual representation of how backward()
uses this graph to calculate d
's gradients with respect to a
and b
.
As we know, we can inspect a tensor's grad_fn
, and we earlier saw that d
's grad_fn
is MulBackward
since it is the direct result of the multiplication operation. In addition to seeing the operation that lead directly to d
's creation, we can also see previous operations by calling next_functions
on the grad_fn
.
print(d.grad_fn)
print(d.grad_fn.next_functions)
We can now see the previous operations that occurred before the multiplication that lead to d
's creation.
MulBackward0 object at 0x000001F229DFBC50
(
(AddBackward0 object at 0x000001F229DFBA20, 0),
(AccumulateGrad object at 0x000001F229DFBCF8, 0)
)
This shows us that the multiplication operation took two inputs. One input was the result of the addition operation, and the other input was a leaf tensor.
Going a step further, we can go all the way back to the creation of the leaf tensors a
and b
.
print(d.grad_fn.next_functions[0][0])
print(d.grad_fn.next_functions[0][0].next_functions)
This gives us the following result.
AddBackward0 object at 0x000001F229DFBA20
(
(AccumulateGrad object at 0x000001F229DFBCF8, 0),
(AccumulateGrad object at 0x000001F229DFBEB8, 0)
)
This shows us that tensor c
, which was created as a result of the addition operation, had two leaf inputs.
This process is representative of "walking through" the graph of d
. This is further illustrated with the function below that prints the next_functions
in an organized manner that is directly in relation with the graph we plotted with torchviz
.
def walk_graph(g, step=0):
if g is not None:
print(step, step * ' ', type(g))
for f in g.next_functions:
walk_graph(f[0], step + 1)
walk_graph(d.grad_fn, 0)
Calling this function as shown above gives us the following output.
0 class 'MulBackward0'
1 class 'AddBackward0'
2 class 'AccumulateGrad'
2 class 'AccumulateGrad'
1 class 'AccumulateGrad'
Detaching a Tensor's Graph
Now we're going to discuss what happens when we detach a graph from a tensor. This is directly applicable to what we saw when we implemented DCGAN training with PyTorch. We'll tie this specific example back in towards the end.
Let's create three leaf tensors.
a = torch.tensor(1.0, requires_grad=True)
b = torch.tensor(2.0, requires_grad=True)
c = torch.tensor(3.0, requires_grad=True)
Now let's create another tensor d
, which is the sum of a
and b
.
d = a + b
Now we'll create another tensor e
, which is the result of c
summed with the relu
operation on d
. Note that we've called detach()
on d
to create e
.
e = F.relu(d.detach()) + c
The last tensor f
is the same as tensor e
, except we're not detaching d
's graph when we create f
.
f = F.relu(d) + c
Let's now look at e
's graph using torchviz
.
torchviz.make_dot(e, params = {
"a": a, "b": b, "c": c, "d": d, "e": e, "f": f
})
We can see that e
is the result of the addition that occurred between the leaf tensor c
and some other tensor that is not having its operations recorded. As we know, this other tensor is the result of relu(d.detach())
.
Let's now look at f
's graph.
torchviz.make_dot(f, params = {
"a": a, "b": b, "c": c, "d": d, "e": e, "f": f
})
Recall that e
and f
were created in the exact same manner, except for that when creating e
, we detached d
's graph.
Tensor f
's graph looks more comprehensive because now we have d
's computational history included as well.
From the graph, we can see that f
is the result of the addition that occurred between the leaf tensor c
and relu(d)
.
We can then see the additional information that d
is the result of the addition that occurred between leaf tensors a
and b
. Since d
's history was detached when we created e
, we did not see this additional piece of history in e
's graph.
Now let's circle back to the use of detach()
during GAN training.
The use of detach()
during GAN Training
Below is a code snippet from the PyTorch training loop for training the discriminator of the DCGAN we implemented earlier. We're specifically focusing on the step of training the discriminator with an all-fake batch.
# TRAIN DISCRIMINATOR
# Train with all-real batch
...
# Train with all-fake batch
noise = torch.randn(real_images.size(0), z_size, 1, 1).to(device)
fake_images = netG(noise)
fake_output = netD(fake_images.detach())
d_loss = discriminator_loss(real_output, fake_output)
d_loss.backward()
d_optimizer.step()
We first generate a batch of fake images by passing a vector of random noise
tensors to the generator netG
and store the results in fake_images
.
We then pass fake_images
to the discriminator netD
to get predictions for this fake batch.
When doing this step, we use a copy of fake_images
that has been detached from its graph. The result of these predictions, fake_output
, is then used to calculate the discriminator's loss d_loss
, which we ultimately call backward()
on.
The network netG
is made up of many leaf tensors that lead to the creation of fake_images
, and so fake_images
has a history of operations already attached its graph from its creation. We'll need access to this history whenever we train netG
.
Note that calling backward()
clears the graph used to compute the gradients. Therefore, if we didn't detach fake_images
from its graph, then its graph would be cleared in the upcoming backward()
pass for netD
and would not be accessible when training netG
afterwards.
Computational graphs can become large, and this is why PyTorch clears graphs after using them in a backward()
pass.
Recall that after training the discriminator, we use this same tensor when training netG
. There, we pass this tensor to netD
again during netG
's training, and it is at this point that we'll need access to fake_images
graph to compute netG
's gradients.
# TRAIN GENERATOR
netG.zero_grad()
fake_output = netD(fake_images)
g_loss = generator_loss(fake_output)
g_loss.backward()
g_optimizer.step()
This is precisely why we take the crucial step of detaching the graph from the fake_images
when training the discriminator. If we didn't, we wouldn't have access to fake_images
graph at this point, as it would have been cleared when we called backward()
during the discriminator training.
quiz
resources
updates
Committed by on