<
>

torch.stack()的官方解释,详解以及例子

2020-06-28 08:15:16 来源:易采站长站 作者:易采站长站整理

可以直接看最下面的例子,再回头看前面的解释

pytorch
中,常见的拼接函数主要是两个,分别是:

stack()

cat()

实际使用中,这两个函数互相辅助:关于

cat()
参考这个torch.cat(),但是本文主要说
stack()

前言

函数的意义:使用

stack
是为了保留–[1. 序列(先后)] 和 [2. 张量矩阵] 信息, 常出现在自然语言处理(
NLP
)和图像卷积神经网络(
CV
)中。

1. stack()官方解释

官方解释:沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状。

浅显说法:把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量…以此类推,也就是在增加新的维度进行堆叠

outputs = torch.stack(inputs, dim=0)  # → Tensor

参数

inputs : 待连接的张量序列。
注:

python
的序列数据只有
list
tuple

dim : 新的维度, 必须在

0
len(outputs)
之间。
注:
len(outputs)
是生成数据的维度大小,也就是
outputs
的维度值。

2. 重点

函数中的输入

inputs
只允许是序列;且序列内部的张量元素,必须
shape
相等

—-举例:

[tensor_1, tensor_2,..]
或者
(tensor_1, tensor_2,..)
,且必须
tensor_1.shape == tensor_2.shape

dim
是选择生成的维度,必须满足
0<=dim<len(outputs)
len(outputs)
是输出后的
tensor
的维度大小

不懂的看例子,再回过头看就懂了。

3. 例子

按下面的三步:准备数据,合成

inputs
,查看结果。

1.准备数据,每个的

shape
              
暂时禁止评论

微信扫一扫

易采站长站微信账号