AI 논문리뷰 - Vision

Research Paper Review : Learning predictable and robust neural representations by straightening image sequences

study_love 2026. 1. 29. 19:22

 오늘은 밑에 논문에 대해서 리뷰해보려고 한다.

https://arxiv.org/abs/2411.01777

 

Learning predictable and robust neural representations by straightening image sequences

Prediction is a fundamental capability of all living organisms, and has been proposed as an objective for learning sensory representations. Recent work demonstrates that in primate visual systems, prediction is facilitated by neural representations that fo

arxiv.org

 

Introduction

 모든 영장류는 미래를 예측하며 행동한다. 이 과정에서 RGB space에서 highly non-linear한 시각 입력을 latent space에서 보다 linear한 표현으로 변환하고, 이를 바탕으로 미래를 예측한다고 알려져 있다. 본 논문은 이러한 메커니즘을 딥러닝을 통해 증명한다.

Methods

 이 논문에서 하고자 하는 핵심은 하나의 image encoder를 학습하는 것이다. RGB space에서 highly non-linear한 frame sequence, 즉 비디오를 image encoder에 통과시키면, 그 결과로 얻어지는 latent vector sequence가 보다 linear한 형태를 띠도록 만드는 것이 목표다. 저자들은 이를 통해 linear extrapolation만으로도 미래 state를 예측할 수 있을 것이라고 믿는다.

dataset 준비 

 이 논문에서 사용한 dataset은 여러 개의 sequence로 구성되어 있다. 이러한 sequence들은 하나의 image를 기반으로, 일정한 속도를 가지는 변형을 적용해 생성된다. 구체적으로는 rotation, translation, zoom-in과 같은 변형을 시간에 따라 연속적으로 적용함으로써, 단일 image로부터 video sequence를 만들어낸다.

 

모델 훈련 

 한 batch는 위의 그림과 같이 서로 다른 n개의 sequence로 구성된다. 각 sequence를 vision encoder에 통과시킨 뒤, 얻어진 latent representation에 대해 다음의 세 가지 loss를 적용한다.

 

 첫 번째는 straightness loss이다. 이 loss는 시점 t 사이의 변화 방향이, 사이의 변화 방향과 같아지도록 강제한다. 즉, local한 구간에서는 latent trajectory가 linear하게 변화해야 한다는 의미다. 이러한 특성 때문에 straightness loss는 하나의 sequence 내부에서만 적용된다.

 

 하지만 이 loss만 사용할 경우, 모든 representation이 하나의 constant 값으로 수렴해버리는 collapse 현상이 발생할 수 있다.

 

 두 번째는 variance loss이다. 이 loss는 sequence 구분 없이, 한 batch 안의 모든 데이터를 하나의 population으로 보고 적용된다. batch에 포함된 모든 데이터는 각 feature에 대해 최소 1 이상의 분산을 가져야 하며, 이를 통해 전체 representation이 전반적으로 퍼지도록 강제한다. 다만 이 제약은 global하게 적용되기 때문에, 같은 sequence에 속한 데이터들끼리는 여전히 서로 가까울 수 있다.

 

 마지막은 covariance loss이다. 이 loss는 서로 다른 feature들 사이의 correlation을 0으로 만들어, 각 feature가 서로 다른 정보를 capture하도록 유도한다. 일반적으로 covariance loss에 대각 성분을 1로 만드는 항을 함께 포함하면 variance loss와 유사한 역할을 하게 된다. (특정 feature에 대해서 모든 데이터가 동일한 feature 값을 가지면 batch 평균을 뺐을 때, 해당 feature값이 0이 되어 correlation matrix를 구했을 때 대각 성분이 0이되어 loss가 세게 걸림) 이 논문에서는 variance loss를 별도로 분리해 제시했는데, 이는 각 loss의 역할을 명확히 보여주기 위함이거나, 혹은 weighting을 독립적으로 조절하기 위함으로 보인다.

 

 세 가지 loss를 종합하면 다음과 같은 목표를 가진다고 정리할 수 있다.
서로 다른 sample들은 전반적으로 멀어져야 하지만, 같은 sequence 안의 sample들은 linear한 trajectory를 따라야 하며, 동시에 각 feature는 서로 다른 역할을 수행해야 한다.

 

vision encoder architecture

 vision encoder의 architecture는 사용한 dataset에 따라 다르게 설정된다. MNIST를 이용한 실험에서는 비교적 단순한 CNN을 사용해 학습을 진행했고, CIFAR-10을 이용한 실험에서는 ResNet을 backbone으로 하고 그 뒤에 3-layer MLP를 추가한 구조를 사용했다고 한다.

experiment results

 먼저 인상적인 점은, 같은 숫자끼리 모이도록—즉 같은 sequence에 속한 샘플들이 서로 가까워지도록—직접적인 loss를 준 적이 없음에도 불구하고, 결과를 보면 같은 숫자에 해당하는 sample들이 자연스럽게 모여 있다는 것이다. 논문에서는 이러한 현상이 straightening 과정에서 자연스럽게 발생한다고 설명한다.

 그 이유를 살펴보면, 훈련 과정 중에 모델은 straightening을 하기 위한 규칙을 찾아내게 되는데, 이를 위해서는 결국 객체의 identity(class)를 찾아야 할 것이다. 이 과정에서 identity가 같은 객체들끼리 representation이 너무 멀 수는 없을 것이고, 반대로 말하면 identity가 다른 객체들이 서로 가까워지는 것도 어렵다. 이로 인해 같은 sequence 내부에서는 representation이 상대적으로 가까워지게 되고, 그 결과 발생하는 분산의 부족은 variance loss를 통해, 서로 다른 sequence에 속한 sample들 사이의 거리로 보완되는 것으로 보인다.

 

 또한 다음 frame을 예측한 뒤, 그 예측된 frame을 이용해 downstream task를 수행해본 결과, 실제 frame을 사용했을 때와 맞먹는 성능을 보였다.

 

 자세히 보면 좋은 점 : 사람 눈으로보기에 linear한 특성이 없는 것도 배울 순 없다. 

 밑의 그림을 보면, 을 이용해 를 예측할 때, 에서 로 갈수록 숫자가 커지고 있었다면 t 역시 계속 커지고 있을 것으로 예측하는 것을 확인할 수 있다. 이는 매우 자연스러운 결과이며, 오히려 이 상황에서 값이 작아질 것이라고 예측하는 편이 더 말이 되지 않는다.

 이와 같이 상식에 어긋나는 예측까지 straightening이 학습할 수 있다고 기대해서는 안 된다. straightening은 변화의 방향과 경향을 보존하는 역할을 할 뿐, 관측된 dynamics와 정반대의 변화를 만들어내는 메커니즘은 아니다.

 

Comment

 개인적으로 비디오를 본 image encoder가 이상적으로 가져야 할 특징은, video sequence를 latent space에 그렸을 때 만으로도 를 예측할 수 있도록 임베딩이 학습되는 것이라고 생각한다. 그런 의미에서 이 논문은 그 예시를 아주 잘 보여준다고 느꼈다.

 다만 이 논문을 통해 또 하나 알 수 있는 점은, 사람의 눈으로 보아도 예측 불가능한 것은 딥러닝 역시 예측하기 어렵다는 것이다. 가만히 굴러가는 공과 같은 단순한 영상에서 straightening을 학습하는 것은 큰 문제가 없을 것이다. 하지만 사람이 춤을 추는 비디오를 보고 straightening loss를 준다고 해서 과연 의미 있는 학습이 될까? 개인적으로는 거의 안 될 것이라고 확신한다.

 곰곰이 생각해보면, 사람 역시 모든 것을 예측하면서 세상을 보지는 않는다. 춤추는 사람의 다음 동작을 정확히 예측할 수 있을까? 사람에게도 쉽지 않은 일이다. 그렇다면 모든 비디오에 대해 일반화가 잘 되는 representation을 학습하기 위해, 다음 frame 예측을 강한 제약으로 거는 것이 과연 적절한가라는 의문이 들었다. 오히려 진짜 foundation model을 목표로 한다면, 다음 frame 예측을 지나치게 강하게 요구하는 것은 한계가 있을 수도 있겠다는 생각이 들었다.

 아래는 대표적인 비디오 데이터셋인 Kinetics-400의 일부 예시이다. 이런 sequence들을 보면, straightening은커녕 다음 frame 예측 자체도 매우 어려워 보인다.