Pytorch实现基于CharRNN的文本分类与生成示例
2020-06-25 08:07:59 来源:易采站长站 作者:易采站长站整理
rnn = RNN(input_dim, hidden_dim, num_layers=1, batch_first, bidirectional)
input_dim 输入token的特征数量,使用embeding时为嵌入的维度
hidden_dim 隐层的单元数,决定RNN的输出长度
num_layers 层数
batch_frist 第一维为batch,反之第一堆为seq_len,默认为False
bidirectional 是否为双向RNN,默认为False
output, hidden = rnn(input, hidden)
input 一批输入数据,shape为[batch, seq_len, input_dim]hidden 上一时刻的隐层状态,shape为[num_layers * num_directions, batch, hidden_dim]output 当前时刻的输出,shape为[batch, seq_len, num_directions*hidden_dim]
import torch
from torch import nnvocab_size = 5
embed_dim = 3
hidden_dim = 8
embedding = nn.Embedding(vocab_size, embed_dim)
rnn = nn.RNN(embed_dim, hidden_dim, batch_first=True)
sents = [[1, 2, 4],
[2, 3, 4]]h0 = torch.zeros(1, embeded.size(0), 8) # shape=(num_layers*num_directions, batch, hidden_dim)
embeded = embedding(torch.LongTensor(sents))
out, hidden = rnn(embeded, h0) # out.shape=(2,3,8), hidden.shape=(1,2,8)
print(out, hidden)
'''
tensor([[[-0.1556, -0.2721, 0.1485, -0.2081, -0.2231, -0.1459, -0.0319, 0.2617],
[-0.0274, 0.1561, -0.0509, -0.1723, -0.2678, -0.2616, 0.0786, 0.4124],
[ 0.2346, 0.4487, -0.1409, -0.0807, -0.0232, -0.4975, 0.4244, 0.8337]],
[[ 0.0879, 0.1122, 0.1502, -0.3033, -0.2715, -0.1191, 0.1367, 0.5275],
[ 0.2258, 0.4395, -0.1365, 0.0135, -0.0777, -0.5221, 0.4683, 0.8115],
[ 0.0158, 0.3471, 0.0742, -0.0550, -0.0098, -0.5521, 0.5923,0.8782]]], grad_fn=<TransposeBackward0>)
tensor([[[ 0.2346, 0.4487, -0.1409, -0.0807, -0.0232, -0.4975, 0.4244, 0.8337],
[ 0.0158, 0.3471, 0.0742, -0.0550, -0.0098, -0.5521, 0.5923, 0.8782]]], grad_fn=<ViewBackward>)
'''
2.3 nn.LSTM
LSTM是RNN的一种模型,结构中增加了记忆单元,LSTM单元结构如下图所示:

每个单元存在输入x与上一时刻的隐层状态h和上一次记忆c,输出有y与当前时刻的隐层状态及当前时刻的记忆c。其使用上和RNN类似。
lstm = LSTM(input_dim, hidden_dim, num_layers=1, batch_first=True, bidirectional)
input_dim 输入word的特征数量,使用embeding时为嵌入的维度
hidden_dim 隐层的单元数
output, (hidden, cell) = lstm(input, (hidden, cell))
input 一批输入数据,shape为[batch, seq_len, input_dim]hidden 当前时刻的隐层状态,shape为[num_layers * num_directions, batch, hidden_dim]cell 当前时刻的记忆状态,shape为[num_layers * num_directions, batch, hidden_dim]output 当前时刻的输出,shape为[batch, seq_len, num_directions*hidden_dim]













闽公网安备 35020302000061号