앞서 살펴본 바와 같이(https://pasus.tistory.com/263) VAE의 손실함수는 복원손실 항과 정규화 손실 항으로 구성된다.
식 (1)의 VAE 손실함수에서 첫번째 항인 복원손실은 복원(reconstruction)된 데이터가 원본 데이터와 얼마나 유사한지를 나타내는 오차다. 원본 데이터의 확률분포에 따라서 원본 데이터와 복원된 데이터 간의 교차 엔트로피나 L2 놈(norm)을 사용하여 복원손실을 계산할 수 있다.
두번째 항인 정규화(regularization) 손실은 고차원 입력 데이터를 저차원 잠재변수 공간(latent space)으로 인코딩할 때 잠재변수의 확률분포가 사전에(prior) 규정한 확률분포
AE(오토인코더)는 복원손실만 최소화하지만 VAE는 KL발산 항도 함께 최소화하여 잠재변수의 공간에도 제약을 가한다. 즉 KL발산 항은 잠재공간의 구조에 제약을 가하는 정규화(regularization) 기능을 갖는다.

잠재변수의 확률분포를 정규 가우시안 분포
AE(오토인코더)는 기존의 선형 차원축소(dimension reduction) 방법인 POD(Proper Orthogonal Decomposition)를 비선형 방법으로 일반화시킨 것으로 해석되며 최근 수많은 연구에서 고도의 비선형 문제에 대한 차원축소 모델을 구축하는데 사용되고 있다. 하지만 POD가 직교 기반의 선형공간에서 잠재변수 간의 독립성을 보장하는 것과는 달리 AE의 경우에는 복원손실을 최소화하는 것에만 목표를 두고 잠재공간의 구조를 고려하지 않기 때문에 잠재변수 간의 독립성을 보장할 수 없다.
서로 독립인 또는 상관관계가 없는 잠재변수가 이상적인 이유는 각 잠재변수가 가지고 있는 물리적 의미나 데이터의 특성을 해석하기가 용이하기 때문이다.
이런 관점에서 볼 때 VAE는 KL발산 항을 통하여 차원축소 방법으로서의 AE의 단점을 개선한 것으로 해석할 수 있다. 구글 딥마인드에서 발표한 논문 '
식 (2)는 VAE의 변형 알고리즘인
관련 논문에 의하면 가중치
다음은 MNIST 데이터셋을 이용한

다음은 복원 성능을 본 것이다.

다음은 잠재공간의 일정 영역을 스캔한 결과다.

다음은 Tensorflow 2로 작성한
vae_model.py
# beta-VAE model for MNIST
# coded by st.watermelon
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, Dense, Conv2DTranspose, Reshape
from tensorflow.keras.models import Model
""" beta-VAE encoder """
class Encoder(Model):
def __init__(self, latent_dim):
super(Encoder, self).__init__()
self.h1 = Conv2D(32, 3, strides=2, padding='same', activation='relu')
self.h2 = Conv2D(64, 3, strides=2, padding='same', activation='relu')
self.h3 = Dense(16, activation='relu')
self.flatten = tf.keras.layers.Flatten()
self.mu = Dense(latent_dim)
self.log_var = Dense(latent_dim)
def call(self, x):
x = self.h1(x) # input size (28,28,1)
x = self.h2(x) # (7,7,64)
x = self.flatten(x) # (latent_dim*7*7*64)
x = self.h3(x)
mu = self.mu(x)
log_var = self.log_var(x)
return mu, log_var
""" decoder """
class Decoder(Model):
def __init__(self):
super(Decoder, self).__init__()
self.fc = Dense(7*7*64, activation='relu')
self.reshape = Reshape((7, 7, 64))
self.h1 = Conv2DTranspose(64, 3, strides=2, padding='same', activation='relu')
self.h2 = Conv2DTranspose(32, 3, strides=2, padding='same', activation='relu')
self.output_layer = Conv2DTranspose(1, 3, padding='same', activation='sigmoid')
def call(self, z):
x = self.fc(z)
x = self.reshape(x)
x = self.h1(x)
x = self.h2(x)
x = self.output_layer(x)
return x
""" beta-VAE """
class betaVAE(Model):
def __init__(self, latent_dim):
super(betaVAE, self).__init__()
self.latent_dim = latent_dim
self.encoder = Encoder(latent_dim)
self.decoder = Decoder()
def call(self, x):
mu, log_var = self.encoder(x)
encoded = self.reparameterize(mu, log_var)
decoded = self.decoder(encoded)
return decoded, encoded, mu, log_var
def reparameterize(self, mu, log_var):
epsilon = tf.random.normal(shape=tf.shape(mu))
return epsilon * tf.exp(0.5*log_var) + mu
if __name__ == "__main__":
a = betaVAE(2)
a.build(input_shape=(None, 28,28,1))
a.encoder.summary()
a.decoder.summary()
vae_train.py
# beta-VAE model for MNIST
# coded by st.watermelon
import os
import tensorflow as tf
from tensorflow.keras.models import Model
from vae_model import betaVAE
import matplotlib.pyplot as plt
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
""" loading training and test data """
def get_data():
mnist = tf.keras.datasets.mnist
(x_train, _), (x_test, _) = mnist.load_data()
# adjusting to 0 ~ 1.0
x_train = x_train / 255.0
x_test = x_test / 255.0
# reshaping
x_train = x_train.reshape(-1,28,28,1)
x_test = x_test.reshape(-1,28,28,1)
return x_train, x_test
""" setup beta-VAE agent """
class betaVAEagent(Model):
def __init__(self, latent_dim, beta):
super(betaVAEagent, self).__init__()
self.latent_dim = latent_dim
self.beta = beta
self.beta_vae = betaVAE(latent_dim)
self.beta_vae.build(input_shape=(None, 28, 28, 1)) # (batch, H, W, C)
lr = 1e-3
self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
def loss_function(self, x):
x_hat, z, mu, log_var = self.beta_vae(x)
recons_loss = tf.reduce_mean(
tf.reduce_sum(
#tf.keras.losses.binary_crossentropy(x, x_hat), axis=(1, 2)
tf.keras.losses.MSE(x, x_hat), axis=(1, 2)
)
)
kl_loss = -0.5 * (1 + log_var - tf.square(mu) - tf.exp(log_var))
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
loss = recons_loss + self.beta * kl_loss
return loss, recons_loss, kl_loss
def train_step(self, x):
with tf.GradientTape() as tape:
loss, recons_loss, kl_loss = self.loss_function(x)
gradients = tape.gradient(loss, self.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
return {'total_loss': loss, 'recons_loss': recons_loss, 'kl_loss': kl_loss}
def save_weights(self):
self.beta_vae.save_weights("./save_weights/vae.h5")
def train(self, train_dataset, epochs):
train_loss_history = []
kl_loss_history = []
recons_loss_history = []
for epoch in range(1, epochs + 1):
print(f"Epoch {epoch}/{epochs}")
epoch_train_loss = []
epoch_kl_loss = []
epoch_recons_loss = []
for step, batch in enumerate(train_dataset):
loss_dict = self.train_step(batch)
epoch_train_loss.append(loss_dict['total_loss'].numpy())
epoch_recons_loss.append(loss_dict['recons_loss'].numpy())
epoch_kl_loss.append(loss_dict['kl_loss'].numpy())
if step % 100 == 0:
loss = loss_dict['total_loss']
recons_loss = loss_dict['recons_loss']
kl_loss = loss_dict['kl_loss']
print(f"Step {step}, Total Loss: {loss}, Recon Loss: {recons_loss}, KL Loss: {kl_loss}")
train_loss_history.append(tf.reduce_mean(epoch_train_loss))
kl_loss_history.append(tf.reduce_mean(epoch_kl_loss))
recons_loss_history.append(tf.reduce_mean(epoch_recons_loss))
# Save encoder and decoder weights after each epoch
self.save_weights()
# plot loss history
plt.plot(train_loss_history, label='Training Loss')
plt.plot(kl_loss_history, label='KL Loss')
plt.plot(recons_loss_history, label='Recons Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()
def main():
x_train, _ = get_data()
train_size = x_train.shape[0]
batch_size = 128
train_dataset = (tf.data.Dataset.from_tensor_slices(x_train)
.shuffle(train_size).batch(batch_size))
latent_dim = 2
beta = 1
agent = betaVAEagent(latent_dim, beta)
epochs = 30
agent.train(train_dataset, epochs)
if __name__ == "__main__":
main()
'AI 딥러닝 > DLA' 카테고리의 다른 글
[LLM] Ollama 모델에서 OpenAI Chat API 사용하기 (0) | 2025.02.15 |
---|---|
[LLM] Ollama Web-UI 설치 (0) | 2024.02.25 |
[VAE] 변이형 오토인코더(Variational Autoencoder) (0) | 2023.04.30 |
[U-Net] 망막 혈관 세그멘테이션 (Retinal Vessel Segmentation) (0) | 2022.05.11 |
[U-Net] U-Net 구조 (0) | 2022.05.11 |
댓글