28×28の手書き数字の場合のAutoEncoder


2017年 03月 25日

MNISTの28×28の手書き数字は、現実の処理でも使うくらいのドット数がある。
これでAutoEndoderを作って、どのくらい再現性があるか調べてみよう。

まず、クラスMyAeのノード数の部分を書き換えよう。
中間層は、40ノードとする。つまり、748 –> 40 –> 748 と変換していこう。

class MyAE(Chain):
def __init__(self):
super(MyAE, self).__init__(
l1=L.Linear(784,40),
l2=L.Linear(40,784),
)
変更はたったこれだけで、28×28の画像に対応できるようになる。
後は、書くだけだ。

最初に、データを読み込む。

# http://yann.lecun.com/exdb/mnist/
train, test = chainer.datasets.get_mnist()
xtrain = train._datasets[0]
ytrain = train._datasets[1]
xtest = test._datasets[0]
ytest = test._datasets[1]
今回は、学習の途中で、一定エポック毎にAutoEncoderの結果の最初の48枚の画像を1つの画像ファイルまとめて出力した。

# Learn
losslist = []
for j in range(1000000):
x = Variable(xtrain[:10000])
model.cleargrads()             # model.zerograds() 非推奨
loss = model(x)
if j%10000 == 9999:
print( "%6d   %10.6f" % (j+1, loss.data) )
xx = Variable(xtrain[:48], volatile='on')
yy = model.fwd(xx)
plotresults( yy, "mnistaeout/mnistae%d.png" % (j+1) )

losslist.append(loss.data)     # 誤差をリストに追加
loss.backward()
optimizer.update()
10000エポック毎にスナップショット画像を吐き出して、全体で100万エポックまでやったのだが、このくらいやるとしっかり時間がかかり、走らせて結果は翌日確認することになった。(まだGPUは使っていない)

といことで、今回はプログラムの紹介だけで、結果は次回に示す。
今回のプログラムは:”minstae.py”

#!/usr/bin/env python
# from http://nlp.dse.ibaraki.ac.jp/~shinnou/book/chainer.tgz

import numpy as np
import chainer
from chainer import cuda, Function, gradient_check, Variable
from chainer import optimizers, serializers, utils
from chainer import Link, Chain, ChainList
import chainer.functions as F
import chainer.links as L
import matplotlib.pyplot as plt

# http://yann.lecun.com/exdb/mnist/
train, test = chainer.datasets.get_mnist()
xtrain = train._datasets[0]
ytrain = train._datasets[1]
xtest = test._datasets[0]
ytest = test._datasets[1]

class MyAE(Chain):
def __init__(self):
super(MyAE, self).__init__(
l1=L.Linear(784,40),
l2=L.Linear(40,784),
)

def __call__(self,x):
bv = self.fwd(x)
return F.mean_squared_error(bv, x)

def fwd(self,x):
fv = F.sigmoid(self.l1(x))
bv = self.l2(fv)
return bv

def plotresults(yy,filename):
fig,ax = plt.subplots(nrows=6,ncols=8,sharex=True,sharey=True)
ax = ax.flatten()
for i in range(48):
img = yy[i].data.reshape(28,28)
ax[i].imshow(img,cmap='Greys',interpolation='none')
ax[0].set_xticks([])
ax[0].set_yticks([])
plt.tight_layout()
plt.savefig(filename)
plt.close()

# Initialize model
model = MyAE()
optimizer = optimizers.SGD()
optimizer.setup(model)

# Learn
losslist = []
for j in range(1000000):
x = Variable(xtrain[:10000])
model.cleargrads()             # model.zerograds() 非推奨
loss = model(x)
if j%10000 == 9999:
print( "%6d   %10.6f" % (j+1, loss.data) )
xx = Variable(xtrain[:48], volatile='on')
yy = model.fwd(xx)
plotresults( yy, "mnistaeout/mnistae%d.png" % (j+1) )

losslist.append(loss.data)     # 誤差をリストに追加
loss.backward()
optimizer.update()