MNISTデータについて
備忘録としてMNISTデータの読み込み方法を忘れないように記載しておきます。
そもそもMNISTデータセットとは機械学習のための、0~9までの手書き文字の訓練画像60,000枚、テスト画像10,000万枚で構成されています。詳しくは、英語しか見つからなかったけどWikipediaをご参照ください。→(MNIST database - Wikipedia)
データの仕様はこのサイトが詳しく書かれていました。→(MNIST データの仕様を理解しよう)
とりあえずデータをダウンロードすると4つのファイルが落ちてきます。
train-images-idx3-ubyte: 学習用の画像セット
train-labels-idx1-ubyte: 学習用のラベルセット
t10k-images-idx3-ubyte: 検証用の画像セット
t10k-labels-idx1-ubyte: 検証用のラベルセット
とりあえずデータを読み込んで、データの形を確認してみます。
なお、本ブログではゼロから作るDeep Learningのコードを参照しているので、datasetというフォルダに上記の4ファイルおよびmnist.pyが入っていると想定しています。
# coding: utf-8 import sys, os sys.path.append(os.pardir) # 親ディレクトリのファイルをインポートするための設定 from dataset.mnist import load_mnist # mnist.pyのload_mnist関数を呼び出す。 (x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, normalize=False) # mnistデータ読み込むためのコード print(x_train.shape) # (60000, 784) print(t_train.shape) # (60000,) print(x_test.shape) # (10000, 784) print(t_test.shape) # (10000,)
(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, normalize=False)で読み込み、データの形を確認すると、訓練データは60,000個の784次元(28 x 28のグレー画像)、テストデータは10,000個の784次元(28 x 28のグレー画像)であることがわかります。
さらに、load_mnist関数の中身は以下の感じです。
P73〜74に何をしているかは書かれています。
def load_mnist(normalize=True, flatten=True, one_hot_label=False): """MNISTデータセットの読み込み Parameters ---------- normalize : 画像のピクセル値を0.0~1.0に正規化する one_hot_label : one_hot_labelがTrueの場合、ラベルはone-hot配列として返す one-hot配列とは、たとえば[0,0,1,0,0,0,0,0,0,0]のような配列 flatten : 画像を一次元配列に平にするかどうか Returns ------- (訓練画像, 訓練ラベル), (テスト画像, テストラベル) """ dataset = {} for key in key_file.keys(): if not os.path.exists(get_save_file_path(key)): init_mnist(key)
Saveしたfileがある場合は init_mnist()をスキップする
with open(get_save_file_path(key), 'rb') as f: dataset[key] = np.array(pickle.load(f))
セーブしたデータを読み込んで、1次元配列で読み込み
if normalize: for key in ('train_img', 'test_img'): dataset[key] = dataset[key].astype(np.float32) dataset[key] /= 255.0
読み込んだデータを正規化する。(ピクセル数255で割る)
if one_hot_label: dataset['train_label'] = _change_ont_hot_label(dataset['train_label']) dataset['test_label'] = _change_ont_hot_label(dataset['test_label'])
On hot labelに変換する
if not flatten: for key in ('train_img', 'test_img'): dataset[key] = dataset[key].reshape(-1, 1, 28, 28)
もし1次元配列で読み込まない場合は1 x 28 x 28の3次元配列で読み込む
return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label']) if __name__ == '__main__':