覚え書きブログ

train_stepのオーバーライドによるsubclassingの簡単化

以下では、tensorflowのバージョン2.1を前提にSequential、FunctionalおよびSubclassing APIを用いた実装方法についておおざっぱにまとめた。
hirotaka-hachiya.hatenablog.com



tensorflow 2.1のSubclassing APIでは、自由度が高い分学習まわりでfor文でepochを回すなどコードが煩雑になり一貫性の無いところが残念だった。
しかし、以下のtensorflow 2.2以上ではtf.keras.Modelのtrain_stepのオーバーライドが許可され、損失を複数設定し同時に最適化するなどの自作の学習に対しても、compile, fit, evaluateが実行できるようになり、シンプルにコードが書けるようになった。
https://blog.exxactcorp.com/tensorflow-2-2-0-released/

具体的には、Subclassingで独自のModelクラスの定義にて、train_stepメソッドを以下のように独自に定義する。
ここで、損失と評価方法は、model.compileにて指定したものを使う場合は、以下のように、self.compiled_lossと self.compiled_metrics.update_stateメソッドを用いる。

# Modelクラスを継承し,独自のlayerクラス(myConvとmyFC)を用いてネットワークを定義する
# 独自のモデルクラスを作成
class myModel(tf.keras.Model):
~省略~
    def train_step(self,data):
        x, y = data

        with tf.GradientTape() as tape:
            
            # 予測
            y_pred = self(x, training=True)

            # 損失
            loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
        
        # 勾配を用いた学習
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # 評価値の更新
        self.compiled_metrics.update_state(y, y_pred)

        # 評価値をディクショナリで返す
        return {m.name: m.result() for m in self.metrics}
~省略~

# モデルの設定
model = myModel()

# 学習方法の設定
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])
#----------------------------

そうすると、SequentialとFunctional APIと同様に、fitとevaluateを用いて学習評価を行うことができるようになり、煩わしいfor文などを省略することができる。

#----------------------------
# 学習
# - cnn関数を実行しネットワークを定義
# - fitで学習を実行
model.fit(x_train, y_train, batch_size=200, epochs=1)
#----------------------------

#----------------------------
# 学習データに対する評価
train_loss, train_accuracy = model.evaluate(x_train, y_train, verbose=0)
print('Train data loss:', train_loss)
print('Train data accuracy:', train_accuracy)
#----------------------------

コード全体は以下を参照。

gistc596cfdc5e60e40a16deb9eecd0bbcc9