딥러닝 실습을 하다 보면 데이터를 합치는 작업이 꼭 필요합니다.
PyTorch에서는 주로 torch.cat()과 torch.stack() 두 가지 함수를 사용하는데, 저도 매번 헷갈려서 이번에 한번 쭉 정리해보려 합니다.
torch.cat()
"이미 있는 축(axis) 방향으로 나란히 연결한다."
- cat은 같은 차원끼리 이어붙이는 함수입니다.
- 차원 수는 변하지 않습니다.
- input들의 shape이 모든 차원에서 일치해야 하며, 붙이려는 차원만 다를 수 있습니다.
1. torch.cat((a, b), dim=0)
a = torch.tensor([[1, 2], [3, 4]]) # shape: (2, 2)
b = torch.tensor([[5, 6], [7, 8]]) # shape: (2, 2)
# dim=0 (행 방향 연결)
out1 = torch.cat((a, b), dim=0)
print(out1) # shape: (4, 2)
2. torch.cat((a, b), dim=1)
a = torch.tensor([[1, 2], [3, 4]]) # (2, 2)
b = torch.tensor([[5, 6], [7, 8]]) # (2, 2)
# dim=1 (열 방향 연결)
out2 = torch.cat((a, b), dim=1)
print(out2) # shape: (2, 4)
torch.stack()
"새로운 차원을 만들어 쌓아 올린다."
- stack은 새로운 차원(axis) 을 추가해서 쌓는 함수입니다.
- 차원 수가 1 증가합니다.
- input들의 shape이 완전히 동일해야 합니다.
1. torch.stack((a, b), dim=0)
a = torch.tensor([[1, 2], [3, 4]]) # shape: (2, 2)
b = torch.tensor([[5, 6], [7, 8]]) # shape: (2, 2)
# dim=0 (행 방향 연결)
out3 = torch.stack((a, b), dim=0)
print(out3) # shape: (2, 2, 2)
2. torch.stack((a, b), dim=1)
a = torch.tensor([[1, 2], [3, 4]]) # shape: (2, 2)
b = torch.tensor([[5, 6], [7, 8]]) # shape: (2, 2)
# dim=1 (열 방향 연결)
out4 = torch.stack((a, b), dim=1)
print(out4) # shape: (2, 2, 2)
'Framework > PyTorch' 카테고리의 다른 글
[PyTorch] unsqueeze() vs squeeze() (0) | 2025.04.29 |
---|