안녕하세요. 오늘 소개드릴 논문은 AAAI'2021에 게재된 DegAug입니다. 논문 제목에서 알 수 있듯이 Feature representation을 나누고 semantic augmentation을 통해 Out-of-Distribution generalization을 달성한 논문입니다.
About Out-of-Distribution
논문에 대한 내용에 앞서 IID와 OoD의 개념을 말씀드리겠습니다.
Independent, Identical Distribution (I.I.D)→예를 들면 CIFAR10의 train-set과 test-set이 나눠져 있지만 그 둘은 동일한 분포를 가지고 있습니다. 따라서 이 둘은 I.I.D입니다.
어떤 랜덤 확률 변수 집합이 있을 때 각각의 랜덤 확률변수들은 독립적이면서 동일한 분포를 가지는 것을 의미합니다.
Out of Distribution (OoD)(ICML 2021, Stanford) WILDS: A Benchmark of in-the-Wild Distribution Shifts→ 당연하게도 실제 상황에선 OoD가 더 많이 존재합니다.
예를 들어, Medical 응용의 경우 특정 병원에서 얻은 데이터로 학습한 뒤, 다른 병원에 배포되는 경우가 있습니다. 학습 때와 다른 분포를 테스트시 추론하게 되므로 이 경우 OoD입니다.
학습 데이터의 분포를 따르지 않는 데이터를 따르지 않는 데이터를 이야기합니다.
그렇다면 기존의 학습방법으로 OoD을 잘 추론할 수 있을까요? → 그렇지 않습니다!
위 그림은 2019년 페이스북에서 나온Invariant Risk Minimization에 나온 예시입니다.초록색 배경의 소🐮 와모래색 배경의 낙타🐪 로 학습을 시킨 모델이 있다고 가정해봅시다.
이 모델에초록색 배경낙타🐪를 추론시킬 경우 소🐮로 추론한다고 합니다.
즉, 기존의 학습 방법은 데이터간의 가장 큰 공통 특징(배경색)을 가지고 학습 및 추론을 하기 때문에 학습 데이터와 다른 데이터(OoD)은 잘 추론하지 못합니다.
OoD는 언제 문제가 될까요?
(CVPR 2019, Facebook) Does Object Recognition Work for Everyone?
위 그림은 페이스북에서 CVPR'2019에 발표한 논문에서 발췌한 그림입니다.
왼쪽 사진은 상용 클라우드 플랫폼 (Azure, Clarifai, Google, Amazon)의 AI classification 모델에 서로 다른 국가에서 촬영한 비누를 추론시킨 결과입니다 네팔에서 촬영한 사진의 경우 모든 상용 솔루션이 틀리는 것을, 영국에서 촬영한 사진의 경우 대부분의 경우에서 맞는것을 알 수 있습니다.→ 상용 솔루션의 경우 대부분 미국에서 배포하고 있고, 미국을 포함한 수입이 높은 국가에서 수집한 데이터로 학습을 진행합니다. 따라서 데이터를 수집한 국가와 다를수록 즉 OoD의 경우 정확도가 낮아지는 것을 알 수 있습니다.
오른쪽 사진은 한달 수입이 $x$ 인 국가에서 데이터를 수집한 뒤, 상용 솔루션 모델들에 추론시켰을 때 정확도를 나타낸 표입니다. 수입이 높은 국가에서 수집한 데이터인 경우 정확도가 높은 것을 알 수 있습니다.
DegAug
Category vs . Context
이 논문에선 하나의 이미지가 category, context 2개의 정보를 담고 있다고 얘기하고 있습니다
context : 각 이미지가 가지고 있는 환경적인 특징을 의미합니다. 예를 들어 잔디밭 위에 양이 있다면 잔디밭이 context가 됩니다.
category : label을 의미합니다.
따라서 하나의 이미지는 category-context의 쌍으로 이루어져 있습니다. 만약 train-set에 없는 category-context쌍을 기존의 학습 방법으로 학습한 모델의 추론시키면 잘 추론을 못하고, 이 연구는 이를 해결하고자 하는 논문입니다.
Proposed Method
Overall Architecture
자 그럼 이 논문에서 OoD을 어떻게 해결했는지. 즉, OoD generalization을 어떻게 달성했는지 알아보도록 하겠습니다.
Overall Diagram of DegAug
위 그림은 OoD 데이터를 학습하기 위한 이 논문의 구조를 도식화 한 그림입니다
input data를 backbone네트워크에 추론
1에서 얻은 embedding vector를 category feature extractor와 context feature extractor에 각각 추론
gradient간의 코사인 유사도를 감소하는 방향으로 backbone 네트워크의 학습을 진행함으로써 backbone 네트워크가 category와 context를 모두 담은 embedding vector를 추출할 수 있도록 학습이 진행됩니다.
train-set에 없는 데이터를 잘 추론하기 위해 DecAug는 feature-level에서 augmentation을 진행합니다. (Semantic augmentation)
→ 위 수식을 보시면 context feature에 context loss gradient를 더해줌으로써 train-set에 없는 category-context 쌍의 학습을 진행할 수 있습니다.
위의 방법을 통해 얻은 category feature와 context feature를 합쳐 최종적인 classification을 하게 됩니다.
모든 loss function을 정리하면 다음과 같습니다.
Experiments
DegAug는 총 3가지 데이터셋에 대한 실험 결과를 보여주고 있습니다
먼저 colored-MNIST에 대한 실험 결과 입니다.
ERM은 일반적인 학습 방법입니다. DegAug가 가장 높은 정확도를 보이는 것을 확인할 수 있습니다.
PACS 실험 결과입니다.
PACS는 Picture, Art, Cartoon, Scatch 4개의 도메인으로 이루어진 데이터셋으로 3개의 도메인을 train-set에, 나머지 하나의 도메인을 test-set으로 두고 정확도를 측정합니다.
NICO 실험 결과입니다.
Conclusion
Feature augmentation을 통해 OoD generalization을 달성한 DegAug에 대해 알아보았습니다. feature를 augmentation한다는 점에서 흥미로웠지만 OoD generalization을 달성하기 위해 context label이 필요하다는 점이 굉장히 크리티컬한 것 같아 아쉬움이 남습니다. self-supervised learning과 이 연구를 합치거나 StyleGAN과 같이 context를 변경해 DegAug를 진행한다면 더 재밌지 않을까.. 라는 생각을 논문을 읽으며 많이 했던 것 같습니다.
KDST(KIST Data Science Team)는 기계학습을 포함한 데이터와 지능에 관련된 여러 주제에 대해서 연구하는 팀입니다. KDST