[22’ NIPS] Flamingo: a Visual Language Model for Few-Shot Learning
카테고리: Multimodal
🔍 Abstract
Flamingo는 VLM의 일종으로 task-specific fine-tuning 없이도 다양한 task에서 few-shot learning으로 SOTA 성능을 보이는 모델이다. Flamingo는 다음 3가지 architectural contribution을 가지며, 대부분 Perceiver Resampler라고 하는 구조에 의해 해결되었다.
- Pretrained vision-only model 및 language-only model을 연결하는 방법을 제안한다. 참고로 이러한 구조는 이후 연구에서 modality interface라는 이름으로 불리며 계속 사용되었다.
- 임의의 개수와 길이를 가진 visual 및 textural data를 잘 처리할 수 있도록 설계되었다.
- Visual data가 image이든 video이든 동일하게 잘 처리할 수 있도록 설계되었다.
이러한 I/O에 대한 유연성 덕분에 Flamingo는 기존의 잘 가공된 데이터셋이 아닌 인터넷으로부터 수집한 거대한 데이터셋으로 훈련이 가능했으며, 이를 통해 in-context few-shot learning에서의 성능을 향상시켰다.
1. Architecture
Flamingo의 전체 구조는 위와 같다. 먼저 Perceiver Resampler는 Vision Encoder로부터 얻은 spatio-temporal feature를 가공하여 길이가 고정된 visual token을 생성한다. 이 token은 gated cross-attention layer를 통해 Language Model에 전달된다. Text 사이에 끼워들어간(interleaved) 이미지 또는 비디오를 $x$라고 하면, Flamingo는 다음 식을 모델링한다.
\[p(y \vert x) = \prod_ {\ell = 1} ^ L p(y_ \ell \vert y_ {\lt \ell}, x_ {\lt \ell})\]끼워들어갔다(interleaved)는 의미는 위 그림을 참고하면 쉽게 이해할 수 있을 것이다. 그럼 이제 각 구조를 자세히 살펴보자. 가장 중요한 부분은 Perceiver Resampler, 그 다음으로 중요한 부분은 gated cross-attention layer라는 것만 기억하자.
1.1. Vision Encoder and Perceiver Resampler
Vision Encoder로는 NFNet(Normalizer-Free ResNet)을 사용했고, CLIP과 같이 image-text pair에 대한 contrastive objective로 훈련했다. 이러한 방식으로 image-text alignment를 어느 정도 가할 수 있다.
Perceiver Resampler는 Vision Encoder로부터 얻은 spatio-temporal feature를 가공하여 64개의 visual token을 생성한다. 이러한 token의 개수가 정해져 있기에 이후 계산의 복잡도가 감소한다. Token의 개수를 어떻게 동일하게 생성했는지 보면, DETR과 같이 latent queries의 개수를 64개로 제한하는 방식을 사용하였음을 알 수 있다.
1.2. Gated Cross-Attention Layer
Gated Cross-Attention Layer는 Vision Encoder로부터 얻은 visual token과 Language Model로부터 얻은 text token을 연결하는 역할을 한다. 다른 cross-attention과 다른 점은 $\tanh$ gating을 사용한다는 것이다. 즉, 학습 처음에는 cross-attention 시 $\tanh(\alpha)$ gating에서 $\alpha=0$으로 두어 cross-attention이 사용되지 않도록 한 것이다. 이 방식을 사용하면 freeze LM의 output을 훼손시키지 않기에 초기 학습의 안정성을 높일 수 있고, $\alpha$는 learnable parameter로 학습이 진행됨에 따라 vision input의 contribution이 적절한 정도로 조절된다.
1.3. Multi-visual Input Support
저자들은 모델이 특정 시간에서 단 하나의 이미지에만 접근하도록 제한했다. 위 그림을 보면 직관적으로 이해할 수 있다. 이전의 이미지 정보는 self-attention에 담겨 계속 전달된다고 가정한다. Ablation study에 따르면, 모든 이미지에 mask 없이 한번에 접근하는 것보다 이러한 방식이 더 효과적이었다. 그리고 학습 시에는 최대 5개의 이미지만 볼 수 있도록 제한했는데, 추론 시에는 32개의 이미지까지 볼 수 있도록 하였으며 이 때에도 성능의 향상이 있었다.
1.4. Multi-objective Optimization
저자들은 여러 데이터셋에 대하여 학습을 진행했기 때문에 다음과 같은 objective를 사용했다.
\[\sum_ {m=1} ^ M \lambda_ m \cdot \mathbb{E} _ {(x, y) \sim \mathcal{D}_ m} \left[ - \sum _ {\ell = 1} ^ L \log p(y_ \ell \vert y_ {\lt \ell}, x_ {\lt \ell}) \right]\]이러한 학습에 있어 per-dataset weights $\lambda_ m$을 잘 조절하는 것이 성능에 지대한 영향을 끼쳤다고 한다. 그리고 각 데이터셋에서 한 번씩 최적화를 반복하는 “round-robin” 방식보다는 여러 데이터셋에서의 gradient를 축적해 한번에 backpropagation을 진행하는 accumulation 방식이 더 효과적이었다고 한다.
2. Experiments
2.1. Few-shot Learning
Flamingo의 메인 결과다. 결론은 간단한데, 지금까지의 few-shot SOTA를 모두 뛰어넘었으며, fine-tuning SOTA보다도 높은 성능을 보이는 경우도 있었다. 그리고 shot이 증가함에 따라, 모델 크기가 커짐에 따라 성능이 향상되는 scalability를 보였다.
2.2. Fine-tuning
그리고 Flamingo가 얼마나 더 잘 할 수 있는지 궁금하니 fine-tuning을 진행했다. 결과적으로 과반수의 task에서 또 SOTA를 달성했다. 그러나 이 모델이 추구하는 바가 아니기에 그렇게 중요한 결과는 아닌 듯.
2.3. Ablation Study
마지막으로 ablation study이다. 위에서 언급했던 내용은 제외하고, 중요한 내용만 정리해보았다.
- 다양한 dataset을 사용하는 것이 굉장히 중요하다. Image-text와 Video-text를 모두 사용하는 것이 성능을 높이는 데 중요한 역할을 한다.
- Gated cross-attention layer를 모든 layer마다 사용하면 성능이 향상되지만, 많은 계산량이 필요한 trade-off가 있다. 그러나 전체 layer의 1/4 정도에만 사용해도 약간의 성능 하락만 있을 뿐이어서, 모델이 거대하면 전체 layer의 일부에만 사용해도 된다. Flamingo-9B, 80B에서 각각 1/4, 1/7의 layer에만 적용했다.
- LM fine-tuning은 오히려 성능을 하락시킨다. 이는 catastrophic forgetting 때문이다. 따라서 LM을 freeze하는 것이 더 좋다.
3. Limitations
저자들은 Flamingo의 limitation으로 다음 3가지를 언급했다.
첫 번째로, Flamingo는 pretrained LM의 단점을 그대로 물려받는다. 위와 같이 hallucination이 발생하거나, 정답이 없는 경우, 즉 ungrounded guess에 대해 잘 처리하지 못한다. 또한, training 때보다 긴 sequence를 잘 처리하지 못한다.
두 번째로, classification 성능이 CLIP으로 대표되는 contrastive model에 비해 떨어진다. Contrastive model은 label을 직접 text에 넣어 유사도를 비교할 수 있지만, Flamingo는 autoregressive model이기에 open-ended answer를 내도록 되어 있기에 당연한 결과이다.
세 번째로, in-context learning은 사람의 손을 많이 탄다. 무엇이 최적의 prompt인지 알 수 없다는 뜻이다. 따라서 prompt engineering 분야의 연구가 필요하다고 말하고 있다.
💡 Summary
Flamingo는 Vision Encoder와 LM 사이를 연결(bridge)하는 modality interface 개념을 초기에 도입한 중요한 논문이다. 본 논문에서는 Perceiver Resampler와 gated cross-attention layer를 통해 vision-language multimodal model을 구현하는 방법을 제시했다. 이러한 방법은 이후 multimodal model의 핵심이 되었으며, few-shot learning에서의 성능을 높이는 데 중요한 역할을 했다. 워낙 중요한 논문이기에 읽어보기를 권한다.
댓글 남기기