<
>

Pytorch实现基于CharRNN的文本分类与生成示例

2020-06-25 08:07:59 来源:易采站长站 作者:易采站长站整理


print(data_set[:10])
return data_set

'''
[('yoshiko', 0), ('timothea', 0), ('giorgi', 1), ('thedrick', 1), ('tessie', 0), ('keith', 1), ('carena', 0), ('anthea', 0), ('cathyleen', 0), ('almeta', 0)]'''
class CharRNN(nn.Module):
def __init__(self, vocab_size, hidden_size, output_size):
super(CharRNN, self).__init__()
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.output_size = output_size
# 输入维度增加了性别的one-hot嵌入,dim+=2
self.rnn = nn.GRU(vocab_size+2, hidden_size, batch_first=True)
self.liner = nn.Linear(hidden_size, output_size)

def forward(self, sexy, name, hidden=None):
if hidden is None:
hidden = torch.zeros(1, 1, self.hidden_size, device=device) # 初始hidden state
# 对每个输入字符,将性别向量嵌入到头部
input = torch.cat([sexy, name], dim=2)
output, hidden = self.rnn(input, hidden)
output = self.liner(output)
output = F.dropout(output, 0.3)
output = F.softmax(output, dim=2)
return output.view(1, -1), hidden

class Model:
def __init__(self, epoches):
self.model = CharRNN(len(chars), hidden_dim , output_dim)
self.model.to(device)
self.epoches = epoches

def train(self, train_set):
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.RMSprop(self.model.parameters(), lr=0.001)

for epoch in range(self.epoches):
total_loss = 0
for x in range(1000): # 每轮随机样本训练1000次
loss = 0
name, sexy = random.choice(train_set)
optimizer.zero_grad()
hidden = torch.zeros(1, 1, hidden_dim, device=device)
# 对于姓名kate,将kate作为输入,ate!作为训输出,依次将每个字符输入网络,以计算误差
for x, y in zip(list(name), list(name[1:]+'!')):
name_tensor = torch.tensor([name2input(x)], dtype=torch.float, device=device)
sexy_tensor = torch.tensor([sexy2input(sexy)], dtype=torch.float, device=device)
target_tensor = torch.tensor(name2target(y), dtype=torch.long, device=device)

pred, hidden = self.model(sexy_tensor, name_tensor, hidden)
loss += loss_func(pred, target_tensor)

loss.backward()
optimizer.step()

total_loss += loss/(len(name) - 1)

print("Training: in epoch {} loss {}".format(epoch, total_loss/1000))

def sample(self, sexy, start):
max_len = 8
result = [] with torch.no_grad():
hidden = None
for c in start:
sexy_tensor = torch.tensor([sexy2input(sexy)], dtype=torch.float, device=device)
name_tensor = torch.tensor([name2input(c)], dtype=torch.float, device=device)
pred, hidden = self.model(sexy_tensor, name_tensor, hidden)

c = start[-1] while c != '!':
sexy_tensor = torch.tensor([sexy2input(sexy)], dtype=torch.float, device=device)
name_tensor = torch.tensor([name2input(c)], dtype=torch.float, device=device)

暂时禁止评论

微信扫一扫

易采站长站微信账号