Loading...

Reinforcement Learning - Developing Intelligent Agents

Deep Learning Course 5 of 6 - Level: Advanced

Solving Cart and Pole with Deep Q-Network - Reinforcement Learning Code Project

video

expand_more chevron_left

text

expand_more chevron_left

Solving Cart and Pole with Deep Q-Network - Reinforcement Learning Code Project

Welcome back to this course on reinforcement learning! In this episode, we'll discuss how we can tune our current code project in order for our deep Q-network to solve the Cart and Pole environment.

We last left off having completed our program by training our deep Q-network on the Cart and Pole environment. There, we saw that the network's average reward consistently grew over the 1000 episode training duration, with its final score reaching 100-episode average reward of 90.

This growing reward indicates that the DQN was learning the environment over time, however, in order to solve the Cart and Pole environment, we must reach a 100-episode average reward of 195 or higher.

Solving the Cart and Pole Environment

After experimenting with both the network architecture and the hyperparameters, I found that tuning these specs alone was not sufficient for solving this environment with our current setup.

Sometimes in deep learning, we may need to do more than experiment with hyperparameters or network architecture to solve a problem. We may, for example, decide we may need a more sophisticated algorithm to solve this environment, or perhaps reconsider how we're processing and passing the environment state inputs to the model.

In this particular environment, there are multiple options for how we may choose to determine the state of the environment.

In our project, for example, we chose to work with the difference of the last two rendered screen frames, and then pass this pixel data as input to the network.

As another option, which was implemented in the original paper that used a DQN to solve Atari environments, the authors used a stack of the last four frames as their input.

The Cart and Pole OpenAI Gym environment, as another example, returns the environment states (which are to be passed as network input) as 1-dimensional tensors containing the current cart position, cart velocity, pole angle, and pole velocity at the tip.

env = gym.make('CartPole-v0')
state = env.reset()
print(state)

> [0.02193801 -0.02728747  0.04033063 -0.02978346]

As we can see, the states returned by the Gym environment are much more simplistic than the pixel data we have been working with in our project.

As an experiment, I left the network architecture and hyperparameters unchanged from last time, and changed only the way we determined the environment's states and therefore network inputs. I used the states returned by the Gym environment rather than using image data.

With this change, our deep Q-network was able to solve the environment in just 156 episodes!

Now, this environment may be still be solvable using our original image inputs, but again, we may need to use another algorithm or change how exactly we're processing the image inputs in order to do so.

I encourage you to continue experimenting with the environment and share your results in the comments! In the mean time, below has all of the code changes necessary in our current project to solve this environment with the new states.

Code Changes

To make our current code work with the states returned by the gym environment, we need to make a few adjustments to our program.

Deep Q-Network

The first change that we need to make is to our DQN class.

Previously, as input to the DQN, we were passing the state of the environment as the pixel data resulting from the difference of the last two screen frames. Now, we're passing the state returned by the Gym environment as a 1-dimensional tensor with 4 elements: [cart_position, cart_velocity, pole_angle, pole_velocity_at_tip].

As such, we need to change the in_features in our first hidden layer, and we no longer need to flatten the input in the forward()function, as it's already flattened.

class DQN(nn.Module):
    def __init__(self, num_state_features):
        super().__init__()
         
        self.fc1 = nn.Linear(in_features=num_state_features, out_features=24)   
        self.fc2 = nn.Linear(in_features=24, out_features=32)
        self.out = nn.Linear(in_features=32, out_features=2)            

    def forward(self, t):
        t = F.relu(self.fc1(t))
        t = F.relu(self.fc2(t))
        t = self.out(t)
        return t

Agent

In the select_action() function of our Agent class, when exploiting the environment, we return the best actions for the given states returned by the policy_network. Now that the shape of our states have changed, we need to reshape the results returned by the network by adding one additional dimension to the tensor using the unsqueeze() function.

class Agent():
    def __init__(self, strategy, num_actions, device):
        self.current_step = 0
        self.strategy = strategy
        self.num_actions = num_actions
        self.device = device

    def select_action(self, state, policy_net):
        rate = strategy.get_exploration_rate(self.current_step)
        self.current_step += 1

        if rate > random.random():
            action = random.randrange(self.num_actions)
            return torch.tensor([action]).to(self.device) # explore      
        else:
            with torch.no_grad():
                return policy_net(state)
                        .unsqueeze(dim=0)
                        .argmax(dim=1)
                        .to(self.device) # exploit

This change is shown in the last line of the select_action() function.

Environment Manager

In the CartPoleEnvManager class, we have a few changes to make.

Firstly, since we're no longer dealing with screen frames to determine the environment states, the class no longer has a current_screen attribute. Instead, we now have a current_state attribute, which is initialized to None.

class CartPoleEnvManager():
    def __init__(self, device):
        self.device = device
        self.env = gym.make('CartPole-v0').unwrapped
        self.env.reset()
        self.done = False
        self.current_state = None

    ...

Similarly, the CartPoleEnvManager's reset() function now resets its current_state to the initial observation returned by the Gym environment when its reset, rather than resetting the screen as was implemented previously.

def reset(self):
    self.current_state = self.env.reset()

We also need to set the current_state in the take_action() function that is called whenever the agent takes an action in the environment.

Previously, when we called env.step() in the CartPoleEnvManager's take_action() function, we would not store the state returned by env.step() since we were not making use of it. Now, since we're using the returned state, rather than the screen frames, we store the state returned by step() in the class's current_state attribute.

def take_action(self, action):        
    self.current_state, reward, self.done, _ = self.env.step(action.item())
    return torch.tensor([reward], device=self.device)

Next, we need to change what is returned by CartPoleEnvManager's get_state() function.

Previously, this function would return the pixel data that resulted from the difference of the last two screen frames rendered in the environment. Now, we simply return the state of the environment that we've stored in the class's current_state attribute.

def get_state(self):
    if self.done:
       return torch.zeros_like(
           torch.tensor(self.current_state), device=self.device
        ).float()
    else:
       return torch.tensor(self.current_state, device=self.device).float()

Finally, we create a new function num_state_features(), which returns the number of features included in a state returned by the Gym environment. This is so that we can know the size of states that will be passed to the network as input.

def num_state_features(self):
    return self.env.observation_space.shape[0]

The only other changes regarding the CartPoleEnvManager class is that we no longer need any of the functions that we previously used for screen processing, and so they can all be deleted. These include:

  • get_screen_height()
  • get_screen_width()
  • get_processed_screen()
  • crop_screen()
  • transform_screen_data()
  • just_starting()

Tensor Processing

In our extract_tensors() utility function, now that the states are 1-dimensional, we need to stack the states and next_states along a new axis, rather than concatenate (cat) them along an existing axis, as we were doing previously when our states were 4-dimensional (BCHW).

def extract_tensors(experiences):
    # Convert batch of Experiences to Experience of batches
    batch = Experience(*zip(*experiences))

    t1 = torch.stack(batch.state)
    t2 = torch.cat(batch.action)
    t3 = torch.cat(batch.reward)
    t4 = torch.stack(batch.next_state)

    return (t1,t2,t3,t4)

Check out this episode on stacking versus concatenating to explore this concept further.

Main Program

Given the changes we've gone over so far, we now just need to carry those changes over in our main program where the actual training takes place.

The first change in the main program is that when we now initialize both the policy_net and the target_net, we no longer pass screen height and width to determine the input size, but instead call our new em.num_state_features() function to determine the size.

policy_net = DQN(em.num_state_features()).to(device)
target_net = DQN(em.num_state_features()).to(device)

Next, we've now added a line to our main program's nested timestep loop to render() the environment to the screen. Previously, the call to render() was nested inside one of the CartPoleEnvManager's screen processing functions get_processed_screen() that we no longer utilize.

for timestep in count():
    em.render()
    ...

Finally, we've added one line to the end of the episode loop to end the program if the network solves the environment by reaching the 100-episode average reward of 195 or higher.

for episode in range(num_episodes):
    ...
    if get_moving_average(100, episode_durations)[-1] >= 195:
        break

With this, our program will terminate once we reach a score greater than or equal to 195.

Note, due to the inherent randomness that is involved in this environment and with neural networks in general, the final results may vary. If you find that the model diverges prior to solving, then you may restart the program to reinitialize everything and try a second attempt.

That wraps up all of the code changes necessary to use our program with the states returned by the Gym environment!

quiz

expand_more chevron_left
deeplizard logo DEEPLIZARD Message notifications

Quiz Results

resources

expand_more chevron_left
In this episode, we'll discuss how we can tune our current code project in order for our deep Q-network to solve the Cart and Pole environment. Sources: Reinforcement Learning: An Introduction, Second Edition by Richard S. Sutton and Andrew G. Bartow http://incompleteideas.net/book/RLbook2020.pdf Playing Atari with Deep Reinforcement Learning by Deep Mind Technologies https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf πŸ•’πŸ¦Ž VIDEO SECTIONS πŸ¦ŽπŸ•’ 00:00 Welcome to DEEPLIZARD - Go to deeplizard.com for learning resources 00:13 Tuning DQN to Solve Cart and Pole 04:18 Collective Intelligence and the DEEPLIZARD HIVEMIND πŸ’₯🦎 DEEPLIZARD COMMUNITY RESOURCES 🦎πŸ’₯ πŸ‘‹ Hey, we're Chris and Mandy, the creators of deeplizard! πŸ‘€ CHECK OUT OUR VLOG: πŸ”— https://youtube.com/deeplizardvlog πŸ’» DOWNLOAD ACCESS TO CODE FILES πŸ€– Available for members of the deeplizard hivemind: πŸ”— https://deeplizard.com/resources ❀️🦎 Special thanks to the following polymaths of the deeplizard hivemind: Tammy BufferUnderrun Mano Prime πŸ‘€ Follow deeplizard: Our vlog: https://youtube.com/deeplizardvlog Facebook: https://facebook.com/deeplizard Instagram: https://instagram.com/deeplizard Twitter: https://twitter.com/deeplizard Patreon: https://patreon.com/deeplizard YouTube: https://youtube.com/deeplizard πŸŽ“ Deep Learning with deeplizard: Deep Learning Dictionary - https://deeplizard.com/course/ddcpailzrd Deep Learning Fundamentals - https://deeplizard.com/course/dlcpailzrd Learn TensorFlow - https://deeplizard.com/learn/video/RznKVRTFkBY Learn PyTorch - https://deeplizard.com/learn/video/v5cngxo4mIg Reinforcement Learning - https://deeplizard.com/learn/video/nyjbcRQ-uQ8 Generative Adversarial Networks - https://deeplizard.com/course/gacpailzrd πŸŽ“ Other Courses: Data Science - https://deeplizard.com/learn/video/d11chG7Z-xk Trading - https://deeplizard.com/learn/video/ZpfCK_uHL9Y πŸ›’ Check out products deeplizard recommends on Amazon: πŸ”— https://amazon.com/shop/deeplizard πŸ“• Get a FREE 30-day Audible trial and 2 FREE audio books using deeplizard's link: πŸ”— https://amzn.to/2yoqWRn 🎡 deeplizard uses music by Kevin MacLeod πŸ”— https://youtube.com/channel/UCSZXFhRIx6b0dFX3xS8L1yQ πŸ”— http://incompetech.com/ ❀️ Please use the knowledge gained from deeplizard content for good, not evil.

updates

expand_more chevron_left
deeplizard logo DEEPLIZARD Message notifications

Update history for this page

Did you know you that deeplizard content is regularly updated and maintained?

  • Updated
  • Maintained

Spot something that needs to be updated? Don't hesitate to let us know. We'll fix it!


All relevant updates for the content on this page are listed below.