본문 바로가기
프로그래밍/TensorFlow2

TensorFlow2에서 1차, 2차 편미분 계산

by 세인트 워터멜론 2021. 7. 25.

자동으로 미분을 계산하려면 Tensorflow는 순방향 패스 과정에서 어떤 순서로 어떤 연산을 했는지 기억해야 한다. 그런 다음, 역방향 패스 중에 이 연산 목록을 역순으로 이동해 가며 미분(derivative)을 계산한다.

 

 

Tensorflow는 자동으로 미분을 계산하기 위해서 tf.GradientTape API를 제공한다. Tensorflow는 tf.GradientTape의 컨텍스트 내에서 실행 된 관련 연산을 '테이프'에 '기록'한다. 그런 다음 해당 테이프를 거꾸로 돌려서 기록된 연산의 미분을 계산한다. tf.Variable로 생성된 변수에 대해서만 미분할 수 있는데 상수인 경우에도 watch() 메쏘드를 이용하면 미분을 계산할 수 있다.

예를 들어서 \(y(x,t)=x^3+2t\) 를 \(x\) 에 대해서 편미분하면, \(\frac{\partial y}{\partial x}=3x^2\) 이다. 한번 더 편미분하면 \( \frac{\partial^2 y}{\partial x^2} =6x\) 이다. \(x=1\) 에서 1차 편미분값과 2차 편미분값을 구하면 각각 \(3\) 과 \(6\) 이다. 한편, \(t\) 에 대해서 편미분하면, \( \frac{\partial y}{\partial t}=2\) 이다.

 

 

Tensorflow2로 미분을 구현해 보자. 먼저 \(1.0\) 으로 할당된 상수 \(x\) 와 \(t\) 를 생성한다.

 

x = tf.constant(1.0)
t = tf.constant(1.0)

 

다음으로 tf.GradientTape 안에 미분 계산에 필요한 모든 연산을 '기록'한다.

 

with tf.GradientTape(persistent=True) as tape:
    tape.watch(x)
    tape.watch(t)
    y = x*x*x + 2*t
    y_x = tape.gradient(y, x)

 

tape.watch(x) 와 tape.watch(t)는 \(x\) 와 \(t\) 가 상수이기 때문에 필요하다. 기본적으로 tape.gradient() 메쏘드가 호출되면 GradientTape에 포함된 리소스가 해제된다. 따라서 2번 이상의 미분을 계산하려면 persistent=True 를 설정해야 한다.

미분 연산을 테이프 안에서 수행하면 CPU와 메모리 자원을 많이 차지하기 때문에 보통 테이프 밖에서 미분 계산을 수행한다. 하지만 여기서는 테이프 안에 1차 편미분 \(\frac{\partial y}{\partial x}\) 도 기록했는데 이것은 2차 편미분 \(\frac{\partial^2 y}{\partial x^2}\) 계산에 필요하기 때문이다.

이제 테이프의 기록을 이용하여 테이프 밖에서 2차 편미분 \(\frac{\partial^2 y}{\partial x^2}\) 와 시간에 대한 편미분 \(\frac{\partial y}{\partial t}\) 를 계산한다. 그리고 미분 계산이 끝나면 테이프를 삭제한다.

 

y_xx = tape.gradient(y_x, x)      
y_t = tape.gradient(y, t)      
del tape

 

전체 코드는 다음과 같다.

 

# higher order derivatives test
# coded by St.Watermelon

import tensorflow as tf

x = tf.constant(1.0)
t = tf.constant(1.0)

with tf.GradientTape(persistent=True) as tape:
    tape.watch(x)
    tape.watch(t)
    y = x*x*x + 2*t
    y_x = tape.gradient(y, x)     # y_x = 3*x*x = 3
y_xx = tape.gradient(y_x, x)      # y_xx = 6*x = 6
y_t = tape.gradient(y, t)         # y_t=2
del tape

print("y_x=", y_x)
print("y_xx=", y_xx)
print("y_t=", y_t)

 

 

 

댓글