Managing the Environment - Reinforcement Learning Code Project
Welcome back to this series on reinforcement learning! In this episode, we’ll be continuing to develop the code project we’ve been working on to build a deep Q-network to master the cart and pole problem. We'll see how to manage the environment and process images that will be passed to our deep Q-network as input.
Update to Agent
Before getting to the new code, I have one quick update to share for the
Agent class we developed
Agent will now require a
device, as shown below.
class Agent(): def __init__(self, strategy, num_actions, device): self.current_step = 0 self.strategy = strategy self.num_actions = num_actions self.device = device ...
device will be the device that we tell PyTorch to use for tensor calculations, i.e. a CPU or GPU. We’ll see exactly how to set this later in our main program, but for now, this explanation
is all we need to know.
device gets initialized with whatever
device we passed in, and we then use it on the tensor we return in our
select_action() function, as shown below. We do this
simply by calling
to() on the tensor and then passing in the
if rate > random.random(): action = random.randrange(self.num_actions) return torch.tensor([action]).to(self.device) # explore else: with torch.no_grad(): return policy_net(state).argmax(dim=1).to(self.device) # exploit
Alright, that’s it for the update.
Now that we have created the classes for our all of our crucial objects, like our
Agent, we’re now going to move on to creating a
class that we’ll call
This class will manage our cart and pole environment. It will wrap several of gym’s environment capabilities, and it will also give us some added functionality, like image preprocessing, for the environment images that will be given to our network as input.
To create a
CartPoleEnvManager object, we just require a
device be passed to the constructor. Just like the explanation we gave for the
device will be the device that we’re telling PyTorch to use for tensor calculations.
class CartPoleEnvManager(): def __init__(self, device): self.device = device self.env = gym.make('CartPole-v0').unwrapped self.env.reset() self.current_screen = None self.done = False ...
Within the class constructor, we first initialize the class’s
device attribute with the
device that was passed in, and we initialize the
env attribute to
be gym’s cart and pole environment. Calling
unwrapped gives us access to behind-the-scenes dynamics of the environment that we wouldn’t have access to otherwise.
Since we must
reset the environment to get an initial observation of it, we do this right after initializing
env. We then set the
current_screen attribute equal
current_screen will track the current screen of the environment at any given time, and when it’s set to
None, that indicates that we’re at
the start of an episode and have not yet rendered the screen of the initial observation. We’ll see more about this soon.
We then set the
done attribute equal to
done will track whether or not any taken action has ended an episode.
We now have a few wrapper functions that simply just wrap a function with the same name used by gym. Specifically, we have
reset, close, and render the environment using gym’s
def reset(self): self.env.reset() self.current_screen = None def close(self): self.env.close() def render(self, mode='human'): return self.env.render(mode)
As a reminder from when we covered these gym functions during the
Frozen Lake project, we call
reset() on the gym environment when we want the environment to be reset to a starting state.
reset() returns an initial observation from the
close() to close the environment when we’re finished with it, and we call
render() on the environment to render the current state to the screen. We can also get
numpy array version of the rendered screen from this function as well.
The only one of our wrapper functions that does anything outside of calling gym’s function with the same name is
reset(). You can see that we’re also setting the
When we reset the environment, we’re typically going to be at the end of an episode, and therefore, we want to set the
current_screen back to
None since this indicates
that we’re at the start of an episode and have not yet rendered the screen of the initial observation.
We’re wrapping these functions in this way so that later, in our main program, we’ll only have to deal with a
CartPoleEnvManager, and not both this manager and an environment.
Instead, we’re encapsulating the environment functionality within our environment manager, so that our manager can completely manage the environment using these functions, as well as new functions that will be introduced in just a moment that an original gym environment wouldn’t have access to. This gives our main program a clean and consistent interface for interacting with the environment.
Number of actions available to agent
Moving on to our next function,
num_actions_available() returns the number of actions available to an agent in the environment. In our cart and pole environment, at any given time, an agent
will only have two actions available: move left or move right.
def num_actions_available(self): return self.env.action_space.n
Taking an action in the environment
take_action() is a function that requires an
action to be passed in. Using this
action, we call
step() on the environment, which will execute
the given action taken by the agent in the environment.
def take_action(self, action): _, reward, self.done, _ = self.env.step(action.item()) return torch.tensor([reward], device=self.device)
As you may recall from
previous episodes where we used gym to set up our Frozen Lake environment,
step() returns a tuple containing the environment observation, reward, whether or not the episode ended, and
diagnostic info, all of which resulted from the agent executing that particular action.
For our purposes, we only care about the reward and whether or not the episode ended from taking the given action, so we set the
reward variable accordingly, and also update the class’s
done attribute with the boolean value of whether or not the episode ended by taking the given step.
Notice that we’re calling
item() on the
action we’re passing to
step(). This is because the
action that will be passed to this function
in our main program will be a tensor. We’ll be consistently working with tensors throughout the main program.
item() just returns the value of this tensor as a standard Python number,
which is what
take_action() function then returns the
reward wrapped in this PyTorch tensor. We’re processing the
reward in this way, by wrapping it in a tensor, to
put it in the format that will be needed later on in our main program.
So, we have a tensor coming into the function, and a tensor coming out of it. This is how we keep the data type consistent in our main program. Here, we can see where the
device comes into
play, as we’re setting the device of this tensor to be the device that was passed in to the
Starting an episode
Next, we have this function
just_starting() that returns
True when the
None and returns
def just_starting(self): return self.current_screen is None
current_screen is set to
None in the class constructor and also gets set to
None when the environment is reset after ending an episode. So, if
None, that means we are at the starting state of an episode and haven’t yet rendered an initial observation from the environment.
Getting the state of the environment
Next we define the function
get_state(). The point of this function is to return the current state of the environment in the form of a processed image of the screen. Remember, a deep Q-network
takes states of the environment as input, and we previously mentioned that for our environment, states would be represented using screenshot-like images.
Actually, note that we will represent a single state in the environment as the difference between the current screen and the previous screen. This will allow the agent to take the velocity of the pole into account from one single image. So, a single state will be represented as a processed image of the difference between two consecutive screens. We’ll see in a moment what type of processing is being done.
def get_state(self): if self.just_starting() or self.done: self.current_screen = self.get_processed_screen() black_screen = torch.zeros_like(self.current_screen) return black_screen else: s1 = self.current_screen s2 = self.get_processed_screen() self.current_screen = s2 return s2 - s1
We have two conditions we’re checking for in this function.
We check first to see if we are just starting or if we’re done with the episode. Remember, if we’re just starting, then the initial screen has not yet been rendered from the initial observation in the environment. If
done == True,
then that means the last action taken by the agent ended the episode.
We said that states would be represented as the difference between the last two screens. Well, when we’re at the start of a new episode, there is no last screen to compare to the current screen. So, we’re going to represent our starting state with a fully black screen. The fact we’re doing this will make more sense once we see some visual example of states in a few minutes.
When we’re in the next state that occurs after an agent has taken an action that ended the episode, we’ll also represent this state with a fully black screen as well.
We do this by first calling
get_processed_screen(), which returns the processed screen from the environment and assigns this result to the current screen. We then create a fully black screen
of the same shape as the
torch.zeros_like(). We’ll explore the
get_processed_screen() function more in a moment.
If we’re not just starting an episode, and we’re not ending it either, then we’re somewhere in the middle of an episode.
In this case we’ll take the difference between the current screen and the last screen and return this result.
In the code above,
s1 stands for screen1 and is set to the
s2 stands for screen2 and is set to the result of a new call to
get_processed_screen(). We then update our
current_screen to the value of
s2. So now,
s2 is our
our previous screen, so we return the difference of these two screens to represent a single state.
Get processed screen dimensions
Next we have these two simple functions
def get_screen_height(self): screen = self.get_processed_screen() return screen.shape def get_screen_width(self): screen = self.get_processed_screen() return screen.shape
These functions return the height and width of a processed screen by first getting a processed screen from the
get_processed_screen() function (which we’re about to cover), and then
indexing into the
shape of the screen with a
2 to get the height, or with a
3 to get the width.
Processing the screen image
Now, we’ll move on to the
get_processed_screen() function we’ve been referencing.
def get_processed_screen(self): screen = self.render('rgb_array').transpose((2, 0, 1)) # PyTorch expects CHW screen = self.crop_screen(screen) return self.transform_screen_data(screen)
This function first renders the environment as an RGB array using the
render() function and then transposes this array into the order of channels by height by width, which is what our PyTorch
DQN will expect.
This result is then cropped by passing it to the
crop_screen() function, which we’ll cover next. We then pass the cropped screen to the function
again, which we’ll cover in a moment, which just does some final data conversion and rescaling to the cropped image.
This transposed, cropped, and transformed version of the original screen returned by gym is what is returned by
Crop screen image
crop_screen() function accepts a
screen and will return a cropped version of it. We first get the height of the screen that was passed in, and then we strip off the top
and bottom of the screen.
We’ll see an example of a screen both before and after it’s been processed in a moment, and there you’ll see how there is a lot of plain white space at the top and bottom of the cart and pole environment, so we’re removing this empty space here.
def crop_screen(self, screen): screen_height = screen.shape # Strip off top and bottom top = int(screen_height * 0.4) bottom = int(screen_height * 0.8) screen = screen[:, top:bottom, :] return screen
top equal to the value that corresponds to 40% of the
screen_height. Similarly, we set
bottom equal to the value that corresponds to 80% of the
bottom values, we then take a slice of the screen starting from the
top value down to the
bottom value so that we’ve essentially
stripped off the top 40% of the original
screen and the bottom 20%.
Convert and rescale screen image data
Our last image processing function is
transform_screen_data(). This function accepts a
def transform_screen_data(self, screen): # Convert to float, rescale, convert to tensor screen = np.ascontiguousarray(screen, dtype=np.float32) / 255 screen = torch.from_numpy(screen) # Use torchvision package to compose image transforms resize = T.Compose([ T.ToPILImage() ,T.Resize((40,90)) ,T.ToTensor() ]) return resize(screen).unsqueeze(0).to(self.device) # add a batch dimension (BCHW)
We first pass this
screen to the numpy
ascontiguousarray() function, which returns a contiguous array of the same shape and content as
screen, meaning that all
the values of this array will be stored sequentially next to each other in memory.
We’re also converting the individual pixel values into type
float32 and rescaling all the values by dividing them each by
255. This is a common rescaling process that
occurs during image processing for neural network input.
We then convert this array to a PyTorch
We then use
Compose class to chain together several image transformations. We’ll call this compose
resize. So, when a tensor is passed
resize, it will first be converted to a
PIL image, then it will be resized to a
40 x 90 image. The
PIL image is then transformed to a
So, we pass our
screen from above to
resize and then add an extra batch dimension to the tensor by calling
unsqueeze(). This result is then what is returned by
Alright, we’re now finished up with the our
CartPoleEnvManager class. Let’s now take a look visually at the results of all the image processing that we went over from this
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") em = CartPoleEnvManager(device) em.reset() screen = em.render('rgb_array') plt.figure() plt.imshow(screen) plt.title('Non-processed screen example') plt.show()
This first screen is an example of what a non-processed screen looks like from the environment. We’re getting this screen just by setting up an instance of
reset() to get an initial observation, and then rendering the screen.
screen = em.get_processed_screen() plt.figure() plt.imshow(screen.squeeze(0).permute(1, 2, 0), interpolation='none') plt.title('Processed screen example') plt.show()
After we do all the processing to this image simply from calling
get_processed_screen(), we now get this processed image. There is actually some further processing we could do, like stripping
some of the empty space from the right and left sides, but we’ll save that for a later episodes as an experiment to try to improve performance.
screen = em.get_state() plt.figure() plt.imshow(screen.squeeze(0).permute(1, 2, 0), interpolation='none') plt.title('Starting state example') plt.show()
Now, remember how a state from the environment is created from the difference between two processed images? And also how the starting state will always be a fully black screen based on our earlier discussion? Well, here we can see that example by calling
get_state() on our environment manager.
for i in range(5): em.take_action(torch.tensor()) screen = em.get_state() plt.figure() plt.imshow(screen.squeeze(0).permute(1, 2, 0), interpolation='none') plt.title('Non starting state example') plt.show()
If we want to see what a state looks that is not a starting state, we can take some actions in the environment and call
get_state() again to get this result.
Since we’re taking the difference between the current screen and the previous screen, most pixel values will become zero. The only ones that are anything but zero are just the kind of highlight that we’re seeing here which gives us an idea of where are cart and pole were in the previous screen, and where they have moved to now.
em.done = True screen = em.get_state() plt.figure() plt.imshow(screen.squeeze(0).permute(1, 2, 0), interpolation='none') plt.title('Ending state example') plt.show() em.close()
Lastly, if we want to see the state of the environment after an episode has ended, we specify
done = True and call
em.get_state() again, and we can see the fully black screen
that we’d expect.
So, these states are exactly what will be passed to our DQN as input during training.
Now, we’re going to move on to a couple of quick utility functions we’ll have available to us during training so that we can plot our performance on a chart.
We’re creating this function called
plot() that accepts values and a moving average period. This plot will plot the duration of each episode, as well as the 100 episode moving average.
To solve cart and pole, the average reward must be greater than or equal to 195 over 100 consecutive episodes. Recall that our agent gets a reward of +1 for each step it takes that doesn’t end the episode. So, the duration of an episode measured in timesteps is exactly equivalent to the reward for that episode.
def plot(values, moving_avg_period): plt.figure(2) plt.clf() plt.title('Training...') plt.xlabel('Episode') plt.ylabel('Duration') plt.plot(values) plt.plot(get_moving_average(moving_avg_period, values)) plt.pause(0.001) if is_ipython: display.clear_output(wait=True)
pyplot module, we set up the figure, give it a title, name the axes, and give it the values to plot, which in our case will be episode durations.
We’ll also want to plot the 100 episode moving average, so we do so by calling the function
get_moving_average(), which accepts the
moving_average_period and the
values for which it will be calculating the moving average from.
def get_moving_average(period, values): values = torch.tensor(values, dtype=torch.float) if len(values) >= period: moving_avg = values.unfold(dimension=0, size=period, step=1) \ .mean(dim=1).flatten(start_dim=0) moving_avg = torch.cat((torch.zeros(period-1), moving_avg)) return moving_avg.numpy() else: moving_avg = torch.zeros(len(values)) return moving_avg.numpy()
get_moving_average(), we first transform the values to a PyTorch
tensor and then check to see if the length of the values is greater than or equal to the
period. We do this because we can’t calculate a moving average of a data set when the data set is not at least as large as the period we want to calculate the moving average for.
For example, if we want to calculate the 100-period moving average of episode durations, then if we’ve only played 90 episodes, a 100-period moving average can’t be calculated.
If this condition is met, then we calculate the moving average by first calling
unfold() on the tensor, which returns a tensor that contains all slices with a size equal to the
period that was passed in (in our case, that is going to be 100). It does this on the zeroth dimension of the original values tensor.
This gives us a new tensor containing all slices of size 100 across the original value tensor. We then take the
mean of each of these slices and
flatten the tensor so that now
moving_avg is equal to a tensor containing all 100-period averages from the values that were passed in.
We then concatenate this resulting tensor to a tensor of zeros with a size equal to
period-1. This is to show that the moving average for the first
period-1 values is zero given
the explanation we just gave a moment ago. So, if our period is 100, then the first 99 values of the
moving_avg tensor will be
0, and then each value afterwards will be the
actual calculated 100-period moving average.
We then convert the
moving_avg tensor to a numpy array and return this result.
Now, if our initial condition was not met, so that the length of the values array that was initially passed in was not at least the
period size, then we just return a numpy array of all
zeros with a length equal to the
values array that was passed in.
To show an example of this
plot() function, we’ll pass in a numpy array that contains 300 random values between 0 and 1, and we’ll specify 100 as our moving average period.
The actual values are plotted in blue, and the orange line is the 100-period moving average across these values.
We can see that the 100-period moving average is 0 for the first 99 values, and then we get the first calculated moving average at the 100th value. This represents the average of the first 100 values in the array. If we skip over to the moving average at value 200, then the orange line at this point represents the average of the second 100 values between value 100 to 200.
When we train our network, we’ll be using this plot to show our performance over time.
Next time we’ll be picking up with developing our main program! We’ll take all of these classes and functions we’ve developed over the last couple episodes and see how they all come together in our main program to train our DQN.
Until then, please like this video to let us know you’re learning, and take the corresponding quiz to test your own understanding! Don’t forget about the deeplizard hivemind for exclusive perks and rewards. See ya in the next one!