시작은 미약하였으나 , 그 끝은 창대하리라

[Pytorch 스크래치 코드] Early Stopping 본문

인공지능/딥러닝 스크래치 코드

[Pytorch 스크래치 코드] Early Stopping

애플파ol 2023. 12. 9. 13:49

➢ Loss, Accuracy 에 따라 코드가 일부 상이함으로 주의.

 

✓ Pytorch 는 Tensorflow 처럼 Early stopping API 를 제공하지 않는다. 결국 early stopping을 위해 Pytorch 에서는 직접 구현을 해야 함. 

 

✓ train.py (학습코드) 에 직접 추가할수도 있지만, 필자는 가독성을 위해서 class로 만들어서 train.py에 호출해서 사용한다. 

 

✓ 이해를 위한 선행 지식 (예시 참고) :

  • 'def __init__ ' 기능 : class 인스턴스를 생성 할 때, 관련된 데이터를 초기화(initialization) 하는 함수
  • 'def __call__ ' 기능 : 해당 인스턴스를 함수처럼 호출 가능. 
# 이해를 위한 예시 코드

# 클래스 생성
class ExampleClass:
    def __init__(self, value):
        self.value = value

    def __call__(self, x):
        return self.value * x


# 인스턴스 생성 및 value=5 초기화
example_instance = ExampleClass(5)

# __call__ method(함수) 호출
result = example_instance(3)

# 결과 출력 =15
print(result)

 

 

 

Val_loss Early Stopping : validation loss 값 기준 으로 early stopping 수행.

import torch
import numpy as np

class EarlyStopping:
    """주어진 patience 이후로 validation loss가 개선되지 않으면 학습을 조기 중지"""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'):
        """
        Args:
            patience (int): validation loss가 개선된 후 기다리는 기간
                            Default: 7
            verbose (bool): True일 경우 각 validation loss의 개선 사항 메세지 출력
                            Default: False
            delta (float): 개선되었다고 인정되는 monitered quantity의 최소 변화
                            Default: 0
            path (str): checkpoint저장 경로
                            Default: 'checkpoint.pt'
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''validation loss가 감소하면 모델을 저장한다.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

 

 

 Val_acc Early Stopping : validation accuracy 값 기준 으로 early stopping 수행.

→ def __call__ 부분의 score = val_accuracy 부분이 변경됨.

import torch
import numpy as np

class EarlyStopping:
    """주어진 patience 이후로 validation accuracy가 개선되지 않으면 학습을 조기 중지"""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'):
        """
        Args:
            patience (int): validation loss가 개선된 후 기다리는 기간
                            Default: 7
            verbose (bool): True일 경우 각 validation loss의 개선 사항 메세지 출력
                            Default: False
            delta (float): 개선되었다고 인정되는 monitered quantity의 최소 변화
                            Default: 0
            path (str): checkpoint저장 경로
                            Default: 'checkpoint.pt'
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
    
    def __call__(self, val_accuracy, model):
        score = val_accuracy

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_accuracy, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_accuracy, model)
            self.counter = 0

    def save_checkpoint(self, val_accuracy, model):
        '''validation accuracy가 증가하면 모델을 저장한다.'''
        if self.verbose:
            print(f'Validation accuracy increased ({self.best_score:.6f} --> {val_accuracy:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)

 

 

 학습코드에 적용 예시 :

( 아래 적용 코드를 이해하기 위한 환경 이해) 

✓ model 명 = vit 임.

✓ 데이터 로더 = trainloader, valloader

✓ layers 폴더안에 Earlystopping.py 파일 존재. class명은 EarlyStopping 임.

✓ 디렉토리 구조

    - layers

    I   -Earlystopping.py

    I

    -  train.py

 

 

1 ) Val_loss Early Stopping 사용시 적용 예시

# train.py 

from layers.Earlystopping import EarlyStopping

###
###  기타 코드들 생략.
###  def train, def evaluate, 등등
###

patience=10
early_stopping = EarlyStopping(patience = patience, verbose = True)

os.makedirs('./pt', exist_ok=True)

best_val_loss = float('inf')  # Initialize with a large value to ensure the first validation loss will be lower

for epoch in range(1,epochs+1):
    train(vit,trainloader,optimizer,log_interval=5)
    val_loss,val_accuracy=evaluate(vit,valloader)
    print("\n[Epoch: {}],\t Test Loss : {:.4f},\tTest Accuracy :{:.2f} % \n".format
          (epoch, val_loss,val_accuracy))
    
    early_stopping(val_loss, vit)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        # Save the model when the validation loss improves
        model_path = f"{'./pt/'}model_epoch_{epoch}_Accuracy_{val_accuracy:.2f}.pt"
        torch.save(vit.state_dict(), model_path)    

    if early_stopping.early_stop:
        print("Early stopping")
        break

 

2) Val_acc Early Stopping 사용시 적용 예시

 

# train.py 

from layers.Earlystopping import EarlyStopping

###
###  기타 코드들 생략.
###  def train, def evaluate, 등등
###

patience=10
early_stopping = EarlyStopping(patience = patience, verbose = True)

os.makedirs('./pt', exist_ok=True)

best_val_accuracy = 0.0  # 최초의 최고 정확도를 나타내는 값으로 초기화

for epoch in range(1, epochs+1):
    train(vit, trainloader, optimizer, log_interval=5)
    val_loss, val_accuracy = evaluate(vit, valloader)
    print("\n[Epoch: {}],\t Test Loss : {:.4f},\tTest Accuracy : {:.2%}\n".format(epoch, val_loss, val_accuracy))

    early_stopping(val_accuracy, vit)

    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        # Save the model when the validation accuracy improves
        model_path = f"./pt/model_epoch_{epoch}_Accuracy_{val_accuracy:.2%}.pt"
        torch.save(vit.state_dict(), model_path)

    if early_stopping.early_stop:
        print("Early stopping")
        break
Comments