<
>

Pytorch实现神经网络的分类方式

2020-06-25 08:09:51 来源:易采站长站 作者:易采站长站整理


#6.迭代训练
for epoch in range(2):
for step, (batch_x, batch_y) in enumerate(loader):
out = net(batch_x)#输入训练集,获得当前迭代输出值
loss = loss_func(out, batch_y)#获得当前迭代的损失

optimizer.zero_grad()#清除上次迭代的更新梯度
loss.backward()#反向传播
optimizer.step()#更新权重

if step%200==0:
plt.cla()#清空之前画布上的内容
entire_out = net(x)#测试整个训练集
#获得当前softmax层最大概率对应的索引值
pred = torch.max(F.softmax(entire_out), 1)[1] #将二维压缩为一维
pred_y = pred.data.numpy().squeeze()
label_y = y.data.numpy()
plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, cmap='RdYlGn')
accuracy = sum(pred_y == label_y)/y.size()
print("第 %d 个epoch,第 %d 次迭代,准确率为 %.2f"%(epoch+1, step/200+1, accuracy))
#在指定位置添加文本
plt.text(1.5, -4, 'Accuracy=%.2f' % accuracy, fontdict={'size': 15, 'color': 'red'})
plt.pause(2)#图像显示时间

#7.保存模型结构和参数
torch.save(net, 'net.pkl')
#7.只保存模型参数
# torch.save(net.state_dict(), 'net_param.pkl')

plt.ioff()#关闭画布
plt.show()

if __name__ == '__main__':
save()

2. 读取已训练好的模型测试数据


import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
import torch.nn.functional as F

#制作数据
n_data = torch.ones( 100,2 )
x0 = torch.normal( 1.5*n_data, 1 )#均值为2 标准差为1
y0 = torch.zeros( 100 )

x1 = torch.normal( -1.5*n_data,1 )#均值为-2 标准差为1
y1 = torch.ones( 100 )
print("数据集维度:",x0.size(),y0.size())

#合并训练数据集,并转化数据类型为浮点型或整型
x = torch.cat( (x0,x1),0 ).type( torch.FloatTensor )
y = torch.cat( (y0,y1) ).type( torch.LongTensor )
print( "合并后的数据集维度:",x.data.size(), y.data.size() )

#将Tensor放入Variable中
x,y = Variable(x), Variable(y)

#载入模型和参数
def restore_net():
net = torch.load('net.pkl')
#获得载入模型的预测输出
pred = net(x)
# 获得当前softmax层最大概率对应的索引值
pred = torch.max(F.softmax(pred), 1)[1] # 将二维压缩为一维
pred_y = pred.data.numpy().squeeze()
label_y = y.data.numpy()
accuracy = sum(pred_y == label_y) / y.size()
print("准确率为:",accuracy)
plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, cmap='RdYlGn')
plt.show()
#仅载入模型参数,需要先创建网络模型
def restore_param():
net = torch.nn.Sequential(
torch.nn.Linear(2,10),#指定输入层和隐层结点,获得隐层线性输出
torch.nn.ReLU(),#隐层非线性化

暂时禁止评论

微信扫一扫

易采站长站微信账号