<
>

【PyTorch学习】PyTorch基础知识

2020-06-28 07:49:00 来源:易采站长站 作者:易采站长站整理


import torch.nn as nn

class net_name(nn.Module):
def __init__(self, other_arguments):
super(net_name, self).__init__()
self.convl = nn.Conv2d(in_channels, out_channels, kernel_size)
# 其他网路层

def forward(self, x):
x = self.convl(x)
return x

这样就建立了一个计算图,并且这个结构可以复用多次,每次调用就相当于用该计算图定义的相同参数做一次前向传播,这得益于PyTorch的自动求导功能,所以我们不需要自己编写反向传播,而所有的网络层都是由nn这个包得到的,比如线性层nn.Linear。

定义完模型之后,我们需要通过nn这个包来定义损失函数。常见的损失函数都已经定义在了nn中,比如均方误差、多分类的交叉熵,以及二分类的交叉熵等等,调用这些已经定义好的损失函数也很简单:


criterion = nn.CrossEntropyLoss()

loss = criterion(output, target)

criterion = nn.CrossEntropyLoss() loss = criterion(output, target)

2.5 torch.optim(优化)

在机器学习或者深度学习中,我们需要通过修改参数使得损失函数最小化(或最大化),优化算法就是一种调整模型参数更新的策略。优化算法分为两大类。

1. 一阶优化算法

这种算法使用各个参数的梯度值来更新参数,最常用的一阶优化算法是梯度下降。所谓的梯度就是导数的多变量表达式,函数的梯度形成了一个向量场,同时也是一个方向,这个方向上方向导数最大,且等于梯度。梯度下降的功能是通过寻找最小值,控制方差,更新模型参数,最终使模型收敛,网络的参数更新公式是:

2.6 模型的保存和加载

在PyTorch里面使用torch.save来保存模型的结构和参数,有两种保存方式:

保存整个模型的结构信息和参数信息,保存的对象是模型 model;
保存模型的参数,保存的对象是模型的状态model.state_dict()。

可以这样保存,save的第一个参数是保存对象,第二个参数是保存路径及名称:


torch.save(model, './model.pth')

torch.save(model.state_dict(), './model_state.pth')

加载模型有两种方式对应于保存模型的方式:

加载完整的模型结构和参数信息,使用 load_model=torch.load('model.pth'),在网络较大的时候加载的时间比较长,同时存储空间也比较大;
加载模型参数信息,需要先导入模型的结构,然后通过 model.load_state_dic(torch.load('model_state.pth'))来导入。

暂时禁止评论

微信扫一扫

易采站长站微信账号