アカデミック

【Pythonお悩み解決】binary_cross_entropyとbinary_cross_entropy_with_logitsの挙動がおかしい

この記事は, Pythonを利用して研究を行なっていく中で私がつまずいてしまったポイントをまとめていくものです。同じような状況で苦しんでいる方々の参考になれば嬉しいです。Pythonつまずきポイント集の目次は以下のページをご覧ください。

【超初心者お悩み解決】Pythonつまずきポイント記事まとめページ この記事は,Pythonを利用して研究を行なっていく中で私がつまずいてしまったポイントをまとめていくものです。同じような状況で苦しんで...

本記事で紹介する解決策がBestという保証はできません。正確な情報を発信するように心掛けていますが図らずも誤った情報を記載してしまう場合があります。もしご指摘等がありましたら,コメント欄またはお問い合わせページよりご連絡下さい。

読みたい場所へジャンプ!

環境

●Ubuntu 16.04
●Python 3.7.3
●conda 4.7.12
●pytorch 1.2.0

現象

原因

binary_cross_entropy_with_logitsを本来の用途で用いる場合には問題ありませんでした。しかし,binary_cross_entropyと同様のはたらきをさせるために,pos_weightにtorch.ones([(クラス数)])を指定したところ,binary_cross_entropyとは異なる挙動を示しました。

こちらに関しては,原因は不明です。本来,binary_cross_entropy_with_logitsはラベルに偏りがある場合に,正例に重み付けをするようなはたらきをします。ですので,重みであるpos_weightにtorch.ones([(クラス数)])を指定すればbinary_cross_entropyと等価になるはずです。

しかし,私の環境ではなぜかbinary_cross_entropy_with_logitsの挙動がおかしく,バックプロップが安定しませんでした。それどころか,各イテレーションのvalidationでF値が0.2%まで低下するなど,明らかにおかしい挙動を示しました。

解決方法

素朴なbinary_cross_entropyを利用したい場合は,binary_cross_entropy_with_logitsで重みを1に指定するのではなく,素直にbinary_cross_entropyを利用しましょう。

COMMENT

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です