torch.cat()函数的官方解释,详解以及例子
2020-06-28 09:36:12 来源:易采站长站 作者:易采站长站整理
可以直接看最下面的例子,再回头看前面的解释,就很明白了。
在
pytorch中,常见的拼接函数主要是两个,分别是:
stack()
cat()一般
torch.cat()是为了把函数
torch.stack()得到
tensor进行拼接而存在的。区别参考链接torch.stack(),但是本文主要说
cat()。前言
和
python中的内置函数
cat(), 在使用和目的上,是没有区别的。1. cat()官方解释
—-
torch.cat(inputs, dim=0) → Tensor函数目的: 在给定维度上对输入的张量序列seq 进行连接操作。
outputs = torch.cat(inputs, dim=0) # → Tensor
参数
inputs : 待连接的张量序列,可以是任意相同
Tensor类型的python 序列dim : 选择的扩维, 必须在
0到
len(inputs[0])之间,沿着此维连接张量序列。2. 重点
输入数据必须是序列,序列中数据是任意相同的
shape的同类型
tensor维度不可以超过输入数据的任一个张量的维度
3.举例子
准备数据,每个的
shape都是
[2,3]
# x1
x1 = torch.tensor([[11,21,31],[21,31,41]],dtype=torch.int)
x1.shape # torch.Size([2, 3])
# x2
x2 = torch.tensor([[12,22,32],[22,32,42]],dtype=torch.int)
x2.shape # torch.Size([2, 3])
合成
inputs
'inputs为2个形状为[2 , 3]的矩阵 '
inputs = [x1, x2]print(inputs)
'打印查看'
[tensor([[11, 21, 31],
[21, 31, 41]], dtype=torch.int32),
tensor([[12, 22, 32],
[22, 32, 42]], dtype=torch.int32)]3.查看结果, 测试不同的
dim拼接结果
In [1]: torch.cat(inputs, dim=0).shape
Out[1]: torch.Size([4, 3])In [2]: torch.cat(inputs, dim=1).shape
Out[2]: torch.Size([2, 6])
In [3]: torch.cat(inputs, dim=2).shape
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)
暂时禁止评论













闽公网安备 35020302000061号