Pytorch学习之torch用法----比较操作(Comparison Ops)
2020-06-28 12:01:17 来源:易采站长站 作者:王旭
1. torch.eq(input, other, out=None)
说明: 比较元素是否相等,第二个参数可以是一个数,或者是第一个参数同类型形状的张量
参数:
input(Tensor) ---- 待比较张量
other(Tenosr or float) ---- 比较张量或者数
out(Tensor,可选的) ---- 输出张量
返回值: 一个torch.ByteTensor张量,包含了每个位置的比较结果(相等为1,不等为0)
>>> a = torch.Tensor([[1, 2], [3, 4]]) >>> b = torch.Tensor([[1, 1], [4, 4]]) >>> torch.eq(a, b) tensor([[1, 0], [0, 1]], dtype=torch.uint8)
2. torch.equal(tensor1, tensor2, out=None)
说明: 如果两个张量有相同的形状和元素值,则返回true,否则False
参数:
tensor1(Tenosr) ---- 比较张量1
tensor2(Tensor) ---- 比较张量2
out(Tensor,可选的) ---- 输出张量
>>> a = torch.Tensor([1, 2]) >>> b = torch.Tensor([1, 2]) >>> torch.equal(a, b) True
3. torch.ge(input, other, out=None)
说明: 逐元素比较input和other,即是否input >= other。
参数:
input(Tensor) ---- 待对比的张量
other(Tensor or float) ---- 对比的张量或float值
out(Tensor,可选的) ---- 输出张量,
>>> a = torch.Tensor([[1, 2], [3, 4]]) >>> b = torch.Tensor([[1, 1], [4, 4]]) >>> torch.ge(a, b) tensor([[1, 1], [0, 1]], dtype=torch.uint8)
4. torch.gt(input, other, out=None)
说明: 逐元素比较input和other,即是否input > other
参数:
input(Tensor) ---- 要对比的张量
other(Tensor or float) ---- 要对比的张量或float值
out(Tensor,可选的) ---- 输出张量
>>> a = torch.Tensor([[1, 2], [3, 4]]) >>> b = torch.Tensor([[1, 1], [4, 4]]) >>> torch.gt(a, b) tensor([[0, 1], [0, 0]], dtype=torch.uint8)
5. torch.kthvalue(input, k, dim=None, out=None)
说明: 取输入张量input指定维度上第k个最小值。如果不指定dim。默认为最后一维。返回一个元组(value, indices), 其中indices是原始输入张量中沿dim维的第k个最小值下标。
参数:
input(Tensor) ---- 要对比的张量
k(int) ---- 第k个最小值
dim(int, 可选的) ---- 沿着此维度进行排序
out(tuple,可选的) ---- 输出元组
>>> x = torch.arange(1, 6) >>> x tensor([1, 2, 3, 4, 5]) >>> torch.kthvalue(x, 4) torch.return_types.kthvalue( values=tensor(4), indices=tensor(3)) >>> torch.kthvalue(x, 1) torch.return_types.kthvalue( values=tensor(1), indices=tensor(0))
6. torch.le(input, other, out=None)
说明: 逐元素比较input和other,即是否input <= other.
参数:
input(Tenosr) ---- 要对比的张量
other(Tensor or float) ---- 对比的张量或float值













闽公网安备 35020302000061号