[22’ EMNLP] MedCLIP: Contrastive Learning from Unpaired Medical Images and Text
카테고리: Medical
🔍 Abstract
저자들은 기존 Medical Domain에서의 CLIP Pre-training에 여러 문제가 있다고 주장한다. 일반적으로 Medical Image-Text Pair를 사용하여 InfoNCE Loss로 훈련하는데, 이 부분에서 문제가 생긴다는 것이다. 주요한 문제들은 다음과 같다.
- Paired data의 수가 부족하다. Image-only, text-only dataset은 CLIP 학습에 사용할 수 없다.
- Contrastive learning 시에는 FN(False Negative)가 생길 수 있다. 즉, negative text로 사용한 것들 중에서도 positive image의 설명과 일치하는 부분이 있을 수 있다.
따라서 저자들은 이러한 문제를 해결하기 위해 각각의 데이터의 더 깊은 의미를 추출하여 학습하는 soft target learning 방식을 제안하고, 이러한 방식으로 MedCLIP 모델을 제안하였다. 해당 모델은 성능이 좋을 뿐만 아니라, 적은 데이터로도 학습이 가능하다는 장점을 가지고 있다.
1. Method
저자들은 위에서 언급한 각각의 문제를 (1) Image-Text Decoupling, (2) Semantic Matching Loss를 제안하여 해결하였다. 각각의 방법을 알아보자.
1.1. Image-Text Decoupling
먼저 Image-Text Decoupling은 image, text의 semantic feature를 추출하여 pair가 아니더라도 유사도를 계산할 수 있도록 하는 방법이다. 데이터셋이 image-text pair가 $n$개, image-only $m$개, text-only $h$개로 이루어져 있다면, CLIP에서는 학습을 $n^ 2$ pair에 대해서만 할 수 있었다. 그러나 MedCLIP에서는 $(n + m) \times (n + h)$ pair에 대해서 학습을 진행할 수 있다. 이는 Image-only, Text-only dataset도 학습에 사용할 수 있어 data efficiency를 높이는데 도움이 된다.
먼저 이미지를 $x_ {img}$, 텍스트를 $x_ {txt}$라고 하자. 이때 $x_ {img}$로부터는 diagnosis label로부터 class를 추출하여 $l_ {img}$로, $x_ {txt}$로부터는 특정 entity를 추출하여 $l_ {txt}$로 변환한다. 이때 추출에 사용한 알고리즘은 MetaMap으로, 이를 활용하여 정형화된 UMLS(Unified Medical Language System) concept로 변환한다.
1.2. Semantic Matching Loss
이렇게 추출한 $l_ {img}$, $l_ {txt}$를 사용하여 Semantic Matching Loss를 계산한다. 이는 $l_ {img}$, $l_ {txt}$의 cosine similarity를 계산하여 이를 최대화하는 방향으로 간접적인 학습을 진행하는 것이다. 즉 $image v_ i$로부터 추출한 $l_ i$와 $text v_ j$로부터 추출한 $l_ j$로부터 soft target $s_ {ij}$와 normalized target $y_ {ij}$를 다음과 같이 정의한다.
\[s_ {ij} = \frac {l_ i \cdot l_ j} {\Vert l_ i \Vert \Vert l_ j \Vert} \quad y_ {ij} = \frac {\exp (s_ {ij})} {\sum_ {j} \exp (s_ {ij})}\]한편 실제 $x_ {img}$, $x_ {txt}$로부터 Encoder를 이용하여 얻은 embedding $v_ i$, $t_ j$을 사용해서 predicted similarity를 계산한다.
\[\hat {s_ {ij}} = \frac {v_ i \cdot t_ j} {\Vert v_ i \Vert \Vert t_ j \Vert} \quad \hat {y_ {ij}} = \frac {\exp (\hat {s_ {ij}})} {\sum_ {j} \exp (\hat {s_ {ij}})}\]이 둘 사이의 cross-entropy loss를 계산하여 학습을 진행하며, 저자들은 이를 Semantic Matching Loss라고 불렀다.
\[\mathcal {L} = - \frac {1} {N} \sum_ i \sum_ j y_ {ij} \log \hat {y_ {ij}}\]실제로는 CLIP과 같이 InfoNCE Loss를 사용한다.
\[\mathcal {L} = \frac {\mathcal{L}^ {v \rightarrow l} + \mathcal{L}^ {t \rightarrow l}} {2}\]이러한 soft target 방식을 사용하면 semantic meaning을 잘 추출할 수 있기 때문에, 기존의 CLIP training pipeline에서 발생하는 FN(False Negative) 문제를 해결할 수 있다.
학습 과정에서 Text Encoder는 BioClinicalBERT, Vision Encoder는 Swin Transformer를 사용하였다. 학습 시에는 대략 CheXpert(200K), MIMIC-CXR(200K) 데이터셋을 사용하였다. 실제로는 여기에 Image-only, Text-only data도 포함되더 더 많은 데이터를 사용하였다.
2. Evaluation
2.1. Main Results
성능 검증은 (1) Classification (Zero-shot, Fine-tuning), (2) Retrieval을 통해 이루어졌다. 결과는 전반적으로 다른 모델보다 확연히 우수한 성능을 보인다. 이는 학습 파이프라인의 semantic matching loss가 충분히 좋은 representation을 학습하는 데 큰 도움이 된다고 이해할 수 있다.
2.2. Data Efficiency
놀라운 점은 이러한 semantic matching loss가 data efficiency를 높이는 데 큰 도움이 된다는 것이다. 생각해보면, image-text pair가 맞으면 1, 아니면 0을 부여하는 hard target보다 medical domain에서의 semantic meaning을 추출하여 이들의 유사도와 가까워지게 만드는 soft target은 어느 정도 medical knowledge prior를 가지고 있다. 이러한 prior를 가지고 학습을 진행한다는 점에서 data efficiency가 높다고 이해할 수 있다.
2.3. Visualization
마지막으로 t-SNE를 사용해서 MedCLIP의 representation이 얼마나 효과적인지 확인하였다. 결과는 위와 같이 기존 CLIP에 비해 훨씬 더 semantic meaning을 잘 구분하는 것을 확인할 수 있다.
💡 Summary
MedCLIP은 Image-Text Decoupling과 Semantic Matching Loss를 통해 Medical Domain에서의 일종의 Prior Knowledge를 활용하여 Data-efficient한 CLIP 모델을 학습시키는 데 성공하였다. 이러한 방식은 기존의 CLIP에서 발생하는 FN(False Negative) 문제를 해결하고, 가뜩이나 데이터의 수가 적은 medical domain에서 효과적으로 데이터를 활용할 수 있도록 도와준다.
댓글 남기기