覚え書きブログ

Deep Learning覚え書き(Batch Normalization)

Deep Learningの各階層の入力データの分布は、学習の過程において、下位層のパラメータが更新されることにより変化する。各階層の勾配は、ミニバッチ内で平均をとることにより推定しているが、この分布の変化により推定に、ミニバッチごとに異なるバイアスが乗りやすくなる。そのため、学習が不安定になるという問題がある。この問題は、internal convariance shiftと呼ばれている。この問題を解決するために、下記の論文では各階層の入力分布をミニバッチごとに平均=0と分散=1とに正規化するbatch normalizationという方法が提案されている。
http://arxiv.org/pdf/1502.03167v3.pdf

今回は、chainerにてBatch normalizationを、各CNNに2つの方法で適用してみた。
1)下記のようにconvと非線形な活性化関数ReLUの間ににbatch normalizationを適用

h = F.relu(self.bnorm1(self.conv1(x)))

2)下記のようにconvとReLUの後にbatch normalizationを適用

h = F.relu(self.conv1(x))
h = self.bnorm1(h)

具体的には、各ネットワークのクラスは以下のように定義(MnistCNN_BN_beforeとMnistCNN_BN_after)されている。

class MnistCNN_BN_before(chainer.Chain):

    """An example of convolutional neural network for MNIST dataset.

    """

    def __init__(self, channel=1, c1=16, c2=32, c3=64, f1=256, \
                 f2=512, filter_size1=3, filter_size2=3, filter_size3=3):
        super(MnistCNN_BN_before, self).__init__(
            conv1=L.Convolution2D(channel, c1, filter_size1),
            conv2=L.Convolution2D(c1, c2, filter_size2),
            conv3=L.Convolution2D(c2, c3, filter_size3),
            l1=L.Linear(f1, f2),
            l2=L.Linear(f2, 10),
            bnorm1=L.BatchNormalization(c1),
            bnorm2=L.BatchNormalization(c2),
            bnorm3=L.BatchNormalization(c3)
        )

    def __call__(self, x):
        # param x --- chainer.Variable of array

        x.data = x.data.reshape((len(x.data), 1, 28, 28))

        h = F.relu(self.bnorm1(self.conv1(x)))
        h = F.max_pooling_2d(h, 2)
        h = F.relu(self.bnorm2(self.conv2(h)))
        h = F.max_pooling_2d(h, 2)
        h = F.relu(self.bnorm3(self.conv3(h)))
        h = F.max_pooling_2d(h, 2)
        h = F.dropout(F.relu(self.l1(h)))
        y = self.l2(h)
        return y
        
class MnistCNN_BN_after(chainer.Chain):

    """An example of convolutional neural network for MNIST dataset.

    """

    def __init__(self, channel=1, c1=16, c2=32, c3=64, f1=256, \
                 f2=512, filter_size1=3, filter_size2=3, filter_size3=3):
        super(MnistCNN_BN_after, self).__init__(
            conv1=L.Convolution2D(channel, c1, filter_size1),
            conv2=L.Convolution2D(c1, c2, filter_size2),
            conv3=L.Convolution2D(c2, c3, filter_size3),
            l1=L.Linear(f1, f2),
            l2=L.Linear(f2, 10),
            bnorm1=L.BatchNormalization(c1),
            bnorm2=L.BatchNormalization(c2),
            bnorm3=L.BatchNormalization(c3)
        )

    def __call__(self, x):
        # param x --- chainer.Variable of array

        x.data = x.data.reshape((len(x.data), 1, 28, 28))

        h = F.relu(self.conv1(x))
        h = self.bnorm1(h)
        h = F.max_pooling_2d(h, 2)
        h = F.relu(self.conv2(h))
        h = self.bnorm2(h)
        h = F.max_pooling_2d(h, 2)
        h = F.relu(self.conv3(h))
        h = self.bnorm3(h)
        h = F.max_pooling_2d(h, 2)
        h = F.dropout(F.relu(self.l1(h)))
        y = self.l2(h)
        return y

以下がBatch normalizationを入れない場合、convとReLUの間に入れた場合、およびconvとReLUの後に入れた場合の結果である。

Batch Normalizationなしの場合:
f:id:hirotaka_hachiya:20160802220957p:plain

convとReLUの間にBatch Normalizationを入れた場合:
f:id:hirotaka_hachiya:20160806182959p:plain

convとReLUの後にBatch Normalizationを入れた場合:
f:id:hirotaka_hachiya:20160806175038p:plain

3つのaccuracyを比較してみると、全体的にBatch normalizationを入れた方が改善している。特にSGDのaccuracyが大きく改善していて、convとReLUの間に入れた方が若干よくなっている。下記では、ReLUの後に入れた方が改善されたと報告されているので、問題依存かもしれない。
https://github.com/ducha-aiki/caffenet-benchmark/blob/master/batchnorm.md

さらに、fully connected層にも、Batch Normalizationを入れてみた。

class MnistCNN_BN_linear_before(chainer.Chain):

    """An example of convolutional neural network for MNIST dataset.

    """

    def __init__(self, channel=1, c1=16, c2=32, c3=64, f1=256, \
                 f2=512, filter_size1=3, filter_size2=3, filter_size3=3):
        super(MnistCNN_BN_linear_before, self).__init__(
            conv1=L.Convolution2D(channel, c1, filter_size1),
            conv2=L.Convolution2D(c1, c2, filter_size2),
            conv3=L.Convolution2D(c2, c3, filter_size3),
            l1=L.Linear(f1, f2),
            l2=L.Linear(f2, 10),
            bnorm1=L.BatchNormalization(c1),
            bnorm2=L.BatchNormalization(c2),
            bnorm3=L.BatchNormalization(c3),
            bnorm4=L.BatchNormalization(f2),
            bnorm5=L.BatchNormalization(10)
        )

    def __call__(self, x):
        # param x --- chainer.Variable of array

        x.data = x.data.reshape((len(x.data), 1, 28, 28))

        h = F.relu(self.bnorm1(self.conv1(x)))
        h = F.max_pooling_2d(h, 2)
        h = F.relu(self.bnorm2(self.conv2(h)))
        h = F.max_pooling_2d(h, 2)
        h = F.relu(self.bnorm3(self.conv3(h)))
        h = F.max_pooling_2d(h, 2)
        h = F.dropout(F.relu(self.bnorm4(self.l1(h))))
        y = self.bnorm5(self.l2(h))
        return y

f:id:hirotaka_hachiya:20160807204804p:plain
全体的に学習曲線が安定したような感じがする。