強化学習の初心者向けに、DQNでCart-Poleを動かしてみます。

【強化学習入門】DQNでCart-Poleを動かす(Google Colab+OpenAI Gym)

強化学習の入門として「Cart-Pole問題」というものがあります。Cart-Pole問題とは、棒(Pole)の支点を台車(Cart)に固定して、その台車を「左右」に動かすことでバランスをとり、棒が倒さないようにする問題のことです。

cartpole

以前の記事では、そのCart-PoleをQ学習というアルゴリズムで動かしました。よければ下記リンクからチェックしてみてください。

【強化学習入門】Q学習でCart-Poleを動かす(Google Colab+OpenAI Gym)

一方で、今回の記事では、深層学習とQ学習を組み合わせたDQN(Deep Q Network)で動かしてみます。

今回も、実行環境はGoogle Colaboratory、Cart-PoleはOpenAI Gymのモデルを用います。

ブラウザでGoogle Colaboratoryをひらき、本記事のプログラムをコピー&ペーストすれば実行できます。

この記事のポイント

  • 強化学習の入門者向け
  • 強化学習を実装レベルで勉強(i.e. アルゴリズムの詳細解説は省略)
  • Google Colaboratoryで動かせるコードをご紹介(i.e. ブラウザ上だけで完結可能)

パッケージのインストール

OpenAI Gymをインストールします。

!pip install gym

表示のためのパッケージもインストールします。

!apt update
!apt install xvfb
!pip install pyvirtualdisplay

DQNで動かしてみる

ソースコード

## OpenAI Gym
import gym

## 表示用のパッケージ
import base64
import io
from gym.wrappers import Monitor
from IPython import display
from pyvirtualdisplay import Display

## その他必要なパッケージ
import numpy as np

import torch
from torch import nn
import torch.optim as optim


## QNetクラス
class QNet(nn.Module):
	def __init__(self, num_states, dim_mid, num_actions):
		super().__init__()

		## ネットワーク構造を設定(単純な全結合層)
		self.fc = nn.Sequential(
			nn.Linear(num_states, dim_mid),
			nn.ReLU(),
			nn.Linear(dim_mid, dim_mid),
			nn.ReLU(),
			nn.Linear(dim_mid, num_actions)
		)

	def forward(self, x):
		x = self.fc(x)
		return x


## Brainクラス
class Brain:
	def __init__(self, num_states, num_actions, gamma, r, lr):
		## パラメータをセット
		self.num_states = num_states
		self.num_actions = num_actions
		self.eps = 1.0  # for epsilon greedy algorithm
		self.gamma = gamma
		self.r = r
		## 計算アクセレータを指定(CPU or GPU)
		self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
		print("self.device = ", self.device)
		## Qネットワークを用意
		self.q_net = QNet(num_states, 64, num_actions)
		self.q_net.to(self.device)
		## 損失関数を用意(MSE=平均二乗誤差)
		self.criterion = nn.MSELoss()
		## 最適化アルゴリズムを用意
		# self.optimizer = optim.RMSprop(self.q_net.parameters(), lr=lr)
		self.optimizer = optim.Adam(self.q_net.parameters(), lr=lr)

	## Qネットワークを更新
	## 引数:観測情報、アクションのインデックス、報酬、アクション後の観測情報
	def updateQnet(self, obs_numpy, action, reward, next_obs_numpy):
		## numpy型をテンソル型へ変換
		obs_tensor = torch.from_numpy(obs_numpy).float()
		obs_tensor.unsqueeze_(0)	# 先頭にバッチ数(=1)を格納する次元を増やす
		obs_tensor = obs_tensor.to(self.device)

		next_obs_tensor = torch.from_numpy(next_obs_numpy).float()
		next_obs_tensor.unsqueeze_(0)
		next_obs_tensor = next_obs_tensor.to(self.device)

		## 最適化アルゴリズムの勾配をリセット
		self.optimizer.zero_grad()

		## 推論(Q値を出力)
		self.q_net.train()	# 訓練モード
		q = self.q_net(obs_tensor)

		## 勾配を固定(勾配を計算しない)
		with torch.no_grad():
			## Q値の目標値を計算
			self.q_net.eval()	# 推論モード
			label = self.q_net(obs_tensor)
			next_q = self.q_net(next_obs_tensor)

			label[:, action] = reward + self.gamma*np.max(next_q.cpu().detach().numpy(), axis=1)[0]
## 損失を計算&誤差逆伝播
		loss = self.criterion(q, label)
		loss.backward()
		self.optimizer.step()

## アクションを決定
	## 引数:観測情報、訓練フラグ
	def getAction(self, obs_numpy, is_training):
		if is_training and np.random.rand() < self.eps:
			action = np.random.randint(self.num_actions)
		else:
			## numpy型をテンソル型へ変換
			obs_tensor = torch.from_numpy(obs_numpy).float()
			obs_tensor.unsqueeze_(0)
			obs_tensor = obs_tensor.to(self.device)
			with torch.no_grad():
				## Q値が高いアクションを取得
				self.q_net.eval()
				q = self.q_net(obs_tensor)
				action = np.argmax(q.cpu().detach().numpy(), axis=1)[0]
		## epsを更新
		if is_training and self.eps > 0.1:
			self.eps *= self.r
		return action


## Agentクラス
class Agent:
	def __init__(self, num_states, num_actions, gamma, r, lr):
		## Brainを用意
		self.brain = Brain(num_states, num_actions, gamma, r, lr)

	## Qネットワークを更新
	## 引数:観測情報、アクションのインデックス、報酬、アクション後の観測情報
	def updateQnet(self, obs, action, reward, next_obs):
		self.brain.updateQnet(obs, action, reward, next_obs)

	## アクションを決定
	## 引数:観測情報、訓練フラグ
	def getAction(self, obs, is_training):
		action = self.brain.getAction(obs, is_training)
		return action


## Environmentクラス
class Environment:
	def __init__(self, num_episodes, max_consecutive_completion, max_step, gamma, r, lr):
		## パラメータをセット
		self.num_episodes = num_episodes
		self.max_consecutive_completion = max_consecutive_completion
		self.max_step = max_step
		## Cart-Poleの環境を用意
		self.env = Monitor(gym.make('CartPole-v0'), './videos/', video_callable=(lambda ep: ep % 100 == 0), force=True)	# 100エピソードごとの動画を保存
		## Agentを用意
		num_states = self.env.observation_space.shape[0]	# position, velocity, angle, angular velocity
		num_actions = self.env.action_space.n
		self.agent = Agent(num_states, num_actions, gamma, r, lr)

	## 訓練
	## 引数:なし
	def train(self):
		## 連続で成功したエピソード数を数えるカウンター
		consecutive_completion = 0

## 指定するエピソード数でループ		
		for episode in range(self.num_episodes):
			obs = self.env.reset()
			episode_reward = 0

			## 指定する最大ステップ数でループ
			for step in range(self.max_step):
				## アクションを決定
				action = self.agent.getAction(obs, is_training=True)
				## アクション後の状態を取得
				next_obs, _, is_done, _ = self.env.step(action)
				## 報酬を付与
				if is_done:
					if step < max_step - 1:
						reward = -1
						consecutive_completion = 0
					else:
						reward = 1
						consecutive_completion += 1
				else:
					reward = 0
				episode_reward += reward
				## Qネットワークを更新
				self.agent.updateQnet(obs, action, reward, next_obs)
				## 次のステップへ
				obs = next_obs

				## 終了判定
				if is_done:
					print('{0} Episode: Finished after {1} time steps with reward {2}'.format(episode, step+1, episode_reward))
					break
			## 早期終了判定
			if consecutive_completion > self.max_consecutive_completion:
				print("It has completed {} consecutive episodes".format(consecutive_completion))
				break

	## 評価(Q値が最大となるアクションを選択して、Qテーブルの更新はしない)
	## 引数:なし
	def evaluate(self):
		obs = self.env.reset()
		
		for step in range(self.max_step):
			## Q値が最大となるアクションを選択
			action = self.agent.getAction(obs, is_training=False)
			## アクション後の状態を取得
			next_obs, _, is_done, _ = self.env.step(action)
			## 次のステップへ
			obs = next_obs

			## 終了判定
			if is_done:
				print('Evaluation: Finished after {} time steps'.format(step+1))
				break


## 描画用の関数
def show_video(env):
	env.reset()
	for frame in env.videos:
		print("frame = ", frame)
		video = io.open(frame[0], 'r+b').read()
		encoded = base64.b64encode(video)

		display.display(display.HTML(data="""
			<video alt="" controls>
			<source src="data:video/mp4;base64,{0}" type="video/mp4" />
			</video>
			""".format(encoded.decode('ascii')))
		)


## 仮想のディスプレイを用意
virtual_display = Display()
virtual_display.start()

## パラメータ
num_episodes = 1000
max_consecutive_completion = 10
max_step = 200
gamma = 0.9
r = 0.99
lr = 0.001

## 実行
cartpole_env = Environment(num_episodes, max_consecutive_completion, max_step, gamma, r, lr)
cartpole_env.train()
cartpole_env.evaluate()
show_video(cartpole_env.env)

結果

学習終了後、上手くバランスをとれるようになりました。

cartpole_dqn

さいごに

Google Colaboratory上で、OpenAI GymのCart-PoleをDQNで動かしてみました。

少しでも参考になれば幸いです。


以上です。

関連記事

【強化学習入門】Q学習でCart-Poleを動かす(Google Colab+OpenAI Gym)

Ad.