Building a Deep Q-Network in Code
Welcome back to this series on reinforcement learning! In this episode, we’ll get started with building our deep Q-network to be able to perform in the cart and pole environment. Let’s get to it!
After following the environment prep we covered last time, we’re now ready to start writing our code. We’ll be making use of everything we’ve learned about deep Q-networks so far, including the topics of experience replay, fixed Q-targets, and epsilon greedy strategies, to develop our code.
We’ll use the final summary of the DQN training process below that we discussed in an earlier episode to guide our understanding while developing our code. Make sure you’ve familiarized yourself with these concepts fundamentally first so that you can gain a solid grasp for why we’re doing what we’re doing in the upcoming code.
- Initialize replay memory capacity.
- Initialize the policy network with random weights.
- Clone the policy network, and call it the target network.
- For each episode:
- Initialize the starting state.
- For each time step:
- Select an action.
- Via exploration or exploitation
- Requires a pass to the target network for the next state
- After \(x\) time steps, weights in the target network are updated to the weights in the policy network.
Also, remember I mentioned last time that we will be using PyTorch to train our DQN. I also just wanted to quickly mention that the PyTorch code we use can be adapted to whatever other neural network API you may want to use as well. The code and implementation should be easily generalizable.
Just a quick announcement before getting to the code, recall that last time we also discussed how the code we’d be writing would be based on the original PyTorch deep Q-network code available on PyTorch’s website with just some minor tweaks modifications of my own. Since the last episode, though, I’ve spent more time going over the code and decided on several more changes that will differ from the original tutorial on PyTorch’s site. I just wanted to give you a heads-up on that since there will now be considerably more differences than what I originally alluded to last time.
Without further ado, let’s jump into it!
Code set up
%matplotlib inline import gym import math import random import numpy as np import matplotlib import matplotlib.pyplot as plt from collections import namedtuple from itertools import count from PIL import Image import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F import torchvision.transforms as T
As expected, the first thing we’re doing is importing all of the libraries we’re going to be making use of. We’ve got
gym and some
PyTorch modules here
plus many standard libraries like
random, and a few others.
Set up display
Next, we import
IPython’s display module to aid us in plotting images to the screen later.
is_ipython = 'inline' in matplotlib.get_backend() if is_ipython: from IPython import display
Now that we’ve gotten past the overarching initial set up of things, we can now move on to implementing some of the concepts we’ve been discussing throughout this series.
I’ve organized this code in a very object oriented way, which I think makes things a lot easier to understand. We’re going to start out by covering all of the classes functions we need to create, and then at the end, we’ll see the use of all the classes and functions come into play in our main program.
Let’s start first with our deep Q-network. This is where PyTorch comes into play. To build a neural network in PyTorch, we use the
torch.nn package, which we gave the alias
nn when we imported it earlier. This package contains all of the typical components needed to build neural networks.
nn package, there is a class called
Module is the base class for all neural network modules, and so our network and all of its layers will extend
We define our
DQN as a class that extends
nn.Module. Our DQN will receive screenshot-like images of the cart and pole environment as input, so to create a
we’ll require the height and width of the image input that will be coming in to this model.
class DQN(nn.Module): def __init__(self, img_height, img_width): super().__init__() self.fc1 = nn.Linear(in_features=img_height*img_width*3, out_features=24) self.fc2 = nn.Linear(in_features=24, out_features=32) self.out = nn.Linear(in_features=32, out_features=2)
To start out with a very simple network, our network will consist only of two fully connected hidden layers, and an output layer. PyTorch refers to fully connected layers as
Linear layer accepts input with dimensions equal to the passed in
to the three color channels from our RGB images that will be received by the network as input.
Linear layer will have
24 outputs, and therefore our second
Linear layer will accept
24 inputs. Our second layer will have
32 outputs, and lastly, our output layer will have
32 inputs from the previous layer, and will have
In our particular cart and pole example, remember that the network will be outputting the Q-values that correspond to each possible action that the agent can take from a given state. Our only available actions are to move right or to move left, therefore, the number outputs will be equal to two.
As you can see, this architecture is being built within the
DQN class constructor, and we’ve given these arbitrary names of
fc2 to the two fully
connected layers, and
out as the output layer.
Also, note that this network is pretty arbitrary and also very basic. It doesn’t even contain any convolutional layers. I wanted to start out with something very straight forward at first, and then once we see how this network performs, we can start tuning the architecture and experimenting with different variations.
The last thing we have to do for our
DQN class is to define a function called
forward(). This function will implement a forward pass to the network. Note that all PyTorch neural
networks require an implementation of
def forward(self, t): t = t.flatten(start_dim=1) t = F.relu(self.fc1(t)) t = F.relu(self.fc2(t)) t = self.out(t) return t
For any particular image tensor,
t, passed to the network,
t will first need to be flattened before it can be passed to the first fully connected layer. After this,
t will be passed to the fully connected layer and then have relu applied to it. Then, this result will be passed to the second fully connected layer, and again have relu applied. This result
will then be passed to the output player. The result from the output layer will be returned by the
If this is your first time being exposed to PyTorch and you want to go deeper into understanding the steps that we just covered to build a network, be sure to check out our PyTorch series, where all of this is covered in complete and thorough detail. Otherwise, if you’re at all shaky on the fundamental concepts of forward passes, relu, layer input or output, then you’ll definitely want to spend some time on the Deep Learning Fundamentals series .
Now that we have our network, let’s move on to experiences. Recall that experiences from replay memory is what we’ll use to train our network. To create experiences, we creating a class called
Experience. This class will be used to create instances of
Experience objects that will get stored in and sampled from replay memory later.
Experience = namedtuple( 'Experience', ('state', 'action', 'next_state', 'reward') )
As you can see, we’re creating this class by calling
namedtuple(), which is a Python function for creating tuples with named fields.
namedtuple() is returning a new
tuple subclass named
Experience, which is specified by our first argument. This new
Experience class will be
used to create tuple-like objects that have the fields
next state, and
reward. Remember, these are the exact fields that we
previously discussed, which make up an individual experience.
Let’s show a quick example of an Experience object.
e = Experience(2,3,1,4)
e equal to an instance of the
Experience class and pass in the parameters
2, 3, 1, 4. Given the way we set up the
2 will be the
state of experience
3 will be the
1 will be the
e > Experience(state=2, action=3, next_state=1, reward=4)
Now that we have our
Experience class, let’s define our
ReplayMemory class, which is where these experiences will be stored.
Recall that replay memory will have some set capacity. This
capacity is the only parameter that needs to be specified when creating a
class ReplayMemory(): def __init__(self, capacity): self.capacity = capacity self.memory =  self.push_count = 0
capacity to whatever was passed in, and we also define a
memory attribute equal to an empty list.
memory will be
the structure that actually holds the stored experiences. We also create a
push_count attribute, which we initialize to
0, and we’ll use this to keep track of how many
experiences we’ve added to memory.
Now, we need a way to store experiences in replay memory as they occur, so we define this
push() function to do just that.
def push(self, experience): if len(self.memory) < self.capacity: self.memory.append(experience) else: self.memory[self.push_count % self.capacity] = experience self.push_count += 1
experience, and when we want to push a new experience into replay memory, we have to check first that the amount of experiences we already have in memory is indeed
less than the
capacity. If it is, then we append the experience to
If, on the other hand, the amount of experiences we have in
memory has reached
capacity, then we begin to push new experiences onto the front of memory, overwriting the oldest
experiences first. We then update our
push_count by incrementing by
Aside from storing experiences in replay memory, we also want to be able to sample experiences from replay memory. Remember, these sampled experiences will be what we use to train our DQN.
We define this
sample() function, which returns a random sample of experiences. The number of randomly sampled experiences returned will be equal to the
passed to the function.
def sample(self, batch_size): return random.sample(self.memory, batch_size)
Finally, we have this
can_provide_sample() function that returns a
boolean to tell us whether or not we can sample from memory. Recall that the size of a sample we’ll
obtain from memory will be equal to the batch size we use to train our network.
def can_provide_sample(self, batch_size): return len(self.memory) >= batch_size
For example, suppose we only have \(20\) experiences in replay memory and that our batch size is \(50\). Then, we will be unable to sample because we do not have \(20\) experiences yet. Therefore, before we try to sample from memory, we’ll do a
check to see if it’s possible to do so by calling the
can_provide_sample() function first. We’ll see this in practice later.
Epsilon Greedy Strategy
Hopefully you remember from earlier in this series the concept of exploration versus exploitation. This has to do with the way our agent selects actions. Recall, our agent’s actions will either fall in the category of exploration, where the agent is just exploring the environment by taking a random action from a given state, or the category of exploitation, where the agent exploits what it’s learned about the environment to take the best action from a given state.
To get a balance of exploration and exploitation, we use what we previously introduced as an epsilon greedy strategy. With this strategy, we define an exploration rate called epsilon that we initially set to \(1\). This exploration rate is the probability that our agent will explore the environment rather than exploit it. With epsilon equal to \(1\), it is \(100%\) certain that the agent will start out by exploring the environment.
As the agent learns more about the environment, though, epsilon will decay by some decay rate that we set so that the likelihood of exploration becomes less and less probable as the agent learns more and more about the environment. We’re now going
to write an
EpsilonGreedyStrategy class that puts this idea into code.
class EpsilonGreedyStrategy(): def __init__(self, start, end, decay): self.start = start self.end = end self.decay = decay
decay, which correspond to the starting, ending, and decay values of epsilon. These attributes all
get initialized based on the values that are passed in during object creation.
def get_exploration_rate(self, current_step): return self.end + (self.start - self.end) * \ math.exp(-1. * current_step * self.decay)
We then have this single function
get_exploration_rate(), which requires the
current_step of the agent to be passed. This function returns the calculated exploration rate, which
is based on the formula that we covered in an
earlier episode. Our agent is going to be able to use the exploration rate to determine how it should select it’s actions, either by exploring or exploiting the environment.
Reinfocement Learning Agent
Speaking of our agent, an
Agent class is where we’re headed next. Our
Agent class will require a
So, later when we create an
Agent object, we’ll need to already have an instance of
EpsilonGreedyStrategy class created so that we can use that strategy to create our
num_actions corresponds to how many possible actions can the agent take from a given state. In our cart and pole example, this number will always be two since the agent can always
choose to only move left or right.
class Agent(): def __init__(self, strategy, num_actions): self.current_step = 0 self.strategy = strategy self.num_actions = num_actions
We initialize the agent’s
num_actions accordingly, and we also initialize the
current_step attribute to
0. This corresponds to
the agent’s current step in the environment. The
Agent class has a single function called
select_action(), which requires a
state and a
def select_action(self, state, policy_net): rate = strategy.get_exploration_rate(self.current_step) self.current_step += 1 if rate > random.random(): return random.randrange(self.num_actions) # explore else: with torch.no_grad(): return policy_net(state).argmax(dim=1).item() # exploit
Remember a policy network is the name we give to our deep Q-network that we train to learn the optimal policy.
Within this function, we first initialize
rate to be equal to the exploration rate returned from the epsilon greedy
strategy that was passed in when we created our agent, and
we increment the agent’s
We then check to see if the exploration rate is greater than a randomly generated number between
1. If it is, then we explore the environment by randomly selecting an
1, corresponding to left or right moves.
If the exploration rate is not greater than the random number, then we exploit the environment by selecting the action that corresponds to the highest Q-value output from our policy network for the given
with torch.no_grad() before we pass data to our
policy_net to turn off gradient tracking since we’re currently using the model for inference and
During training PyTorch keeps track of all the forward pass calculations that happen within the network. It needs to do this so that it can know how to apply backpropagation later. Since we’re only using the model for inference at the moment, we’re telling PyTorch not to keep track of any forward pass calculations.
Next time, we’ll pick up with the code for how we’ll be extracting and preprocessing the cart and pole input for our DQN.
Let me know in the video comments how you’re moving so far, and please like this video to let us know you’re learning! Don’t forget to take the corresponding quiz to test your own understanding. See ya in the next one!