VAE解讀

VAE簡介

生成式模型有兩個著名的傢族GAN和VAE,圖片可以看做高維數據的分佈 P(X),我們手上的數據集 {x_{1}, x_{2}…x_{n}} 可以看做從 P(X) 中采樣得到的若幹樣本。 P(X) 很難直接求得,常用的思路是通過引入隱藏變量(latent variable) z ,尋找 Z 空間到 X 空間的映射,這樣我們通過在 Z 空間采樣映射到 X 空間就可以生成新的圖片。

我們希望通過映射 Z sim X^{'} 得到的分佈 X^{'} 與真實的分佈 X 盡可能接近,但是由於不知道它們具體的形式,無法計算KL散度來衡量。GAN的思路很直接-用一個神經網絡(Discriminator)去衡量 XX' 的相似度,VAE則使用瞭一種更間接也更巧妙的思路。

GAN中對於隱變量 z 有著很強的先驗假設(Gaussian distribution),實際上隻有decode的過程而沒有encode。通過VAE(variational auto-encoder)的全稱就能看出來VAE的主要貢獻就在於在encode上的改進。如今被人熟知的VAE模型中的 z ,其prior distribution與posterior distribution均為正態分佈,但網上很少解釋地清楚為什麼會有這麼奇怪的先驗假設,根據VAE的論文寫一點自己的理解。

問題定義

我們有數據集 X=left{ x^{(i)} right}_{i=1}^{N} ,並且假設圖片 x 由潛在變量 z 生成,其過程可以描述為:

  • z 的先驗分佈 p_{theta^{*}}(z) 中采樣 z
  • x 的條件概率分佈 p_{theta^{*}}left( x |zright) 中采樣產生x

這裡兩個概率分佈 p_{theta^{*}}(z)p_{theta^{*}}left( x |zright) 是來自 p_{theta}(z)p_{theta}left( x |zright) 這兩個parametric families of distributions。我們的任務是找到或者逼近這個參數 theta^{*}

論文作者在這裡還特意強調瞭一下他們這個estimation算法的通用性,即使面對intractability和large dataset也能用

這段意思大概是說, p_{theta}(x)p_{theta}left( z |xright) 都是算不出來的,因為算這兩個都需要用到 p_{theta}left( x |zright)做積分 ,而這個是用復雜的神經網絡模擬的,所以求解很困難。

既然沒辦法直接求解,那就引入一個新的分佈去逼近(模擬)這個不好求解的 p_{theta}left( z |xright) ,記做 q_{phi}left( z|x right) ,我們接下來要做的,就是去同時求解 thetaphi

論文對於encoder和decoder的圖形化表示

虛線部分表示 q_{phi}left( z|x right) ,相當於encode過程,由 X 得到 Z ;實線部分表示 p_{theta}left( z right)p_{theta}left( x|z right) ,相當於decode過程。之所以encode過程和decode過程要用到兩個參數,簡單的說是因為encode和decode是“不可逆”的,所以分別引入一個變量。至於為什麼不可逆,前面講過p_{theta}left( z |xright)=p_{theta}left( x |zright)p_{theta}left( zright)/p_{theta}left( xright) 是intractable的。

優化目標

極大似然法:優化模型參數,使得出現已知樣本 X 的概率最高。

log p_{theta}left( x^{(1)},…,x^{(N)}right)=sum_{i =1}^{N}{logp_{theta}}left(x^{(i)}right)

等式成立前提是Dataset中數據樣本是獨立同分佈的,我們的目的就是使這個marginal likelihood越大越好。

接下來是第一個非常重要的公式。

log p_{boldsymbol{theta}}(textbf{x}^{(i)})=D_{KL}(q_{boldsymbol{phi}}(textbf{z}|textbf{x}^{(i)})||p_{boldsymbol{theta}}(textbf{z}|textbf{x}^{(i)}))+mathcal{L}(boldsymbol{theta},boldsymbol{phi};textbf{x}^{(i)})quad tag{1}

此公式推導如下:

由於我們用 q_{phi}left( z|x right) 近似 p_{theta}left( z |xright) ,它們之間的KL散度為

D_{KL}(q_{phi}(z|x)||p_{theta}(z|x))=begin{array}{c}{{-sum_{z}q_{phi}(z|x)logleft(frac{p_{theta}(z|x)}{q_{phi}(z|x)}right)}}end{array}\=-sum_{z}q_{phi}(z|x)logleft(frac{frac{p_{theta}(x,z)}{p_{theta}(x)}}{q_{phi}(z|x)}right)\=-sum_z q_phi(z|x)left[logleft(dfrac{p_theta(x,z)}{q_phi(z|x)}right)-{logleft(p_theta(x)right)}right]\=logbigl(p_theta(x)bigr)-sum_{z}q_{phi}(z|x)logleft(frac{p_{theta}(x,z)}{q_{phi}(z|x)}right) tag{2}

將後項記為 mathcal{L}(boldsymbol{theta},boldsymbol{phi};textbf{x}^{(i)})quad ,移項過去就得到瞭上面的公式 (1)

公式 (1) 中,由KL散度的非負性我們可以得到 log p_{boldsymbol{theta}}(textbf{x}^{(i)}) 的下界:

log p_{boldsymbol{theta}}(textbf{x}^{(i)})geqmathcal{L}(boldsymbol{theta},boldsymbol{phi};textbf{x}^{(i)})=mathbb{E}_{q_{boldsymbol{phi}}(textbf{z}|textbf{x})}left[-log q_{boldsymbol{phi}}(textbf{z}|textbf{x})+log p_{boldsymbol{theta}}(textbf{x},textbf{z})right]tag{3}mathcal{L}(theta,phi;textbf{x}^{(i)})=-D_{KL}(q_{phi}(textbf{z}|textbf{x}^{(i)})||p_{theta}(textbf{z}))+mathbb{E}_{q_{phi}(textbf{z}|textbf{x}^{(i)})}left[log p_{theta}(textbf{x}^{(i)}|textbf{z})right]tag{3'} (3')(3) 展開後得到的另一種形式, D_{KL}(q_{phi}(textbf{z}|textbf{x}^{(i)})||p_{theta}(textbf{z})) 描述的是p_{theta}(textbf{z}) 與我們用來逼近它的 q_{phi}(textbf{z}|textbf{x}^{(i)}) 之間的KL距離,可以理解為模型產生 z 的encode能力, mathbb{E}_{q_{phi}(textbf{z}|textbf{x}^{(i)})}left[log p_{theta}(textbf{x}^{(i)}|textbf{z})right] 描述的是根據 z 重建 x 的能力,即decode能力。

到現在我們的目標轉成瞭去提高這個下界 mathcal{L}(boldsymbol{theta},boldsymbol{phi};textbf{x}^{(i)})quad 。直接對其求導的方法不可行,用Monte Carlo gradient estimator(簡單來說就是抽樣估計梯度)有:

nabla_{phi}mathbb{E}_{q_{phi}(textbf{z})}left[f(textbf{z})right]=mathbb{E}_{q_{phi}(textbf{z})}left[f(textbf{z})nabla_{q_{phi}(textbf{z})}log q_{phi}(textbf{z})right]simeqfrac{1}{L}sum_{l=1}^L f(textbf{z})nabla_{q_{phi}(textbf{z}^{(l)})}log q_{phi}(textbf{z}^{(l)}) tag{4} 其中 textbf{z}^{(l)} 是從 q_{boldsymbol{phi}}(textbf{z}|textbf{x}^{(i)}) 中采樣得到的,對於 textbf{z}^{(l)} 很難求梯度,沒辦法用 (4) 去優化。

SGVB 估計和AEVB優化

針對現在的優化目標形式 mathbb{E}_{q_{phi}(textbf{z})}left[f(textbf{z})right] ,既然沒辦法直接求導,那就找一個它的近似,論文的一大貢獻就在於重參數化技巧(The reparameterization trick),使得 mathbb{E}_{q_{phi}(textbf{z})}left[f(textbf{z})right] 可微。

上節講過 mathbb{E}_{q_{phi}(textbf{z})}left[f(textbf{z})right] 沒辦法微分問題出在 z 是隨機采樣出來的,重參數化技巧本質上是將這個采樣的過程用另一個變量 epsilon (與其他變量獨立)描述生成,即 textbf{z}=g_{boldsymbol{phi}}(boldsymbol{epsilon},textbf{x}) 。於是有:

q_{boldsymbol{phi}}(textbf{z}|textbf{x})prod_i dz_i=p(boldsymbol{epsilon})prod_i depsilon_iquadtext{}tag{5} int q_{phi}(textbf{z}|textbf{x})f(textbf{z})d{z}=int p({epsilon})f({z})d{epsilon}=int p({epsilon})f(g_{phi}({epsilon},{x}))depsilontag{6} 現在 z 可以看做是一個由 epsilon 產生的“普通的變量”,采樣過程帶來的不確定性被轉嫁到瞭 epsilon 上,於是我們得到瞭 mathbb{E}_{q_{phi}(textbf{z})}left[f(textbf{z})right] 的一個可微的近似值:

mathbb{E}_{q_phi(mathbf{z})mathbf{x}^{(i)})}left[f(mathbf{z})right]=mathbb{E}_{mathfrak{p}(mathbf{epsilon})}left[f(g_{phi}(mathbf{epsilon},mathbf{x}^{(i)}))right]simeqdfrac{1}{L}sum_{l=1}^{L}f(g_{phi}(mathbf{epsilon}^{(l)},mathbf{x}^{(i)}))quadtext{where}quadepsilon^{(l)}sim p(mathbf{epsilon})tag{7} 優化目標 mathcal{L}(boldsymbol{theta},boldsymbol{phi};textbf{x}^{(i)})quad(3) 被重寫為:

begin{aligned}widetilde{mathcal{L}}^A(boldsymbol{theta},boldsymbol{phi};boldsymbol{mathbf{x}}^{(i)})&=frac{1}{L}sum_{l=1}^Llog pboldsymbol{theta}(boldsymbol{mathbf{x}}^{(i)},boldsymbol{mathbf{z}}^{(i,l)})-log q_phi(boldsymbol{mathbf{z}}^{(i,l)}|boldsymbol{mathbf{x}}^{(i)})\ text{where}quadboldsymbol{mathbf{z}}^{(i,l)}&=g_phi(boldsymbol{epsilon}^{(i,l)},boldsymbol{mathbf{x}}^{(i)})quadtext{and}quadboldsymbol{epsilon}^{(l)}sim p(boldsymbol{epsilon})end{aligned} tag{8}

(3') 被重寫為:

begin{aligned}widetilde{mathcal{L}}^{mathcal{B}}(boldsymbol{theta},boldsymbol{phi};boldsymbol{x}^{(i)})=-D_{KL}(q_{boldsymbol{theta}}(boldsymbol{z}|boldsymbol{x}^{(i)})||p_{boldsymbol{theta}}(boldsymbol{z}))+frac{1}{L}sum_{i=1}^L(log p_{boldsymbol{theta}}(boldsymbol{x}^{(i)}|boldsymbol{x}^{(i,l)}))\ text{where}quadboldsymbol{z}^{(i,l)}=g_{boldsymbol{phi}}(boldsymbol{epsilon}^{(i,l)},boldsymbol{x}^{(i)})quadtext{and}quadboldsymbol{epsilon}^{(l)}sim p(boldsymbol{epsilon})end{aligned} tag{9} 前半部分 D_{KL}(q_{boldsymbol{theta}}(boldsymbol{z}|boldsymbol{x}^{(i)})||p_{boldsymbol{theta}}(boldsymbol{z})) 大多數情況下可以直接被計算出來。因此一般用 (9) u去估計 mathcal{L}

到現在我們有瞭優化目標以及它的一個可微的近似 (9) ,可以派出我們的老朋友SGD瞭。這樣我們就得到瞭完整的Auto-Encoding VB algorithm:

AEVB

VAE

如果你認真的看到這裡,其實VAE用一句話就可以將完瞭,前面的推導中我們並沒有為 p_{boldsymbol{theta}}(boldsymbol{z}) 或者 q_{boldsymbol{theta}}(boldsymbol{z}|boldsymbol{x}^{(i)}) 指定任何具體的概率分佈形式,而我們現在熟知的VAE做出瞭以下假設: p_{boldsymbol{theta}}(textbf{z})=mathcal{N}(textbf{z};textbf{0},{textbf{I}})left.q_{phi}(textbf{z}|textbf{x}^{(i)})=right.mathcal{N}(textbf{z};boldsymbol{mu}^{(i)},boldsymbol{sigma}^{2(i)}textbf{I}) 。其他所有的原理、推導和之前一樣,最後經過數學化簡,有:

begin{aligned}mathcal{L}(boldsymbol{theta},boldsymbol{phi};boldsymbol{mathbf{x}}^{(i)})&simeqfrac{1}{2}sum_{j=1}^JBig(1+log((sigma_j^{(i)})^2)-(mu_j^{(i)})^2-(sigma_j^{(i)})^2Big)+frac{1}{L}sum_{l=1}^Llog p_{boldsymbol{theta}}(mathbf{x}^{(i)}midmathbf{z}^{(i)})\ text{where}quadmathbf{z}^{(i,l)}&=boldsymbol{mu}^{(i)}+boldsymbol{sigma}^{(i)}circmathbf{e}^{(l)}quadtext{and}quadboldsymbol{epsilon}^{(l)}simmathcal{N}(0,mathbf{I})end{aligned} tag{10} 再使用AEVB優化就可以瞭。(原論文有這部分詳細的數學推導和 log p(textbf{x}|textbf{z}) 的計算方法,在此不再重復)。

總結

VAE與GAN的目標是一致的:最大化 X 的likelihood。GAN的思路很巧妙:我學習一個網絡去衡量我的generator效果。而VAE的思路其實還是最大化似然,但是不是直接計算 log p_{boldsymbol{theta}}(textbf{x}^{(i)}) ,而是去提升它的下界 mathcal{L} ,在這一優化過程中提出瞭重參數化的技巧。

VAE這篇論文距今已經十年,但是讀這篇論文比起讀近年來論文著實艱難,VAE很大程度上解決瞭AE被人詬病的生成能力差的問題,且其隱空間的連續性在實驗結果上得到瞭充分的體現。高山仰止。

VAE可視化結果

赞(0)