Vision Transformer(ViT)は画像認識手法の一つで、畳み込み層を用いずに高い精度を出したことから注目を浴びたモデルです。この手法では深層学習を用いた自然言語処理において有名なモデルであるTransformerを画像分類タスクに用いています。今回はいくつかの重要なポイントに気を付けながらVision Transformerの仕組みについて、順を追って解説していきます。
Contents
Vision Transformer(ViT)の仕組み
Vision Transformerの構造は以下左図のようになっています。
Vision Transformerの流れとしては、
- 入力画像をパッチに分割
- パッチをベクトル化
- ベクトルを埋め込み
- CLSトークン追加、位置埋め込み
- Transformer Encoderの処理
- MLPヘッドで分類
となります。以下、順に解説していきます。
入力画像をパッチへ分割
まず、Vision Transformerでは入力画像を以下のように固定サイズのパッチに分割します。なお、下の画像では例として4つに分割していますがもちろん実際はもっと細かく分割します。
ちなみにこれは自然言語処理において文章を単語ごとに区切って単語をベクトル化(単語埋め込み)するのと同じような処理です。つまり、各パッチは自然言語処理における単語に相当します。
また、Transformerは単語ベクトルのような一次元ベクトルを受け取るため、画像パッチも一次元のベクトルに変形する必要があります。つまりH×W×C(高さ×幅×チャンネル数)の元画像をN×(P^2・C)の形に変形します。ここで、NについてN=HW/p^2でNはパッチ数を表し、Pは各パッチの縦もしくは横の大きさです。つまり、1次元ベクトルにする前の各パッチの高さ×幅はP×Pということになります。
この変形は以下の図のようになっています。以降、分割後の画像パッチをのように表します。
パッチのベクトル化
Transformerは全ての層で一定の潜在ベクトル次元数Dを取るため、さらにこの変形したパッチをD次元空間に埋め込みます。すなわち、変形後は(P^2・C)×Dの形となります。
ここまでの入力画像に対する処理をパッチ埋め込み(Patch Embedding)と呼びます。
CLSトークンの追加、位置埋め込み
また、BERTにおける[class]トークンと同様に、パッチ埋め込み結果の先頭に[CLS]トークンを追加します。この[CLS]トークンは最終的に図6のMLP Headにおいて分類を行うために用いられ、学習されるパラメータです。
また、位置情報を保持するために位置埋め込み(Position Embedding)も各パッチに加算されます。位置埋め込みは各パッチの位置に関するベクトルの集合であり、他のパラメータと共に学習されるパラメータです。
ここまでやって、ようやく入力データの前処理が完了です。
Transformer Encoderの処理
次に、Vision Transformerのモデル構造は以下の図のようになっているのでした。以下左図におけるLinear Projection of Flattened Patchesの部分はすでに説明した埋め込みの部分です。
上右図におけるTransformer Encoderの部分は元となった自然言語処理におけるTransformerのエンコーダ部分(下の画像における左半分の部分)とほぼ同じ構造をしています。
これらの画像から、Transformerエンコーダとの相違点として、Vision TransformerではNormがMulti Head AttentionとMLPそれぞれの手前に位置しているということがわかります。なお、NormとしてはLayerNormalizationが用いられています。
以上を踏まえ、モデルの計算を数式で表すと以下のようになります。
上の式(1)におけるZ0は前章で説明した入力です。LNはLayerNormalizationを表し、MSAはMulti Head Self Attentionを表します。式(2)、(3)がエンコーダ部分の処理に該当します。最後のyをMLP Headに入れることで、欲しい出力を得ることができます。
また、Vision TransformerにおいてはMLPは二層で、さらに活性化関数としてGELU関数を採用しています。
事前学習とファインチューニング
もう一つ重要な点として、Vision Transformerでは他の画像認識モデルと同様に大規模データセットを用いて十分な事前学習とファインチューニングを行いますが、いくつか注意すべき点があります。まず大規模データセットで事前学習した後、タスクに合わせてファインチューニングするという流れは同じです。
この際、MLPヘッドを取り除き、新たに初期化されたMLPヘッドを追加します。また、ファインチューニング時の解像度は事前学習時よりも大きくしますが、ファインチューニング時と事前学習時のバッチサイズは同じにします。このため、ファインチューニング時にはシーケンス長さが大きくなってしまうという事態が発生しますが、事前学習した位置埋め込みはファインチューニング時には元画像に応じて二次元補間します。
この事前学習とファインチューニングを行うことで、Vision Transformerは非常に高いパフォーマンスを出すことに成功しています。
Vision Transformerを実装する
ここからはPyTorchによる実装を交えつつ、より深い理解を目指しましょう。
まず、画像を分割し、CLSトークンの追加、位置埋め込みの追加などを行います。
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
# PatchEmbeddingクラスの定義
class PatchEmbedding(nn.Module):
def __init__(self, in_channels, patch_size, emb_size, img_size):
super(PatchEmbedding, self).__init__()
self.patch_size = patch_size
self.emb_size = emb_size
self.projection = nn.Linear(patch_size * patch_size * in_channels, emb_size)
self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
self.position_embeddings = nn.Parameter(torch.randn((img_size // patch_size) ** 2 + 1, emb_size))
def forward(self, x):
b, c, h, w = x.shape
p = self.patch_size
# 画像をパッチに分割
x = x.view(b, c, h // p, p, w // p, p)
x = x.permute(0, 2, 4, 3, 5, 1).contiguous().view(b, -1, p * p * c)
# パッチを埋め込みベクトルに変換
x = self.projection(x)
# CLSトークンの追加
cls_tokens = self.cls_token.expand(b, -1, -1)
x = torch.cat([cls_tokens, x], dim=1)
# 位置エンコーディングの追加
x += self.position_embeddings
return x
次に、Transformer Encoderの主要な構成要素であるMulti Head Attention層を実装します。実装内容としては通常のTransformerで用いられているものと同じであり、詳しくは解説しませんが、QueryとKeyのドット積を取ることで、類似度を表すAttention行列というものを作成し、出力はAttention行列によって重み付けされたValueの線形和となります。
# MultiHeadAttentionクラスの定義
class MultiHeadAttention(nn.Module):
def __init__(self, emb_size, num_heads):
super(MultiHeadAttention, self).__init__()
self.emb_size = emb_size
self.num_heads = num_heads
self.keys = nn.Linear(emb_size, emb_size)
self.queries = nn.Linear(emb_size, emb_size)
self.values = nn.Linear(emb_size, emb_size)
self.unify_heads = nn.Linear(emb_size, emb_size)
def forward(self, x):
b, t, e = x.size()
h = self.num_heads
assert e == self.emb_size, f'Expected input embedding size {self.emb_size}, but got {e}'
# ヘッドごとに分割
keys = self.keys(x).view(b, t, h, e // h).transpose(1, 2)
queries = self.queries(x).view(b, t, h, e // h).transpose(1, 2)
values = self.values(x).view(b, t, h, e // h).transpose(1, 2)
# attention計算
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
attention = torch.softmax(energy / (e ** (1 / 2)), dim=-1)
out = torch.einsum('bhqk, bhkd -> bhqd', attention, values).transpose(1, 2).contiguous()
out = out.view(b, t, e)
return self.unify_heads(out)
このMulti Head Attention層を用いて、図6右のようにTransformer Encoderを実装します。
# TransformerBlockクラスの定義
class TransformerBlock(nn.Module):
def __init__(self, emb_size, num_heads, forward_expansion):
super(TransformerBlock, self).__init__()
self.attention = MultiHeadAttention(emb_size, num_heads)
self.norm1 = nn.LayerNorm(emb_size)
self.norm2 = nn.LayerNorm(emb_size)
#MLP層
self.feed_forward = nn.Sequential(
nn.Linear(emb_size, forward_expansion * emb_size),
nn.ReLU(),
nn.Linear(forward_expansion * emb_size, emb_size)
)
def forward(self, x):
x = self.attention(self.norm1(x)) + x
x = self.feed_forward(self.norm2(x)) + x
return x
これらをまとめ、最終的なViTの実装は以下のようになります。図6左にそれぞれ対応していることがわかります。
class VisionTransformer(nn.Module):
def __init__(self, in_channels, patch_size, emb_size, img_size, num_layers, num_heads, forward_expansion, num_classes):
# VisionTransformerクラスの定義
class VisionTransformer(nn.Module):
def __init__(self, in_channels, patch_size, emb_size, img_size, num_layers, num_heads, forward_expansion, num_classes):
super(VisionTransformer, self).__init__()
self.patch_embedding = PatchEmbedding(in_channels, patch_size, emb_size, img_size)
self.transformer = nn.Sequential(
*[TransformerBlock(emb_size, num_heads, forward_expansion) for _ in range(num_layers)]
)
self.to_cls_token = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(emb_size),
nn.Linear(emb_size, num_classes)
)
def forward(self, x):
x = self.patch_embedding(x)
x = self.transformer(x)
x = self.to_cls_token(x[:, 0])
x = self.mlp_head(x)
return x
GPU環境ではないので学習させるのは難しいですが、ダミーの画像データを入れてみて、ちゃんと計算結果が出せることは確認できます。
これをCIFAR-10データセットで実際に学習させるコードは以下のようになります(データのロード、モデルの定義などは省略しています)。
# CIFAR-10用のViTモデルのパラメータ設定
img_size = 32
patch_size = 4
in_channels = 3
emb_size = 128
num_layers = 6
num_heads = 8
forward_expansion = 4
num_classes = 10
model = VisionTransformer(in_channels, patch_size, emb_size, img_size, num_layers, num_heads, forward_expansion, num_classes)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
# 損失関数とオプティマイザの設定
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
# 学習ループ
num_epochs = 20
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for inputs, labels in trainloader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f'Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(trainloader)}')
print('Finished Training')
# テスト
model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in testloader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Accuracy on the test images: {100 * correct / total} %')
GPU環境で学習させてみると、だいたい60%程度の精度が出ます。
まとめ
今回はTransformerを画像認識に用いたVision Transformerについて解説しました。E資格の範囲に新しく含まれる箇所なので、よく復習しておくとよいでしょう。Vision Transformerの論文(AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE)も参考にしてみるといいかもしれません。