※現在、ブログ記事を移行中のため一部表示が崩れる場合がございます。
順次修正対応にあたっておりますので何卒ご了承いただけますよう、お願い致します。

Chainer:学習結果(脳)をsave/loadする前に


2017年 02月 21日

Deep Learningを色々いじってきて、なんとなく学習出来ているのは分かったのだが、学習には結構時間が掛かる。そのため、一度学習した結果をファイルに出力しておき、必要になったときにファイルを読み込むことで学習することなく賢くして、使いたいものである。

ということで、学習結果(脳)のファイルへのsave/loadについて検討してみよう。

そのために、いままで利用してきたプログラムを使おう。

いままでのプログラムは、学習部分も、学習結果の判定のためのテスト部分も1つのファイルになっていたので、まずは2つの部分に分けようと思う。

このとき、学習結果が入っているオブジェクト(脳)は同じでなければならず、プログラムではそれを表すクラスが使われていた。
手書きデータの場合には、次のクラスがそれだ。

class DigitsChain(Chain):
def __init__(self):
super(DigitsChain, self).__init__(
l1=L.Linear(64,32),		# 1-2層
l2=L.Linear(32,10),		# 2-3層
)

def __call__(self,x,y):
return F.mean_squared_error(self.fwd(x), y)

def fwd(self,x):
h1 = F.sigmoid(self.l1(x))
h2 = self.l2(h1)
return h2
今までのファイルは、「学習&セーブ」と「ロード&テスト」の2つのファイルに別れる。
このとき、上記クラスは共通だから、これを”digitschain.py”という名前のファイルにする。

Pythonのモジュールは異常に簡単だ。 上のように、何でも良いから、.pyがついたファイルにしてしまえば、それだけでモジュールになってしまう。 他の言語のように、モジュールにするにはあれこれ宣言しないとダメとか全然ない。 そして、そのモジュールをインポートすれば使えるようになる。 記述フォーマットはいくつか種類があるが、モジュール名(=ファイル名)と、インポートする物を記述するのが次である。
from digitschain import DigitsChain
ということで、実際にimportしてみた。
>>> from digitschain import DigitsChain
Traceback (most recent call last):
File "", line 1, in 
File "/home/fuji/Study/Python/Chainer/digitschain.py", line 11, in 
class DigitsChain(Chain):
NameError: name 'Chain' is not defined
>>>
クラスだけのファイルはダメだった。
Chainというのは親クラスなので、
from chainer import Chain
の1行だけを加えたら、一応エラーは出なくなった。
L.やF.が使われているのにエラーにならないということは、メソッドを定義だけでは中身は検査されず、エラーにならないようだ。
でも、一応、おまじないを最初に全部加えておいた。
 
“digitschain.py”
# Define model

import numpy as np
import chainer
from chainer import cuda, Function, gradient_check, Variable
from chainer import optimizers, serializers, utils
from chainer import Link, Chain, ChainList
import chainer.functions as F
import chainer.links as L

class DigitsChain(Chain):
def __init__(self):
super(DigitsChain, self).__init__(
l1=L.Linear(64,32),		# 1-2層
l2=L.Linear(32,10),		# 2-3層
)

def __call__(self,x,y):
return F.mean_squared_error(self.fwd(x), y)

def fwd(self,x):
h1 = F.sigmoid(self.l1(x))
h2 = self.l2(h1)
return h2
さて、次回は、これをimport して、学習内容をファイルに書き出すことをやってみよう。