인공지능 (Deep Learning)/딥러닝 및 파이토치 기타 정리
[1차원 추가 및 제거] Pytorch squeeze / unsqueeze
애플파ol
2023. 12. 23. 13:09
상황: 딥러닝을 하다보면 1차원을 추가 및 제거를 위한 단계가 필요하게 된다.
1. squeeze : 1차원 제거 역할.
import torch
x= torch.ones(5,4,1,4,1)
x1 = x.squeeze() # 모든 1차원 제거
print(x1.size()) # torch.Size([5, 4, 4])
x2= x.squeeze(dim = 2)
print(x2.size()) # torch.Size([5, 4, 4, 1])
x3= x.squeeze(dim = -1) # dim=4 와 같음
print(x3.size()) # torch.Size([5, 4, 1, 4])
x4= x.squeeze(dim = 1) # 잘못된 차원 삭제 불가능.
print(x4.size()) # torch.Size([5, 4, 1, 4, 1])
참고 : torch.squeeze(tensor, dim) 형태도 가능.
2. unsqueeze : 1차원 추가 역할.
import torch
# 초기 텐서 생성
x = torch.ones(7,3,4)
# unsqueeze로 차원 추가
x1 = x.unsqueeze(dim=2) # 2번째 차원에 1차원 추가
print(x1.size()) # torch.Size([7, 3, 1, 4])
x2 = x.unsqueeze(dim=-1) # 마지막 차원에 1차원 추가
print(x2.size()) # torch.Size([7, 3, 4, 1])
x3 = x.unsqueeze(dim=0) # 첫 번째 차원에 1차원 추가
print(x3.size()) # torch.Size([1, 7, 3, 4])
x4 = x.unsqueeze(dim=4) # 차원 넘어가면 오류발생.
print(x4.size()) # 오류 발생.
참고 : torch.unsqueeze(tensor, dim) 형태도 가능.