生成モデルの代表格としてGAN(Generative Adversarial Network)があります。 今回はGANのモデルの中でも、ある画像から別の画像を生成することができるCycle GANと呼ばれるモデルを用いて葉書のデザインを行いたいと思います。おしゃれな風景画や花の絵が付いていたノートや絵葉書を今までに購入したことがある人がいると思います。今回はそのような商品に使えそうな風景画をベースとなる写真から生成してみたいと思います。
Cycle GANについて
概要
今回は生成モデルを用いて何ができるかの概要を掴むことをゴールとします。しかし具体的にどのようにして写真から風景画を生成できるかが気になる人がある程度いらっしゃると思いますので、ここで簡単に説明します。 Cycle GANでは元の画像からペアとなる画像(今回は生成する画像の訓練データ)を使わずにソースドメインXからターゲットドメインYへの変換を行います。訓練データを必要としないので、いわゆる教師無し学習です。例えば、馬の画像からシマウマの画像への変換は、ソースドメインが馬の画像で、ターゲットドメインがシマウマの画像となります。
損失関数について
例えば、ウマからシマウマへの画像変換においては、ウマとシマウマの2つの画像群(訓練データはウマとシマウマの画像は一対一である必要はありません)を用いて、ウマ画像からシマウマ画像へ自動変換するネットワークを学習します。本物画像の分布と見分けがつかない分布を生成する生成器を学習させます。また、一方向のみではなく、シマウマ画像からウマ画像へ逆変換するネットワークも同様に学習します。そして、元の画像から生成した画像を逆変換した画像が元の画像と極端にずれていないかを検討します。
よって、Cycle GANでは以下の項を目的関数に反映させます。
- XからYへの生成器をG、生成イメージG(X)とターゲットイメージYを比較し本物か偽物かを判定する識別器をDXとした場合の生成器と識別器の敵対損失
- YからXへの生成器をF、生成イメージF(Y)とターゲットイメージXを比較し本物か偽物かを判定する識別器DYとした場合の生成器と識別器の敵対損失
- 生成器GとFの矛盾を防ぐためのCycle Consistency Loss(XをG、Fの順に変換をかけてXに戻ってくるかどうかを、変換後の出力F(G(X))とXの絶対値を基準として評価する。逆も同様).(図2 (b),(c)に対応)
- Xから生成器Fを得てF(X)を作った時、もとのXとの絶対値を損失関数とするidentity Mapping Loss
参考までに損失関数の数式も表記します。
XからYに変換した時の生成画像の識別器
YからXに変換した時の生成画像の識別器
XからYへの変換
YからXへの変換
Cycle Consistency
Identity Loss
生成器、識別器の損失関数は以下のようになります。
識別器
生成器
上で定義した二つの損失関数をもとにG、F、DX、DYを最適化を行います。
CycleGANの実装
画像の前処理
今回は風景の写真を風景画に変換できるようにモデルの訓練を行います。モデルの訓練には、風景と写真の二つの画像データセットを用います。画像のソース
ではどのような画像が入っているか下のプログラムを実行して確認してみましょう。
右に写真、左に絵画風の画像が表示されました。次に、訓練用データとテスト用データの画像サイズ変更、クロッピング、左右反転化を行う関数を作成し、前処理を行います。
処理後の画像を比較してみましょう。※ランダムなので、適応されない場合もあります。
モデルの作成
これでデータセットの準備は整いました。では次にGANのモデルを作成しましょう。今回使用するGANのGeneratorのモデルは pix2pix.unet_generator()
と呼ばれる畳み込み層にUNet(Olafらによって生物医学のために開発された、セマンティックセグメンテーション用のモデル)と呼ばれるネットワークが用いられています。UNetの代わりにResNetを使っても構いません。UNetについての説明はここでは省略しますが、下にネットワークの概要図を示します。
それでは、概要の部分で説明した通り、二つの生成器と識別器のモデルを作成します。
では読み込んだモデルは学習させない状態ではどのような挙動を示すのでしょうか?写真から絵画へ、絵画から絵に変換する生成モデルを使ってみましょう。
DCGANの回でも試した通り、全く学習させていない状態では入力に対して全くデタラメな出力が返ってきます。二つの生成モデルが適切に訓練されると、写真から絵画のような画像、または絵画から写真のような画像を生成することができます。
損失、最適化手法の選択
では次に損失関数について考えます。今回はGenerator Discriminatorともに2組存在するので、損失関数もそれに対応して2組必要となります。 それに加えて今回はXからYへの変換、YからXへの変換において対応関係を担保するためにサイクル一貫性損失、アイデンティティー損失を追加で考えます。始めに示した数式をもとに損失関数を実装すると以下のように書くことができます。
LAMBDA = 10
#交差エントロピー
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)
#discriminatorの損失
def discriminator_loss(real, generated):
real_loss = loss_obj(tf.ones_like(real), real)
generated_loss = loss_obj(tf.zeros_like(generated), generated)
total_disc_loss = real_loss + generated_loss
return total_disc_loss * 0.5
#generatorの損失
def generator_loss(generated):
return loss_obj(tf.ones_like(generated), generated)
#サイクル一貫性損失
def calc_cycle_loss(real_image, cycled_image):
loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
return LAMBDA * loss1
#アイデンティティー損失
def identity_loss(real_image, same_image):
loss = tf.reduce_mean(tf.abs(real_image - same_image))
return LAMBDA * 0.5 * loss
定義した損失関数を最適化するための最適手法を選択します。勾配降下法にも種類があり、その中でもよく使われるAdam
を選択します。記述方法は前回のDCGANの最適化手法と同じです。ただ、今回は学習させるネットワークは合計で4つであることに注意が必要です。
#generator g fの最適化手法
#Adamと呼ばれる最適化手法を用います
generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
#discriminator x yの最適化手法
#Adamと呼ばれる最適化手法を用います
discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
モデルの訓練
ではモデルの訓練に入っていきましょう。少し難しい書き方をしていますので、ここはざっくりと読み飛ばしてもらって構いません。
#途中で学習が中断した時用に記録を残す
checkpoint_path = "./checkpoints/train"
ckpt = tf.train.Checkpoint(generator_g=generator_g,
generator_f=generator_f,
discriminator_x=discriminator_x,
discriminator_y=discriminator_y,
generator_g_optimizer=generator_g_optimizer,
generator_f_optimizer=generator_f_optimizer,
discriminator_x_optimizer=discriminator_x_optimizer,
discriminator_y_optimizer=discriminator_y_optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)
#最新版のチェックポイントをここで保存します
if ckpt_manager.latest_checkpoint:
ckpt.restore(ckpt_manager.latest_checkpoint)
print ('Latest checkpoint restored!!')
#予測画像を出力するための関数
def generate_images(model, test_input):
prediction = model(test_input)
plt.figure(figsize=(12, 12))
display_list = [test_input[0], prediction[0]]
title = ['Input Image', 'Predicted Image']
for i in range(2):
plt.subplot(1, 2, i+1)
plt.title(title[i])
plt.imshow(display_list[i] * 0.5 + 0.5)
plt.axis('off')
plt.show()
#deviceの設定(デフォルトはcpuでcudaが使えるのであれば、cudaを使います)
device = ['/device:GPU:0' if tf.test.gpu_device_name() =='/device:GPU:0' else '/cpu:0'][0]
#学習エポック
EPOCHS = 20
@tf.function
#訓練を実行するための関数
def train_step(real_x, real_y ,device):
with tf.device(device):
with tf.GradientTape(persistent=True) as tape:
# Generator G X -> Y の変換
# Generator F Y -> X の変換
# 一方向の変換だけではなく、それを逆変換して元に戻す処理も行う(cycle一貫性のため)
fake_y = generator_g(real_x, training=True)
cycled_x = generator_f(fake_y, training=True)
fake_x = generator_f(real_y, training=True)
cycled_y = generator_g(fake_x, training=True)
# same_x と same_y がアイデンティティー損失に使われます
same_x = generator_f(real_x, training=True)
same_y = generator_g(real_y, training=True)
#識別器xに本物のxを通した時の結果
disc_real_x = discriminator_x(real_x, training=True)
#識別器yに本物のyを通した時の結果
disc_real_y = discriminator_y(real_y, training=True)
#識別器xに偽物(生成した)のxを通した時の結果
disc_fake_x = discriminator_x(fake_x, training=True)
#識別器yに偽物(生成した)のyを通した時の結果
disc_fake_y = discriminator_y(fake_y, training=True)
#損失を計算します
gen_g_loss = generator_loss(disc_fake_y)
gen_f_loss = generator_loss(disc_fake_x)
#サイクル一貫性
total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)
# generatorの合計損失 = 敵対損失 + サイクル一貫性損失 + アイデンティティー損失
total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)
# discriminatorの損失
disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)
#generatorとdiscriminatorにおいて勾配を計算する
generator_g_gradients = tape.gradient(total_gen_g_loss,
generator_g.trainable_variables)
generator_f_gradients = tape.gradient(total_gen_f_loss,
generator_f.trainable_variables)
discriminator_x_gradients = tape.gradient(disc_x_loss,
discriminator_x.trainable_variables)
discriminator_y_gradients = tape.gradient(disc_y_loss,
discriminator_y.trainable_variables)
#勾配をoptimizerに作用させる
generator_g_optimizer.apply_gradients(zip(generator_g_gradients,
generator_g.trainable_variables))
generator_f_optimizer.apply_gradients(zip(generator_f_gradients,
generator_f.trainable_variables))
discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
discriminator_x.trainable_variables))
discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
discriminator_y.trainable_variables))
for epoch in range(EPOCHS):
start = time.time()
n = 0
for image_x, image_y in tf.data.Dataset.zip((train_x, train_y)):
train_step(image_x, image_y)
if n % 10 == 0:
print ('.', end='')
n += 1
clear_output(wait=True)
#途中経過を表示する(画像がそれっぽくなっていくのがわかる)
generate_images(generator_f, sample_y)
#5エポック毎にモデルの保存を行う
if (epoch + 1) % 5 == 0:
ckpt_save_path = ckpt_manager.save()
print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
ckpt_save_path))
print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
time.time()-start))
このブロックで行なっている処理は、以下の通りです。
- 画像を生成し、識別器で評価
- 識別器の損失関数を求め、勾配を計算して識別器の重みを更新
- 画像を生成し、生成器で評価
- 生成器の損失関数を求め、勾配を計算して生成器の重みを更新
- 1~4のサイクルで計算した各ネットワークでの損失関数を記録
- 5までのステップをnum_epochで指定した回数だけ繰り返す
行なっている処理は前回のDCGANの学習と本質的には変わらないのですが、ネットワークの個数が2個から4個に増えました。また、サイクル一貫性損失やアイデンティティー損失の評価のために、ドメインからターゲットへの変換のみではなく、ターゲットからドメインへの変換と計2回の変換が必要となります。
上の処理を行うと、GPU環境内でも時間がかかりますので、あらかじめ学習した学習済みモデルを読み込んで学習は終了したものとして以下進めていきます。自分で学習したいという方は、上のコードをご自分の環境で実行してみてください。
写真から絵画への変換
今回は事前にGPUを用いて学習をさせたネットワークを用います。それでは、学習済みのモデルを読み込みます。
一つのネットワークでも計5400万個の重みを学習していることになります。訓練時には、生成器と識別器の合計4個のネットワークがありますので、ざっくりと計算すると2億個の重みを学習していることになります。単純なネットワークと比較して規模がとても大きいことがわかります。では、最後に上のモデルを用いて実際に画像変換を行いましょう。葉書に使えそうなデザインを生成してみましょう。
今回は、宮城県、山形県の有名な観光名所で撮られた写真から絵画のようなデザインを生成してみましょう。どんな写真を使うか可視化してみましょう。
上から順番に、
- 仙台城伊達政宗像
- 宮城県松島町の紅葉
- 山形県山寺市立山寺(冬)
- 山形県山寺市立山寺(夏)
- 宮城県松島町円通院
- 宮城県大崎市鳴子狭
の写真が表示されました。今回はこの中からランダムな写真を選んで、絵画のような形式に変換してみたいと思います。(本当は6枚全て変換したいのですが、実行環境のリソースの観点からランダムに1枚を選ぶ形式となっております。)
全ての写真を変換すると以下のようになります。上の実装プログラムでは、下の図のどれか一つが表示されます。
6枚のうち、上面の3枚は絵画のような画像に変換されていることがわかります。下面の3枚はパッと見てあまり変化がないようにに見えます。今回はGPUなどの計算リソースの関係より、学習ループは20回しか回していませんので、学習回数を増やしてみるのも方法の一つです。
まとめ
今回は生成モデルの応用編ということでCycleGANと呼ばれるモデルについて学習しましょう。原理は通常のGANと全く同じですが、画像から画像の変換ということで生成器、識別器が増え、更にそれらを評価する目的関数の種類も増えたので少し大変だったと思います。しかし、任意の写真から絵を作り出すということができるのには驚いた方もいらっしゃるのではないでしょうか。実際に、世の中に出回っている「写真から肖像画」を作るアプリなどは、CycleGANをはじめとするGANモデルから作られています。興味のある人は生成モデルについて更に勉強してみてください。
参考文献
[1] Balraj Ashwath. (2020) “Moneto2Photo CycleGAN’s Monet Paintings & Natural Photos Dataset”. Kaggle
[2] Jun-Yan Zhu, Taesung Park, Phillip Isola Alexei A. Efros. (2020) “Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks”. Arxiv
[3] Olaf Ronneberger, Philipp Fischer, and Thomas Brox. (2015) “U-Net: Convolutional Networks for Biomedical Image Segmentation”. Arxiv
[4] TensorFlow Core. “CycleGAN”. TensorFlow