前回に引き続き、autoencoderをsublassing AIPで実装してみた。
hirotaka-hachiya.hatenablog.com
今回は、以下のように画像を2値化し分類問題として解いてみた。
また、ラベルをone-hotベクトルで表現し、embeddingベクトルに追加してみた。
# 画像の2値化 x_train[x_train<125] = 0 x_train[x_train>=125] = 1 x_test[x_test<125] = 0 x_test[x_test>=125] = 1 # ラベルをone-hot表現にする y_train_onehot = np.eye(10)[y_train] y_test_onehot = np.eye(10)[y_test]
損失関数は普通の分類と同様にsparse_categorical_crossentropyを用いて、decoderの出力のチャネル数を2(背景か前景か)に設定している。
self.deconv3 = myDeconv(chn=2, conv_kernel=(3,3), strides=(1,1), activation='softmax', isBatchNorm=False) ~省略~ model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])
また、embeddingベクトルにクラスラベルのone-hotベクトルの追加は、以下のtf.concatで行っている。
if self.isEmbedLabel: # embeddingにone-hotラベルをconcat zz = tf.concat([z,y], axis=1) else: zz = z
また、以下のcallbacks.ModelCheckpointを用いて学習したモデルを選択し、load_weightsを用いて学習したモデルの読み込みができるようにした。
# 学習したパラメータを保存するためのチェックポイントコールバックを作る checkpoint_path = "autoencoder_training/cp.ckpt" checkpoint_dir = os.path.dirname(checkpoint_path) cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path, save_weights_only=True, verbose=1)
# 学習したパラメータの読み込み
model.load_weights(checkpoint_path)
また、fitの返り値を使って学習と評価データに対するlossとaccuracyのプロットもしている。
# fitで学習を実行 history = model.fit((x_train, y_train_onehot), x_train, batch_size=200, epochs=5, validation_split=0.1, callbacks=[cp_callback]) # 損失のプロット acc = history.history['accuracy'] val_acc = history.history['val_accuracy'] loss = history.history['loss'] val_loss = history.history['val_loss'] epochs = np.arange(len(acc)) plt.plot(epochs,acc,'bo-',label='training acc') plt.plot(epochs,val_acc,'b',label='validation acc') plt.title('Training and validation acc') plt.legend() plt.figure() plt.plot(epochs,loss,'bo-',label='training loss') plt.plot(epochs,val_loss,'b',label='validation loss') plt.title('Training and Validation loss') plt.legend() plt.show()
コード全体は以下を参照。
gist2a8c18b3799fbf7331f1f24c831b62f5
実行すると以下のようにaccuracyとlossのグラフが表示される。学習が進むにつれ順調にaccuracyがあがり、lossが下がっているのがわかる。
posクラスのスコアのプロットは以下のようになる(1行目が予測、2行目が真値)。