※現在、ブログ記事を移行中のため一部表示が崩れる場合がございます。
順次修正対応にあたっておりますので何卒ご了承いただけますよう、お願い致します。

Pythonのコマンド引数パーサー


2017年 04月 07日

MNISTの学習データは6万画像もあって、非常に多い。
もっと少ない枚数でも大丈夫かも。
ということで、学習データ画像数とテストの正解率の関係を知りたいと思ったのだが、コマンド引数にはデータ数を指定する項目がない。

それで、サンプルソース(train_mnist.py)プログラムを見たら、コマンド引数の処理をargparseで行っているようだ。
main()の最初の部分を以下に示す。

import argparse

def main():
parser = argparse.ArgumentParser(description='Chainer example: MNIST')
parser.add_argument('--batchsize', '-b', type=int, default=100,
help='Number of images in each mini-batch')
parser.add_argument('--epoch', '-e', type=int, default=20,
help='Number of sweeps over the dataset to train')
parser.add_argument('--frequency', '-f', type=int, default=-1,
help='Frequency of taking a snapshot')
parser.add_argument('--gpu', '-g', type=int, default=-1,
help='GPU ID (negative value indicates CPU)')
parser.add_argument('--out', '-o', default='result',
help='Directory to output the result')
parser.add_argument('--resume', '-r', default='',
help='Resume the training from snapshot')
parser.add_argument('--unit', '-u', type=int, default=1000,
help='Number of units')
args = parser.parse_args()

print('GPU: {}'.format(args.gpu))
print('# unit: {}'.format(args.unit))
print('# Minibatch-size: {}'.format(args.batchsize))
print('# epoch: {}'.format(args.epoch))
print('')
ということで、Pythonのマニュアルを調べたら、説明があった。
16.4. argparse — コマンドラインオプション、引数、サブコマンドのパーサー
このドキュメントを実はほとんど参考にせず、元のプログラムにちょっと手を入れてみた。
    parser.add_argument('--number', '-n', type=int, default=60000,
help='Number of training data')
args = parser.parse_args()

print('GPU: {}'.format(args.gpu))
print('# number: {}'.format(args.number))
print('# unit: {}'.format(args.unit))
print('# Minibatch-size: {}'.format(args.batchsize))
print('# epoch: {}'.format(args.epoch))
print('')
これで、引数 -n が読み取られるはずだ。 引数が省略されているときは、全データ数の60000をdefault値として設定しておいた。 さて、実際に学習のデータ量を、引数で指定した値、 args.number まで減らすには、データを読み込んだところで、先頭から指定した個数だけにしてしまうことにした。
    # Load the MNIST dataset
train, test = chainer.datasets.get_mnist()
train = train[:args.number]

これで実行すると、こんな感じになった。
Chainer$ python train_mnist0.py -n 1000
GPU: -1
# number: 1000
# unit: 1000
# Minibatch-size: 100
# epoch: 20

epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy  elapsed_time
1           1.31764     0.606311              0.631          0.8063                    0.632579
2           0.444206    0.490377              0.867          0.8456                    1.48039
3           0.259161    0.44472               0.911          0.8633                    2.32615
4           0.149083    0.41302               0.958          0.8786                    3.21353
5           0.0815992   0.402668              0.983          0.8868                    4.07794
6           0.0438824   0.434466              0.994          0.8755                    4.95173
7           0.0247996   0.412777              0.999          0.8861                    5.80761
8           0.0134032   0.398092              1              0.8901                    6.66825
9           0.00692543  0.387881              1              0.8963                    7.55733
10          0.00444668  0.397806              1              0.8985                    8.4276
11          0.00350568  0.408884              1              0.8968                    9.27244
12          0.00275337  0.409523              1              0.8982                    10.1134
13          0.00231785  0.413871              1              0.8971                    10.9584
14          0.00198377  0.419695              1              0.8968                    11.8287
15          0.00173918  0.422949              1              0.8973                    12.6778
16          0.00155182  0.426604              1              0.897                     13.5328
17          0.00139003  0.42915               1              0.897                     14.395
18          0.0012566   0.431308              1              0.8975                    15.2749
19          0.00113922  0.434765              1              0.897                     16.1276
20          0.00104006  0.436891              1              0.8973                    17.0124
Chainer$
学習データ数が60000から1000に減ると、かなり速くなった。 ということで、とりあえず学習データの個数をコマンドラインから制御できるようになった。 データ数による学習成果への影響については、次回調べることにする。