アカデミック

【Pythonお悩み解決】Pytorchのモデル更新を止めたい。

この記事は, Pythonを利用して研究を行なっていく中で私がつまずいてしまったポイントをまとめていくものです。同じような状況で苦しんでいる方々の参考になれば嬉しいです。Pythonつまずきポイント集の目次は以下のページをご覧ください。

【超初心者お悩み解決】Pythonつまずきポイント記事まとめページ この記事は,Pythonを利用して研究を行なっていく中で私がつまずいてしまったポイントをまとめていくものです。同じような状況で苦しんで...

本記事で紹介する解決策がBestという保証はできません。正確な情報を発信するように心掛けていますが図らずも誤った情報を記載してしまう場合があります。もしご指摘等がありましたら,コメント欄またはお問い合わせページよりご連絡下さい。

読みたい場所へジャンプ!

環境

●Ubuntu 18.04
●Python 3.7.3
●conda 4.8.3
●pytorch 1.2.0

実現したいこと

転移学習やファインチューニング,知識蒸留などで学習済みモデルの更新を止めたい。だけど,やり方が分からない。

方法

for param in model.parameters():
    param.requires_grad = False

全ての重みパラメータの更新を止めたい場合,これで一発OKです。一部の重みパラメータを止めたい場合は,

lstm.weight_hh_l0.requires_grad = False

というように直接重みの指定をしてあげます。止めたい重みの名前は,以下のようにしてprintしてあげると分かりやすいと思います。LSTMの命名規則は独特であるため,注意が必要です。(lstm.weight_ih_l0という具合)

for param in model.parameters():
    print(param)

注意点

Pytorchの「model.eval()」には,モデルの更新を止める働きはありません。ここに私はハマりました。評価モードにしているから,モデルの更新はもちろん止まっているだろうと思っていたのです。転移学習や知識蒸留などを行うときには,モデルの更新をストップしないと,知識が平滑化されてしまう恐れがあります。注意が必要です。

COMMENT

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です