<
>

torch.stack()的官方解释,详解以及例子

2020-06-28 08:15:16 来源:易采站长站 作者:易采站长站整理

都是
[2,3]

# x1
x1 = torch.tensor([[11,21,31],[21,31,41]],dtype=torch.int)
x1.shape # torch.Size([2, 3])
# x2
x2 = torch.tensor([[12,22,32],[22,32,42]],dtype=torch.int)
x2.shape # torch.Size([2, 3])
# x3
x3 = torch.tensor([[13,23,33],[23,33,43]],dtype=torch.int)
x3.shape # torch.Size([2,3])
# x4
x4 = torch.tensor([[14,24,34],[24,34,44]],dtype=torch.int)
x4.shape # torch.Size([2,3])

2.合成

inputs

'inputs为4个形状为[2 , 3]的矩阵 '
inputs = [x1, x2, x3, x4]print(inputs)
# 打印看看结构。是4个张量
[tensor([[11, 21, 31],
[21, 31, 41]], dtype=torch.int32),
tensor([[12, 22, 32],
[22, 32, 42]], dtype=torch.int32),
tensor([[13, 23, 33],
[23, 33, 43]], dtype=torch.int32),
tensor([[14, 24, 34],
[24, 34, 44]], dtype=torch.int32)]

3.查看结果, 测试不同的

dim
拼接结果

'选择的 0<=dimlen(outputs),所以报错'
In [4]: torch.stack(inputs, dim=3).shape
IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)

大家可以复制代码运行一下就会发现:这个拼接后的维度大小4根据不同的

dim
一直变化。

dimshape
0[4, 2, 3]
1[2, 4, 3]
2[2, 3, 4]
3溢出报错

4. 总结

函数作用:
函数

stack()
序列数据内部的张量进行扩维拼接,指定维度由我们选择、大小是生成后数据的维度区间。

存在意义:
在自然语言处理和卷及神经网络中, 通常为了保留–[序列(先后)信息] 和 [张量的矩阵信息] 才会使用

stack

研究自然语言处理的同学一般知道,在循环神经网络中,网络的输出数据通常是:包含了

n
个数据大小
[batch_size, num_outputs]
list
,这个和
[n, batch_size, num_outputs]
是完全不一样的!!!!不利于计算,需要使用
stack
进行拼接,保留–[1.时间步]和–[2.张量的矩阵乘积属性]。

作者:模糊包

暂时禁止评论

微信扫一扫

易采站长站微信账号