本記事の内容は新ブログに移行されました。

新しい記事へ

こちらのブログにコメントをいただいても
ご返信が遅れてしまう場合がございます。
予めご了承ください。

ご質問やフィードバックは
上記サイトへお願い致します。

今流行りの深層生成モデルを実装したい!

チュートリアルも大体終わったし次は実装!

今回は,深層生成モデルの一種であるVAE(Variational Autoencoder)をPythonで実装する方法をお伝えしていこうと思います。

本記事はpython実践講座シリーズの内容になります。その他の記事は,こちらの「Python入門講座/実践講座まとめ」をご覧ください。また,本記事の実装は以下のサイトを参考にさせていただきました。ありがとうございます。

コーディングに関して未熟な部分がたくさんあると思いますので,もし何かお気づきの方は教えていただけると幸いです。また,誤りについてもご指摘していただけると非常に助かります。

読みたい場所へジャンプ!

VAEの概要

VAE(変分オートエンコーダ)とは,簡単にまとめると以下のような手法です。

●潜在変数モデルにおけるモデルエビデンスの推論方法の1つ
●入力を潜在空間上の特徴量で表す(エンコーダ)
●潜在空間から元の次元に戻す(デコーダ)
●潜在空間には何かしらの分布を仮定

潜在空間としては,分布のパラメータを設定します。例えば,潜在空間にガウス分布を仮定した場合,エンコーダでは潜在空間の「平均」と「分散」を学習するようにします。(出力ユニット数が2つのニューラルネットワークになるということです。)

ネットワークの構造

VAEのネットワークは「エンコーダ部」と「デコーダ部」に分かれます。よく勘違いされるのですが,VAEは「確率分布のパラメータ」を出力しているのであって,値そのものを出力しているわけではありません。

VAEの構造

一方で,デコーダ部に関しては再構成データをそのまま出力するモデルも存在します。一般にVAEと呼ぶときは,以下のネットワークを指していることが多いようです。MNISTなどで学習を行うときは,デコーダの出力はシグモイドにかけて[0,1]の値域にします。

VAEの構造バージョン2

また,実用上以下のような分布の仮定を置くことが多いです。潜在空間には平均が$\boldsymbol{0}$で共分散行列が単位行列の標準多次元ガウス分布を仮定します。また,エンコーダ部にもガウス分布を仮定することで目的関数を解析的に求めることができます。(以下で説明します)

\begin{eqnarray}
q_{\phi}(\boldsymbol{z}|\boldsymbol{x}) &\sim& \mathcal{N}(\boldsymbol{z};\boldsymbol{\mu}_{\phi},\boldsymbol{\sigma^2}_{\phi})\\
p_{\theta}(\boldsymbol{z}) &\sim& \mathcal{N}(\boldsymbol{z};\boldsymbol{0},\boldsymbol{I})
\end{eqnarray}

大切なのは,デコーダ$p_{\theta}$に関する分布は自分たちで定める必要があるという点です。例えば,MNISTのような画像を対象とする場合は,ベルヌーイ分布を仮定してシグモイドを通すのが適しています。他には,音声のスペクトログラムなどを扱う場合は,デコーダ$p_{\theta}$にもガウス分布を仮定してしまいます。

VAEの学習(目的関数)

VAEの学習は,生成器$p_\theta(x)$の対数周辺尤度最大化です。しかし,$p_\theta(x)$は一般には計算できないため,潜在変数を噛ませて変分下界を最大化するという方向性で考えていきます。イエンセンの不等式より,変分下界の式を作ります。変分下界を$L(x; \varphi, \theta)$とおくことにします。

\begin{eqnarray}
\log p_\theta(x) &=& \log \int p_\theta(x, z) dz \\
&=& \log \int q_\varphi(z|x)\frac{p_\theta(x, z)}{q_\varphi(z|x)} dz \\
&\geq& \int q_\varphi(z|x) \log \frac{p_\theta(x, z)}{q_\varphi(z|x)} dz \\
&=& L(x; \varphi, \theta)
\end{eqnarray}

実は,変分下界(右辺)と対数周辺尤度(左辺)の差は識別器$q_\varphi (z|x)$と生成器$p_\theta (x|z)$のKLダイバージェンスになります。実際に計算してみると,以下のようになります。二つの項をくくり出すために積分して1になる$\int q_\varphi (z|x) dz$を持ち出す点がかなりトリッキーです。

\begin{eqnarray}
\log p_\theta(x) – L(x; \varphi, \theta) &=&  \log p_\theta(x) – \int q_\varphi(z|x) \log \frac{p_\theta(x, z)}{q_\varphi(z|x)} dz \\
&=& \log p_{\theta}(x) \int q_{\varphi} (z|x) dz – \int q_{\varphi} (z|x) \log \frac{p_{\theta} (z|x)p(x)}{q_{\varphi}(z|x)} dz \\
&=& \int q_\varphi (z|x) \{ \log p_{\theta}(x) – \log p_\theta(z|x) – \log p_{\theta}(x) + \log q_\varphi (z|x) \} dz\\
&=& \int q_\varphi (z|x) \{ \log q_\varphi (z|x) – \log p_\theta(z|x) \} dz\\
&=& KL[q_\varphi (z|x) \| p_\theta (z|x)]
\end{eqnarray}

KLダイバージェンスは距離関数の一種で,必ず非負の値を取り,KLダイバージェンスが非負の値を取るため,結局対数周辺尤度を最大化することは,変分下界の最大化と等価になります。さて,変分下界を改めて計算し直してみましょう。(ベイズの定理を利用して$\log$で分解しています。)

\begin{eqnarray}
L(x; \varphi, \theta) &=& \log p_\theta(x) – KL[q_\varphi (z|x) \| p_\theta (z|x)] \\
\nonumber\\
&=& \log p_\theta(x) – E_{q_\varphi (z|x)}[\log q_\varphi(z|x) – \log p_\theta (z|x) ] \\
\nonumber\\
&=& \log p_\theta(x) – E_{q_\varphi (z|x)}[\log q_\varphi(z|x) – \log p_\theta (x|z) – \log p_\theta(z) + \log p_\theta (x)] \\
\nonumber\\
&=& E_{q_\varphi (z|x)}[\log p_\theta (x|z)] – KL[q_\varphi (z|x) \| p_\theta (z)]
\end{eqnarray}

変分下界を2つの項で表せましたね。期待値の方は,$p_\theta$(生成器)の分布を仮定してしまって計算すればOKです。例えば,二値画像の分類では出力は[0,1]であるので,生成器にベルヌーイ分布を仮定することが多いです。この場合,第1項目は以下のようになります。ただし,$f$は活性化関数を,$L$は潜在変数の次元を表しています。

\begin{eqnarray}
E_{q_\varphi (z|x)}[\log p_\theta (x|z)]
&=& E_{q_\varphi (z|x)}[\log \prod_l^{L} f(z_l)^x (1 – f(z_l))^{(1 – x)}] \\
&=& \frac{1}{L} \sum_{l=1}^L \{ x \log f(z_l) + (1 – x) \log (1 – f(z_l)) \}
\end{eqnarray}

上の式では,デコーダの出力の各次元$f(z_i)$に対して,ベルヌーイ分布を仮定しています。さらに,式(16)から(17)でモンテカルロ近似を利用しています。これは,次に,簡単に言えば期待値の積分計算を離散で有限サンプルの平均で表してしまおうという近似になります。

$z$に関する期待値なので$z$の次元数でならすというイメージです。

次に,KLダイバージェンスの方を考えてみましょう。こちら,真面目に計算すると非常に面倒臭いです。($p_\theta (z)$と$q_\varphi (z|x)$に正規分布を仮定したとしても計算が煩雑になります)以下のPRML記事解説で計算過程はお伝えしていますので,参考にしていただければと思います。

【第2章確率分布】PRML演習問題解答を全力で分かりやすく解説<2.13>本記事はPRML「パターン認識と機械学習<上>第7版」(C.M.ビショップ著)の演習問題の基本問題・標準問題を解説したページになります。...

以下では,文献[1]のAppendix Bにならって$p_\theta (z)$に$\mathcal{N}(\boldsymbol{z}; \boldsymbol{0}, \boldsymbol{I})$,$q_\varphi (z|x)$に$\mathcal{N}(\boldsymbol{z};\boldsymbol{\mu}, \boldsymbol{\sigma}^2)$を仮定します。結果は,このようになります。

\begin{eqnarray}
-KL[q_\varphi (z|x) \| p_\theta (z)]
&=& \frac{1}{2} \sum_{l=1}^L (1 + \log \sigma^2 – \mu^2 – \sigma^2)
\end{eqnarray}

以上をまとめると,変分下界は

\begin{eqnarray}
L(x; \varphi, \theta)
= &\frac{1}{L}& \sum_{l=1}^L \{ x \log f(z_l) + (1 – x) \log (1 – f(z_l)) \} \nonumber\\
&&+ \frac{1}{2} \sum_{l=1}^L (1 + \log \sigma^2 – \mu^2 – \sigma^2)
\end{eqnarray}

と書くことができます。こちらの変分下界を最大化することが,対数周辺尤度の最大化と等価になるのでした。実際の実装でも,こちらの式を計算しています。

問題点

これでネットワークの学習ができるね!

と考えたあなた!まだまだ先があるんです…。よくよく考えてみると,エンコーダの出力は,どのようにして潜在変数のパラメータとして利用されるのでしょうか。

エンコーダは,潜在空間の分布のパラメータを出力しているわけですよね。そしたら,そのパラメータを利用して潜在空間を定義すればよいのでしょうか。もし,潜在空間を定義できたとしても,そこから$z$を得る作業(サンプリング)が必要になってしまいます。

ああ!ギブスサンプリングとか使えばええんちゃう?

Nooooです。サンプリングを行ってしまうと,誤差を逆伝播することが不可能になってしまうからです。ですから,$z$の値は決定的に定めなくてはなりません。そこで,編み出された妙案がこちらの式です。

\begin{eqnarray}
z = \mu + \epsilon \sigma
\end{eqnarray}

ただし,$\epsilon \sim \mathcal{N} (0, I)$とします。つまり,分布を仮定してサンプリングするのではなく,zというのは平均値にノイズ項を加えたものですよと近似してしまうというアイディアです。こちらの式は決定的に定まりますから,誤差の逆伝播を遮らずに済みます。ここが,VAEのアルゴリズムの中でも大切な部分です。

誤差逆伝播は命!

よくある質問

ここでは,本記事でいただいたご質問とその回答をまとめていきます。

質問1:モデルのどこに単位行列を仮定しているのか

デコーダの事前分布です。数式を用いれば,デコーダが条件付き確率で表される点がポイントです。つまり,デコーダは$p_{\theta}(x|z)$であって,$p_{\theta}(z)$はデコーダの事前分布です。また,エンコーダ・デコーダ型のモデルでは,デコーダの事前分布を「潜在空間」と呼ぶことが多いです。つまり,単位行列を仮定しているのはデコーダの事前分布(潜在空間)ということになります。本文中の式(2)です。

質問2:式(1)の示す分布の正体

式(1)は$q_{\phi}$に関する分布を定めていますので,エンコーダに関する分布です。これは,条件付き確率を意識すると分かりやすいと思います。つまり,$q_{\phi}(z|x)$は「$x$から$z$を生成する」と捉えればOKです。対して,条件付けられていない確率密度関数は事前分布を表すことが多いです。ここら辺は,ベイズ推論の考え方に通じるところがあります。以下の記事をぜひご参照ください。

【初学者向き】ベイズ推論とは?事前分布や事後分布を分かりやすく解説してみます!
【初学者向き】ベイズ推論の学習と予測とは?1次元ガウス分布を例に解説してみます!
【これなら分かる!】変分ベイズ詳解&Python実装。最尤推定/MAP推定との比較まで。

質問3:$p_{\theta}$の正体

$p_{\theta}$はデコーダに関する分布を表しています。デコーダの事前分布として$z$を考えるときに,$z$に関する分布の情報を加味する必要が出てきます。ここは少し頭が混乱するところですよね。$p_{\theta}$が吐き出す$x$は再構成された$\hat{x}$(のパラメータ)であることに注意が必要です。また,$p_{\theta}(x)$の分布は分かりません。なぜなら,それが分かればこんなに苦労してモデルを組み立てる必要がないからです。逆に言えば,最も汎用的にフィットするような$p_{\theta}(x)$を求めるのが私たちの目的です。

質問4:「デコーダ$p_{\theta}$に関する分布は自分たちで定める必要がある」のに「$p_{\theta}(x)$は一般には計算できない」ことの説明

日本語が下手くそでした…。確かに分かりにくいですね。「デコーダ$p_{\theta}$に関する分布」をより詳しく言えば「$p_{\theta}(x|z)$」と「$p_{\theta}(z)$」です。これらは,$p_{\theta}(x)$に潜在変数を噛ませてベイズの定理を利用することで,いわば無理やり「一般的には計算できない$p_{\theta}(x)$に関する裏の情報」を仮定しているわけです。このような背景から,「デコーダ$p_{\theta}$に関する分布は自分たちで定める必要がある」と記述しました。

質問5:潜在変数の依存関係

実は,潜在空間はエンコーダ・デコーダの両方に依存します。なぜなら,エンコーダでは$q_{\phi}(z|x)$,デコーダでは$p_{\theta}(z)$で$z$が分布の式に登場しているからです。若干語弊をうむことにはなりますが,基本的に分布は登場する変数に依存します(どれを主役に取るかの話)。

ただし,「条件付き確率」としての依存を指している場合は,デコーダに依存します。条件付き確率の右側に$z$が出てくるのはデコーダだけだからです。お答えとしては,条件付き確率として考える場合はデコーダのみに依存,VAEの原理として捉える場合は両方に依存します。

確かに$z$を生成するのはエンコーダの役割ですが,VAEの基本原理に立ち返ってみると「再構成誤差+潜在空間の分布の良さ」を誤差関数として学習を回していきますので,$z$を変数として含むエンコーダ・デコーダに$z$は依存します。イメージとしては,エンコーダが入力を元にして潜在空間にデータをマッピングします。そして,マッピングされたデータを元にしてデコーダが入力と同じようなデータを再現します。このとき,「どれだけ入力と同じような出力がなされているか」と「どれだけエンコーダがマッピングした分布が標準正規分布(多次元)に近いか」によってVAEは学習していきます。これらは,両方とも$z$に依存しているため,VAEの原理として捉える場合は$z$はエンコーダ・デコーダの両方に依存するのです。

質問6:関数とパラメータが同じであれば同じ分布なのか

分布は必ずしも同じとは限りません。$p$と$q$は分布としては同じ種類ですが,各パラメータの値は異なります。つまり,母集団が従う分布の形は同じと仮定していますが,取ってきた結果はそれぞれ異なっているということです。

質問7:エンコーダは高次元の空間から低次元の空間への射影なのか

概ねその通りです!というのも,エンコーダ・デコーダ型のモデルの意義は,低次元の潜在空間への変換だからです。なぜ低次元に変換するのかというと,それだけ情報が凝縮され洗練されるからです。この操作をニューラルネットワークさんが勝手にしてくれるというのはかなり大きいです。従来は主成分分析を利用して行っていましたからね…。オートエンコーダが主成分分析の非線型変換であると捉えられているのもそのためです。逆に言えば,エンコーダ・デコーダ型のモデルで潜在空間を元の次元よりも大きくするようなモデルはあまり見かけません(エンコーダ・デコーダ型の良さを使えないモデルになってしまいます)。

質問8:エンコーダの操作は『低次元の空間へ埋め込み』とも呼ばれるのでしょうか。

正しい理解だと思います!ただ,埋め込みという用語はあまり使われない印象です。なぜなら,埋め込み(Embed)は専ら自然言語処理の分散表現に用いられる用語だからです。ここでは,変換という言葉をよく用いる印象を受けます。

質問9:潜在変数は等式で生成されるのでしょうか

はい。そういうことになります。ただし,本来であればエンコーダが出力したパラメータを元に$z$が自然に生成されるべきなのですが,本文中にもある通り「自然に生成される」というランダム操作を組み入れてしまうとニューラルネットワークの誤差伝播が途切れてしまいますので,等式「$=$」によって$z$を生成(決定)しています。これを「Reparameterization Trick」と呼び,VAEの発案者Kingma先生が考案された妙技です。

質問10:ギブスサンプリングはやっぱり必要なんじゃないか

ギブスサンプリングが要るか要らないか(使えるか使えないか)というのは,「モデルが誤差逆伝播によって学習しているかどうか」によって決められます。おっしゃる通り,解析的な数式によってモデル化できる場合はギブスサンプリングする必要はありません。ただし,ニューラルネットワークの枠組みではまた話は違ってきます。例えば,ベイズ推論で事前分布に共役事前分布を設定することが多いのは,解析的にパラメータの更新式を求めるためです。解析的に更新式が求まれば,サンプリングの必要はなくなります。しかし,今回のモデルはニューラルネットワークです。ベイズ推論とは学習の方法が異なりますので(深層ベイズモデリングは別として),ニューラルネットワークを誤差逆伝播で学習させる場合にはサンプリングという確率的な操作があってはならないのです。ですので,VAEではなくエンコーダ・デコーダ型のモデルを誤差逆伝播以外の方法で学習させる場合で,仮定した分布を解析的に解くことができないときには,おっしゃる通りサンプリングを利用するほかないと思います。しかし,このような状況はあまり起こりません。なぜなら,「サンプリングを行うしかない」という状況を作らないように研究者たちは努めているからです。

質問11:結局$p_{\theta}(\hat{x})$を知りたいってことじゃないの?

VAEの数式には$\hat{x}$は出現しません。なぜなら,VAEの基本的なモデリングはデコーダの出力が再構成データの分布の「パラメータ」であり,再構成データそのものではないからです。例えば,デコーダ$p_{\theta}$が入力データと全く異なるデータを出力したとしても,$p_{\theta}$を目的関数にしてしまえばデタラメなデコーダの出力を「尤もらしい」と判断するようなモデルが完成してしまい,オートエンコーダとして成り立たなくなってしまいます。VAEの正しい目的関数は$p_(x)$です。つまり,「デコーダ分布はどれだけ入力データを確からしいと判断できるか」がVAEの目的関数ということなのです。

質問12:Decoderの出力分布がEncoderの入力分布を”忠実に”再現するように学習するってこと?

こちらは,VAEの目的関数が$p_{\theta}(x)$であることに注意すれば分かりやすいと思います。イメージではほぼ同じように思えますが,VAEの目的関数は入力するデータの分布と出力されるデータの分布を近づけるような目的関数ではないです。つまり,「Decoderの出力分布がEncoderの入力分布を”忠実に”再現する」ように学習しているのではなく,「Decoderの分布がEncoderの入力をどれだけ確からしいと判断するか(確からしいと判断できるようにデコーダを形成するのがオートエンコーダ流派の基本思想です)」+「潜在空間がどれだけ仮定した分布に近づいているか」の二つの項からVAEは学習されます。

質問13:$p_{\theta}(\hat{x})=p_{\theta}(\hat{x}|z) p_{\theta}(z)$で目的関数が求められるのではないか

こちらも,上述の通り$p_{\theta}(\hat{x})$の意味するところが不明になってしまいます。

質問14:$p_{\theta}(x|z), p_{\theta}(z), p_{\theta}(x)$が同じ表記なのは混乱を招くだけでは?

興味深い視点,ありがとうございます。この三者は同じ表記ではなくてはなりません。なぜなら,同じニューロンを通過しているからです。同じ重みパラメータ$\theta$を使って行列演算されているからです。ここは,ニューラルネットの特徴的な部分なのですが,対象を$x$にするのか$z$にするのか,はたまた$z$に条件づけられた$x$にするのかで,表す(裏に仮定される)分布が異なるように「できる」という点です。ニューラルネット,おそるべしです。ここを異なる表記にしてしまうと,対象が「$x$」なのか「$z$」なのか「$z$に条件づけられた$x$」なのかで,入力するニューロンが異なるように学習させなくてはならなくなってしまいます。

質問15:$p_{\theta}(z)$は$q_{\phi}(z|x)$と同じなのではないか

近い分布になると思います。なぜなら,この2つの分布を近づけることがVAEの目標の1つだからです。目的関数の片方の項は$KL[q_\varphi (z|x) \| p_\theta (z)]$ですね。これは2つの分布を近づけるようにVAEを学習しましょうという宣言に他なりません。

質問16:$p_{\theta}(.)$は全て同じ構造を表しているのか

確率変数によって表す分布は異なります。ネットワークのパラメータ$\theta$,$\phi$を利用するという状況下で,対象とする確率変数を変えれば表される分布も変わるような「上手い」パラメータを学習するのがキモです。

質問17:$p_{\theta}(.)$や$q_{\phi}(.)$はNN(ニューラルネットワーク)を表しているか

$p_{\theta}(.)$と$q_{\phi}(.)$は確率分布ですので,ニューラルネットワークとはそもそもの概念としての出発点が異なります。しかし,NNを入力-出力機構として捉えた場合に,ニューラルネットワークの入力と出力に様々な確率分布を仮定することができます。逆に,入力と出力に様々な確率分布を仮定することでニューラルネットワークを学習させることも可能になります。確率的にNNを発展させることにより,ベイズ推論のような議論も可能になります。

質問18:$p_{\theta}(.)$に関する分布を全て$p_{\theta}(.)$で表すのは分かりにくいのではないか

これはやはり,全て$p_{\theta}(.)$として表すことに意味があると私は思います。なぜなら,1つのエンコーダ・デコーダでVAEは構成されているからです。これが複数のエンコーダ・デコーダに拡張されれば,$r_{\omega}(.)$などと表記することになると思います。

質問19:やはりデコーダ側のNNのパラメータ$\theta$が$z$に影響を与えるのは不可能なのではないか

デコーダは$z$を入力として学習していきますので,入力の良さもデコーダのパラメータに影響を与えます。つまり,デコーダのパラメータ更新を行う中で$z$も対応して更新されていきます。

質問20:再構成データはどのようにして生成されるのか

本文中にあるVAEの図の1枚目のようにパラメータを出力する場合はサンプリングなどを利用します。これは,ネットワークの末端ですので誤差逆伝播には影響を与えません。2枚目のように再構成データそのものを出力する場合はデコーダの出力そのものが再構成データになります。

実装

以下では,MNISTを例にとって実装の各パートを眺めていきます。最後に全体のコードを載せたいと思います。

必要なライブラリのインポート

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torch import optim
import torch.utils as utils
from torchvision import datasets, transforms

チュートリアル通りだと思います。

deviceの定義

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

本記事では,GPUを利用する前提でコーディングしていきます。deviceを定義しておきましょう。

データセットのロード

transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Lambda(lambda x: x.view(-1))])

dataset_train = datasets.MNIST(
    '~/mnist', 
    train=True, 
    download=True, 
    transform=transform)
dataset_valid = datasets.MNIST(
    '~/mnist', 
    train=False, 
    download=True, 
    transform=transform)

batch_size = 1000

dataloader_train = utils.data.DataLoader(dataset_train,
                                          batch_size=batch_size,
                                          shuffle=True,
                                          num_workers=4)
dataloader_valid = utils.data.DataLoader(dataset_valid,
                                          batch_size=batch_size,
                                          shuffle=True,
                                          num_workers=4)

チュートリアル通りに読み込みます。

ネットワークの定義

class VAE(nn.Module):
    def __init__(self, x_dim, z_dim):
      super(VAE, self).__init__()
      self.x_dim = x_dim
      self.z_dim = z_dim
      self.fc1 = nn.Linear(x_dim, 20)
      self.bn1 = nn.BatchNorm1d(20)
      self.fc2_mean = nn.Linear(20, z_dim)
      self.fc2_var = nn.Linear(20, z_dim)

      self.fc3 = nn.Linear(z_dim, 20)
      self.drop1 = nn.Dropout(p=0.2)
      self.fc4 = nn.Linear(20, x_dim)

    def encoder(self, x):
      x = x.view(-1, self.x_dim)
      x = F.relu(self.fc1(x))
      x = self.bn1(x)
      mean = self.fc2_mean(x)
      log_var = self.fc2_var(x)
      return mean, log_var

    def sample_z(self, mean, log_var, device):
      epsilon = torch.randn(mean.shape, device=device)
      return mean + epsilon * torch.exp(0.5*log_var)

    def decoder(self, z):
      y = F.relu(self.fc3(z))
      y = self.drop1(y)
      y = torch.sigmoid(self.fc4(y))
      return y

    def forward(self, x, device):
      x = x.view(-1, self.x_dim)
      mean, log_var = self.encoder(x)
      delta = 1e-8
      KL = 0.5 * torch.sum(1 + log_var - mean**2 - torch.exp(log_var))
      z = self.sample_z(mean, log_var, device)
      y = self.decoder(z)
      # 本来はmeanだがKLとのスケールを合わせるためにsumで対応
      reconstruction = torch.sum(x * torch.log(y + delta) + (1 - x) * torch.log(1 - y + delta))
      lower_bound = [KL, reconstruction]
      return -sum(lower_bound), z, y

ネットワークを「__init__」「_encoder」「_sample_z」「decoder」「forward」「loss」で定義しています。「__init__」にはdenseを利用した定義を,「_encoder」には潜在空間のパラメータを得るまでの定義を,「_sample_z」にはエンコーダで得たパラメータから$z$を計算するための処理を,「decoder」には$z$を入力として元の次元まで再現するネットワークの定義を,「forward」には実際の計算機構を,lossには上でお伝えした変分下界の定義を記述しています。

モデルの学習

model = VAE(x_dim=28*28, z_dim=10).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
model.train()

num_epochs = 20
loss_list = []
for i in range(num_epochs):
  losses = []
  for x, t in dataloader_train:
      x = x.to(device)
      loss, z, y = model(x, device)
      model.zero_grad()
      loss.backward()
      optimizer.step()
      losses.append(loss.cpu().detach().numpy())
  loss_list.append(np.average(losses))
  print("EPOCH: {} loss: {}".format(i, np.average(losses)))

生成

fig = plt.figure(figsize=(20, 6))

model.eval()
zs = []
for x, t in dataloader_valid:
    for i, im in enumerate(x.view(-1, 28, 28).detach().numpy()[:10]):
      ax = fig.add_subplot(3, 10, i+1, xticks=[], yticks=[])
      ax.imshow(im, 'gray')

    x = x.to(device)
    y, z = model(x)
    zs.append(z)
    y = y.view(-1, 28, 28)
    for i, im in enumerate(y.cpu().detach().numpy()[:10]):
      ax = fig.add_subplot(3, 10, i+11, xticks=[], yticks=[])
      ax.imshow(im, 'gray')
    
    z1to0 = torch.cat([z[1, :] * (i * 0.1) + z[0, :] * ((10 - i) * 0.1) for i in range(10)]).reshape(10, 10)
    y2 = model._decoder(z1to0).view(-1, 28, 28)
    for i, im in enumerate(y2.cpu().detach().numpy()[:20]):
      ax = fig.add_subplot(3, 10, i+21, xticks=[], yticks=[])
      ax.imshow(im, 'gray')
    break

1行目がデータセットオリジナル。2行目は潜在空間から生成した画像。3行目は,2行目の0番目から1番目に割合を変えながら遷移させていったもの。徐々に7から4に移り変わっていることが読み取れます。また,以下のコードを用いれば潜在空間を可視化することができます。([外部リンク]PyTorchでVAEのモデルを実装してMNISTの画像を生成する

from sklearn.manifold import TSNE
from random import random

colors = ["red", "green", "blue", "orange", "purple", "brown", "fuchsia", "grey", "olive", "lightblue"]
def visualize_zs(zs, labels):
  plt.figure(figsize=(10,10))
  points = TSNE(n_components=2, random_state=0).fit_transform(zs)
  for p, l in zip(points, labels):
    plt.scatter(p[0], p[1], marker="${}$".format(l), c=colors[l])
  plt.show()

model.eval()
zs = []
for x, t in dataloader_valid:
    x = x.to(device)
    t = t.to(device)
    # generate from x
    y, z = model(x)
    z = z.cpu()
    t = t.cpu()
    visualize_zs(z.detach().numpy(), t.cpu().detach().numpy())
    break

先ほど,7から4まで変化させていったときに,間に9のようなイメージが再現されました。これは,上の潜在空間でもみて取れると思います。7が集まっているゾーンから,4が集まっているゾーンに直線を引くと,9が集まっているゾーンを通りますね。

まとめ

VAEの簡単な理論的背景と実装をまとめてみました。理論は,ややトリッキーな点が何箇所かあるものの,全体的に分かりやすいモデルになっていると思います。GANと比べて出力が連続的になりやすいことは有名ですが,たしかに生成された画像はぼやけて見えますね。ラベル教師付きVAE(CVAE)にも注目が集まっていますね。

参考文献


[1] Auto-Encoding Variational Bayes(https://arxiv.org/abs/1312.6114)
[2] PyTorchでVAEのモデルを実装してMNISTの画像を生成する

ABOUT ME
zuka
京都大学で機械学習を学んでいます。