<
>

Pytorch实现基于CharRNN的文本分类与生成示例

2020-06-25 08:07:59 来源:易采站长站 作者:易采站长站整理


name, sexy = random.choice(test_set)
name_tensor = torch.tensor([name], dtype=torch.float, device=device)

pred = self.model(name_tensor)
if torch.argmax(pred).item() == sexy:
correct += 1

print('Evaluating: test accuracy is {}%'.format(correct/10.0))

def predict(self, name):
p = name2vec(name.lower())
name_tensor = torch.tensor([p], dtype=torch.float, device=device)
with torch.no_grad():
out = self.model(name_tensor)
out = torch.argmax(out).item()
sexy = 'female' if out == 0 else 'male'
print('{} is {}'.format(name, sexy))

if __name__ == "__main__":
model = Model(10)
data_set = load_data()
train, test = sklearn.model_selection.train_test_split(data_set)
model.train(train)
model.evaluate(test)
model.predict("Jim")
model.predict('Kate')
'''
Evaluating: test accuracy is 82.6%
Jim is male
Kate is female
'''

4 基于字符级RNN的文本生成

文本生成的思想是,通过让神经网络学习下一个输出是哪个字符来训练权重参数。这里我们仍使用names语料库,尝试训练一个生成指定性别人名的神经网络化。与分类不同的是分类只计算最终状态输出的误差而生成要计算序列每一步计算上的误差,因此训练时要逐个字符的输入到网络。由于是根据性别来生成人名,因此把性别的one-hot向量concat到输入数据里,作为训练数据的一部分。

模型由类CharRNN实现,模型的训练和使用由Model类实现,提供了train(), sample()方法,前者用于训练模型,后者用于从训练中进行采样生成。


# coding=utf-8
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import string
import random
import nltk
nltk.download('names')
from nltk.corpus import names

USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")

# 使用符号!作为名字的结束标识
chars = string.ascii_lowercase + '-' + ' ' + "'" + '!'

hidden_dim = 128
output_dim = len(chars)

# name abc encode as [[1, ...], [0,1,...], [0,0,1...]]def name2input(name):
ids = [chars.index(c) for c in name if c not in [""]] a = np.zeros(shape=(len(ids), len(chars)), dtype=np.long)
for i, idx in enumerate(ids):
a[i][idx] = 1
return a

# name abc encode as [0 1 2]def name2target(name):
ids = [chars.index(c) for c in name if c not in [""]] return ids

# female=[[1, 0]] male=[[0,1]]def sexy2input(sexy):
a = np.zeros(shape=(1, 2), dtype=np.long)
a[0][sexy] = 1
return a

def load_data():
female_file, male_file = names.fileids()

f1_names = names.words(female_file)
f2_names = names.words(male_file)

data_set = [(name.lower(), 0) for name in f1_names] + [(name.lower(), 1) for name in f2_names] random.shuffle(data_set)

暂时禁止评论

微信扫一扫

易采站长站微信账号