본문 바로가기
Framework/PyTorch

[PyTorch] cat() vs stack()

by ngool 2025. 4. 29.

딥러닝 실습을 하다 보면 데이터를 합치는 작업이 꼭 필요합니다.
PyTorch에서는 주로 torch.cat()torch.stack() 두 가지 함수를 사용하는데, 저도 매번 헷갈려서 이번에 한번 쭉 정리해보려 합니다.


torch.cat()

"이미 있는 축(axis) 방향으로 나란히 연결한다."
  • cat은 같은 차원끼리 이어붙이는 함수입니다.
  • 차원 수는 변하지 않습니다.
  • input들의 shape이 모든 차원에서 일치해야 하며, 붙이려는 차원만 다를 수 있습니다.

1. torch.cat((a, b), dim=0)

a = torch.tensor([[1, 2], [3, 4]])  # shape: (2, 2)
b = torch.tensor([[5, 6], [7, 8]])  # shape: (2, 2)

# dim=0 (행 방향 연결)
out1 = torch.cat((a, b), dim=0)
print(out1)  # shape: (4, 2)

 

2. torch.cat((a, b), dim=1)

a = torch.tensor([[1, 2], [3, 4]])   # (2, 2)
b = torch.tensor([[5, 6], [7, 8]])   # (2, 2)

# dim=1 (열 방향 연결)
out2 = torch.cat((a, b), dim=1)
print(out2)  # shape: (2, 4)


torch.stack()

"새로운 차원을 만들어 쌓아 올린다."
  • stack은 새로운 차원(axis)추가해서 쌓는 함수입니다.
  • 차원 수가 1 증가합니다.
  • input들의 shape이 완전히 동일해야 합니다.

1. torch.stack((a, b), dim=0)

a = torch.tensor([[1, 2], [3, 4]])  # shape: (2, 2)
b = torch.tensor([[5, 6], [7, 8]])  # shape: (2, 2)

# dim=0 (행 방향 연결)
out3 = torch.stack((a, b), dim=0)
print(out3)  # shape: (2, 2, 2)

 

2. torch.stack((a, b), dim=1)

a = torch.tensor([[1, 2], [3, 4]])  # shape: (2, 2)
b = torch.tensor([[5, 6], [7, 8]])  # shape: (2, 2)

# dim=1 (열 방향 연결)
out4 = torch.stack((a, b), dim=1)
print(out4)  # shape: (2, 2, 2)

'Framework > PyTorch' 카테고리의 다른 글

[PyTorch] unsqueeze() vs squeeze()  (0) 2025.04.29