PyTorch 之torch.cat()和torch.stack()的理解

看了下官方手册和网上的一些解释,但还是没太弄明白,于是自己下来在草稿纸上画了一下,然后用程序验证,总算把这两个函数搞懂了些。

首先给出两个tensor: a和b,它们的形状都是(1,3,2)

判断tensor维度的方法

看对应的中括号,有几个左中括号就有几个维度,如:

有三个左括号,那么a就有3个维度

判断每个维度有多少个元素

对于dim=0,我们找到第一个左中括号,并且找到与之匹配的右中括号,看他们之中有多少个元素(凡是被中括号包裹的视作1个元素,其实就是看配对的中括号中有几个逗号,其中的元素个数就是逗号个数加1)

tensor.cat()

将两个(或更多)相同形状的tensor按某一维度(dim)进行拼接,操作是:对两个tensor,将需要拼接的维度中的所有元素接在一起,具体示例如下:

  • tensor.cat((a,b), dim=0)

    找到tensor a和tensor b的第0维中所有的元素,将他们拼在一起:

  • tensor.cat((a,b), dim=1)

    同样,找到两个tensor第一维中的所有元素,然后拼在一起:

  • tensor.cat((a,b), dim=2)

通过以上步骤,我们不难得出结论,对于cat操作,输出的tensor的形状,就是输入的tensor形状在指定维度(dim)上相加即可,输出的tensor总的维度数不会改变,改变的只是相应维度上元素的个数而已。

tensor.stack()

已经说过对于cat()操作,输出的tensor总的维度数不会改变,而stack()操作则会改变输出tensor总的维度数。假设输入两个tensor,分别是a,b,他们的形状都是(3, 2),在stack时,我们会先在指定的dim维度上增加一维,然后在增加的这个维度上对这两个tensor做cat操作,具体如下:

  • tensor.stack((a,b), dim=0)

  • tensor.stack((a,b), dim=1)

  • tensor.stack((a,b), dim=2)

所以,对于stack操作,如果我们想知道输出tensor的维度,我们只需要将输入tensor添加一维,然后按照cat()操作来进行就可以了,至于在哪个维度上添加,就取决于参数dim的值。

您的赞赏将会是我前进的巨大动力