본문 바로가기
AI 딥러닝/DLA

[U-Net] U-Net 구조

by 세인트워터멜론 2022. 5. 11.

이미지 세그멘테이션(image segmentation)은 이미지의 모든 픽셀이 어떤 카테고리(예를 들면 자동차, 사람, 도로 등)에 속하는지 분류하는 것을 말한다.

 

 

이미지 전체에 대해 단일 카테고리를 예측하는 이미지 분류(image classification)와는 달리, 이미지 세그멘테이션은 픽셀 단위의 분류를 수행하므로 일반적으로 더 어려운 문제로 인식되고 있다.

 

 

위 그림에서 semantic segmentation은 이미지 내에 있는 객체들을 의미 있는 단위로 분할해내는 것이고, instance segmentation 은 같은 카테고리에 속하는 서로 다른 객체까지 더 분할하여 semantic segmentation 범위를 확장한 것이다.

이미지 세그멘테이션은 의료 이미지 분석(종양 경계 추출 등), 자율주행 차량(도로면, 보행자 감지 등) 및 증강현실과 같은 광범위한 분야에서 사용되고 있다.

딥러닝을 이용한 이미지 세그멘테이션은 수백개의 알고리즘이 제안되어 있다고 하는데, 그 중 U-Net 모델을 Tensorflow 2로 구현해보고자 한다.

U-Net은 'U-Net: Convolutional Networks for Biomedical Image Segmentation' 이라는 논문에서 제안한 구조로서 매우 적은 수의 학습 데이터로도 정확한 이미지 세그멘테이션 성능을 보여주었으며 ISBI 세포 추적 챌린지 2015에서 큰 점수 차이로 우승했다고 한다.

U-Net은 오토인코더(autoencoder)와 같은 인코더-디코더(encoder-decoder) 기반 모델에 속한다. 보통 인코딩 단계에서는 입력 이미지의 특징을 포착할 수 있도록 채널의 수를 늘리면서 차원을 축소해 나가며, 디코딩 단계에서는 저차원으로 인코딩된 정보만 이용하여 채널의 수를 줄이고 차원을 늘려서 고차원의 이미지를 복원한다. 하지만 인코딩 단계에서 차원 축소를 거치면서 이미지 객체에 대한 자세한 위치 정보를 잃게 되고, 디코딩 단계에서도 저차원의 정보만을 이용하기 때문에 위치 정보 손실을 회복하지 못하게 된다.

 

 

U-Net의 기본 아이디어는 저차원 뿐만 아니라 고차원 정보도 이용하여 이미지의 특징을 추출함과 동시에 정확한 위치 파악도 가능하게 하자는 것이다. 이를 위해서 인코딩 단계의 각 레이어에서 얻은 특징을 디코딩 단계의 각 레이어에 합치는(concatenation) 방법을 사용한다. 인코더 레이어와 디코더 레이어의 직접 연결을 스킵 연결(skip connection)이라고 한다.

 

 

원래 논문에서는 신경망 구조를 스킵 연결을 평행하게 두고 가운데를 기준으로 좌우가 대칭이 되도록 레이어를 배치하여, 이름 그대로 U자 형으로 만들었다.

 

 

U-Net 은 인코더 또는 축소경로(contracting path)와 디코더 또는 확장경로(expending path)로 구성되며 두 구조는 서로 대칭적이다. 인코더와 디코더를 연결하는 부분을 브릿지(bridge)라고 한다. 인코더와 디코더에서는 모두 3x3 컨볼루션을 사용한다. 인코더의 자세한 구조는 다음 그림과 같다. 맨 아래의 블록은 브릿지이다.

 

 

그림에서 세로 방향 숫자는 맵(map)의 차원을 표시하고 가로 방향 숫자는 채널 수를 표시한다. 예를 들면 세로 방향 숫자 256x256 과 가로 방향 숫자 128 은 해당 레이어의 이미지가 256x256x128 임을 의미한다. 입력 이미지는 512x512x3 이므로 RGB 3개 채널을 갖고 크기가 512x512 인 이미지를 나타낸다.

 

 

그림에서 파란색 박스가 인코더의 각 단계마다 계속 반복하여 나타나는 것을 볼 수 있는데, 이 박스는 3x3 컨볼루션, Batch Normalization, ReLU 활성화 함수가 차례로 배치된 것을 나타낸다. 이 박스 두 개를 한데 묶어서 한 개의 레이어 블록으로 구현하여 사용하면 편리하다. 이 블록 이름을 ConvBlock이라고 하자.

 

 

다음 코드는 ConvBlock을 구현한 것이다.

 

""" Conv Block """
class ConvBlock(tf.keras.layers.Layer):
    def __init__(self, n_filters):
        super(ConvBlock, self).__init__()

        self.conv1 = Conv2D(n_filters, 3, padding='same')
        self.conv2 = Conv2D(n_filters, 3, padding='same')

        self.bn1 = BatchNormalization()
        self.bn2 = BatchNormalization()

        self.activation = Activation('relu')

    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.activation(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.activation(x)

        return x

 

인코더 그림을 보면 보라색 박스안에 한 개의 ConvBlock이 있고 이 박스가 인코더의 각 단계마다 나타나는 것을 볼 수 있다. 이 박스에서 나오는 출력이 2개인데, 한 개의 출력은 U-Net의 디코더로 복사하기 위한 연결선이며, 또 한 개의 출력은 2x2 max pooling 으로 다운 샘플링(down sampling)하여 인코더의 다음 단계로 내보내는 빨간색 화살선이다. 이 박스도 한 개의 레이어 블록으로 구현하여 사용하면 편리하다. 이 블록 이름을 EncoderBlock이라고 하자.

 

 

다음 코드는 EncoderBlock을 구현한 것이다.

 

""" Encoder Block """
class EncoderBlock(tf.keras.layers.Layer):
    def __init__(self, n_filters):
        super(EncoderBlock, self).__init__()

        self.conv_blk = ConvBlock(n_filters)
        self.pool = MaxPooling2D((2,2))

    def call(self, inputs):
        x = self.conv_blk(inputs)
        p = self.pool(x)
        return x, p

 

브릿지는 두개의 파란색 박스로만 구성되어 있으므로 1개의 ConvBlock 레이어로 표현할 수 있다.

 

 

디코더의 자세한 구조는 다음 그림과 같다.

 

 

그림에서 파란색 박스 2개는 인코더에 있는 ConvBlock와 동일하다. 녹색 박스는 스킵 연결을 통해서 인코더에 있는 맵을 복사한 것이다. 노란색 박스는 디코더의 하위 단계에서 전치 컨볼루션(transposed convolution)을 통해서 맵의 차원을 두배로 늘리면서 채널 수를 반으로 줄인 것이다. 두 개의 맵을 서로 합쳐서(concatenation) 저차원 이미지 정보뿐만 아니라 고차원 정보도 이용할 수 있는 것이다.

디코더의 그림에서도 회색 박스가 반복적으로 나타나므로 한 개의 레이어 블록으로 구현하여 사용하면 편리하다. 이 블록 이름을 DecoderBlock 이라고 하자.

 

 

다음 코드는 DecoderBlock을 구현한 것이다.

 

""" Decoder Block """
class DecoderBlock(tf.keras.layers.Layer):
    def __init__(self, n_filters):
        super(DecoderBlock, self).__init__()

        self.up = Conv2DTranspose(n_filters, (2,2), strides=2, padding='same')
        self.conv_blk = ConvBlock(n_filters)

    def call(self, inputs, skip):
        x = self.up(inputs)
        x = Concatenate()([x, skip])
        x = self.conv_blk(x)

        return x

 

디코더 그림의 맨 상단의 오른쪽 부분은 U-Net의 출력으로서 1x1 컨볼루션으로 특징 맵을 처리하여 입력 이미지의 각 픽셀을 분류하는 세그멘테이션 맵을 생성하는 부분이다. 컨볼루션 필터의 개수는 분류할 카테고리 개수이며 활성 함수로는 카테고리 수가 1개라면 sigmoid 함수를, 여러 개라면 softmax 함수를 사용한다.

EncoderBlock과 DecoderBlock 을 사용하면 U-Net을 다음과 같이 간단히 코드로 구현할 수 있다.

 

""" U-Net Model """
class UNET(tf.keras.Model):
    def __init__(self, n_classes):
        super(UNET, self).__init__()

        # Encoder
        self.e1 = EncoderBlock(64)
        self.e2 = EncoderBlock(128)
        self.e3 = EncoderBlock(256)
        self.e4 = EncoderBlock(512)

        # Bridge
        self.b = ConvBlock(1024)

        # Decoder
        self.d1 = DecoderBlock(512)
        self.d2 = DecoderBlock(256)
        self.d3 = DecoderBlock(128)
        self.d4 = DecoderBlock(64)

        # Outputs
        if n_classes == 1:
            activation = 'sigmoid'
        else:
            activation = 'softmax'

        self.outputs = Conv2D(n_classes, 1, padding='same', activation=activation)

    def call(self, inputs):
        s1, p1 = self.e1(inputs)
        s2, p2 = self.e2(p1)
        s3, p3 = self.e3(p2)
        s4, p4 = self.e4(p3)

        b = self.b(p4)

        d1 = self.d1(b, s4)
        d2 = self.d2(d1, s3)
        d3 = self.d3(d2, s2)
        d4 = self.d4(d3, s1)

        outputs = self.outputs(d4)

        return outputs

 

U-Net 모델의 전체 코드는 다음과 같다.


unet_model.py

 

# U-Net model
# coded by st.watermelon

import tensorflow as tf
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Conv2DTranspose
from tensorflow.keras.layers import Activation, BatchNormalization, Concatenate


""" Conv Block """
class ConvBlock(tf.keras.layers.Layer):
    def __init__(self, n_filters):
        super(ConvBlock, self).__init__()

        self.conv1 = Conv2D(n_filters, 3, padding='same')
        self.conv2 = Conv2D(n_filters, 3, padding='same')

        self.bn1 = BatchNormalization()
        self.bn2 = BatchNormalization()

        self.activation = Activation('relu')

    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.activation(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.activation(x)

        return x


""" Encoder Block """
class EncoderBlock(tf.keras.layers.Layer):
    def __init__(self, n_filters):
        super(EncoderBlock, self).__init__()

        self.conv_blk = ConvBlock(n_filters)
        self.pool = MaxPooling2D((2,2))

    def call(self, inputs):
        x = self.conv_blk(inputs)
        p = self.pool(x)
        return x, p


""" Decoder Block """
class DecoderBlock(tf.keras.layers.Layer):
    def __init__(self, n_filters):
        super(DecoderBlock, self).__init__()

        self.up = Conv2DTranspose(n_filters, (2,2), strides=2, padding='same')
        self.conv_blk = ConvBlock(n_filters)

    def call(self, inputs, skip):
        x = self.up(inputs)
        x = Concatenate()([x, skip])
        x = self.conv_blk(x)

        return x


""" U-Net Model """
class UNET(tf.keras.Model):
    def __init__(self, n_classes):
        super(UNET, self).__init__()

        # Encoder
        self.e1 = EncoderBlock(64)
        self.e2 = EncoderBlock(128)
        self.e3 = EncoderBlock(256)
        self.e4 = EncoderBlock(512)

        # Bridge
        self.b = ConvBlock(1024)

        # Decoder
        self.d1 = DecoderBlock(512)
        self.d2 = DecoderBlock(256)
        self.d3 = DecoderBlock(128)
        self.d4 = DecoderBlock(64)

        # Outputs
        if n_classes == 1:
            activation = 'sigmoid'
        else:
            activation = 'softmax'

        self.outputs = Conv2D(n_classes, 1, padding='same', activation=activation)

    def call(self, inputs):
        s1, p1 = self.e1(inputs)
        s2, p2 = self.e2(p1)
        s3, p3 = self.e3(p2)
        s4, p4 = self.e4(p3)

        b = self.b(p4)

        d1 = self.d1(b, s4)
        d2 = self.d2(d1, s3)
        d3 = self.d3(d2, s2)
        d4 = self.d4(d3, s1)

        outputs = self.outputs(d4)

        return outputs

 

 

 

댓글0