Deep Q-Learning - Combining Neural Networks and Reinforcement Learning
Deep Q-Learning and Deep Q-Networks explained
What's up, guys? In this post, we'll finally bring artificial neural networks into our discussion of reinforcement learning! Specifically, we'll be building on the concept of Q-learning we've discussed over the last few posts to introduce the concept of deep Q-learning and deep Q-networks (or DQNs). This will move us into the world of deep reinforcement learning. So, let's get to it!
Limitations of Q-learning with value iteration
From everything we've discussed over the last few posts, we should now be comfortable with the idea of Q-learning. Now, while it's true that the Q-learning algorithm that we used to play Frozen Lake may do a pretty decent job in relatively small state spaces, it's performance will drop-off considerably when we work in more complex and sophisticated environments.
Frozen Lake, for example, our environment was relatively simplistic with only
16 states and
4 actions, giving us a total state-action space of just
16 x 4. Meaning
we only had
16 x 4 or
64 Q-values to update in the Q-table. Given the fact that these Q-value updates occur in an iterative fashion, we can imagine that as our state space increases
in size, the time it will take to traverse all those states and iteratively update the Q-values will also increase.
Think about a video game where a player has a large environment to roam around in. Each state in the environment would be represented by a set of pixels, and the agent may be able to take several actions from each state. The iterative process of computing and updating Q-values for each state-action pair in a large state space becomes computationally inefficient and perhaps infeasible due to the computational resources and time this may take.
So, what can we do when we want to step up our game from a simple toy environment, like Frozen Lake, to something more sophisticated? Well, rather than using value iteration to directly compute Q-values and find the optimal Q-function, we instead use a function approximation to estimate the optimal Q-function.
Well, you know what can do a pretty darn good job at approximating functions? Artificial Neural Networks!
We'll make use of a deep neural network to estimate the Q-values for each state-action pair in a given environment, and in turn, the network will approximate the optimal Q-function. The act of combining Q-learning with a deep neural network is called deep Q-learning, and a deep neural network that approximates a Q-function is called a deep Q-Network, or DQN.
Let's break down how exactly this integration of neural networks and Q-learning works. We'll first discuss this at a high level, and then we'll get into all the nitty-gritty details.
Suppose we have some arbitrary deep neural network that accepts states from a given environment as input. For each given state input, the network outputs estimated Q-values for each action that can be taken from that state. The objective of this network is to approximate the optimal Q-function, and remember that the optimal Q-function will satisfy the Bellman equation that we covered previously:
With this in mind, the loss from the network is calculated by comparing the outputted Q-values to the target Q-values from the right hand side of the Bellman equation, and as with any network, the objective here is to minimize this loss.DQN-states-actions
After the loss is calculated, the weights within the network are updated via SGD and backpropagation, again, just like with any other typical network. This process is done over and over again for each state in the environment until we sufficiently minimize the loss and get an approximate optimal Q-function.
So, take a second now to think about how we previously used the Bellman equation to compute and update Q-values in our Q-table in order to find the optimal Q-function. Now, with deep Q-learning, our network will make use of the Bellman equation to estimate the Q-values to find the optimal Q-function. So, we're still solving the same general problem here, just with a different algorithm. Rather than making use of value iteration to solve the problem, we're now using a deep neural network.
Alright, we should now have a general idea about what deep Q-learning is and what, at a high level, the deep Q-network is doing. Now, let's get a little more into the details about the network itself.
We discussed earlier that the network would accept states from the environment as input. Thinking of Frozen Lake, we could easily represent the states using a simple coordinate system from the grid of the environment and use this as input.
SFFF FHFH FFFH HFFG
If we're in a more complex environment, though, like a video game, for example, then we'll use images as our input. Specifically, we'll use still frames that capture states from the environment as the input to the network.
The standard preprocessing done on the frames usually involves converting the RGB data into grayscale data since the color in the image is probably usually not going to affect the state of the environment. Additionally, we'll typically see some cropping and scaling as well to both cut out unimportant information from the frame and shrink the size of the image.
Now actually, rather than having a single frame represent a single input, we usually will use a stack of a few consecutive frames to represent a single input. So, we would grab, say, four consecutive frames from the video game. We'd then do all the preprocessing on each of these four frames we mentioned earlier – the grayscale conversion, the cropping, and the scaling – and then we'd take the preprocessed frames and stack them on top of each other in the order of which they occurred in the game.
We do this because a single frame usually isn't going to be enough for our network, or even for our human brains, to fully understand the state of the environment. For example, by just looking at the first single frame above from the Atari game, Breakout, we can't tell if the ball is coming down to the paddle or going up to hit the block. We also don't have any indication about the speed of the ball, or which direction the paddle is moving in.
If we look at four consecutive frames, though, then we have a much better idea about the current state of the environment because we now do indeed have information about all of these things that we didn't know with just a single frame. So, the takeaway is that a stack of frames will represent a single input, which represents the state of the environment.
Now that we know what the input is, the next thing to address is the inner workings of the network – the layers. Well, really, we're not going to see much more than we're used to seeing with any other network we've already covered. Like, seriously, many deep-Q networks are purely just some convolutional layers, followed by some non-linear activation function, and then the convolutional layers are followed by a couple fully connected layers, and that's it.
So, the layers used in a DQN are nothing new and nothing to be freaked out about.
If you do need a crash course or a refresher on convolutional neural networks or neural networks in general, then be sure to check out the Deep Learning Fundamentals series.
The last piece of the network to discuss is the output. The output layer will be a fully connected layer, and it will produce the Q-value for each action that can be taken from the given state that was passed as input.
For example, suppose in a given game, the actions we can take consist of moving left, moving right, jumping, and ducking. Then the output layer would consist of four nodes, each representing one of the four actions. The value produced from a single output node would be the Q-value associated with taking the action that corresponds to that node from the state that was supplied as input to the network.
We won't see the output layer followed by any activation function since we want the raw, non-transformed Q-values from the network.
Alright, so we now know what deep Q-learning and deep Q-networks are, what these networks consist of, and how they work.
I know when I was originally learning about DQNs, I was prepared to learn some type of new and mysterious neural network and was pretty surprised when I found out that, really, there was nothing new about the network at all. Instead, we were just utilizing a CNN to solve a different type of problem. Let me know in the comments if you're feeling the same way.
Stay tuned because in the next post, we'll continue discussing DQNs and dissect the training process step-by-step to prepare ourselves to create and train our own DQN in code. Thanks for contributing to collective intelligence, and I'll see ya in the next one!
Committed by on