アカデミック

【Pythonお悩み解決】リストにおける軸の順番の意味について

この記事は, Pythonを利用して研究を行なっていく中で私がつまずいてしまったポイントをまとめていくものです。同じような状況で苦しんでいる方々の参考になれば嬉しいです。Pythonつまずきポイント集の目次は以下のページをご覧ください。

【超初心者お悩み解決】Pythonつまずきポイント記事まとめページ この記事は,Pythonを利用して研究を行なっていく中で私がつまずいてしまったポイントをまとめていくものです。同じような状況で苦しんで...

本記事で紹介する解決策がBestという保証はできません。正確な情報を発信するように心掛けていますが図らずも誤った情報を記載してしまう場合があります。もしご指摘等がありましたら,コメント欄またはお問い合わせページよりご連絡下さい。

環境

●Ubuntu 16.04
●Python 3.7.3
●conda 4.7.12
●pytorch 1.2.0

目的

リストの軸の順番に意味ってあるの?

軸の順番

Pytorchなどのライブラリを使うと,テンソルのサイズさえ合っていれば(もちろんそれ以外の箇所で問題がないことが全体になっていますが),学習が回ることが多いです。しかし,transpose(テンソルの場合permute)やflatten(テンソルの場合view)を軸の順番を意識せずに乱用しまくると,元のデータの形をぐちゃぐちゃにしてしまう可能性があります。

私がハマったのは,以下のような状況でした。入力特徴量として,(80, 512)のサイズのリストを用意していました。これをバッチサイズ200で(200, 80, 512)というテンソルにしてネットワークの学習を回していました。

推論時に,(200, 80, 512)を(80, 200*512)にしたいときに,何も考えずに以下のスクリプトを書いてしまったのです。

# dataは推論時に吐き出された(200, 80, 512)のテンソル

data_flat = data.transpose(1,2,0).reshape(80, -1)

dataを(80, 512, 200)にしてから(80, 200*512)にしようとするスクリプトです。しかし,このスクリプトには重大な問題があります。それは「512と200の意味を反映していないtranspose」になっているという点です。

どゆこと??

上の例では,(80, 512, 200)で80を保持しながらflattenしましたね。これはつまり,「200というカタマリを512個繋げたリストを作る」という操作に該当します。これって,間違っていますよね。

もともと,学習で使っていたリストのサイズは,バッチサイズを無視すれば(80, 512)でした。つまり,512側の軸が学習で利用されていたのです。「200というカタマリを512個繋げたリストを作る」のではなく,「512というカタマリを200個繋げたリストを作る」という操作をしなくてはなりません。

以上を踏まえて,「512というカタマリを200個繋げたリストを作る」スクリプトに直したものがこちらです。

# dataは推論時に吐き出された(200, 80, 512)のテンソル

data_flat = data.transpose(1,0,2).reshape(80, -1)

つまり,dataを(80, 200, 512)に直してから(200, 512)の部分を(200*512)にreshapeしています。(200, 512)というのは「512が200個ある」という意味ですので,推論で吐き出された「512」側の情報をしっかりと保持しながら整形を行うことができます。これにて一件落着です。

ABOUT ME
zuka
京都大学で機械学習を学んでいます。

COMMENT

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

※ Please enter your comments in Japanese to prevent spam.