この記事では、基礎的なニューラルネットワークにおける勾配計算について解説します。
今は便利な機械学習ライブラリがたくさん用意されているので、その中で順伝播、逆伝播による勾配計算、パラメータ更新といった処理を自動でやってくれます。そのため、計算の中身についてはあまり意識していない人も多いのではないでしょうか。
そこで今回は、シグモイド関数を用いた二値分類モデルと、ソフトマックス関数を用いたマルチクラス分類モデルの2つについて、誤差逆伝播法によってどのようにコスト関数の勾配が求められるのかを導出しながら説明したいと思います。
※ニューラルネットワークの基礎的な知識は持っていることを想定していますので、不安がある方は以下の用語集などを一度見てみてください。
・活性化関数:活性化関数、シグモイド関数、ソフトマックス関数
・勾配降下法(最急降下法):最急降下法
・誤差逆伝播法:誤差逆伝播法
ニューラルネットワークとは?
まずはニューラルネットワークの順伝播と逆伝播について簡単におさらいです。今回の解説では、図のように全結合層が一つだけあるシンプルなニューラルネットワークを例に説明します(隠れ層もなしです)。
、は全結合層の重みとバイアス、は活性化関数です。が出力層(予測値)であり、と正解ラベルの誤差をコスト関数で求めます。
(や、の形、および活性化関数やコスト関数は解きたいタスクによって変わります。)
このネットワークの順伝播は
と表されます。
勾配降下法によって重みとバイアスを更新するために、目的関数の勾配、を求めなければなりません。
これらの勾配は、微分の連鎖率を使った誤差逆伝播法によって次のように求められます。
上記の通り、まずを求め、そこから、が求まります。
ここまではわかっている方も多いと思いますが、今回は特にこのの導出に焦点を当てて、実際にこれらの値を計算して求めてみましょう。
二値分類モデルの場合、マルチクラス分類モデルの場合の順に、計算の流れを追って説明していきます。
1. シグモイド関数を用いた二値分類の場合
まずは二値分類モデルの場合です。二値分類の場合は、分類したいデータに対して正解のクラスを示すラベルが0か1の値で与えられます。
モデルは0~1の間の値を出力することで、例えば「出力値が0.5以上ならクラス1に分類し、0.5未満ならクラス0に分類する」のように人間(設計者)が決めた閾値に従って分類します。
よって、具体的には次のようなネットワークになります。
出力層のユニットは一つで、活性化関数をシグモイド関数にすることで0~1の間の値を出力することができます。
コスト関数は一般的に交差エントロピー関数が使われ、訓練データ数をm個とすると
となります。
(添え字は、訓練データ集合の中の番目のデータであることを示します。)
ではさっそくを求めていきますが、誤差逆伝播法による求め方を改めて書くとです。
つまり、との2つがわかればも計算できます。この2つを順に求めていきましょう。
まずはです。
(※ここからは、訓練データ中の番目のデータに対して考えていきます。)
これは上記のコスト関数をそのままで微分するだけなので簡単で、以下のようになります。
(対数の微分はとなることを用います。)
次にです。これは、順伝播の式をで微分したものです。つまり、活性化関数を微分したものになります。
ここではシグモイド関数を用いているので、途中計算は省きますが
と表すことができます。
よって、は
と求めることができました。番目だけでなく全てのデータに対して同じ結果になるので、データ集合全体で書くと
となります。
2. ソフトマックス関数を用いたマルチクラス分類の場合
マルチクラス分類の場合は、分類したいデータに対して正解ラベルがone-hotベクトルで与えられます。one-hotベクトルは、クラス数分の要素があり、正解クラスの位置だけ1、それ以外は0になっているベクトルです。これに対してモデルは、クラス数分の要素数を持ち、全要素の合計が1になるベクトルを出力するように設計し、出力ベクトルにおいて値が一番大きい位置のクラスを予測値とします。
つまり、具体的には次のようなネットワークです。ここでは、クラス数が3個の場合を例に考えます。
出力層はクラス数分のユニットで、活性化関数はソフトマックス関数です。
コスト関数は交差エントロピー関数(先ほどとは少し形が違います)を使用します。
は予測値および正解ラベルの何番目の要素かを表し、今はクラスが3つある場合を考えているのでです。
では本題のを求めたいのですが、ここでは3つあるユニットのうち1つ、に関する勾配を導出します。
(※先ほどと同じように番目のデータに関して求めますが、式が見づらくなるためここからは添え字は省略し、クラス数に関する添え字のみ付けます。)
さて、微分の連鎖率によって普通だったらとなりそうですが、実は違います(ここが勘違いしやすく、最も大事なポイントです)。
ソフトマックス関数の計算上、は後ろの全てに関わっています。そのため、の勾配も全てから伝播してきた勾配を足す必要があり、連鎖率の式は
となります。一つ一つの項を見ればさっきまでとほぼ同じ形ですが、このようにの全てを考慮する必要があるのです。
は、それぞれコスト関数の式を微分するだけなので簡単に計算することができ、
です。
続いてですが、これは活性化関数の微分です。
計算過程は省きますが、ソフトマックス関数の微分は、の何番目の値をの何番目の値で微分するかという組み合わせによって結果が変わり、次のように書き表せます。
つまり、同じ位置同士の場合と違う位置の場合で変わってくるのです。 今回はでを微分しているので、上の式に当てはめると
です。
よって、は
と計算できます。ここで、正解ラベルは、正解クラスのみ1、それ以外は0のベクトルでした。そのため、正解クラスが何であっても必ずになります。
したがって上の続きは、
となります。
他のユニットに関しても同じように計算できて
となり、当然番目だけでなく全てのデータに対して当てはまるので、全体で表すと
となります。
気づいた方もいるかもしれませんが、実はこれは最初に導出した二値分類モデルと同じ結果です。
活性化関数もコスト関数の式も違うのに、の値が同じになりました。
結果だけ見るとそのせいで混乱する人も多いのですが、ここまで導出過程を追ってくれた方は大丈夫だと思います。
特にソフトマックス関数を用いたマルチクラス分類モデルのケースは詳しい導出過程を説明している資料が(特にネット上には)少ないので、ぜひ覚えておきましょう。
wとbに関する勾配
さて、誤差逆伝播法の最終的な目的はを求めることでした。
最初にも確認しましたが、微分の連鎖率から
のように求められます。
は、順伝播の式をそれぞれで微分すればいいので簡単に求まります。
実際に計算すると、さっき求めたも合わせて、次のようになります。
このように求まった勾配を用いて、勾配降下法によってパラメータを更新することになります。
まとめ
今回はニューラルネットワークにおいて、シグモイド関数を用いた二値分類の場合とソフトマックス関数を用いたマルチクラス分類の場合それぞれの誤差逆伝播法による勾配計算を解説しました。
実用上はライブラリが自動でやってくれるのでここまで理解せずとも使うことはできますが、中でどのような計算が行われているのか知っておくことは大事です。
ここでは極めてシンプルな例を用いて説明しましたが、層を深くしたり、他の活性化関数を使ったりしても基本的には同じように勾配を求めることができます。