Pytorch实现基于CharRNN的文本分类与生成示例
2020-06-25 08:07:59 来源:易采站长站 作者:易采站长站整理
2.4 nn.GRU
GRU也是一种RNN单元,但它比LSTM简化许多,普通的GRU单元结构如下图所示:

每个单元存在输入x与上一时刻的隐层状态h,输出有y与当前时刻的隐层状态。
rnn = GRU(input_dim, hidden_dim, num_layers=1, batch_first=True, bidirectional)
input_dim 输入word的特征数量,使用embeding时为嵌入的维度
hidden_dim 隐层的单元数
output, hidden = rnn(input, hidden)
input 一批输入数据,shape为[batch, seq_len, input_dim]hidden 上一时刻的隐层状态,shape为[num_layers*num_directions, batch, hidden_dim]output 当前时刻的输出,shape为[batch, seq_len, num_directions*hidden_size]
2.5 损失函数
MSELoss均方误差

输入x,y可以是任意的shape,但要保持相同的shape
CrossEntropyLoss 交叉熵误差

x : 包含每个类的得分,2-D tensor, shape=(batch, n)
class: 长度为batch 的 1D tensor,每个数值为类别的索引(0到 n-1)
3 字符级RNN的分类应用
这里先介绍字符极词向量的训练与使用。语料库使用nltk的names语料库,训练根据人名预测对应的性别,names语料库有两个分类,female与male,每个分类下对应约4000个人名。这个语料库是比较适合字符级RNN的分类应用,因为人名比较短,不能再做分词以使用词向量。
首次使用nltk的names语料库要先下载下来,运行代码nltk.download(‘names’)即可。
字符级RNN模型的词汇表很简单,就是单个字符的集合,对于英文来说,只有26个字母,外加空格等会出现在名字中间的字符,见第14行代码。出于简化的目的,所有名字统一转换为小写。
神经网络很简单,一层RNN网络,用于学习名字序列的特征。一层全连接网络,用于从将高维特征映射到性别的二分类上。这部分代码由CharRNN类实现。这里没有使用embeding层,而是使用字符的one-hot编码,当然使用Embeding也是可以的。
网络的训练和使用封装为Model类,提供三个方法。train(), evaluate(),predict()分别用于训练,评估和预测使用。具体见下面的代码及注释。
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import sklearn
import string
import random
nltk.download('names')
from nltk.corpus import namesUSE_CUDA = torch.cuda.is_available()













闽公网安备 35020302000061号