前回、状態値及び行動値が共に離散化された環境下で Deep Q-Network を用いた強化学習を実装した. 今回は、状態値及び行動値が共に実数で表される環境下での強化学習を見てみる. これは、センサー付きのロボット制御等に応用可能な技術となる. 連続的な行動を学習する為に、replay buffer と actor-critic を用いた policy gradient method を採用した Deep Deterministic Policy Gradient (DDPG) を試してみる. actor に policy gradients を実装し、critic に DQN を実装して、 actor が決定した行動の善し悪しを critic が評価する. replay buffer に過去の状況・行動・報酬を保存しておき、それを用いて学習を行う. exploration policy には、 Ornstein-Uhlenbeck process を使用する.
Ornstein-Uhlenbeck process とは、ブラウン運動をする粒子の速度を表す様な物で、短期的には連続な変化をする. 以下がその実装.
class OUNoise:
def __init__(self, size, mu=0., theta=0.05, sigma=0.25):
self.size = size
self.mu = mu
self.theta = theta
self.sigma = sigma
self.reset()
def reset(self):
self.state = np.ones(self.size)*self.mu
def sample(self):
x = self.state
dx = self.theta * (self.mu - x) + self.sigma*np.random.randn(len(x))
self.state = x + dx
return self.state
replay buffer とは、リングバッファーのようなもので、ある一定期間まで遡った最新情報を保存しておき、バッチ学習に使用する.
class ReplayBuffer:
def __init__(self, buf_size):
self.buffer = deque(maxlen=buf_size)
self.experience = namedtuple("Experience",
field_names=['s0', 'a', 'r', 's1'])
def add(self, s0, a, r, s1):
e = self.experience(s0, a, r, s1)
def sample(self, batch_size):
return random.sample(self.buffer, k=batch_size)
def clear(self):
self.buffer.clear()
def __len__(self):
return len(self.memory)
critic には DQN を実装する.
states = Input(shape=(nstate,), name='states')
actions = Input(shape=(naction,), name='actions')
nnet = Dense(units=20, activation='relu')(states)
nnet = Add()([nnet, actions])
nnet = Dense(units=20, activation='relu')(nnet)
Q_values = Dense(units=1, name='q_values')(nnet)
model = Model(inputs=[states, actions], outputs=Q_values)
actor には policy network を実装する
states = Input(shape=(nstate,), name='states')
nnet = Dense(units=40, activation='relu')(states)
nnet = Dense(units=20, activation='relu')(nnet)
actions = Dense(units=naction, activation='tanh', name='actions')(nnet)
model = Model(inputs=states, outputs=actions)
以下に示す、DDPG algorithm を実装する.
for t = 1, T do
select action a[t] = mu(s[t]|theta^mu) + N[t]
action then get reward r[t] and observe new stat s[t+1]
store experience [s[t], a[t], r[t], s[t+1]] in replay buffer
set y[i] = r[i] + gamma Q'(s[i+1], mu'(s[i+1]|theta^mu')|theta^Q')
update critic by minimising the MSE loss
update actor policy with sampled policy gradient
soft update target networks
MountainCarContinuous-v0 に対して DDPG を実装したコードはこちらで、学習結果は以下の様になった.