[CS285] 5. Policy Gradients

Date:     Updated:

카테고리:

태그:


💡 이 글은 『2024 PseudoLab 전반기 강화학습팀』으로 진행되었으며, CS285 Fall 2023를 따라 정리했습니다.

1. Introduction

Policy Gradient는 가장 간단한 RL 알고리즘 중 하나이다. Policy gradient에서는 policy $\pi_{\theta}$를 최적화하게 된다. 여기서는 RL의 목표로부터 policy gradient의 식을 유도해보자. 먼저 RL의 궁극적인 목표는 다음과 같았다.

\[\theta ^ \ast = \arg \max _ \theta \mathbb{E}_ {\tau \sim p_\theta (\tau)} \left[\sum_t r(s_t, a_t)\right]\]

여기서 $J(\theta)$를 다음과 같이 정의하자. 마지막은 Monte Carlo estimation을 통해 추정한 것이다.

\[J(\theta) = \mathbb{E}_ {\tau \sim p_\theta (\tau)} \left[\sum_t r(s_t, a_t)\right] \approx \frac{1}{N} \sum_i \sum_t r(s_{i, t}, a_{i, t})\]

실제로는 $J(\theta)$를 구하는 것이 중요한 것이 아니라 이것의 gradient를 통해 policy를 업데이트해야 하므로 $\nabla_ \theta J(\theta)$를 구해야 한다. 이때 아래 log-gradient trick을 사용한다.

\[p_\theta (\tau) \nabla_ \theta \log p_\theta (\tau) = \nabla_ \theta p_\theta (\tau)\]

이를 이용하면 다음과 같이 정리할 수 있다. (이때 편의상 $\sum_t r(s_t, a_t) = r(\tau)$라고 하자.)

\[\begin{aligned} \nabla_ \theta J(\theta) &= \nabla_ \theta \mathbb{E}_ {\tau \sim p_\theta (\tau)} \left[r(\tau)\right] \\ &= \nabla _\theta \int p_\theta (\tau) r(\tau) d\tau \\ &= \int \nabla_ \theta p_\theta (\tau) r(\tau) d\tau \\ &= \int p_\theta (\tau) \nabla_ \theta \log p_\theta (\tau) r(\tau) d\tau \\ &= \mathbb{E}_ {\tau \sim p_\theta (\tau)} \left[\nabla_ \theta \log p_\theta (\tau) r(\tau)\right] \end{aligned}\]

한편 $p_\theta(\tau)$를 풀어 쓰면 다음과 같다.

\[\begin{aligned} p_\theta (\tau) &= p(s_1) \prod _{t=1} ^T \pi_\theta (a_t \vert s_t) p(s_{t+1} \vert s_t, a_t) \\ \log p_\theta (\tau) &= \log p(s_1) + \sum _{t=1} ^T \left[ \log \pi_\theta (a_t \vert s_t) + \log p(s_{t+1} \vert s_t, a_t) \right] \\ \end{aligned}\]

$\nabla_\theta \log p_\theta (\tau)$를 구하면 $\theta$와 관련있는 항만 남게 되므로 $\nabla_\theta J(\theta)$를 다음과 같이 정리할 수 있다.

\[\nabla_ \theta J(\theta) = \mathbb{E}_ {\tau \sim p_\theta (\tau)} \left[ \left(\sum _{t=1} ^T \nabla_ \theta \log \pi_\theta (a_t \vert s_t) \right) \left(\sum _{t=1} ^T r(s_t, a_t) \right) \right]\]

Monte Carlo estimation을 통해 이를 추정하면 다음과 같다.

\[\nabla_ \theta J(\theta) \approx \frac{1}{N} \sum _{i=1} ^N \left( \sum _{t=1} ^T \nabla_ \theta \log \pi_\theta (a_{i, t} \vert s_{i, t}) \right) \left( \sum _{t=1} ^T r(s_{i, t}, a_{i, t}) \right)\]

이를 통하여 policy를 업데이트한다.

\[\theta \leftarrow \theta + \alpha \nabla_ \theta J(\theta)\]

직관적으로 이 과정은 reward가 높은 쪽으로 policy를 업데이트하는 것으로 해석할 수 있다. 이러한 알고리즘의 예시로 1990년대에 등장한 REINFORCE algorithm이 있으며, 다음과 같이 정리할 수 있다.

  1. Sample $\lbrace \tau^ i \rbrace$ from $\pi_\theta(a_t \vert s_t)$ (run the policy)
  2. $\nabla_ \theta J(\theta) \approx \frac{1}{N} \sum _ i \left( \sum _ t \nabla_ \theta \log \pi_\theta (a_{i, t} \vert s_{i, t}) \right) \left( \sum _ t r(s_{i, t}, a_{i, t}) \right)$
  3. $\theta \leftarrow \theta + \alpha \nabla_ \theta J(\theta)$

그러나 실제로는 위 알고리즘은 잘 작동하지 않는다. 이러한 이유를 이해하고 이를 개선하기 위한 방법에 대해 논의해보자.


2. Understanding Policy Gradients

2.1. Formalization of trial-and-error

먼저 $\nabla_ \theta J(\theta)$ 식에서 등장하는 $\nabla _ \theta \log \pi_\theta (a_t \vert s_t)$를 분석해보자.

\[\nabla_ \theta J(\theta) \approx \frac{1}{N} \sum _{i=1} ^N \left( \sum _{t=1} ^T \nabla_ \theta \log \pi_\theta (a_{i, t} \vert s_{i, t}) \right) \left( \sum _{t=1} ^T r(s_{i, t}, a_{i, t}) \right)\]

이는 action에 대한 log probability의 gradient이다. 이는 supervised learning에서의 log likelihood maximization과 유사하다.

\[\nabla_ \theta J_ {\text{ML}} (\theta) \approx \frac{1}{N} \sum _ {i=1} ^N \left( \sum _ {t=1} ^T \nabla_ \theta \log \pi_\theta (a_{i, t} \vert s_{i, t}) \right)\]

즉, policy gradient를 통한 policy update 과정은 일종의 log-likelihood maximization 과정으로 볼 수 있다. 다만, supervised learning과는 다르게 trajectory가 나쁠 수도 있고 좋을 수도 있기에 reward에 따라 policy가 positive/negative 두 방향 중 어떤 방향으로도 업데이트가 가능하다. 즉, 일종의 weighted log-likelihood maximization으로 해석할 수 있다. 이는 trial-and-error를 log-likelihood maximization에 반영한 식으로 볼 수 있다.

2.2. Partial Observability

실제로는 agent가 partial observability를 가지는 경우가 많다. 이 경우, agent는 state $s$가 아닌 observation $o$를 관측하게 된다. 즉 POMDP의 경우에도 policy gradient를 간단히 다음과 같이 적용할 수 있다.

\[\nabla_ \theta J(\theta) \approx \frac{1}{N} \sum _{i=1} ^N \left( \sum _{t=1} ^T \nabla_ \theta \log \pi_\theta (a_{i, t} \vert o_{i, t}) \right) \left( \sum _{t=1} ^T r(s_{i, t}, a_{i, t}) \right)\]

2.3. Problem of Policy Gradients

지금까지 formalization을 마쳤으니, 다시 돌아와서 이 알고리즘의 문제가 무엇인지 알아보자. 아래 그래프의 초록색은 $\tau$에 따른 reward, 파란색은 policy distribution을 나타낸다.

위와 같은 상황이라면 big negative reward sample을 피하는 방향으로 policy가 업데이트될 것이다. 동일한 케이스로 아래 그래프를 살펴보자.

위의 그래프의 reward에 상수만큼의 값을 더해준 뒤, 이를 통해 update한 모습이다. 이전보다 policy distribution이 크게 변하지 않았다. 이처럼 REINFORCE algorithm은 reward를 어떻게 정의하느냐에 따라 distribution의 차이가 크고 따라서 high variance를 가진다. 이를 줄이기 위해 샘플 수를 늘릴 수 있으나, 이는 high variance를 완전히 해결하지는 못한다. 이 문제를 해결하기 위한 advanced policy gradient method를 알아보자.


3. Reducing Variance

3.1. Causality

먼저 causality라는 개념을 알아야 한다. 이는 $t \lt t^ \prime$일 때, $t^ \prime$ 시점의 policy는 $t$ 시점의 reward에 영향을 미치지 못한다는 것이다. 찬찬히 곱씹어보면 항상 참인 문장임을 알 수 있다. 이는 Markov property와는 다른 개념이다. Markov property는 세팅에 따라 거짓일 수 있지만 causality는 항상 참이다. 조금 더 고급스럽게 설명하자면 past reward가 present decision에 independent하다는 특징이다.

이를 사용하면 $\nabla _ \theta J(\theta)$를 다음과 같이 정리할 수 있다.

\[\begin{aligned} \nabla_ \theta J(\theta) &\approx \frac{1}{N} \sum _{i=1} ^N \left( \sum _{t=1} ^T \nabla_ \theta \log \pi_\theta (a_{i, t} \vert s_{i, t}) \right) \left( \sum _{t=1} ^T r(s_{i, t}, a_{i, t}) \right) \\ &= \frac{1}{N} \sum _{i=1} ^N \sum _{t=1} ^T \nabla_ \theta \log \pi_\theta (a_{i, t} \vert s_{i, t}) \left( \sum _{t^ \prime = t} ^T r(s_{i, t^ \prime}, a_{i, t^ \prime}) \right) \\ \end{aligned}\]

이것이 가능한 이유는 $t$ 시점의 policy $\pi_\theta(a_{i, t} \vert s_{i, t})$는 $t^ \prime \lt t$ 시점의 reward $r(s_{i, t^ \prime}, a_{i, t^ \prime})$에 영향을 미치지 않기 때문이다. Expectation value 관점에서는 여전히 unbiased estimator이다. 이제 term이 많이 사라지면서, variance가 감소하는 효과가 있다.

한편 이때 reward to go를 $\hat{Q}_ {i, t}$로 나타낼 수 있다. 이는 Q function과 동일한 것이다.

\[\hat{Q}_ {i, t} = \sum _{t^ \prime = t} ^T r(s_{i, t^ \prime}, a_{i, t^ \prime})\]

3.2. Baselines

평균보다 잘했다면 probability를 올리고, 그렇지 않다면 낮추는 것을 나타내기 위해 baseline $b$를 설정한다.

\[b = \frac{1}{N} \sum_ {i=1} ^ N r(\tau)\]

그리고 $\nabla_ \theta J(\theta)$를 다음과 같이 나타낸다.

\[\nabla_ \theta J(\theta) \approx \frac{1}{N} \sum _{i=1} ^N \log p_\theta (\tau) \left(r(\tau) - b \right)\]

이렇게 써도 상관없는 이유는 baseline을 빼거나 더하는 것은 unbiased이기 때문이다.

\[\begin{aligned} \mathbb{E} \left[ \nabla_ \theta \log p_\theta (\tau) b \right] &= \int p_\theta (\tau) \nabla_ \theta \log p_\theta (\tau) b d\tau \\ &= \int \nabla_ \theta p_\theta (\tau) b d\tau \\ &= b \nabla_ \theta \int p_\theta (\tau) d\tau \\ &= b \nabla_ \theta 1 \\ &= 0 \end{aligned}\]

실제로 이처럼 평균을 사용한 baseline은 best baseline이 아니지만, 이정도면 괜찮다. 실제로 optimal baseline도 유도할 수 있지만, 실제로는 잘 사용되지 않는다. 이렇게 baseline을 사용하면 variance를 줄일 수 있다.


4. Off-Policy Policy Gradients

Policy gradient는 on-policy 방법이다. 이는 policy가 업데이트될 때마다 $p_\theta(\tau)$가 바뀌기 때문이다. 하지만 실제로 neural network는 1회의 gradient step마다 아주 조금씩만 바뀐다. 이전의 sample을 버리는 것은 너무나도 아까운 일이다. 이를 해결하기 위해 off-policy를 사용할 수 있으며, 이를 위해 importance sampling을 사용할 수 있다. Importance sampling이란 분포 $p(x)$ 대신 $q(x)$에서 샘플링하기 위해서는 importance weight에 따라 sample의 가중치를 조절해야 한다는 것이다.

\[\begin{aligned} \mathbb{E} _{x \sim p(x)} [f(x)] &= \int p(x) f(x) dx \\ &= \int \frac{q(x)}{q(x)} p(x) f(x) dx \\ &= \int q(x) \frac{p(x)}{q(x)} f(x) dx \\ &= \mathbb{E} _{x \sim q(x)} \left[ \frac{p(x)}{q(x)} f(x) \right] \end{aligned}\]

따라서 만약 distribution이 $p_\theta(\tau)$에서 $p_{\theta^\prime}(\tau)$로 바뀐다면, importance weight를 사용하여 다음과 같이 여전히 $p_\theta(\tau)$에서 샘플링한 샘플을 사용할 수 있다.

\[J(\theta^ \prime ) = \mathbb{E} _{\tau \sim p_{\theta}(\tau)} \left[ \frac{p_{\theta^ \prime}(\tau)}{p_{\theta}(\tau)} r(\tau) \right]\]

Gradient를 구하면 다음과 같이 된다.

\[\begin{aligned} \nabla_ {\theta^ \prime} J(\theta^ \prime) &= \mathbb{E} _{\tau \sim p_{\theta}(\tau)} \left[ \frac{\nabla_ {\theta^ \prime} p_{\theta^ \prime}(\tau)}{p_{\theta}(\tau)} r(\tau) \right] \\ &= \mathbb{E} _{\tau \sim p_{\theta}(\tau)} \left[ \frac{p_{\theta^ \prime}(\tau)}{p_{\theta}(\tau)} \nabla_ {\theta^ \prime} \log p_{\theta^ \prime}(\tau) r(\tau) \right] \\ \end{aligned}\]

이때 importance weight을 조금 더 정리해보면 다음과 같다.

\[\begin{aligned} \frac{p_{\theta^ \prime}(\tau)}{p_{\theta}(\tau)} &= \frac{p(s_1) \prod _{t=1} ^T \pi_{\theta^ \prime}(a_t \vert s_t) p(s_{t+1} \vert s_t, a_t)}{p(s_1) \prod _{t=1} ^T \pi_{\theta}(a_t \vert s_t) p(s_{t+1} \vert s_t, a_t)} \\ &= \frac{\prod _{t=1} ^T \pi_{\theta^ \prime}(a_t \vert s_t)}{\prod _{t=1} ^T \pi_{\theta}(a_t \vert s_t)} \\ \end{aligned}\]

즉, importance weight는 policy의 비율로만 정의된다. 이를 통해 off-policy를 사용할 수 있게 되었다. 이를 정리하여 나타내면 $\nabla_ {\theta^ \prime} J(\theta^ \prime)$는 다음과 같다.

\[\nabla_ {\theta^ \prime} J(\theta^ \prime) = \mathbb{E} _{\tau \sim p_{\theta}(\tau)} \left[ \left( \prod _{t=1} ^T \frac{\pi_{\theta^ \prime}(a_t \vert s_t)}{\pi_{\theta}(a_t \vert s_t)} \right) \left( \sum _{t=1} ^T \nabla_ {\theta^ \prime} \log \pi_ {\theta^ \prime} (a_t \vert s_t) \right) \left( \sum _{t=1} ^T r(s_t, a_t) \right) \right]\]

Causality를 적용하면 다음과 같이 정리할 수 있다.

\[\nabla_ {\theta^ \prime} J(\theta^ \prime) = \mathbb{E} _{\tau \sim p_{\theta}(\tau)} \left[ \sum_ {t=1} ^T \nabla_ {\theta^ \prime} \log \pi_ {\theta^ \prime} (a_t \vert s_t) \left( \sum _{t^ \prime = 1} ^ t \frac{\pi_{\theta^ \prime}(a_{t^ \prime} \vert s_{t^ \prime})}{\pi_{\theta}(a_{t^ \prime} \vert s_{t^ \prime})} \right) \left( \sum _{t^ \prime = t} ^ T r(s_{t^ \prime}, a_{t^ \prime}) \left(\prod_ {t^ {\prime \prime} = t} ^ {t^ \prime} \frac{\pi_{\theta^ \prime}(a_{t^ {\prime \prime}} \vert s_{t^ {\prime \prime}})}{\pi_{\theta}(a_{t^ {\prime \prime}} \vert s_{t^ {\prime \prime}})} \right) \right) \right]\]

이때 마지막 부분 $\prod_ {t^ {\prime \prime} = t} ^ {t^ \prime} \frac{\pi_{\theta^ \prime}(a_{t^ {\prime \prime}} \vert s_{t^ {\prime \prime}})}{\pi_{\theta}(a_{t^ {\prime \prime}} \vert s_{t^ {\prime \prime}})}$를 무시하면 state marginal probability를 무시하는 것이 되고, 이 알고리즘이 policy iteration algorithm이다. 이를 무시해도 꽤나 잘 작동한다. 관련해서는 이후 강의에서 더 자세히 다룰 예정이다.

문제는 policy ratio $\prod _ {t^ {\prime} = 1} ^ t \frac{\pi_ {\theta^ \prime}(a_ {t^ {\prime}} \vert s_ {t^ {\prime}})}{\pi_ {\theta} (a_{t^ {\prime}} \vert s_{t^ {\prime}})}$가 0으로 가며 gradient vanishing 문제가 발생할 수 있다는 점이다. 이유는 $\tau$를 $p_{\theta}(\tau)$에서 샘플링했기 때문에 $\pi_{\theta}$의 확률이 $\pi_{\theta^ \prime}$보다 높아서 비율이 1보다 작아지고, 이것이 계속 곱해질 것이기 때문이다. 따라서 variance가 커진다. 이를 해결하기 위해 다른 방식으로 objective를 보도록 하자.

먼저 on-policy policy gradient를 다시 살펴보자.

\[\nabla_ \theta J(\theta) \approx \frac{1}{N} \sum _{i=1} ^N \sum _{t=1} ^T \nabla_ \theta \log \pi_\theta (a_{i, t} \vert s_{i, t}) \hat{Q}_ {i, t}\]

이는 곧 $(s_ {i, t}, a_ {i, t})$를 $\pi_\theta (s_t, a_t)$에서 샘플링한 것이고, importance sampling을 적용하여 다음과 같이 off-policy policy gradient를 나타낼 수 있다.

\[\begin{aligned} \nabla_ {\theta^ \prime} J(\theta^ \prime) &= \frac{1}{N} \sum _{i=1} ^N \sum _ {t=1} ^T \frac{\pi_ {\theta^ \prime} (s_{i, t}, a_{i, t})}{\pi_ {\theta} (s_{i, t}, a_{i, t})} \nabla_ {\theta^ \prime} \log \pi_ {\theta^ \prime} (a_{i, t} \vert s_{i, t}) \hat{Q}_ {i, t} \\ &= \frac{1}{N} \sum _{i=1} ^N \sum _ {t=1} ^T \frac{\pi_ {\theta^ \prime} (s_ {i, t})}{\pi_ {\theta} (s_ {i, t})} \frac{\pi_ {\theta^ \prime} (a_ {i, t} \vert s_ {i, t})}{\pi_ {\theta} (a_ {i, t} \vert s_ {i, t})} \nabla_ {\theta^ \prime} \log \pi_ {\theta^ \prime} (a_ {i, t} \vert s_ {i, t}) \hat{Q}_ {i, t} \\ \end{aligned}\]

이때 $\theta$와 $\theta^ \prime$가 크게 다르지 않으면 state marginal probability $\frac{\pi_{\theta^ \prime}(s)}{\pi_{\theta}(s)}$를 무시하는 것이 합리적이다. 이유는 이후 강의에서 설명할 것이다. 어쨌든, 이러한 방식으로 접근하면 vanishing gradient 문제를 해결하여 variance를 줄일 수 있다.


5. Implementing Policy Gradients

TensorFlow, PyTorch와 같은 automatic differentiation tool을 사용하여 policy gradient를 계산할 수 있다. 이를 위해서는 policy gradient를 일종의 weighted maximum likelihood 식으로 해석하여 pseudo loss를 구현해주면 된다.

\[\begin{aligned} \text{PG} &\quad \nabla_ \theta J(\theta) \approx \frac{1}{N} \sum _{i=1} ^N \sum _{t=1} ^T \nabla_ \theta \log \pi_\theta (a_{i, t} \vert s_{i, t}) \hat{Q}_ {i, t} \\ \text{MLE} &\quad \nabla_ \theta J_ {\text{ML}} (\theta) \approx \frac{1}{N} \sum _{i=1} ^N \sum _{t=1} ^T \nabla_ \theta \log \pi_\theta (a_{i, t} \vert s_{i, t}) \\ \end{aligned}\]

실제 코드로 작성하면 다음과 같다.

추가로 실제로 policy gradient를 사용할 때 주의할 점을 정리하였다.

  • Gradient가 high variance를 가지고 있으므로, noisy gradient를 보게 될 것임을 염두에 두어야 한다.
  • Large batch size를 사용하면 variance를 줄일 수 있다.
  • SGD와 같은 방법보다는 Adam과 같은 adaptive step size rule을 적용하는 것이 낫다. 또한, policy-gradient specific optimizer를 사용하는 것도 좋다.


6. Advanced Policy Gradients

Natural policy gradient를 소개하기 위해 아주 간단한 예시를 보자. 1차원의 파란색 공은 $s = 0$으로 가고자 하고, reward function이 다음과 같이 정의된다고 하자. 즉, $s$가 0에서 너무 멀거나 action이 너무 크면 penalty를 받는다.

\[r(s_t, a_t) = -s_ t ^ 2 - a_ t ^ 2\]

Policy는 다음과 같이 대략 $ks_t$만큼 이동할 수 있는 Gaussian distribution이다.

\[\log \pi_\theta (a_t \vert s_t) = -\frac{1}{2 \sigma ^ 2} (k s_t - a_t) ^ 2 + \text{const}\]

여기서 학습할 수 있는 $\theta = (k , \sigma)$이다. 그런데 실제로 parameter space에서 gradient를 구해보면 아래와 같다.

즉, $\sigma$가 작아질수록 $\sigma$ 방향으로의 gradient만 너무 커지기 때문에 $k$는 거의 업데이트되지 않는다. 이를 해석해보면, 지금까지 우리가 하던 gradient update는 다음과 같이 일종의 Taylor 1차 근사를 통해 $\theta$를 update하는 것으로 볼 수 있다.

\[\theta ^ \prime \leftarrow \arg \max_ {\theta ^ \prime} (\theta ^ \prime - \theta) ^ T \nabla_ \theta J(\theta) \text{ s.t. } \Vert \theta ^ \prime - \theta \Vert ^ 2 \leq \epsilon\]

지금까지는 parameter space에서 일정한 제한을 두고 update하는 것이었다. 하지만 실제로는 policy space에서 update하는 것이 더 효과적이다. 즉 아래와 같이 쓸 수 있다.

\[\theta ^ \prime \leftarrow \arg \max_ {\theta ^ \prime} (\theta ^ \prime - \theta) ^ T \nabla _ \theta J(\theta) \text{ s.t. } D_{KL} (\pi_{\theta ^ \prime} \Vert \pi_{\theta}) \leq \epsilon\]

이때 KL-divergence는 다음과 같이 쓸 수 있고, 이때 $\mathbf{F}$는 Fischer information matrix이다.

\[D_{KL} (\pi_{\theta ^ \prime} \Vert \pi_{\theta}) = (\theta ^ \prime - \theta) ^ T \mathbf{F} (\theta ^ \prime - \theta \\ \mathbf{F} = \mathbb{E} _{\pi_ \theta} \left[ \nabla_ \theta \log \pi_\theta (a \vert s) \nabla_ \theta \log \pi_\theta (a \vert s) ^ T \right]\]

실제 update 과정을 비교해보면 다음과 같다. 이러한 방법을 natural policy gradient 혹은 covariant policy gradient라고 하며, conjugate gradient와 같은 방법으로 update할 수 있다.

\[\begin{aligned} \text{Vanilla} &\quad \theta \leftarrow \theta + \alpha \nabla_ \theta J(\theta) \\ \text{Natural} &\quad \theta \leftarrow \theta + \alpha \mathbf{F} ^ {-1} \nabla_ \theta J(\theta) \\ \end{aligned}\]


CS285 카테고리 내 다른 글 보러가기

댓글 남기기