오늘은 I-JEPA라는 논문을 공부해보려고 한다. 이 논문에서 제안한 아키텍처가 훈련이 안정적으로 잘 된다는 점이 특히 인상적이다. 직관적으로는 학습이 쉽지 않을 것처럼 보이는데도 실제로는 잘 수렴한다는 점이 꽤 신기하게 느껴졌다. 이 논문을 통해 이런 형태의 구조 역시 딥러닝으로 충분히 학습 가능하다는 사실을 다시 한 번 실감하게 되었고, 그런 점에서 I-JEPA를 공부하는 것은 상당히 의미 있는 일이라고 생각한다.
Background
Computer Vision 분야에서 self-supervised learning은 크게 두 가지 계열로 나눌 수 있다. 하나는 DINO로 대표되는 invariance-based methods이고, 다른 하나는 MAE로 대표되는 generative methods이다.
Invariance-based methods는 서로 다른 view 간의 표현을 일치시키는 방식으로 학습되며, 비교적 high-level semantic 정보를 잘 포착하는 장점이 있다. 그러나 이러한 방법들은 사람이 설계한 inductive bias에 크게 의존하기 때문에, 특정 downstream task에서는 오히려 성능이 제한될 수 있다는 한계도 함께 가진다.
반면, generative methods는 입력을 복원하거나 예측하는 과정을 통해 학습되며, 상대적으로 inductive bias가 적어 이미지의 본질적인 구조를 더 직접적으로 학습한다는 인상을 준다. 이러한 관점은 인지과학적으로도 어느 정도 타당성이 있는 것으로 알려져 있는데, 인간의 인지 능력 역시 감각 입력(sensory input)을 바탕으로 다른 감각 입력을 예측하는 과정과 밀접하게 연관되어 있기 때문이다.
다만 generative methods는 raw image를 복원해야 한다는 특성상, 저수준(low-level) 시각 정보에 많은 용량을 할애하게 되고, 그 결과 high-level semantic 정보를 충분히 학습하지 못하는 경향이 있다. 이러한 한계를 극복하기 위해, 본 논문에서는 generative framework 내에서도 high-level semantic 표현을 효과적으로 학습할 수 있는 방법을 제안한다.
Method

I-JEPA는 MAE와 비교했을 때 크게 두 가지 중요한 차이점을 가진다.
첫 번째는 masking 방식이다. MAE와 마찬가지로 context(보이는 부분)로부터 masked part를 예측하는 구조를 유지하지만, masking을 pixel 단위가 아니라 공간적으로 연속된 pixel들로 구성된 block 단위로 수행한다. 또한 context는 전체 이미지에서 masked block을 단순히 제거한 결과가 아니라, 전체 이미지에서 context 영역을 먼저 sampling한 뒤, 그 안에서 masked block을 제거한 부분으로 구성된다. (그림에서 context와 target에 겹치는 부분이 있는 것은 저자들이 설명을 쉽게 하려고 하다보니 실수한 것 같다.)
두 번째 차이점은 예측 대상의 수준이다. MAE가 masking된 영역의 raw pixel을 직접 복원하는 반면, I-JEPA는 pixel level이 아닌 latent space에서 masking된 부분을 예측한다.
이러한 설계 선택에는 명확한 이유가 있다.
우선, 이러한 block-based masking 방식을 사용하면 context 영역 내에서 서로 연결된 pixel들이 더 많이 유지되므로, 모델이 local semantic 정보를 보다 잘 파악한 상태에서 나머지 영역을 예측할 수 있다. 또한 context 영역을 결정할 때 전체 이미지에서 먼저 sampling을 수행하는 이유는, context가 과도하게 커지거나 정보가 지나치게 많아져 문제가 지나치게 쉬워지는 것을 방지하기 위함이다.
또한 예측을 latent space에서 수행함으로써, downstream task에는 크게 필요하지 않지만 raw pixel 복원에는 필수적인 불필요한 low-level detail을 embedding에서 자연스럽게 제거할 수 있다. 그 결과, 표현은 보다 추상적이면서도 semantic한 정보에 집중하게 된다. 심지어는, 불필요한 low-level detail을 제거함으로써 학습의 수렴속도도 빨라질 수 있다.
Detail : How to make Context & Masking target
그렇다면 이제 이미지를 block 단위로 masking하는 방식에 대해 살펴보자. I-JEPA에서는 총 4개의 masking block을 사용하며, 모델은 각 block에 대해 예측을 수행한다. 각 block은 가로세로 비율이 (0.75, 1.5) 범위 내에서 결정되며, 전체 이미지 크기에 대해 (0.15, 0.2)의 scale을 갖는다.
한편, seen part에 해당하는 context block은 전체 이미지에서 (0.85, 1) 범위의 scale을 가지는 단일 block을 먼저 sampling한 뒤, 그 안에서 masking block과 겹치는 영역을 제거하여 구성된다.
이에 대한 직관적인 예시는 아래 그림에 제시되어 있다.

Detail : Model Architecture & How to train?
여기까지 잘 따라왔다면, context로 부터 각각의 masking block을 latent space에서 맞춘다는 것까지는 이해했을 것이다. 이제, 조금 더 자세히 모델의 구조를 보도록 하자.

그림을 보면, 학습에 사용되는 네트워크는 총 세 가지로 구성되어 있다. 바로 context encoder, target encoder, 그리고 predictor이다. Architecture는 모두 ViT를 사용한다.
Context encoder는 입력으로 주어진 context 영역을 encoding하여 context embedding을 생성하며, predictor는 이 context embedding과 함께 masked block의 위치 정보를 입력으로 받아 해당 block에 대응하는 latent representation을 예측한다. 한편, target encoder는 모델이 맞춰야 할 정답 latent를 생성하는 역할을 한다.
이 구조에서 자연스럽게 알 수 있는 점은, context encoder와 target encoder는 동일한 표현 공간을 가져야 한다는 것이다. 즉, 두 encoder는 구조적으로 동일해야 하며, 같은 의미의 latent representation을 생성하도록 설계된다. 가장 간단히 생각하면, 그냥 동일한 net을 쓰고, weight도 공유하면 된다. 하지만 이렇게 되면 문제가 하나 생기는데 그것이 바로 representation collapse이다. 아무 제약 없이 학습을 진행하면, 모델은 모든 입력에 대해 동일한 embedding을 출력하도록 수렴할 수 있고, 이 경우 예측과 타깃이 항상 일치하게 되어 loss가 0이 되는 trivial solution에 빠지게 된다. 따라서 이러한 collapse를 방지하기 위한 장치가 반드시 필요하다.
I-JEPA에서는 이를 해결하기 위해, context encoder와 predictor는 gradient descent로 학습하는 반면, target encoder는 context encoder의 가중치에 대해 EMA(Exponential Moving Average)를 적용하여 업데이트하는 방식을 사용한다. 이와 같은 비대칭적 학습 방식이 collapse를 효과적으로 방지한다는 점은 직관적으로 이해하기 쉽지는 않지만, 기존의 여러 연구들에서 실험적·이론적으로 그 타당성이 검증되어 왔다고 알려져 있다.
Experiment
I-JEPA로 학습한 Embedding을 여러가지 downstream task에 적용해본 결과이다. 다른 모델들에 비해서 성능이 우수한 것을 알 수 있다.


그리고 RCDM에서 제시한 방법을 통해서 I-JEPA embedding을 visualization해 본 결과, embedding이 실제로 low level detail에 치중하지 않은 유의미한 semantic 표현을 학습하고 있다는 것을 알 수 있었다.

Additional Knowledge

이 논문에서는 self-supervised learning의 서로 다른 접근 방식을 구현하기 위한 model architecture를 분류하고, 각 범주에 대해 명확한 naming을 제안한다. 이를 정리한 내용을 아래에 옮겨 적어본다.
먼저, DINO와 같이 서로 다른 view에서 동일한 object를 관측했을 때, 두 view에서 추출된 embedding이 동일하도록 학습하는 구조를 Joint-Embedding Architecture라고 정의한다.
다음으로, MAE와 같이 partial observation을 latent space로 인코딩한 뒤, masking된 영역에 대한 힌트 (예: masking된 patch의 위치 정보)를 함께 사용하여 raw observation을 복원하는 구조를 Generative Architecture라고 부른다.
마지막으로, I-JEPA와 같이 partial observation으로부터 얻은 latent와 masked observation에 대한 힌트 z를 입력으로 받아, raw observation 자체가 아니라 그 latent representation을 예측하는 구조를 논문에서는 Joint Embedding Predictive Architecture (JEPA)라고 명명한다.
Comment
앞으로 저렇게 latent에서 예측을 하고 싶을 때, JEPA를 사용하는 것이 유용할 수 있겠다는 생각이 든다.