본문 바로가기
AI 딥러닝/Sequence

HiPPO - 3

by 깊은대학 2025. 1. 12.

이전 게시글(https://pasus.tistory.com/363)에서 함수 f(x), 0<xt 를 N차원 부분 함수공간으로 투사한 근사 함수 g(x), 0<xt 를 다음과 같이 유도하였다.

 

(1)g(x)=n=0N1cn(t)(2n+1)Pn(2xt1)c˙(t)=1tAc(t)+1tBf(t)

 

식 (1)에 의하면 HiPPO는 본질적으로 연속시간(continuous-time) 에서 정의된 상미분 방정식(ODE)을 기반으로 데이터의 상태를 업데이트한다. 그러나 센서 신호, 주가 데이터 등의 실제 데이터는 대부분 샘플링(sampling)으로 수집되므로, HiPPO를 실제 시스템에 적용하려면 연속시간 히포 미분방정식(continuous-time HiPPO ODE)을 이산시간(discrete-time) 차분 방정식(difference equation)으로 바꿀 필요가 있다.

 

 

히포 미분방정식은 시간에 따라 변화하는 시변(time-varying) 시스템으로, 시간 t 에 따라 행렬 1tA1tB 가 변화한다. 따라서 일반적인 선형 시불변(linear time-invariant, LTI) 시스템과는 달리 이산화 과정에서 주의가 필요하다. 그러나 논문에서 언급된 forward Euler, backward Euler, bilinear 등의 방법은 시변 시스템에도 적용 가능하다. 하지만 ZOH (Zero-Order Hold) 방법의 경우에는 샘플링 구간 동안 행렬을 고정시키든가 또는 평균값을 사용해야 한다(https://pasus.tistory.com/321).

여기서는 forward Euler, backward Euler, bilinear에 대해서만 알아보고자 한다.

미분을 다음과 같이 근사화하는 방법을 오일러 근사(Euler's approximation)법이라고 한다.

 

(2)c˙(tk)c(tk+1)c(tk)Δt(3)c˙(tk)c(tk)c(tk1)Δt

 

여기서 t=tkk 번째 샘플링 싯점으로서 tk=kΔt 이며, Δt=tk+1tk 는 샘플링 간격이다. k 는 시간스텝(time step)을 나타내는 인덱스로서 정수값을 갖는다. 식 (2)는 c˙(tk) 를 근사화할 때 미래시간과 현재시간의 c 값을 이용했으므로 forward Euler 근사라고 하고, 식 (3)은 현재시간과 과거시간의 c 값을 이용했으므로 backward Euler 근사라고 한다.

식 (2)를 (1)에 대입하면 다음과 같다.

 

(4)c˙(tk)c(tk+1)c(tk)Δt=1tkAc(tk)+1tkBf(tk)

 

위 식은 다음과 같이 정리된다.

 

(5)c(tk+1)=(IΔttkA)c(tk)+ΔttkBf(tk)

 

이산시간 시스템의 기호를 따르기 위하여 다음과 같이 정의하면,

 

(6)ck+1=c(tk+1),    ck=c(tk),    fk=f(tk),    Δttk=ΔtkΔt=1k

 

다음과 같이 식 (1)에 대한 forward Euler 방식의 이산시간 등가 모델을 얻을 수 있다.

 

(7)ck+1=(I1kA)ck+1kBfk

 

이번에는 식 (3)을 (1)에 대입하면 다음과 같다.

 

(8)c˙(tk)c(tk)c(tk1)Δt=1tkAc(tk)+1tkBf(tk)

 

위 식을 정리하면 다음과 같다.

 

(9)c(tk)=(I+ΔttkA)1c(tk1)+(I+ΔttkA)1ΔttkBf(tk)

 

이산시간 시스템의 기호를 따르면 다음과 같이 식 (1)에 대한 backward Euler 방식의 이산시간 등가 모델을 얻을 수 있다.

 

(10)ck=(I+1kA)1ck1+(I+1kA)11kBfk

 

식 (7)과 식 (10)은 fk 가 가해지는 싯점에 차이가 있다. 식 (7)에서는 ck+1 를 계산할 때 fk 가 적용된 반면 식 (10)에서는 ck 를 계산할 때 fk 가 적용되었다. 제어 분야에서는 식 (7)의 형식이 일반적으로 사용되는 반면에, RNN 계열에서는 식 (10)의 형식이 사용된다. 식 (10) 은 N×N 행렬의 역행렬을 계산해야 하므로 투사(projection)된 차원이 클 경우 계산상의 문제를 야기할 수 있다.

 

 

Bilinear 변환은 tktk+1 또는 tktk1 사이를 평균화하는 방법이다. 식 (1)의 c(t) 항을 샘플링 싯점 t=tk 에서 다음과 같이 근사화한다.

 

(11)1tAc(t)12(1tk+1Ac(tk+1)+1tkAc(tk))(12)1tAc(t)12(1tkAc(tk)+1tk1Ac(tk1))

 

식 (11)은 미래시간과 현재시간의 값을 이용한 forward 평균 근사이고, 식 (12)는 현재시간과 과거시간의 값을 이용한 backward 평균 근사다.

식 (11)과 (2)를 식 (1)에 대입하면 다음과 같다.

 

(13)c(tk+1)c(tk)Δt=12(1tk+1Ac(tk+1)+1tkAc(tk))+1tkBf(tk)

 

위 식은 다음과 같이 정리된다.

 

(14)c(tk+1)=(I+Δt2tk+1A)1(IΔt2tkA)c(tk)+(I+Δt2tk+1A)1ΔttkBf(tk)

 

이산시간 시스템의 기호를 따르면 다음과 같이 식 (1)에 대한 forward bilinear 방식의 이산시간 등가 모델을 얻을 수 있다.

 

(15)ck+1=(I+12(k+1)A)1(I12kA)ck+(I+12(k+1)A)11kBfk

 

이번에는 식 (12)와 (3)을 식 (1)에 대입하면 다음과 같다.

 

(16)c(tk)c(tk1)Δt=12(1tkAc(tk)+1tk1Ac(tk1))+1tkBf(tk)

 

위 식은 다음과 같이 정리된다.

 

(17)c(tk)=(I+Δt2tkA)1(IΔt2tk1A)c(tk)+(I+Δt2tkA)1ΔttkBf(tk)

 

이산시간 시스템의 기호를 따르면 다음과 같이 식 (1)에 대한 backward bilinear 방식의 이산시간 등가 모델을 얻을 수 있다.

 

(18)ck=(I+12kA)1(I12(k1)A)ck1+(I+12kA)11kBfk

 

식 (15)와 식 (18)은 fk 가 가해지는 싯점에 차이가 있다. 식 (15)에서는 ck+1 를 계산할 때 fk 가 적용된 반면 식 (18)에서는 ck 를 계산할 때 fk 가 적용되었다. 또한 식 (15)와 (18)은 N×N 행렬의 역행렬을 계산해야 하므로 투사(projection)된 차원이 클 경우 계산상의 문제를 야기할 수 있다.

계수 차분방정식 (8), (10), (15), (18)에 의하면 수식 상에 샘플링 간격 Δt 가 없으므로 샘플링 간격에 의존하지 않는다는 것을 알 수 있다. 또한 HiPPO는 시간 스케일에 독립적(timescale invariant)이며 어떤 시간 범위에서도 성능을 유지할 수 있다. 시간 스케일에 의존적이면 초당 1회 데이터를 수집하던 센서가 초당 10회 데이터를 수집하면 성능이 저하될 수도 있다.

HiPPO가 시간 스케일에 독립적이라는 것은 다음과 같이 증명할 수 있다. 함수 f(x) 에서 시간 스케일이 변경된 함수를 f~(x)=f(αx) 라고 하자. 그러면 계수 적분방정식에 의해서 f~(x) 의 계수 c~n(t) 는 다음과 같이 계산할 수 있다.

 

(19)c~n(t)=0tf~(x)(2n+1)Pn(2xt1)1t dx=0tf(αx)(2n+1)Pn(2αxαt1)1αtα dx=0αtf(y)(2n+1)Pn(2yαt1)1αt dy=cn(αt)

 

여기서 y=αx 로 변수를 변환했다. 따라서 함수의 시간 스케일을 바꾸면 동일한 스케일로 계수도 바뀐다는 것을 알 수 있다.

HiPPO는 시간에 따라 변화하는 데이터를 직교 다항식 기반으로 최적화하여 식 (8), (10), (15) 또는 (18)을 이용하여 계수 ck=c(tk) 의 형태로 함축적으로 저장한다. 그리고 원래 함수 f(x)c(tk) 를 이용하여 다음과 같은 근사 함수 g(x) 로 복원할 수 있다.

 

(20)g(xi)=n=0N1cn(tk)(2n+1)Pn(2xitk1),    xitk

 

다음은 x[0, 150] 구간에서 Δt=0.1, N=10 을 사용하여 f(x)=cos(x20)sin(x5) 을 HiPPO 로 근사한 것이다. 아래 그림은 tk=75 에서 근사한 것이다.

 

 

다음 그림은 tk=150 에서 근사한 것이다.

 

 

다음 그림은 N=20 을 사용하여 tk=150 에서 근사한 것이다.

 

 

 

 

다음은 해당 매트랩 코드다.

 

% Comparison among HiPPO ODE, forward Euler, backward Euler, 
% and Bilinear forwad and backward methods
%
% (c) st.watermelon

clear; 

% HiPPO setup
N = 10; 
tf = 150;
dt = 0.1; % time interval
t_step = dt:dt:tf;

% HiPPO A, B
[A, B] = hippo_AB(N);

% input signal
f = @(t) cos(t/20).*sin(t/5);

ft_orig = f(t_step);
ft = ft_orig + 0.1 * randn(1,length(t_step));

% HiPPO ODE 
hippo_ode = @(t, c) (-A / t) * c + (B / t) * f(t);

% initial condition
c0 = zeros(N, 1);

% 0. HiPPO ODE
[time, c_ode] = ode45(hippo_ode, t_step, c0);


% 1. forward Euler
c_euler_fwd = zeros(N, 1);
c_euler_fwd_history = zeros(length(t_step), N); 
for k = 1:length(t_step)-1
    c_euler_fwd = (eye(N) - A/k) * c_euler_fwd + (B/k) * ft(k); 
    c_euler_fwd_history(k+1, :) = c_euler_fwd'; 
end

% 2. backward Euler
c_euler_bwd = zeros(N, 1);
c_euler_bwd_history = zeros(length(t_step), N);
for k = 2:length(t_step)
    M = eye(N) + A/k;
    invM = inv(M);
    c_euler_bwd = invM * c_euler_bwd + invM * (B/k) * ft(k); 
    c_euler_bwd_history(k, :) = c_euler_bwd'; 
end

% 3. bilinear forward 
c_bi_fwd = zeros(N, 1);
c_bi_fwd_history = zeros(length(t_step), N);
for k = 1:length(t_step)-1
    M = eye(N) + A/(2*(k+0));
    invM = inv(M);
    c_bi_fwd = invM * (eye(N) - A/(2*k)) * c_bi_fwd + invM * (B/k) * ft(k); 
    c_bi_fwd_history(k+1, :) = c_bi_fwd'; 
end

% 4. bilinear backward 
c_bi_bwd = zeros(N, 1);
c_bi_bwd_history = zeros(length(t_step), N);
for k = 2:length(t_step)
    M = eye(N) + A/(2*k);
    invM = inv(M);
    c_bi_bwd = invM * (eye(N) - A/(2*(k-0))) * c_bi_bwd + invM * (B/k) * ft(k); 
    c_bi_bwd_history(k, :) = c_bi_bwd'; 
end


%%
% fitting

cut = round(length(t_step)*1/2);
g = hippo_fit(c_bi_fwd_history(1:cut,:), dt);


% plotting
figure, plot(t_step, ft, '.', 'color', [.7 .7 .7]), hold on
plot(t_step, ft_orig, 'r')
plot(t_step(1:cut), g,'b', 'LineWidth', 1.2)
legend('noisy f(t)','f(t)','g(t)')
xlabel('Time');
title(['N = ' num2str(N)]);

%%

function g = hippo_fit(c, dt)
%
% g = hippo_fit(c, dt)
% HiPPO fitting
% input
%   c(t): coefficients
%   dt: sampling time interval
% output
%   g: fitted function
%
% (c) st.watermelon

[k, N] = size(c);

g = [];
for ii=1:k
    x = ii*dt;
    sum = 0;
    for n=0:N-1
        p=LegendrePoly(n);
        Pn=polyval(p,2*x/(k*dt)-1);
        sum = sum + c(k, n+1)*sqrt(2*n+1)*Pn;
    end
    g = [g; sum];
end

end


%%
function lpoly = LegendrePoly(n)
%
% lpoly = LegendrePoly(n)
% N-th order Legendre polynomials
%
% input: order of Legendre polynomials
% output: polynomial coefficients
%
% (c) st.watermelon

for ii=0:n
    
    if ii==0
        p{ii+1}=1;
    elseif ii==1
        p{ii+1}=[1 0];
    else
        p{ii+1}=(2*ii-1)/ii*[p{ii} 0]-(ii-1)/ii*[0 0 p{ii-1}];
    end
    
end

lpoly=p{n+1};

end

'AI 딥러닝 > Sequence' 카테고리의 다른 글

HiPPO - 2  (0) 2025.01.09
HiPPO - 1  (0) 2025.01.08
[PtrNet] Pointer Net 구조  (0) 2023.09.12
[seq2seq] 어텐션이 포함된 seq2seq 모델  (0) 2023.08.23
[seq2seq] 간단한 seq2seq 모델 구현  (0) 2023.08.17

댓글