Chainer:MNISTの手書き数字の読み込みと表示


2017年 03月 21日

MNISTの手書き数字の読み込みは、きわめて簡単である。
最初におまじないを並べたあと、次の1行だけで、トレーニングデータとテストデータが読み込まれ、2つのオブジェクトに入る。
train, test = chainer.datasets.get_mnist()

トレーニングデータがどのように入っているか、確認しよう。
>>> type(train)
<class 'chainer.datasets.tuple_dataset.tupledataset'>
>>> len(train)
60000
>>> train[0]
(array([ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
0.        ,  0.        ,  0.        ,  0.        ,  0.        ,

.........中略..........
0. , 0.11764707, 0.14117648, 0.36862746, 0.60392159, 0.66666669, 0.99215692, 0.99215692, 0.99215692, 0.99215692, 0.99215692, 0.88235301, 0.67450982, 0.99215692, 0.94901967, 0.76470596, 0.25098041, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.19215688, 0.9333334 , 0.99215692, 0.99215692, 0.99215692, 0.99215692, 0.99215692, .........中略.......... 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ], dtype=float32), 5) >>> train[0][0].shape (784,)
これから、trainは要素数60000個のリストで、各リストは、画像データと、数値(0から9)のタプル。
画像データは、要素数784個の1次元配列。

今回は、最初の48個の画像だけを表示するので、最初の48個だけを取り出す。
xtrain = train._datasets[0][:48]
ytrain = train._datasets[1][:48]
あとは、subplotsを使って、画面を分割して表示するだけである。
fig,ax = plt.subplots(nrows=6,ncols=8,sharex=True,sharey=True)
ax = ax.flatten()
for i in range(48):
img = xtrain[i].reshape(28,28)
ax[i].imshow(img,cmap='Greys',interpolation='none')
ax.flatten()により、forループ中で、表示枠を指定するのに、単にax[i]で済ますことができる。
画像データは、サイズが784の1次元配列になっているので、28×28の2次元配列に直している。
画像データは賢く表示してくれると困るので、ありのまま表示するようにinterpolation='none'を指定している。
後は、ファイルにセーブし表示している。
ax[0].set_xticks([])
ax[0].set_yticks([])
plt.tight_layout()
plt.savefig("mnistdisp48.png")
print(ytrain.reshape(6,8))
plt.show()
plt.subplots()で、分割された枠のx軸y軸の情報を共有するために、sharex、shareyをTrueにしている。そして、最後のところで、set_xticks([])、set_yticks([])により、目盛りなど目障りなものを表示しないようにしている。

ということで、最後にプログラム全体を示す。 “mnistdisp.py”: NMISTの読み込みと表示
#!/usr/bin/env python
# from http://nlp.dse.ibaraki.ac.jp/~shinnou/book/chainer.tgz

import numpy as np
import chainer
from chainer import cuda, Function, gradient_check, Variable
from chainer import optimizers, serializers, utils
from chainer import Link, Chain, ChainList
import chainer.functions as F
import chainer.links as L
import matplotlib.pyplot as plt

# http://yann.lecun.com/exdb/mnist/
train, test = chainer.datasets.get_mnist()
xtrain = train._datasets[0][:48]
ytrain = train._datasets[1][:48]
#xtest = test._datasets[0]
#ytest = test._datasets[1]

fig,ax = plt.subplots(nrows=6,ncols=8,sharex=True,sharey=True)
ax = ax.flatten()
for i in range(48):
img = xtrain[i].reshape(28,28)
ax[i].imshow(img,cmap='Greys',interpolation='none')

ax[0].set_xticks([])
ax[0].set_yticks([])
plt.tight_layout()
plt.savefig("mnistdisp48.png")
print(ytrain.reshape(6,8))
plt.show()