[CS236n] 4. Latent Variable Models
카테고리: CS236n
태그: VAE Generation
💡 이 글은 『2024 YAI 봄전반기 생성모델팀』으로 진행되었으며, CS236n Fall 2023를 따라 정리했습니다.
1. Introduction
2장에서 다룬 자가회귀모형(autoregressive model)은 데이터셋 $\mathcal{D}$로부터 데이터의 분포 $p_{data}$와 유사한 분포 $p_{\theta}$를 학습한다. 그러나 autoregressive model은 이미지 자체에서 드러나지 않는 이미지의 특성들을 담아내기에는 어렵다.
예를 들어 사람 얼굴 데이터셋 $\mathcal{D}$를 보면 우리는 사람의 성별, 눈의 색깔, 머리의 색깔, 표정 등을 쉽게 파악할 수 있다. 그러나 autoregressive model은 이러한 특성들을 담아내기에는 어렵다. 이러한 특성들을 명시적으로(explicitly) 드러내 모델링하기 위하여 잠재변수모형(latent variable model)을 사용한다.
잠재변수모형(latent variable model)에서는 이러한 특성들을 잠재변수(latent variables) $z$로 나타낸다. 데이터에서 우리가 관찰할 수 있는 것은 $x$뿐이지만, 사실 $x$는 latent variables $z$에 의해 생성된 것이라고 보는 것이다. 만약, 우리가 적절한 latent variables $z$를 선택하고 이미지마다 일일히 latent variables $z$를 라벨링한 뒤, 이를 이용하여 모델을 학습시킨다면, 우리는 이미지의 특성들을 명시적으로 모델링할 수 있을 것이다. 예를 들어 사람 얼굴 데이터셋 $\mathcal{D}$에서 latent variables $z$를 성별, 눈의 색깔, 머리의 색깔, 표정 등으로 나누어 모두 라벨링하고 위 모델을 학습시키는 것을 생각할 수 있다.
만약 적절한 특성을 latent variables로 선정하고 라벨링했다면, 아무것도 모르는 상태에서 확률분포 $p(x)$를 추정하는 것보다는 당연히 latent variables가 주어졌을 때의 확률분포 $p(x \mid z)$를 추정하는 것이 더 쉬울 것이다. 그리고 잘 훈련된 모델이라면 사후확률(posterior) $p(z \mid x)$를 추정하여 각 이미지의 특성을 알아낼 수도 있을 것이다.
그러나 latent variables $z$를 일일히 라벨링하는 것은 불가능에 가깝다. 우리는 모델이 데이터를 통해 latent variables $z$를 스스로 학습하도록 만들고 싶다. 이 아이디어를 활용한 것이 latent variable model의 기초 버전이라고 볼 수 있는 gaussian mixture model이다.
2. Gaussian mixture model
2.1. Setting
Gaussian mixture model은 다음과 같이 정의할 수 있다. latent variables $z$는 $K$개의 클래스 중 하나에 속하며, $z=k$일 때의 $x$의 분포는 가우시안 분포를 따른다고 가정한다. 이를 수식으로 나타내면 다음과 같다.
\[z \sim \text{Categorical}(1, 2, \cdots, K) \\ p(x \mid z=k) = \mathcal{N}(\mu_k, \Sigma_k)\]예를 들어 얼굴 데이터셋 $\mathcal{D}$에서 $z=1$은 남성, $z=2$는 여성으로 라벨링했다고 하면, $z=k$일 때의 $x$의 분포는 각각 남성의 얼굴, 여성의 얼굴을 가우시안 분포로 모델링하는 것이라고 생각하면 된다. (실제로는 명시적으로 $z=1$은 남성, $z=2$는 여성처럼 라벨링하는 것이 아니라, 학습한 뒤에 $z$가 그런 특성들에 의해 잘 나누어지기를 바라는 것이다.)
데이터를 생성할 때는 먼저 $z$를 샘플링하고, $z=k$일 때의 $x$의 분포를 따르는 가우시안 분포에서 $x$를 샘플링한다.
2.2. Clustering
이는 unsupervised learning 중에서도 clustering이라고 불리는 문제이다. 즉, 데이터를 $K$개의 클러스터로 나누는 문제이다. 아래는 왼쪽부터 데이터셋, $K=1$인 경우, $K=2$인 경우 clustering를 나타낸 그림이다.
잘 훈련된 모델이라면 $z=1, 2, \cdots, K$ 중 $z=i$로부터 데이터 $x$가 생성되었을 확률, 즉 posterior probability $p(z = i \mid x)$를 추정할 수 있을 것이다. 이를 통해 각 데이터가 어떤 클러스터에 속하는지 알아낼 수 있을 것이다. 아래는 $K=3$인 모델에서, 각 데이터가 어떤 클러스터에 속하는지 알아낸 결과이다. 수학적으로는 각 데이터에 대하여 $k = {\text{argmax}}_ i \, p(z=i \mid x)$를 계산하여 나타낸 것이다.
2.3. Alternative motivation
Gaussian mixture model을 다시 해석해보면, 간단한 가우시안 분포들인 $p(x \mid z)$의 합으로 복잡한 분포 $p(x)$를 모델링하는 것이라고 볼 수 있다. 수학적으로는 다음과 같이 표현한다.
\[\begin{aligned} p(x) &= \sum_ z p(x, z) \\ &= \sum_ z p(x \mid z) p(z) \\ &= \sum_ {k=1}^K p(z=k) \mathcal{N}(\mu_k, \Sigma_k) \\ \end{aligned}\]3. Variational autoencoder
3.1. Setting
Variational autoencoder에서는 이를 확장하여 무수히 많은 수의 가우시안 분포 $p(x \mid z)$를 사용한다. 즉, gaussian mixture model에서 $z \sim \text{Categorical}(1, 2, \cdots, K)$였던 것을 $z \sim \mathcal{N}(0, I)$로 바꾼 것이다. 이를 수식으로 나타내면 다음과 같다.
\[z \sim \mathcal{N}(0, I) \\ p(x \mid z) = \mathcal{N}(\mu_\theta(z), \Sigma_\theta(z))\]$z$가 유한한 개수를 가지는 것이 아니므로 gaussian mixture model에서 각각 $z=k$에 대해 $\mu_k, \Sigma_k$를 사용했던 것처럼 각 $z$에 대해 $\mu, \Sigma$를 할당할 수 없다. 따라서 이를 대신하여 neural networks를 사용해 $\mu_\theta(z), \Sigma_\theta(z)$로 나타낸다.
이때 $p(x)$는 다음과 같이 간단한 가우시안 분포들인 $p(x \mid z)$의 합으로 볼 수 있고, $z$가 continuous variable이므로 다음과 같이 나타낼 수 있다. 이를 베이지안 통계학에서는 marginal likelihood라고 부른다.
\[\begin{aligned} p(x) &= \int_ z p(x, z) \, dz \\ &= \int_ z p(x \mid z) p(z) \, dz \\ \end{aligned}\]3.2. Marginal likelihood
3.2.1. Motivation
Variational autoencoder의 개념을 이해하기 위해 문제 하나를 정의하자. 우리에게는 이미지 하단 절반만 있는 MNIST 데이터셋 $\mathcal{D} = \lbrace x^{(1)}, x^{(2)}, \cdots, x^{(m)}\rbrace $이 있고, 이때 관찰 가능한 하단 절반을 $X$ 즉 observed random variables이라 하고, 알지 못하는 상단 절반을 $Z$ 즉 일종의 latent variables로 두자.
이를 직관적으로 생각해보자. 우리는 하단 절반을 잘 생성하는 모델, 즉 $p(x)$를 잘 추정하는 모델을 만들고 싶다. 그러나 사실 하단 절반 $x$는 상단 절반 $z$에 의해 생성된 것이어서, $p(x \mid z)$를 잘 추정한 뒤 이것을 다 더해 $p(x)$를 추정하는 것이 더 낫다(latent variable model의 기본 가정). 따라서 이를 추정하는 모델 $\mathcal{M}$의 파라미터를 $\theta$라 하고, $p(x \mid z; \theta)$를 훈련하고자 한다. 이때 $p(z)$는 우리가 마음대로 정할 수 있으므로 (gaussian mixture model에서는 $p(z) = \text{Categorical}(1, 2, \cdots, K)$, variational autoencoder에서는 $p(z) = \mathcal{N}(0, I)$ 등) joint probability $p(x, z; \theta) = p(x \mid z; \theta) p(z)$도 알 수 있다. 여기서는 $z \in \lbrace 0, 1\rbrace ^{H \times W}$이므로 $p(z)$는 $\lbrace 0, 1\rbrace ^{H \times W}$에서의 uniform distribution으로 두어도 무방하다.
3.2.2. Maximum likelihood estimation
우리는 autoregressive model 때와 같이 maximum likelihood estimation을 통해 $p(x)$를 추정하고자 한다. 단 latent variable $z$를 활용할 것이다. Log-likelihood를 다음과 같이 나타낼 수 있다. (* 엄밀하게는 variational autoencoder에서는 $z$가 continuous variable이므로 $\sum_z$ 대신 $\int_z dz$를 사용해야겠지만, 하단 절반만 있는 MNIST 데이터셋에서는 $z$가 상단 절반을 나타내는 것이므로 discrete variable이어서 $\sum_z$를 사용했다.)
\[\begin{aligned} \log \prod_ {x \in \mathcal{D}} p(x; \theta) &= \sum_ {x \in \mathcal{D}} \log p(x; \theta) \\ &= \sum_ {x \in \mathcal{D}} \log \sum_ z p(x, z; \theta) \\ \end{aligned}\]그런데 이를 계산하기 위해서는 $z$에 대한 모든 가능한 경우의 수를 다 더해야 하므로 계산량이 너무 많다. 여기서도 $z \in \lbrace 0, 1\rbrace ^{H \times W}$이므로 $2^{H \times W}$개의 경우의 수가 있다. MNIST 데이터셋이 32 x 32 이고, 그 상단 절반을 추론한다고 가정해도 $2^{16 \times 32}$개의 $p(x,z; \theta)$항을 모두 더해야 하는 것이다. Gaussian mixture model에서는 $z$가 $K$개 뿐이었으므로 이를 모두 더하는 것이 가능했을지 몰라도, variational autoencoder에서는 이를 모두 더하는 것이 불가능하다. 이를 intractable하다고 한다. 당연히 gradient $\nabla_\theta$를 계산하는 것도 불가능하다. 우리에게는 근사(approximation)가 필요하다.
3.2.3. First attempt: Naive Monte Carlo
앞으로 $p(x; \theta)$를 $p_\theta(x)$로, $p(x, z; \theta)$를 $p_\theta(x, z)$로 나타낸다. 이제 $p_\theta(x)$를 근사하기 위한 방법을 생각해보자. 가장 간단한 방법은 naive Monte Carlo이다. 먼저 $p_\theta(x)$를 다음과 같이 나타낼 수 있다.
\[\begin{aligned} p_\theta(x) &= \sum_ z p_\theta(x, z) \\ &= \left\vert \mathcal{Z} \right\vert \sum _{z \in \mathcal{Z}} \frac{1}{\left\vert \mathcal{Z} \right\vert} p_\theta(x, z) \\ &= \left\vert \mathcal{Z} \right\vert \mathbb{E}_ {z \sim \mathrm{Uniform}(\mathcal{Z})} \left[ p_\theta(x, z) \right] \\ \end{aligned}\]여기서 $\mathcal{Z}$는 $z$의 모든 가능한 경우의 수이다. 당연히 모든 가능한 경우의 수를 다 더하는 것은 불가능(intractable)하므로, 이를 Monte Carlo approximation으로 근사하자.
- 랜덤으로 $z^{(1)}, z^{(2)}, \cdots, z^{(k)}$를 $\text{Uniform}(\mathcal{Z})$에서 샘플링한다.
- 기댓값을 다음과 같이 근사한다.
쉽게 말하면 $2^{16 \times 32}$가지 $z$에 대해 $p_\theta(x, z)$를 다 더하는 것이 불가능하므로, 랜덤으로 $k$개의 $z$를 샘플링하여 $p_\theta(x, z)$를 다 더하는 것으로 근사하자는 것이다. 그러나 이 방법은 $k$가 충분히 크지 않으면 $p_\theta(x)$를 잘 근사하지 못한다. 이는 많은 $z$에 대하여 $p_\theta(x, z)$ 값은 굉장히 낮고, 몇몇 $z$에서만 $p_\theta(x, z)$ 값이 높기 때문이다. 그래서 “운이 좋게” 높은 $p_\theta(x, z)$를 가진 $z$를 샘플링하지 못하면 $p_\theta(x)$를 잘 근사하지 못한다. 즉, high variance를 가진다. 그러므로 $z$를 잘 샘플링하여, $k$가 조금 작더라도 $p_\theta(x)$를 잘 근사할 수 있는, 즉 low variance를 가지는 방법이 필요하다.
3.2.4. Second attempt: Importance sampling
이를 위해 importance sampling이라는 방법을 사용한다. 먼저 $p_\theta(x)$를 다음과 같이 나타낼 수 있다.
\[\begin{aligned} p_\theta(x) &= \sum_ z p_\theta(x, z) \\ &= \sum_ z \frac{q(z)}{q(z)} p_\theta(x, z) \\ &= \mathbb{E}_ {z \sim q(z)} \left[ \frac{p_\theta(x, z)}{q(z)} \right] \\ \end{aligned}\]여기서 $q(z)$는 샘플링하고자 하는 $z$의 분포로, 임의로 정할 수 있다. Naive Monte Carlo에서는 $q(z) = \text{Uniform}(\mathcal{Z})$였다. 이제 $p_\theta(x)$를 다음과 같이 근사할 수 있다.
- 랜덤으로 $z^{(1)}, z^{(2)}, \cdots, z^{(k)}$를 $q(z)$에서 샘플링한다.
-
기댓값을 다음과 같이 근사한다.
\[\begin{aligned} \sum_ z p_\theta(x, z) &\approx \frac{1}{k} \sum_ {j=1}^k \frac{p_\theta(x, z^{(j)})}{q(z^{(j)})} \\ \end{aligned}\] -
참고로, 이 경우에도 naive Monte Carlo와 같이 그 평균은 $p_\theta(x)$이다. 즉 unbiased estimator이다.
\[\begin{aligned} \mathbb{E}_ {z^{(j)} \sim q(z)} \left[ \frac{1}{k} \sum_ {j=1}^k \frac{p_\theta(x, z^{(j)})}{q(z^{(j)})} \right] &= p_\theta(x) \\ \end{aligned}\]
그렇다면 어떤 $q(z)$를 사용해야 할까? 직관적으로는 데이터 $\bar{x}$가 주어졌을 때, $p_\theta(x=\bar{x}, z)$가 높은 $z$를 더 자주 샘플링하는 분포를 사용하면 좋을 것이다. 이를 다음과 같이 나타낼 수 있다.
\[\begin{aligned} p_\theta(x=\bar{x}, z) &= p_\theta(z \mid x=\bar{x}) p_\theta(x=\bar{x}) \\ \end{aligned}\]결국 $p_\theta(z \mid x=\bar{x})$의 값이 높을수록 $p_\theta(x=\bar{x}, z)$의 값이 높아지므로, $p_\theta(z \mid x)$를 $q(z)$로 사용하면 좋을 것이다. 이를 수학적으로 증명하는 과정이 그 유명한 ELBO(=Evidence Lower Bound)의 유도과정이다.
3.3. Evidence Lower Bound (ELBO)
ELBO의 유도는 log-likelihood를 구하는 것으로부터 시작한다. 식은 importance sampling에서 유도했던 것과 동일하다.
\[\begin{aligned} \log p_\theta(x) &= \log \left( \sum_ {z \in \mathcal{Z}} p_\theta(x, z) \right)\\ &= \log \left( \sum_ {z \in \mathcal{Z}} \frac{q(z)}{q(z)} p_\theta(x, z) \right)\\ &= \log \left( \mathbb{E}_ {z \sim q(z)} \left[ \frac{p_\theta(x, z)}{q(z)} \right] \right)\\ \end{aligned}\]
이때 위로 볼록(concave)한 함수인 $\log$ 함수와 임의의 분포 $q(z)$, 임의의 함수 $f(z)$에 대해 젠센 부등식(Jensen’s inequality)는 다음과 같이 나타낼 수 있다. 등호성립조건은 $f(z)$가 상수함수일 때이다.
\[\begin{aligned} \log \left( \mathbb{E}_ {z \sim q(z)} \left[ f(z) \right] \right) &= \log \left( \sum _z q(z) f(z) \right) \\ &\geq \sum_ z q(z) \log f(z) \\ &= \mathbb{E}_ {z \sim q(z)} \left[ \log f(z) \right] \\ \end{aligned}\]이를 적용하면 다음과 같이 나타낼 수 있다.
\[\begin{aligned} \log \mathbb{E}_ {z \sim q(z)} \left[ \frac{p_\theta(x, z)}{q(z)} \right] &\geq \mathbb{E}_ {z \sim q(z)} \left[ \log \frac{p_\theta(x, z)}{q(z)} \right] \\ \end{aligned}\]이 lower bound를 ELBO(=Evidence Lower Bound)라고 부른다. “Evidence”가 들어가는 이유는 베이지안 추론에서 $p(x)$를 evidence라고 부르기 때문이다. 즉 evidence인 $p_\theta(x)$의 lower bound인 것이다.
등호성립조건을 다시 보면 $p_\theta(x, z)/q(z)$가 $z$에 무관한 상수함수여야 하는데, $p_\theta(x,z) = p_\theta(z \mid x) p_\theta(x)$이므로 $q(z) = p_\theta(z \mid x)$일 때 등호가 성립한다. 3.2.4절과 동일한 결론이다. 실제로 $q(z) = p_\theta(z \mid x)$를 대입해보자.
\[\begin{aligned} \mathbb{E}_ {z \sim q(z)} \left[ \log \frac{p_\theta(x, z)}{q(z)} \right] &= \mathbb{E}_ {z \sim q(z)} \left[ \log \frac{p_\theta(x, z)}{p_\theta(z \mid x)} \right] \\ &= \mathbb{E}_ {z \sim q(z)} \left[ \log p_\theta(x) \right] \\ &= \log p_\theta(x) \\ \end{aligned}\]다시 돌아가서, ELBO를 다시 정리하면 다음과 같다. 여기서 $H(q)$는 entropy로 $- \sum _z q(z) \log q(z)$로 정의된다.
\[\begin{aligned} \log p_\theta(x) &\geq \mathbb{E}_ {z \sim q(z)} \left[ \log \frac{p_\theta(x, z)}{q(z)} \right] \\ &= \mathbb{E}_ {z \sim q(z)} \left[ \log p_\theta(x, z) - \log q(z) \right] \\ &= \sum _z q(z) \log p_\theta(x, z) - \sum _z q(z) \log q(z) \\ &= \sum _z q(z) \log p_\theta(x, z) + H(q) \end{aligned}\]따라서 ELBO를 다음과 같이 나타낼 수 있다.
\[\begin{aligned} \mathrm{ELBO} &= \sum _z q(z) \log p_\theta(x, z) + H(q) \\ \end{aligned}\]3.4. Variational inference
그렇다면 $q(z) = p_\theta(z \mid x)$로 두면 등호가 성립하고, 직접 $\log p_\theta(x)$를 구할 수 있을 텐데 왜 ELBO를 구하는 것일까? 이유는 사후확률(posterior) $p_\theta(z \mid x)$를 구하는 것이 불가능(intractable)하기 때문이다. 왜? 베이즈 정리에 의해 사후확률 $p_\theta(z \mid x)$는 다음과 같이 나타낼 수 있다.
\[\begin{aligned} p_\theta(z \mid x) &= \frac{p_\theta(x \mid z)}{p_\theta(x)} p(z) \\ \end{aligned}\]이 식의 분모, 즉 marginal likelihood $p_\theta(x)$를 구하는 것은 intractable하고 지금 이를 근사하기 위해 ELBO까지 열심히 구했는데, 또 이것을 활용해서 $p_\theta(z \mid x)$를 구하는 것은 당연히 intractable하다. 대신, 우리는 $q(z)$를 사용하여 $p_\theta(z \mid x)$를 근사하고자 한다. 이를 variational inference라고 한다.
이를 수학적으로는 $q(z)$와 $p_\theta(z \mid x)$의 KL-divergence를 최소화하는 것으로 나타낼 수 있다. 수식으로 전개해보면 다음과 같다.
\[\begin{aligned} D_{KL}(q(z) \parallel p_\theta(z \mid x)) &= \sum _z q(z) \log \frac{q(z)}{p_\theta(z \mid x)} \\ &= \sum _z q(z) \log q(z) - \sum _z q(z) \log p_\theta(z \mid x) \\ &= -H(q) - \sum _z q(z) \log \frac{p_\theta(x, z)}{p_\theta(x)} \\ &= -H(q) - \sum _z q(z) \log p_\theta(x, z) + \sum _z q(z) \log p_\theta(x) \\ &= -H(q) - \sum _z q(z) \log p_\theta(x, z) + \log p_\theta(x) \\ &= - \sum _z q(z) \log p_\theta(x, z) + \log p_\theta(x) - H(q) \\ &\geq 0 \end{aligned}\]위에서 유도한 대로, (혹은 위 식에서 부등호에 대해 정리하면) ELBO는 다음과 같다.
\[\begin{aligned} \mathrm{ELBO} &= \sum _z q(z) \log p_\theta(x, z) + H(q) \\ \end{aligned}\]따라서 다음과 같이 나타낼 수 있다.
\[\begin{aligned} \log p_\theta(x) &= \mathrm{ELBO} + D_{KL}(q(z) \parallel p_\theta(z \mid x)) \\ \end{aligned}\]즉, $q(z)$가 $p_\theta(z \mid x)$와 가까워질수록, 즉 $D_{KL}(q(z) \parallel p_\theta(z \mid x))$가 0에 가까워질수록 ELBO의 값은 실제 log-likelihood $\log p_\theta(x)$에 가까워진다. 우리는 $q(z)$를 tractable probability distribution인 간단한 분포로 만들기 위해 variational parameters $\phi$를 사용한다. 즉, $q(z; \phi)$로 나타낸다.
예를 들어 다음과 같은 정규분포로 모델링할 수도 있다.
\[q(z; \phi) = \mathcal{N}(\phi_1, \phi_2)\]이때 variational inference는 $q(z; \phi)$를 $p_\theta(z \mid x)$에 가깝게 만드는 것이다. 예를 들면 아래 그림에서 posterior $p_\theta(z \mid x)$는 초록색 $\mathcal{N}(-4, 0.75)$보다 주황색 $\mathcal{N}(2,2)$에 가깝다.
이제 $q(z)$는 $\phi$를 파라미터로 가지는 분포 $q(z; \phi)$가 되었다. 따라서 ELBO를 다음과 같이 나타내는 것이 좋을 것이다.
\[\begin{aligned} \mathrm{ELBO} &= \sum _z q(z; \phi) \log p_\theta(x, z) + H(q; \phi) \\ &= \mathcal{L} (x; \theta, \phi) \\ \log p_\theta(x) &= \mathcal{L} (x; \theta, \phi) + D_{KL}(q(z; \phi) \parallel p_\theta(z \mid x)) \\ \end{aligned}\]
이제 우리는 ELBO를 최대화하기 위해 $\theta, \phi$를 모두 학습시켜야 한다.
학습 과정을 보기 전에, 지금까지의 과정을 이미지 하단 절반만 있는 MNIST 데이터셋에 적용시켜 보자. 일단 $p_\theta(z,x)$는 충분히 $p_{data}(z,x)$와 비슷한 상태(학습된 상태)라고 가정하자. 이때, $q(z; \phi)$를 다음과 같이 나타낼 수 있다.
\[q(z; \phi) = \prod_ {z_i} (\phi_i)^{z_i} (1-\phi_i)^{1-z_i}\]해석하자면 픽셀 $z_i$가 1일 확률이 $\phi_i$라는 뜻이다. 그렇다면, $q(z; \phi)$가 posterior $p_\theta(z \mid x)$의 좋은 근사가 되기 위해서는 어떤 조건을 만족해야 할까? 직관적으로 “숫자가 있어야 할 만한 곳”에는 $\phi_i = 1$, 그렇지 않은 곳에는 $\phi_i = 0$이 되어야 한다고 생각할 것이다. 그렇게 되도록 학습시키는 것이 우리의 목표이고, 이는 ELBO를 $\phi$에 대해 최대화하면 달성할 수 있다.
3.5. Stochastic variational inference
이제 ELBO를 최대화하기 위해 $\theta, \phi$를 동시에(jointly) 학습시켜야 한다. 데이터셋 $\mathcal{D} = \lbrace x^{(1)}, x^{(2)}, \cdots, x^{(m)}\rbrace $가 주어졌을 때, 우리는 $\theta, \phi^1, \phi^2, \cdots, \phi^m$의 함수인 $\sum_ {x^i \in \mathcal{D}} \mathcal{L} (x^i; \theta, \phi^i)$를 최적화(최대화)해야 한다. 여기서 $\phi^i$는 데이터 $x^{(i)}$가 주어졌을 때 $q(z; \phi^i)$에 사용되는 파라미터이다. 데이터마다 파라미터가 다르다. 이는 $q(z; \phi)$가 $p_\theta(z \mid x)$를 근사하기 위한 것이므로 데이터 $x$마다 $q(z; \phi)$를 다르게 근사해야 할 것이기 때문이다. 그럼 데이터셋의 크기가 커지면 $\phi$도 엄청나게 많아지지 않을까? 이는 이후 “amortized inference”라는 방법으로 해결한다.
일단 ELBO $\mathcal{L} (x^i; \theta, \phi^i)$를 다시 살펴보자. ELBO는 다음과 같이 나타낼 수 있고, stochastic gradient descent로 다음과 같이 최적화한다.
\[\begin{aligned} \mathcal{L} (x^i; \theta, \phi^i) &= \sum _z q(z; \phi^i) \log p_\theta(x^i, z) + H(q; \phi^i) \\ &= \mathbb{E}_ {z \sim q(z; \phi^i)} \left[ \log p_\theta(x^i, z) - \log q(z; \phi^i) \right] \\ \end{aligned}\]- $\theta, \phi^1, \phi^2, \cdots, \phi^m$를 초기화한다.
- 데이터셋 $\mathcal{D}$에서 랜덤으로 $x^i$를 샘플링한다.
- $\mathcal{L} (x^i; \theta, \phi^i)$를 $\phi^i$의 함수로 보고 최적화한다. 즉
- $\phi^i = \phi^i + \eta \nabla_{\phi^i} \mathcal{L} (x^i; \theta, \phi^i)$를 반복한다.
- $\phi^{i, *} \approx \mathrm{argmax}_{\phi^i} \mathcal{L} (x^i; \theta, \phi^i)$로 수렴시킨다.
- $\mathcal{L} (x^i; \theta, \phi^i)$를 $\theta$의 함수로 보고 $\nabla_{\theta} \mathcal{L} (x^i; \theta, \phi^i)$을 계산한다.
- $\theta$를 업데이트한다.
- 2-5를 반복한다.
앞으로는 표기의 단순성을 위해 $\phi^i$를 단순히 $\phi$로 표기하겠다. 이제 gradient $\nabla_{\phi} \mathcal{L} (x; \theta, \phi)$와 $\nabla_{\theta} \mathcal{L} (x; \theta, \phi)$를 계산해야 하는데, $\mathcal{L} (x; \theta, \phi)$는 계산할 수 없다. 대신 3.2.4절의 importance sampling을 이용하여 근사한다. 즉 $q(z; \phi)$로부터 $z^1, \cdots, z^K$를 샘플링하여 다음과 같이 나타낸다.
\[\begin{aligned} \mathcal{L} (x; \theta, \phi) &= \mathbb{E}_ {z \sim q(z; \phi)} \left[ \log p_\theta(x, z) - \log q(z; \phi) \right] \\ &\approx \frac{1}{K} \sum_ {k=1}^K \left[ \log p_\theta(x, z^k) - \log q(z^k; \phi) \right] \\ \end{aligned}\]Gradient $\nabla_{\theta} \mathcal{L} (x; \theta, \phi)$는 다음과 같이 쉽게 계산할 수 있다.
\[\begin{aligned} \nabla_{\theta} \mathcal{L} (x; \theta, \phi) &= \nabla_{\theta} \mathbb{E}_ {q(z; \phi)} \left[ \log p_\theta(x, z) - \log q(z; \phi) \right] \\ &= \mathbb{E}_ {q(z; \phi)} \left[ \nabla_{\theta} \log p_\theta(x, z) \right] \\ &\approx \frac{1}{K} \sum_ k \nabla_{\theta} \log p_\theta(x, z^k) \\ \end{aligned}\]Gradient $\nabla_{\phi} \mathcal{L} (x; \theta, \phi)$의 경우는 말이 다르다. 기댓값이 $\phi$에 관련된 분포 $q(z; \phi)$에 의해 계산되기 때문에 sampling을 통해서 gradient를 계산할 수 없다. 이를 reparametrization trick이라는 방법으로 해결한다. 대신 이 방법은 $z$가 continuous variable일 때만 가능하다. 아쉽지만 이제 절반만 있는 MNIST 데이터셋을 떠나보내야 할 때인 것 같다 😥
📌 Reparametrization trick 대신 discrete variable일 때도 graient를 계산할 수 있는 “REINFORCE”라는 테크닉을 사용해서 근사할 수도 있다. 대신 REINFORCE 방법은 reparametrization trick에 비해 high variance를 가진다.
3.6. Reparametrization trick
$r(z) = \log p_\theta(x, z) - \log q(z; \phi)$라고 두면, 우리는 아래 식의 $\phi$에 대한 gradient를 계산해야 한다.
\[\begin{aligned} \mathbb{E}_ {q(z; \phi)} \left[ r(z) \right] &= \int q(z; \phi) r(z) dz \\ \end{aligned}\]지금까지 $q(z; \phi)$는 임의의 분포였는데, 이제 이를 가우시안 분포라고 가정하자. 즉 $q(z; \phi) = \mathcal{N}(\mu, \sigma^2 I)$이고, 파라미터 $\phi = (\mu, \sigma)$라고 하자. 이때 $z$를 다음과 같이 두 가지 방법으로 샘플링할 수 있다. 두 방법은 동치(equivalent)이다.
- $z \sim q(z; \phi)$로부터 샘플링한다.
- $\epsilon \sim \mathcal{N}(0, I)$로부터 샘플링한 뒤, $z = \mu + \sigma \epsilon = g(\epsilon; \phi)$로 계산한다.
두 방법은 완전히 동일하다. 즉, 아래의 두 식도 동일하다.
\[\begin{aligned} \mathbb{E}_ {q(z; \phi)} \left[ r(z) \right] &= \int q(z; \phi) r(z) dz \\ \mathbb{E}_ {\epsilon \sim \mathcal{N}(0, I)} \left[ r(g(\epsilon; \phi)) \right] &= \int \mathcal{N}(\epsilon) r(g(\epsilon; \phi)) d\epsilon \\ \end{aligned}\]하지만 이제 기댓값이 $\phi$에 관련된 분포에 의존하지 않고, $\epsilon$에 대한 분포에 의존하게 되었다. 이제 gradient를 계산할 수 있다.
\[\begin{aligned} \nabla_{\phi} \mathbb{E}_ {q(z; \phi)} \left[ r(z) \right] &= \nabla_{\phi} \mathbb{E}_ {\epsilon} \left[ r(g(\epsilon; \phi)) \right] \\ &= \mathbb{E}_ {\epsilon} \left[ \nabla_{\phi} r(g(\epsilon; \phi)) \right] \\ &\approx \frac{1}{K} \sum_ k \nabla_{\phi} r(g(\epsilon^k; \phi)) \\ \end{aligned}\]이것으로 모든 variational autoencoder의 학습 과정을 완성했다.
3.7. Amortized inference
지금까지 우리의 학습 목표를 표현하면 다음과 같다.
\[\begin{aligned} \max_{\theta} \ell (\theta; \mathcal{D}) \geq \max_{\theta, \phi^1, \cdots, \phi^m} \sum_ {x^i \in \mathcal{D}} \mathcal{L} (x^i; \theta, \phi^i) \\ \end{aligned}\]즉 우리는 지금까지 모든 데이터 $x^i$에 대하여 다른 $\phi^i$를 훈련했다. 이는 데이터셋의 크기가 커지면 커질수록 더 비효율적이다. Amortized inference는 이를 해결하기 위한 방법이다. 즉, $x^i \mapsto \phi^{i, *}$ 대응을 잘 학습하는 하나의 네트워크 $f_\lambda$를 학습하는 것이다. 예를 들어, 원래 $q(z; \phi^i)$가 모두 다른 평균 $\mu^1, \cdots, \mu^m$을 가지고 있었다면, 이제는 $f_\lambda$에 의해 $x^i$가 있으면 $\mu^i$를 출력하도록 학습한다. 이제 다음과 같이 표기할 수 있다.
\[q(z; \phi^i) = q(z; f_\lambda(x^i)) = q_\phi(z \mid x)\]Amortized inference를 사용하면 ELBO를 다음과 같이 나타낼 수 있다.
\[\begin{aligned} \mathcal{L} (x; \theta, \phi) &= \mathbb{E}_ {z \sim q_\phi(z|x)} \left[ \log p_\theta(x, z) - \log q_\phi(z|x) \right] \\ \end{aligned}\]이제 $\theta, \phi$에 대해 $\sum_ {x^i \in \mathcal{D}} \mathcal{L} (x^i; \theta, \phi)$를 최적화하는 과정은 다음과 같다.
- $\theta, \phi$를 초기화한다.
- 데이터셋 $\mathcal{D}$에서 랜덤으로 $x^i$를 샘플링한다.
- $\nabla_\theta \mathcal{L} (x^i; \theta, \phi)$, $\nabla_\phi \mathcal{L} (x^i; \theta, \phi)$를 계산한다.
- $\theta, \phi$를 업데이트한다. (반복)
3.8. Interpretation
지금까지의 과정을 autoencoder의 관점으로 다시 한 번 정리해보자. Encoder, decoder의 과정을 각각 다음과 같이 해석할 수 있다.
- Encoder: 데이터 $x^i$가 주어졌을 때, $q_\phi(z \mid x^i)$로부터 $\hat{z}$를 샘플링한다. 이때 $q_\phi(z \mid x^i)$는 가우시안 분포로 파라미터 $(\mu, \sigma) = \mathrm{encoder}_ \phi (x^i)$에 의해 결정된다.
- Decoder: $\hat{z}$가 주어졌을 때, $p_\theta(x^i \mid \hat{z})$로부터 $\hat{x}$를 샘플링한다. 이때 $p_\theta(x^i \mid \hat{z})$는 가우시안 분포로 파라미터는 $\mathrm{decoder}_ \theta (\hat{z})$에 의해 결정된다.
한편, 우리의 학습 목표인 ELBO를 다시 정리해보면 다음과 같다.
\[\begin{aligned} \mathcal{L} (x; \theta, \phi) &= \mathbb{E}_ {q_\phi (z \mid x)} \left[ \log p_\theta(x, z) - \log q_\phi(z \mid x) \right] \\ &= \mathbb{E}_ {q_\phi (z \mid x)} \left[ \log p_\theta(x, z) - \log p(z) + \log p(z) - \log q_\phi(z \mid x) \right] \\ &= \mathbb{E}_ {q_\phi (z \mid x)} \left[ \log p_\theta (x \mid z) \right] - D_{KL}(q_\phi(z \mid x) \parallel p(z)) \\ \end{aligned}\]두 항의 역할은 다음과 같이 해석할 수 있다.
- Reconstruction term: $\mathbb{E}_ {q_\phi (z \mid x)} \left[ \log p_\theta (x \mid z) \right]$는 $x_i$가 주어졌을 때, $z$로부터 $x_i$를 잘 복원하는지를 나타낸다. 즉, $x_i \approx \hat{x}_i$이도록 하는 autoencoding loss이다.
- Regularization term: $D_{KL}(q_\phi(z \mid x) \parallel p(z))$는 $q_\phi(z \mid x)$가 $p(z)$와 얼마나 가까운지를 나타낸다. 즉, $q_\phi(z \mid x)$가 $p(z)$와 가까워지도록 하는 regularization term이다.
Variational autoencoder가 일반적인 autoencoder와 달리 생성모델인 것은 두 번째 항, 즉 regularization term 때문이다. 우리는 3.1절에서 $p(z) = \mathcal{N}(0, I)$로 가정하였다. Regularization term에 의해 $q_\phi(z \mid x)$가 $p(z)$와 충분히 가까워졌다면 $p(z)$라는 단순한 분포에서 $z$를 샘플링해도 $x$를 잘 복원할 수 있을 것이다. 이것이 바로 variational autoencoder가 생성모델인 이유이다.
3.9. Research directions
Variational autoencoder의 variational learning을 향상시키고자 하는 다양한 연구가 진행되고 있다. 이는 크게 3가지 방향으로 나눌 수 있다.
- Better optimization techniques
- More expressive approximating families
- Alternative loss functions
관련 논문들은 CS236n 강의안에 나열되어 있으므로, 관심있는 분은 참고하기 바란다.
4. Summary
지금까지 latent variable model에 대해서 알아보았다. Latent variable model의 기본 가정은 데이터를 설명하는 latent variable $z$가 존재한다는 것이고, 이로부터 얻은 간단한 분포들 $p(x \mid z)$의 합으로 복잡한 분포 $p(x)$를 설명할 수 있다는 것이다. 이때 ancestral sampling, 즉 $z \sim p(z)$, $x \sim p(x \mid z; \theta)$로부터 데이터를 생성할 수 있다.
그러나, log-likelihood는 계산이 불가능(intractable)했고, 따라서 ELBO를 통해 이를 근사하고자 했다. 이때 모델의 파라미터 $\theta$와 amortized inference의 파라미터 $\phi$를 같이(jointly) 학습하였다. 데이터 $x$의 latent representation은 학습 후 $q_\phi (z \mid x)$로 추론할 수 있다.
댓글 남기기