본문 바로가기
AI/강화학습

Soft Actor Critic (SAC) 알고리즘 - 1

by 세인트워터멜론 2021. 5. 29.

행동가치 함수에 대한 소프트 벨만 방정식은 다음과 같다.

 

\[ \begin{align} Q_{soft}^\pi (\mathbf{x}_t, \mathbf{u}_t ) & \gets r_t + \gamma \ \mathbb{E}_{\mathbf{x}_{t+1} \sim p(\mathbf{x}_{t+1} | \mathbf{x}_t, \mathbf{u}_t ), \ \mathbf{u}_{t+1} \sim \pi (\mathbf{u}_{t+1} | \mathbf{x}_{t+1} ) } \tag{1} \\ & \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \left[ Q_{soft}^\pi (\mathbf{x}_{t+1}, \mathbf{u}_{t+1} )- \alpha \log \pi(\mathbf{u}_{t+1} | \mathbf{x}_{t+1} ) \right] \end{align} \]

 

식 (1)은 현재 정책 \(\pi\) 에 관한 행동가치 함수이므로 \(\pi\) 가 주어졌을 때 업데이트를 반복하면 풀 수 있다.

 

 

정책 개선을 위한 식은 다음과 같다.

 

\[ \begin{align} \pi (\mathbf{u}_t | \mathbf{x}_t ) & \gets \arg\max_\pi⁡ \ \mathbb{E}_{\mathbf{u}_t \sim \pi (\mathbf{u}_t | \mathbf{x}_t ) } \left[ Q_{soft}^\pi (\mathbf{x}_t, \mathbf{u}_t )- \alpha \log \pi(\mathbf{u}_t | \mathbf{x}_t ) \right] \tag{2} \\ \\ & = \arg\min_\pi⁡ D_{KL} \left( \pi (\mathbf{u}_t | \mathbf{x}_t ) \parallel \frac{ \exp \left( \frac{1}{\alpha} Q_{soft}^\pi (\mathbf{x}_t, \mathbf{u}_t ) \right) }{ Z(\mathbf{x}_t ) } \right) \end{align} \]

 

식 (1)로 \(Q_{soft}^\pi (\mathbf{x}_t, \mathbf{u}_t )\) 가 수렴할 때까지 반복 계산하고, 수렴한 후에 식 (2)로 정책을 업데이트 하는 과정을 정책 이터레이션이라고 한다.

정책 이터레이션을 수행하려면 환경 모델이 필요하다. 하지만 모델프리(model-free) 강화학습에서는 환경 모델을 이용하여 행동가지 함수를 계산하는 것이 아니라, 에이전트가 환경에 행동을 가해서 생성한 데이터를 기반으로 하여 추정한다.

이제 소프트 행동가치 함수와 정책을 추정하기 위해서 신경망을 이용하기로 하자. 행동가치 함수를 추정하기 위한 신경망을 Q 신경망 또는 크리틱(critic) 신경망이라 하고 파라미터를 \(\phi\) 로, 정책을 추정하기 위한 신경망을 액터 신경망이라 하고 파라미터를 \(\theta\) 로 표기한다. 그리고 추정된 행동가치를 \(Q_\phi (\mathbf{x}_t, \mathbf{u}_t)\) 로, 정책을 \(\pi_\theta (\mathbf{u}_t | \mathbf{x}_t)\) 로 표기하자.

이와 같이 신경망을 이용하여 엔트로피 최대화 문제의 해를 찾는 알고리즘을 SAC(Soft Actor Critic) 알고리즘이라고 한다. SAC알고리즘에서는 소프트 행동가치가 수렴할 때까지 기다리지 않고 정책 평가와 정책 개선을 번갈아 가며 한 번씩 업데이트하는 방법을 사용한다.

Q 신경망의 손실함수는 소프트 행동가치 함수를 정확히 추정할 수 있는 파라미터 \(\phi\) 를 갖도록 정해져야 한다. 따라서 손실함수는 소프트 행동가치 추정값 \(Q_\phi (\mathbf{x}_t, \mathbf{u}_t)\) 와 소프트 행동가치의 참값 \(Q_{soft} (\mathbf{x}_t, \mathbf{u}_t )\) 의 차이가 최소가 되도록 정하면 되므로 다음과 같이 Q 신경망의 손실함수를 설정한다.

 

\[ L_Q (\phi)= \mathbb{E}_{(\mathbf{x}_i, \mathbf{u}_i) \sim \mathcal{D} } \left[ \frac{1}{2} \left( Q_\phi (\mathbf{x}_i, \mathbf{u}_i ) - Q_{soft} (\mathbf{x}_i, \mathbf{u}_i ) \ | \ \mathbf{x}_i, \mathbf{u}_i \right)^2 \right] \tag{3} \]

 

그런데 여기서 소프트 행동가치의 참값 \(Q_{soft} (\mathbf{x}_i, \mathbf{u}_i )\) 를 알지 못하므로, 식 (1)을 이용하여 다음과 같은 타깃을 설정한다.

 

\[ \begin{align} & Q_{soft} (\mathbf{x}_i, \mathbf{u}_i ) \approx q_i \tag{4} \\ \\ & \ \ \ \ \ = r(\mathbf{x}_i, \mathbf{u}_i) + \gamma \ \mathbb{E}_{(\mathbf{x}_i, \mathbf{u}_i, \mathbf{x}_{i+1} ) \sim \mathcal{D} } \left[ \begin{pmatrix} Q_{\phi^\prime} (\mathbf{x}_{i+1}, \mathbf{u}_{i+1} ) \\ - \alpha \log \pi_\theta (\mathbf{u}_{i+1} | \mathbf{x}_{i+1} ) \end{pmatrix} \ \mid \ \mathbf{x}_i, \mathbf{u}_i, \mathbf{x}_{i+1} \right] \end{align} \]

 

여기서 \(\phi^\prime\) 은 DQN과 DDPG에서 제기되었던 문제를 해결하고자 도입한 타깃 Q 신경망 파라미터다. 그러면 Q 신경망의 손실함수는 다음과 같이 된다.

 

\[ L_Q (\phi) = \ \mathbb{E}_{(\mathbf{x}_i, \mathbf{u}_i, \mathbf{x}_{i+1} ) \sim \mathcal{D} } \left[ \frac{1}{2} \left( Q_\phi (\mathbf{x}_{i+1}, \mathbf{u}_{i+1} ) - q_i \ \mid \ \mathbf{x}_i, \mathbf{u}_i, \mathbf{x}_{i+1} \right)^2 \right] \tag{5} \]

 

여기서 바깥쪽에 있는 기댓값은 리플레이 버퍼 \(\mathcal{D}\) 에서 샘플링한 \(N\)개의 데이터 \( ( \mathbf{x}_i, \mathbf{u}_i, \mathbf{x}_{i+1}) \) 를 추출하여 계산할 수 있지만, 안쪽에 있는 \(q_i\) 는 \(\pi_\theta (\mathbf{u}_{i+1} | \mathbf{x}_{i+1} ) \) 에 기반한 기댓값으로서 온-폴리시(on-policy)이므로 해당 계산에 쓰이는 \(\mathbf{u}_{i+1}\) 는 현재의 정책 \(\pi_\theta\) 에서 샘플링되어야 한다. 즉 \(\mathbf{x}_{i+1}\) 에 대한 행동인 \(\mathbf{u}_{i+1}\) 는 정책 \(\pi_\theta\) 를 따르도록 해야 한다.

 

 

손실함수 \(L_Q (\phi)\) 를 최소화하는 파라미터 \(\phi\) 는 다음과 같이 경사하강법으로 구할 수 있다.

 

\[ \phi \gets \phi - \alpha_\phi \nabla_\phi L_Q (\phi) \tag{6} \]

 

손실함수의 그래디언트는 다음과 같이 계산한다.

 

\[ \nabla_\phi L_Q (\phi)= \sum_i \nabla_\phi Q_\phi (\mathbf{x}_i, \mathbf{u}_i )\begin{pmatrix} Q_\phi (\mathbf{x}_i, \mathbf{u}_i ) -\gamma ( \ Q_{\phi^\prime} (\mathbf{x}_{i+1}, \mathbf{u}_{i+1} ) \\ -\alpha \log \pi_\theta (\mathbf{u}_{i+1} | \mathbf{x}_{i+1} ) \ ) \end{pmatrix} \tag{7} \]

 

 

 

액터 신경망의 손실함수도 정책을 정확히 추정할 수 있는 파라미터 \(\theta\) 를 갖도록 정해져야 한다. 식 (2)에 의하면 최소화해야 할 손실함수를 다음과 같이 정하면 된다.

 

\[ L_\pi (\theta)= \mathbb{E}_{\mathbf{x}_i \sim \mathcal{D}} \left[ \mathbb{E}_{\mathbf{u}_i \sim \pi_\theta (\mathbf{u}_i | \mathbf{x}_i ) } \left[ \alpha \log \pi_\theta (\mathbf{u}_i | \mathbf{x}_i ) - Q_\phi (\mathbf{x}_i, \mathbf{u}_i ) \right] \ \mid \ \mathbf{x}_i \ \right] \tag{8} \]

 

마찬가지로 여기서 바깥쪽에 있는 기댓값은 리플레이 버퍼 \(\mathcal{D}\) 에서 샘플링한 \(N\)개의 데이터 \(\mathbf{x}_i\) 를 추출하여 계산할 수 있지만, 안쪽에 있는 \(\pi_\theta (\mathbf{u}_i | \mathbf{x}_i )\) 에 기반한 기댓값은 현재의 정책 \(\pi_\theta\) 에서 샘플링해야 한다.

손실함수 \(L_\pi (\theta)\) 의 그래디언트는 다음과 같이 계산한다.

 

\[ \nabla_\theta L_\pi (\theta)= \sum_i \nabla_\theta \mathbb{E}_{\mathbf{u}_i \sim \pi_\theta (\mathbf{u}_i | \mathbf{x}_i ) } \left[ \alpha \log \pi_\theta (\mathbf{u}_i | \mathbf{x}_i ) - Q_\phi (\mathbf{x}_i, \mathbf{u}_i ) \right] \tag{9} \]

 

그런데 여기서 한가지 문제가 있다. 바로 안쪽에 있는 기댓값의 미분을 샘플링 평균으로 계산할 수 없는 것이다. 다음 수식을 보면 이유가 명확해진다.

 

\[ \begin{align} & \nabla_\theta \mathbb{E}_{\mathbf{u}_i \sim \pi_\theta (\mathbf{u}_i | \mathbf{x}_i ) } \left[ \log⁡ \pi_\theta (\mathbf{u}_i | \mathbf{x}_i ) \right] \tag{10} \\ \\ & \ \ \ \ \ = \nabla_\theta \int_{\mathbf{u}_i} \log⁡ \pi_\theta (\mathbf{u}_i | \mathbf{x}_i ) \pi_\theta (\mathbf{u}_i | \mathbf{x}_i ) d\mathbf{u}_i \\ \\ & \ \ \ \ \ = \int_{\mathbf{u}_i} (1+\log⁡ \pi_\theta (\mathbf{u}_i | \mathbf{x}_i ) ) \nabla_\theta \pi_\theta (\mathbf{u}_i | \mathbf{x}_i ) d\mathbf{u}_i \end{align} \]

 

이에 대한 해결책으로서 재파라미터화 트릭(reparameterization trick)이라는 방법을 사용한다. 이 방법에 의하면 정책을 다음과 같은 함수로 만든다.

 

\[ \mathbf{u}_i^j= \mathbf{f}_\theta (\mathbf{x}_i, \eta_j) \tag{11} \]

 

여기서 \(\eta_j\) 는 노이즈 벡터로서 보통 가우시안 분포로 가정한다. 예를 들면 정책 \(\pi_\theta (\mathbf{u}_i | \mathbf{x}_i )\) 를 평균 \(\mu_\theta (\mathbf{x}_i)\) 와 공분산 \(\sigma_\theta^2 (\mathbf{x}_i)\) 를 갖는 가우시안 분포라고 가정하면, 행동을 다음과 같이 표현할 수 있다.

 

\[ \begin{align} \mathbf{u}_i^j & \sim \mathcal{N} \left( \mu_\theta (\mathbf{x}_i ), \sigma_\theta^2 (\mathbf{x}_i ) \right) \tag{12} \\ \\ & = \mu_\theta (\mathbf{x}_i ) + \sigma_\theta (\mathbf{x}_i ) \eta_j, \ \ \ \ \ \eta_j \sim \mathcal{N}(0, I) \end{align} \]

 

그러면 손실함수 식 (8)은 다음과 같이 쓸 수 있다.

 

\[ L_\pi (\theta)= \mathbb{E}_{\mathbf{x}_i \sim \mathcal{D}} \left[ \mathbb{E}_{\eta \sim \mathcal{N}} \left[ \alpha \log \pi_\theta (\mathbf{u}_i | \mathbf{x}_i ) - Q_\phi (\mathbf{x}_i, \mathbf{u}_i ) \right] \ \mid \ \mathbf{x}_i \ \right] \tag{13} \]

 

 

 

미분의 연쇄법칙을 사용하면 손실함수 \(L_\pi (\theta)\) 의 그래디언트는 다음과 같이 계산할 수 있다.

 

\[ \begin{align} & \nabla_\theta L_\pi (\theta) = \sum_i \sum_j \nabla_\theta \left[ \alpha \log \pi_\theta (\mathbf{u}_i^j | \mathbf{x}_i ) - Q_\phi (\mathbf{x}_i, \mathbf{u}_i^j ) \right] \tag{14} \\ \\ & \ \ \ \ = \sum_i \sum_j \left[ \alpha \nabla_\theta \log \pi_\theta (\mathbf{u}_i^j | \mathbf{x}_i ) - \nabla_\theta Q_\phi (\mathbf{x}_i, \mathbf{u}_i^j ) \right] \\ \\ & \ \ \ \ = \sum_i \sum_j \begin{bmatrix} \alpha \nabla_\theta \log \pi_\theta (\mathbf{u}_i^j | \mathbf{x}_i ) \\ + \alpha \nabla_\theta \mathbf{f}_\theta (\mathbf{x}_i, \eta_j ) \nabla_{\mathbf{u}_i} \log \pi_\theta (\mathbf{u}_i^j | \mathbf{x}_i ) \\ - \nabla_\theta \mathbf{f}_\theta (\mathbf{x}_i, \eta_j ) \nabla_{\mathbf{u}_i} Q_\phi (\mathbf{x}_i, \mathbf{u}_i^j ) \end{bmatrix} \\ \\ & \ \ \ \ = \sum_i \sum_j \begin{bmatrix} \alpha \nabla_\theta \log⁡ \pi_\theta (\mathbf{u}_i^j | \mathbf{x}_i ) \\ + \nabla_\theta \mathbf{f}_\theta (\mathbf{x}_i, \eta_j ) \left( \alpha \nabla_{\mathbf{u}_i} \log⁡ \pi_\theta (\mathbf{u}_i^j | \mathbf{x}_i ) - \nabla_{\mathbf{u}_i} Q_\phi (\mathbf{x}_i, \mathbf{u}_i^j ) \right) \end{bmatrix} \end{align} \]

 

만약 재파라미터화 트릭으로 한 개의 샘플만을 추출하여 이용한다면 식 (14)는 다음과 같이 된다.

 

\[ \nabla_\theta L_\pi (\theta) = \sum_i \begin{bmatrix} \alpha \nabla_\theta \log⁡ \pi_\theta (\mathbf{u}_i | \mathbf{x}_i ) \\ + \nabla_\theta \mathbf{f}_\theta (\mathbf{x}_i, \eta_i ) \left( \alpha \nabla_{\mathbf{u}_i} \log⁡ \pi_\theta (\mathbf{u}_i | \mathbf{x}_i ) - \nabla_{\mathbf{u}_i} Q_\phi (\mathbf{x}_i, \mathbf{u}_i) \right) \end{bmatrix} \tag{15} \]

 

이제 손실함수 \(L_\pi (\theta)\) 를 최소화하는 파라미터 \(\theta\) 는 다음과 같이 경사하강법으로 구할 수 있다.

 

\[ \theta \gets \theta - \alpha_\theta \nabla_\theta L_\pi (\theta) \tag{16} \]

 

 

 

댓글0