DEVELOPER’s BLOG
技術ブログ
【Python】TensorFlow / SegmentationチュートリアルでU-netを学習させる
はじめに
前回の「画像セグメンテーションのためのU-net概要紹介」では画像のクラス分類のタスクを、画像のSegmentationのタスクにどう発展させるかを解説し、SegmentationのネットワークであるU-netの理論ついて簡単に解説しました。
今回はTensorFlowのSegmentationのチュートリアルを行いながら、実際にU-netを学習させてみたいと思います。
尚、本記事ではTensorflowの詳しい解説は行いません。
参考 : https://www.tensorflow.org/tutorials/images/segmentation
Segmentationとは
ある物体が画像内に含まれている時、画像のどこにあるのかを推定するタスクのことです。
言い換えると「画像のピクセルがそれぞれ何かを推定する」タスクのことです。
今回はOxford-IIIT Pet Datasetというデータセットを用いて学習を行い、ピクセルごとに以下のようなクラス分けを行います。
- Class 0 : 動物のピクセル
- Class 1 : 動物とその他の境界線のピクセル
- Class 2 : その他のピクセル
目標は以下のような出力を得ることです。(左:入力画像 右:出力画像)
ライブラリのインポート
import tensorflow as tf import sys from IPython.display import display from IPython.display import HTML from PIL import Image # sys.modules['Image'] = Image from __future__ import absolute_import , division, print_function, unicode_literals from tensorflow_examples.models.pix2pix import pix2pix import tensorflow_datasets as tfds tfds.disable_progress_bar() from IPython.display import clear_output import matplotlib.pyplot as plt
データの読み込み&可視化
dataset, info = tfds.load('oxford_iiit_pet:3.0.0', with_info=True) def normalize(input_image,input_mask): input_image = tf.cast(input_image, tf.float32) / 255 input_mask -= 1 return input_image,input_mask @tf.function def load_image_train(datapoint): input_image = tf.image.resize(datapoint['image'],(128,128)) input_mask = tf.image.resize(datapoint['segmentation_mask'],(128,128)) if tf.random.uniform(()) > 0.5: input_image = tf.image.flip_left_right(input_image) input_mask = tf.image.flip_left_right(input_mask) input_image,input_mask = normalize(input_image, input_mask) return input_image, input_mask def load_image_test(datapoint): input_image = tf.image.resize(datapoint['image'], (128, 128)) input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128)) input_image, input_mask = normalize(input_image, input_mask) return input_image, input_mask TRAIN_LENGTH = info.splits['train'].num_examples BATCH_SIZE = 64 BUFFER_SIZE = 1000 STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE train = dataset['train'].map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE) test = dataset['test'].map(load_image_test) train_dataset = train.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat() train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) test_dataset = test.batch(BATCH_SIZE) def display(display_list): plt.figure(figsize=(15, 15)) title = ['Input Image', 'True Mask', 'Predicted Mask'] for i in range(len(display_list)): plt.subplot(1, len(display_list), i+1) plt.title(title[i]) plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i])) plt.axis('off') plt.show() for image, mask in train.take(88): sample_image, sample_mask = image, mask display([sample_image, sample_mask])
U-netの構築
U-netのencorder部分には今回のデータセットとは別のデータセットを学習したのMovileNetを用います。
MovileNetの重みは学習で更新しないように更新しておきます。
このようにすることU-netの精度を上げることができます。(転移学習)
decorder部分には未学習のpix2pixを用います。
pix2pixの重みは学習により更新されます。
OUTPUT_CHANNELS = 3 # encorder部分には学習済みのMovileNet base_model = tf.keras.applications.MobileNetV2(input_shape=[128,128,3],include_top=False) layer_names = [ 'block_1_expand_relu', 'block_3_expand_relu', 'block_6_expand_relu', 'block_13_expand_relu', 'block_16_project' ] layers = [base_model.get_layer(name).output for name in layer_names] down_stack = tf.keras.Model(inputs=base_model.input, outputs=layers) # MovileNetの重みは固定 down_stack.trainable = False # decorder部分にはpix2pixを用いる up_stack = [ pix2pix.upsample(512, 3), # 4x4 -> 8x8 pix2pix.upsample(256, 3), # 8x8 -> 16x16 pix2pix.upsample(128, 3), # 16x16 -> 32x32 pix2pix.upsample(64, 3), # 32x32 -> 64x64 ] def unet_model(output_channels): # This is the last layer of the model last = tf.keras.layers.Conv2DTranspose( output_channels, 3, strides=2, padding='same', activation='softmax') #64x64 -> 128x128 inputs = tf.keras.layers.Input(shape=[128, 128, 3]) x = inputs # Downsampling through the model skips = down_stack(x) x = skips[-1] skips = reversed(skips[:-1]) # Upsampling and establishing the skip connections for up, skip in zip(up_stack, skips): x = up(x) concat = tf.keras.layers.Concatenate() x = concat([x, skip]) x = last(x) return tf.keras.Model(inputs=inputs, outputs=x)
U-netの学習
実際にU-netが学習していく様子を眺めてみましょう。
epochが進むにつれ、正しく予測できているのがわかります。
class DisplayCallback(tf.keras.callbacks.Callback): def on_epoch_end(self, epoch,logs=None): clear_output(wait=True) ims.append([sample_image, sample_mask, create_mask(model.predict(sample_image[tf.newaxis, ...]))]) print ('\nSample Prediction after epoch {}\n'.format(epoch+1)) def create_mask(pred_mask): pred_mask = tf.argmax(pred_mask, axis=-1) pred_mask = pred_mask[..., tf.newaxis] return pred_mask[0] def show_predictions(dataset=None, num=1): if dataset: for image, mask in dataset.take(num): pred_mask = model.predict(image) display([image[0], mask[0], create_mask(pred_mask)]) ims.append([image[0], mask[0], create_mask(pred_mask)]) else: display([sample_image, sample_mask, create_mask(model.predict(sample_image[tf.newaxis, ...]))]) ims.append([sample_image, sample_mask, create_mask(model.predict(sample_image[tf.newaxis, ...]))]) ims = [] EPOCHS = 100 VAL_SUBSPLITS = 5 VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS model = unet_model(OUTPUT_CHANNELS) model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) model_history = model.fit(train_dataset, epochs=EPOCHS, steps_per_epoch=STEPS_PER_EPOCH, validation_steps=VALIDATION_STEPS, validation_data=test_dataset, callbacks=[DisplayCallback()]) import matplotlib.animation as animation def make_animation(ims): %matplotlib nbagg fig, (ax0, ax1, ax2) = plt.subplots(1,3,figsize=(14.0, 8.0)) ax0.axis('off') ax1.axis('off') ax2.axis('off') ax0.set_title('Input Image') ax1.set_title('True Mask') ax2.set_title('Predicted Mask') ims2 = [] for epoch,im in enumerate(ims): im0, = [ax0.imshow(tf.keras.preprocessing.image.array_to_img(im[0]))] im1, = [ax1.imshow(tf.keras.preprocessing.image.array_to_img(im[1]))] im2, = [ax2.imshow(tf.keras.preprocessing.image.array_to_img(im[2]))] ims2.append([im0,im1,im2]) ani = animation.ArtistAnimation(fig, ims2, interval=50, repeat_delay=1000) return ani
学習の様子
学習の様子を見てみます。epochが進むごとに精度が増していることがわかります。
ani1 = make_animation(ims) HTML(ani1.to_jshtml())
おまけ
U-netにはskip-conectionという手法が使われています。
encorder部分で畳み込みをして失ってしまった画像内の位置情報を保持する役割を持ちます。
U-netにskip-conectionが無い場合も比較してみましょう。
up_stack = [ pix2pix.upsample(512, 3), # 4x4 -> 8x8 pix2pix.upsample(256, 3), # 8x8 -> 16x16 pix2pix.upsample(128, 3), # 16x16 -> 32x32 pix2pix.upsample(64, 3), # 32x32 -> 64x64 ] def unet_model_no_sc(output_channels): # This is the last layer of the model last = tf.keras.layers.Conv2DTranspose( output_channels, 3, strides=2, padding='same', activation='softmax') #64x64 -> 128x128 inputs = tf.keras.layers.Input(shape=[128, 128, 3]) x = inputs x = down_stack(x) x = x[-1] for up in up_stack: x = up(x) x = last(x) return tf.keras.Model(inputs=inputs, outputs=x) ims = [] EPOCHS = 100 VAL_SUBSPLITS = 5 VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS model = unet_model_no_sc(OUTPUT_CHANNELS) model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) model_history = model.fit(train_dataset, epochs=EPOCHS, steps_per_epoch=STEPS_PER_EPOCH, validation_steps=VALIDATION_STEPS, validation_data=test_dataset, callbacks=[DisplayCallback()])
学習の様子
こちらも学習の経過を眺めてみます。
skip-conectionがある場合と比べ、学習が遅いばかりか精度が悪いことがわかります。
Twitter・Facebookで定期的に情報発信しています!
Follow @acceluniverse