Model Subclassing API를 사용하여 입력을 여러 개 갖는 즉, 멀티 입력 신경망 모델을 어떻게 구현하고 빌드(build)할 수 있을까.
강화학습의 DDPG알고리즘에서는 행동가치 함수(actor-value function)를 크리틱(critic) 신경망으로 구현한다. 크리틱 신경망은 입력으로 상태(state)와 행동(action)등 두 개를 받는데, 이를 Model Subclassing API를 이용해서 구현해 보자. 구현해야 할 신경망 구조는 다음 그림과 같다. 상태변수를 첫번째 은닉층에서 처리한 후 두번째 은닉층에서 행동과 병합하는 구조다.
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, concatenate
class Critic(Model):
def __init__(self):
super(Critic, self).__init__()
self.x1 = Dense(64, activation='relu', name='state_input')
self.x2 = Dense(32, activation='linear')
self.a1 = Dense(32, activation='linear', name='action_input')
self.h3 = Dense(16, activation='relu')
self.q = Dense(1, activation='linear', name='action_value')
def call(self, state_action):
state = state_action[0]
action = state_action[1]
x = self.x1(state)
x = self.x2(x)
a = self.a1(action)
h = concatenate([x, a], axis=-1)
x = self.h3(h)
q = self.q(x)
return q
여기서 __init__(self) 는 객체가 생성될 때 호출되는 함수이고, call(self) 는 객체 변수를 실행할 때 호출되는 함수다.
call 함수에서는 입력을 한 개의 텐서 state_action 으로 받아서 앞 단은 state 로, 뒷 단은 action 으로 분리한다.
state = state_action[0]
action = state_action[1]
critic 객체는 다음과 같이 생성한다.
critic = Critic()
critic 을 빌드하려면 입력 데이터를 넣고 critic 을 호출하든가 아니면 입력 사이즈를 정해줘야 한다. 다음과 같이 layers.Input 을 이용하여 입력 사이즈를 정해주고 모델을 빌드한다. 여기서는 입력이 두 개이므로 두 개 모두 사이즈를 정해준다.
state_in = Input((3,))
action_in = Input((1,))
critic([state_in, action_in])
빌드가 됐는지 확인하기 위해서 모델 서머리(summary)를 해본다. 성공적으로 빌드된 것을 확인할 수 있다.
critic.summary()
Output:
Model: "critic"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
state_input (Dense) multiple 256
_________________________________________________________________
dense (Dense) multiple 2080
_________________________________________________________________
action_input (Dense) multiple 64
_________________________________________________________________
dense_1 (Dense) multiple 1040
_________________________________________________________________
action_value (Dense) multiple 17
=================================================================
Total params: 3,457
Trainable params: 3,457
Non-trainable params: 0
_________________________________________________________________
'프로그래밍 > TensorFlow2' 카테고리의 다른 글
Mac M1 Pro 에 Tensorflow, Gym, Mujoco 설치하기 (0) | 2022.09.19 |
---|---|
TensorFlow2에서 1차, 2차 편미분 계산 (0) | 2021.07.25 |
tf.reduce_sum 함수 (0) | 2021.03.12 |
텐서와 변수 - 3 (0) | 2021.02.11 |
텐서와 변수 - 2 (0) | 2021.02.10 |
댓글