アカデミック

【初学者向け】Gumbel Sigmoidの導出。

この記事では,Gumbel Sigmoidの導出方法についてお伝えしていきます。本来,誤差逆伝播可能な形でArgmax処理を行おうというモチベーションで考案された手法がGumbel分布からのサンプリングを利用したGumbel Max Trick[1]です。本記事では,多クラス分類を前提としたGumbel Softmaxを単なる1クラスの閾値処理に利用するためのGumbel Sigmoidを導出してみたいと思います。

Gumbel Softmaxにこのような処理を加えれば,閾値処理に利用できるGumbel Sigmoidが導出できるよ!

と堂々とお話し(自慢)できることを目標に執筆していきます。分かりやすさを重視しているため,正確性に欠ける表現もありますが大目にみてください。

間違えている箇所がございましたらご指摘ください。随時更新予定です。他のサーベイまとめ記事はコチラのページをご覧ください。

結論

まずは結論からです。ニューラルネットワークの活性化関数に通す前の出力を$\phi$,Gumbel分布からのサンプリングを$g_1, g_2$,温度パラメータを$\tau$とすれば,Gumbel Sigmoidの出力値$y$は以下のように表されます。

【Gumbel Sigmoid】

\begin{align}
y &= {\rm sigmoid}\left( (\phi + g_1 – g_2)/\tau \right)
\end{align}

は!?

となると思います。今から説明していきます。

Gumbel分布

文献[1]で示されている通り,Gumbel分布からのサンプリングは以下のように表されます。

\begin{align}
u &\sim \rm{Uniform}(0, 1)\\
g &= -\log\{ -\log (u) \}\\
y_i &= \frac{\exp\{ \left( \log(\pi_i)+g_i \right)/\tau \}}{\sum_{j=1}^K \exp\{ \left( \log(\pi_j)+g_j \right)/\tau \}}\\
&= {\rm Softmax}\left( \left( \log(\pi_i) + g_i \right) /\tau \right)
\end{align}

このGumbel分布からのサンプリングを,Sigmoidを使って表したいというのが本記事のモチベーションです。こちらの記事「【初学者向け】SoftmaxとSigmoidの関係とは。」でも説明している通り,SigmoidはSoftmaxにおいて$K=2$,$x_0=x$,$x_1=0$とした場合に相当します。($x_0$と$x_1$は2クラス版Softmaxに対する2つの入力値を示しています。詳しくは上の記事をご覧ください。)

つまりは,クラス数を$2$にして,片方の入力を$0$に固定してしまえば,SigmoidをSoftmaxとして捉えることができるということです。今回も,そのアイディアを利用します。しかし,上記定義式は,Softmaxの入力をlog確率にしています。そのため,上の記事で展開した議論が成り立たず(logをとってしまえば確率値の余事象の関係が成立しなくなりますよね),このままでは先ほどのアイディアを流用することはできません。

そこで,まず最初に「log確率を任意の変数(実数)に置き換えてもGumbel Softmaxは機能する」ことを示します。

log確率を任意の変数に置き換えられる証明

変数を整理します。$\boldsymbol{\phi}$をニューラルネットワークの活性化関数に噛ませる前の生の出力,$\boldsymbol{\pi}$をニューラルネットワークの出力にSoftmaxを噛ませた確率値(クラス数$K$に関して足せば$1$になります)とします。

\begin{align}
\boldsymbol{\phi} &= \{ \phi_1,\cdots, \phi_K \}\\
\boldsymbol{\pi} &= \{ \pi_1,\cdots, \pi_K \}\\
\pi_i &= \frac{\exp \phi_i}{\sum_{j=1}^{K} \exp \{ \phi_j \}}
\end{align}

今示したいことは「log確率である$\log \pi_i$は$\phi_i$に置き換えられるのか」という点です。そのためには,$\pi_i$を先ほどのGumbel Softmaxの定義式に代入して,同様の形式の式が得られればOKということになります。

実際に計算してみます。煩雑になりそうなため,Softmaxの分母が定数として

\begin{align}
\sum_{j=1}^{K} \{\exp \{ \phi_j \} \}&= A
\end{align}

と表記することにします。また,以下のように$GS(\log \pi_i)$を表記します。GSはGumbel Softmaxの頭文字です。

\begin{align}
GS(\log \pi_i) &= \frac{\exp\{ \left( \log(\pi_i)+g_i \right)/\tau \}}{\sum_{j=1}^K \exp\{ \left( \log(\pi_j)+g_j \right)/\tau \}}
\end{align}

すると,示したいことはシンプルで「$GS(\log \pi_i)=GS(\phi_i)$」を示せばOKということになります。

\begin{align}
GS(\log(\pi_i))
&= \frac{\exp\{ \left( \log(\pi_i)+g_i \right)/\tau \}}{\sum_{j=1}^K \exp\{ \left( \log(\pi_j)+g_j \right)/\tau \}}\\
&= \frac{\exp\{ \left( \log(\frac{\exp \phi_i}{N})+g_i \right)/\tau \}}{\sum_{j=1}^K \exp\{ \left( \log(\frac{\exp \phi_i}{N})+g_j \right)/\tau \}}\\
&= \frac{\exp\{ \left( \phi_i – \log N +g_i \right)/\tau \}}{\sum_{j=1}^K \exp\{ \left( \phi_i – \log N + g_j \right)/\tau \}}\\
&= \frac{\exp \{-\frac{\log N}{\tau} \} \exp\{ \left(\phi_i +g_i \right)/\tau \}}{\exp \{-\frac{\log N}{\tau} \} \sum_{j=1}^K \exp\{ \left( \phi_i +g_j \right)/\tau \}}\\
&= \frac{ \exp\{ \left( \phi_i +g_i \right)/\tau \}}{\sum_{j=1}^K \exp\{ \left( \phi_i +g_j \right)/\tau \}}\\
&= GS(\phi_i)
\end{align}

無事,「$GS(\log \pi_i)=GS(\phi_i)$」を示せました。

SigmoidをSoftmaxとして捉える

さて,これで先ほどのアイディア「クラス数を$2$にして,片方の入力を$0$に固定してしまえば,SigmoidをSoftmaxとして捉えることができる」を適用する準備が整いました。まずは,変数を揃えましょう。片方の変数を$\phi$,もう片方の変数を$0$とします。

実際に$GS(\log \pi_i)=GS(\phi_i)$に$i=0,1$と$\phi_0=\phi$,$\phi_1=0$を代入します。求める値を$y$とすれば,今回知りたいのは$\phi$に関しての出力であるため,$y=GS(\phi_0)=GS(\phi)$となります。

\begin{align}
y &= GS(\phi_0)\\
&= GS(\phi)\\
&= \frac{\exp\{ (\phi + g_1) / \tau\}}{\exp\{ (\phi + g_1) / \tau \} + \exp\{ (g_2) / \tau \}}\\
&= \frac{1}{1+\exp \{ \frac{g_2}{\tau} – \frac{\phi + g_1}{\tau} \}}\\
&= \frac{1}{1+\exp \{ -(\phi + g_1 – g_2)/\tau \}}\\
&= {\rm sigmoid}\left( (\phi + g_1 – g_2)/\tau \right)
\end{align}

以上で,Gumbel Sigmoidの導出が完了しました。

参考文献

[1] Jang, Eric, Shixiang Gu, and Ben Poole. “Categorical reparameterization with gumbel-softmax.” arXiv preprint arXiv:1611.01144 (2016).APA

ABOUT ME
zuka
京都大学で機械学習を学んでいます。

COMMENT

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

※ Please enter your comments in Japanese to prevent spam.