強化学習の初心者向けに、DQNでCart-Poleを動かしてみます。
【強化学習入門】DQNでCart-Poleを動かす(Google Colab+OpenAI Gym)
強化学習の入門として「Cart-Pole問題」というものがあります。Cart-Pole問題とは、棒(Pole)の支点を台車(Cart)に固定して、その台車を「左右」に動かすことでバランスをとり、棒が倒さないようにする問題のことです。

以前の記事では、そのCart-PoleをQ学習というアルゴリズムで動かしました。よければ下記リンクからチェックしてみてください。
▶【強化学習入門】Q学習でCart-Poleを動かす(Google Colab+OpenAI Gym)
一方で、今回の記事では、深層学習とQ学習を組み合わせたDQN(Deep Q Network)で動かしてみます。
今回も、実行環境はGoogle Colaboratory、Cart-PoleはOpenAI Gymのモデルを用います。
ブラウザでGoogle Colaboratoryをひらき、本記事のプログラムをコピー&ペーストすれば実行できます。
この記事のポイント
- 強化学習の入門者向け
- 強化学習を実装レベルで勉強(i.e. アルゴリズムの詳細解説は省略)
- Google Colaboratoryで動かせるコードをご紹介(i.e. ブラウザ上だけで完結可能)
パッケージのインストール
OpenAI Gymをインストールします。
!pip install gym
表示のためのパッケージもインストールします。
!apt update !apt install xvfb !pip install pyvirtualdisplay
DQNで動かしてみる
ソースコード
## OpenAI Gym import gym ## 表示用のパッケージ import base64 import io from gym.wrappers import Monitor from IPython import display from pyvirtualdisplay import Display ## その他必要なパッケージ import numpy as np import torch from torch import nn import torch.optim as optim ## QNetクラス class QNet(nn.Module): def __init__(self, num_states, dim_mid, num_actions): super().__init__() ## ネットワーク構造を設定(単純な全結合層) self.fc = nn.Sequential( nn.Linear(num_states, dim_mid), nn.ReLU(), nn.Linear(dim_mid, dim_mid), nn.ReLU(), nn.Linear(dim_mid, num_actions) ) def forward(self, x): x = self.fc(x) return x ## Brainクラス class Brain: def __init__(self, num_states, num_actions, gamma, r, lr): ## パラメータをセット self.num_states = num_states self.num_actions = num_actions self.eps = 1.0 # for epsilon greedy algorithm self.gamma = gamma self.r = r ## 計算アクセレータを指定(CPU or GPU) self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print("self.device = ", self.device) ## Qネットワークを用意 self.q_net = QNet(num_states, 64, num_actions) self.q_net.to(self.device) ## 損失関数を用意(MSE=平均二乗誤差) self.criterion = nn.MSELoss() ## 最適化アルゴリズムを用意 # self.optimizer = optim.RMSprop(self.q_net.parameters(), lr=lr) self.optimizer = optim.Adam(self.q_net.parameters(), lr=lr) ## Qネットワークを更新 ## 引数:観測情報、アクションのインデックス、報酬、アクション後の観測情報 def updateQnet(self, obs_numpy, action, reward, next_obs_numpy): ## numpy型をテンソル型へ変換 obs_tensor = torch.from_numpy(obs_numpy).float() obs_tensor.unsqueeze_(0) # 先頭にバッチ数(=1)を格納する次元を増やす obs_tensor = obs_tensor.to(self.device) next_obs_tensor = torch.from_numpy(next_obs_numpy).float() next_obs_tensor.unsqueeze_(0) next_obs_tensor = next_obs_tensor.to(self.device) ## 最適化アルゴリズムの勾配をリセット self.optimizer.zero_grad() ## 推論(Q値を出力) self.q_net.train() # 訓練モード q = self.q_net(obs_tensor) ## 勾配を固定(勾配を計算しない) with torch.no_grad(): ## Q値の目標値を計算 self.q_net.eval() # 推論モード label = self.q_net(obs_tensor) next_q = self.q_net(next_obs_tensor) label[:, action] = reward + self.gamma*np.max(next_q.cpu().detach().numpy(), axis=1)[0] ## 損失を計算&誤差逆伝播 loss = self.criterion(q, label) loss.backward() self.optimizer.step() ## アクションを決定 ## 引数:観測情報、訓練フラグ def getAction(self, obs_numpy, is_training): if is_training and np.random.rand() < self.eps: action = np.random.randint(self.num_actions) else: ## numpy型をテンソル型へ変換 obs_tensor = torch.from_numpy(obs_numpy).float() obs_tensor.unsqueeze_(0) obs_tensor = obs_tensor.to(self.device) with torch.no_grad(): ## Q値が高いアクションを取得 self.q_net.eval() q = self.q_net(obs_tensor) action = np.argmax(q.cpu().detach().numpy(), axis=1)[0] ## epsを更新 if is_training and self.eps > 0.1: self.eps *= self.r return action ## Agentクラス class Agent: def __init__(self, num_states, num_actions, gamma, r, lr): ## Brainを用意 self.brain = Brain(num_states, num_actions, gamma, r, lr) ## Qネットワークを更新 ## 引数:観測情報、アクションのインデックス、報酬、アクション後の観測情報 def updateQnet(self, obs, action, reward, next_obs): self.brain.updateQnet(obs, action, reward, next_obs) ## アクションを決定 ## 引数:観測情報、訓練フラグ def getAction(self, obs, is_training): action = self.brain.getAction(obs, is_training) return action ## Environmentクラス class Environment: def __init__(self, num_episodes, max_consecutive_completion, max_step, gamma, r, lr): ## パラメータをセット self.num_episodes = num_episodes self.max_consecutive_completion = max_consecutive_completion self.max_step = max_step ## Cart-Poleの環境を用意 self.env = Monitor(gym.make('CartPole-v0'), './videos/', video_callable=(lambda ep: ep % 100 == 0), force=True) # 100エピソードごとの動画を保存 ## Agentを用意 num_states = self.env.observation_space.shape[0] # position, velocity, angle, angular velocity num_actions = self.env.action_space.n self.agent = Agent(num_states, num_actions, gamma, r, lr) ## 訓練 ## 引数:なし def train(self): ## 連続で成功したエピソード数を数えるカウンター consecutive_completion = 0 ## 指定するエピソード数でループ for episode in range(self.num_episodes): obs = self.env.reset() episode_reward = 0 ## 指定する最大ステップ数でループ for step in range(self.max_step): ## アクションを決定 action = self.agent.getAction(obs, is_training=True) ## アクション後の状態を取得 next_obs, _, is_done, _ = self.env.step(action) ## 報酬を付与 if is_done: if step < max_step - 1: reward = -1 consecutive_completion = 0 else: reward = 1 consecutive_completion += 1 else: reward = 0 episode_reward += reward ## Qネットワークを更新 self.agent.updateQnet(obs, action, reward, next_obs) ## 次のステップへ obs = next_obs ## 終了判定 if is_done: print('{0} Episode: Finished after {1} time steps with reward {2}'.format(episode, step+1, episode_reward)) break ## 早期終了判定 if consecutive_completion > self.max_consecutive_completion: print("It has completed {} consecutive episodes".format(consecutive_completion)) break ## 評価(Q値が最大となるアクションを選択して、Qテーブルの更新はしない) ## 引数:なし def evaluate(self): obs = self.env.reset() for step in range(self.max_step): ## Q値が最大となるアクションを選択 action = self.agent.getAction(obs, is_training=False) ## アクション後の状態を取得 next_obs, _, is_done, _ = self.env.step(action) ## 次のステップへ obs = next_obs ## 終了判定 if is_done: print('Evaluation: Finished after {} time steps'.format(step+1)) break ## 描画用の関数 def show_video(env): env.reset() for frame in env.videos: print("frame = ", frame) video = io.open(frame[0], 'r+b').read() encoded = base64.b64encode(video) display.display(display.HTML(data=""" <video alt="" controls> <source src="data:video/mp4;base64,{0}" type="video/mp4" /> </video> """.format(encoded.decode('ascii'))) ) ## 仮想のディスプレイを用意 virtual_display = Display() virtual_display.start() ## パラメータ num_episodes = 1000 max_consecutive_completion = 10 max_step = 200 gamma = 0.9 r = 0.99 lr = 0.001 ## 実行 cartpole_env = Environment(num_episodes, max_consecutive_completion, max_step, gamma, r, lr) cartpole_env.train() cartpole_env.evaluate() show_video(cartpole_env.env)
結果
学習終了後、上手くバランスをとれるようになりました。

さいごに
Google Colaboratory上で、OpenAI GymのCart-PoleをDQNで動かしてみました。
少しでも参考になれば幸いです。
以上です。
関連記事
Ad.
コメント