[21’ NIPS-WS] CFG: Classifier-free Diffusion Guidance

Date:     Updated:

카테고리:

태그:


0. Abstract

  • CG(Classifier Guidance)는 image classifer를 추가로 학습해야 한다는 문제가 있었다. 이를 해결하기 위해 CFG(Classifier-free Diffusion Guidance)를 제안하였으며, 이는 classifer 없이 conditional image synthesis를 수행할 수 있게 해주었다.


1. Introduction

CG(Classifier Guidance) 이전까지는 GAN에서의 truncation, Glow의 low temperature sampling과 같이 low temperature sampling을 하는 방법이 없었다. 즉 diversity가 낮고, fidelity가 높도록 sampling하는 방법이 없었다는 것이다. Model의 score vector (score function) 값을 scaling한다든지, Gaussian noise 크기를 감소한다든지 하는 나이브한 방법들은 별로 효과적이지 않았다. 그렇다고 CG(Classifier Guidance)를 사용하자니 몇 가지 문제가 있었다. 첫 번째는 classifier를 추가로 훈련해야 한다는 점이고, 두 번째는 훈련 과정이 마치 classifier와 diffusion model이 서로를 학습하는 것처럼 보인다는 점이다. 즉, classifier에 대한 gradient-based adversarial attack이라는 것이다. FID나 IS와 같은 척도들은 classifier-based metric인데, classifier guidance를 사용하면 당연히 이러한 척도가 향상될 것이고 실제 효과에 비하여 높은 점수를 받게 된다는 것이다. 따라서 여기서는 classifier-free diffusion guidance를 제안하였다.


2. Background

Notation이 조금 더 generalized version으로 작성되어 있어 앞으로의 논문을 이해하기 위해 간단하게 정리하겠다. 원본 이미지는 $\mathbf{x}$이고, latent variable은 $\mathbf{z} = \lbrace \mathbf{z}_ \lambda \vert \lambda \in \left[\lambda_ {\text{min}}, \lambda_ {\text{max}}\right] \rbrace$이다. Forward process는 다음과 같은 과정으로 진행된다고 이해하면 된다.

\[\mathbf{x} \rightarrow \mathbf{z}_ {\lambda_ {\text{max}}} \rightarrow \cdots \rightarrow \mathbf{z}_ {\lambda ^\prime} \rightarrow \mathbf{z}_ {\lambda} \rightarrow \cdots \rightarrow \mathbf{z}_ {\lambda_ {\text{min}}}\]

분포는 다음과 같이 나타낼 수 있다.

\[\begin{aligned} q(\mathbf{z}_ {\lambda} \vert \mathbf{x}) &= \mathcal{N}(\alpha_ \lambda \mathbf{x}, \sigma_ \lambda ^ 2 \mathbf{I}), \quad \text{where } \alpha_ \lambda ^ 2 = \frac{1}{1 + e^{-\lambda}}, \sigma_ \lambda ^ 2 = 1 - \alpha_ \lambda ^ 2 \\ q(\mathbf{z}_ {\lambda} \vert \mathbf{z}_ {\lambda ^\prime}) &= \mathcal{N}(\frac{\alpha_ \lambda}{\alpha_ {\lambda ^\prime}} \mathbf{z}_ {\lambda ^\prime}, \sigma_ {\lambda \vert \lambda ^\prime} ^ 2 \mathbf{I}), \quad \text{where } \sigma_ {\lambda \vert \lambda ^\prime} ^ 2 = (1 - e^ {\lambda - \lambda ^\prime}) \sigma ^ 2 _ {\lambda} \end{aligned}\]

여기서 $\lambda = \log \alpha_ \lambda ^ 2 / \sigma_ \lambda ^ 2$로 log signal-to-noise ratio로 해석할 수 있다. 따라서 forward process가 진행될수록 $\lambda$가 감소하는 것이 옳다. 한편 reverse process는 다음과 같이 parametrize할 수 있다.

\[p_\theta (\mathbf{z}_ {\lambda ^ \prime} \vert \mathbf{z}_ {\lambda}) = \mathcal{N}(\tilde{\mu}_ {\lambda ^ \prime \vert \lambda} (\mathbf{z}_ \lambda, \mathbf{x}_ \theta (\mathbf{z}_ \lambda)), (\tilde{\sigma}_ {\lambda ^ \prime \vert \lambda} ^ 2) ^ {1-v} (\tilde{\sigma}_ {\lambda \vert \lambda ^ \prime} ^ 2) ^ v)\]

Sampling 과정은 $\lambda_{\text{min}} = \lambda_1$부터 시작하여 $T$ 단계로 $\lambda_{\text{max}} = \lambda_T$까지 진행된다. DDPM과 numbering이 반대임에 주의해야겠다. 본 논문의 notation이 조금 더 reverse process를 강조하는 numbering이라고 할 수 있다. 학습 시에는 $\mathbf{x}_ \theta$를 reparametrization한 $\epsilon_ \theta$를 사용한다.

\[\mathbb{E}_ {\epsilon, \lambda} \left[ \Vert \epsilon_ \theta (\mathbf{z}_ \lambda) - \epsilon \Vert ^ 2_ 2 \right]\]

여기서 $p(\lambda)$가 uniform distribution이면 objective는 latent variable model의 variational lower bound와 같다. 그렇지 않으면 weighted variational lower bound가 되며, DDPM에서의 방법이 이 방법이다. 여기서는 IDDPM의 cosine noise scheduling에 영감을 받아 hyperbolic noise scheduling을 사용하였다. 즉, $\lambda = -2 \log \tan (au + b)$로 두는 것이다. (여기서 $u \in [0, 1]$인 uniform distribution이다.)


3. Guidance

3.1. Classifier Guidance

Class label $\mathbf{c}$가 주어졌을 때, conditional model을 $\epsilon_ \theta (\mathbf{z}_ \lambda, \mathbf{c})$로 확장할 수 있다. 그리고 estimator가 다음과 같이 score를 잘 학습하기를 바랄 것이다.

\[\epsilon_ \theta (\mathbf{z}_ \lambda, \mathbf{c}) \approx - \sigma_ \lambda \nabla _ {\mathbf{z}_ \lambda} \log p (\mathbf{z}_ \lambda \vert \mathbf{c})\]

CG(Classifier Guidance) 논문에서 제안했던 것은 이러한 conditional model에 classifier guidance를 추가하는 것이다. 정확히는 classifier guidance가 핵심이고 conditional model을 쓰면 추가적인 성능 개선이 있다는 뉘앙스였지만… 저자들의 논리 전개에 맞게 변형된 듯하다. Classifier model을 $p_ \theta (\mathbf{c} \vert \mathbf{z}_ \lambda)$로 정의하면 modified score $\tilde{\epsilon}_ \theta$는 다음과 같이 나타낼 수 있다.

\[\begin{aligned} \tilde{\epsilon}_ \theta (\mathbf{z}_ \lambda, \mathbf{c}) &= \epsilon_ \theta (\mathbf{z}_ \lambda) - w \sigma_ \lambda \nabla _ {\mathbf{z}_ \lambda} \log p_ \theta (\mathbf{c} \vert \mathbf{z}_ \lambda) \\ &\approx - \sigma_ \lambda \nabla _ {\mathbf{z}_ \lambda} \left[ \log p (\mathbf{z}_ \lambda \vert \mathbf{c}) + w \log p_ \theta (\mathbf{c} \vert \mathbf{z}_ \lambda) \right] \\ \end{aligned}\]

따라서 CG는 다음과 같은 분포에서 샘플링을 수행하는 것으로, classifier가 correct label $\mathbf{c}$에 대해 high likelihood를 가질 것이라 보는 데이터의 샘플링 확률을 높이는 것이라 볼 수 있다. 따라서 IS나 FID와 같은 classifier-based metric을 향상시킬 수 있다.

\[\tilde{p}_ \theta (\mathbf{z}_ \lambda \vert \mathbf{c}) \propto p_ \theta (\mathbf{z}_ \lambda \vert \mathbf{c}) p_ \theta (\mathbf{c} \vert \mathbf{z}_ \lambda) ^ w\]

이론적으로 CG weight을 $w+1$만큼 주고 unconditional model을 학습시키는 것과, CG weight을 $w$만큼 주고 conditional model을 학습시키는 것은 동일하다.

\[\begin{aligned} \tilde{\epsilon}_ \theta (\mathbf{z}_ \lambda, \mathbf{c}) &= \epsilon_ \theta (\mathbf{z}_ \lambda) - (w+1) \sigma_ \lambda \nabla _ {\mathbf{z}_ \lambda} \log p_ \theta (\mathbf{c} \vert \mathbf{z}_ \lambda) \\ &\approx - \sigma_ \lambda \nabla _ {\mathbf{z}_ \lambda} \left[ \log p (\mathbf{z}_ \lambda) + (w+1) \log p_ \theta (\mathbf{c} \vert \mathbf{z}_ \lambda) \right] \\ &= - \sigma_ \lambda \nabla _ {\mathbf{z}_ \lambda} \left[ \log p (\mathbf{z}_ \lambda \vert \mathbf{c}) + w \log p_ \theta (\mathbf{c} \vert \mathbf{z}_ \lambda) \right] \\ \end{aligned}\]

그러나 실제로 IDDPM에서는 CG weight을 $w$만큼 주고 conditional model을 학습시키는 것의 결과가 더 좋았다. 이는 conditional model이 중요하다는 것을 의미하고, 따라서 본 논문에서는 classifier-free diffusion guidance를 선택하였다.

3.2. Classifier-free Guidance

Classifier-free guidance에서는 unconditional model $p_ \theta (\mathbf{z})$와 conditional model $p_ \theta (\mathbf{z} \vert \mathbf{c})$를 같이 학습시킨다. 즉, score estimator $\epsilon_ \theta(\mathbf{z}, \mathbf{c})$를 같이 사용하는데 unconditional case에서는 null token $\varnothing$을 class label로 사용하는 것이다. 즉, $\epsilon_ \theta(\mathbf{z}, \varnothing) = \epsilon_ \theta(\mathbf{z})$이다. 그리고 estimator는 다음과 같이 conditional과 unconditional score estimate의 linear combination으로 선택하였다.

\[\tilde{\epsilon}_ \theta(\mathbf{z}_ \lambda, \mathbf{c}) = (1+w) \epsilon_ \theta(\mathbf{z}_ \lambda, \mathbf{c}) - w \epsilon_ \theta(\mathbf{z}_ \lambda)\]

이와 같은 방법으로 estimator를 정의하면 더 이상 classifier gradient를 사용하지 않고, classifier에 대한 gradient-based adversarial attack으로 해석될 수 있다는 문제를 해결하게 된다. 대신 마치 implicit classifier가 존재하는 것처럼 학습이 이루어진다고 생각할 수 있다.

저자들은 CFG(Classifier-free Guidance)의 학습 방법 및 샘플링 방법을 다음과 같이 나타내었다. Conditional sample에 비해 unconditional sample을 얼마나 많이 학습할지도 hyperparameter $p_{\text{uncond}}$로 두었다.


4. Experiments

4.1. Varying the classifier-free guidance strength

본 논문에서는 FID를 diversity, IS를 fidelity로 해석하여 둘 사이의 trade-off를 분석하였다. 결과를 보지 않아도 $w$가 증가하면 diversity는 떨어지고 fidelity가 증가하므로 FID가 증가하고, IS가 증가할 것으로 예상된다. 실제로 그렇다. 그리고 CG(ADM)과 비교하여 성능이 더 좋은 것을 확인할 수 있다.

정성적으로 보아도 그렇다. 흥미로운 점은 CFG를 사용한 경우 채도(saturation)가 높아진다는 것이다.

4.2. Varying the unconditional training probability

$p_{\text{uncond}}$에 따라 IS-FID curve를 그려보면 다음과 같다. 그래프가 더 낮을수록 좋은데, unconditional sample을 더 많이 학습할수록 결과가 나쁜 것을 볼 수 있다. 따라서 적은 수의 unconditional sample만 학습해도 샘플을 생성해내는 데 충분하다는 결론을 얻을 수 있다.

4.3. Varying the number of sampling steps

당연히 sampling step 수가 길어질수록 효과가 좋은데, $T=256$과 $T=512$는 큰 차이가 없었다. 따라서 본 논문에서는 $T=256$을 선택하여 사용하였다.


5. Discussion

CFG의 장점은 extreme simplicity이다. Classifier를 따로 훈련할 필요 없고, 코드 상으로 간단하게 구현할 수 있다. 게다가 gradient-based adversarial attack을 피할 수 있다. 그러나 단점은 sampling speed가 느리다는 것이다. CG의 경우 sampling을 위해 classifer와 diffusion model을 사용하지만, CFG의 경우 diffusion model만 2회 사용한다. 따라서 CG보다 CFG가 느리게 동작하게 된다. 저자들은 network 후반부에만 condition을 넣어주는 식으로 해결할 수 있지 않을까 하고 추측한다.


💡 Summary

  • CGclassifier를 추가로 훈련해야 하고, classifier에 대한 gradient-based adversarial attack으로 해석될 수 있는 문제가 있다.
  • CFGclassifier-free guidance로 이러한 문제를 해결하였다. 이는 unconditional model과 conditional model을 함께 학습하는 것이다.
  • CFG의 conditional weight $w$를 조절하여 diversity와 fidelity 간 trade-off를 조절할 수 있다.
  • Opinion: CFG는 지금까지도 굉장히 중요하게 사용되는 테크닉이다. 저자들이 말한대로 simplicity가 가장 큰 장점이다. 그러나 ImageNet 데이터셋에서만 실험을 진행했다는 점이 아쉽다.


📃 Reference


Generation 카테고리 내 다른 글 보러가기

댓글 남기기