시작은 미약하였으나 , 그 끝은 창대하리라

[1차원 추가 및 제거] Pytorch squeeze / unsqueeze 본문

인공지능/딥러닝 및 파이토치 기타 정리

[1차원 추가 및 제거] Pytorch squeeze / unsqueeze

애플파ol 2023. 12. 23. 13:09

상황: 딥러닝을 하다보면 1차원을 추가 및 제거를 위한 단계가 필요하게 된다. 

 

1. squeeze : 1차원 제거 역할.

import torch


x= torch.ones(5,4,1,4,1)

x1 = x.squeeze() # 모든 1차원 제거
print(x1.size()) # torch.Size([5, 4, 4])

x2= x.squeeze(dim = 2)
print(x2.size())   # torch.Size([5, 4, 4, 1])

x3= x.squeeze(dim = -1) # dim=4 와 같음
print(x3.size())  # torch.Size([5, 4, 1, 4])

x4= x.squeeze(dim = 1)  # 잘못된 차원 삭제 불가능.
print(x4.size())  # torch.Size([5, 4, 1, 4, 1])

참고 : torch.squeeze(tensor, dim) 형태도 가능.

 

2. unsqueeze : 1차원 추가 역할.

import torch

# 초기 텐서 생성
x = torch.ones(7,3,4)

# unsqueeze로 차원 추가
x1 = x.unsqueeze(dim=2)  # 2번째 차원에 1차원 추가
print(x1.size())  # torch.Size([7, 3, 1, 4])

x2 = x.unsqueeze(dim=-1)  # 마지막 차원에 1차원 추가 
print(x2.size())  # torch.Size([7, 3, 4, 1])

x3 = x.unsqueeze(dim=0)  # 첫 번째 차원에 1차원 추가
print(x3.size())  # torch.Size([1, 7, 3, 4])

x4 = x.unsqueeze(dim=4)  # 차원 넘어가면 오류발생.
print(x4.size())  # 오류 발생.

참고 : torch.unsqueeze(tensor, dim) 형태도 가능.

Comments