본문 바로가기
프로그래밍/TensorFlow2

Model Subclassing 멀티 입력 신경망 모델 구현 방법

by 세인트 워터멜론 2021. 3. 16.

Model Subclassing API를 사용하여 입력을 여러 개 갖는 즉, 멀티 입력 신경망 모델을 어떻게 구현하고 빌드(build)할 수 있을까.

 

 

강화학습의 DDPG알고리즘에서는 행동가치 함수(actor-value function)를 크리틱(critic) 신경망으로 구현한다. 크리틱 신경망은 입력으로 상태(state)와 행동(action)등 두 개를 받는데, 이를 Model Subclassing API를 이용해서 구현해 보자. 구현해야 할 신경망 구조는 다음 그림과 같다. 상태변수를 첫번째 은닉층에서 처리한 후 두번째 은닉층에서 행동과 병합하는 구조다.

 

 

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, concatenate

class Critic(Model):

    def __init__(self):
        super(Critic, self).__init__()

        self.x1 = Dense(64, activation='relu', name='state_input')
        self.x2 = Dense(32, activation='linear')
        self.a1 = Dense(32, activation='linear', name='action_input')
        self.h3 = Dense(16, activation='relu')
        self.q = Dense(1, activation='linear', name='action_value')


    def call(self, state_action):
        state = state_action[0]
        action = state_action[1]
        x = self.x1(state)
        x = self.x2(x)
        a = self.a1(action)
        h = concatenate([x, a], axis=-1)
        x = self.h3(h)
        q = self.q(x)
        return q

 

여기서 __init__(self) 는 객체가 생성될 때 호출되는 함수이고, call(self) 는 객체 변수를 실행할 때 호출되는 함수다.

call 함수에서는 입력을 한 개의 텐서 state_action 으로 받아서 앞 단은 state 로, 뒷 단은 action 으로 분리한다.

 

        state = state_action[0]
        action = state_action[1]

 

critic 객체는 다음과 같이 생성한다.

 

critic = Critic()

 

critic 을 빌드하려면 입력 데이터를 넣고 critic 을 호출하든가 아니면 입력 사이즈를 정해줘야 한다. 다음과 같이 layers.Input 을 이용하여 입력 사이즈를 정해주고 모델을 빌드한다. 여기서는 입력이 두 개이므로 두 개 모두 사이즈를 정해준다.

 

state_in = Input((3,))
action_in = Input((1,))
critic([state_in, action_in])

 

빌드가 됐는지 확인하기 위해서 모델 서머리(summary)를 해본다. 성공적으로 빌드된 것을 확인할 수 있다.

 

critic.summary()

 

Output:

Model: "critic"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
state_input (Dense)          multiple                  256       
_________________________________________________________________
dense (Dense)                multiple                  2080      
_________________________________________________________________
action_input (Dense)         multiple                  64        
_________________________________________________________________
dense_1 (Dense)              multiple                  1040      
_________________________________________________________________
action_value (Dense)         multiple                  17        
=================================================================
Total params: 3,457
Trainable params: 3,457
Non-trainable params: 0
_________________________________________________________________

 

 

 

'프로그래밍 > TensorFlow2' 카테고리의 다른 글

Mac M1 Pro 에 Tensorflow, Gym, Mujoco 설치하기  (0) 2022.09.19
TensorFlow2에서 1차, 2차 편미분 계산  (0) 2021.07.25
tf.reduce_sum 함수  (0) 2021.03.12
텐서와 변수 - 3  (0) 2021.02.11
텐서와 변수 - 2  (0) 2021.02.10

댓글