看完秒懂torch.stack()
2020-06-28 10:23:22 来源:易采站长站 作者:易采站长站整理
torch.stack ()在这里插入图片描述一、准备数据二、dim=0三、dim=1四、dim=2
一、准备数据
首先把基本的数据准备好:
import torch
import numpy as np
# 创建3*3的矩阵,a、b
a=np.array([[1,2,3],[4,5,6],[7,8,9]])
b=np.array([[10,20,30],[40,50,60],[70,80,90]])
# 将矩阵转化为Tensor
a = torch.from_numpy(a)
b = torch.from_numpy(b)
# 打印a、b、c
print(a)
print(b)
output:
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]], dtype=torch.int32)
tensor([[10, 20, 30],
[40, 50, 60],
[70, 80, 90]], dtype=torch.int32)
二、dim=0
首先,一起来看看dim=0的时候,结果会是怎么样
d = torch.stack((a, b), dim=0)
print(d)
print(d.size())
tensor([[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9]], [[10, 20, 30],
[40, 50, 60],
[70, 80, 90]]], dtype=torch.int32)
torch.Size([2, 3, 3])
观察结果,可以得出结论:
当
dim = 0,原来的每一个矩阵也变成了一个维度一个矩阵看做一个整体,有几个矩阵,新的维度就是几,第几个矩阵就是第几维;
如下,取出第1维度的矩阵(下标从0开始):
d[0]tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]], dtype=torch.int32)
可以很清楚的看到这就是
stack前的第一个矩阵。三、dim=1
那么,dim=1的时候,结果会是怎么样
d = torch.stack((a, b), dim=1)
print(d)
print(d.size())
tensor([[[ 1, 2, 3],
[10, 20, 30]], [[ 4, 5, 6],
[40, 50, 60]],
[[ 7, 8, 9],
[70, 80, 90]]], dtype=torch.int32)
torch.Size([3, 2, 3])
为了观察方便,把
原始数据也再拿过来对照。
# 原始数据
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]], dtype=torch.int32)
tensor([[10, 20, 30],
[40, 50, 60],
[70, 80, 90]], dtype=torch.int32)
可以得出结论:
将每个矩阵的第一行组成第一维矩阵,依次下去,每个矩阵的第n行组成第n维矩阵。size=(n,i,y)













闽公网安备 35020302000061号