覚え書きブログ

tensorflow2.2のsubclassingによるautoencoderの実装

以下を参考に、tensorflow2.2以降で許可されたtrain_step, test_stepおよびpredict_stepのオーバーライドを用いて、Subclassing APIでautoencoderを実装してみた。
hirotaka-hachiya.hatenablog.com

以下は、全体のコードである。元画像と復元画像の平均二乗誤差を最小化するように学習している。
今回から新たに、tf.keras.callbacks.ModelCheckpointとmodel.load_weightsを用いて学習したパラメータの保存と読み込みをできるようにした。

gistbaf88b3900c5a179a2d6474fe3ae3919

実行すると以下のように学習とテストデータにおける誤差と画像が表示される。

Epoch 1/3
300/300 [==============================] - ETA: 0s - loss: 0.0289 - mae: 0.0856
Epoch 00001: saving model to autoencoder_training/cp.ckpt
300/300 [==============================] - 107s 358ms/step - loss: 0.0289 - mae: 0.0856
Epoch 2/3
300/300 [==============================] - ETA: 0s - loss: 0.0053 - mae: 0.0282
Epoch 00002: saving model to autoencoder_training/cp.ckpt
300/300 [==============================] - 113s 376ms/step - loss: 0.0053 - mae: 0.0282
Epoch 3/3
300/300 [==============================] - ETA: 0s - loss: 0.0040 - mae: 0.0231
Epoch 00003: saving model to autoencoder_training/cp.ckpt
300/300 [==============================] - 89s 298ms/step - loss: 0.0040 - mae: 0.0231
Train data loss: 0.0037610987201333046
Train data mae: 0.021624118089675903
Test data loss: 0.003681402187794447
Test data mae: 0.02133253403007984

1行目が復元した画像で、2行目が元の画像である。ほぼ完ぺきに再現できている。
f:id:hirotaka_hachiya:20200906234254p:plain

train_step, test_stepおよびpredict_stepをオーバーライドすることにより学習まわりはシンプルに書くことができたが、SequentialやFunctionalのようにmode.summary()を用いてネットワークの構造を可視化できないので、ネットワークの設計時のデバッグがしにくい問題がある。なぜ、summary()をつかえないのだろうか。。。

そこで、モデルの定義はFunctional APIを使い、train_step, test_stepおよびpredict_stepをオーバーライドしたmyModelに、引数でinputs, outputsを設定することによりネットワークを定義するように変更してみた。

#----------------------------
# Functionalを用いたネットワークの定義
def autoencoder(input_shape):
    inputs = tf.keras.layers.Input(shape=input_shape, name="inputs")

    # conv1
    conv1 = tf.keras.layers.Conv2D(filters=32, strides=(2, 2), padding='same', kernel_size=(3, 3), activation='relu')(inputs)
    conv1 = tf.keras.layers.BatchNormalization()(conv1)

    # conv2
    conv2 = tf.keras.layers.Conv2D(filters=64, strides=(2, 2), padding='same', kernel_size=(3, 3), activation='relu')(conv1)
    conv2 = tf.keras.layers.BatchNormalization()(conv2)

    # fc1
    conv2_flat = tf.keras.layers.Flatten()(conv2)
    fc = tf.keras.layers.Dense(units=64,activation='relu')(conv2_flat)

    # defc
    defc = tf.keras.layers.Dense(units=3136,activation='relu')(fc)
    defc = tf.reshape(defc, tf.shape(conv2))

    # deconv1    
    deconv1 = tf.keras.layers.Conv2DTranspose(filters=64, strides=(2, 2), padding='same', kernel_size=(3, 3), activation='relu')(defc)
    deconv1 = tf.keras.layers.BatchNormalization()(deconv1)

    # deconv2
    deconv2 = tf.keras.layers.Conv2DTranspose(filters=32, strides=(2, 2), padding='same', kernel_size=(3, 3), activation='relu')(deconv1)
    deconv2 = tf.keras.layers.BatchNormalization()(deconv2)    

    # deconv2
    outputs = tf.keras.layers.Conv2DTranspose(filters=1, strides=(1, 1), padding='same', kernel_size=(1, 1), activation='sigmoid')(deconv2)

    return inputs, outputs
#----------------------------
~省略~
# モデルの設定
inputs, outputs = autoencoder((H,W,C))
model = myModel(inputs,outputs)

# 学習方法の設定
model.compile(optimizer='adam',loss='mean_squared_error',metrics=['mae'])
model.summary()

これにより、以下のようにsummary()でネットワークの構造が確認できるとともに、学習方法を独自に設定することができる。

Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
inputs (InputLayer)             [(None, 28, 28, 1)]  0
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 14, 14, 32)   320         inputs[0][0]
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 14, 14, 32)   128         conv2d[0][0]
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 7, 7, 64)     18496       batch_normalization[0][0]
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 7, 7, 64)     256         conv2d_1[0][0]
__________________________________________________________________________________________________
flatten (Flatten)               (None, 3136)         0           batch_normalization_1[0][0]
__________________________________________________________________________________________________
dense (Dense)                   (None, 64)           200768      flatten[0][0]
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 3136)         203840      dense[0][0]
__________________________________________________________________________________________________
tf_op_layer_Shape (TensorFlowOp [(4,)]               0           batch_normalization_1[0][0]
__________________________________________________________________________________________________
tf_op_layer_Reshape (TensorFlow [(None, 7, 7, 64)]   0           dense_1[0][0]
                                                                 tf_op_layer_Shape[0][0]
__________________________________________________________________________________________________
conv2d_transpose (Conv2DTranspo (None, 14, 14, 64)   36928       tf_op_layer_Reshape[0][0]
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 14, 14, 64)   256         conv2d_transpose[0][0]
__________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTrans (None, 28, 28, 32)   18464       batch_normalization_2[0][0]
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 28, 28, 32)   128         conv2d_transpose_1[0][0]
__________________________________________________________________________________________________
conv2d_transpose_2 (Conv2DTrans (None, 28, 28, 1)    33          batch_normalization_3[0][0]
==================================================================================================
Total params: 479,617
Trainable params: 479,233
Non-trainable params: 384

ただし、やはり、中間層の値をtrain_stepに渡したりすることができないので、できることが限られている。。。
なお、中間層の値を参照することは以下のように、別途それ専用のモデルを作ればできる。ただし、これを学習のtrain_stepで参照するのはどうすればよいのだろうか?

# 中間層の値を取得するためのモデル
features_list = [layer.output for layer in model.layers]
feat_extraction_model = tf.keras.Model(inputs=inputs, outputs=features_list)

# 中間層の値を取得
features = feat_extraction_model(x_test[:img_num])

全体のコードは以下参照。

gist89bacab85ca1bdb63dca0462a172cef8