チラ裏備忘録

情報整理

MNISTデータの正規化・正解データのone-hot表記化(to_categoricalメソッド)

MNIST読み込み

import tensorflow as tf 
tf.keras.backend.clear_session() # Destroys the current TF graph and creates a new one.

from tensorflow import keras
from tensorflow.keras.utils import to_categorical # one-hot表記化メソッド

# データセットの読み込み
mnist = keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

画像データを0~1に正規化

# 0-255の値が入っているので、0-1に収まるよう正規化
x_train, x_test = x_train / 255.0, x_test / 255.0

正解データをone-hot表記

y_train = to_categorical(y_train, num_classes=10)
y_test = to_categorical(y_test, num_classes=10)
to_categoricalの挙動
y = np.array([1,2,1,3]) # one-hot表記化前の正解データ
to_categorical(y, num_classes=5) # yをone-hot表記化

# 出力
array([[0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 0., 1., 0.]], dtype=float32)

正解番号の一覧が入ったリストをto_categoricalメソッドに食わせると,一覧がone-hot表記化されたリストが返される.
num_classesはいくつに分類するか(MNISTなら0~9の10通りで分類するのでnum_classes=10とする).