前回のコードを拡張し、今度はUnetを実装してみた。
hirotaka-hachiya.hatenablog.com
UNetは、以下のように、Uのような形をしたネットワークで、画像を一度低解像度化し特徴抽出してから、元の解像度に戻すことにより復元する。
特徴としては、encoderの各大きさの特徴マップをdecoderに持ってきてconcatすることと、encoderとdecoderともに1x1のconvolutionを用いて、convolutionでは特徴マップのサイズを変更せずに、max poolingとup samplingにより画像サイズを変更するところである。
したがって、実装では以下のようにmyConvというconvolutionのクラスを、myModelクラスで呼び出しネットワーク構造を設定している。
class myConv(tf.keras.layers.Layer): def __init__(self, chn=32, conv_kernel=(3,3), strides=(2,2), activation='relu', isBatchNorm=True, padding='same'): super(myConv, self).__init__() self.activation = activation self.isBatchNorm = isBatchNorm self.conv_relu = tf.keras.layers.Conv2D(filters=chn, strides=strides, padding=padding, kernel_size=conv_kernel, activation='relu') self.conv_sigmoid = tf.keras.layers.Conv2D(filters=chn, strides=strides, padding=padding, kernel_size=conv_kernel, activation='sigmoid') self.conv = tf.keras.layers.Conv2D(filters=chn, strides=strides, padding=padding, kernel_size=conv_kernel) self.batchnorm = tf.keras.layers.BatchNormalization() def call(self, x): if self.activation == 'relu': x = self.conv_relu(x) elif self.activation == 'sigmoid': x = self.conv_sigmoid(x) elif self.activation == 'softmax': x = self.conv(x) x = tf.keras.activations.softmax(x, axis=3) if self.isBatchNorm: x = self.batchnorm(x) return x # Modelクラスを継承し,独自のlayerクラス(myConvとmyFC)を用いてネットワークを定義する # 独自のモデルクラスを作成 class myModel(tf.keras.Model): def __init__(self,isEmbedLabel=True): super(myModel, self).__init__() self.isEmbedLabel = isEmbedLabel # maxpool & upsample self.maxpool = tf.keras.layers.MaxPool2D((2,2), padding='same') self.upsample = tf.keras.layers.UpSampling2D((2,2), interpolation='nearest') # encoder self.conv1_1 = myConv(chn=baseChn, conv_kernel=(3,3), strides=(1,1), activation='relu', padding='same') self.conv1_2 = myConv(chn=baseChn, conv_kernel=(3,3), strides=(1,1), activation='relu', padding='same') self.conv2_1 = myConv(chn=baseChn*2, conv_kernel=(3,3), strides=(1,1), activation='relu', padding='same') self.conv2_2 = myConv(chn=baseChn*2, conv_kernel=(3,3), strides=(1,1), activation='relu', padding='same') self.conv3_1 = myConv(chn=baseChn*4, conv_kernel=(3,3), strides=(1,1), activation='relu', padding='same') self.conv3_2 = myConv(chn=baseChn*4, conv_kernel=(3,3), strides=(1,1), activation='relu', padding='same') self.conv4_1 = myConv(chn=baseChn*8, conv_kernel=(3,3), strides=(1,1), activation='relu', padding='same') self.conv4_2 = myConv(chn=baseChn*8, conv_kernel=(3,3), strides=(1,1), activation='relu', padding='same') # decoder self.conv5_1 = myConv(chn=baseChn*4, conv_kernel=(3,3), strides=(1,1), activation='relu', padding='same') self.conv5_2 = myConv(chn=baseChn*4, conv_kernel=(3,3), strides=(1,1), activation='relu', padding='same') self.conv6_1 = myConv(chn=baseChn*2, conv_kernel=(3,3), strides=(1,1), activation='relu', padding='same') self.conv6_2 = myConv(chn=baseChn*2, conv_kernel=(3,3), strides=(1,1), activation='relu', padding='same') self.conv7_1 = myConv(chn=baseChn, conv_kernel=(3,3), strides=(1,1), activation='relu', padding='same') self.conv7_2 = myConv(chn=baseChn, conv_kernel=(3,3), strides=(1,1), activation='relu', padding='same') self.conv7_3 = myConv(chn=2, conv_kernel=(3,3), strides=(1,1), activation='relu', padding='valid') self.conv7_4 = myConv(chn=2, conv_kernel=(3,3), strides=(1,1), activation='softmax', padding='valid', isBatchNorm=False) def call(self, data): (x,label) = data # encoder conv1_0 = tf.keras.layers.ZeroPadding2D(padding=(2,2))(x) conv1_1 = self.conv1_1(conv1_0) conv1_2 = self.conv1_2(conv1_1) conv2_0 = self.maxpool(conv1_2) conv2_1 = self.conv2_1(conv2_0) conv2_2 = self.conv2_2(conv2_1) conv3_0 = self.maxpool(conv2_2) conv3_1 = self.conv3_1(conv3_0) conv3_2 = self.conv3_2(conv3_1) conv4_0 = self.maxpool(conv3_2) conv4_1 = self.conv4_1(conv4_0) conv4_2 = self.conv4_2(conv4_1) # embeddingにone-hotラベルをconcat if self.isEmbedLabel: shape = tf.shape(conv4_2) label_map = tf.tile(tf.expand_dims(tf.expand_dims(label,1),1),[1,shape[1],shape[2],1]) conv4_2=tf.concat([conv4_2,label_map],axis=3) # decoder conv5_0 = self.upsample(conv4_2) conv5_0_concat = tf.concat([conv5_0,conv3_2],axis=-1) conv5_1 = self.conv5_1(conv5_0_concat) conv5_2 = self.conv5_2(conv5_1) conv6_0 = self.upsample(conv5_2) conv6_0_concat = tf.concat([conv6_0,conv2_2],axis=-1) conv6_1 = self.conv6_1(conv6_0_concat) conv6_2 = self.conv6_2(conv6_1) conv7_0 = self.upsample(conv6_2) conv7_0_concat = tf.concat([conv7_0,conv1_2],axis=-1) conv7_1 = self.conv7_1(conv7_0_concat) conv7_2 = self.conv7_2(conv7_1) conv7_3 = self.conv7_3(conv7_2) output = self.conv7_4(conv7_3) return output, (conv1_2,conv2_2,conv3_2,conv4_2,conv5_2,conv6_2,conv7_2)
気を付けなければならないところとしては、encoderの各階層の画像サイズと、decoderの各階層の画像サイズが合わないとconcatできないところである。
padding='same'を使っているので、max poolingでは、画像サイズ割るstrideの大きさ、つまり1/2ずつ画像サイズを小さくしていくが、MNISTの場合、28、14、7、3.5なり、4階層目で3.5となり割り切れない。そして、padding='same'の場合は、4階層目で画像サイズ4に設定するわけだが、今度はdecoder側でupsamplingにて2倍ずつしていくと、4、8、16、32となりdecoderとencoderで画像サイズが合わないためconcatができなくなってしまうのである。
そこで、今回は、まず画像サイズをZeroPaddingを用いて32x32に解像度を上げてから、encoderおよびdecoderを行うことにより、32, 16, 8, 4と画像サイズが合うように調整している。
conv1_0 = tf.keras.layers.ZeroPadding2D(padding=(2,2))(x)
そして、deconvolutionの最後にて、カーネルサイズ3x3、strideが1x1のconvolutionを2回行い、28x28に解像度を落とし出力するようにしている。
conv7_2 = self.conv7_2(conv7_1) conv7_3 = self.conv7_3(conv7_2)
全体のコードは以下を参照。
gistb71bbfe5276041c22048919baad1f102
実行すると以下のようにaccuracyとloss(交差エントロピー)のグラフが表示される。前回のautoencoderよりも早い段階から誤差はほぼゼロとなりaccuracyもほぼ1になっていることがわかる。
復元した画像(1行目)と元画像(2行目)は以下のようになっている。