논문 리뷰

[논문 리뷰] Dynamic Curriculum Learning for Imbalanced Data Classification

컴퓨터비전 LCK 2024. 4. 10. 20:50

1.Introduction

저번 포스팅에서와 같이 long tail 데이터셋에서의 예측 성능을 높이기 위한 기법을 제시된 논문입니다.
 
차이점은 이 논문에서는 ImageNet과 같은 데이터셋과는 달리 CelebA 데이터셋을 중심으로 설명하고 있습니다.(RAP와 CIFAR-100 데이터셋으로 실험한 결과도 제공합니다.)
 
CelebA 데이터셋은 여러 유명인들의 얼굴 이미지인데, 총 202,599개의 얼굴 이미지로 구성되며 10,177명 가량의 사람들(명당 20장 정도)이 데이터셋을 이루고 있습니다.
 
주목할 점은 각 사진마다 40개의 특성 정보들(Smiling, Male, Young, Big nose,...)이 있고, 1또는 0으로 라벨링 되어있다는 것입니다.
 
따라서 이 데이터셋은 각 사람에 대한 샘플 수는 균형적이라고 볼 수 있으나 각 특성을 따져보았을 때는 불균형적인 특성이 많습니다. (ex. 'Bald' 특성의 경우 positive sample의 수가 훨씬 적고, negative sample의 수는 많습니다.)
 
특성 공간을 나누는 feature embedding과 클래스 간의 classification을 모두 고려합니다.
 
 

3. Method

 
논문에서 제시하는 기법을 정리해보면 다음과 같습니다.

논문에서는 두가지를 스케줄러를 사용하여 curriculum learning을 구현합니다.(curriculum learning이란 쉬운 샘플, 즉 정보가 많은 샘플을 트레이닝 초기에 학습하고 에포크가 진행될수록 어려운 샘플들을 학습하는 것과 같이 훈련 진행에 따른 learning focus가 달라지는 훈련방식입니다.)

(1) sampling scheduler

논문에서 제시하는 sampling scheduler를 활용하여 데이터 분포를 에포크가 진행됨에 따라 불균형 → 균형으로 만들어 줍니다. 트레이닝 초기 단계에서는 모델이 head 클래스에 대해 학습하고 학습이 진행될수록 tail 클래스쪽을 훈련할 수 있도록합니다.

(2) loss scheduler

Classification을 위한 손실함수와 feature embedding을 위한 손실함수 두가지를 합친 손실함수를 정의합니다.

① 특성에 따른 feature embedding을 잘하기 위한 손실함수 → metric learning loss
② 분류(decision boundary 설정)를 잘하기 위한 손실함수 → cross entropy

 
이제 하나하나 알아보도록 할텐데요.
 
먼저 샘플링 전략을 알아보겠습니다.
 
 
3.1 Scheduler Function Design
 

모두 에포크에 따라 1에서 출발해서 0으로 가는 함수들입니다. 이 함수들을 스케줄러 함수라고 합니다.
 
 
3.2 Sampling Scheduler
 
다음으로는 데이터 분포를 정의합니다.

 
가장 minority한 클래스의 샘플 수를 공통 분모로 적용하고 분자로는 각 클래스의 샘플수가 들어가게 됩니다. 이를 오름차순 정렬한 것을 기호 D_train으로 표기합니다.
 

 
D_train의 sampling scheduler를 제곱하여 줍니다. 샘플링 스케줄러는 1→0 이므로 데이터 분포 D_target(l)은 마지막 에포크에서 1:1:1:...:1, 즉 모두 같은 분포를 같게 됩니다.
 
이 데이터 분포에 따라 샘플링 확률이 결정되게 됩니다.
 
또 이것을 활용하여 손실함수 Dynamic Selective learning(DSL) loss 를 정의합니다.

DSL(Dynamic Selective learning) loss

 
D_target,j(l) < D_current,j 일때 웨이트가 0이 되어 손실함수에 반영이 되지 않습니다.
 
현재 배치에 뽑힌 샘플들로 산출한 디스트리뷰션(D_current,j)이 에포크에서 정해준 디스트리뷰션(D_target,j(l))을 넘는다면 손실함수에 포함하지 않는다는 것을 의미합니다.
 
이를 통해 tail클래스의 손실함수 반영치가 에포크 진행에 따라 점진적으로 늘어감을 알 수 있습니다.
 
 
3.3 Metric Learning with Easy Anchors
 
feature embedding을 위한 손실함수를 정의합니다.
 

metric learning loss(triplet loss)

앵커와 negative sample 사이의 유클리디안 거리를 벌리고, positive sample사이의 거리는 좁히는 알고리즘입니다.
 
앵커와 positive sample사이의 거리에서 마진 m_j를 더한 거리보다 큰 차이가 앵커와 negative sample 사이에서 나야한다는 것을 수식적으로 이해할 수 있습니다.
 
이를 사용하여 triplet pairs 에 대하여 합한 값을 평균내준 함수가 CRL loss입니다. 음수값이 나올 수 있어 0과 해당값 중에 max를 취한값을 최종 반환합니다.

metric learning loss(triplet loss)

하지만 이 손실함수에는 단점이 있는데 모든 샘플을 앵커로 적용한다는 것입니다. 이렇게 했을 때 발생하는 문제점에 대해서 살펴보겠습니다.

 
hard positive가 앵커로 적용되었을 때 문제가 발생합니다. 결정경계에 가까운 hard positive sample이 앵커로 작용하여 negative sample을 밀고 positive sample을 끌어오게 되는데 이것이 매끄럽지 못한 결정 경계를 만들게 합니다.
 

 
논문에서는 이러한 문제를 해결하기 위해서 triplet loss를 변형하여 easy positive 샘플들만을 앵커로 사용하는 손실함수를 제안합니다.

Triplet loss with Easy Anchors (TEA)

 
 
3.4 Loss Scheduler
 
이렇게 제안된 DSL loss 와 TEA loss 를 결합하여 DCL 손실함수가 정의됩니다.
 
feature embedding을 담당하는 TEA loss에서, classification을 담당하는 DSL loss로 curriculum learning이 진행됩니다.
 

Dynamic Curriculum Learning (DCL) loss

 
(p는 0이상 1이하의 하이퍼 파라미터)
 
로스 스케줄러 f(l)은 에포크에 따라 1+ ε 로 시작하여 ε 로 끝납니다. 
 
이를 통해 학습 초반에는 feature embedding에 힘을 주고 학습이 진행될수록 classification이 손실함수에서 차지하는 비중이 커진다는 것을 알 수 있습니다.
 
 

4. Experiments

 
4.1 Datasets
 
1) CelebA: 총 202,599개의 얼굴 이미지, 10,177명 가량의 사람들(명당 20장 정도), 40개의 이진 특성
2) RAP: 총 41,585개의 이미지, 각 이미지에 72개의 속성이 주석
3) CIFAR-100: balanced dataset, 총 60,000개의 이미지, 100개의 클래스
 
4.2 Evaluation Metric
 

 
(P_i는 positive, TP_i는 true positive, N_i는 negative, TN_i는 true negative)
 
CelebA와 RAP 데이터셋에서는 class-balanced mean accuracy을 모든 binary classification 에서 적용한 평균값으로 정확도를 평가합니다.
 
CIFAR-100
 
4.3 Experiments on CelebA Face Dataset

위는 다른 논문에서 제시된 방법들과 DCL의 성능지표를 비교한 표입니다. 각 특성에 대한 이진 분류 정확도를 나타내었습니다. 붉은색이 가장 높은 정확도, 푸른색이 그 다음으로 높은 정확도를 보인 기법입니다.
 
또한 불균형 정도가 높은 특성들에서 DCL 기법이 다른 기법들보다 확연 좋은 성능을 냄을 알 수 있습니다.
 

Backbone으로는 DeepID2를 채용하였다. SS(Sampling Scheduler), TL(TEA Loss), LS(Loss Scheduler)가 하나씩 추가될수록 모델의 정확도가 상승하였습니다.
 

스케줄러 함수는 Convex 함수가 가장 높은 성능을 보였습니다.
 
4.4 Experiments on RAP Pedestrian Dataset

 
4.5 Experiments on CIFAR-100 Dataset

RAP과 CIFAR-100 dataset 으로 실험한 경우에서도 좋은 성능을 보였습니다.