Week 4, Day 3 (Deep Q-Learning with Pytorch)
Welcome to third day (Week 4) of the McE-51069 course.
Deep Q-Learning with Pytorch
In this notebook, we will implement deep q-learning algorithm using Pytorch to solve Atari games from OpenAI Gym. More specifically, we will be using PongNoFrameskip-v4
environment. You can find more information of the Deep Q-Learning in the original paper.
Lets start by importing necessary modules and packages.
import math
import time
import random
import gym
import gym.spaces
import cv2
import numpy as np
import matplotlib.pyplot as plt
import collections
from collections import namedtuple, deque
# pytorch related
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
We will be using different gym wrappers to wrap our pong env so that we no need to do a lot of hardwork as described in the DQN paper such as frame-skipping, frame-stacking, etc.
# ref: https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On/blob/master/Chapter06/lib/wrappers.py
# ref: https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
class FireResetEnv(gym.Wrapper):
def __init__(self, env=None):
"""For environments where the user need to press FIRE for the game to start."""
super(FireResetEnv, self).__init__(env)
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
assert len(env.unwrapped.get_action_meanings()) >= 3
def step(self, action):
return self.env.step(action)
def reset(self):
self.env.reset()
obs, _, done, _ = self.env.step(1)
if done:
self.env.reset()
obs, _, done, _ = self.env.step(2)
if done:
self.env.reset()
return obs
class MaxAndSkipEnv(gym.Wrapper):
def __init__(self, env=None, skip=4):
"""Return only every `skip`-th frame"""
super(MaxAndSkipEnv, self).__init__(env)
# most recent raw observations (for max pooling across time steps)
self._obs_buffer = collections.deque(maxlen=2)
self._skip = skip
def step(self, action):
total_reward = 0.0
done = None
for _ in range(self._skip):
obs, reward, done, info = self.env.step(action)
self._obs_buffer.append(obs)
total_reward += reward
if done:
break
max_frame = np.max(np.stack(self._obs_buffer), axis=0)
return max_frame, total_reward, done, info
def reset(self):
"""Clear past frame buffer and init. to first obs. from inner env."""
self._obs_buffer.clear()
obs = self.env.reset()
self._obs_buffer.append(obs)
return obs
class ProcessFrame84(gym.ObservationWrapper):
def __init__(self, env=None):
super(ProcessFrame84, self).__init__(env)
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)
def observation(self, obs):
return ProcessFrame84.process(obs)
@staticmethod
def process(frame):
if frame.size == 210 * 160 * 3:
img = np.reshape(frame, [210, 160, 3]).astype(np.float32)
elif frame.size == 250 * 160 * 3:
img = np.reshape(frame, [250, 160, 3]).astype(np.float32)
else:
assert False, "Unknown resolution."
img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + img[:, :, 2] * 0.114
resized_screen = cv2.resize(img, (84, 110), interpolation=cv2.INTER_AREA)
x_t = resized_screen[18:102, :]
x_t = np.reshape(x_t, [84, 84, 1])
return x_t.astype(np.uint8)
class ImageToPyTorch(gym.ObservationWrapper):
def __init__(self, env):
super(ImageToPyTorch, self).__init__(env)
old_shape = self.observation_space.shape
self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(old_shape[-1], old_shape[0], old_shape[1]),
dtype=np.float32)
def observation(self, observation):
return np.moveaxis(observation, 2, 0)
class ScaledFloatFrame(gym.ObservationWrapper):
def observation(self, obs):
return np.array(obs).astype(np.float32) / 255.0
class BufferWrapper(gym.ObservationWrapper):
def __init__(self, env, n_steps, dtype=np.float32):
super(BufferWrapper, self).__init__(env)
self.dtype = dtype
old_space = env.observation_space
self.observation_space = gym.spaces.Box(old_space.low.repeat(n_steps, axis=0),
old_space.high.repeat(n_steps, axis=0), dtype=dtype)
def reset(self):
self.buffer = np.zeros_like(self.observation_space.low, dtype=self.dtype)
return self.observation(self.env.reset())
def observation(self, observation):
self.buffer[:-1] = self.buffer[1:]
self.buffer[-1] = observation
return self.buffer
def make_env(env_name):
env = gym.make(env_name)
env = MaxAndSkipEnv(env)
env = FireResetEnv(env)
env = ProcessFrame84(env)
env = ImageToPyTorch(env)
env = BufferWrapper(env, 4)
return ScaledFloatFrame(env)
After initializing the wrapper classes and functions, environment creation is straight forward just by calling make_env("PongNoFrameskip-v4")
.
env = make_env("PongNoFrameskip-v4")
The env's state is already stacked into 4 consecutive frames (4x84x84), thanks to our wrappers.
env.observation_space
The env has 6 action spaces namely:
env.action_space
env.unwrapped.get_action_meanings()
Lets visualize the preprocessed state-space image.
frame = env.reset()
# (4x84x84) -> (84x84x4)
transposed_frame = frame.transpose(1,2,0)
plt.imshow(transposed_frame)
Before moving on, lets watch our agent playing pong. To do this, we cannot render directly on colab, so that we need to install some libs. We'll do that running the following cell.
!apt update && apt install xvfb
!pip install gym-notebook-wrapper
import gnwrapper
env = gnwrapper.Monitor(make_env("PongNoFrameskip-v4"),directory="pong-video-v1", force=True)
total_reward = 0
state = env.reset()
while True:
state, reward, done, _ = env.step(env.action_space.sample())
total_reward += reward
if done:
break
print("Total reward: %.2f" % total_reward)
env.display()
The following cell creates the deep q-network model DQN
using Pytorch.
class DQN(nn.Module):
def __init__(self, in_channels=4, n_actions=14):
"""
Initialize Deep Q Network
Args:
in_channels (int): number of input channels
n_actions (int): number of outputs
"""
super(DQN, self).__init__()
self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=8, stride=4)
self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
self.fc4 = nn.Linear(7 * 7 * 64, 512)
self.head = nn.Linear(512, n_actions)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = F.relu(self.fc4(x.view(x.size(0), -1)))
return self.head(x)
The following cell is a ReplayBuffer
class, which is used for Experience Replay technique.
Experience = namedtuple('Experience', field_names=['state', 'action', 'reward', 'done', 'new_state'])
class ReplayBuffer:
def __init__(self, capacity):
self.buffer = deque(maxlen=capacity)
def __len__(self):
return len(self.buffer)
def append(self, experience):
self.buffer.append(experience)
def sample(self, batch_size):
indices = np.random.choice(len(self.buffer), batch_size, replace=False)
states, actions, rewards, dones, next_states = zip(*[self.buffer[idx] for idx in indices])
return np.array(states), np.array(actions), np.array(rewards, dtype=np.float32), \
np.array(dones, dtype=np.uint8), np.array(next_states)
# Hyperparameters
##################
# minibatch size
BATCH_SIZE = 32
# learning rate of Adam
LEARNING_RATE = 1e-4
# discount factor
GAMMA = 0.99
# initial exploration
EPS_START = 1.0
# final exploration
EPS_END = 0.02
# final exploration frame
EPS_DECAY = 100000
# target network update frequency
TARGET_UPDATE = 1000
# replay memory size
INITIAL_MEMORY = 10000
REPLAY_MEMORY = 10000
# create Q-network
q_network = DQN(n_actions=env.action_space.n).to(device)
# create Target-network
target_network = DQN(n_actions=env.action_space.n).to(device)
target_network.load_state_dict(q_network.state_dict())
# adam optimizer
optimizer = optim.Adam(q_network.parameters(), lr=LEARNING_RATE)
# create memory object for experience replay
replay_buffer = ReplayBuffer(REPLAY_MEMORY)
total_rewards = []
total_reward = 0.0
frame_idx = 0
Run the code cell below to train our agent for total_frames
frames.
total_frames = 1000000
state = env.reset()
while frame_idx < total_frames:
frame_idx += 1
#########################################################
# Epsilon-Greedy Policy action selection
epsilon = max(EPS_END, EPS_START - frame_idx / EPS_DECAY)
if np.random.random() < epsilon:
action = env.action_space.sample()
else:
state_a = np.array([state], copy=False)
state_v = torch.tensor(state_a).to(device)
q_vals_v = q_network(state_v)
_, act_v = torch.max(q_vals_v, dim=1)
action = int(act_v.item())
#########################################################
# do step in the environment
next_state, reward, done, info = env.step(action)
total_reward += reward
# experience replay
exp = Experience(state, action, reward, done, next_state)
replay_buffer.append(exp)
state = next_state
if done:
total_rewards.append(total_reward)
state = env.reset()
total_reward = 0.0
avg_reward = np.mean(total_rewards[-100:])
# print every 10 episodes
if len(total_rewards) % 10 == 0:
print(f"Frames: {frame_idx}, Episodes: {len(total_rewards)}, Average reward: {avg_reward}, Epsilon: {epsilon}")
if avg_reward > 19.5:
print(f"Environment is solved in {frame_idx} frames!")
break
if len(replay_buffer) < INITIAL_MEMORY:
continue
if frame_idx % TARGET_UPDATE == 0:
target_network.load_state_dict(q_network.state_dict())
optimizer.zero_grad()
batch = replay_buffer.sample(BATCH_SIZE)
states, actions, rewards, dones, next_states = batch
states_v = torch.tensor(states).to(device)
next_states_v = torch.tensor(next_states).to(device)
actions_v = torch.tensor(actions).to(device)
rewards_v = torch.tensor(rewards).to(device)
done_mask = torch.ByteTensor(dones).to(device)
state_action_values = q_network(states_v).gather(1, actions_v.unsqueeze(-1)).squeeze(-1)
next_state_values = target_network(next_states_v).max(1)[0]
next_state_values[done_mask] = 0.0
next_state_values = next_state_values.detach()
expected_state_action_values = next_state_values * GAMMA + rewards_v
loss_t = nn.MSELoss()(state_action_values, expected_state_action_values)
loss_t.backward()
optimizer.step()
print("Training Complete!")
torch.save(q_network.state_dict(), f"pong-{total_frames}-v1.pt")
Now, its time to test and watch our smart agent.
saved_file_path = "pong-1000000-v1.pt"
torch.save(target_network.state_dict(), saved_file_path)
q_network = DQN(n_actions=env.action_space.n).to(device)
q_network.load_state_dict(torch.load(saved_file_path))
env = gnwrapper.Monitor(make_env("PongNoFrameskip-v4"),directory="pong-video-v2", force=True)
state = env.reset()
total_reward = 0
while True:
state_a = np.array([state], copy=False)
state_v = torch.tensor(state_a).to(device)
q_vals_v = q_network(state_v)
_, act_v = torch.max(q_vals_v, dim=1)
action = int(act_v.item())
state, reward, done, _ = env.step(action)
total_reward += reward
if done:
break
print("Total reward: %.2f" % total_reward)
env.display()