MNISTデータについて

備忘録としてMNISTデータの読み込み方法を忘れないように記載しておきます。

そもそもMNISTデータセットとは機械学習のための、0~9までの手書き文字の訓練画像60,000枚、テスト画像10,000万枚で構成されています。詳しくは、英語しか見つからなかったけどWikipediaをご参照ください。→(MNIST database - Wikipedia)

データの仕様はこのサイトが詳しく書かれていました。→(MNIST データの仕様を理解しよう

とりあえずデータをダウンロードすると4つのファイルが落ちてきます。

train-images-idx3-ubyte: 学習用の画像セット
train-labels-idx1-ubyte: 学習用のラベルセット
t10k-images-idx3-ubyte: 検証用の画像セット
t10k-labels-idx1-ubyte: 検証用のラベルセット

とりあえずデータを読み込んで、データの形を確認してみます。
なお、本ブログではゼロから作るDeep Learningのコードを参照しているので、datasetというフォルダに上記の4ファイルおよびmnist.pyが入っていると想定しています。

# coding: utf-8
import sys, os
sys.path.append(os.pardir)  # 親ディレクトリのファイルをインポートするための設定
from dataset.mnist import load_mnist # mnist.pyのload_mnist関数を呼び出す。

(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, normalize=False) # mnistデータ読み込むためのコード

print(x_train.shape) # (60000, 784)
print(t_train.shape) # (60000,)
print(x_test.shape) # (10000, 784)
print(t_test.shape) # (10000,)

(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, normalize=False)で読み込み、データの形を確認すると、訓練データは60,000個の784次元(28 x 28のグレー画像)、テストデータは10,000個の784次元(28 x 28のグレー画像)であることがわかります。

さらに、load_mnist関数の中身は以下の感じです。
P73〜74に何をしているかは書かれています。

def load_mnist(normalize=True, flatten=True, one_hot_label=False):
    """MNISTデータセットの読み込み
    
    Parameters
    ----------
    normalize : 画像のピクセル値を0.0~1.0に正規化する
    one_hot_label : 
        one_hot_labelがTrueの場合、ラベルはone-hot配列として返す
        one-hot配列とは、たとえば[0,0,1,0,0,0,0,0,0,0]のような配列
    flatten : 画像を一次元配列に平にするかどうか 
    
    Returns
    -------
    (訓練画像, 訓練ラベル), (テスト画像, テストラベル)
    """
    dataset = {}
    for key in key_file.keys():
        if not os.path.exists(get_save_file_path(key)):
            init_mnist(key)

Saveしたfileがある場合は init_mnist()をスキップする

        with open(get_save_file_path(key), 'rb') as f:
            dataset[key] = np.array(pickle.load(f))

セーブしたデータを読み込んで、1次元配列で読み込み

    if normalize:
        for key in ('train_img', 'test_img'):
            dataset[key] = dataset[key].astype(np.float32)
            dataset[key] /= 255.0

読み込んだデータを正規化する。(ピクセル数255で割る)

 if one_hot_label:
        dataset['train_label'] = _change_ont_hot_label(dataset['train_label'])
        dataset['test_label'] = _change_ont_hot_label(dataset['test_label'])    

On hot labelに変換する

    if not flatten:
         for key in ('train_img', 'test_img'):
            dataset[key] = dataset[key].reshape(-1, 1, 28, 28)

もし1次元配列で読み込まない場合は1 x 28 x 28の3次元配列で読み込む

    return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label']) 
if __name__ == '__main__':