<
>

【PyTorch学习】PyTorch基础知识

2020-06-28 07:49:00 来源:易采站长站 作者:易采站长站整理


x = torch.randn(3)
x = Variable(x, requires_grad=True)

y = x * 2
print(y)

y.backward(torch.FloatTensor([1, 0.1, 0.01]))
print(x.grad)


tensor([ 2.3863, 1.3822, -2.5512], grad_fn=)
tensor([2.0000, 0.2000, 0.0200])

相当于给出了一个三维向量去做运算,这时候得到的结果y就是一个向量,这里对这个向量求导就不能直接写成 y.backward(),这样程序是会报错的。这个时候需要传入参数声明,比如y.backward(torch.FloatTensor([1, 1, 1])),这样得到的结果就是它们每个分量的梯度,或者可以传入y.backward(torch.FloatTensor([1, 0.1, 0.01])),这样得到的梯度就是它们原本的梯度分别乘上1,0.1和0.01。

2.3 Dataset(数据集)

在处理任何机器学习问题之前都需要数据读取,并进行预处理。PyTorch提供了很多工具使得数据的读取和预处理变得很容易。接下来介绍 Dataset,TensorDataset,DataLoader,ImageFolder的简单用法。

1. torch.utils.data.Dataset

它是代表这一数据的抽象类。你可以自己定义你的数据类,继承和重写这个抽象类,非常简单,只需要定义__len__和__getitem__这两个函数:


from torch.utils.data import Dataset
import pandas as pd

class myDataset(Dataset):
def __init__(self, csv_file, txt_file, root_dir, other_file):
self.csv_data = pd.read_csv(csv_file)
with open(txt_file, 'r') as f:
data_list = f.readlines()
self.txt_data = data_list
self.root_dir = root_dir

def __len__(self):
return len(self.csv_data)

def __getitem__(self, idx):
data = (self.csv_data[idx], self.txt_data[idx])
return data

通过上面的方式,可以定义我们需要的数据类,可以通过迭代的方式来取得每一个数据,但是这样很难实现取batch,shuffle或者是多线程去读取数据。

2. torch.utils.data.TensorDataset

它继承自Dataset,新版把之前的data_tensor和target_tensor去掉了,输入变成了可变参数,也就是我们平常使用*args。


# 原版使用方法

train_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)

# 新版使用方法

train_dataset = Data.TensorDataset(x, y)


import torch
import torch.utils.data as Data

BATCH_SIZE = 5

x = torch.linspace(1, 10, 10)
y = torch.linspace(10, 1, 10)

torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(

暂时禁止评论

微信扫一扫

易采站长站微信账号