深度學習—BN的理解(一)

0、問題

機器學習領域有個很重要的假設:IID獨立同分佈假設,就是假設訓練數據和測試數據是滿足相同分佈的,這是通過訓練數據獲得的模型能夠在測試集獲得好的效果的一個基本保障。那BatchNorm的作用是什麼呢?BatchNorm就是在深度神經網絡訓練過程中使得每一層神經網絡的輸入保持相同分佈的。

  思考一個問題:為什麼傳統的神經網絡在訓練開始之前,要對輸入的數據做Normalization?原因在於神經網絡學習過程本質上是為瞭學習數據的分佈,一旦訓練數據與測試數據的分佈不同,那麼網絡的泛化能力也大大降低;另一方面,一旦在mini-batch梯度下降訓練的時候,每批訓練數據的分佈不相同,那麼網絡就要在每次迭代的時候去學習以適應不同的分佈,這樣將會大大降低網絡的訓練速度,這也正是為什麼我們需要對所有訓練數據做一個Normalization預處理的原因。

  為什麼深度神經網絡隨著網絡深度加深,訓練起來越困難,收斂越來越慢?這是個在DL領域很接近本質的好問題。很多論文都是解決這個問題的,比如ReLU激活函數,再比如Residual Network,BN本質上也是解釋並從某個不同的角度來解決這個問題的。

  結論:BN 層在激活函數之前。BN層的作用機制也許是通過平滑隱藏層輸入的分佈,幫助隨機梯度下降的進行,緩解隨機梯度下降權重更新對後續層的負面影響。因此,實際上,無論是放非線性激活之前,還是之後,也許都能發揮這個作用。隻不過,取決於具體激活函數的不同,效果也許有一點差別(比如,對sigmoid和tanh而言,放非線性激活之前,也許順便還能緩解sigmoid/tanh的梯度衰減問題,而對ReLU而言,這個平滑作用經ReLU“扭曲”之後也許有所衰弱)。

  (1)sigmoid、tanh 激活函數。函數圖像的兩端,相對於x的變化,y的變化都很小(這其實很正常,畢竟tanh就是拉伸過的sigmoid)。也就是說,容易出現梯度衰減的問題。那麼,如果在tanh或sigmoid之前,進行一些normalization處理,就可以緩解梯度衰減的問題。我想這可能也是最初的BN論文選擇把BN層放在非線性激活之前的原因。

  (2)relu激活函數。所以,現在我們假設所有的激活都是relu,也就是使得負半區的卷積值被抑制,正半區的卷積值被保留。而bn的作用是使得輸入值的均值為0,方差為1,也就是說假如relu之前是bn的話,會有接近一半的輸入值被抑制,一半的輸入值被保留。

  所以bn放到relu之前的好處可以這樣理解:bn可以防止某一層的激活值全部都被抑制,從而防止從這一層往前傳的梯度全都變成0,也就是防止梯度消失。(當然也可以防止梯度爆炸)

1、“Internal Covariate Shift”問題

  從論文名字可以看出,BN是用來解決“Internal Covariate Shift”問題的,那麼首先得理解什麼是“Internal Covariate Shift”?

  論文首先說明Mini-Batch SGD相對於One Example SGD的兩個優勢:梯度更新方向更準確;並行計算速度快;(為什麼要說這些?因為BatchNorm是基於Mini-Batch SGD的,所以先誇下Mini-Batch SGD,當然也是大實話);然後吐槽下SGD訓練的缺點:超參數調起來很麻煩。(作者隱含意思是用BN就能解決很多SGD的缺點)

  接著引入covariate shift的概念:如果ML系統實例集合<X,Y>中的輸入值X的分佈老是變,這不符合IID假設,網絡模型很難穩定的學規律,這不得引入遷移學習才能搞定嗎,我們的ML系統還得去學習怎麼迎合這種分佈變化啊。對於深度學習這種包含很多隱層的網絡結構,在訓練過程中,因為各層參數不停在變化,所以每個隱層都會面臨covariate shift的問題,也就是在訓練過程中,隱層的輸入分佈老是變來變去,這就是所謂的“Internal Covariate Shift”,Internal指的是深層網絡的隱層,是發生在網絡內部的事情,而不是covariate shift問題隻發生在輸入層。

  然後提出瞭BatchNorm的基本思想:能不能讓每個隱層節點的激活輸入分佈固定下來呢?這樣就避免瞭“Internal Covariate Shift”問題瞭,順帶解決反向傳播中梯度消失問題。BN 其實就是在做 feature scaling,而且它的目的也是為瞭在訓練的時候避免這種 Internal Covariate Shift 的問題,隻是剛好也解決瞭 sigmoid 函數梯度消失的問題。

  BN不是憑空拍腦袋拍出來的好點子,它是有啟發來源的:之前的研究表明如果在圖像處理中對輸入圖像進行白化(Whiten)操作的話——所謂白化,就是對輸入數據分佈變換到0均值,單位方差的正態分佈——那麼神經網絡會較快收斂,那麼BN作者就開始推論瞭:圖像是深度神經網絡的輸入層,做白化能加快收斂,那麼其實對於深度網絡來說,其中某個隱層的神經元是下一層的輸入,意思是其實深度神經網絡的每一個隱層都是輸入層,不過是相對下一層來說而已,那麼能不能對每個隱層都做白化呢?這就是啟發BN產生的原初想法,而BN也確實就是這麼做的,可以理解為對深層神經網絡每個隱層神經元的激活值做簡化版本的白化操作。

2、BatchNorm的本質思想

  BN的基本思想其實相當直觀:因為深層神經網絡在做非線性變換前的激活輸入值(就是那個x=WU+B,U是輸入)隨著網絡深度加深或者在訓練過程中,其分佈逐漸發生偏移或者變動,之所以訓練收斂慢,一般是整體分佈逐漸往非線性函數的取值區間的上下限兩端靠近(對於Sigmoid函數來說,意味著激活輸入值WU+B是大的負值或正值),所以這導致反向傳播時低層神經網絡的梯度消失,這是訓練深層神經網絡收斂越來越慢的本質原因,而BN就是通過一定的規范化手段,把每層神經網絡任意神經元這個輸入值的分佈強行拉回到均值為0方差為1的標準正態分佈,其實就是把越來越偏的分佈強制拉回比較標準的分佈,這樣使得激活輸入值落在非線性函數對輸入比較敏感的區域,這樣輸入的小變化就會導致損失函數較大的變化,意思是這樣讓梯度變大,避免梯度消失問題產生,而且梯度變大意味著學習收斂速度快,能大大加快訓練速度。

  THAT’S IT。其實一句話就是:對於每個隱層神經元,把逐漸向非線性函數映射後向取值區間極限飽和區靠攏的輸入分佈強制拉回到均值為0方差為1的比較標準的正態分佈,使得非線性變換函數的輸入值落入對輸入比較敏感的區域,以此避免梯度消失問題。因為梯度一直都能保持比較大的狀態,所以很明顯對神經網絡的參數調整效率比較高,就是變動大,就是說向損失函數最優值邁動的步子大,也就是說收斂地快。BN說到底就是這麼個機制,方法很簡單,道理很深刻。

  從上面幾個圖應該看出來BN在幹什麼瞭吧?其實就是把隱層神經元激活輸入x=WU+B從變化不拘一格的正態分佈通過BN操作拉回到瞭均值為0,方差為1的正態分佈,即原始正態分佈中心左移或者右移到以0為均值,拉伸或者縮減形態形成以1為方差的圖形。什麼意思?就是說經過BN後,目前大部分Activation的值落入非線性函數的線性區內,其對應的導數遠離導數飽和區,這樣來加速訓練收斂過程。

  但是很明顯,看到這裡,稍微瞭解神經網絡的讀者一般會提出一個疑問:如果都通過BN,那麼不就跟把非線性函數替換成線性函數效果相同瞭?這意味著什麼?我們知道,如果是多層的線性函數變換其實這個深層是沒有意義的,因為多層線性網絡跟一層線性網絡是等價的。這意味著網絡的表達能力下降瞭,這也意味著深度的意義就沒有瞭。所以BN為瞭保證非線性的獲得,對變換後的滿足均值為0方差為1的x又進行瞭scale加上shift操作(y=scale*x+shift),每個神經元增加瞭兩個參數scale和shift參數,這兩個參數是通過訓練學習到的,意思是通過scale和shift把這個值從標準正態分佈左移或者右移一點並長胖一點或者變瘦一點,每個實例挪動的程度不一樣,這樣等價於非線性函數的值從正中心周圍的線性區往非線性區動瞭動。

核心思想應該是想找到一個線性和非線性的較好平衡點,既能享受非線性的較強表達能力的好處,又避免太靠非線性區兩頭使得網絡收斂速度太慢。當然,這是我的理解,論文作者並未明確這樣說。但是很明顯這裡的scale和shift操作是會有爭議的,因為按照論文作者論文裡寫的理想狀態,就會又通過scale和shift操作把變換後的x調整回未變換的狀態,那不是饒瞭一圈又繞回去原始的“Internal Covariate Shift”問題裡去瞭嗎,感覺論文作者並未能夠清楚地解釋scale和shift操作的理論原因。

3、訓練階段如何做BatchNorm

  對於Mini-Batch SGD來說,一次訓練過程裡面包含m個訓練實例,其具體BN操作就是對於隱層內每個神經元的激活值來說,進行如下變換:

  要註意,這裡t層某個神經元的x(k)不是指原始輸入,就是說不是t-1層每個神經元的輸出,而是t層這個神經元的線性激活x=WU+B,這裡的U才是t-1層神經元的輸出。變換的意思是:某個神經元對應的原始的激活x通過減去mini-Batch內m個實例獲得的m個激活x求得的均值E(x)並除以求得的方差Var(x)來進行轉換。

  上文說過經過這個變換後某個神經元的激活x形成瞭均值為0,方差為1的正態分佈,目的是把值往後續要進行的非線性變換的線性區拉動,增大導數值,增強反向傳播信息流動性,加快訓練收斂速度。但是這樣會導致網絡表達能力下降,為瞭防止這一點,每個神經元增加兩個調節參數(scale和shift),這兩個參數是通過訓練來學習到的,用來對變換後的激活反變換,使得網絡表達能力增強,即對變換後的激活進行如下的scale和shift操作,這其實是變換的反操作:

  BN其具體操作流程,如論文中描述的一樣:

  走一遍Batch Normalization網絡層的前向傳播過程。

4、BatchNorm的推理(Inference)過程

  BN在訓練的時候可以根據Mini-Batch裡的若幹訓練實例進行激活數值調整,但是在推理(inference)的過程中,很明顯輸入就隻有一個實例,看不到Mini-Batch其它實例,那麼這時候怎麼對輸入做BN呢?因為很明顯一個實例是沒法算實例集合求出的均值和方差的。這可如何是好?既然沒有從Mini-Batch數據裡可以得到的統計量,那就想其它辦法來獲得這個統計量,就是均值和方差。可以用從所有訓練實例中獲得的統計量來代替Mini-Batch裡面m個訓練實例獲得的均值和方差統計量,因為本來就打算用全局的統計量,隻是因為計算量等太大所以才會用Mini-Batch這種簡化方式的,那麼在推理的時候直接用全局統計量即可。

  決定瞭獲得統計量的數據范圍,那麼接下來的問題是如何獲得均值和方差的問題。很簡單,因為每次做Mini-Batch訓練時,都會有那個Mini-Batch裡m個訓練實例獲得的均值和方差,現在要全局統計量,隻要把每個Mini-Batch的均值和方差統計量記住,然後對這些均值和方差求其對應的數學期望即可得出全局統計量

5、BatchNorm的好處

  BatchNorm為什麼NB呢,關鍵還是效果好。

①不僅僅極大提升瞭訓練速度,收斂過程大大加快;

②還能增加分類效果,一種解釋是這是類似於Dropout的一種防止過擬合的正則化表達方式,所以不用Dropout也能達到相當的效果;

③另外調參過程也簡單多瞭,對於初始化要求沒那麼高,而且可以使用大的學習率等。

總而言之,經過這麼簡單的變換,帶來的好處多得很,這也是為何現在BN這麼快流行起來的原因。

6、tensorflow中的BN

  為瞭activation能更有效地使用輸入信息,所以一般BN放在激活函數之前。

   一個batch裡的128個圖,經過一個64 kernels卷積層處理,得到瞭128×64個圖,再針對每一個kernel所對應的128個圖,求它們所有像素的mean和variance,因為總共有64個kernels,輸出的結果就是一個一維長度64的數組啦!最後輸出是(64,)的數組向量。

參考文獻:郭耀華博客:https://cloud.tencent.com/developer/article/1157136

     https://zhuanlan.zhihu.com/p/36222443

赞(0)