[CS285] 12. Model-Based Policy Learning

Date:     Updated:

카테고리:

태그:

💡 이 글은 『2024 PseudoLab 전반기 강화학습팀』으로 진행되었으며, CS285 Fall 2023를 따라 정리했습니다.


1. Introduction

Lecture 11에서 얻은 model-based RL version 1.5를 다시 살펴보면 다음과 같다.

image

이러한 MPC(Model Predictive Control)은 오류가 커지기 전에 바로잡기 때문에 이점을 가지지만 아직 sub-optimal하다. 이는 본 알고리즘이 open-loop이기에, 모든 sequence를 한 번에 계산하여 예측하기 때문이다. 이러한 방법은 비효율적이고, closed-loop인 경우에는 외부의 state를 고려하여 action planning이 이루어지므로 더 효율적이다.

image

그렇다면 어떤 policy $\pi$를 사용할 것인가? Neural network는 global planning이 가능하고, Lecture 10에서 다룬 LQR의 경우 initial plan과 유사한 경우에만 local planning이 가능하다. 이번에는 global policy, 즉 neural network를 사용한 model-based policy learning에 대해 다룬다.

model-based policy learning의 간단한 방법을 생각해보자. 먼저 dynamic model $f(s, a)$를 사용하여 state를 예측하고, policy $\pi_ \theta$를 통해 action을 예측하는 순서로 진행한다. 이를 통해 단순화된 computation graph를 그려보자. 이제 closed-loop으로 model-based learning을 할 수 있는 간단한 모델이 완성되었다.

image

이를 학습시키는 방법을 model-based RL version 2.0이라고 부르자.

image

그러나 이러한 방법에는 두 가지 문제가 있다. 첫 번째는 parameter sensitivity이다. 시작 시점의 action의 약간의 변화가 trajectory 전체를 크게 변화시킬 수 있다. 이는 Lecture 10에서 다룬 shooting method에서와 동일한 문제인데, shooting method에서는 parameter sensitivity 문제를 해결하기 위해 second derivative를 사용하여 variance를 줄이는 등의 테크닉을 사용할 수 있었다. 그러나 shooting method에서는 dynamic programming으로 transition dynamics를 간단하게 modeling했던 것과 달리 model-based RL에서는 neural network를 사용하기 때문에 이러한 테크닉을 사용할 수 없다.

두 번째 문제는 vanishing gradients 혹은 exploding gradients 문제이다. 이러한 문제는 RNN과 같이 BPTT(Backpropagation Through Time)에서 발생할 수 있는데, 이 경우에는 RNN의 경우처럼 LSTM과 같은 구조를 사용할 수 없다.

따라서 policy로 backpropagation을 직접 수행하는 것은 좋은 선택이 아니다. 대신 derivative free RL 알고리즘을 사용해야 한다. 이러한 derivative-free 알고리즘은 BPTT를 하지 않는 알고리즘을 의미하며, 결국 trajectory sampling을 통해 계산하는 model-free algorithm을 사용해야 한다. 관련해서 다음 절에서 더 자세히 다루어보자. 지금까지의 내용을 정리하면 다음과 같다.

  • Open-loop model-based RL 대신 closed-loop model-based RL을 사용하기 위해서는 policy $\pi_ \theta (s_ t \vert a_ t)$를 훈련해야 한다.
  • 그러나 단순히 BPTT(backpropagation through time) 방식으로 policy를 훈련시키는 것은 좋은 선택이 아니다. Parameter sensitivityvanishing gradients 문제가 발생할 수 있기 때문이다.


2. Model-free Learning with a Model

Lecture 5에서 다룬 policy gradient와 지금 다루었던 backprop gradient 또는 pathwise gradient의 식을 살펴보면 다음과 같다.

image

Policy gradient는 단순 trajectory sampling을 통해 학습하기 때문에 explicit하게 transition dynamics를 사용하지 않는다. 물론, implicit하게 사용되긴 하지만 실제로 나타나지 않기에 transition dynamics를 통해 gradient 관련 문제가 발생하지 않는다. 반면 pathwise gradient 혹은 backprop gradient는 식의 마지막 부분의 product of jacobians 때문에 gradient 관련 문제가 발생할 수 있다. 따라서 policy gradient가 충분한 샘플만 있다면 더 안정적으로 학습할 수 있다.

그렇다면 지금까지 policy gradient와 같은 model-free RLreal physical system에서 sampling을 해야 하는 문제가 있었는데, 이제 model-based RL을 적용하여 인공적인 sample을 만들어 학습시키면 이러한 문제도 해결되고 일석이조다. 이러한 leverage를 통해 두 알고리즘의 장점을 같이 얻을 수 있다. 이러한 알고리즘을 model-based RL version 2.5라고 부르자. Model-based RL을 사용하여 인공적으로 trajectory $\tau_ i$를 생성하고, 이를 policy gradient를 통해 학습시킨다.

image

그렇다면 이러한 방법의 문제는 없는가? 역시나 distribution shift 문제가 생긴다. 즉, true dynamics로 얻는 trajectory와 learned model로 얻는 trajectory 사이에 distribution shift가 생길 수 있다. 이는 imitation learning에서 training trajectory와 expected trajectory가 달랐던 것과 동일한 문제이다. 따라서 long model-based rollout을 시행하면 error가 축적되며, 이는 behavior cloning에서 다뤘던 것과 같은 원리로 $O(\epsilon T^2)$의 크기로 커진다.

image

따라서 long rollout을 피하고, short rollout을 사용해야 한다. 그러나 이러한 방법을 사용하게 되면 later time step에 대해 최적화가 되지 않는다. 예를 들어 30분 동안 요리를 해야 하는 로봇이 있다면, 요리를 준비하기만 하는 앞 5분에 대해서만 계속 최적화하는 상황이 발생할 수 있다. 이를 해결하기 위해 적은 수의 real trajectory를 sampling한 뒤, 각 time step에서 시작하여 short model-based rollout을 시행하는 trick을 사용한다.

image

그러나 이 방법에도 한 가지 문제가 있다. 이러한 model-based rolloutstate distribution은 정확하지 않다. 왜냐하면 real world rollout을 얻을 때의 policy는 short model-based rollout을 얻을 때의 policy와 다르기 때문이다. 즉, Real world rolloutshort model-based rollout을 통해 policy가 계속 update될 것이고, 더 이상 real word rolloutlatest policy로 얻은 trajectory가 아니기에 이로부터 파생된 model-based rollout은 잘못된 state distribution을 얻게 된다. 따라서 on-policy algorithmpolicy gradient와 같은 알고리즘은 이러한 문제를 해결하기 어렵다. 대신 Q-learning, Q-function actor-critic algorithm과 같은 off-policy algorithm이 더 잘 작동한다. 이를 사용하면 어느 정도 문제를 해결할 수 있다. 이러한 방법을 model-based RL version 3.0이라고 부르자.

image

지금까지의 내용을 정리하면 다음과 같다.

  • Policy gradient와 같은 model-free RLBPTT가 필요없어 closed-loop model-based RL에서 policy learning을 위한 적절한 선택이다. 즉, model-based RL로 인공적인 샘플을 만들어 model-free RL을 학습시킨다.
  • 그러나 long model-based rollout을 사용하면 distribution shift 문제가 발생하고, short model-based rollout을 사용하면 later time step에 대해 최적화가 되지 않는다. 따라서 real world rollout을 적은 수만 sampling한 뒤 short model-based rollout을 사용하는 trick을 사용한다.
  • 대신 이러한 방법은 off-policy sample을 생성하기에 on-policy algorithmpolicy gradient는 이러한 문제를 해결하기 어렵다. 따라서 off-policy algorithmQ-learning과 같은 알고리즘을 사용해야 한다.


3. Dyna-style Algorithms

Dyna는 지금까지 설명한 model을 사용한 model-free RL을 수행하는 online Q-learning algorithm의 시초이다. 그 알고리즘은 다음과 같고, 여기서 1~4 step은 일반적인 online Q-learning과 같고, 5~7 step이 중요하다. 이때 $(s^ \prime, a^ \prime)$을 one step model-based rollout으로만 얻어 distributional shift를 최소화할 수 있다.

image

이를 일반화한 Dyna-style algorithm은 다음과 같다. 현대의 model-based RL은 모두 이러한 형식을 가지고 있으며, short model-based rollout만 요구하고, diverse state에 대해 샘플링이 가능하다는 점에서 강점을 가진다.

image

지금까지의 내용을 Q-learning process와 함께 나타내면 아래와 같다. Process 4, 5가 추가된 것이며, 따라서 이를 model-accelerated off-policy RL이라고 부를 수 있다. 다만 일반적으로 real buffer보다 synthetic buffer가 더 크다는 점을 유의해야 한다.

image

이를 활용한 실제 알고리즘으로 MBA(Model-based Acceleration), MVE(Model-based Value Expansion), MBPO(Model-based Policy Optimization) 등이 있다. 지금까지 설명한 내용은 MBPO에 가까우며, 이 모든 알고리즘은 전반적으로 다음 과정을 거친다고 생각할 수 있다.

image

본 알고리즘의 장점은 일종의 model-based dataset amplification을 통해 성능을 향상시킬 수 있다는 것이다. 그러나 단점으로는 model이 정확하지 않으면 model-based rolloutbias가 생길 수 있다는 점이 있다. 이를 Lecture 11에서 다루었던 ensemble 등으로 해결할 수 있으며, 본 Lecture에서의 접근법인 off-policy algorithm으로 어느 정도 해결하긴 했지만 어쨌든 wrong state distribution을 가지므로 문제가 생길 수 있다. 따라서 real world data collection을 적당한 주기로 해주는 것이 필요하다. 즉, sampling efficiencybias 간의 trade-off가 존재한다. 지금까지의 내용을 정리하면 다음과 같다.

  • Dyna-style algorithmmodel-based RLmodel-free RL에 적용한 알고리즘으로, model-accelerated off-policy RL이라고 부를 수 있다.
  • 이러한 알고리즘은 real world data collection 빈도에 따라 sampling efficiencybias 간의 trade-off가 존재하여 그 trade-off를 잘 조절해야 한다.


CS285 카테고리 내 다른 글 보러가기

댓글 남기기