Training a deep Q-network with fixed Q-targets
What’s up, guys? In this post, we’ll continue our discussion of deep Q-networks, and as promised from last time, we'll be introducing a second network called the target network, into the mix. We'll see how exactly this target network fits into the DQN training process. So, let's get to it!
Review: Training the policy network
Recall from last time that we left off with this summary that describes the training process of a deep Q-network.
- Initialize replay memory capacity.
- Initialize the network with random weights.
- For each episode:
- Initialize the starting state.
- For each time step:
- Select an action.
- Via exploration or exploitation
- Execute selected action in an emulator.
- Observe reward and next state.
- Store experience in replay memory.
- Sample random batch from replay memory.
- Preprocess states from batch.
- Pass batch of preprocessed states to policy network.
- Calculate loss between output Q-values and target Q-values.
- Requires a second pass to the network for the next state
- Gradient descent updates weights in the policy network to minimize loss.
We briefly mentioned previously that there were some issues that could arise from this approach though. These issues come into play in step \(8\) where we calculate the loss between the output Q-values and the target Q-values. Remember, this is the step that requires a second pass to the deep Q-network, otherwise known as the policy network.
As a quick refresher, remember, for a single sample, the first pass to the network occurs for the state from the experience tuple that was sampled. The network then outputs the Q-values associated with each possible action that can be taken from that state, and then the loss is calculated between the Q-value for the action from the experience tuple and the target Q-value for this action.
To calculate the target Q-value though, we were required to do a second pass to the network with the next state. From this second pass, we can obtain the maximum Q-value among the possible actions that can be taken from that next state, and plug that in to the Bellman equation to calculate the target Q-value for the action from the first pass.
This process is a bit of a earful, I know, so if you're struggling at all, be sure to spend some time on the previous post where we cover this in full detail.
Potential training issues with deep Q-networks
Alright, now that we have that refresher out of the way, let's focus on the potential issues with this process. As mentioned, the issues stem from the second pass to the network.
We do the first pass to calculate the Q-value for the relevant action, and then we do a second pass in order to caluclate the target Q-value for this same action. Our objective is to get the Q-value to approximate the target Q-value.
Remember, we don't know ahead of time what the target Q-value even is, so we attempt to approximate it with the network. This second pass occurs using the same weights in the network as the first pass.
Given this, when our weights update, our outputted Q-values will update, but so will our target Q-values since the targets are calculated using the same weights. So, our Q-values will be updated with each iteration to move closer to the target Q-values, but the target Q-values will also be moving in the same direction.
As Andong put it in the comments of the last video, this makes the optimization appear to be chasing its own tail, which introduces instability. As our Q-values move closer and closer to their targets, the targets continue to move further and further because we're using the same network to calculate both of these values.
The target network
Well, here's a perfect time to introduce the second network that we mentioned earlier. Rather than doing a second pass to the policy network to calculate the target Q-values, we instead obtain the target Q-values from a completely separate network, appropriately called the target network.
The target network is a clone of the policy network. Its weights are frozen with the original policy network’s weights, and we update the weights in the target network to the policy network’s new weights every certain amount of time steps. This certain amount of time steps can be looked at as yet another hyperparameter that we'll have to test out to see what works best for us.
So now, the first pass still occurs with the policy network. The second pass, however, for the following state occurs with the target network. With this target network, we're able to obtain the \(\max\) Q-value for the next state, and again, plug this value into the Bellman equation in order to calculate the target Q-value for the first state.
As it turns out, this removes much of the instability introduced by using only one network to calculate both the Q-values, as well as the target Q-values. We now have something fixed, i.e. fixed Q-targets, that we want our policy network to approximate. So, we no longer have the dog-chasing-it's-tail problem.
As mentioned though, these values don't stay completely fixed the entire time. After \(x\) amount of time steps, we'll update the weights in the target network with the weights from our policy network, which will in turn, update the target Q-values with respect to what it's learned over those past time steps. This will cause the policy network to start to approximate the udpated targets.
Wrapping up
Alright, so now let's just highlight what's changed in our training summary.
For the most part, this is the same. We only have a few tweaks. The first change is that now we have a new step at the start where we clone the policy network and call it the target network.
Additionally, when we calculate the loss between the Q-values output from the policy network and the target Q-values, we do this using the new target network now, rather than with a second pass to the policy network.
The last change is just that we update the target network weights with the policy network weights every \(x\) time steps.
- Initialize replay memory capacity.
- Initialize the policy network with random weights.
- Clone the policy network, and call it the target network.
- For each episode:
- Initialize the starting state.
- For each time step:
- Select an action.
- Via exploration or exploitation
- Execute selected action in an emulator.
- Observe reward and next state.
- Store experience in replay memory.
- Sample random batch from replay memory.
- Preprocess states from batch.
- Pass batch of preprocessed states to policy network.
- Calculate loss between output Q-values and target Q-values.
- Requires a pass to the target network for the next state
- Gradient descent updates weights in the policy network to minimize loss.
- After \(x\) time steps, weights in the target network are updated to the weights in the policy network.
Take some time to go over this new algorithm and see if you now have the full picture for how target networks fit into the deep Q-network training process. Let me know your thoughts in the comments! I’ll see ya in the next one!