➢ Loss, Accuracy 에 따라 코드가 일부 상이함으로 주의.
✓ Pytorch 는 Tensorflow 처럼 Early stopping API 를 제공하지 않는다. 결국 early stopping을 위해 Pytorch 에서는 직접 구현을 해야 함.
✓ train.py (학습코드) 에 직접 추가할수도 있지만, 필자는 가독성을 위해서 class로 만들어서 train.py에 호출해서 사용한다.
✓ 이해를 위한 선행 지식 (예시 참고) :
- 'def __init__ ' 기능 : class 인스턴스를 생성 할 때, 관련된 데이터를 초기화(initialization) 하는 함수
- 'def __call__ ' 기능 : 해당 인스턴스를 함수처럼 호출 가능.
# 이해를 위한 예시 코드
# 클래스 생성
class ExampleClass:
def __init__(self, value):
self.value = value
def __call__(self, x):
return self.value * x
# 인스턴스 생성 및 value=5 초기화
example_instance = ExampleClass(5)
# __call__ method(함수) 호출
result = example_instance(3)
# 결과 출력 =15
print(result)
❏ Val_loss Early Stopping : validation loss 값 기준 으로 early stopping 수행.
import torch
import numpy as np
class EarlyStopping:
"""주어진 patience 이후로 validation loss가 개선되지 않으면 학습을 조기 중지"""
def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'):
"""
Args:
patience (int): validation loss가 개선된 후 기다리는 기간
Default: 7
verbose (bool): True일 경우 각 validation loss의 개선 사항 메세지 출력
Default: False
delta (float): 개선되었다고 인정되는 monitered quantity의 최소 변화
Default: 0
path (str): checkpoint저장 경로
Default: 'checkpoint.pt'
"""
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = np.Inf
self.delta = delta
self.path = path
def __call__(self, val_loss, model):
score = -val_loss
if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss, model)
elif score < self.best_score + self.delta:
self.counter += 1
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_loss, model)
self.counter = 0
def save_checkpoint(self, val_loss, model):
'''validation loss가 감소하면 모델을 저장한다.'''
if self.verbose:
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
torch.save(model.state_dict(), self.path)
self.val_loss_min = val_loss
❏ Val_acc Early Stopping : validation accuracy 값 기준 으로 early stopping 수행.
→ def __call__ 부분의 score = val_accuracy 부분이 변경됨.
import torch
import numpy as np
class EarlyStopping:
"""주어진 patience 이후로 validation accuracy가 개선되지 않으면 학습을 조기 중지"""
def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'):
"""
Args:
patience (int): validation loss가 개선된 후 기다리는 기간
Default: 7
verbose (bool): True일 경우 각 validation loss의 개선 사항 메세지 출력
Default: False
delta (float): 개선되었다고 인정되는 monitered quantity의 최소 변화
Default: 0
path (str): checkpoint저장 경로
Default: 'checkpoint.pt'
"""
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = np.Inf
self.delta = delta
self.path = path
def __call__(self, val_accuracy, model):
score = val_accuracy
if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_accuracy, model)
elif score < self.best_score + self.delta:
self.counter += 1
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_accuracy, model)
self.counter = 0
def save_checkpoint(self, val_accuracy, model):
'''validation accuracy가 증가하면 모델을 저장한다.'''
if self.verbose:
print(f'Validation accuracy increased ({self.best_score:.6f} --> {val_accuracy:.6f}). Saving model ...')
torch.save(model.state_dict(), self.path)
❏ 학습코드에 적용 예시 :
( 아래 적용 코드를 이해하기 위한 환경 이해)
✓ model 명 = vit 임.
✓ 데이터 로더 = trainloader, valloader
✓ layers 폴더안에 Earlystopping.py 파일 존재. class명은 EarlyStopping 임.
✓ 디렉토리 구조
- layers
I -Earlystopping.py
I
- train.py
1 ) Val_loss Early Stopping 사용시 적용 예시
# train.py
from layers.Earlystopping import EarlyStopping
###
### 기타 코드들 생략.
### def train, def evaluate, 등등
###
patience=10
early_stopping = EarlyStopping(patience = patience, verbose = True)
os.makedirs('./pt', exist_ok=True)
best_val_loss = float('inf') # Initialize with a large value to ensure the first validation loss will be lower
for epoch in range(1,epochs+1):
train(vit,trainloader,optimizer,log_interval=5)
val_loss,val_accuracy=evaluate(vit,valloader)
print("\n[Epoch: {}],\t Test Loss : {:.4f},\tTest Accuracy :{:.2f} % \n".format
(epoch, val_loss,val_accuracy))
early_stopping(val_loss, vit)
if val_loss < best_val_loss:
best_val_loss = val_loss
# Save the model when the validation loss improves
model_path = f"{'./pt/'}model_epoch_{epoch}_Accuracy_{val_accuracy:.2f}.pt"
torch.save(vit.state_dict(), model_path)
if early_stopping.early_stop:
print("Early stopping")
break
2) Val_acc Early Stopping 사용시 적용 예시
# train.py
from layers.Earlystopping import EarlyStopping
###
### 기타 코드들 생략.
### def train, def evaluate, 등등
###
patience=10
early_stopping = EarlyStopping(patience = patience, verbose = True)
os.makedirs('./pt', exist_ok=True)
best_val_accuracy = 0.0 # 최초의 최고 정확도를 나타내는 값으로 초기화
for epoch in range(1, epochs+1):
train(vit, trainloader, optimizer, log_interval=5)
val_loss, val_accuracy = evaluate(vit, valloader)
print("\n[Epoch: {}],\t Test Loss : {:.4f},\tTest Accuracy : {:.2%}\n".format(epoch, val_loss, val_accuracy))
early_stopping(val_accuracy, vit)
if val_accuracy > best_val_accuracy:
best_val_accuracy = val_accuracy
# Save the model when the validation accuracy improves
model_path = f"./pt/model_epoch_{epoch}_Accuracy_{val_accuracy:.2%}.pt"
torch.save(vit.state_dict(), model_path)
if early_stopping.early_stop:
print("Early stopping")
break
'인공지능 (기본 딥러닝) > 딥러닝 스크래치 코드' 카테고리의 다른 글
[Pytorch 스크래치 코드] 회귀문제 Train, Validation 함수 (1) | 2023.12.17 |
---|---|
[Pytorch 스크래치 코드] 분류문제 Train, Validation 함수 (0) | 2023.12.16 |
[Pytorch 스크래치 코드] Train Test split (1) | 2023.12.10 |
[Pytorch 스크래치 코드] 실험 재현을 위한 Seed 고정 (0) | 2023.12.09 |
[Pytorch 스크래치 코드] Custom Dataset, DataLoader (1) | 2023.11.26 |