자동으로 미분을 계산하려면 Tensorflow는 순방향 패스 과정에서 어떤 순서로 어떤 연산을 했는지 기억해야 한다. 그런 다음, 역방향 패스 중에 이 연산 목록을 역순으로 이동해 가며 미분(derivative)을 계산한다.
Tensorflow는 자동으로 미분을 계산하기 위해서 tf.GradientTape API를 제공한다. Tensorflow는 tf.GradientTape의 컨텍스트 내에서 실행 된 관련 연산을 '테이프'에 '기록'한다. 그런 다음 해당 테이프를 거꾸로 돌려서 기록된 연산의 미분을 계산한다. tf.Variable로 생성된 변수에 대해서만 미분할 수 있는데 상수인 경우에도 watch() 메쏘드를 이용하면 미분을 계산할 수 있다.
예를 들어서 \(y(x,t)=x^3+2t\) 를 \(x\) 에 대해서 편미분하면, \(\frac{\partial y}{\partial x}=3x^2\) 이다. 한번 더 편미분하면 \( \frac{\partial^2 y}{\partial x^2} =6x\) 이다. \(x=1\) 에서 1차 편미분값과 2차 편미분값을 구하면 각각 \(3\) 과 \(6\) 이다. 한편, \(t\) 에 대해서 편미분하면, \( \frac{\partial y}{\partial t}=2\) 이다.
Tensorflow2로 미분을 구현해 보자. 먼저 \(1.0\) 으로 할당된 상수 \(x\) 와 \(t\) 를 생성한다.
x = tf.constant(1.0)
t = tf.constant(1.0)
다음으로 tf.GradientTape 안에 미분 계산에 필요한 모든 연산을 '기록'한다.
with tf.GradientTape(persistent=True) as tape:
tape.watch(x)
tape.watch(t)
y = x*x*x + 2*t
y_x = tape.gradient(y, x)
tape.watch(x) 와 tape.watch(t)는 \(x\) 와 \(t\) 가 상수이기 때문에 필요하다. 기본적으로 tape.gradient() 메쏘드가 호출되면 GradientTape에 포함된 리소스가 해제된다. 따라서 2번 이상의 미분을 계산하려면 persistent=True 를 설정해야 한다.
미분 연산을 테이프 안에서 수행하면 CPU와 메모리 자원을 많이 차지하기 때문에 보통 테이프 밖에서 미분 계산을 수행한다. 하지만 여기서는 테이프 안에 1차 편미분 \(\frac{\partial y}{\partial x}\) 도 기록했는데 이것은 2차 편미분 \(\frac{\partial^2 y}{\partial x^2}\) 계산에 필요하기 때문이다.
이제 테이프의 기록을 이용하여 테이프 밖에서 2차 편미분 \(\frac{\partial^2 y}{\partial x^2}\) 와 시간에 대한 편미분 \(\frac{\partial y}{\partial t}\) 를 계산한다. 그리고 미분 계산이 끝나면 테이프를 삭제한다.
y_xx = tape.gradient(y_x, x)
y_t = tape.gradient(y, t)
del tape
전체 코드는 다음과 같다.
# higher order derivatives test
# coded by St.Watermelon
import tensorflow as tf
x = tf.constant(1.0)
t = tf.constant(1.0)
with tf.GradientTape(persistent=True) as tape:
tape.watch(x)
tape.watch(t)
y = x*x*x + 2*t
y_x = tape.gradient(y, x) # y_x = 3*x*x = 3
y_xx = tape.gradient(y_x, x) # y_xx = 6*x = 6
y_t = tape.gradient(y, t) # y_t=2
del tape
print("y_x=", y_x)
print("y_xx=", y_xx)
print("y_t=", y_t)
'프로그래밍 > TensorFlow2' 카테고리의 다른 글
Mac M1 Pro 에 Tensorflow, Gym, Mujoco 설치하기 (0) | 2022.09.19 |
---|---|
Model Subclassing 멀티 입력 신경망 모델 구현 방법 (0) | 2021.03.16 |
tf.reduce_sum 함수 (0) | 2021.03.12 |
텐서와 변수 - 3 (0) | 2021.02.11 |
텐서와 변수 - 2 (0) | 2021.02.10 |
댓글