[23’] LISA: Reasoning Segmentation via Large Language Model
카테고리: Multimodal
🔍 Abstract
ICLR 2024에서 withdrawl한 논문이지만, 이를 기반으로 파생된 모델이 있고, 2024년 5월 말 기준 120회 넘게 인용된 논문이기에 충분히 중요한 논문이라 판단되어 리뷰하였다. 본 논문에서는 다음 3가지를 제시한다.
- Reasoning segmentation이라는, 복잡한 추론이 필요한 text에 대해 segmentation을 수행하는 task를 최초로 제시한다.
- 제시한 task에 대한 benchmark로 ReasonSeg를 제시한다.
- 제시한 task에 대한 model로 LISA(Large Language Instructed Segmentation Assistant)를 제시한다.
이때 LISA는 기존 vocabulary에 <SEG>
token을 추가하여 LLM 모델이 이 token을 예측하도록 만들고, 이를 embedding-as-mask paradigm으로 해석한다. 즉, <SEG>
token은 vision decoder에 들어가 mask로 decoding된다. <SEG>
token을 활용한 아키텍쳐를 제안해 reasoning segmentation task에 대한 해결책을 제시한 것이다.
참고: 오늘 다시 확인하니 CVPR 2024 Oral을 받았다. 이 분야 연구가 더욱 활발해질 전망이다.
1. Reasoning Segmentation
Reasoning segmentation과 유사한 task로 referring segmentation이 있다. 간단한 단어가 text input이 되는 referring segmentation과 달리, reasoning segmentation은 복잡한 추론이 필요하고, 긴 문장이 포함된 text가 input이 된다. 예를 들면 다음과 같다.
- Referring Segmentation: “the trash can”
- Reasoning Segmentation: ““something that the garbage should be put into”, “After cooking, consuming food, and preparing for food, where can we throw away the rest of the food and scraps?”
따라서 reasoning segmentation task는 complex reasoning과 world knowledge를 요구한다. 또한, 이를 평가하기 위한 벤치마크인 ReasonSeg를 제시한다. 이는 OpenImages와 ScanNetv2 이미지에 직접 annotation을 수행하여 1218개의 image-instruction-mask triplet을 제공한다.
2. Method
2.1. Architecture
저자들이 주장한 것은 embedding-as-mask paradigm이다. 먼저 저자들은 Multimodal LLM vocabulary에 <SEG>
token을 추가하였다. 이 토큰은 image-text input에 의해 생성되어 mask 정보를 implicit하게 담고 있다고 가정한다. 따라서 이를 vision decoder에 넣어 mask로 decoding하도록 한다.
자세한 과정은 다음과 같다. image-text input $\mathbf{x}_ \text{img}, \mathbf{x}_ \text{txt}$를 LLM에 넣어 text output $\hat{\mathbf{y}}$를 얻는다. 이때 prompt $\mathbf{x}_ \text{txt}$의 명령에 의해, $\hat{\mathbf{y}}$는 <SEG>
token을 포함하고 있다.
이때 <SEG>
에 해당하는 LLM의 last-layer embedding $\tilde{h}_ \text{seg}$를 MLP projection layer $\gamma$를 통해 mask embedding $h_ \text{seg}$로 변환한다. 이후 이를 vision decoder에 넣어 mask prediction을 수행한다. 이 과정을 정리하면 다음과 같다.
Decoder의 구조는 아래와 같은 SAM lightweight decoder를 사용하였다.
Training objective는 LoRA Fine-tuning 및 Segmentation Loss로 구성되어 있다.
2.2. Training
Training data는 기존 semantic segmentation dataset, referring segmentation dataset, 그리고 VQA dataset을 사용하였고, segmentation dataset의 경우 prompt의 형태로 rule-based로 재구성하였다. VQA dataset은 기존 MLLM의 VQA ability를 보존하기 위해 사용하였다고 한다(즉, segmentation output에만 overfitting되지 않도록 하여 catastrophic forgetting을 방지함). 저자들은 이러한 데이터셋은 전혀 reasoning segmentation task에 대한 데이터셋이 아니라고 강조한다. 따라서 ReasonSeg의 training set을 통해 fine-tuning한 결과(ft)도 비교하였다.
3. Experiments
저자들은 먼저 가장 중요한 결과로 benchmark인 ReasonSeg에서의 결과를 비교하였다. 당연히 새로운 task이기 때문에 다른 모델들과 직접적인 비교는 불가능하고, 그 외의 결론은 3가지이다.
- Parameter 수가 많을수록 성능이 좋다.
- 기본 LLM 성능이 높을수록 성능이 좋다.
- ReasonSeg Training set으로 fine-tuning하면 성능이 좋다.
사실 당연한 이야기이다. 조금 더 유의미한 비교는 LLaVA + OVSeg vs LLaVA + LISA 비교이다. LLaVA + OVSeg은 LLaVA가 가능한 input query를 예측하면 OVSeg이 이를 가지고 segmentation을 진행하는 방식으로 진행했다. 이러한 방식의 문제로 저자들은 (1) end-to-end가 아닌 decoupled two-stage method이고, (2) $h_ \text{seg}$에 비해 text(=input query)는 덜 expressive하다고 제시했다. 따라서 $h_ \text{seg}$가 더 많은 정보를 잘 가지고 있다는 것이다.
다른 모델들과의 더 직접적인 비교를 위해 저자들은 기존 task인 referring segmentation에서도 성능이 잘 나오는지 확인하였고, 여기서도 성능이 높다는 것은 고무적이다.
그 외에 insight를 가질만한 결과들을 정리해보았다.
- SAM이 성능이 좋은 이유는 pre-training phase의 high-quality data 덕분이고, SAM LoRA finetuning은 generalization ability를 낮추기 때문에 성능이 좋지 않다.
- ReasonSeg fine-tuning 시 기존 text instruction을 GPT-3.5를 이용해 rephrasing하여 data augmentation 후 fine-tuning하면 성능이 좋아진다.
💡 Summary
LISA에서는 reasoning segmentation이라는 새로운 task를 제시하였고, 이를 해결하기 위한 benchmark인 ReasonSeg와 model LISA를 제안했다. LISA는 embedding-as-mask paradigm 방식을 통해 mask 정보를 MLLM이 추출하도록 했다.
OpenReview의 평가를 보면 몇 가지 Limitation을 볼 수 있다. 먼저 <SEG>
token 하나이기 때문에 아직 multiple segmentation mask를 만들지 못한다. 그리고 여기서의 contribution은 <SEG>
token을 제안했다는 것인데, 이를 통해 reasoning segmentation이 가능해진 것이 아니라 기존 MLLM의 능력을 활용한 것뿐이라는 의견이 있다.
댓글 남기기