LSTMってなんだ??
最近の機械学習で大活躍しているって聞いたけど?
今回は,LSTMを知ってみたいと思っている方や,pythonでの実装を参考にしたい方を対象とした内容になります。本記事はpython実践講座シリーズの内容になります。その他の記事は,こちらの「Python入門講座/実践講座まとめ」をご覧ください。
読みたい場所へジャンプ!
はじめに
早く結論を言いなさい!LSTMって何やねん!
となっているかと思いますが,LSTMの説明を始める前に少しだけお話しさせてください。
今やLSTMは機械学習屋さんでは知らない人の方が少ないネットワークになっています。音声認識をはじめとしたあらゆる分野のブレイクスルーを起こした革新的な技術といっても過言ではないでしょう。
しかし,実はその歴史は古く,1997年[1]に最初の原型が提案されました。多くのLSTMの解説では,勾配消失/発散が問題となっていたことからLSTMを導入しています。歴史を辿っていくという意味で,最も妥当な説明だと思います。
しかし,本記事はもう少しキャッチーにLSTMを理解してみたいという方に向けた内容になっています。具体的に言えば,LSTMが
時系列データをうまく扱えるようなネットワークなんでしょ!(知らんけど)
この程度にご存知である層の方々向けに書いていきます。
私自身の経験を少しだけお話ししておきます。音声認識や音楽情報処理の分野では,LSTMはなくてはならない存在となっています。最近になってAttentionなどの台頭もありますが,いまだにLSTMの性能は劣りません。
このような背景から,LSTMを使ってみようとは思っていたものの,多くの解説記事やPytorchのドキュメントなどを読んでも,サッパリわからないという状況でした。
特に,まずは実装だと思ってPytorchのドキュメントを読むと,絶望すると思います。だって,こんな意味わからん数式が羅列されているのですから。

ナニコレ?!?!
何の説明もなしに,こんな数式見せられて,正しくPytorchのLSTMを使える人なんているのでしょうか。私も,最初の頃は食わず嫌いをしていて,何とかLSTMを使わずにモデルを組み立てることはできないかと試行錯誤していました。恥ずかしながら。
でもでも。思うに,このような理由からLSTMを毛嫌いして実装を避けて通っている人が多いというのは,学問の発展を妨げている気がしてなりません。「完璧に理解してから次に進む」というアプローチは,かえって遠回りになるケースが往往にしてあります。
ですので,本記事の目標は,先ほどの数式を理解して,PytorchのLSTMをある程度使いこなせるようになることを目指したいと思います。
前置きが長くなりましたが,まとめておきます。
【本記事の目標】
1.LSTMに出てくる基本的な数式を理解して
2.Pytorchで実装できるようになる
本記事の構成
本記事では,ボトムアップ形式でLSTMを理解していきます。最初に全体像を見せて詳細を説明していくトップダウン形式の説明とは異なり,徐々に詳細を詰めていくという方法をとりたいと思います。
ボトムアップを実感していただくために,LSTMの理解をレベル1〜レベル9に分けました。レベルが上がるにしたがって,どんどんLSTMの詳細に突っ込んでいきます。最初に目次を示しておきます。
●レベル1:入出力
●レベル2:ループ構造
●レベル3:展開
●レベル4:注目
●レベル5:正規化
●レベル6:状態
★コラム:WhereとHow
●レベル7:忘却
●レベル8:更新
●レベル9:最終出力
レベル1:入出力

何事もまずは問題定義から入りましょう。LSTMは何をするのかというと,「入力から出力を予測する」ことを行います。機械学習には色々な問題設定(分類・クラスタリング・回帰等)がありますが,まず意識してほしいことはLSTMはNN(ニューラルネットワーク)であるということです。
LSTMは,NNの中でも系列データを得意とする再帰型ニューラルネットワークに該当します。具体的な問題設定としては,NNというくくりの中で,系列データの回帰に使われることが多い印象です。
図中のAはネットワークを意味しています。まだブラックボックスにしています。添字の
【レベル1のまとめ】
レベル2:ループ構造

LSTM(RNN)は系列データが得意ということで,先ほどの入出力関係を複数個に拡張します。そのときに,RNNではネットワーク部にループ構造を持たせることで,系列データの情報を保持していきます。
ループ構造導入に伴って,入力はベクトルになり,出力もベクトルになっています。ただし,1つ1つのデータが複数の要素を持っているために,レベル1における入出力がベクトル,レベル2における入出力は行列になります。以下では,わざわざ太字に表すのが面倒なため,基本的に変数はベクトルだということにします。
【レベル2のまとめ】
レベル3:展開

ループ構造を横に展開します。LSTMを含んだ多くの時系列ネットワークは,このように横に展開されて説明されることが非常に多いため,このレベル3の図はよく頭に入れておくといいと思います。
【レベル3のまとめ】
レベル4:注目

さて,横に展開されたネットワークは,全てに同じ振る舞いを示します。ですので,1つに注目すれば十分なのです。そこで,レベル1と同じように
レベル1と異なるのは,前の出力がネットワークに組み込まれている点です。出力がなぜ2つあるのかについては,1つの出力(上に出ている方)はネットワークの吐き出され,もう1つの出力(横に出ている方)はネットワークの内部で使い回されるからです。
【レベル4のまとめ】
レベル5:正規化

さて,徐々にLSTMの内部に突っ込んでいきます。今までブラックボックス化していたLSTMの中身を少し覗いてみましょう。ここで出てくるのが「

任意の実数を取る値を,
え?でも何で正規化する必要があるの?
もっともな質問です。LSTMに限った話ではありませんが,NNでは内部のパラメータが撮りうる値を制限した方が学習が上手くいきます。結局,NNの学習原理は「誤差逆伝播」です。これは何かというと,NNが吐き出した値を正解の値の誤差を「微分」という操作を介して入力まで遡って誤差を正していく手法のことを指します。
イメージしてほしいのですが,出力から入力まで遡っていくときに内部のパラメータは好き勝手な値を取ってしまうとどのようなことが起こるでしょうか。誤差の情報が「発散」してしまったり「消失」してしまったりしますね。
これを防ぐためにも,
【レベル5のまとめ】
レベル6:状態

ここで,新しい概念が登場します。「状態」と呼ばれるものです。図中では
これから内部でごにょごにょ情報をいじっていきますが,その結果を内部状態
【レベル6のまとめ】
コラム:WhereとHow

ここで一旦コラムに入ります。説明するのは,LSTMにおける「Where」 と「How」です。Whereとは「どこの情報をいじるのか」を示し,Howとは「どのように情報をいじるのか」を示します。
どこの情報をいじるのか,というのはいわばindex情報とも言い換えることができます。そして,あるベクトルのindexを指定するという操作は,算数でいうと「掛け算」に相当します。ベクトル同士の演算で言えば「要素積」に相当します。
例えば,長さ
すると,ベクトル
ここで気づいてほしいのですが,掛け算に利用する要素は重要度の確率的な意味合いを持ちます。そのため,とりうる値は
Sigmoidに関してはこちらの記事「【初学者向け】SoftmaxとSigmoidの関係とは。」で詳しく解説しています。Sigmoidは
以上のお話から分かる通り,LSTMでは基本的に「Sigmoid」と「掛け算(要素積)」はセットで扱われます。なぜなら,重要度の確率値的な値を算出するためにSigmoidを利用し,実際に必要な値を取ってくるために要素積を利用するからです。
LSTMにおいては,もう1つセットで扱われる演算と関数があります。それが「How」に関する情報の伝達です。Howを実現する演算は「足し算」です。なぜなら,値を増やしたいなら大きい数を足せばよいですし,値を減らしたいなら小さい負の数を足せばよいからです。ベクトルで表すならば,要素ごとの足し算ということになります。
先ほど
ですので,どのような値を足し算するのかに関しては「
【コラムのまとめ】
●Where?(どこの情報をいじるのか)
Sigmoid関数「
●How?(どのように情報をいじるのか)
正規化関数「
レベル7:忘却

コラムの内容をぶんぶん振り回します。どこをいじるのかという「Where」の情報は,Sigmoid「

そして,Sigmoidで算出されたindex情報は,掛け算(要素積)によって内部状態
LSTMでは,この操作で「どこの情報を忘れるんか」を学習します。ですので,忘却ゲートなんて呼ばれることもあります。数式に関して言えば,まだ出力
【レベル7のまとめ】
レベル8:更新

忘れさせたからには,今度は情報を付け加えなくてはなりません。ここでも,忘却のときと全く同じ操作を行います。つまり,Sigmoidによってindexの重要度を算出し,要素積によって実際に計算を行います。
そして,今回はここで終了というわけではありません。内部状態に「どのように値を更新するのか」という情報を伝達しなくてはなりません。コラムを思い出してほしいのですが,Howに対しては
ここで先ほどのWhereに対する操作とコラボレーションを起こします。

この操作によって,内部状態に「どのように値を更新するか」という情報が付け加えられます。ちなみに「
【レベル8のまとめ】
レベル9:最終出力

最後は,今まで蓄えてきた内部情報

これまた例によって,要素積を取る前には
例えば,一番左のSigmoidは以下のような重み付け和とバイアス値の加算を表しています。
一方,ただの丸で表されている箇所は演算を表します。「+」マークや「

各区間におけるベクトルを上の図のように定義します。すると,以下のように入出力関係を表すことができます。
あ…!この式!
はい!最初にお見せしたPytorchの公式ドキュメントの式と同じですね。

これで,ようやくLSTMのブラックボックスを説明し終わりました。最後に,簡単にLSTMを振り返っておきます。再び,LSTMの内部をブラックボックス化すれば,以下のような図に表されることが分かります。

Pytorch実装
最後に,簡単なLSTMの実装方法をお伝えしておきます。主要部分だけ抜き出していきます。以下のようなネットワークを設計します。

LSTMネットワークのクラス定義
Pytorchのネットワークの記法にしたがってLSTMを用いたネットワークを記述していきます。LSTMのパラメータとして「input_dim, hidden_dim, num_layers, batch_first, bidirectional」があります。input_dimとoutput_dimは各データの次元数です。データの個数ではないため注意が必要です。
num_layersはLSTMを何層重ねるかといった引数です。batch_firstは入力テンソルの軸の順番を「(バッチ数,特徴量次元数,データ数)」とするための引数です。デフォルトでは「(データ数,バッチ数,特徴量次元数)」です。何故かこんな順番なのかは謎です。forwardにはLSTMに通した後にFC層に突っ込むということを書きます。
学習
学習for文の中では,いつも通りmodelをインスタンス化して呼び出せばOKです。
[1] Hochreiter, Sepp, and Jürgen Schmidhuber. “Long short-term memory.” Neural computation 9.8 (1997): 1735-1780.
[2] http://colah.github.io/posts/2015-08-Understanding-LSTMs/
めちゃくちゃ分かりやすかったです!
tgaさま
非常に嬉しいお言葉ありがとうございます…!!
業務で取り組む機会がありLSTMの記事をいろいろと読みあさりましたが、書いていただいた記事が最もわかりやすかったです!
ありがとうございます!
くり様
ありがたいお言葉身に沁みます。拙い箇所もありますので、どうぞご参考程度におさめていただけますと幸いです。