本記事では、画像認識の代表的な手法であるセマンティックセグメンテーションについて解説します。具体的な手法としてFCN、U-Netを取り上げる他、インスタンスセグメンテーション、パノプティックセグメンテーションについても解説します。
セマンティックセグメンテーションとは
セマンティックセグメンテーションとは、画像のピクセル一つ一つに対してラベル付けしていくことで、画像全体ではなくピクセル単位で分類を行う手法です。画像の分類や画像に写っている物体を検出するのではなく、ピクセル単位で分類を行うことで不規則な形状の物体も容易に識別することができます。ただ、以下の画像のように、同じ種類の物体は区別しません。そのため、同じ種類の物体同士が隣接していると区別できません。注意点として、セグメンテーションでは物体検出と異なりピクセル単位での分類を行うため、ピクセル単位での情報が付加されたデータセットが必要です。
セマンティックセグメンテーションは近年では自動運転、医療現場における被写体の識別などでよく用いられています。
セマンティックセグメンテーションにはいくつかの手法がありますが、今回はその中でもFCN(Fully Convolutional Network ; 全畳み込みネットワーク)と、エンコーダ-デコーダ構造を用いたU-Netについて解説します。
FCN
FCNはCNN(Convolutional Neural Network ; 畳み込みネットワーク)をセマンティックセグメンテーションに利用した手法で、大きな特徴として全結合層を用いず、畳み込み層のみで構成されたモデルを用います。以下の図のようにVGG16など優れた画像分類モデルの全結合層を全て1×1の畳み込み層に置き換えます。
また、最後にアップサンプリング層を追加します。このアップサンプリング層は畳み込みの逆のプロセスを行うので、Transpose Convolutional LayerやDeconvolutional Layerと呼ばれます。
しかし、全ての層が畳み込み層であると層が深くなるにつれて情報の損失が大きくなります。これを解決するために、上層のプーリング層の出力をアップサンプリングした出力に加算するという処理(スキップ接続)を行います。以下の図のように、FCN8sでは上層の二つのプーリング層からの出力をアップサンプリングしたものに足し合わせていることがわかります。
余談ですが、VGG16など元となったモデルと共通する畳み込み層は学習済みのパラメータを用いて初期化することで、学習を高速化できます。
U-Net
U-Netも代表的なセマンティックセグメンテーション手法です。全ての層が畳み込み層で構成されている点はFCNと同じです。U-Net自体の構造は以下の図のようになっています。この構造がUの字型をしているためU-Netという名前がついているようです。
上の図の左側がエンコーダ、右側がデコーダに対応します。エンコーダ側は通常の畳み込み層を用いたネットワークになっていますが、デコーダ側ではプーリングの代わりにアップサンプリングが行われていることがわかります。また、エンコーダ側からデコーダ側へ矢印が伸びています。これはスキップ接続で、特徴量の伝播過程で情報が失われてしまうことを防ぐためにエンコーダ側の各層で出力される特徴マップをデコーダ側の対応する層の特徴マップに連結します。ここで、U-Netにおけるスキップ接続は厳密にはFCNにおけるスキップ接続とは異なっています。
パノプティックセグメンテーション
最後に、パノプティックセグメンテーションについて説明します。
画像のセグメンテーション手法には、セマンティックセグメンテーションの他にもインスタンスセグメンテーション、パノプティックセグメンテーションという二つの手法があります。既に述べた通り、セマンティックセグメンテーションは以下のように同じ種類の物体は区別しないため、同じ種類の物体同士が隣接していると区別できません。
一方、インスタンスセグメンテーションは画像中における全ての物体の領域を特定し、個体ごとに物体の種類を分類する手法です。そのため、セマンティックセグメンテーションとは異なり、同じ種類の物体同士を区別することができますが、空や道路などの不定形のものに対しては分類を行いません。
これらセマンティックセグメンテーション及びインスタンスセグメンテーションを組み合わせたものがパノプティックセグメンテーションです。画像中の全てのピクセルについて分類を行うだけでなく、数えられる物体に対しては個々の物体を区別して認識します。
実装してみる
では、実際にセマンティックセグメンテーションを実装してみましょう。モデルを学習させるのはとても時間がかかるため、今回はtorchvisionで公開されているモデルをダウンロードして使ってみることにします。
from torchvision.models.segmentation import fcn_resnet101, FCN_ResNet101_Weights
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision.io import read_image
from torchvision import transforms as transforms
import seaborn as sns
# モデルの初期化
weights = FCN_ResNet101_Weights.DEFAULT
model = fcn_resnet101(weights=weights)
# モデルを評価モードにする
model.eval()
実際に実行してみると、線形層はなく、畳み込み層のみで構成されていることがわかります。
今回入力する画像は以下の3チャンネルで、縦×横=612×640のものになります。
from PIL import Image, ImageFilter
img = Image.open('cat-551554_640.jpg')
img
推論させる画像には前処理を施す必要がありますが、torchvisionには学習済みのパラメータに紐づけられて自動で前処理を行う機能があります。
from torchvision.models.segmentation import fcn_resnet101, FCN_ResNet101_Weights
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision.io import read_image
from torchvision import transforms as transforms
import seaborn as sns
# 重みと変換の初期化
weights = FCN_ResNet101_Weights.DEFAULT
preprocess = weights.transforms()
img = torch.zeros((3, 612, 640), dtype=torch.int8)
# 入力画像への適用
img_transformed = preprocess(img)
モデルに合わせ、自動で画像のリサイズ、標準化などを行ってくれます。
また、モデルは(バッチサイズ, 3, H, W) の入力が想定されているようなのでバッチ次元を追加した上で、実際にモデルに入力して出力の形状を見てみましょう。
from PIL import Image, ImageFilter
img = Image.open('cat-551554_640.jpg')
# モデルの初期化
weights = FCN_ResNet101_Weights.DEFAULT
model = fcn_resnet101(weights=weights)
# モデルを評価モードにする
model.eval()
# 推論用の前処理
preprocess = weights.transforms()
# 前処理変換の適用、バッチ次元追加
batch = preprocess(img).unsqueeze(0)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
img_batch = batch.to(device)
# 推論
output = model(img_batch)['out']
print(output.size())
出力の形状は[1, 21, 520, 543]で、縦×横が520×543になります。よって、アップサンプリングで元のサイズに戻します。
from PIL import Image, ImageFilter
img = Image.open('cat-551554_640.jpg')
# モデルの初期化
weights = FCN_ResNet101_Weights.DEFAULT
model = fcn_resnet101(weights=weights)
# モデルを評価モードにする
model.eval()
# 推論用の前処理
preprocess = weights.transforms()
# 前処理変換の適用、バッチ次元追加
batch = preprocess(img).unsqueeze(0)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
img_batch = batch.to(device)
# 推論
output = model(img_batch)['out']
print(output.size())
# もとの大きさに戻す
output = F.upsample(output, size=(612, 640), mode='bilinear')
print(output.size())
この出力はバッチサイズ1、チャンネル数21、縦×横=612×640です。バッチサイズはもともとモデルに入力するために追加したものなので消し、推論結果は確率であり最も高い確率のクラスを推論値とするのでチャンネル方向にargmaxを取ると[612, 640]の二次元データが得られます。
out = torch.argmax(output[0], dim=0)
推論結果を可視化するために、クラスごとにRGB値を割り当てます。
def decode_segmap(image, nc=21):
label_colors = np.array([(0, 0, 0), # 0=背景
(128, 0, 0), # 1=飛行機
(0, 128, 0), # 2=自転車
(128, 128, 0), # 3=鳥
(0, 0, 128), # 4=ボート
(128, 0, 128), # 5=瓶
(0, 128, 128), # 6=バス
(128, 128, 128), # 7=車
(64, 0, 0), # 8=猫
(192, 0, 0), # 9=椅子
(64, 128, 0), # 10=牛
(192, 128, 0), # 11=テーブル
(64, 0, 128), # 12=犬
(192, 0, 128), # 13=馬
(64, 128, 128), # 14=バイク
(192, 128, 128), # 15=人
(0, 64, 0), # 16=鉢植え
(128, 64, 0), # 17=羊
(0, 192, 0), # 18=ソファ
(128, 192, 0), # 19=列車
(0, 64, 128)]) # 20=テレビ
r = np.zeros_like(image).astype(np.uint8)
g = np.zeros_like(image).astype(np.uint8)
b = np.zeros_like(image).astype(np.uint8)
for l in range(0, nc):
idx = image == l
r[idx] = label_colors[l, 0]
g[idx] = label_colors[l, 1]
b[idx] = label_colors[l, 2]
rgb = np.stack([r, g, b], axis=2)
return rgb
推論させたものを可視化してみましょう。
from PIL import Image, ImageFilter
img = Image.open('cat-551554_640.jpg')
# モデルの初期化
weights = FCN_ResNet101_Weights.DEFAULT
model = fcn_resnet101(weights=weights)
# モデルを評価モードにする
model.eval()
# 推論用の前処理
preprocess = weights.transforms()
# 前処理変換の適用、バッチ次元追加
batch = preprocess(img).unsqueeze(0)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
img_batch = batch.to(device)
# 推論
output = model(img_batch)['out']
# もとの大きさに戻す
output = F.upsample(output, size=(612, 640), mode='bilinear')
out = torch.argmax(output[0], dim=0)
# 推論結果のデコード
def decode_segmap(image, nc=21):
label_colors = np.array([(0, 0, 0), # 0=背景
(128, 0, 0), # 1=飛行機
(0, 128, 0), # 2=自転車
(128, 128, 0), # 3=鳥
(0, 0, 128), # 4=ボート
(128, 0, 128), # 5=瓶
(0, 128, 128), # 6=バス
(128, 128, 128), # 7=車
(64, 0, 0), # 8=猫
(192, 0, 0), # 9=椅子
(64, 128, 0), # 10=牛
(192, 128, 0), # 11=テーブル
(64, 0, 128), # 12=犬
(192, 0, 128), # 13=馬
(64, 128, 128), # 14=バイク
(192, 128, 128), # 15=人
(0, 64, 0), # 16=鉢植え
(128, 64, 0), # 17=羊
(0, 192, 0), # 18=ソファ
(128, 192, 0), # 19=列車
(0, 64, 128)]) # 20=テレビ
r = np.zeros_like(image).astype(np.uint8)
g = np.zeros_like(image).astype(np.uint8)
b = np.zeros_like(image).astype(np.uint8)
for l in range(0, nc):
idx = image == l
r[idx] = label_colors[l, 0]
g[idx] = label_colors[l, 1]
b[idx] = label_colors[l, 2]
rgb = np.stack([r, g, b], axis=2)
return rgb
# セグメンテーションマップのデコード
segmap = decode_segmap(out.cpu().numpy())
# 結果の表示
fig, ax = plt.subplots(1, 2, figsize=(20, 10))
ax[0].imshow(Image.open('cat-551554_640.jpg'))
ax[0].set_title('Original Image')
ax[1].imshow(segmap)
ax[1].set_title('Segmentation Map')
plt.show()
結果は以下のようになります。
だいたいピクセルごとの分類ができていることがわかります。
まとめ
今回はセマンティックセグメンテーション、中でもFCNやU-Netについて解説した他、パノプティックセグメンテーションについても触れました。これらはE資格においても問われる可能性がある重要な箇所なので、よく復習しておくとよいでしょう。