本記事では深層モデルの過学習を防止する正則化の手法として、ドロップアウトやドロップコネクトについて解説します。
ドロップアウト
ドロップアウト(Dropout)はHintonらにより2012年に提案された手法です(Hinton etal., 2012, “Improving neural networks by preventing co-adaptation of feature detectors”)。ドロップアウトとは、深層ニューラルネットワークモデルの学習時にモデルの一部を学習させないことによって過学習を防ぐ手段です。ドロップアウトでは具体的に何をしているのかというと、以下の図のように学習時に一部のノードのみ学習させないということを行っています。どのノードを学習させないかはミニバッチ毎にランダムに変化させます。
これは数学的な定式化もできます。まず、以下のように入力vベクトル、バイアスを含む重み行列W、出力hベクトル、非線形の活性化関数aからなる全結合層は以下のように表せます。
ドロップアウトは、この出力hに対してバイナリマスク(0と1の要素で構成される行列)をかける操作に該当します。よって、ベルヌーイ分布に従ってバイナリマスクmが決まるとき、最終的な出力h’は
と表せます。ここで*は要素積です。ちなみにベルヌーイ分布というのは確率pで1を、確率1-Pで0をとるという確率分布のことです。
ドロップアウトの効果は二つあり、
- 過学習を防ぐことに役立つ
- 複数モデルの結果を組み合わせることで精度が向上する
です。まず過学習を防げるということについてですが、これはミニバッチ毎に一部のノードが不活性化されることで、パラメータの更新のし過ぎが抑制されるためです。また、複数モデルの結果の組み合わせになるということについてですが、ミニバッチ毎にネットワークのランダムなノードが不活性化されるため、ミニバッチ毎にネットワークは異なるものとなることがわかります。よってドロップアウトを用いた学習は異なるネットワークの結果の組み合わせとなり、汎化性能が向上します。アンサンブル学習に似た効果が得られるというわけです。
なお、ドロップアウト処理は学習時にのみ行う処理であり、推論時には全てのニューロンが有効になります。しかし、そうすると学習の時とは全体のニューロン数が変わってしまうため、推論の際には以下のようにニューロンの出力を1-pでスケーリングしてやる必要があります(もしくは、学習時に各ニューロンの出力を1/(1-p)倍してやることでも同じ効果が得られます)。
ドロップコネクト
次に、ドロップコネクト(DropConnect)についてです。ドロップコネクトはLi Wanらにより2013年に提案されました(Li et al., 2013, “Regularization of Neural Networks Using DropConnect”)。ドロップコネクトではノードを不活性化させるのではなく、代わりにノード間の結合をランダムに切ります。つまり、結合の重みをランダムで0にします。これにより、ノードは前のノードからランダムに入力を受けるわけです。
つまりドロップアウトとなにが違うのかというと、ドロップアウトは全結合層の出力をランダムに0にするのに対し、ドロップコネクトは全結合層の重みをランダムに0にするということです。
こちらもドロップアウト同様、過学習の防止及び複数モデルの組み合わせ学習という効果を得ることができます。ただ、通常はノードよりもコネクションの方が数が多いため、ドロップコネクトの方がより多くのパターンを持ちます。
また、ドロップアウトと同様に数学的な定式化をすると、ドロップアウトのときと同様にバイナリマスクMをベルヌーイ分布に従って生成し、
重みWにバイナリマスクをかけます。
あとはドロップアウトの項で示したように普通に全結合層を計算すればいいというわけです。
ドロップアウトの定式化と比較すると、こちらでは重みに対して結合情報をエンコーディングしているということがわかります。
また、ドロップコネクトでもドロップアウトのときと同様に、推論時には学習時に無効化された重みを補うため重みを1-pでスケーリングします(もしくは、学習時に各重みを1/(1-p)倍してやることでも同じ効果が得られます)。
Pytorchで実装してみる
まず、ドロップアウトやドロップコネクトの効果を確認するために普通のモデルを用いてMNISTデータを10エポック学習させてみます。その学習結果は以下のようになります。学習過程での損失、正解率もプロットします。わざと過学習を発生させるため、少ないデータで学習させています。
なお、以降のコードはモデルの乱数を固定していないため実行ごとに結果が異なり、正式に比較はできませんが、固定すればより正確に比較できます。
少しですが過学習が発生していることがわかります。もう少しデータを偏らせたりしたらより顕著になるかもしれませんが、ドロップアウトやドロップコネクトの効果を見るには十分です。
では、ドロップアウトとドロップコネクトを実装してみましょう。ドロップアウトはpytorchではすでに実装されており、モデル中でtorch.nn.Droputを用いるだけで簡単に使えます。以下はSubclassing APIとしてモデルを定義したとき、ドロップアウトを用いる例です。
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, padding=2)
self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, padding=2)
self.dropout_1 = nn.Dropout(0.5)
self.dropout_2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(32 * 7 * 7, 128)
self.fc2 = nn.Linear(128,10)
def forward(self,x):
x = torch.relu(F.max_pool2d(self.conv1(x), 2))
x = torch.relu(F.max_pool2d(self.conv2(x), 2))
x = x.view(-1, x.size(1) * x.size(2) * x.size(3)) #テンソルを平らに(1次元に)する処理
x = self.dropout_1(x) #ドロップアウトを適用
x = torch.relu(self.fc1(x))
x = self.dropout_2(x) #ドロップアウトを適用
x = self.fc2(x)
return x
これを学習させ、その学習過程をプロットすると以下のようになります。
あんまり効果がないように見えますね。
一方、ドロップコネクトはpytorchにデフォルトの機能がないので、カスタムレイヤとして自分で実装してやる必要があります。以下は実装の一例です。重みを確率的に消す処理を行います。
def _weight_drop(module, weights, dropout):
for name_w in weights:
w = getattr(module, name_w)
del module._parameters[name_w]
module.register_parameter(name_w + '_raw', Parameter(w))
original_module_forward = module.forward
def forward(*args, **kwargs):
for name_w in weights:
raw_w = getattr(module, name_w + '_raw')
w = torch.nn.functional.dropout(raw_w, p=dropout, training=module.training)
setattr(module, name_w, w)
return original_module_forward(*args, **kwargs)
setattr(module, 'forward', forward)
class WeightDropLinear(torch.nn.Linear):
def __init__(self, *args, weight_dropout=0.0, **kwargs):
super().__init__(*args, **kwargs)
weights = ['weight']
_weight_drop(self, weights, weight_dropout)
このクラスを用いてドロップコネクトを実装したモデルで学習させてみたものは以下になります。
正直微妙な結果ですね。今回はドロップアウトやドロップコネクトにより過学習が改善するところを見たかったのですが、もっと過学習が顕著に出てるデータを使えば、はっきり効果がわかるかもしれません。
TensorFlowで実装してみる
同じことをTensorFlowでもやってみます。まず普通のモデルで過学習を起こします。タイムアウトになる場合は何回か実行してみてください。
こちらはかなり綺麗に(?)過学習が出ています。
TensorFlowでもドロップアウトの機能は容易に用いることができ、以下のようにモデル中でtf.keras.layers.Dropoutを用いるだけです。
結構改善が見られますね。
一方、TensorFlowでドロップコネクトを用いるにはdropconnect-tensorflowを用いれば簡単に実装することができます。
pip install dropconnect-tensorflow
import tensorflow as tf
from tensorflow.keras import layers,models
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from dropconnect_tensorflow import DropConnectDense
#データを読み込む
mnist = tf.keras.datasets.mnist
(X_train, y_train),(X_test, y_test) = mnist.load_data(path='/mnt/lib/ztodataset/mnist.npz')
#データを少なく取る
df = pd.DataFrame(columns=["label"])
df["label"] = y_train.reshape([-1])
list_0 = df.loc[df.label==0].sample(n=10)#n=10でsampling
list_1 = df.loc[df.label==1].sample(n=10)
list_2 = df.loc[df.label==2].sample(n=10)
list_3 = df.loc[df.label==3].sample(n=10)
list_4 = df.loc[df.label==4].sample(n=10)
list_5 = df.loc[df.label==5].sample(n=10)
list_6 = df.loc[df.label==6].sample(n=10)
list_7 = df.loc[df.label==7].sample(n=10)
list_8 = df.loc[df.label==8].sample(n=10)
list_9 = df.loc[df.label==9].sample(n=10)
label_list = pd.concat([list_0,list_1,list_2,list_3,list_4,list_5,list_6,list_7,list_8,
list_9])
label_list = label_list.sort_index()
label_idx = label_list.index.values
train_label = label_list.label.values
"""
x_trainからlabel用のdataframe.indexを取り出すことでlabelに対応したデータを取り出す。
"""
X_train = X_train[label_idx]
y_train= train_label
#正規化
X_train, X_test = X_train/255.0, X_test/255.0
model = tf.keras.models.Sequential([
tf.keras.layers.Reshape((28, 28, 1), input_shape=(28, 28)),
tf.keras.layers.Conv2D(16, (5, 5), activation='relu'),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Conv2D(32, (5, 5), activation='relu'),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Flatten(),
DropConnectDense(128, prob=0.5, activation="relu"), #線形層にドロップコネクトを適用
DropConnectDense(10, prob=0.5, activation="softmax")
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
history = model.fit(X_train, y_train,
batch_size=64,
epochs=50,
verbose=0,
validation_data=(X_test, y_test))
plt.plot(history.history['loss'], label='train')
plt.plot(history.history['val_loss'], label = 'val')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(loc='lower right')
plt.show()
こちらはなぜかあまり改善が見られませんでした。
まとめ
今回は、ドロップアウトとドロップコネクトについて簡単に解説しました。いずれも非常に重要なところなのでわからないところがある場合はよく復習しておきましょう。