Deeper Learning

GAN: Generative Adversarial Nets 본문

AI/Deep Learning

GAN: Generative Adversarial Nets

Dlaiml 2022. 3. 28. 14:33

Ian J. Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, Yoshua Bengio, Universite de Montreal (2014.06)

 

Abstract

  • 데이터 분포를 파악하는 generative model $G$와 sample이 training data에서 온 것인지 생성된 것인지 판별하는 discriminator $D$를 적대적인 프로세스로 동시에 학습하는 생성모델 프레임워크를 제시
  • $G$는 $D$가 제대로 판별하지 못하도록 학습되는 minimax two-player game
  • $G$와 $D$는 MLP로 구성하였으며 역전파를 통해 학습이 가능한 모델
  • Markov chain 또는 approximate inference를 학습, 추론과정에서 사용하지 않으며 quantitative, qualitative evaluation으로 제시한 프레임워크의 가능성을 증명

1. Introduction

  • 딥러닝은 분류 task에서 relu와 같은 piecewise linear units을 사용하여 gradient vanishing과 같은 문제를 완화시키고 성공적인 학습을 이룸
  • 하지만 Deep generative model의 경우 MLE나 다른 접근법의 적용에서 마주치는 intractable probabilistic computations 근사에서 어려움이 있으며 piecewise linear units의 장점을 활용하기도 쉽지 않음
  • 이러한 문제점을 피해갈 수 있는 생성모델 추정 방법론을 제시
  • 제시한 적대적 네트워크 프레임워크에서 generative model은 model 분포와 data 분포를 판별하는 discriminator와 적대적인 관계를 가진다.
  • 마치 generative model은 위조지폐범, discrminative model은 경찰과 같은 관계로 둘의 경쟁으로 generative model은 판별 불가능한 진짜 같은 sample을 생성하도록 학습된다.
  • random noise를 input으로 sample을 생성하는 MLP로 구성된 generative model, sample을 판별하는 MLP로 구성된 discriminative model을 사용하는 case를 adversarial nets으로 명명

2. Related work

  • explicitly probability distribution defining 대신 generative machine이 목표로 하는 분포로부터 sample을 뽑아내도록 학습하는 method가 존재
  • 위 method의 경우 역전파로 학습이 가능하다는 장점이 있음
  • denosing auto-encoders 그리고 이를 확장한 형태의 generative stochasitc network(GSN) framework은 모두 parameterized Markov chainn을 정의 (generative Markov chain의 1 step으로 machine의 parameter를 학습)
  • 제시한 adversarial nets은 generation step에서 feedback loop가 필요없기 때문에 sampling에 Markov chain을 사용하지 않는다
  • adversarial nets은 relu와 같은 piecewise linear units의 역전파의 성능 측면의 장점을 활용할 수 있으나 feedback loop 내의 unbounded activation으로 인한 문제가 존재
  • 역전파로 학습하는 생성모델의 최신 연구는 VAE, stochasitc backpropagation 등이 있다

3. Adversarial nets

  • generator의 분포 $p_g$가 data $x$에 대한 분포를 학습하기 위해 input noise variables $p_z(z)$의 prior를 정의한다
  • data space로의 매핑: $G(z;\theta_g)$에서 $G$는 parametes $\theta_g$를 가지는 MLP로 구성되어있다
  • $D(x;\theta_d)$는 $x$가 $p_g$가 아닌 data에서 나왔을 확률(scalar)이 output

  • $G$와 $D$는 위와 같은 value function을 사용하여 학습된다

  • 위 그림을 보면 1번 지점에서는 생성 결과가 좋지 않음에도 적은 graient가 발생하며 2번 지점에서는 생성 결과가 1번 지점보다 훨씬 좋음에도 불구하고 더 큰 gradient가 발생, 초기 학습이 더딘 문제를 해결하기 위해 $\log(1-D(G(z)))$를 minimize하지 않고 $\log(D(G(z))$를 maximize하는 방식을 사용

  • 검정선: data 분포, 파랑선: discriminator decision, 초록선: generative distribution $p_g$
  • (b): discriminator가 학습, (c): Generator 학습, (d): 여러 스텝 이후 Nash equilibrium 도달

4. Theoredical Results

  • $G$는 확률분포 $p_g$를 직접 정의하는 것이 아닌 $G$가 만들어낸 samples $G(z)$가 특정 지표를 만족하도록 학습하면서 implicit하게 정의

  • 위 알고리즘을 사용하면 충분한 capacity와 time이 주어진다면 $p_{data}$의 좋은 estimator로 수렴

4.1 Global Optimally of $p_g$ = $p_{data}$

  • $G$가 fixed 상황에서, optimal discriminator $D$는 아래와 같다

  • 둘다 0이 아닌 실수 a,b에서 $y=a\log(y)+b\log(1-y)$는 $\frac{a}{a+b}$에서 최댓값을 가짐
  • $D$의 학습 목표는 조건부 확률 $P(Y=y|x)$를 추정하기 위한 maximizing log-likelihood로 해석할 수 있다. optimal $D^*$를 가정하고 다시 수식을 써보면 아래와 같다

  • Theorem 1: C(G)의 global minimum은 오직 $p_g = p_{data}$일 때 얻을 수 있으며 이때 C(G)의 값은 -log4

  • 익히 알고있는 Kullback-Leibler Divergence 식을 이용하여 식을 전개해보면 위처럼 $p_{data},p_g$의 Jensen-Shannon divergence를 줄이는 것이 GAN의 목적함수임을 알 수 있다 (자세한 전개는 생략, KL식 대입이 전부인 간단한 과정)
  • Jensen-Shannon divergence는 0이상의 값을 가지기 때문에 global minimum $C(G) = -\log(4)$이며 이때 $p_g = p_{data}$가 유일한 solution (생성모델이 data의 분포를 완벽하게 모방)

4.2 Convergence of Algorithm 1

  • 지금까지 설명한 세팅으로 ($G, D, V(G,D))$ 학습을 진행할 경우 $p_g$가 $p_{data}$로 적은 update로도 수렴한다
  • 증명의 골자는 아래와 같은 목적함수가 $\theta_g$에 대해서 convex하다는 것

  • 실제로는 $p_g$가 아닌 parameter set $\theta_g$를 optimize하기 때문에 parameter space에 critical point가 여럿 존재하지만 MLP의 성능은 뛰어났다 (MLP는 이론적 보장이 부족하더라도 reasonble model)

5. Experiments


6. Advantages and disadvantages

장점

  • 학습 중 추론이 필요없음
  • Markov chain이 필요 없음
  • 결과가 sharp

단점

  • $p_g(x)$의 explicit representation의 부재
  • 학습이 어렵다


7. Conclusions and future work

Future work

  • Conditional generative model
  • generator의 학습이 끝난 후 x → z를 추정하는 auxiliary network 학습을 통한 approximate inference 학습
  • MP-DBM의 stochastic extension
  • Semi-superviesed learning (discriminator)
  • Efficiency improvements

정리 & 후기

  • 적대적인 프로세스 기반 생성모델인 GAN을 제시
  • 데이터 분포와 생성 데이터 분포의 Jensen-Shannon divergence를 줄이도록 학습
  • 이전에 보았던 논문이라 논문의 흐름만 가볍게 기술

Reference

[0] Ian J. Goodfellow at al. (2014). "Generative Adversarial nets". https://arxiv.org/abs/1406.2661

 

 

Comments