본문 바로가기
프로그래밍/TensorFlow2

TensorFlow2에서 1차, 2차 편미분 계산

by 깊은대학 2021. 7. 25.

자동으로 미분을 계산하려면 Tensorflow는 순방향 패스 과정에서 어떤 순서로 어떤 연산을 했는지 기억해야 한다. 그런 다음, 역방향 패스 중에 이 연산 목록을 역순으로 이동해 가며 미분(derivative)을 계산한다.

 

 

Tensorflow는 자동으로 미분을 계산하기 위해서 tf.GradientTape API를 제공한다. Tensorflow는 tf.GradientTape의 컨텍스트 내에서 실행 된 관련 연산을 '테이프'에 '기록'한다. 그런 다음 해당 테이프를 거꾸로 돌려서 기록된 연산의 미분을 계산한다. tf.Variable로 생성된 변수에 대해서만 미분할 수 있는데 상수인 경우에도 watch() 메쏘드를 이용하면 미분을 계산할 수 있다.

예를 들어서 \(y(x,t)=x^3+2t\) 를 \(x\) 에 대해서 편미분하면, \(\frac{\partial y}{\partial x}=3x^2\) 이다. 한번 더 편미분하면 \( \frac{\partial^2 y}{\partial x^2} =6x\) 이다. \(x=1\) 에서 1차 편미분값과 2차 편미분값을 구하면 각각 \(3\) 과 \(6\) 이다. 한편, \(t\) 에 대해서 편미분하면, \( \frac{\partial y}{\partial t}=2\) 이다.

 

 

Tensorflow2로 미분을 구현해 보자. 먼저 \(1.0\) 으로 할당된 상수 \(x\) 와 \(t\) 를 생성한다.

 

x = tf.constant(1.0)
t = tf.constant(1.0)

 

다음으로 tf.GradientTape 안에 미분 계산에 필요한 모든 연산을 '기록'한다.

 

with tf.GradientTape(persistent=True) as tape:
    tape.watch(x)
    tape.watch(t)
    y = x*x*x + 2*t
    y_x = tape.gradient(y, x)

 

tape.watch(x) 와 tape.watch(t)는 \(x\) 와 \(t\) 가 상수이기 때문에 필요하다. 기본적으로 tape.gradient() 메쏘드가 호출되면 GradientTape에 포함된 리소스가 해제된다. 따라서 2번 이상의 미분을 계산하려면 persistent=True 를 설정해야 한다.

미분 연산을 테이프 안에서 수행하면 CPU와 메모리 자원을 많이 차지하기 때문에 보통 테이프 밖에서 미분 계산을 수행한다. 하지만 여기서는 테이프 안에 1차 편미분 \(\frac{\partial y}{\partial x}\) 도 기록했는데 이것은 2차 편미분 \(\frac{\partial^2 y}{\partial x^2}\) 계산에 필요하기 때문이다.

이제 테이프의 기록을 이용하여 테이프 밖에서 2차 편미분 \(\frac{\partial^2 y}{\partial x^2}\) 와 시간에 대한 편미분 \(\frac{\partial y}{\partial t}\) 를 계산한다. 그리고 미분 계산이 끝나면 테이프를 삭제한다.

 

y_xx = tape.gradient(y_x, x)      
y_t = tape.gradient(y, t)      
del tape

 

전체 코드는 다음과 같다.

 

# higher order derivatives test
# coded by St.Watermelon

import tensorflow as tf

x = tf.constant(1.0)
t = tf.constant(1.0)

with tf.GradientTape(persistent=True) as tape:
    tape.watch(x)
    tape.watch(t)
    y = x*x*x + 2*t
    y_x = tape.gradient(y, x)     # y_x = 3*x*x = 3
y_xx = tape.gradient(y_x, x)      # y_xx = 6*x = 6
y_t = tape.gradient(y, t)         # y_t=2
del tape

print("y_x=", y_x)
print("y_xx=", y_xx)
print("y_t=", y_t)

 

 

 

댓글