시작은 미약하였으나 , 그 끝은 창대하리라

[Pytorch 스크래치 코드] Train Test split 본문

인공지능/딥러닝 스크래치 코드

[Pytorch 스크래치 코드] Train Test split

애플파ol 2023. 12. 10. 20:03

 

✓ Tensorflow 에는 Train Test split 이라는 API가 존재하지만 Pytorch에서는 존재하지 않는다.

✓ 대표적인 데이터셋은 처음부터 Train, Validation, Test 를 주어서 데이터를 분리할 필요가 없다, 하지만 우리가 직접 데이터를 수집해서 학습을 한다면 Train, Validation, Test 을 나눠야한다. 

✓ 이를 위해 이번 글에서는 Pytorch 를 사용하여 Train Validaion Test로 나누는 방법을 소개한다. 

 

➢  주의 사항 : 실험을 동일하게 재현하고 싶으면 random seed 값을 고정하고 돌려야함. 

 

❏ Train, Validation, Test 로 나누기

from torch.utils.data import DataLoader, random_split


# dataset Load
dataset=CustomDataset("./dataset.csv")

# 전체 데이터셋 크기
dataset_size = len(dataset)

# 훈련 데이터셋 크기
train_ratio = 0.6
train_size = int(train_ratio * dataset_size)

# 검증 데이터셋 크기
val_ratio = 0.2
val_size = int(val_ratio * dataset_size)

# 테스트 데이터셋 크기 (나머지 부분)
test_size = dataset_size - train_size - val_size

# 훈련, 검증, 테스트 데이터셋으로 나누기
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

# DataLoader를 사용하여 데이터를 배치 단위로 로드
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

 

❏ Train, Validation 로 나누기 (다른 지정된 Test 셋이 있을때)

from torch.utils.data import DataLoader, random_split

# dataset Load
dataset=CustomDataset("./dataset.csv")

# 전체 데이터셋 크기
dataset_size = len(dataset)

# train,test 분리 
val_ratio = 0.2 
val_size = int(val_ratio * len(dataset))
train_size = len(dataset) - val_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])


# Create dataloaders for train and validation
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

 

❏ seed 값 고정: https://put-idea.tistory.com/88

 

[Pytorch 스크래치 코드] 실험 재현을 위한 Seed 고정

✓ ablation study를 진행하다보면 다른 seed값이 모두 동일하게 고정적으로 설정이 되있어야 한다. ✓ 이를 방지하기 위해 우리는 seed 값 고정을 하여 수행한다. ❏ 실험 재현을 위한 seed 값 고정 코

put-idea.tistory.com

 

 

참고하면 좋은 글:  https://put-idea.tistory.com/85

 

[Pytorch 스크래치 코드] Custom Dataset, DataLoader

➢ Pytorch 에서 dataset을 쉽게 다룰 수 있도록 모듈을 제공하고 있다. 아래와 같이 두가지 Step으로 구성 된다. ❏ (Step1) CustomDataset 뼈대 from torch.utils.data import Dataset # 데이터셋 상속 class CustomDataset_na

put-idea.tistory.com

 

Comments