[24’ ICML-WS] Transformers need glasses! Information over-squashing in language tasks

Date:     Updated:

카테고리:

태그:

image


🔍 Abstract

2024년 ICML Workshop 중 Workshop on Theoretical Foundations of Foundation Models (TF2M)에 나온 논문이다. Google DeepMind에서 낸 논문이라 주목하여 읽게 되었다. 아무래도 이론적인 내용이 주를 이루는 논문이라 엄밀한 정도의 이해는 생략하고, 해당 논문에서 분석한 개념들 위주로 요약하고자 한다.

image

저자들은 Representational CollapseOver-squashing이라는 개념을 LLM에서 처음으로 제안한다.

  • Representational Collapse란 LLM이 입력 데이터의 정보를 충분히 표현하지 못하는 것을 의미한다. 특히 final token의 representation이 input이 길어지면 길어질수록 하나에 수렴하게 되며, 따라서 LLM은 나중에 나오는 정보를 충분히 이해하기 어렵다. 이는 어느 정도 fp16과 같은 low-precision floating point로 인해 발생하는 문제이다.
  • Over-squashing은 Transformer 구조의 한계로 인해 Input의 Early Token에만 LLM이 집중하게 되는 현상을 의미한다. Early Token은 Token Prediction에 기여할 수 있는 경로가 Late Token보다 많고, 따라서 LLM은 Input의 앞쪽에만 집중하게 된다.


1. Motivating Examples

저자들은 위의 두 개념을 추론하기 위해 아주 간단한 Task인 Copying, Counting을 예시로 들었다. 각각에 대해 살펴보자.


1.1. Copying

image

저자들은 Google의 LLM인 Gemini에게 1...10의 last element 또는 01...1의 first element를 copy하여 출력하도록 했다. 어떠한 경우에도 정답은 0이어야 한다. 그러나, input의 길이가 길어지면 길어질수록 last element를 예측하는 task에서는 0이 아닌 1을 예측하는 현상이 발생했다. 반면, first element를 예측하는 task에서는 0을 잘 예측하는 것을 확인할 수 있었다. 반면 앞에 “Hint It’s not necessarily a 1, check carefully” 와 같은 Hint를 제시한 경우, 또는 0과 1을 번갈아 제시한 경우에는 0을 잘 예측하는 것을 확인할 수 있었다. 저자들은 이러한 문제가 Representational Collapse로 인해 발생하며, 0이 앞에 있는 경우나 Hint를 제시하거나 번갈아 제시하는 경우에는 Over-squashing으로 인해 그러한 문제가 발생하지 않는다고 주장한다.


1.2. Counting

image

Counting은 위 그림을 통해 쉽게 이해할 수 있다. 숫자가 커지고, Input이 길어짐에 따라 LLM은 수를 잘 세지 못하게 된다. 내용을 생략하였으나, 저자들은 Positional Encoding, Causal Attention Mechanism 없이 Transformer는 Counting이 불가능하다는 것을 수학적으로 증명하였다. 이 또한 Representational Collapse에 의한 문제로, 점점 1에 대한 Representation이 유사해지면서 LLM은 이를 숫자로 세야 한다는 것을 잘 인지하지 못하게 된다. 대신 LLM은 직접 수를 세기보다는 대충 많이 등장하는 숫자를 대답하는 방식을 사용한다. 아래 그림을 통해 LLM이 100을 넘어가면 숫자를 세는 것을 “포기”하고 답으로 100을 말하는 것을 확인할 수 있다.

image


2. Representational Collapse

Representational Collapse는 LLM이 입력 데이터의 정보를 충분히 표현하지 못하는 것을 의미한다. 특히 final token의 representation이 input이 길어지면 길어질수록 하나에 수렴하게 되며, 따라서 LLM은 나중에 나오는 정보를 충분히 이해하기 어렵다. 이를 수학적으로는 다음과 같이 이야기할 수 있다.

image

실제로 Copying, Counting 문제에서 1이 하나 추가되었을 때의 Representation의 차이는 길이가 길어짐에 따라 Positional Encoding의 차이가 줄어들면서 0에 가까워진다. 따라서 LLM은 이러한 차이를 잘 인지하지 못하게 된다. 저자들은 아주 간단히 3개의 Token마다 Comma ,를 넣어 representation distance를 늘리는 방식으로 이를 해결할 수 있었다.

image


3. Over-squashing

Over-squashing을 수학적으로는 다음과 같이 나타낼 수 있다.

image

이 식으로부터 final token의 sensitivity를 알 수 있다. 즉, Final Token $\mathbf{y}_ n$의 Input Token $\mathbf{v}_ i ^ {(0)}$에 대한 sensitivity를 위 식으로 이해할 수 있으며, 결국 $i$가 작을수록, 즉 input sequence 초기에 위치할수록 sensitivity가 높아진다. 이는 LLM이 Input의 Early Token에만 집중하고 해당 정보가 더 잘 보존되는 현상을 설명한다. 극단적인 상황에서, Layer의 수 $L \to \infty$로 가면 $\mathbf{y}_ n$는 $\mathbf{v}_ 1 ^ {(0)}$에만 의존하게 된다.

image

이러한 over-squashing은 LLM에서 관찰된 U-shape effect를 설명하는 데 도움을 준다. U-shape effect란 LLM이 retrieval task를 수행하는 경우 input의 앞쪽과 뒤쪽에 해당 정보가 있는 경우 더 잘 수행하는 현상을 의미한다. 앞쪽의 경우 over-squashing에 의해 정보가 더 잘 보존되기 때문에, 그리고 뒤쪽의 경우 학습 당시 attention mechanism이 주변 정보를 더 잘 받아들이는 recency bias로 인해 더 잘 수행하는 것으로 설명할 수 있다.


💡 Summary

지금까지의 내용을 요약하면 다음과 같다.

  • 저자들은 Copying, Counting과 같은 간단한 문제일지라도 Sequence가 길어지면 LLM이 Task를 잘 수행하지 못한다는 것을 발견하고, 이를 Representational CollapseOver-squashing이라는 개념으로 설명하였다.
  • Representational Collapse는 input sequence가 길어질수록 final token의 representation이 하나로 수렴하게 되어 나중에 나오는 정보를 충분히 이해하기 어렵게 된다는 것이고, Over-squashing는 input sequence의 앞쪽 부분만 충분한 정보를 보존하여 final token에 전달할 수 있고 중간 부분은 정보가 소실되는 현상이다.


📃 Reference


Language 카테고리 내 다른 글 보러가기

댓글 남기기