この記事は, Pythonを利用して研究を行なっていく中で私がつまずいてしまったポイントをまとめていくものです。同じような状況で苦しんでいる方々の参考になれば嬉しいです。Pythonつまずきポイント集の目次は以下のページをご覧ください。
【超初心者お悩み解決】Pythonつまずきポイント記事まとめページ
この記事は,Pythonを利用して研究を行なっていく中で私がつまずいてしまったポイントをまとめていくものです。同じような状況で苦しんで...
スポンサーリンク
環境
●Ubuntu 18.04
●Python 3.7.3
●conda 4.8.3
●pytorch 1.2.0
実現したいこと
転移学習やファインチューニング,知識蒸留などで学習済みモデルの更新を止めたい。だけど,やり方が分からない。
方法
for param in model.parameters():
param.requires_grad = False
全ての重みパラメータの更新を止めたい場合,これで一発OKです。一部の重みパラメータを止めたい場合は,
lstm.weight_hh_l0.requires_grad = False
というように直接重みの指定をしてあげます。止めたい重みの名前は,以下のようにしてprintしてあげると分かりやすいと思います。LSTMの命名規則は独特であるため,注意が必要です。(lstm.weight_ih_l0という具合)
for param in model.parameters():
print(param)
注意点
Pytorchの「model.eval()」には,モデルの更新を止める働きはありません。ここに私はハマりました。評価モードにしているから,モデルの更新はもちろん止まっているだろうと思っていたのです。転移学習や知識蒸留などを行うときには,モデルの更新をストップしないと,知識が平滑化されてしまう恐れがあります。注意が必要です。