以下を参考に、tensorflow2.2以降で許可されたtrain_step, test_stepおよびpredict_stepのオーバーライドを用いて、Subclassing APIでautoencoderを実装してみた。
hirotaka-hachiya.hatenablog.com
以下は、全体のコードである。元画像と復元画像の平均二乗誤差を最小化するように学習している。
今回から新たに、tf.keras.callbacks.ModelCheckpointとmodel.load_weightsを用いて学習したパラメータの保存と読み込みをできるようにした。
gistbaf88b3900c5a179a2d6474fe3ae3919
実行すると以下のように学習とテストデータにおける誤差と画像が表示される。
Epoch 1/3 300/300 [==============================] - ETA: 0s - loss: 0.0289 - mae: 0.0856 Epoch 00001: saving model to autoencoder_training/cp.ckpt 300/300 [==============================] - 107s 358ms/step - loss: 0.0289 - mae: 0.0856 Epoch 2/3 300/300 [==============================] - ETA: 0s - loss: 0.0053 - mae: 0.0282 Epoch 00002: saving model to autoencoder_training/cp.ckpt 300/300 [==============================] - 113s 376ms/step - loss: 0.0053 - mae: 0.0282 Epoch 3/3 300/300 [==============================] - ETA: 0s - loss: 0.0040 - mae: 0.0231 Epoch 00003: saving model to autoencoder_training/cp.ckpt 300/300 [==============================] - 89s 298ms/step - loss: 0.0040 - mae: 0.0231 Train data loss: 0.0037610987201333046 Train data mae: 0.021624118089675903 Test data loss: 0.003681402187794447 Test data mae: 0.02133253403007984
1行目が復元した画像で、2行目が元の画像である。ほぼ完ぺきに再現できている。
train_step, test_stepおよびpredict_stepをオーバーライドすることにより学習まわりはシンプルに書くことができたが、SequentialやFunctionalのようにmode.summary()を用いてネットワークの構造を可視化できないので、ネットワークの設計時のデバッグがしにくい問題がある。なぜ、summary()をつかえないのだろうか。。。
そこで、モデルの定義はFunctional APIを使い、train_step, test_stepおよびpredict_stepをオーバーライドしたmyModelに、引数でinputs, outputsを設定することによりネットワークを定義するように変更してみた。
#---------------------------- # Functionalを用いたネットワークの定義 def autoencoder(input_shape): inputs = tf.keras.layers.Input(shape=input_shape, name="inputs") # conv1 conv1 = tf.keras.layers.Conv2D(filters=32, strides=(2, 2), padding='same', kernel_size=(3, 3), activation='relu')(inputs) conv1 = tf.keras.layers.BatchNormalization()(conv1) # conv2 conv2 = tf.keras.layers.Conv2D(filters=64, strides=(2, 2), padding='same', kernel_size=(3, 3), activation='relu')(conv1) conv2 = tf.keras.layers.BatchNormalization()(conv2) # fc1 conv2_flat = tf.keras.layers.Flatten()(conv2) fc = tf.keras.layers.Dense(units=64,activation='relu')(conv2_flat) # defc defc = tf.keras.layers.Dense(units=3136,activation='relu')(fc) defc = tf.reshape(defc, tf.shape(conv2)) # deconv1 deconv1 = tf.keras.layers.Conv2DTranspose(filters=64, strides=(2, 2), padding='same', kernel_size=(3, 3), activation='relu')(defc) deconv1 = tf.keras.layers.BatchNormalization()(deconv1) # deconv2 deconv2 = tf.keras.layers.Conv2DTranspose(filters=32, strides=(2, 2), padding='same', kernel_size=(3, 3), activation='relu')(deconv1) deconv2 = tf.keras.layers.BatchNormalization()(deconv2) # deconv2 outputs = tf.keras.layers.Conv2DTranspose(filters=1, strides=(1, 1), padding='same', kernel_size=(1, 1), activation='sigmoid')(deconv2) return inputs, outputs #---------------------------- ~省略~ # モデルの設定 inputs, outputs = autoencoder((H,W,C)) model = myModel(inputs,outputs) # 学習方法の設定 model.compile(optimizer='adam',loss='mean_squared_error',metrics=['mae']) model.summary()
これにより、以下のようにsummary()でネットワークの構造が確認できるとともに、学習方法を独自に設定することができる。
Layer (type) Output Shape Param # Connected to ================================================================================================== inputs (InputLayer) [(None, 28, 28, 1)] 0 __________________________________________________________________________________________________ conv2d (Conv2D) (None, 14, 14, 32) 320 inputs[0][0] __________________________________________________________________________________________________ batch_normalization (BatchNorma (None, 14, 14, 32) 128 conv2d[0][0] __________________________________________________________________________________________________ conv2d_1 (Conv2D) (None, 7, 7, 64) 18496 batch_normalization[0][0] __________________________________________________________________________________________________ batch_normalization_1 (BatchNor (None, 7, 7, 64) 256 conv2d_1[0][0] __________________________________________________________________________________________________ flatten (Flatten) (None, 3136) 0 batch_normalization_1[0][0] __________________________________________________________________________________________________ dense (Dense) (None, 64) 200768 flatten[0][0] __________________________________________________________________________________________________ dense_1 (Dense) (None, 3136) 203840 dense[0][0] __________________________________________________________________________________________________ tf_op_layer_Shape (TensorFlowOp [(4,)] 0 batch_normalization_1[0][0] __________________________________________________________________________________________________ tf_op_layer_Reshape (TensorFlow [(None, 7, 7, 64)] 0 dense_1[0][0] tf_op_layer_Shape[0][0] __________________________________________________________________________________________________ conv2d_transpose (Conv2DTranspo (None, 14, 14, 64) 36928 tf_op_layer_Reshape[0][0] __________________________________________________________________________________________________ batch_normalization_2 (BatchNor (None, 14, 14, 64) 256 conv2d_transpose[0][0] __________________________________________________________________________________________________ conv2d_transpose_1 (Conv2DTrans (None, 28, 28, 32) 18464 batch_normalization_2[0][0] __________________________________________________________________________________________________ batch_normalization_3 (BatchNor (None, 28, 28, 32) 128 conv2d_transpose_1[0][0] __________________________________________________________________________________________________ conv2d_transpose_2 (Conv2DTrans (None, 28, 28, 1) 33 batch_normalization_3[0][0] ================================================================================================== Total params: 479,617 Trainable params: 479,233 Non-trainable params: 384
ただし、やはり、中間層の値をtrain_stepに渡したりすることができないので、できることが限られている。。。
なお、中間層の値を参照することは以下のように、別途それ専用のモデルを作ればできる。ただし、これを学習のtrain_stepで参照するのはどうすればよいのだろうか?
# 中間層の値を取得するためのモデル features_list = [layer.output for layer in model.layers] feat_extraction_model = tf.keras.Model(inputs=inputs, outputs=features_list) # 中間層の値を取得 features = feat_extraction_model(x_test[:img_num])
全体のコードは以下参照。
gist89bacab85ca1bdb63dca0462a172cef8