強化学習の初心者向けに、DQNでCart-Poleを動かしてみます。
【強化学習入門】DQNでCart-Poleを動かす(Google Colab+OpenAI Gym)
強化学習の入門として「Cart-Pole問題」というものがあります。Cart-Pole問題とは、棒(Pole)の支点を台車(Cart)に固定して、その台車を「左右」に動かすことでバランスをとり、棒が倒さないようにする問題のことです。

以前の記事では、そのCart-PoleをQ学習というアルゴリズムで動かしました。よければ下記リンクからチェックしてみてください。
▶【強化学習入門】Q学習でCart-Poleを動かす(Google Colab+OpenAI Gym)
一方で、今回の記事では、深層学習とQ学習を組み合わせたDQN(Deep Q Network)で動かしてみます。
今回も、実行環境はGoogle Colaboratory、Cart-PoleはOpenAI Gymのモデルを用います。
ブラウザでGoogle Colaboratoryをひらき、本記事のプログラムをコピー&ペーストすれば実行できます。
この記事のポイント
- 強化学習の入門者向け
- 強化学習を実装レベルで勉強(i.e. アルゴリズムの詳細解説は省略)
- Google Colaboratoryで動かせるコードをご紹介(i.e. ブラウザ上だけで完結可能)
パッケージのインストール
OpenAI Gymをインストールします。
!pip install gym
表示のためのパッケージもインストールします。
!apt update !apt install xvfb !pip install pyvirtualdisplay
DQNで動かしてみる
ソースコード
## OpenAI Gym
import gym
## 表示用のパッケージ
import base64
import io
from gym.wrappers import Monitor
from IPython import display
from pyvirtualdisplay import Display
## その他必要なパッケージ
import numpy as np
import torch
from torch import nn
import torch.optim as optim
## QNetクラス
class QNet(nn.Module):
def __init__(self, num_states, dim_mid, num_actions):
super().__init__()
## ネットワーク構造を設定(単純な全結合層)
self.fc = nn.Sequential(
nn.Linear(num_states, dim_mid),
nn.ReLU(),
nn.Linear(dim_mid, dim_mid),
nn.ReLU(),
nn.Linear(dim_mid, num_actions)
)
def forward(self, x):
x = self.fc(x)
return x
## Brainクラス
class Brain:
def __init__(self, num_states, num_actions, gamma, r, lr):
## パラメータをセット
self.num_states = num_states
self.num_actions = num_actions
self.eps = 1.0 # for epsilon greedy algorithm
self.gamma = gamma
self.r = r
## 計算アクセレータを指定(CPU or GPU)
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("self.device = ", self.device)
## Qネットワークを用意
self.q_net = QNet(num_states, 64, num_actions)
self.q_net.to(self.device)
## 損失関数を用意(MSE=平均二乗誤差)
self.criterion = nn.MSELoss()
## 最適化アルゴリズムを用意
# self.optimizer = optim.RMSprop(self.q_net.parameters(), lr=lr)
self.optimizer = optim.Adam(self.q_net.parameters(), lr=lr)
## Qネットワークを更新
## 引数:観測情報、アクションのインデックス、報酬、アクション後の観測情報
def updateQnet(self, obs_numpy, action, reward, next_obs_numpy):
## numpy型をテンソル型へ変換
obs_tensor = torch.from_numpy(obs_numpy).float()
obs_tensor.unsqueeze_(0) # 先頭にバッチ数(=1)を格納する次元を増やす
obs_tensor = obs_tensor.to(self.device)
next_obs_tensor = torch.from_numpy(next_obs_numpy).float()
next_obs_tensor.unsqueeze_(0)
next_obs_tensor = next_obs_tensor.to(self.device)
## 最適化アルゴリズムの勾配をリセット
self.optimizer.zero_grad()
## 推論(Q値を出力)
self.q_net.train() # 訓練モード
q = self.q_net(obs_tensor)
## 勾配を固定(勾配を計算しない)
with torch.no_grad():
## Q値の目標値を計算
self.q_net.eval() # 推論モード
label = self.q_net(obs_tensor)
next_q = self.q_net(next_obs_tensor)
label[:, action] = reward + self.gamma*np.max(next_q.cpu().detach().numpy(), axis=1)[0]
## 損失を計算&誤差逆伝播
loss = self.criterion(q, label)
loss.backward()
self.optimizer.step()
## アクションを決定
## 引数:観測情報、訓練フラグ
def getAction(self, obs_numpy, is_training):
if is_training and np.random.rand() < self.eps:
action = np.random.randint(self.num_actions)
else:
## numpy型をテンソル型へ変換
obs_tensor = torch.from_numpy(obs_numpy).float()
obs_tensor.unsqueeze_(0)
obs_tensor = obs_tensor.to(self.device)
with torch.no_grad():
## Q値が高いアクションを取得
self.q_net.eval()
q = self.q_net(obs_tensor)
action = np.argmax(q.cpu().detach().numpy(), axis=1)[0]
## epsを更新
if is_training and self.eps > 0.1:
self.eps *= self.r
return action
## Agentクラス
class Agent:
def __init__(self, num_states, num_actions, gamma, r, lr):
## Brainを用意
self.brain = Brain(num_states, num_actions, gamma, r, lr)
## Qネットワークを更新
## 引数:観測情報、アクションのインデックス、報酬、アクション後の観測情報
def updateQnet(self, obs, action, reward, next_obs):
self.brain.updateQnet(obs, action, reward, next_obs)
## アクションを決定
## 引数:観測情報、訓練フラグ
def getAction(self, obs, is_training):
action = self.brain.getAction(obs, is_training)
return action
## Environmentクラス
class Environment:
def __init__(self, num_episodes, max_consecutive_completion, max_step, gamma, r, lr):
## パラメータをセット
self.num_episodes = num_episodes
self.max_consecutive_completion = max_consecutive_completion
self.max_step = max_step
## Cart-Poleの環境を用意
self.env = Monitor(gym.make('CartPole-v0'), './videos/', video_callable=(lambda ep: ep % 100 == 0), force=True) # 100エピソードごとの動画を保存
## Agentを用意
num_states = self.env.observation_space.shape[0] # position, velocity, angle, angular velocity
num_actions = self.env.action_space.n
self.agent = Agent(num_states, num_actions, gamma, r, lr)
## 訓練
## 引数:なし
def train(self):
## 連続で成功したエピソード数を数えるカウンター
consecutive_completion = 0
## 指定するエピソード数でループ
for episode in range(self.num_episodes):
obs = self.env.reset()
episode_reward = 0
## 指定する最大ステップ数でループ
for step in range(self.max_step):
## アクションを決定
action = self.agent.getAction(obs, is_training=True)
## アクション後の状態を取得
next_obs, _, is_done, _ = self.env.step(action)
## 報酬を付与
if is_done:
if step < max_step - 1:
reward = -1
consecutive_completion = 0
else:
reward = 1
consecutive_completion += 1
else:
reward = 0
episode_reward += reward
## Qネットワークを更新
self.agent.updateQnet(obs, action, reward, next_obs)
## 次のステップへ
obs = next_obs
## 終了判定
if is_done:
print('{0} Episode: Finished after {1} time steps with reward {2}'.format(episode, step+1, episode_reward))
break
## 早期終了判定
if consecutive_completion > self.max_consecutive_completion:
print("It has completed {} consecutive episodes".format(consecutive_completion))
break
## 評価(Q値が最大となるアクションを選択して、Qテーブルの更新はしない)
## 引数:なし
def evaluate(self):
obs = self.env.reset()
for step in range(self.max_step):
## Q値が最大となるアクションを選択
action = self.agent.getAction(obs, is_training=False)
## アクション後の状態を取得
next_obs, _, is_done, _ = self.env.step(action)
## 次のステップへ
obs = next_obs
## 終了判定
if is_done:
print('Evaluation: Finished after {} time steps'.format(step+1))
break
## 描画用の関数
def show_video(env):
env.reset()
for frame in env.videos:
print("frame = ", frame)
video = io.open(frame[0], 'r+b').read()
encoded = base64.b64encode(video)
display.display(display.HTML(data="""
<video alt="" controls>
<source src="data:video/mp4;base64,{0}" type="video/mp4" />
</video>
""".format(encoded.decode('ascii')))
)
## 仮想のディスプレイを用意
virtual_display = Display()
virtual_display.start()
## パラメータ
num_episodes = 1000
max_consecutive_completion = 10
max_step = 200
gamma = 0.9
r = 0.99
lr = 0.001
## 実行
cartpole_env = Environment(num_episodes, max_consecutive_completion, max_step, gamma, r, lr)
cartpole_env.train()
cartpole_env.evaluate()
show_video(cartpole_env.env)
結果
学習終了後、上手くバランスをとれるようになりました。

さいごに
Google Colaboratory上で、OpenAI GymのCart-PoleをDQNで動かしてみました。
少しでも参考になれば幸いです。
以上です。
関連記事
Ad.
コメント