PyTorch로 딥러닝 모델을 다루다 보면 차원(Dimension) 을 추가하거나 제거해야 할 때가 자주 있습니다.
이때 자주 등장하는 것이 바로 unsqueeze() 와 squeeze() 함수입니다.
torch.unsqueeze()
" 차원 하나를 살짝 껴 넣는다!"
- unsqueeze는 특정 위치(dim) 에 새로운 차원을 추가합니다.
- 추가된 차원은 크기가 1입니다.
1. torch.unsqueeze(x, dim=0)
x = torch.tensor([1, 2, 3])
# 0번째 축에 차원 추가
x_unsqueezed = torch.unsqueeze(x, dim=0)
print(x.shape) # torch.Size([3])
print(x_unsqueezed.shape) # torch.Size([1, 3])
2. torch.unsqueeze(x, dim=1)
x = torch.tensor([1, 2, 3])
# 1번째 축에 차원 추가
x_unsqueezed = torch.unsqueeze(x, dim=1)
print(x.shape) # torch.Size([3])
print(x_unsqueezed.shape) # torch.Size([3, 1])
torch.squeeze()
" 쓸모없는 1을 짜내서 없애버린다!"
- squeeze는 크기가 1인 차원을 찾아서 제거합니다.
- 특정 차원을 지정할 수도 있습니다.
1. torch.squeeze(x) : 크기 1인 차원을 전부 제거
x = torch.tensor([[[1], [2], [3]]])
# 모든 크기 1 차원 제거
x_squeezed = torch.squeeze(x)
print(x.shape) # torch.Size([1, 3, 1])
print(x_squeezed.shape) # torch.Size([3])
2. torch.squeeze(x, dim=0) : 지정한(0번째) 차원만 제거
x = torch.tensor([[[1], [2], [3]]])
# dim=0 (첫 번째 축만 제거)
x_squeezed_dim0 = torch.squeeze(x, dim=0)
print(x.shape) # torch.Size([1, 3, 1])
print(x_squeezed_dim0.shape) # torch.Size([3, 1])
'Framework > PyTorch' 카테고리의 다른 글
[PyTorch] cat() vs stack() (0) | 2025.04.29 |
---|