以下では、tensorflowのバージョン2.1を前提にSequential、FunctionalおよびSubclassing APIを用いた実装方法についておおざっぱにまとめた。
hirotaka-hachiya.hatenablog.com
tensorflow 2.1のSubclassing APIでは、自由度が高い分学習まわりでfor文でepochを回すなどコードが煩雑になり一貫性の無いところが残念だった。
しかし、以下のtensorflow 2.2以上ではtf.keras.Modelのtrain_stepのオーバーライドが許可され、損失を複数設定し同時に最適化するなどの自作の学習に対しても、compile, fit, evaluateが実行できるようになり、シンプルにコードが書けるようになった。
https://blog.exxactcorp.com/tensorflow-2-2-0-released/
具体的には、Subclassingで独自のModelクラスの定義にて、train_stepメソッドを以下のように独自に定義する。
ここで、損失と評価方法は、model.compileにて指定したものを使う場合は、以下のように、self.compiled_lossと self.compiled_metrics.update_stateメソッドを用いる。
# Modelクラスを継承し,独自のlayerクラス(myConvとmyFC)を用いてネットワークを定義する # 独自のモデルクラスを作成 class myModel(tf.keras.Model): ~省略~ def train_step(self,data): x, y = data with tf.GradientTape() as tape: # 予測 y_pred = self(x, training=True) # 損失 loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses) # 勾配を用いた学習 trainable_vars = self.trainable_variables gradients = tape.gradient(loss, trainable_vars) self.optimizer.apply_gradients(zip(gradients, trainable_vars)) # 評価値の更新 self.compiled_metrics.update_state(y, y_pred) # 評価値をディクショナリで返す return {m.name: m.result() for m in self.metrics} ~省略~ # モデルの設定 model = myModel() # 学習方法の設定 model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy']) #----------------------------
そうすると、SequentialとFunctional APIと同様に、fitとevaluateを用いて学習評価を行うことができるようになり、煩わしいfor文などを省略することができる。
#---------------------------- # 学習 # - cnn関数を実行しネットワークを定義 # - fitで学習を実行 model.fit(x_train, y_train, batch_size=200, epochs=1) #---------------------------- #---------------------------- # 学習データに対する評価 train_loss, train_accuracy = model.evaluate(x_train, y_train, verbose=0) print('Train data loss:', train_loss) print('Train data accuracy:', train_accuracy) #----------------------------
コード全体は以下を参照。
gistc596cfdc5e60e40a16deb9eecd0bbcc9