覚え書きブログ

tensorflow2.2のsubclassingによるautoencoderの実装2

前回に引き続き、autoencoderをsublassing AIPで実装してみた。
hirotaka-hachiya.hatenablog.com

今回は、以下のように画像を2値化し分類問題として解いてみた。
また、ラベルをone-hotベクトルで表現し、embeddingベクトルに追加してみた。

# 画像の2値化
x_train[x_train<125] = 0
x_train[x_train>=125] = 1
x_test[x_test<125] = 0
x_test[x_test>=125] = 1

# ラベルをone-hot表現にする
y_train_onehot = np.eye(10)[y_train]
y_test_onehot = np.eye(10)[y_test]

損失関数は普通の分類と同様にsparse_categorical_crossentropyを用いて、decoderの出力のチャネル数を2(背景か前景か)に設定している。

self.deconv3 = myDeconv(chn=2, conv_kernel=(3,3), strides=(1,1), activation='softmax', isBatchNorm=False)
~省略~
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])

また、embeddingベクトルにクラスラベルのone-hotベクトルの追加は、以下のtf.concatで行っている。

        if self.isEmbedLabel:
            # embeddingにone-hotラベルをconcat
            zz = tf.concat([z,y], axis=1)
        else:
            zz = z  

また、以下のcallbacks.ModelCheckpointを用いて学習したモデルを選択し、load_weightsを用いて学習したモデルの読み込みができるようにした。

# 学習したパラメータを保存するためのチェックポイントコールバックを作る
checkpoint_path = "autoencoder_training/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path, save_weights_only=True, verbose=1)
    # 学習したパラメータの読み込み
    model.load_weights(checkpoint_path)


また、fitの返り値を使って学習と評価データに対するlossとaccuracyのプロットもしている。

    # fitで学習を実行
    history = model.fit((x_train, y_train_onehot), x_train, batch_size=200, epochs=5, validation_split=0.1, callbacks=[cp_callback])

    # 損失のプロット
    acc = history.history['accuracy']
    val_acc = history.history['val_accuracy']
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    epochs = np.arange(len(acc))

    plt.plot(epochs,acc,'bo-',label='training acc')
    plt.plot(epochs,val_acc,'b',label='validation acc')
    plt.title('Training and validation acc')
    plt.legend()

    plt.figure()
    plt.plot(epochs,loss,'bo-',label='training loss')
    plt.plot(epochs,val_loss,'b',label='validation loss')
    plt.title('Training and Validation loss')
    plt.legend()

    plt.show()  

コード全体は以下を参照。

gist2a8c18b3799fbf7331f1f24c831b62f5

実行すると以下のようにaccuracyとlossのグラフが表示される。学習が進むにつれ順調にaccuracyがあがり、lossが下がっているのがわかる。
f:id:hirotaka_hachiya:20200913105929p:plain

posクラスのスコアのプロットは以下のようになる(1行目が予測、2行目が真値)。
f:id:hirotaka_hachiya:20200913110228p:plain