Variational Autoencoder (VAE)

Jan. 2, 2024, 12:15 p.m. ยท 7 min read ยท ๐ŸŒ๏ธŽ ko

deep learning

Erneset Ryu ๊ต์ˆ˜๋‹˜์˜ 2022ํ•™๋…„๋„ 2ํ•™๊ธฐ <์‹ฌ์ธต์‹ ๊ฒฝ๋ง์˜ ์ˆ˜ํ•™์  ๊ธฐ์ดˆ> ๊ณผ๋ชฉ์„ ๋“ฃ๊ณ  ํ•„์ž๊ฐ€ ์š”์•ฝํ•ด ์ •๋ฆฌํ•œ ๊ธ€์ž…๋‹ˆ๋‹ค.

Introduction

Variational Autoencoder์— ๋Œ€ํ•ด ์„ค๋ช…ํ•˜๊ธฐ ์ „์—, ๊ทธ ์ „๋‹จ๊ณ„์ธ Autoencoder์— ๋Œ€ํ•ด ์•Œ์•„๋ณด๊ณ  ๋„˜์–ด๊ฐ€์ž.
Autoencoder๋Š” ์ด๋ฏธ์ง€์ฒ˜๋Ÿผ ๋†’์€ ์ฐจ์›์˜ ์ž…๋ ฅ ๋ฐ์ดํ„ฐ๋ฅผ ์ž ์žฌ ๊ณต๊ฐ„(latent space)์ƒ์˜ ์ €์ฐจ์›์˜ ํ‘œํ˜„ ๋ฒกํ„ฐ๋กœ ์••์ถ•ํ•˜๋Š” ์ธ์ฝ”๋”์™€, ๋‹ค์‹œ ์ž ์žฌ ๊ณต๊ฐ„์˜ ๋ฒกํ„ฐ๋ฅผ ์›๋ณธ์œผ๋กœ ๋ณต์›ํ•˜๋Š” ๊ฒƒ์„ ๋ชฉํ‘œ๋กœ ํ•˜๋Š” ๋””์ฝ”๋”๋กœ ๊ตฌ์„ฑ๋œ ์‹ ๊ฒฝ๋ง์ด๋‹ค. ์ธ์ฝ”๋”์˜ ์ž…๋ ฅ๊ณผ ๋””์ฝ”๋”์˜ ์ถœ๋ ฅ์˜ ์ฐจ์ด(์˜ˆ๋ฅผ ๋“ค์–ด MSE)๋ฅผ ์†์‹ค ํ•จ์ˆ˜๋กœ ์„ค์ •ํ•จ์œผ๋กœ์จ, ์ธ์ฝ”๋”๋Š” ์›๋ณธ ์ด๋ฏธ์ง€์˜ ํŠน์„ฑ์„ ์ž˜ ์‚ด๋ ค์„œ ๋ฒกํ„ฐ๋กœ ํ‘œํ˜„ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ํ•™์Šตํ•˜๊ฒŒ ๋˜๊ณ , ๋””์ฝ”๋”๋Š” ํ‘œํ˜„ ๋ฒกํ„ฐ๋งŒ์„ ๋ณด๊ณ  ์›๋ณธ ์ด๋ฏธ์ง€์— ๊ฐ€๊น๊ฒŒ ๋ณต์›ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ํ•™์Šตํ•˜๊ฒŒ ๋œ๋‹ค.

์—ฌ๊ธฐ์„œ ๋””์ฝ”๋”๊ฐ€ ํ‘œํ˜„ ๋ฒกํ„ฐ๋งŒ์„ ๋ณด๊ณ  ์›๋ž˜์˜ ์ด๋ฏธ์ง€๋ฅผ ๋ณต์›ํ•ด๋‚ธ๋‹ค๋Š” ์ ์— ์ฃผ๋ชฉํ•˜์ž. ์ฆ‰, ํ‘œํ˜„ ๋ฒกํ„ฐ์—๋Š” ์›๋ž˜ ์ด๋ฏธ์ง€์˜ ์ค‘์š”ํ•œ ์ •๋ณด๋“ค์ด ๋‹ค ๋‹ด๊ฒจ์žˆ๋‹ค๋Š” ๊ฒƒ์ด๋‹ค. ์ฆ‰, ํ‘œํ˜„๋ฒกํ„ฐ๋Š” ์›๋ณธ ์ด๋ฏธ์ง€์— ๋‹ด๊ธด ํŠน์„ฑ(feature)๋“ค์„ ๋ฝ‘์•„ ์ €์žฅํ•ด๋†“์€ ์ฐจ์› ์ถ•์†Œ์˜ ์—ญํ• ์„ ํ•œ๋‹ค๊ณ ๋„ ๋ณผ ์ˆ˜ ์žˆ์„ ๊ฒƒ์ด๋‹ค. ์‹ค์ œ๋กœ VAE๋ฅผ ๋ฐ์ดํ„ฐ์˜ ์ฐจ์›์„ ์ถ•์†Œํ•˜๋Š” ๋ฐ ์‚ฌ์šฉํ•˜๋Š” ๊ฒฝ์šฐ๊ฐ€ ๋งŽ๋‹ค.

Variational Autoencoder๋Š” Autoencoder์— ํ™•๋ฅ ์ ์ธ ๊ฐœ๋…์„ ๋”ํ•ด ๊ฐœ์„ ํ•œ ๊ฒƒ์ด๋‹ค. ์ด ๊ธ€์—์„œ๋Š” VAE๊ฐ€ ์™œ ํƒ€๋‹นํ•œ ๋ชจ๋ธ์ธ์ง€, ๊ทธ motivation์€ ์–ด๋””์—์„œ ๋‚˜์™”๋Š”์ง€๋ฅผ ์ˆ˜ํ•™์ ์œผ๋กœ ์„ค๋ช…ํ•œ ๋ถ€๋ถ„๋“ค์„ ์ •๋ฆฌํ•ด๋ณด๋ ค๊ณ  ํ•œ๋‹ค.

Key Idea of VAE

Variational Autoencoder๋Š” ํฌ๊ฒŒ latent vector $z$๊ฐ€ ์ฃผ์–ด์กŒ์„ ๋•Œ, ์ด๋ฏธ์ง€์˜ ํ™•๋ฅ ๋ถ„ํฌ๋ฅผ ๋‚˜ํƒ€๋‚ด๋Š” ์ธ์ฝ”๋” $p_\theta(x|z)$์™€, ์ด๋ฏธ์ง€ $x$๊ฐ€ ์ฃผ์–ด์กŒ์„ ๋•Œ latent vector $z$์˜ ๋ถ„ํฌ๋ฅผ ์„ค๋ช…ํ•˜๋Š” ๋””์ฝ”๋” $q_\phi(z|x)$์˜ ๋‘ ๋ถ€๋ถ„์œผ๋กœ ๊ตฌ์„ฑ๋œ๋‹ค. Autoencoder๋ฅผ ์ดํ•ดํ–ˆ๋‹ค๋ฉด ์™œ ์ € ๋‘ ํ•จ์ˆ˜๊ฐ€ ๊ฐ๊ฐ ์ธ์ฝ”๋”์™€ ๋””์ฝ”๋”๋กœ ๋ถˆ๋ฆฌ๋Š”์ง€ ์‰ฝ๊ฒŒ ์ดํ•ดํ•  ์ˆ˜ ์žˆ๊ฒ ์ง€๋งŒ ์šฐ์„ ์€ '์ธ์ฝ”๋”, ๋””์ฝ”๋”'๋ผ๋Š” ๋ช…์นญ์— ๋Œ€ํ•ด์„œ๋Š” ๋„˜์–ด๊ฐ€๊ธฐ๋กœ ํ•˜์ž. ์—ฌ๊ธฐ์—์„œ๋Š” ์™œ ์ € ๋‘ ํ•จ์ˆ˜๊ฐ€ ํ•„์š”ํ•œ์ง€๋ฅผ ์กฐ๊ธˆ ๋‹ค๋ฅธ motivation์„ ์‚ฌ์šฉํ•ด ์„ค๋ช…ํ•  ๊ฒƒ์ด๋‹ค.

๋ชฉํ‘œ: Maximum Likelihood Estimation

๋จผ์ €, $N$๊ฐœ์˜ ์ด๋ฏธ์ง€(์˜ˆ๋ฅผ ๋“ค์–ด์„œ $N$์žฅ์˜ ๊ณ ์–‘์ด ์‚ฌ์ง„) $X_1, X_2, \cdots, X_N$๊ฐ€ ์ฃผ์–ด์ ธ ์žˆ๋‹ค๊ณ  ์ƒ๊ฐํ•˜์ž. ์šฐ๋ฆฌ์˜ ๋ชฉํ‘œ๋Š” ์ด๋Ÿฌํ•œ ๊ณ ์ฐจ์›์˜ ์ด๋ฏธ์ง€๋“ค์˜ ๊ธฐ์ €์— ์žˆ๋Š”, underlying structure๋ฅผ ์ดํ•ดํ•˜๋Š” ๊ฒƒ์ด๋‹ค. ๋‹ค๋ฅด๊ฒŒ ๋งํ•˜์ž๋ฉด, $N$์žฅ์˜ ๊ณ ์–‘์ด ์‚ฌ์ง„๋“ค์€ "๊ณ ์–‘์ด ์‚ฌ์ง„์˜ ํ™•๋ฅ ๋ถ„ํฌ"์—์„œ $N$๋ฒˆ ์ƒ˜ํ”Œ๋ง๋œ ๊ฒƒ์ด๋ผ๊ณ  ๊ฐ€์ •ํ•  ์ˆ˜ ์žˆ์œผ๋ฉฐ, ๊ทธ ํ™•๋ฅ ๋ฐ€๋„ํ•จ์ˆ˜ $p_X(x)$๋ฅผ ์•Œ์•„๋‚ด๋Š” ๊ฒƒ์„ ๋ชฉํ‘œ๋กœ ์‚ผ์„ ์ˆ˜ ์žˆ์„ ๊ฒƒ์ด๋‹ค.

์ด๋Š” ๋Š˜ ๊ทธ๋ ‡๋“ฏ์ด ์ตœ์šฐ๋„์ถ”์ •(maximum likelihood estimation, MLE)์„ ํ†ตํ•ด์„œ ํ•  ์ˆ˜ ์žˆ๋‹ค. IID๋กœ $p_X(x)$์—์„œ ์ƒ˜ํ”Œ๋ง์„ ํ–ˆ์„ ๋•Œ, ์ € $N$๊ฐœ์˜ ์ด๋ฏธ์ง€๊ฐ€ ๋ชจ๋‘ ๋‚˜์˜ฌ ํ™•๋ฅ (์ •ํ™•ํžˆ๋Š” likelihood)์€
$$p_X(X_1)p_X(X_2)\cdots p_X(X_n)$$
๊ฐ€ ๋˜๋ฏ€๋กœ ์ด๋ฅผ ์ตœ๋Œ€ํ™”ํ•˜๋ฉด ๋˜๋Š” ๊ฒƒ์ด๋‹ค. ๊ณฑ์œผ๋กœ ์ด๋ฃจ์–ด์ง„ ์‹์€ ๋‹ค๋ฃจ๊ธฐ ์–ด๋ ค์šฐ๋ฏ€๋กœ ๋กœ๊ทธ๋ฅผ ์”Œ์šฐ๋ฉด ์šฐ๋ฆฌ์˜ ๋ชฉํ‘œ๋Š”
$$ \text{maximize}_{p} \sum_{i=1}^N\log p(X_i)$$
๊ฐ€ ๋œ๋‹ค. $p$๋ผ๋Š” ํ•จ์ˆ˜๊ฐ€ $\theta$๋กœ ๋งค๊ฐœํ™”๋˜๋Š” ํ•จ์ˆ˜๋ผ๊ณ  ํ•˜๋ฉด, ๋‹ค์‹œ

$$ \text{maximize}_{\theta \in \Theta} \sum_{i=1}^N\log p_\theta(X_i)$$
๋กœ ์“ธ ์ˆ˜ ์žˆ์„ ๊ฒƒ์ด๋‹ค. ์ด๋•Œ $p_\theta$๋Š” ์‹ ๊ฒฝ๋ง์œผ๋กœ ๊ตฌํ˜„๋˜๋ฉฐ, $\theta$๋Š” ๊ทธ ๊ฐ€์ค‘์น˜๊ฐ€ ๋  ๊ฒƒ์ด๋‹ค.

๊ทธ๋Ÿฐ๋ฐ autoencoder์—์„œ ์„ค๋ช…ํ–ˆ๋“ฏ์ด ์ด๋ฏธ์ง€ $X$์—๋Š” ๊ทธ ๊ธฐ์ €์— $Z$๋ผ๋Š”, ์ด๋ฏธ์ง€์˜ ํŠน์„ฑ์„ ์„ค๋ช…ํ•˜๋Š” ๋ณ€์ˆ˜๊ฐ€ ์žˆ์–ด $Z$๋งŒ ์•Œ๋ฉด $X$๊ฐ€ ๊ฑฐ์˜ ๊ฒฐ์ •๋œ๋‹ค๊ณ  ํ•  ์ˆ˜ ์žˆ๋‹ค. ๋”ฐ๋ผ์„œ ์ „ํ™•๋ฅ ๊ณต์‹๊ณผ ์กฐ๊ฑด๋ถ€ํ™•๋ฅ ์„ ์‚ฌ์šฉํ•ด์„œ
$$ p_\theta(X) = \int p_\theta(X|z)p_Z(z) dz = \mathbb{E}_{Z \sim p_Z}[p_\theta(X|Z)]$$
๋กœ ์“ธ ์ˆ˜ ์žˆ๋‹ค. ๊ทธ๋Ÿฌ๋ฉด ๋‹ค์‹œ ์šฐ๋ฆฌ์˜ ๋ชฉํ‘œ๋Š”

$$ \text{maximize}_{\theta \in \Theta} \sum_{i=1}^N \log \mathbb{E}_{Z \sim p_Z}[p_\theta(X_i|Z)]$$

๋กœ ๋ฐ”๋€Œ๊ฒŒ ๋œ๋‹ค. $p_Z$๋Š” ์—ฌ๊ธฐ์„œ ์•Œ๋ ค์ ธ์žˆ๋Š” ํ•จ์ˆ˜๋กœ, ์ผ๋ฐ˜์ ์œผ๋กœ (๋‹ค๋ณ€์ˆ˜) ํ‘œ์ค€์ •๊ทœ๋ถ„ํฌ๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค.

Importance Sampling

์ด์ œ ์œ„์˜ ์‹์„ ์–ด๋–ป๊ฒŒ ์ตœ๋Œ€ํ™”ํ• ์ง€๋ฅผ ์ƒ๊ฐํ•ด๋ด์•ผ ํ•  ๊ฒƒ์ด๋‹ค. ์—ฌ๊ธฐ์„œ ๋ฌธ์ œ์ ์€ ์‹์— ๊ธฐ๋Œ“๊ฐ’์ด ๋ผ์–ด์žˆ๋‹ค๋Š” ๊ฒƒ์ด๋‹ค. $Z$๊ฐ€ ์ด์‚ฐํ™•๋ฅ ๋ณ€์ˆ˜๋ผ๋ฉด ๊ทธ๋ƒฅ
$\mathbb{E}_{Z \sim p_Z}[p_\theta(X|Z)]=\sum_i p_Z(z_i)p_\theta(X|z_i)$
์ฒ˜๋Ÿผ ๋‹ค ๋”ํ•ด๋ฒ„๋ฆฌ๋ฉด ๋œ๋‹ค. ํ•˜์ง€๋งŒ $Z$๋Š” ์—ฐ์†์ ์ธ ๋ถ„ํฌ๋ฅผ ๊ฐ€์ง€๊ธฐ ๋•Œ๋ฌธ์— $\int p_\theta(X|z)p_Z(z) dz$๋ฅผ ๊ณ„์‚ฐํ•ด์•ผ ํ•˜๋ฉฐ, ์ด๋Š” ๊ต‰์žฅํžˆ ์–ด๋ ต๋‹ค. ์ด ๋•Œ๋ฌธ์— $Z_i$๋ฅผ ์ƒ˜ํ”Œ๋งํ•ด์„œ $\mathbb{E}$์˜ ๊ทผ์‚ฌ๊ฐ’์„ ๊ตฌํ•ด ์‚ฌ์šฉํ•˜๊ฒŒ ๋œ๋‹ค.

$$\sum_{i=1}^N \log \mathbb{E}_{Z \sim p_Z} [p_\theta(X_i|Z)] \approx \sum_{i=1}^N \log p_\theta(X_i|Z_i)\quad\quad Z_i \sim p_Z$$
์‚ฌ์‹ค ์œ„์˜ ์‹์€ ๊ฐ ์ด๋ฏธ์ง€ $X_i$์— ๋Œ€ํ•ด์„œ, ๊ทธ ์ด๋ฏธ์ง€๋ฅผ ๋งŒ๋“ค์–ด๋‚ธ(๋งŒ๋“ค์–ด๋ƒˆ์„ ๊ฒƒ์ด๋ผ๊ณ  ์ƒ๊ฐ๋˜๋Š”) latent vector $Z$๋ฅผ ํ•œ๊ฐœ์”ฉ๋งŒ ์ƒ˜ํ”Œ๋งํ•˜์—ฌ ๊ตฌํ•˜๊ธฐ ๋•Œ๋ฌธ์— ๋งค์šฐ ๋ถ€์ •ํ™•ํ•œ ๊ทผ์‚ฌ์ด๋‹ค. ๋”ฐ๋ผ์„œ ์šฐ๋ฆฌ๋Š” Importance Sampling์ด๋ผ๋Š” ๊ฐœ๋…์„ ๋„์ž…ํ•ด์„œ ์ด๋ฅผ ํ•ด๊ฒฐํ•œ๋‹ค.

Importance Sampling์˜ ๊ฐœ๋…

$X$๊ฐ€ $f(x)$๋ผ๋Š” ํ™•๋ฅ ๋ฐ€๋„ํ•จ์ˆ˜๋ฅผ ๊ฐ€์งˆ ๋•Œ $\mathbb{E}_{X\sim f}[\phi(X)]$๋ฅผ ๊ตฌํ•ด์•ผ ํ•˜๋Š” ์ƒํ™ฉ์„ ์ƒ๊ฐํ•ด๋ณด์ž. ๊ทธ๋Ÿฐ๋ฐ ์ ๋ถ„์„ ์‹ค์ œ๋กœ ํ•ด์„œ ์ด๋ฅผ ๊ตฌํ•˜๋Š” ๊ฒƒ์ด ์–ด๋ ค์šด ์ƒํ™ฉ์ด ๋งŽ๊ธฐ ๋•Œ๋ฌธ์— ์œ„์™€ ๊ฐ™์ด ๋งŽ์€ ๊ฒฝ์šฐ $X$๋ฅผ ์ ๋‹นํžˆ ์ƒ˜ํ”Œ๋งํ•ด์„œ
$$\mathbb{E}_{X\sim f}[\phi(X)]\approx \frac{1}{N}\sum_{i=1}^k \phi(X_i)$$
๊ณผ ๊ฐ™์ด ๊ทผ์‚ฌํ•ด์„œ ์‚ฌ์šฉํ•œ๋‹ค. ์ด๋ฅผ Monte Carlo Estimation์ด๋ผ๊ณ  ํ•œ๋‹ค. ํฐ ์ˆ˜์˜ ๋ฒ•์น™์— ์˜ํ•ด, $N$์ด ์ปค์ง€๋ฉด ์ปค์งˆ์ˆ˜๋ก ์šฐ๋ณ€์€ ์‹ค์ œ ๊ธฐ๋Œ€๊ฐ’๊ณผ ๋งค์šฐ ์œ ์‚ฌํ•œ ๊ฐ’์„ ๊ฐ€์ง€๊ฒŒ ๋  ๊ฒƒ์ด๋‹ค.

ํ•˜์ง€๋งŒ ์œ„์™€ ๊ฐ™์€ ๊ทผ์‚ฌ๋Š” ๋•Œ๋•Œ๋กœ ๋ถ„์‚ฐ์ด ๋„ˆ๋ฌด ์ปค์„œ ์‹ค์ œ๋กœ๋Š” ์‚ฌ์šฉํ•˜๊ธฐ ํž˜๋“ค๊ฑฐ๋‚˜, $N$์ด ์•„์ฃผ ์ปค์•ผ ์ •ํ™•ํ•ด์งˆ ๋•Œ๊ฐ€ ๋งŽ๋‹ค. ๋”ฐ๋ผ์„œ Importance Sampling์ด๋ผ๋Š” ๊ฐœ๋…์„ ์‚ฌ์šฉํ•ด์„œ ๋ถ„์‚ฐ์„ ์ค„์ด๊ฒŒ ๋œ๋‹ค. Importance Sampling์˜ ํ•ต์‹ฌ์€ X์˜ ๋ถ„ํฌ ํ•จ์ˆ˜ $f$๋ฅผ ๋‹ค๋ฅธ "์ข‹์€" ํ•จ์ˆ˜ $g$๋กœ ๋ฐ”๊พธ๋Š” ๊ฒƒ์ด๋‹ค. ์ด๋ฅผ ์œ„ํ•ด ์•„๋ž˜์™€ ๊ฐ™์€ ํ…Œํฌ๋‹‰์„ ์‚ฌ์šฉํ•œ๋‹ค.
$$\mathbb{E}_{X\sim f}[\phi(X)] = \int \phi(x) f(x) dx = \int \frac{\phi(x)f(x)}{g(x)} g(x) dx$$
์ด๋Š” ๊ธฐ๋Œ€๊ฐ’์„ ์‚ฌ์šฉํ•ด ์•„๋ž˜์™€ ๊ฐ™์ด ์“ธ ์ˆ˜ ์žˆ๋‹ค.
$$ \mathbb{E}_{X\sim f}[\phi(X)] = \mathbb{E}_{X \sim g}\left[\frac{\phi(X)f(X)}{g(X)} \right]$$
์•ž์„œ ๋งํ–ˆ๋“ฏ์ด, $X$๊ฐ€ ๋”ฐ๋ฅด๋Š” ๋ถ„ํฌ(ํ™•๋ฅ ๋ฐ€๋„ํ•จ์ˆ˜)๊ฐ€ $f$์—์„œ $g$๋กœ ๋ฐ”๋€ ๊ฒƒ์„ ๋ณผ ์ˆ˜ ์žˆ์„ ๊ฒƒ์ด๋‹ค. $g$๋ฅผ ์ ์ ˆํ•˜๊ฒŒ ์„ ํƒํ•˜๋ฉด ์›๋ž˜๋ณด๋‹ค ๋” ์ •ํ™•ํ•œ(variance๊ฐ€ ๋‚ฎ์€) ์ถ”์ •์„ ํ•  ์ˆ˜ ์žˆ๊ฒŒ ๋œ๋‹ค.

๊ทธ๋Ÿฌ๋ฉด $g$๋Š” ์–ด๋–ป๊ฒŒ ์„ ํƒํ•ด์•ผ ํ• ๊นŒ? ์ด์ƒ์ ์œผ๋กœ๋Š”
$$ g(X) = \frac{\phi(X)f(X)}{I} \quad(I = \int \phi(x) f(x) dx) $$
๋กœ ๋†“์œผ๋ฉด ๋ถ„์‚ฐ์ด 0์œผ๋กœ ์ตœ์†Œ๊ฐ€ ๋œ๋‹ค. ๊ทธ๋Ÿฐ๋ฐ $I$๋Š” ์šฐ๋ฆฌ๊ฐ€ ์•Œ๊ณ  ์žˆ๋Š” ๊ฐ’์ด ์•„๋‹ˆ๋ฏ€๋กœ($I = \mathbb{E}_{X\sim f}[\phi(X)]$์ด๋ฏ€๋กœ $I$๋ฅผ ์•Œ๊ณ  ์žˆ๋‹ค๋ฉด ์• ์ดˆ์— ์ด ์ง“์„ ํ•  ํ•„์š”๊ฐ€ ์—†๋‹ค) ์ด๋Ÿฌํ•œ ํ•จ์ˆ˜๋Š” ์šฐ๋ฆฌ๊ฐ€ ์‚ฌ์šฉํ•  ์ˆ˜ ์—†๋‹ค.

๋”ฐ๋ผ์„œ $g$๊ฐ€ ์ด์ƒ์ ์ธ ํ•จ์ˆ˜ $\frac{\phi(X)f(X)}{I}$์™€ ๊ฐ–๋Š” ๊ฑฐ๋ฆฌ๋ฅผ ๊ตฌํ•ด์„œ, ์ด๊ฒƒ์ด ์ตœ์†Œํ™”๋˜๋„๋ก ํ•จ์œผ๋กœ์จ ์–ด๋Š ์ •๋„ ์ข‹์€ $g$๋ฅผ ๊ตฌํ•  ์ˆ˜ ์žˆ๋‹ค. $g$๋Š” $\theta$๋กœ parametrize๋œ ์‹ ๊ฒฝ๋ง์œผ๋กœ ๊ตฌ์„ฑ๋˜์–ด ์žˆ๋‹ค๊ณ  ๊ฐ€์ •ํ•˜์ž. KL-Divergence๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด,

$$ D_{KL} (g_\theta||\phi f/I) = \mathbb{E}_{x\sim g_\theta}\left[{\log\left(\frac{Ig_\theta (X)}{\phi(X)f(X)}\right)}\right]$$

$$ = \mathbb{E}_{x\sim g_\theta}\left[{\log\left(\frac{g_\theta (X)}{\phi(X)f(X)}\right)}\right] + \log I $$

์ด๋ฉฐ, $\log I$๋Š” $\theta$์— ๋Œ€ํ•ด์„œ๋Š” ์ƒ์ˆ˜์ด๋ฏ€๋กœ $\mathbb{E}_{x\sim g_\theta}\left[{\log\left(\frac{g_\theta (X)}{\phi(X)f(X)}\right)}\right]$๋ฅผ SGD๋ฅผ ์‚ฌ์šฉํ•ด์„œ ์ตœ์†Œํ™”ํ•˜๋ฉด ๋œ๋‹ค. ์ด๋ ‡๊ฒŒ ๊ตฌํ•œ $g_\theta$๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ Importance Sampling์„ ํ•˜๋ฉด $I$๋ฅผ ๋น„๊ต์  ๋‚ฎ์€ variance๋กœ ์ถ”์ •ํ•  ์ˆ˜ ์žˆ๋‹ค.

Z๋ฅผ importance samplingํ•˜์ž

์ด์ œ ์›๋ž˜์˜ ๋ฌธ์ œ๋กœ ๋Œ์•„์™€์„œ, ์ด๋ฏธ์ง€ $X_i$์— ๋Œ€ํ•ด
$$p_\theta(X_i) =\mathbb{E}_{Z \sim p_Z} [p_\theta(X_i|Z)] $$
๋ฅผ $Z_i\sim q_i(z)$๋ฅผ ์‚ฌ์šฉํ•œ importance sampling์„ ํ†ตํ•ด ๊ทผ์‚ฌํ•ด ๋ณด์ž.
$$\mathbb{E}_{Z \sim p_Z} [p_\theta(X_i|Z)] \approx p_\theta(X_i|Z_i)\frac{p_Z(Z_i)}{q_i(Z_i)}\quad \quad Z_i \sim q_i(z)$$
์ด๋•Œ $q_i$๋Š” ์•ž์„œ ์„ค๋ช…ํ•œ ๊ฒƒ๊ณผ ๋งˆ์ฐฌ๊ฐ€์ง€๋กœ

$$q_i^*(z) = \frac{p_\theta(X_i|z)p_Z(z)}{p_\theta(X_i)} = p_\theta(z|X_i)$$

์ผ ๋•Œ ์ตœ๋Œ€๊ฐ€ ๋  ๊ฒƒ์ด๋‹ค. ๊ทธ๋Ÿฐ๋ฐ ๋ฒ ์ด์ฆˆ ์ •๋ฆฌ์— ์˜ํ•ด์„œ, ์ด๋Š” $p_\theta(z|X_i)$์™€ ๊ฐ™๋‹ค. ๋ฌผ๋ก  ์ด๋Š” ์ •ํ™•ํ•˜๊ฒŒ ๊ณ„์‚ฐ์ด ๋ถˆ๊ฐ€๋Šฅํ•˜๋ฉฐ($p_\theta(X_i)$๋ฅผ ๋ชจ๋ฅด๋‹ˆ), KL-Divergence ๋ฅผ ํ†ตํ•ด $q_i^*$์™€ ์ตœ๋Œ€ํ•œ ๋น„์Šทํ•œ $q_i$๋ฅผ ์ฐพ์•„์•ผ ํ•œ๋‹ค.

$$D_{KL}(q_i(\cdot) || q_i^*(\cdot)) = D_{KL}(q_i(\cdot) || p_\theta(\cdot|X_i)) = \mathbb{E}_{Z\sim q_i}\log\left(\frac{q_i(Z)}{p_\theta(Z|X_i)} \right)$$

$$=\mathbb{E}_{Z\sim q_i}\log\left(\frac{q_i(Z)}{p_\theta(X_i|Z)p_Z(Z)/p_\theta(X_i)} \right)$$

$$=\mathbb{E}_{Z\sim q_i} \left[\log(q_i(Z)) - \log(p_\theta(X_i|Z))-\log p_Z(Z) \right]+ \log p_\theta(X_i)$$

๋งˆ์ง€๋ง‰ ์ค„์—์„œ, $\log p_\theta(X_i)$๋Š” $Z$์™€ ๋ฌด๊ด€ํ•œ ํ•ญ์ด๋ฏ€๋กœ ์ตœ์†Œํ™”ํ•  ๋•Œ ๋ฌด์‹œํ•ด์ค˜๋„ ๋œ๋‹ค. ๊ทธ๋Ÿฌ๋ฉด $q_i(Z)$, $p_\theta(X_i|Z)$, $p_Z(Z)$๋Š” ๋ชจ๋‘ ์šฐ๋ฆฌ๊ฐ€ ๊ณ„์‚ฐํ•  ์ˆ˜ ์žˆ๋Š” ํ•ญ๋“ค์ด๋ฏ€๋กœ $q_i$๋ฅผ ์ž˜ ์กฐ์ ˆํ•จ์œผ๋กœ์จ ์ตœ์†Œํ™”๊ฐ€ ๊ฐ€๋Šฅํ•˜๋‹ค.

Amortized Inference

๊ทธ๋Ÿฐ๋ฐ ์œ„์—์„œ $q_i$๋ฅผ ๋ณด๋ฉด index $i$๊ฐ€ ๋ถ™์–ด์žˆ๋Š” ๊ฒƒ์„ ์•Œ ์ˆ˜ ์žˆ๋‹ค. ์ฆ‰, ๊ฐ ๋ฐ์ดํ„ฐ(์ด๋ฏธ์ง€) $X_i$์— ๋Œ€ํ•ด์„œ ๊ฐœ๋ณ„์ ์œผ๋กœ ์ตœ์ ํ™” ๋ฌธ์ œ๋ฅผ ํ’€๊ณ  ์žˆ๋Š” ๊ฒƒ์ด๋‹ค. ๋‹น์—ฐํžˆ ์ด๋Š” ๊ณ„์‚ฐ์ด ๋งค์šฐ ๋งŽ์ด ๊ฑธ๋ฆด ๊ฒƒ์ด๋‹ค.

๋”ฐ๋ผ์„œ ์šฐ๋ฆฌ๋Š” ํ•จ์ˆ˜ $q$๋ฅผ ์‹ ๊ฒฝ๋ง์œผ๋กœ ๊ตฌ์„ฑํ•˜๊ณ , ๊ทธ ๊ฐ€์ค‘์น˜ $\phi$๋กœ parametrizeํ•˜์—ฌ $q_\phi$๋กœ ๋งŒ๋“ ๋‹ค. ๊ทธ๋ฆฌ๊ณ 
$$\sum_{i=1}^ND_{KL}(q_\phi(\cdot|X_i) || q_i^*(\cdot))$$

๋ฅผ loss ํ•จ์ˆ˜๋กœ ์‚ผ์•„์„œ SGD๋ฅผ ์‚ฌ์šฉํ•ด ์ตœ์†Œํ™”ํ•œ๋‹ค. ์ด๋ ‡๊ฒŒ ํ•˜๋ฉด, $q_\phi$๋Š” ๋„ฃ์–ด์ฃผ๋Š” ์ด๋ฏธ์ง€ $X_i$์— ๋”ฐ๋ผ์„œ ๋‹ค๋ฅธ ๋ถ„ํฌ $q_i(z)$๋ฅผ ๋‚˜ํƒ€๋‚ด๊ฒŒ ๋œ๋‹ค. ์ฆ‰ ํ•˜๋‚˜์˜ ํ•จ์ˆ˜ $q_\phi(z|X)$๋งŒ์œผ๋กœ $N$๊ฐœ์˜ ๊ณ„์‚ฐ๊ณผ์ •์„ ๋Œ€์‹ ํ•  ์ˆ˜ ์žˆ๋Š” ๊ฒƒ์ด๋‹ค. ์ฆ‰,
$$q_\phi(z|X_i) = q_i(z) \approx q_i^*(z) = p_\theta(z|X_i)\quad \text{for all } i = 1, \cdots, N$$
๊ฐ€ ๋˜๋Š” ๊ฒƒ์ด๋‹ค. ์ด $q_\phi$๊ฐ€ ๋ฐ”๋กœ ์ธ์ฝ”๋”๊ฐ€ ๋œ๋‹ค.

Encoder์™€ Decoder์˜ ์ตœ์ ํ™”

์ด์ œ ์ธ์ฝ”๋” $q_\phi$์™€ $p_\theta$๋ฅผ ์ตœ์ ํ™”ํ•˜๋ฉด ๋œ๋‹ค. ๋จผ์ € ์ธ์ฝ”๋”์˜ ๋ชฉํ‘œ๋Š” ์•ž์—์„œ ์„ค๋ช…ํ•œ ๊ฒƒ์ฒ˜๋Ÿผ ๊ฐ ์ด๋ฏธ์ง€ $X_i$์— ๋Œ€ํ•ด importance sampling์„ ํ•˜๋Š” ์ตœ์ ์˜ ํ•จ์ˆ˜ $q_i^*$๋ฅผ amortized inference๋กœ ๊ทผ์‚ฌํ•˜๋Š” ๊ฒƒ์ด ๋œ๋‹ค.

$$\text{minimize}_{\phi\in\Phi}\sum_{i=1}^ND_{KL}(q_\phi(\cdot|X_i) || q_i^*(\cdot))$$

$$= \text{maximize}_{\phi\in\Phi} \sum_{i=1}^N \mathbb{E}_{Z\sim q_\phi(z|X_i)}\log\left(\frac{q_i(Z)}{p_\theta(Z|X_i)} \right) $$

$$= \text{maximize}_{\phi\in\Phi}\mathbb{E}_{Z\sim q_\phi(z|X_i)} \left[\log\left(\frac{p_\theta(X_i|Z)p_Z(Z)}{q_\phi(Z|X_i)}\right) \right]$$

$$= \text{maximize}_{\phi\in\Phi}\sum_{i=1}^N \mathbb{E}_{Z\sim q_\phi(z|X_i)}\left[ \log p_\theta(X_i|Z)-D_{KL} (q_\phi(\cdot|X_i)||p_Z(\cdot)) \right]$$

๋””์ฝ”๋”์˜ ๋ชฉํ‘œ๋Š” (๋‹น์—ฐํžˆ) Maximum Likelihood Estimation์„ ์ˆ˜ํ–‰ํ•˜๋Š” ๊ฒƒ์ด๋‹ค.

$$ \text{maximize}_{\theta\in\Theta}\sum_{i=1}^N \log p_\theta(X_i) $$

$$ = \text{maximize}_{\theta\in\Theta} \log\mathbb{E}_{Z\sim p_Z}\left[p_\theta(X_i|Z)\right]$$

$$\approx\text{maximize}_{\theta \in\Theta} \sum_{i=1}^N \log\left(\frac{p_\theta(X_i|Z)p_Z(Z)}{q_\phi(Z|X_i)} \right)\quad (Z\sim q_{\phi}(z|X_i))$$

$$\approx\text{maximize}_{\theta \in\Theta} \sum_{i=1}^N \mathbb{E}_{Z_\sim q_{\phi}(z|X_i)}\left[\log\left(\frac{p_\theta(X_i|Z)p_Z(Z)}{q_\phi(Z|X_i)} \right)\right]$$

$$= \text{maximize}_{\theta \in\Theta} \sum_{i=1}^N \mathbb{E}_{Z_\sim q_{\phi}(z|X_i)} \left[\log p_\theta(X_i|Z)\right] - D_{KL}(q_\phi (\cdot|X_i)||p_Z(\cdot)) $$

์šฐ์—ฐํžˆ๋„ ๋‘ ์‹์˜ ํ˜•ํƒœ๊ฐ€ ๋˜‘๊ฐ™์€ ๊ฒƒ์„ ์•Œ ์ˆ˜ ์žˆ๋‹ค! ๋”ฐ๋ผ์„œ ์œ„ ์‹์„ ์ตœ๋Œ€ํ™”ํ•˜๋Š” $\theta$์™€ $\phi$๋ฅผ ์ฐพ์œผ๋ฉด ๋œ๋‹ค. ์ฆ‰,
$$ \text{maximize}_{\theta \in\Theta, \phi \in \Phi} \sum_{i=1}^N \mathbb{E}_{Z_\sim q_{\phi}(z|X_i)} \left[\log p_\theta(X_i|Z)\right] - D_{KL}(q_\phi (\cdot|X_i)||p_Z(\cdot)) $$
๋ฅผ ์ฐพ๋Š” ๊ฒƒ์ด VAE์˜ training objective๊ฐ€ ๋œ๋‹ค.