본문 바로가기

AI/Vision

DANN : Domain Adversarial Training Neural Network

 

연구 주제 관련 논문을 서치하다가 알게된 고전적인 방법이다. 

이 논문은 2016년에 JMNL 저널에 나온 논문이지만, 아이디어 자체는 간단하여 지금까지도 인용이 많이 되고 있다. 

알아두면 좋을 것 같아서 정리하기로 했다. 

년도별 논문 인용 수 그래프

 

 

기존의 Domain Adaptation(DA) 방법에 GAN의 concept을 일부 도입한 방법이라고 이해하면 될 것 같다. 

 

Traditional 한 DA 방법은, 주로 fixed feature representation을 추출하여 원하는 target distribution에 mapping하는 것이다.

하지만 위 논문은 representation 자체를 바꾸는 방법으로, 한 번의 training을 통해 위 결과를 도출해 낸다는 것이 가장 큰 특징이라고 할 수 있다. 

 

 

 

Domain Adaptation이 뭘까?

가장 간단하게 설명해 보자면, 모델이 학습한 데이터의 분포와 실제로 사용할 때 input으로 받게 되는 데이터의 분포가 너무 달라서 생기는 문제를 해결하기 위해, 두 도메인 간의 분포를 줄이는 방법론이다. 

 

아래 예시에서 왼쪽(MNIST)은 이미지에서 가장 많이 쓰이는 손글씨 데이터이고, 오른쪽(SVHN)은 실제로 일상 생활에서 사용하게 되는 데이터이다. 딱 봐도, 오른쪽 이미지가 훨씬 다양해보인다. 

 

Image classification task에서 왼쪽 이미지로 학습시킨 모델이 오른쪽 이미지에 대해서는 classification 성능이 좋지 않다면, 그 이유는 바로 domain divergence에 있다. 

좌(source data) 우(target data)

 

따라서, 두 도메인 간의 분포 격차를 줄이기 위해서 아래와 같은 방법을 제시했다. 

 

 

 

먼저, 데이터 샘플이 어떤 도메인에서 왔는지 예측하는 classifier를 만든다. ( source 도메인 / target 도메인 )

 

더보기

- x : input 

- y : ground-truth label 

- N : total data samples 

- n : total  source data samples 

- R_Dt : Target rosk 

- eta : classifier X -> Y (0,1) 

source domain에서 학습된 모델이 target domain에서도 잘 작동하는 것이 핵심. 

즉, Target risk를 적게 하면서 target domain에서의 error를 최소화 하는 것이 중요하다. 

Target domain error는 source domain error + domain divergence로 upper bound를 정할 수 있다.

Domain Divergence 

두 도메인 간의 분포를 좁히기 위해서는 먼저 두 도메인 간의 거리를 측정해야 한다. 

여기서, domain divergence라는 개념이 등장하는데, 이는 두 도메인간의 거리를 측정하는 measure이다. 

 

여기서는 Ben-David의 논문에서 H-divergence를 사용한다. (매우 유명하신 분이라고 한다.)

 

H-divergence란 두 도메인을 잘 구분하는 classifier를 얼마나 담을 수 있는지를 뜻하기에 도메인을 구분하는 능력을 칭한다. 실제로 거리의 개념은 아님! 

더보기

- sup : upper bound를 나타내는 수식 

- hypothesis class H  : Binary classifier "eta"들의 집합 

즉, source domain을 1로 판단할 확률 - target domain을 1로 판단할 확률로 정의한다. 

( 예로, source domain을 1로 판단할 확률이 1이고, target domain도 1로 판단할 확률이 1이라면 두 값의 차는 0이 된다. 

즉, classifier가 도메인을 구분하지 못한다면 위 식은 1/2 - 1/2 = 0 이 되고, 도메인을 잘 구분한다면 1-0=1 or 0-1 = -1 로 절댓값과 2배를 취하면 2가 된다.) 

 

더보기

이때, H 안에 있는 모든 classifier들에 대해서만 계산을 진행하는데, 

이런 의문이 들 수도 있다. (사실 나는 안 들었지만...;;)

 

Q. H의 범위를 무한히 늘리면 잘 구분하는 classifier를 H안에 충분히 넣을 수 있지 않을까?

=> The upper bound of the true minimum risk = empirical risk(what we can calculate) + model complexity

H의 범위를 무한히 늘리면 emperical risk (계산 가능) 는 0으로 수렴할 수 있을지언정, model complexity는 증가하여 결국 upper bound가 올라가게 된다. 

따라서, 우리는 충분한 성능을 갖는 classifier를 갖되, 모델의 복잡도가 너무 높지 않은 최적의 값을 찾아야 한다. 

Approximation 

H-Divergence

H-Divergence 에서 H가 symmetric하다면, 아래 식과 같다. 

Empirical H-divergence

I[a]는 a가 true면 1 아니면 0인 indicator function

위 식은 정확한 계산이 어렵기 때문에, 아래 식으로 근사시킨다. 

Proxy A Distance (PAD) 

ϵ : classification error

위 식에서 min( ) 부분이 아래 식에서 입실론으로 치환되어 있는 꼴이다. 

즉, sample의 출처가 source domain인지 target domain인지 classifier가 정확히 구분할 수 있으면 ϵ=0 이다.

 

즉, 도메인이 달라지더라도 일반화 할 수 있는 모델을 만들기 위해서는, 
모델 학습 단계에서 label classifier의 성능은 좋게, domain classifier의 성능은 나빠지도록 해야한다.

 

Domain-Adversarial neural networks(DANN) 구조

더보기
  • feature extractor(Green)
  • label predictor(Blue)
  • domain classifier(Red)

논문의 목표는 앞의 feature extractor Gf 가 최대한 source와 target에 동시에 포함되는, domain의 특성을 지우고 class 분류에만 쓰인 특징을 뽑게 하는 것이다.

Gradient Reversal Layer 

그러기 위해서, GRL을 추가하여 역전파 당시 -lamda 값을 곱하여 부호를 바꿔 전달한다. 

이는 domain classifier가 헷갈리도록 하기 위함이다. 

 

이를 통해, 1. 학습 당시의 classificaion error를 줄이고
2. Train/Test domain classification error를 최대화 할 수 있다. 

 

Experiments

 

 

Two - moon dataset 을 사용하여, classification 문제를 해결해 보았다. 

더보기
  • o : Upper half-moon 
  • + : Down half-moon
  • Blue : Source domain 
  • Red : Target domain
    (Made by rotating the Blue one.)

 

더보기
  • Red : o predict / Blue : + predict
  • Green : Ground truth
  • sNN : shallow Neural Network
  • sDANN : shallow Domain Adversal Training Neural Network 

1행에서, 왼쪽 반원은 o로 예측하고 오른쪽 반원은 +로 예측해야 한다. 

2행과 3행에서는, 초록색 부분이 ground-truth 이므로 초록색이 많이 보일 수록 예측을 정확하게 한 것이다. 

 

따라서, sDANN의 성능이 더 좋다고 할 수 있다. 

 

References 

1.https://arxiv.org/pdf/1505.07818


2. https://jayeon8282.tistory.com/7 

 

[논문 리뷰] Domain Adversarial Training of Neural Networks

이번 방학에는 domain adaptation 논문들을 정리하여 리뷰해 보려고 한다. 심심할 때 리눅스 쉘에서 자주 쓴 명령어도 정리해 기록해 보고자 한다. 지금 2020년도 기준으로 이번에 리뷰하는 [Domain Advers

jayeon8282.tistory.com

3. https://velog.io/@juyeon048/%EB%85%BC%EB%AC%B8-%EA%B3%B5%EB%B6%80-DANN-Domain-Adversarial-Training-of-Neural-Networks#2domain-adaptationda-%EC%9D%B4%EB%A1%A0%EC%A0%81-%EB%B0%B0%EA%B2%BD 

 

[논문 공부] DANN : Domain-Adversarial Training of Neural Networks

(paper) https://arxiv.org/abs/1505.07818 2016년 JMLR(Journal of Machine Learning Research)발표 논문. 이 논문은 Train과 test의 data distribution이 다른 경우, Domain

velog.io