numpyの内積(2) vdotのソース


2023年 09月 21日

ソースコード

関数がどういう動きをするかを調べるには、ソースを見るのが一番である。

/usr/lib/python3/dist-packages/numpy/core/multiarray.pyを見たら、説明が細々と書かれていたのであるが、結局C言語で実装されていたので、標準のソースを見るのはあきらめ、別のnumpyのソースを見ることにした。

それで見つけたのがjaxのnumpyのソース

## from https://jax.readthedocs.io/en/latest/_modules/jax/_src/numpy/lax_numpy.html#vdot
@util._wraps(np.vdot, lax_description=_PRECISION_DOC)
@partial(jit, static_argnames=('precision',), inline=True)
def vdot(a, b, *, precision=None):
  util.check_arraylike("vdot", a, b)dot
  if issubdtype(_dtype(a), complexfloating):
    a = ufuncs.conj(a)
  return dot(ravel(a), ravel(b), precision=precision)

あるいは、このvdotのドキュメント を見ると良いだろう。

これによると、第1引数aは、共役複素数に置き換えられている。そして、実際の内積の計算はdot関数を呼び出して行われている。

数学での複素ベクトルの内積と違い、第1引数の複素共役と第2引数とのdot関数による内積を計算して返している。

なぜ、数学とnumpyとでは計算が異なるのか。

複素ベクトルを利用する分野の代表は量子力学だろう。シュレディンガー方程式には$i$が含まれているし、物理量の計算をしようとすると複素ベクトルの内積を計算することになる。どちらの引数の共役複素数を利用するのかで計算結果は異なるので重要である。

理工学全般で、複素ベクトルの内積をとるとき、第1引数の共役複素数を利用するようである。numpyも、利用のことを考えてこのように決めたのではないかと思う。

ravelとflatten

vdotのソースで普段あまり見かけないravel()が使われている。これは、多次元配列を1次元配列に変換する。

>>> a = np.arange(1,13).reshape(2,2,3);a
array([[[ 1,  2,  3],
    	[ 4,  5,  6]],

   	[[ 7,  8,  9],
    	[10, 11, 12]]])
>>> a.ravel()
array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12])
>>> a.flatten()
array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12])

2x2x3の3次元配列を作って、ravel()とflatten()を作用させたら、どちらも1次元の配列になった。
ravel()とflatten()は同じと考えてよいのだろうか。

多次元配列を1次元配列に平坦化する点は同じだが、ravel()はできるだけビューを返し、flatten()はコピーを返す。
reshape()でも引数に-1を与えることで、1次元配列にすることができるが、ravel()と同じでビューを返す。

>>> a = np.arange(1,13).reshape(2,2,3);a
array([[[ 1,  2,  3],
    	[ 4,  5,  6]],

   	[[ 7,  8,  9],
    	[10, 11, 12]]])
>>> a.ravel()
array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12])
>>> a.flatten()
array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12])

引数中の /*

あまり見かけないと思うが、関数・メソッドの引数の中に、*が単独で書かれていることがある。

まず、引数が4つでテストしよう。

>>> def param_test( a, b, c, d ):
... 	print(a,b,c,d)
...
>>> param_test( 1, 2, 4, 8 )
1 2 4 8
>>> param_test( b=2, c=4, d=8, a=1 )
1 2 4 8
>>> param_test( 1, 2, d=8, c=4 )
1 2 4 8

引数の順番に割り当てられる引数を位置引数、キーワードへの代入の形で引数を書く引数をキーワード引数と呼ぶ。

上の例では、引数a,b,c,dは、呼び出し時の引数の書き方次第で、位置引数とみなされたり、キーワード引数とみなされる。位置引数は必ずキーワード引数より前と決まっている。

引数の間に、/ と * を挿入することで、強制的に位置引数にしたり、キーワード引数にすることができる。

  • / より前は、必ず位置引数になる。
  • * より後ろは、必ずキーワード引数になる。
  • / と * の間は、位置引数にもキーワード引数にもなれる。

では、例で試してみよう。

>>> def param_test( a, b, /, c, d, *, e, f=0 ):
... 	print( a, b, c, d, e, f )
...
>>> param_test( 1, 2, 3, 4, 5, 6 )
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: param_test() takes 4 positional arguments but 6 were given
>>> param_test( 1, 2, 3, 4, e=5 )
1 2 3 4 5 0
>>> param_test( 1, 2, 3, d=4, f=6, e=5 )
1 2 3 4 5 6
>>> param_test( 1, 2, 3, e=5, d=4 )
1 2 3 4 5 0
>>> param_test( 1, 2, d=4, e=5, f=6, c=3 )
1 2 3 4 5 6
>>> param_test( 1, b=2, d=4, e=5, c=3 )
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: param_test() got some positional-only arguments passed as keyword arguments: 'b'

説明は省略するが、以上の例で使い方は分かると思う。

関数を提供する側が、/ や * を仮引数の間に挿入することで、位置引数かキーワード引数かを指定できるのである。

その他にも、仮引数で色々指定できる。

  • 引数の直後に、:型名 を書くとアノテーション(注釈)になる。
  • 位置引数の直前に * を書くと、位置引数に対応する実引数のタプルが仮引数に渡される。
  • キーワード引数の直前に ** を書くと、キーワードとキーと値の組でできた辞書が仮引数に渡される。

これらの例は省略する。

ベクトル、配列、リスト

今回は、ベクトルについて書こうと思っていたのだが、長くなってしまったので次回に回す。