PyTorch - Python Deep Learning Neural Network API

Deep Learning Course 4 of 6 - Level: Intermediate

Reset Weights PyTorch Network - Deep Learning Course

video

expand_more chevron_left

text

expand_more chevron_left

Resetting Network Weights in PyTorch

Welcome to deeplizard. My name is Chris. In this lesson, we're going to see how we can reset the weights in a PyTorch network.

drawing

Without further ado, let's get started.

Ways to Reset Weights

We'll look at several ways we can reset a network's weights.

  1. Individual layer
  2. Individual layer inside a network
  3. Subset of layers inside a network
  4. All weights layer by layer
  5. All weights using snapshot
  6. All weights using re-initialization

Individual Layer

Let's begin by seeing how we can reset the weights of an individual layer. To do this, we'll work with a single linear layer.

torch.manual_seed(50)
layer = nn.Linear(2,1)

Here, we've created a linear layer, and we've manually set the seed used to generate random numbers in PyTorch. This will ensure that we can regenerate the same weights when we do the reset.

Note that we've created our linear layer using the nn.Layer class. This layer takes in two input features and returns a single output.

Now, let's see how we can reset the weights for this layer. The code below performs the following tasks on the layer's weights:

  1. Check
  2. Change
  3. Check
  4. Reset
  5. Check
layer.weight ##### 1 Check
# Output:
# Parameter containing:
# tensor([[ 0.1669, -0.6100]], requires_grad=True)

t = torch.rand(2)
o = layer(t)
o.backward()

optimizer = optim.Adam(layer.parameters(), lr=.01)
optimizer.step() ##### 2 Change

layer.weight ##### 3 Check
# Parameter containing:
# tensor([[ 0.1569, -0.6200]], requires_grad=True)

torch.manual_seed(50)
layer.reset_parameters() ##### 4 Reset

layer.weight ##### 5 Check
# Parameter containing:
# tensor([[ 0.1669, -0.6100]], requires_grad=True)

Here, we can see that the weights were changed and then reset to their original values. Note the use of the reset_parameters() method. It did the primary work. This method simply re-initializes the parameters, i.e. the weights, randomly. To actually get the same values on each reset, it's important to remember to use the manual_seed() function with the same seed value.

The nn.Linear source code for the reset_parameters() method can be seen below:

# linear.py version 1.7.0                                              
def reset_parameters(self) -> None:
    init.kaiming_uniform_(self.weight, a=math.sqrt(5))
    if self.bias is not None:
        fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
        bound = 1 / math.sqrt(fan_in)
        init.uniform_(self.bias, -bound, bound)

Notice that the weights and bias are both re-initialized. The is because the parameters of a layer include the weights and biases.

Individual layer inside a network

If we are working with an individual layer inside a network, the process is almost the same as when working with the layer directly. The only difference is that we access the layer using the network instance.

Let's create a network with a single layer, and see how to access the layer using the network instance.

torch.manual_seed(50)
network = nn.Sequential(nn.Linear(2,1))
network[0]
# Output:
#Linear(in_features=2, out_features=1, bias=True)

With this, we can see that we have the same layer as before, but now, the layer is inside the network instance. We can access this layer by indexing into the network instance like so network[0]. With this information in hand, we can see that we have the similar code as before.

network[0].weight
# Output:
# Parameter containing:
# tensor([[ 0.1669, -0.6100]], requires_grad=True)

t = torch.rand(2)
o = network(t)
o.backward()

optimizer = optim.Adam(network.parameters(), lr=.01)
optimizer.step()

network[0].weight
# Parameter containing:
# tensor([[ 0.1569, -0.6200]], requires_grad=True)

torch.manual_seed(50)
network[0].reset_parameters()

network[0].weight
# Parameter containing:
# tensor([[ 0.1669, -0.6100]], requires_grad=True)

Subset of layers inside a network

To reset a subset of layer weights inside a network, we simply target each layer inside the network directly. We could, for example, collect the subset of layers into a Python list and iterate over the list calling the reset_parameters() method on each layer.

All weights layer by layer

Resetting all the weights of a network layer by layer can become problematic as the network's architecture increases in complexity. We have to remember that PyTorch networks are built using composable modules, and so weights can be nested deep in a module stack.

This means that we would need to implement some type of recursion or complicated logic just to access the layers. We can even run into problems using simple networks. Let's see an example.

network = nn.Sequential(nn.Linear(2,1))
network[0].weight
# Output
# Parameter containing:
# tensor([[-0.0559, -0.0174]], requires_grad=True)

torch.manual_seed(50)
for module in network.children():
    module.reset_parameters()

network[0].weight
# Parameter containing:
# tensor([[ 0.1669, -0.6100]], requires_grad=True)

In this example, we are able to iterate over the child modules of the network, and we are able to successfully reset the parameters. Now, let's add another module to our network to see the trouble.

network = nn.Sequential(
    nn.Linear(2,1)
    , nn.Softmax()
)
network
# Output:
# Sequential(
#     (0): Linear(in_features=2, out_features=1, bias=True)
#     (1): Softmax(dim=None)
# )

try: 
    torch.manual_seed(50)
    for module in network.children():
        module.reset_parameters()
except Exception as e:
    print(e)
# Output
# 'Softmax' object has no attribute 'reset_parameters'

With this setup, we can see that we get an error about nn.Softmax not having the required method. Yes, it is possible to add more logic inside the for loop. However, we're not going to consider this solution as it's error prone, and we have better options.

All weights using a snapshot

The most versatile way to reset all of the weights in an PyTorch model is to use a snapshot. This allows us to choose the exact state the network is moving to when resetting. This is without using the torch.manual_seed() method.

We create a network snapshot at any time by saving the network's state to disk. Then, when we are ready to reset, we load the state and apply it to the model.

The code below shows a full example of this:

torch.manual_seed(50)
network = nn.Sequential(nn.Linear(2,1))

network[0].weight
# Output
# Parameter containing:
# tensor([[ 0.1669, -0.6100]], requires_grad=True)

torch.save(network.state_dict(), "./network.pt")

t = torch.rand(2)
o = network(t)
o.backward()
optimizer = optim.Adam(network.parameters(), lr=.01)
optimizer.step()

network[0].weight
# Parameter containing:
# tensor([[ 0.1569, -0.6200]], requires_grad=True)

network.load_state_dict(torch.load("./network.pt"))

network[0].weight
# Parameter containing:
# tensor([[ 0.1669, -0.6100]], requires_grad=True)
                                     

All weights using re-initialization

The best options for most use cases is actually very easy and intuitive but often overlooked. To reset the weights, we simply re-initialize the network instance. This will give us a network with the same starting values as long as we are using the torch.manual_seed() method.

The following code demonstrates this method.

torch.manual_seed(50)
network = nn.Sequential(nn.Linear(2,1))

network[0].weight
# Output
# Parameter containing:
# tensor([[ 0.1669, -0.6100]], requires_grad=True)

t = torch.rand(2)
o = network(t)
o.backward()
optimizer = optim.Adam(network.parameters(), lr=.01)
optimizer.step()

network[0].weight
# Parameter containing:
# tensor([[ 0.1569, -0.6200]], requires_grad=True)

torch.manual_seed(50)
network = nn.Sequential(nn.Linear(2,1))

network[0].weight
# Parameter containing:
# tensor([[ 0.1669, -0.6100]], requires_grad=True)

quiz

expand_more chevron_left
deeplizard logo DEEPLIZARD Message notifications

Quiz Results

resources

expand_more chevron_left
In this lesson, we're going to see how we can reset the weights in a PyTorch network. πŸ•’πŸ¦Ž VIDEO SECTIONS πŸ¦ŽπŸ•’ 00:00 Welcome to DEEPLIZARD - Go to deeplizard.com for learning resources 00:30 What is Batch Norm? 04:04 Creating Two CNNs Using nn.Sequential 09:42 Preparing the Training Set 10:45 Injecting Networks Into Our Testing Framework 14:55 Running the Tests - BatchNorm vs. NoBatchNorm 16:30 Dealing with Error Caused by TensorBoard 19:49 Collective Intelligence and the DEEPLIZARD HIVEMIND πŸ’₯🦎 DEEPLIZARD COMMUNITY RESOURCES 🦎πŸ’₯ πŸ‘‹ Hey, we're Chris and Mandy, the creators of deeplizard! πŸ‘€ CHECK OUT OUR VLOG: πŸ”— https://youtube.com/deeplizardvlog πŸ’» DOWNLOAD ACCESS TO CODE FILES πŸ€– Available for members of the deeplizard hivemind: πŸ”— https://deeplizard.com/resources ❀️🦎 Special thanks to the following polymaths of the deeplizard hivemind: Tammy BufferUnderrun Mano Prime πŸ‘€ Follow deeplizard: Our vlog: https://youtube.com/deeplizardvlog Facebook: https://facebook.com/deeplizard Instagram: https://instagram.com/deeplizard Twitter: https://twitter.com/deeplizard Patreon: https://patreon.com/deeplizard YouTube: https://youtube.com/deeplizard πŸŽ“ Deep Learning with deeplizard: Deep Learning Dictionary - https://deeplizard.com/course/ddcpailzrd Deep Learning Fundamentals - https://deeplizard.com/course/dlcpailzrd Learn TensorFlow - https://deeplizard.com/learn/video/RznKVRTFkBY Learn PyTorch - https://deeplizard.com/learn/video/v5cngxo4mIg Reinforcement Learning - https://deeplizard.com/learn/video/nyjbcRQ-uQ8 Generative Adversarial Networks - https://deeplizard.com/course/gacpailzrd πŸŽ“ Other Courses: Data Science - https://deeplizard.com/learn/video/d11chG7Z-xk Trading - https://deeplizard.com/learn/video/ZpfCK_uHL9Y πŸ›’ Check out products deeplizard recommends on Amazon: πŸ”— https://amazon.com/shop/deeplizard πŸ“• Get a FREE 30-day Audible trial and 2 FREE audio books using deeplizard's link: πŸ”— https://amzn.to/2yoqWRn 🎡 deeplizard uses music by Kevin MacLeod πŸ”— https://youtube.com/channel/UCSZXFhRIx6b0dFX3xS8L1yQ πŸ”— http://incompetech.com/ ❀️ Please use the knowledge gained from deeplizard content for good, not evil.

updates

expand_more chevron_left
deeplizard logo DEEPLIZARD Message notifications

Update history for this page

Did you know you that deeplizard content is regularly updated and maintained?

  • Updated
  • Maintained

Spot something that needs to be updated? Don't hesitate to let us know. We'll fix it!


All relevant updates for the content on this page are listed below.