본문 바로가기
AI 딥러닝/Sequence

[PtrNet] Pointer Net 구조

by 깊은대학 2023. 9. 12.

조합 최적화(combinatorial optimization)는 개별 개체의 조합으로 이루어진 목적함수의 최대값(또는 최소값)을 구하는 문제이다. 대표적인 예로서는 TSP(traveling salesman problem, 순회외판원문제), Job-shop Scheduling, Knapsack Problem(배낭문제) 등이 있다.

참고로 세가지 문제를 간략히 설명하면 다음과 같다.

TSP 는 \(n\) 개의 서로 다른 도시의 좌표 \((x, y)\) 가 주어졌을 때, 각 도시를 한번씩 모두 방문하는 최단 경로를 찾는 문제다.

Job-shop Scheduling은 수행해야 하는 일련의 작업과 이러한 작업을 수행하는 데 필요한 도구 세트가 주어졌을 때, 모든 작업이 완료될 때까지 걸리는 총 시간을 최소화하기 위해서 어떤 작업을 어떤 도구를 사용하여 언제 수행해야 하는지에 대한 스케줄링 문제다.

Knapsack Problem 은 가격과 무게가 서로 다른 \(N\) 개의 아이템과 최대 \(W\) 의 무게를 담을 수 있는 가방이 주어졌을 때, 가격의 총액이 최대가 되도록 아이템을 선별하여 가방에 넣는 문제다.

 

 

조합 최적화 문제의 대부분은 NP-hard문제에 해당하기 때문에 다항식 시간 최적해를 구할 수 없다. 이 때문에 개별 개체의 수가 매우 큰 경우에는 최적해 대신에 빠르고 효율적으로 계산할 수 있는 근사해가 선호되며, 이에 관한 알고리즘이 많이 나와있다.

전통적인 알고리즘 외에 신경망을 이용하여 조합 최적화 문제를 풀기 위한 시도가 있는데, Oriol Vinyals 가 제안한 포인터넷(pointer network)이 대표적이다. 관련 내용은 2017년에 출간된 논문 'Pointer Networks, O. Vinyals, M. Fortunato, and N. Jaitly' 에 있다.

포인터넷은 어텐션이 포함된 seq2seq 모델 (https://pasus.tistory.com/291) 을 변형한 것이다. 한 시퀀스를 다른 시퀀스로 변환하는 대신 입력 시퀀스의 요소에 대한 일련의 포인터를 생성한다. seq2seq 모델에서 어텐션이 입력 요소에 가중치를 두었다면 포인터넷에서는 입력 시퀀스의 특정 위치를 선택(포인팅)하여 선택된 입력 시퀀스의 요소를 출력한다.

 

 

논문 'Pointer Networks' 에서는 세가지 다른 조합 최적화 문제에 대한 해를 학습하는 데 사용될 수 있음을 보여줬지만, 여기서는 그 중 TSP를 예로 들어서 포인터넷의 구조를 구체적으로 설명하고 Tensorflow2로 구현해 보고자 한다.

TSP는 수십 년 동안 연구된 조합 최적화 문제다. TSP는 이동 경로 계획, 제조업에서의 생산 계획, 물류에서의 적재 계획, 마이크로칩 설계, 유전학 등 응용 분야가 꽤 넓다. TSP에서 세일즈맨이 순회해야 할 도시의 수가 적다면 전수 조사를 통해 최적의 이동 경로를 계산할 수 있지만 그렇지 않고 일정 수준 이상으로 도시의 수가 커진다면 최적해를 계산하는 데는 시간과 비용이 많이 발생한다. 이 때문에 최적해 대신에 근사해를 빠르게 계산해 주는 알고리즘이 많이 제안되었다.

TSP 학습 및 테스트 데이터셋은 다음 사이트에 공개되어 있다.

 

https://drive.google.com/drive/folders/0B2fg8yPGn2TCMzBtS0o4Q2RJaEU?resourcekey=0-46fqXNrTmcUA4MfT6GLcIg

 

데이터셋에는 \(n\) 개 도시의 2D 좌표와 이 도시 집합에 대한 해당 TSP 최적 경로가 나와 있다. 도시의 개수 \(n\) 은 \(5\) 에서 \(20\) 까지이며 각 \(n\) 에 대해 100,000개의 경로가 있다. 데이터의 구조는 다음과 같다.

 

\[ x_1 \ y_1 \ x_2 \ y_2 \ x_3 \ y_3 \ x_4 \ y_4 \ x_5 \ y_5 \ \mbox{output} \ 1 \ 4 \ 2 \ 3 \ 5 \ 1 \]

 

먼저 도시의 2D 좌표 \((x_1 \ y_1 \ x_2 \ y_2 \ x_3 \ y_3 \ x_4 \ y_4 \ x_5 \ y_5 )\) 가 있고, 그 다음에는 "output", 그리고 도시에 대한 방문 순서 \((1 \ 4 \ 2 \ 3 \ 5 \ 1)\) 가 있다. 세일즈맨이 도시를 방문하는 순서는 도시 1부터 5까지의 자연수 시퀀스로 표현한다. 시퀀스의 각 번호는 도시의 위치에 해당한다. 위 예에서 세일즈맨은 1번 도시에서 시작하여 4, 2, 3, 5번 도시 순으로 이동한 후 다시 1번 도시로 돌아온다. 도시의 좌표는 \((x_i, y_i)\) 로서 4번째 도시라면 \((x_4, y_4)\) 가 된다.

 

 

이제 포인터넷의 구조를 구체적으로 살펴보자. 먼저 입력 시퀀스는 세일즈맨이 순회하여야 할 도시의 좌표 \((x_t, y_t ), \ \ t=1,...,n\) 이고 포인터넷의 인코더는 입력 시퀀스의 각 성분을 차례로 입력으로 받는다. 입력 벡터의 차원이 \(2\) , 입력 시퀀스의 길이가 \(n\) 이므로, 외부 입력은 \(\mathbf{x}_1, \ \mathbf{x}_2, \ \mathbf{x}_3, ..., , \mathbf{x}_n \in \mathbb{R}^2\) 로 한다. 은닉상태 (hidden state)의 차원은 \(128\) 로 하겠다. 즉 \(\mathbf{h}_t \in \mathbb{R}^{128}\). 다음은 인코더 LSTM모델을 나타낸 것이다.

 

 

인코더는 매 시간스텝 \(t\) 마다 은닉상태 \(\mathbf{h}_t\) 를 이용하여 키(key)를 생성한다. 어텐션이 포함된 seq2seq 모델 (https://pasus.tistory.com/291) 과는 달리 밸류(value)는 사용하지 않는다. 키(key)는 은닉상태의 선형 또는 비선형 함수로 생성할 수 있지만 여기서는 은닉상태를 그대로 키(key)로 사용한다.

 

\[ \mathbf{k}_t = \mathbf{h}_t \]

 

 

인코더 모델은 어텐션이 포함된 seq2seq 모델과 동일하며 이를 Tensorflow2로 구현하면 다음과 같다.

 

class Encoder(Model):

    def __init__(self, hidden_state_dim):
        super(Encoder, self).__init__()

        self.lstm = LSTM(units=hidden_state_dim, return_sequences=True, return_state=True)

    def call(self, enc_input):
        # enc_hiddens=(batch, seq_len, hidden_dim),
        # h_st=(batch, hidden_dim), c_st=(batch, hidden_dim)
        enc_hiddens, h_st, c_st = self.lstm(enc_input)
        enc_states = [h_st, c_st]
        return enc_hiddens, enc_states

 

return_sequences와 return_state를 모두 True로 설정하여 각 시간스텝마다 은닉상태를 출력하고 마지막 은닉 상태, 마지막 셀상태 값을 출력하도록 하였다.

 

 

디코더는 인코더와 마찬가지로 외부 입력으로 도시의 좌표를 받으므로 입력 벡터의 길이는 \(2\) 이다. 입력 시퀀스의 길이가 \(n\) 이므로, 외부 입력은 \(\mathbf{x}'_0, \mathbf{x}'_1, \mathbf{x}'_2, \mathbf{x}'_3, ... , \mathbf{x}'_n \in \mathbb{R}^2\) 로 한다.

디코더의 첫번째 외부 입력은 특수기호 [SOS] 이다. 출력 시퀀스의 시작을 나타내는데 사용되는 것으로 start of sequence를 의미한다. 여기서는 [SOS]를 벡터 \(\mathbf{x}'_0=[0, \ 0]\) 으로 정하겠다. 보통 출력 시퀀스의 끝을 표시하는 데에도 특수 기호인 (end of sequence)를 사용하는데 본 예제는 출력 시퀀스의 길이가 정해져 있으므로 를 사용하지 않는다.

디코더는 특정 시간스텝에서 필요한 정보를 인코더가 가지고 있는지 여부를 인코더의 각 스텝에게 물어볼 질의(query) 벡터 \(\mathbf{q}_l\) 를 생성한다. 질의벡터는 해당 시간스텝의 디코더 은닉상태를 그대로 사용한다.

 

 

디코더의 출력 시퀀스는 세일즈맨의 경로를 최소화 하는 도시를 순서대로 방문하는 것이므로 출력의 차원은 도시의 일련 번호를 나타내는 숫자 (1번 도시, 3번 도시 등)를 원핫(one-hot) 인코딩한 것으로 \(n\) 이다. 따라서 출력의 차원은 \(\mathbf{y}_0, \mathbf{y}_1, \mathbf{y}_2, ... , \mathbf{y}_n \in \mathbb{R}^n\) 이다. 은닉상태의 차원은 인코더와 동일하게 \(\mathbf{s}_l \in \mathbb{R}^{128}\) 로 한다.

 

 

디코더의 특정 시간스텝에서 필요한 정보를 인코더가 가지고 있는지 여부를 인코더의 각 스텝에 질의(query)하면 어텐션 매커니즘이 디코더의 질의(query)에 대해 인코더의 키(key)를 모두 비교하여 유사한 정도를 점수(score)로 계산하게 된다. 디코더 시간스텝 \(l\) 에서의 질의(query) 벡터와 인코더 시간스텝 \(t\) 에서의 키(key) 벡터의 유사도를 나타내는 어텐션 점수(attention score) \(e_{t,l}\) 은 바다나우(Bahdanau)가 제안한 concat 방식으로 계산한다.

 

\[ \mbox{concat: } \ e_{t,l}= V^T \tanh (W_1 \mathbf{h}_t+ W_2 \mathbf{s}_l ) \]

 

여기서 행렬 \(W_1, W_2, V\) 등은 모두 학습 대상이다.

 

 

인코더의 모든 시간스텝에서 계산한 어텐션 점수는 소프트맥스(softmax) 함수를 통해 인코더의 모든 시간스텝에 대한 일종의 확률분포로 만든다. 이 확률 값 \(\alpha_{t,l}\) 을 어텐션 가중값(weighting) 이라고 한다.

 

\[ \alpha_{t,l}= \mbox{softmax} (e_{t,l} )= \frac{ \exp⁡ (e_{t,l} )}{\sum_{t'} \exp (e_{t',l} ) } \]

 

포인터넷에서 어텐션 메커니즘의 출력은 입력 시퀀스의 길이가 동일한 소프트맥스(softmax) 분포다.

 

 

어텐션 모듈을 Tensorflow2로 구현하면 다음과 같다.

 

class Attention(Layer):

    def __init__(self, attn_units):
        super(Attention, self).__init__()
        self.W1 = Dense(attn_units, use_bias=False)  # key
        self.W2 = Dense(attn_units, use_bias=False)  # query
        self.V = Dense(1, use_bias=False)

    def call(self, enc_hiddens, dec_hidden):
        # enc_hiddens = (batch, seq_len_enc, hidden_state_dim)
        # dec_hidden = (batch, hidden_state_dim)
        # dec_hidden_exp = (batch, 1, hidden_state_dim)
        dec_hidden_exp = tf.expand_dims(dec_hidden, 1)
        keys = self.W1(enc_hiddens)  # (batch, seq_len, attn_units)
        query = self.W2(dec_hidden_exp)  # (batch, 1, attn_units)
        tanh_output = tf.nn.tanh(keys + query)
        score = self.V(tanh_output)  # (batch, seq_len_enc, 1)
        attention_weights = tf.nn.softmax(score, axis=1)  # (batch, seq_len_enc, 1)

        return attention_weights

 

이제 디코더 시간스텝 \(l\) 에서의 질의(query)에 대한 답신으로서 어텐션 가중값을 디코더로 보내준다. 어텐션 가중값을 수신한 디코더의 시간스텝 \(l\) 에서는 이것을 출력으로 내보낸다. 최종적으로 디코더의 출력에서 최대 확률을 갖는 요소를 \(\mbox{argmax}\) 로 선택하면 시간스텝 \(l\) 에서 세일즈맨이 방문해야 할 도시가 선택된다.

 

 

이런 식으로 디코더의 시간스텝 \(l=0\) 부터 \(l=n\) 까지 위와 같은 과정을 반복하여 출력 시퀀스를 계산하면 된다. 포인터넷의 디코더 모델을 Tensorflow2로 구현하면 다음과 같다.

 

class Decoder(Model):

    def __init__(self, attn_units, hidden_state_dim):
        super(Decoder, self).__init__()

        self.lstm = LSTM(units=hidden_state_dim, return_sequences=True, return_state=True)
        self.pointer = Attention(attn_units)

    def call(self, dec_input, enc_hiddens, enc_states):
        # enc_hiddens = [h1, h2, ..., hn] = (batch, seq_len_enc, hidden_state_dim)
        # dec_input=(batch, seq_len_dec, input_dim)
        # enc_states = [h_st, c_st]
        dec_hiddens, h_st, c_st = self.lstm(dec_input, initial_state=enc_states)
        seq_len_dec = tf.shape(dec_hiddens)[1]

        # Use tf.TensorArray to handle dynamic loops in TensorFlow
        dec_out_array = tf.TensorArray(dtype=tf.float32, size=seq_len_dec, dynamic_size=True)

        for t in tf.range(seq_len_dec):
            dec_hidden_t = dec_hiddens[:, t, :]
            out_t = self.pointer(enc_hiddens, dec_hidden_t) # (batch, seq_len_enc, 1)
            out_t = tf.squeeze(out_t, -1) # (batch, seq_len_enc)
            dec_out_array = dec_out_array.write(t, out_t)

        dec_out = dec_out_array.stack()  # Convert TensorArray to tensor
        dec_out = tf.transpose(dec_out, perm=[1, 0, 2])  # Transpose to get (batch, seq_len_dec, seq_len_enc)

        return dec_out, h_st, c_st

 

인코더와 어텐션, 그리고 디코더를 연결하여 포인터넷의 모델을 Tensorflow2로 구현하면 다음과 같다.

 

class PtrNet(Model):

    def __init__(self, attn_units, hidden_state_dim):
        super(PtrNet, self).__init__()

        self.encoder = Encoder(hidden_state_dim)
        self.decoder = Decoder(attn_units, hidden_state_dim)

    def call(self, enc_dec_input):
        enc_input = enc_dec_input[0]
        dec_input = enc_dec_input[1]
        enc_hiddens, enc_states = self.encoder(enc_input)
        dec_out, _, _ = self.decoder(dec_input, enc_hiddens, enc_states) # (batch, seq_len_dec, seq_len_enc)

        return dec_out

 

 

 

학습 단계에서는 출력 시간스텝 \(l\) 마다 입력 \(\mathbf{x}'_l\) 을 공급하는 'teacher forcing' 이라는 방법을 사용한다. 반면 실행 단계에서는 이전 시간스텝 \(l-1\) 에서 생성한 예측값 \(\hat{\mathbf{y}}_{l-1}\) 을 입력 \(\mathbf{x}'_l\) 에 공급한다.

 

 

입출력 데이터 시퀀스는 다음과 같이 준비한다. 5개 도시로 구성된 TSP를 예로 들어서 설명하겠다. 입력 데이터가 다음과 같다고 하자.

 

0.597640340831 0.811472963882 0.378392540696 0.188165799059 0.442849630371 0.291815169363 0.363774192611 0.596244647907 0.955211930984 0.947031856076 output 1 4 2 3 5 1

 

그러면 학습용으로 다음 세가지 시퀀스를 준비해야 한다.

 

\[ \begin{align} \mathbf{x}_{enc} = & [ [0.59764034 \ 0.81147295] \\ & \ [0.37839255 \ 0.1881658 ] \\ & \ [0.44284964 \ 0.29181516] \\ & \ [0.36377418 \ 0.59624463] \\ & \ [0.95521194 \ 0.94703186] ] \\ \\ \mathbf{x}'_{dec} = & [ [0. \ 0. ] \\ & \ [0.59764034 \ 0.81147295] \\ & \ [0.36377418 \ 0.59624463] \\ & \ [0.37839255 \ 0.1881658 ] \\ & \ [0.44284964 0.29181516] \\ & \ [0.95521194 0.94703186] ] \\ \\ \mathbf{y} = & [ [1. \ 0. \ 0. \ 0. \ 0.] \\ & \ [0. \ 0. \ 0. \ 1. \ 0.] \\ & \ [0. \ 1. \ 0. \ 0. \ 0.] \\ & \ [0. \ 0. \ 1. \ 0. \ 0.] \\ & \ [0. \ 0. \ 0. \ 0. \ 1.] \\ & \ [1. \ 0. \ 0. \ 0. \ 0.] ] \\ \\ \mathbf{y}_{raw} = & [1, \ 4, \ 2, \ 3, \ 5, \ 1] \end{align} \]

 

\(\mathbf{x}_{enc}\) 는 인코더의 입력 시퀀스이며 5개 도시의 좌표를 나타낸다. \(\mathbf{x}'_{dec}\) 는 디코더의 입력 시퀀스로서 세일즈맨이 최적으로 순회해야 할 도시의 좌표다. 맨 앞의 성분 \([0, \ 0]\) 은 [SOS]를 표시한 것이다. 순회할 도시의 순서는 \(\mathbf{y}_{raw}\) 에 있다. \(\mathbf{y}\) 는 \(\mathbf{y}_{raw}\) 를 원핫(one-hot) 인코딩한 것으로서 디코더가 생성해야 하는 참값으로 사용한다.

5개 도시로 구성된 TSP에 포인터넷 모델을 적용하여 데이터셋 100,000개, 이폭 100으로 학습한 결과는 다음과 같다. 학습 시간이 많이 걸리는 관계로 5개 도시로 구성된 TSP 데이터를 사용했지만 10개 도시 또는 50개 도시로 구성된 TSP에도 코드를 그대로 적용할 수 있다.

 

 

 

10개의 테스트 데이터셋으로 성능을 테스트해 보았다. 세이즈맨이 순회한 도시의 총 거리를 비교하는 것이 맞겠으나 여기서는 편의상 정답으로 제시된 최적의 도시 방문 순서와 비교하였다.

 

 

논문 'Pointer Networks' 에서는 빔 검색(beam search)을 사용하여 정확도를 높였으나, 여기서는 사용하지 않았다.

논문에서 제안한 방법의 단점은 지도학습을 사용했다는데 있을 것이다. 이러한 방식으로 학습하면 정답 데이터가 필요하기 때문에 NP-hard 문제에는 바람직하지 않다. 이런 단점을 극복하기 위해서 강화학습(RL) 기법을 포인터넷에 적용한 논문이 발표되었는데 이에 대해서는 나중에 살펴보도록 하자.

 

 

다음은 Tensorflow2 전체 코드다.

 

gen_data.py

 

# generating TSP data for pointer net
# coded by st.watermelon

import numpy as np
import tensorflow as tf

def get_train_data(tsp_len, data_size=100000):
    if tsp_len == 10:
        X_encoder = np.loadtxt('data/tsp_10_train_exact.txt', usecols=range(0,20))
        y_train = np.loadtxt('data/tsp_10_train_exact.txt', usecols=range(21,32))
    elif tsp_len == 5:
        X_encoder = np.loadtxt('data/tsp5.txt', usecols=range(0,10))
        y_train = np.loadtxt('data/tsp5.txt', usecols=range(11,17))

    X_encoder = X_encoder[0:data_size]
    y_train = y_train[0:data_size]

    X_encoder = X_encoder.reshape((X_encoder.shape[0], tsp_len, 2)) # (train_size, seq_len, input_dim)
    y_train = y_train - 1
    y_train = y_train.astype(int) # (train_size, seq_len+1)

    X_decoder = np.zeros((X_encoder.shape[0], tsp_len+1, 2))
    yy_train = list()
    for kk in range(y_train.shape[0]):
        yy_train.append(one_hot_encode(y_train[kk], tsp_len))
        X_decoder[kk, 1:, :] = X_encoder[kk, y_train[kk, 0:-1], :]

    y = np.array(yy_train) # (train_size, seq_len_enc+1, seq_len_enc),
                           # seq_len_dec=seq_len_enc+1, tsp_len=seq_len_enc

    # convert the datasets to float32
    X_encoder = tf.cast(X_encoder, tf.float32)
    X_decoder = tf.cast(X_decoder, tf.float32)
    y = tf.cast(y, tf.float32)

    return X_encoder, X_decoder, y


def get_test_data(tsp_len):
    if tsp_len == 10:
        X_encoder = np.loadtxt('data/tsp_10_test_exact.txt', usecols=range(0, 20))
        y_test = np.loadtxt('data/tsp_10_test_exact.txt', usecols=range(21, 32))
    elif tsp_len == 5:
        X_encoder = np.loadtxt('data/tsp5_test.txt', usecols=range(0,10))
        y_test = np.loadtxt('data/tsp5_test.txt', usecols=range(11,17))

    X_encoder = X_encoder.reshape((X_encoder.shape[0], tsp_len, 2))
    y_test = y_test - 1
    y_test = y_test.astype(int)

    X_decoder = np.zeros((X_encoder.shape[0], tsp_len+1, 2))
    yy_test = list()
    for kk in range(y_test.shape[0]):
        yy_test.append(one_hot_encode(y_test[kk], tsp_len))
        X_decoder[kk, 1:, :] = X_encoder[kk, y_test[kk, 0:-1], :]

    y = np.array(yy_test) # (None, seq_len_enc+1, seq_len_enc)

    # convert the datasets to float32
    X_encoder = tf.cast(X_encoder, tf.float32)
    X_decoder = tf.cast(X_decoder, tf.float32)
    y = tf.cast(y, tf.float32)

    return X_encoder, X_decoder, y


# one hot encode sequence
def one_hot_encode(sequence, n_tokens):
    encoding = list()
    for value in sequence:
        vector = [0 for _ in range(n_tokens)]
        vector[value] = 1
        encoding.append(vector)
    return np.array(encoding)


# decode a one hot encoded string
def one_hot_decode(encoded_seq):
    return [np.argmax(vector) for vector in encoded_seq]

 

ptr_model.py

 

# Pointer Net
# coded by st.watermelon

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Layer, LSTM, Dense, Input


class Encoder(Model):

    def __init__(self, hidden_state_dim):
        super(Encoder, self).__init__()

        self.lstm = LSTM(units=hidden_state_dim, return_sequences=True, return_state=True)

    def call(self, enc_input):
        # enc_hiddens=(batch, seq_len, hidden_dim),
        # h_st=(batch, hidden_dim), c_st=(batch, hidden_dim)
        enc_hiddens, h_st, c_st = self.lstm(enc_input)
        enc_states = [h_st, c_st]
        return enc_hiddens, enc_states


class Attention(Layer):

    def __init__(self, attn_units):
        super(Attention, self).__init__()
        self.W1 = Dense(attn_units, use_bias=False)  # key
        self.W2 = Dense(attn_units, use_bias=False)  # query
        self.V = Dense(1, use_bias=False)

    def call(self, enc_hiddens, dec_hidden):
        # enc_hiddens = (batch, seq_len_enc, hidden_state_dim)
        # dec_hidden = (batch, hidden_state_dim)
        # dec_hidden_exp = (batch, 1, hidden_state_dim)
        dec_hidden_exp = tf.expand_dims(dec_hidden, 1)
        keys = self.W1(enc_hiddens)  # (batch, seq_len, attn_units)
        query = self.W2(dec_hidden_exp)  # (batch, 1, attn_units)
        tanh_output = tf.nn.tanh(keys + query)
        score = self.V(tanh_output)  # (batch, seq_len_enc, 1)
        attention_weights = tf.nn.softmax(score, axis=1)  # (batch, seq_len_enc, 1)

        return attention_weights


class Decoder(Model):

    def __init__(self, attn_units, hidden_state_dim):
        super(Decoder, self).__init__()

        self.lstm = LSTM(units=hidden_state_dim, return_sequences=True, return_state=True)
        self.pointer = Attention(attn_units)

    def call(self, dec_input, enc_hiddens, enc_states):
        # enc_hiddens = [h1, h2, ..., hn] = (batch, seq_len_enc, hidden_state_dim)
        # dec_input=(batch, seq_len_dec, input_dim)
        # enc_states = [h_st, c_st]
        dec_hiddens, h_st, c_st = self.lstm(dec_input, initial_state=enc_states)
        seq_len_dec = tf.shape(dec_hiddens)[1]

        # Use tf.TensorArray to handle dynamic loops in TensorFlow
        dec_out_array = tf.TensorArray(dtype=tf.float32, size=seq_len_dec, dynamic_size=True)

        for t in tf.range(seq_len_dec):
            dec_hidden_t = dec_hiddens[:, t, :]
            out_t = self.pointer(enc_hiddens, dec_hidden_t) # (batch, seq_len_enc, 1)
            out_t = tf.squeeze(out_t, -1) # (batch, seq_len_enc)
            dec_out_array = dec_out_array.write(t, out_t)

        dec_out = dec_out_array.stack()  # Convert TensorArray to tensor
        dec_out = tf.transpose(dec_out, perm=[1, 0, 2])  # Transpose to get (batch, seq_len_dec, seq_len_enc)

        return dec_out, h_st, c_st


class PtrNet(Model):

    def __init__(self, attn_units, hidden_state_dim):
        super(PtrNet, self).__init__()

        self.encoder = Encoder(hidden_state_dim)
        self.decoder = Decoder(attn_units, hidden_state_dim)

    def call(self, enc_dec_input):
        enc_input = enc_dec_input[0]
        dec_input = enc_dec_input[1]
        enc_hiddens, enc_states = self.encoder(enc_input)
        dec_out, _, _ = self.decoder(dec_input, enc_hiddens, enc_states) # (batch, seq_len_dec, seq_len_enc)

        return dec_out

 

ptr_train.py

 

# Pointer Net train
# coded by st.watermelon

from ptr_model import PtrNet
from gen_data import *
import os
import matplotlib.pyplot as plt
import numpy as np

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, LSTM, Dense


""" setup pointer net agent """
class PtrAgent(Model):

    def __init__(self, seq_len_enc):
        super(PtrAgent, self).__init__()

        # hyperparameters
        self.TRAIN_SIZE = 100000 # 100000
        self.ATTN_UNITS = 128
        self.SEQ_LEN_ENC = seq_len_enc
        self.SEQ_LEN_DEC = seq_len_enc+1
        self.INPUT_DIM = 2
        self.BATCH_SIZE = 128
        self.HIDDEN_STATE_DIM = 128
        self.LEARNING_RATE = 5e-4
        # create pointer net
        self.ptr = PtrNet(self.ATTN_UNITS, self.HIDDEN_STATE_DIM)

        # compile
        self.ptr.compile(
            loss=tf.keras.losses.categorical_crossentropy,
            optimizer=tf.keras.optimizers.Adam(self.LEARNING_RATE),
            metrics=['accuracy']
        )

        encoder_input = Input((self.SEQ_LEN_ENC, self.INPUT_DIM))
        decoder_input = Input((self.SEQ_LEN_DEC, self.INPUT_DIM))

        self.ptr([encoder_input, decoder_input])
        self.ptr.encoder.summary()
        self.ptr.decoder.summary()
        #self.ptr.summary()

    def train(self, epochs):
        X_encoder, X_decoder, y = get_train_data(self.SEQ_LEN_ENC, self.TRAIN_SIZE)

        history = self.ptr.fit(
            [X_encoder, X_decoder],
            y,
            batch_size=self.BATCH_SIZE,
            epochs=epochs,
            validation_split=0.2,
            verbose=2
        )

        # save
        self.ptr.save_weights("./save_weights/ptr.h5")


        plt.figure(figsize=(10, 4))
        plt.subplot(1, 2, 1)
        plt.plot(history.history['loss'], 'b-', label='loss')
        plt.plot(history.history['val_loss'], 'r-', label='val_loss')

        plt.xlabel('epoch')
        plt.legend()

        plt.subplot(1, 2, 2)
        plt.plot(history.history['accuracy'], 'g-', label='accuracy')
        plt.plot(history.history['val_accuracy'], 'k-', label='val_accuracy')
        plt.xlabel('epoch')
        plt.legend()

        plt.show()


    def pointed_seq(self, input_seq, true_pointing):
        # input_seq=(batch, seq_len_enc, input_dim)
        # true_pointing=(batch, seq_len_dec, seq_len_eec)
        if os.path.exists('./save_weights/ptr.h5'):
            self.ptr.load_weights("./save_weights/ptr.h5")
        else:
            return 0

        print('Exact  \t\t\t   Predicted \t\t T/F')
        correct = 0
        n_seq = input_seq.shape[0]
        for seq_idx in range(0, n_seq):
            # take one sequence at a time
            decoded_seq = self.single_pointed_seq(
                tf.reshape(input_seq[seq_idx], (1, self.SEQ_LEN_ENC, self.INPUT_DIM)))

            if (one_hot_decode(true_pointing[seq_idx]) == decoded_seq):
                correct += 1
            print(np.array(one_hot_decode(true_pointing[seq_idx]))+1, '\t',
                  np.array(decoded_seq)+1,
                  '\t', one_hot_decode(true_pointing[seq_idx]) == decoded_seq)
        print('Accuracy: ', correct / n_seq)


    def single_pointed_seq(self, single_input_seq):
        # single_input_seq=(1, seq_len_enc, input_dim)
        # encode the input sequence as context vector
        enc_hiddens, enc_states = self.ptr.encoder(single_input_seq)

        # make <sos> (the first decoder input).
        dec_input_seq = np.zeros((1, 1, self.INPUT_DIM)) # (1, seq_len, input_dim)

        # looping for a batch of sequences
        stop_condition = False
        decoded_seq = list()

        while not stop_condition:
            # decode the input
            dec_out, h_st, c_st = self.ptr.decoder(dec_input_seq, enc_hiddens, enc_states)
            # dec_out = (1, 1, seq_len_enc)
            # predict the decoder output using argmax
            sampled_digit = np.argmax(dec_out[0, -1, :])
            # add the predicted output to output sequence
            decoded_seq.append(sampled_digit)

            # stop condition: hit max length
            if (len(decoded_seq) == self.SEQ_LEN_DEC):
                stop_condition = True

            # update the decoder input sequence for the next LSTM cell
            dec_input_seq[0, 0, :] = single_input_seq[0, sampled_digit, :]

            # update context value
            enc_states = [h_st, c_st]

        return decoded_seq


if __name__ == "__main__":
    seq_len = 5
    agent = PtrAgent(seq_len)
    agent.train(100)

 

ptr_load_play.py

 

# Pointer net
# coded by st.watermelon

from ptr_train import PtrAgent
from gen_data import *

def main():
    seq_len = 5
    agent = PtrAgent(seq_len)
    input_seq, _, true_pointing = get_test_data(seq_len)
    # choose 10 sequences
    start = 1000# 2000 # 0 #1000  # 2000 -best
    agent.pointed_seq(input_seq[0+start:10+start, :], true_pointing[0+start:10+start, :])


if __name__ == "__main__":
    main()

 

댓글