Variational Inference

본 포스팅은 연세대학교 DSL 학회의 베이지안 인퍼런스 스터디에서 손진원 선배님의 강의를 정리한 내용이다. 관련내용에 관한 손진원 선배님의 깃헙은 링크를 통해 들어갈 수 있다.


1. 개요

베이지안 통계학에서 컴퓨팅을 통해 모수 분포를 찾아내는 방법은 1. Acceptance-Rejection Method와 MCMC (;Markov Chain Monte Carlo) 기반의 2. Metropolitan-Hasting Algorithm, 3. Gibbs Sampler 등이 있다. 이에 대해서는 차후에 포스팅을 할 예정이다.

상기에 거론한 방법은 모수 분포를 Sampling을 통해서 찾아내는 방법이다. 반면 이번에 포스팅할 Variational Inference는 모수분포를 Optimization을 통해서 찾아내는 방법이라고 할 수 있다.


Posterior Distribution $P(Z \mid X)$에 대해서 MCMC 기반의 방법들은 Sampling을 기반으로 비교적 정확하게 모수 분포를 뽑아낼 수 있다. 그러나 컴퓨팅 코스트가 굉장히 무거운 고로 목표로 하는 분포를 찾을 때 까지 오래걸린다.

Variational Inference는 Posterior Distribution을 좀 더 포착하기 쉬운 분포로 근사시키는 것을 목표로 하며 두 분포사이의 거리를 최소화 시키는 (최적화 시키는) 것을 목표로 한다. 이러한 방식은 비록 Global Optima를 보장하진 못하지만 MCMC 방식에 비해서 복잡한 분포도 훨신 빠르게 찾아낼 수 있다고 한다.

즉 정리하면 다음과 같다.

MCMC의 특징

  • Sampling 기반 방법으로 목표 분포를 찾아낸다.
  • Global Optima를 보장한다.
  • Computing Cost가 비교적 크다.

VI의 특징

  • Optimization 기반 방법으로 목표 분포를 찾아낸다.
  • Local Optima에 빠질 수 있다.
  • Computing Cost가 비교적 적다.

2. 목적 함수 정의하기

$p(z \mid \textbf{x}) \approx q(z)$ (where x is Sample Data, z is Latent Variables (target)) 라고 할 시 p함수는 다루기 어렵지만 true함수이고 q 함수는 다루기 쉽지만 true는 아닌 함수이다. 따라서 두 함수의 거리를 정의한다음 이를 최소화 하는 방식을 취한다. 두 함수의 거리는 Kullback - Leibler Divergence라고 불리는 방법으로 정의한다. 그 정의는 다음과 같다.

Kullback - Leibler Divergence

$KL(q(z) \mid \mid p(z \mid \textbf{x})) \equiv \int q(z)log\frac{q(z)}{p(z \mid \textbf{x})}dz $

위의 KL 거리는 Jensen’s inequality에 의해 항상 non-negative하고 q=p일시 0이므로 함수간의 거리 역할을 할 수 있다. KL 거리의 정의를 사용하면 다음과 같은 수식이 또 만족한다.

$log \ p(\textbf{x}) = log \frac{p(z,\textbf{x})}{p(z \mid \textbf{x})} = \int q(z) log \frac{p(z,\textbf{x})}{p(z \mid \textbf{x})} dz $ $=\int q(z) log \frac{p(z,\textbf{x})}{q(z)}dz + \int q(z)log \frac{q(z)}{p(z \mid X)}dz$ $=E_q(log \ p(z,\textbf{x}))-log \ q(z)) + KL(q(z) \mid \mid p(z \mid \textbf{x}))$

  • 위의 수식에서 $log \ p(\textbf{x})$는 일종의 likelihood이며 data $\text{x}$는 픽스되있는 값이므로 상수이다.
  • KL거리는 양수이며 $log \ p(\textbf{x})$ 또한 양수이다.
  • 따라서 $E_q(log \ p(z,\textbf{x}))-log \ q(z)) $ 부분을 최대화 시키면 KL 거리가 줄어든다고 볼 수 있다.

이러한 수식하에 $E_q(log \ p(z,\textbf{x}))-log \ q(z)) $ =F(q) 를 일종의 목적함수라고 두고 이를 최대화 시키는 걸 목표로 함으로써 최소거리 근사를 수행할 수 있다. 이 F(q)를 Evidence Lower BOund(;ELBO)라고 부른다.


+ ELBO에 관하여

  • $log \ p(\textbf{x})$가 베이즈 룰에서 evidence 꼴이기에 F(q)가 Evidence가 취할수 있는 최소값, ELBO라고 불리는 것이다.
  • ELBO는 함수 q를 input으로 받기 떄문에 Functional이다.

  • ELBO에 대해서는 다시 다음과 같은 수식이 성립된다.

$F(q) = \int q(z) log \frac{p(\textbf{x} \mid z)p(z)}{q(z)}dz$ $= E_q(log \ p(\textbf{x} \mid z))-KL(q(z) \mid \mid p(z))$

위의 수식으로 보아 아래와 같은 사실이 성립한다.

  • 수식의 전반부는 q에 대한 log likelihood 꼴이다. 따라서 가능도가 높을수록 ELBO가 커진다.
  • 수식의 후반부는 q와 p 간의 KL거리이다. 따라서 q와 p가 가까울수록 ELBO가 커진다.

3. 최적화 시키기

일반적인 최적화 방법과 마찬가지로 ELBO를 최적화 시킬때도 미분해서 0이 나오는 점을 찾는 것을 목표로 한다. 다른 방법과 다른점은 ELBO는 함수를 인풋으로 가지는 functional이며, 그렇기에 일반적인 미분법을 따르지 않는 다는 점이다. 따라서 functional의 미분법을 살펴볼 필요가 있다.

+ Functional 미분하기

일반적으로 다음 수식이 성립한다고 한다.

$\int \frac{\delta F}{\delta q(x)} \eta(x) dx = lim_{\epsilon \rightarrow 0} \frac{F(q+\epsilon\eta)-F(q)}{\epsilon} = [\frac{d}{d\epsilon}F[q+\epsilon\eta]]_{\epsilon \rightarrow0}$

첫번째 등호 부분은 이해하지 못했다. 그냥 받아들이도록 하자. 첫번째 항 부분에는 functional 미분이 들어있지만 세번째 항에는 일반적인 변수에 대한 미분 부분이 있다. 따라서 세번째는 충분히 풀 수 있는 부분이며 이를 이용하여 첫째 항의 functional 미분 부분을 도출 해낼 수 있다.

ex) $H[q] \equiv - \sum_x q(x) log \ q(x)$ (entropy 함수)

이에 대해 위의 공식을 이용하면 다음과 같은 수식을 유도할 수 있다.

$\sum_x \frac{\delta H}{\delta q(x)}\eta(x) =[\frac{d}{d\epsilon}H[q+\epsilon \eta]]_{\epsilon =0}$

$= -[\frac{d}{d \epsilon}(\sum (q+\epsilon \eta)log (q+ \epsilon \eta))]_{\epsilon=0}$

$= - \sum [\frac{d}{d \epsilon}q \ log(q+ \epsilon \eta)+\frac{d}{d \epsilon} \epsilon \eta log(q+\epsilon\eta)]_{\epsilon=0}$

$=- \sum [\frac{q\eta}{q+\epsilon \eta}+\eta log(q+\epsilon \eta) + \frac{\epsilon \eta \eta}{q+\epsilon \eta}]_{\epsilon=0}$

$= - \sum[1+log(q+ \epsilon\eta)]_{\epsilon=0}\eta(x)$

$= \sum_x [-(1+log \ q)] \eta(x)$

따라서 $\frac{\delta H}{\delta q(x)} = -(1+log \ q)$를 얻어낼 수 있다.


자 이제 목적함수 ELBO 를 최대화 시키는 q* 함수를 찾아보자. ($q^*(x) = argmax_{q}F(q)$)

이를 바로 도출해내기는 어렵다고 한다. 따라서 다음과 같은 가정을 적용한다.

Mean Field Assumption

$q(z)=\prod q_j(z_j)$

이는 찾고자 하는 타깃 변수들이 모두 독립적이라는 가정이며 실제에서 그다지 볼 수 없는 강한 가정이다. 그러나 이러한 가정하에서 ELBO를 구하는 알고리즘이 간소화가 되므로 어쩔 수 없이 차용한다고 한다.

이러한 가정하에서 ELBO는 아래와 같은 식이 된다.

$F(z)=E_q(log \ p(z,\textbf{x})-log \ q(z))$
$=\int \prod q_j(z_j)log \ p(z, \textbf{x})dz - \int \prod q_j (z_j)\sum log \ q_j(z_j) dz $

하지만 여기서 q가 확률밀도함수라는 제약이 붙기에 다음과 같은 라그랑주 승수법을 사용해준다.

$L = \int \prod q_j(z_j)log \ p(z, \textbf{x})dz - \int \prod q_j (z_j)\sum log \ q_j(z_j) dz + \sum \lambda_j (\int q_j(z_j)dz_i - 1)$

이 함수에 대해서 미분을 하고 0이 나오는 지점을 찾으면 다음과 같은 최고점을 찾을 수 있다.

$1.\frac{d}{d \lambda_k}L = \int q_k(z_k)dz_k-1=0$

$2.\frac{\delta}{\delta q_k}L $

$= \frac{\delta}{\delta q_k} [\int \prod q_j(z_j)log \ p(z, \textbf{x})dz_j - \int \prod q_j (z_j)\sum log \ q_j(z_j) dz + \sum \lambda_j (\int q_j(z_j)dz_i - 1) ]$

$=\frac{\delta}{\delta q_k}[\int q_k(z_k)dz_k\int \prod_{-k} q_j(z_j)log \ p(z, \textbf{x})dz_j - \int \prod q_j (z_j) log \ q_k(z_k) dz-\int \prod q_j (z_j)\sum_{-k} log \ q_j(z_j) dz + \sum \lambda_j (\int q_j(z_j)dz_i - 1) ]$

$=\int \prod_{-k}q_j(z_j)log \ p(z,\textbf{x}) dz_j \frac{\delta}{\delta q_k}[\int q_k(z_k) dz_k]-\prod_{-k} \int q_j(z_j)dz_j \frac{\delta}{\delta q_k}[\int q_k(z_k) log \ q_k(z_k) dz_k]$

$-\sum_{i \neq k} \int \prod_{-k} q_j (z_j) log \ q_i(z_i) dz_{i} \frac{\delta}{\delta q_k}[\int q_k(z_k)dz_k]+\lambda_k \frac{\delta}{\delta q_k} \int q_k (z_k) dz_k $

이때 functional 미분 부분들은 다음과 같다.

$\frac{\delta}{\delta q} \int q(z) dz = 1$

$\frac{\delta}{\delta q} \int q(z) log \ q(z) dz = 1+log \ q(z)$

$q_k$가 아닌 $q$ 함수들은 1번조건에 따라 $\int q_j(z_j)dz_j = 1$이다.

위의 세가지를 이용해 식을 간단하게 만들면 다음과 같다.

$=\int \prod_{-k}q_j(z_j)log \ p(z,\textbf{x}) dz_j - (1+log \ q_k(z_k))-C+\lambda_k =0$

(C는 $q_k$와는 무관하므로 상수 취급)

위 식을 넘기면 아래와 같이 성립한다.

$q_k(z_k) = exp(\lambda_k-C-1)exp(\int \prod_{-k} q_j(z_j)log \ p(z,\textbf{x})dz_j)$

$\propto exp(\int \prod_{-k} q_j(z_j) log \ p(z_k\mid z_{-k},\textbf{x})dz_j)$

$\propto exp(E_{q_{-j}}[log \ p (z_j \mid z_{-j},\textbf{x})])$

$q_j^*(z_j) \propto exp(E_{q_{-j}}[log \ p(z_j\mid z_{-j},\textbf{x})])$

이는 Gibbs Sampler에서 보았던 Full Conditional Distribution의 형태와 흡사하다. 즉, j=1의 이상적 분포는 j=1 이외의 모수들에서 나온 분포를 활용하게 된다.

그러나 일반적으로 모든 모수에 대해서 VI를 진행한다고 한다면 j=1 이외의 분포 또한 이상적 분포를 모른다고 할 수 있다. 따라서 사용해주는 것이 Coordinate Ascend Variational Inference(;CAVI) 방법으로 iteration form을 사용함으로써 모든 모수 분포의 최적 분포를 찾아줄 수 있다.

CAVI 방법은 Expectation Maximization(;EM) 기법에서 나온 방법이다. 이에 대해서는 다음에 자세히 다루어 보기로 하자.


지금까지 베이지안 사후분포를 찾아내는 VI 방법에 대해서 살펴보았다. 수학적으로 이해하기 쉬운 기법은 아니지만 현재 빅데이터를 다루는 데 있어서 적합한 방법이라고 한다. 다음에는 MCMC기반의 방법 또한 살펴보고 관련 실습을 진행해보려고 한다.