【PyTorch学习】PyTorch基础知识
2020-06-28 07:49:00 来源:易采站长站 作者:易采站长站整理
dataset=torch_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=0,
)
for epoch in range(3):
for step, (batch_x, batch_y) in enumerate(loader):
print('Epoch: ', epoch, '| Step: ', step, '| batch x: ', batch_x.numpy(), '| batch y: ', batch_y.numpy())
Epoch: 0 | Step: 0 | batch x: [6. 8. 9. 2. 5.] | batch y: [5. 3. 2. 9. 6.]Epoch: 0 | Step: 1 | batch x: [10. 4. 7. 3. 1.] | batch y: [ 1. 7. 4. 8. 10.]Epoch: 1 | Step: 0 | batch x: [ 5. 10. 2. 6. 7.] | batch y: [6. 1. 9. 5. 4.]Epoch: 1 | Step: 1 | batch x: [9. 1. 8. 4. 3.] | batch y: [ 2. 10. 3. 7. 8.]Epoch: 2 | Step: 0 | batch x: [ 5. 7. 1. 10. 9.] | batch y: [ 6. 4. 10. 1. 2.]Epoch: 2 | Step: 1 | batch x: [6. 8. 2. 4. 3.] | batch y: [5. 3. 9. 7. 8.]3. torch.utils.data.DataLoader
PyTorch中提供了一个简单的办法来做这个事情,通过torch.utils.data.DataLoader来定义一个新的迭代器,如下:
from torch.utils.data import DataLoaderdataiter = DataLoader(myDataset,batch_size=32,shuffle=True,collate_fn=defaulf_collate)
其中的参数都很清楚,只有collate_fn是标识如何取样本的,我们可以定义自己的函数来准确地实现想要的功能,默认的函数在一般情况下都是可以使用的。
(需要注意的是,Dataset类只相当于一个打包工具,包含了数据的地址。真正把数据读入内存的过程是由Dataloader进行批迭代输入的时候进行的。)
4. torchvision.datasets.ImageFolder
另外在torchvison这个包中还有一个更高级的有关于计算机视觉的数据读取类:ImageFolder,主要功能是处理图片,且要求图片是下面这种存放形式:
root/dog/xxx.pngroot/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/asd/png
root/cat/zxc.png
之后这样来调用这个类:
from torchvision.datasets import ImageFolderdset = ImageFolder(root='root_path', transform=None, loader=default_loader)
其中 root 需要是根目录,在这个目录下有几个文件夹,每个文件夹表示一个类别:transform 和 target_transform 是图片增强,后面我们会详细介绍;loader是图片读取的办法,因为我们读取的是图片的名字,然后通过 loader 将图片转换成我们需要的图片类型进入神经网络。
2.4 nn.Module(模组)
在PyTorch里面编写神经网络,所有的层结构和损失函数都来自于torch.nn,所有的模型构建都是从这个基类nn.Module继承的,于是有了下面这个模板。













闽公网安备 35020302000061号