本記事では、深層学習において重要なテクニックの一つであるデータオーグメンテーション(データ拡張)について解説します。PythonのディープラーニングフレームワークであるPyTorchを用いた簡単な実装方法についても紹介します。
データ拡張とは
深層学習では非常に多くのデータが必要とされますが、データが少ないときもあります。そんなときにデータを増やすための手段の一つがデータ拡張で、画像データにおいて用いられます。どのようにデータを増やすのかですが、すでに存在する実際のデータに対して少しだけ変化を加えたものをたくさん作ることで、データ数を”水増し”します。しかしただ闇雲に増やせばよいというわけではなく、テストしたときによりよい精度を発揮するためにはどのような変化を加えるかも考慮する必要があります。
今回はデータ拡張でデータに加える変化には具体的にどのようなものがあるのかを解説しつつ、その実装も紹介します。
PyTorchを用いた実装
PyTorchでデータ拡張を行う場合、主にtorchvisionというコンピュータビジョンを扱うためのライブラリを用います。なお、データ拡張の実装にはあらかじめデータの数自体を増やす「オフライン」の方法と、学習時にミニバッチ毎に変換を加えることで疑似的にデータ数を増やす「オンラインの方法」とがありますが、オンラインの方法の方が実際の画像枚数が増えない分メモリを食わないという利点があります。torchvisionを用いればオンラインのデータ拡張を行えますが、今回は単純に一枚の画像のみに対して処理を行うだけとします。
今回処理を行う画像データは、次のコードを実行して表示されるものを用いることにします。
デフォルトのコード
import matplotlib.pyplot as plt
import cv2
# OpenCVを使って画像を読み込む
img = cv2.imread('/mnt/lib/ztodataset/hk_nami.jpg')
#BGRをRGBに変換
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
#読み込んだ画像を表示
plt.imshow(img_rgb)

では、データ拡張にはどのような手法があるかを紹介しながら実装していきます。
Random Flip
Random Flip は、ランダムな確率で画像を反転させる処理です。HorizontalFlipとVerticalFlipの二種類があり、HorizontalFlipは画像を水平方向に反転させ、VerticalFlipは画像を垂直方向に反転させます。
では、Random Flip処理を行ってみましょう。Random Flipの処理はtorchvision.transforms.RandomHorizontalFlip(水平方向)もしくはtorchvision.transforms.RandomVerticalFlip(垂直方向)で実装することができます。引数には反転処理を行う確率を設定します(今回は1にしてあるので必ず反転処理を行います)。また、以下では画像を正しく表示するため、まずtransforms.ToPILImageでPIL画像に変換しています。このように、torchvision.transformsの処理を行った画像はPIL画像に直してあげる必要があります(PILとは、Pythonの画像処理ライブラリであるPillowのことです)。
デフォルトのコード
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import cv2
# OpenCVを使って画像を読み込む
img = cv2.imread('/mnt/lib/ztodataset/hk_nami.jpg')
#BGRをRGBに変換
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
#行う変換を定義
transform = transforms.Compose([
transforms.ToPILImage(), #PIL画像に変換する処理
transforms.RandomHorizontalFlip(p = 1)
])
trans_img = transform(img_rgb)
#画像描画用の関数
def visualize(original, augmented):
fig = plt.figure()
plt.subplot(1,2,1)
plt.title('Original image')
plt.imshow(original)
plt.subplot(1,2,2)
plt.title('Augmented image')
plt.imshow(augmented)
visualize(img_rgb, trans_img)

Random Flip適用前と比較すると、確かに左右反転していることがわかります。
Random Erase
Random Eraseは画像の一部をランダムな確率で消去する処理です。torchvisionではtorchvision.transforms.RandomErasingで実装することができます。Random Flipのときと同様に実装してみましょう。なお、torchvision.transforms.RandomErasingはPIL画像に適用することができずTensorデータに対し適用できるため、ここではまず画像をTensorデータにしてRandom Erase処理を行ったのちPIL画像に変換しています。
デフォルトのコード
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import cv2
# OpenCVを使って画像を読み込む
img = cv2.imread('/mnt/lib/ztodataset/hk_nami.jpg')
#BGRをRGBに変換
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
#データに加える変換を定義
transform = transforms.Compose([
transforms.ToTensor(),
transforms.RandomErasing(p=1),
transforms.ToPILImage() #tensorに変換して処理を行った後、正しく表示するためPIL形式にする
])
trans_img = transform(img_rgb)
#画像描画用の関数
def visualize(original, augmented):
fig = plt.figure()
plt.subplot(1,2,1)
plt.title('Original image')
plt.imshow(original)
plt.subplot(1,2,2)
plt.title('Augmented image')
plt.imshow(augmented)
visualize(img_rgb, trans_img)

なお、torchvision.transforms.RandomErasingは引数を指定することで処理を行う確率以外にも消去する面積の大きさなども変えることができます(詳しくはtorchvision公式ドキュメントを参照してください)。
Random Crop
Random Cropは画像のランダムな一部を切り抜く処理を行います。torchvisionではtransforms.RandomCropで実装でき、引数には切り取った後の画像サイズを指定します。
デフォルトのコード
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import cv2
# OpenCVを使って画像を読み込む
img = cv2.imread('/mnt/lib/ztodataset/hk_nami.jpg')
#BGRをRGBに変換
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
#データに加える変換を定義
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.RandomCrop(size = 200)
])
trans_img = transform(img_rgb)
#画像描画用の関数
def visualize(original, augmented):
fig = plt.figure()
plt.subplot(1,2,1)
plt.title('Original image')
plt.imshow(original)
plt.subplot(1,2,2)
plt.title('Augmented image')
plt.imshow(augmented)
visualize(img_rgb, trans_img)

縦軸、横軸の値を見ると切り取っていることがよくわかります。
Random Contrast, Random Brightness
Random Contrastは画像のコントラストをランダムに変更する処理で、Random Brightnessもその名の通り画像の明るさをランダムに変更する処理です。torchvisionにおいては、いずれの処理もtorchvision.transforms.ColorJitterの引数にそれぞれ数値を指定することで実装することができます。明るさ、コントラストは引数に指定された数値の範囲からランダムに決められます。
デフォルトのコード
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import cv2
# OpenCVを使って画像を読み込む
img = cv2.imread('/mnt/lib/ztodataset/hk_nami.jpg')
#BGRをRGBに変換
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
#データに加える変換を定義
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.ColorJitter(brightness=0.3, contrast=0.5) #今回は明るさ、コントラストを変更
])
trans_img = transform(img_rgb)
#画像描画用の関数
def visualize(original, augmented):
fig = plt.figure()
plt.subplot(1,2,1)
plt.title('Original image')
plt.imshow(original)
plt.subplot(1,2,2)
plt.title('Augmented image')
plt.imshow(augmented)
visualize(img_rgb, trans_img)

ColorJitterは他にも画像の彩度や色相を変化させる処理も行えます。
Random Rotate
Random Rotateは画像をランダムな角度だけ回転させる処理です。torchvisionではtransforms.RandomRotationで実装することができます。
デフォルトのコード
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import cv2
# OpenCVを使って画像を読み込む
img = cv2.imread('/mnt/lib/ztodataset/hk_nami.jpg')
#BGRをRGBに変換
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
#データに加える変換を定義
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.RandomRotation(degrees = 50) #引数には回転する角度の範囲を指定
])
trans_img = transform(img_rgb)
#画像描画用の関数
def visualize(original, augmented):
fig = plt.figure()
plt.subplot(1,2,1)
plt.title('Original image')
plt.imshow(original)
plt.subplot(1,2,2)
plt.title('Augmented image')
plt.imshow(augmented)
visualize(img_rgb, trans_img)

transforms.RandomRotationは引数に回転する角度の範囲を指定する必要があります。他にも、回転中心の座標を指定することなども可能です。
まとめ
今回はデータ拡張で頻繁に用いられる手法のいくつかを簡単に実装しました。データ拡張の手段は今回紹介したもの以外にも多数存在するので、興味がある方は個人で調べてみてもよいでしょう。