アカデミック

【Pythonお悩み解決】LSTMがうまく学習してくれない

この記事は, Pythonを利用して研究を行なっていく中で私がつまずいてしまったポイントをまとめていくものです。同じような状況で苦しんでいる方々の参考になれば嬉しいです。Pythonつまずきポイント集の目次は以下のページをご覧ください。

【超初心者お悩み解決】Pythonつまずきポイント記事まとめページ この記事は,Pythonを利用して研究を行なっていく中で私がつまずいてしまったポイントをまとめていくものです。同じような状況で苦しんで...

本記事で紹介する解決策がBestという保証はできません。正確な情報を発信するように心掛けていますが図らずも誤った情報を記載してしまう場合があります。もしご指摘等がありましたら,コメント欄またはお問い合わせページよりご連絡下さい。

読みたい場所へジャンプ!

環境

●Ubuntu 16.04
●Python 3.7.3
●conda 4.7.12
●pytorch 1.2.0

現象

原因は分からないけど何故かLSTMがうまく学習してくれない。

原因

nn.LSTM」の戻り値の2つ目が(h_0, c_0)であるのに,(c_0, h_0)として受け取ってしまっていました。再帰的なLSTMを設計していたため,hだと思って扱っていたものが実はcで,cだと思って扱っていたものが実はhであったという状況でした。学習が進まないのも納得です。

解決方法

# Before
y, (c_n, h_n) = self.lstm_beat(x_beat, t, h_0, c_0)

# After
y, (h_n, c_n) = self.lstm_beat(x_beat, t, h_0, c_0)
ABOUT ME
zuka
京都大学で機械学習を学んでいます。

COMMENT

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

※ Please enter your comments in Japanese to prevent spam.