<
>

看完秒懂torch.stack()

2020-06-28 10:23:22 来源:易采站长站 作者:易采站长站整理

torch.stack ()在这里插入图片描述一、准备数据二、dim=0三、dim=1四、dim=2
在这里插入图片描述
一、准备数据

首先把基本的数据准备好:

import torch
import numpy as np
# 创建3*3的矩阵,a、b
a=np.array([[1,2,3],[4,5,6],[7,8,9]])
b=np.array([[10,20,30],[40,50,60],[70,80,90]])
# 将矩阵转化为Tensor
a = torch.from_numpy(a)
b = torch.from_numpy(b)
# 打印a、b、c
print(a)
print(b)

output:
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]], dtype=torch.int32)
tensor([[10, 20, 30],
[40, 50, 60],
[70, 80, 90]], dtype=torch.int32)


二、dim=0

首先,一起来看看dim=0的时候,结果会是怎么样

d = torch.stack((a, b), dim=0)
print(d)
print(d.size())
tensor([[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9]],

[[10, 20, 30],
[40, 50, 60],
[70, 80, 90]]], dtype=torch.int32)
torch.Size([2, 3, 3])

观察结果,可以得出结论:

dim = 0
,原来的每一个矩阵也变成了一个维度
一个矩阵看做一个整体,有几个矩阵,新的维度就是几,第几个矩阵就是第几维;

如下,取出第1维度的矩阵(下标从0开始):

d[0]tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]], dtype=torch.int32)

可以很清楚的看到这就是

stack
前的第一个矩阵。

三、dim=1

那么,dim=1的时候,结果会是怎么样

d = torch.stack((a, b), dim=1)
print(d)
print(d.size())
tensor([[[ 1, 2, 3],
[10, 20, 30]],

[[ 4, 5, 6],
[40, 50, 60]],

[[ 7, 8, 9],
[70, 80, 90]]], dtype=torch.int32)
torch.Size([3, 2, 3])

为了观察方便,把

原始数据
也再拿过来对照。

# 原始数据
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]], dtype=torch.int32)
tensor([[10, 20, 30],
[40, 50, 60],
[70, 80, 90]], dtype=torch.int32)

可以得出结论:

将每个矩阵的第一行组成第一维矩阵,依次下去,每个矩阵的第n行组成第n维矩阵。size=(n,i,y)

暂时禁止评论

微信扫一扫

易采站长站微信账号