覚え書きブログ

tensorflow2.2のsubclassingによるUnetの実装

前回のコードを拡張し、今度はUnetを実装してみた。
hirotaka-hachiya.hatenablog.com

UNetは、以下のように、Uのような形をしたネットワークで、画像を一度低解像度化し特徴抽出してから、元の解像度に戻すことにより復元する。
特徴としては、encoderの各大きさの特徴マップをdecoderに持ってきてconcatすることと、encoderとdecoderともに1x1のconvolutionを用いて、convolutionでは特徴マップのサイズを変更せずに、max poolingとup samplingにより画像サイズを変更するところである。
f:id:hirotaka_hachiya:20200913163142p:plain

したがって、実装では以下のようにmyConvというconvolutionのクラスを、myModelクラスで呼び出しネットワーク構造を設定している。

class myConv(tf.keras.layers.Layer):
    def __init__(self, chn=32, conv_kernel=(3,3), strides=(2,2), activation='relu', isBatchNorm=True, padding='same'):
        super(myConv, self).__init__()
        self.activation = activation
        self.isBatchNorm = isBatchNorm
        self.conv_relu = tf.keras.layers.Conv2D(filters=chn, strides=strides, padding=padding, kernel_size=conv_kernel, activation='relu')
        self.conv_sigmoid =  tf.keras.layers.Conv2D(filters=chn, strides=strides, padding=padding, kernel_size=conv_kernel, activation='sigmoid')
        self.conv =  tf.keras.layers.Conv2D(filters=chn, strides=strides, padding=padding, kernel_size=conv_kernel)
        self.batchnorm = tf.keras.layers.BatchNormalization()
         

    def call(self, x):
        if self.activation == 'relu':
            x = self.conv_relu(x)
        elif self.activation == 'sigmoid':
            x = self.conv_sigmoid(x)
        elif self.activation == 'softmax':
            x = self.conv(x)
            x = tf.keras.activations.softmax(x, axis=3)

        if self.isBatchNorm:
            x = self.batchnorm(x)

        return x

# Modelクラスを継承し,独自のlayerクラス(myConvとmyFC)を用いてネットワークを定義する
# 独自のモデルクラスを作成
class myModel(tf.keras.Model):
    def __init__(self,isEmbedLabel=True):
        super(myModel, self).__init__()
        self.isEmbedLabel = isEmbedLabel

        # maxpool & upsample
        self.maxpool = tf.keras.layers.MaxPool2D((2,2), padding='same')
        self.upsample = tf.keras.layers.UpSampling2D((2,2), interpolation='nearest') 

        # encoder
        self.conv1_1 = myConv(chn=baseChn, conv_kernel=(3,3), strides=(1,1), activation='relu', padding='same')
        self.conv1_2 = myConv(chn=baseChn, conv_kernel=(3,3), strides=(1,1), activation='relu', padding='same')
        self.conv2_1 = myConv(chn=baseChn*2, conv_kernel=(3,3), strides=(1,1), activation='relu', padding='same')
        self.conv2_2 = myConv(chn=baseChn*2, conv_kernel=(3,3), strides=(1,1), activation='relu', padding='same')
        self.conv3_1 = myConv(chn=baseChn*4, conv_kernel=(3,3), strides=(1,1), activation='relu', padding='same')
        self.conv3_2 = myConv(chn=baseChn*4, conv_kernel=(3,3), strides=(1,1), activation='relu', padding='same')
        self.conv4_1 = myConv(chn=baseChn*8, conv_kernel=(3,3), strides=(1,1), activation='relu', padding='same')
        self.conv4_2 = myConv(chn=baseChn*8, conv_kernel=(3,3), strides=(1,1), activation='relu', padding='same')

        # decoder
        self.conv5_1 = myConv(chn=baseChn*4, conv_kernel=(3,3), strides=(1,1), activation='relu', padding='same')
        self.conv5_2 = myConv(chn=baseChn*4, conv_kernel=(3,3), strides=(1,1), activation='relu', padding='same')
        self.conv6_1 = myConv(chn=baseChn*2, conv_kernel=(3,3), strides=(1,1), activation='relu', padding='same')
        self.conv6_2 = myConv(chn=baseChn*2, conv_kernel=(3,3), strides=(1,1), activation='relu', padding='same')                
        self.conv7_1 = myConv(chn=baseChn, conv_kernel=(3,3), strides=(1,1), activation='relu', padding='same')
        self.conv7_2 = myConv(chn=baseChn, conv_kernel=(3,3), strides=(1,1), activation='relu', padding='same')
        self.conv7_3 = myConv(chn=2, conv_kernel=(3,3), strides=(1,1), activation='relu', padding='valid')
        self.conv7_4 = myConv(chn=2, conv_kernel=(3,3), strides=(1,1), activation='softmax', padding='valid', isBatchNorm=False)

    def call(self, data):
        (x,label) = data

        # encoder
        conv1_0 = tf.keras.layers.ZeroPadding2D(padding=(2,2))(x)
        conv1_1 = self.conv1_1(conv1_0)
        conv1_2 = self.conv1_2(conv1_1)

        conv2_0 = self.maxpool(conv1_2)
        conv2_1 = self.conv2_1(conv2_0)
        conv2_2 = self.conv2_2(conv2_1)

        conv3_0 = self.maxpool(conv2_2)
        conv3_1 = self.conv3_1(conv3_0)
        conv3_2 = self.conv3_2(conv3_1)

        conv4_0 = self.maxpool(conv3_2)
        conv4_1 = self.conv4_1(conv4_0)
        conv4_2 = self.conv4_2(conv4_1)

        # embeddingにone-hotラベルをconcat
        if self.isEmbedLabel:
            shape = tf.shape(conv4_2)            
            label_map = tf.tile(tf.expand_dims(tf.expand_dims(label,1),1),[1,shape[1],shape[2],1])
            conv4_2=tf.concat([conv4_2,label_map],axis=3)

        # decoder
        conv5_0 = self.upsample(conv4_2)
        conv5_0_concat = tf.concat([conv5_0,conv3_2],axis=-1)
        conv5_1 = self.conv5_1(conv5_0_concat)
        conv5_2 = self.conv5_2(conv5_1)

        conv6_0 = self.upsample(conv5_2)              
        conv6_0_concat = tf.concat([conv6_0,conv2_2],axis=-1)
        conv6_1 = self.conv6_1(conv6_0_concat)
        conv6_2 = self.conv6_2(conv6_1)

        conv7_0 = self.upsample(conv6_2)              
        conv7_0_concat = tf.concat([conv7_0,conv1_2],axis=-1)
        conv7_1 = self.conv7_1(conv7_0_concat)
        conv7_2 = self.conv7_2(conv7_1)
        conv7_3 = self.conv7_3(conv7_2)

        output = self.conv7_4(conv7_3)

        return output, (conv1_2,conv2_2,conv3_2,conv4_2,conv5_2,conv6_2,conv7_2)

気を付けなければならないところとしては、encoderの各階層の画像サイズと、decoderの各階層の画像サイズが合わないとconcatできないところである。
padding='same'を使っているので、max poolingでは、画像サイズ割るstrideの大きさ、つまり1/2ずつ画像サイズを小さくしていくが、MNISTの場合、28、14、7、3.5なり、4階層目で3.5となり割り切れない。そして、padding='same'の場合は、4階層目で画像サイズ4に設定するわけだが、今度はdecoder側でupsamplingにて2倍ずつしていくと、4、8、16、32となりdecoderとencoderで画像サイズが合わないためconcatができなくなってしまうのである。

そこで、今回は、まず画像サイズをZeroPaddingを用いて32x32に解像度を上げてから、encoderおよびdecoderを行うことにより、32, 16, 8, 4と画像サイズが合うように調整している。

        conv1_0 = tf.keras.layers.ZeroPadding2D(padding=(2,2))(x)

そして、deconvolutionの最後にて、カーネルサイズ3x3、strideが1x1のconvolutionを2回行い、28x28に解像度を落とし出力するようにしている。

        conv7_2 = self.conv7_2(conv7_1)
        conv7_3 = self.conv7_3(conv7_2)

全体のコードは以下を参照。

gistb71bbfe5276041c22048919baad1f102

実行すると以下のようにaccuracyとloss(交差エントロピー)のグラフが表示される。前回のautoencoderよりも早い段階から誤差はほぼゼロとなりaccuracyもほぼ1になっていることがわかる。
f:id:hirotaka_hachiya:20200913163109p:plain

復元した画像(1行目)と元画像(2行目)は以下のようになっている。
f:id:hirotaka_hachiya:20200913171532p:plain