チラ裏備忘録

情報整理

Tensorflow 2.x Functional API ショボいメモ

Tensorflow 2.xの勉強を始めたのでとても初歩的な内容をメモ.

インポート

import tensorflow as tf
from tensorflow import keras

TF1では,keras
TF2では,tensorflow.keras
TF2を使うので2行目のように読み込む.

モデルの作成

層の定義には,keras.layersを使う.

主な種類

  • Input
    • 入力層
  • Flantten
    • 入力の平坦化(バッチサイズには影響を与えない)
  • Dropout
  • Dense
    • 全結合層
  • Conv2D
    • 畳み込み層

ここで紹介したDenseとConv2D等は引数にactivationを取るので,そこで任意の活性化関数を指定できる.

# example
keras.layers.Dense(64, activation='relu')

詳細: Coreレイヤー - Keras Documentation

簡単な流れ(1例)

  1. 入力層
  2. 結合層
  3. 活性化関数(結合層等の引数で指定する場合は不要?)
  4. (2,3を適宜繰り返す)
  5. 最後にモデルを作成
    • keras.Model(inputs=入力層, outputs=出力層)
# モデル構造を定義
inputs = keras.layers.Input(shape=(28, 28))
x = keras.layers.Flatten()(inputs)
x = keras.layers.Dense(128, activation='relu')(x)
x = keras.layers.Dropout(0.2)(x)
outputs = keras.layers.Dense(10, activation='softmax')(x)

# 入出力を定義
model = keras.Model(inputs=inputs, outputs=outputs)

keras functional APIの使い方メモ - Qiitaより抜粋・改変させていただきました.

Q.入力層以外の層の右についてる括弧は何?

その行で定義した『層への入力』を『層(入力)』と表す.
(よって始端のInputにはなく,入力層への入力は学習時にfitメソッドで渡す)
Inputs->Flatten->Dense->Dropout->Denseのフィードフォワードな接続をここで定義する.

学習

学習の流れは,

  1. Modelクラスのcompileメソッド
  2. Modelクラスのfitメソッド

compileメソッド

引数に,optimizer(最適化手法),loss(損失関数), metrics(評価関数)等を取る.

model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

Modelクラスの詳細: Modelクラス (functional API) - Keras Documentation
最適化手法の詳細: 最適化 - Keras Documentation
損失関数の詳細: 損失関数 - Keras Documentation

fitメソッド

主な引数

  • x(訓練データ)
    • モデルが単一の入力を持つ場合はNumpy配列
    • モデルが複数の入力を持つ場合はNumpy配列のリスト
  • y(教師データ)
    • <訓練データと同様>
  • batch_size(バッチサイズ)
    • 整数もしくはNone.
    • 指定しなければデフォルトで32.
  • epochs(エポック数)
    • 整数.
    • 1エポック=訓練データをすべて使い切ったことを意味する ※1
  • validation_split
    • 0から1の間の浮動小数点数
    • モデルは訓練データからこの割合分のデータを区別し,それらでは学習を行わず,各試行の終わりにこのデータを用いて損失関数と評価関数を評価する.
hist = model.fit(x_train, y_train, validation_split=0.1, epochs=5)

※1補足
訓練データが1,000個,バッチサイズが200とすると,5(1000/200)バッチ(イテレーション数5)の学習を終えた時点が1エポック
 ・バッチサイズ…1バッチに含まれるデータ数

精度の計算

検証用データを用いてモデルの精度を確認する.

print(model.evaluate(x_test, y_test))