시작은 미약하였으나 , 그 끝은 창대하리라

AutoFormer 코드 설명 및 적용. 본문

인공지능/딥러닝 사이드 Project

AutoFormer 코드 설명 및 적용.

애플파ol 2023. 7. 16. 22:50

 

참고글: https://huggingface.co/blog/autoformer

Autoformer 공식 github: https://github.com/thuml/Autoformer

논문: https://arxiv.org/pdf/2106.13008.pdf  

필자의 Autoformer_encdoer를 이용한 시계열 예측  github: https://github.com/YongTaeIn/Autoformer_encoder_time_series/tree/master

✓ Auto Former 는 크게 두가지 기술이 들어간다. 

    -> 그리고 아래 두가지 기술에 대해 적용하는 방법을 설명할 것이다.

1. Series Decomposition (시계열 분해)

2  Attention layer - > Auto Correlation

 

출처: 논문 https://arxiv.org/pdf/2106.13008.pdf

✓ Series Decomposition 코드 및 설명.

<코드>

import torch
from torch import nn

class DecompositionLayer(nn.Module):
    """
    Returns the trend and the seasonal parts of the time series.
    """

    def __init__(self, kernel_size):
        super().__init__()
        self.kernel_size = kernel_size
        self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=1, padding=0) # moving average 

    def forward(self, x):
        """Input shape: Batch x Time x EMBED_DIM"""
        # padding on the both ends of time series
        num_of_pads = (self.kernel_size - 1) // 2
        front = x[:, 0:1, :].repeat(1, num_of_pads, 1)
        end = x[:, -1:, :].repeat(1, num_of_pads, 1)
        x_padded = torch.cat([front, x, end], dim=1)

        # calculate the trend and seasonal part of the series
        x_trend = self.avg(x_padded.permute(0, 2, 1)).permute(0, 2, 1)
        x_seasonal = x - x_trend
        return x_seasonal, x_trend

 

trend= 추세(경향) ,seasonal = 계절성(반복되는 특징)

Moving average(이동 평균)을 위해 AvgPool을 사용했다.(공식 github를 보면 seasonal 만 사용함.)

무튼 위의 코드를 사용하면 x_seasonal, x_trend 가 나온다. 

 

 

 

 

✓ Auto correlation 코드 및 설명.

     ➢ Attention (Autocorrelation) Mechanism

(a) Vanilla self attention vs (b) Autocorrelation mechanism

일반적인 self attetntion은 좌측과 같이 point wise를 통해 attention을 뽑아내는 반면 우측의 auto-correlation은 series wise , sub series간 종석성을 찾아서 수행한다.

   

 

   ➢ Frequency Domain Mechanism

 

FT : Fourier Transform (푸리에 변환) 을 적용한 것이다... 뭐라해야하지..학부때 들었던 디신처(디지털신호처리), 통신이론 수업때 배운 기억을 꺼내보면 푸리에 변환이란, 시간축의 신호를 주파수 축의 신호로 바꾸는 것이다. 

왜 할까? 싶을수도 있다. 내가 해줄수 있는 답변은... 아마 교수님이 말씀해준 말이겠지?  그 이유는 시간축상에서 해석이 안되는 신호를 주파수 축으로 바꿔서 해석하면 해석이 가능하기 때문이다. 

(inverse 가 붙는 것은 말그래도 다시 원래 상태로 돌리는 과정이다. )

 

FFT : Fast Fourier Transform (고속 푸리에 변환) : 어렵지 않다. 고속 이라는것이다. 시간 복잡도를 줄이기 위해 FT 가 아닌 FFT를 사용한 것이다. 왜냐? 트랜스포머 모델의 계산량을 줄이고 싶기때문이다. 

 

 

    ➢ Time Delay Aggregation

 

바닐라 트랜스포머는 Scaled Dot-Product Attention 으로 계산을 해나가지만 Auto correlation 에서는 위의 그림과 같은 과정을 통해서 진행하다. 

V(value) 에 대해서 rolling을 수행한다. rolling 은 앞에 부분을 잘라서 뒷부분을 연결하는 과정으로  τ_1, τ_2 등과 같이 규칙이 반복되는 부분을 기준으로 실행이 된다.

(τ 의 값은 R_Q,K 를 통해 구한것중 가장 높은값  K개 를 구하는 것이다  수식(1)),

그리고 그렇게 구한 값들을 softmax를 취한 후 (수식2) ,

마지막으로  곱한후 가중합을 하여 구한다 수식(3).

 

 

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import math
from math import sqrt
import os


class AutoCorrelation(nn.Module):
    """
    AutoCorrelation Mechanism with the following two phases:
    (1) period-based dependencies discovery
    (2) time delay aggregation
    This block can replace the self-attention family mechanism seamlessly.
    """
    def __init__(self, mask_flag=True, factor=1, scale=None, attention_dropout=0.1, output_attention=False):
        super(AutoCorrelation, self).__init__()
        self.factor = factor
        self.scale = scale
        self.mask_flag = mask_flag
        self.output_attention = output_attention
        self.dropout = nn.Dropout(attention_dropout)

    def time_delay_agg_training(self, values, corr):
        """
        SpeedUp version of Autocorrelation (a batch-normalization style design)
        This is for the training phase.
        """
        head = values.shape[1]
        channel = values.shape[2]
        length = values.shape[3]
        # find top k
        top_k = int(self.factor * math.log(length))
        mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
        index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1]
        weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1)
        # update corr
        tmp_corr = torch.softmax(weights, dim=-1)
        # aggregation
        tmp_values = values
        delays_agg = torch.zeros_like(values).float()
        for i in range(top_k):
            pattern = torch.roll(tmp_values, -int(index[i]), -1)
            delays_agg = delays_agg + pattern * \
                         (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length))
        return delays_agg

    def time_delay_agg_inference(self, values, corr):
        """
        SpeedUp version of Autocorrelation (a batch-normalization style design)
        This is for the inference phase.
        """
        batch = values.shape[0]
        head = values.shape[1]
        channel = values.shape[2]
        length = values.shape[3]
        # index init
        init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0)\
            .repeat(batch, head, channel, 1).to(values.device)
        # find top k
        top_k = int(self.factor * math.log(length))
        mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
        weights, delay = torch.topk(mean_value, top_k, dim=-1)
        # update corr
        tmp_corr = torch.softmax(weights, dim=-1)
        # aggregation
        tmp_values = values.repeat(1, 1, 1, 2)
        delays_agg = torch.zeros_like(values).float()
        for i in range(top_k):
            tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)
            pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
            delays_agg = delays_agg + pattern * \
                         (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length))
        return delays_agg

    def time_delay_agg_full(self, values, corr):
        """
        Standard version of Autocorrelation
        """
        batch = values.shape[0]
        head = values.shape[1]
        channel = values.shape[2]
        length = values.shape[3]
        # index init
        init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0)\
            .repeat(batch, head, channel, 1).to(values.device)
        # find top k
        top_k = int(self.factor * math.log(length))
        weights, delay = torch.topk(corr, top_k, dim=-1)
        # update corr
        tmp_corr = torch.softmax(weights, dim=-1)
        # aggregation
        tmp_values = values.repeat(1, 1, 1, 2)
        delays_agg = torch.zeros_like(values).float()
        for i in range(top_k):
            tmp_delay = init_index + delay[..., i].unsqueeze(-1)
            pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
            delays_agg = delays_agg + pattern * (tmp_corr[..., i].unsqueeze(-1))
        return delays_agg

    def forward(self, queries, keys, values, attn_mask):
        B, L, H, E = queries.shape
        _, S, _, D = values.shape
        if L > S:
            zeros = torch.zeros_like(queries[:, :(L - S), :]).float()
            values = torch.cat([values, zeros], dim=1)
            keys = torch.cat([keys, zeros], dim=1)
        else:
            values = values[:, :L, :, :]
            keys = keys[:, :L, :, :]

        # period-based dependencies
        q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1)
        k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1)
        res = q_fft * torch.conj(k_fft)
        corr = torch.fft.irfft(res, n=L, dim=-1)

        # time delay agg
        if self.training:
            V = self.time_delay_agg_training(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)
        else:
            V = self.time_delay_agg_inference(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)

        if self.output_attention:
            return (V.contiguous(), corr.permute(0, 3, 1, 2))
        else:
            return (V.contiguous(), None)


class AutoCorrelationLayer(nn.Module):
    def __init__(self, correlation, d_model, n_heads, d_keys=None,
                 d_values=None):
        super(AutoCorrelationLayer, self).__init__()

        d_keys = d_keys or (d_model // n_heads)
        d_values = d_values or (d_model // n_heads)

        self.inner_correlation = correlation
        self.query_projection = nn.Linear(d_model, d_keys * n_heads)
        self.key_projection = nn.Linear(d_model, d_keys * n_heads)
        self.value_projection = nn.Linear(d_model, d_values * n_heads)
        self.out_projection = nn.Linear(d_values * n_heads, d_model)
        self.n_heads = n_heads

    def forward(self, queries, keys, values, attn_mask):
        B, L, _ = queries.shape
        _, S, _ = keys.shape
        H = self.n_heads

        queries = self.query_projection(queries).view(B, L, H, -1)
        keys = self.key_projection(keys).view(B, S, H, -1)
        values = self.value_projection(values).view(B, S, H, -1)

        out, attn = self.inner_correlation(
            queries,
            keys,
            values,
            attn_mask
        )
        out = out.view(B, L, -1)

        return self.out_projection(out), attn

위와 같이 코드가 구성이 되여 있다.

아래코드는 필자의 github에서 encoder_layer 부분의 코드만 가져온것이다. mask_flag를 안써서 False로 해둔것이다)

import torch
import torch.nn as nn

from layers.multi_head_attention import MultiHeadAttention
from layers.positioin_wise import PositionWise
from layers.AutoCorrelation import AutoCorrelation,AutoCorrelationLayer


class EncoderLayer(nn.Module):
    def __init__(self, d_model, head, d_ff, dropout):
        super().__init__()
        #self.attention = MultiHeadAttention(d_model,head)
        self.attention_1= AutoCorrelationLayer(AutoCorrelation(False,factor=4, 
                                attention_dropout=0.1,
                                output_attention=False),
                                d_model=d_model,n_heads=head)

        self.layerNorm1 = nn.LayerNorm(d_model)

        self.ffn = PositionWise(d_model,d_ff)
        self.layerNorm2 = nn.LayerNorm(d_model)

        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):  # (self, x, padding_mask)
        # residual connection을 위해 잠시 담아둔다.
        residual = x

        # 1. multi-head attention (self attention)
        # x, attention_score = self.attention(q=x, k=x, v=x) # (q=x, k=x, v=x, mask=padding_mask)
        x, attention_score = self.attention_1(queries=x, keys=x, values=x, attn_mask=None)
        
        # 2. add & norm
        x = self.dropout(x) + residual
        x = self.layerNorm1(x)

        residual = x

        # 3. feed-forward network
        x = self.ffn(x)

        # 5. add & norm
        x = self.dropout(x) + residual
        x = self.layerNorm2(x)

        return x, attention_score

 

 

 

 

전체 코드는 Github를 통해 사용법을 확인할 수 있으며 나는 autoformer에  encoder 부분만 사용하기때문에 encoder부분으로만 결과를 추출하는 코드로 작성되어있다. 

 

 

필자의 Autoformer_encdoer를 이용한 시계열 예측  github: https://github.com/YongTaeIn/Autoformer_encoder_time_series/tree/master

Comments