用 LSTM 预测字符序列
在循环神经网络(Recurrent Neural Network)简介中我们了解了什么是 RNN,本文用 TensorFlow 实现一个超级简单的字符预测模型,并对代码进行详细说明,防止自己以后忘记( ╯□╰ )。
首先定义 RNN 网络:
class RNN:
def __init__(self,
in_size,
cell_size,
num_layers,
out_size,
sess,
lr=0.003):
self.in_size = in_size # 输入数据大小,这里为字符数目(注意这里将大写字母转换为了小写字母,减小了字符数目,加快训练)
self.cell_size = cell_size # LSTM cell 的unit数目
self.num_layers = num_layers # 因为是多层LSTM,num_layers指定了层数
self.out_size = out_size # 输出数据大小,同样为字符数目
self.sess = sess # session
self.lr = lr # 学习速率
# 储存上一次的state, 测试的时候用
self.last_state = np.zeros([num_layers * 2 * cell_size])
# 输入数据,(batch, time_step, in_size)
self.inputs = tf.placeholder(tf.float32, shape=[None, None, in_size])
self.lstm_cells = [
rnn.BasicLSTMCell(cell_size, state_is_tuple=False)
for _ in range(num_layers)
]
self.lstm = rnn.MultiRNNCell(self.lstm_cells, state_is_tuple=False)
self.init_state = tf.placeholder(tf.float32, [None, num_layers * 2 * cell_size])
# 定义 recurrent neural network
outputs, self.new_state = tf.nn.dynamic_rnn(
self.lstm,
self.inputs,
initial_state=self.init_state,
dtype=tf.float32)
# 最后使用全连接层来计算loss,w b 为全连接层参数
w = tf.Variable(tf.random_normal([cell_size, out_size], stddev=0.01))
b = tf.Variable(tf.random_normal([out_size], stddev=0.01))
reshaped_outputs = tf.matmul(tf.reshape(outputs, [-1, cell_size]), w) + b
# 将输出转换为概率
fc = tf.nn.softmax(reshaped_outputs)
shape = tf.shape(outputs)
self.final_outputs = tf.reshape(fc, [shape[0], shape[1], out_size])
# labels, (batch, time_step, in_size)
self.targets = tf.placeholder(tf.float32, [None, None, out_size])
# loss
self.cost = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(
logits=reshaped_outputs,
labels=tf.reshape(self.targets, [-1, out_size])))
self.optimizer = tf.train.RMSPropOptimizer(lr).minimize(self.cost)
def train(self, inputs, targets):
# init_state,(batch_size, num_layers * 2 * self.cell_size)
init_state = np.zeros(
[inputs.shape[0], self.num_layers * 2 * self.cell_size])
_, loss = self.sess.run(
[self.optimizer, self.cost],
feed_dict={
self.inputs: inputs,
self.targets: targets,
self.init_state: init_state
})
return loss
def get_next_char_pro(self, x, init=False):
"""根据输入字符x, 预测下一个字符并返回,x 的 shape 为(1, in_size), 即为单个字符的 one-hot 形式
"""
if init:
init_state = np.zeros([self.num_layers * 2 * self.cell_size])
else:
init_state = self.last_state
out, next_state = self.sess.run(
[self.final_outputs, self.new_state],
feed_dict={self.inputs: [x], self.init_state: [init_state]})
# 将当前state储存起来,下次预测的时候使用,这样就保留了上下文信息
self.last_state = next_state[0]
return out[0][0]
处理数据:
def generate_one_hot_data(data, vocabulary):
data_ = np.zeros([len(data), len(vocabulary)])
count = 0
for char in data:
i = vocabulary.index(char)
data_[count, i] = 1.0
count += 1
return data_
path = 'data/data.txt'
with open(path, 'r') as f:
data = f.read()
data = data.lower() # 全部小写,降低复杂度
vocabulary = list(set(data)) # 字符列表
one_hot_data = generate_one_hot_data(data, vocabulary) # one_hot_data,shape 为(len(data), len(vocabulary))
定义参数:
in_size = out_size = len(vocabulary)
cell_size = 128
num_layers = 2
batch_size = 64
time_steps = 128
NUM_TRAIN_BATCHES = 5000
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.InteractiveSession(config=config)
model = RNN(in_size=in_size,
cell_size=cell_size,
num_layers=num_layers,
out_size=out_size,
sess=sess)
sess.run(tf.global_variables_initializer())
inputs = np.zeros([batch_size, time_steps, in_size])
targets = np.zeros([batch_size, time_steps, out_size])
possible_batch_start_ids = range(len(data) - time_steps - 1)
生成训练数据:
for i in range(NUM_TRAIN_BATCHES):
batch_start_ids = random.sample(possible_batch_start_ids, batch_size)
for j in range(time_steps):
inputs_ids = [k + j for k in batch_start_ids]
targets_ids = [k + j + 1 for k in batch_start_ids]
inputs[:, j, :] = one_hot_data[inputs_ids, :]
targets[:, j, :] = one_hot_data[targets_ids, :]
loss = model.train(inputs, targets)
if i % 100 == 0:
print('loss: {:.5f} of batch {}'.format(loss, i))
测试:
# 预测一个以'we ' 开头的句子
TEST_PREFIX = 'we '
for i in range(len(TEST_PREFIX)):
out = model.get_next_char_pro(
generate_one_hot_data(TEST_PREFIX[i], vocabulary), i == 0)
gen_str = TEST_PREFIX
for i in range(200):
index = np.random.choice(range(len(vocabulary)), p=out)
pred = vocabulary[index]
# print(index, len(out), len(vocabulary))
gen_str += pred
out = model.get_next_char_pro(generate_one_hot_data(pred, vocabulary))
print(gen_str)
训练结果:
loss: 3.12114 of batch 100
loss: 3.06464 of batch 200
loss: 2.49030 of batch 300
loss: 2.20478 of batch 400
loss: 2.02495 of batch 500
loss: 1.88005 of batch 600
loss: 1.74067 of batch 700
loss: 1.70199 of batch 800
loss: 1.69426 of batch 900
loss: 1.54315 of batch 1000
loss: 1.55390 of batch 1100
loss: 1.48157 of batch 1200
loss: 1.49068 of batch 1300
loss: 1.48002 of batch 1400
loss: 1.44718 of batch 1500
loss: 1.44463 of batch 1600
loss: 1.44997 of batch 1700
loss: 1.41875 of batch 1800
loss: 1.39939 of batch 1900
loss: 1.40289 of batch 2000
loss: 1.36561 of batch 2100
loss: 1.37862 of batch 2200
loss: 1.39195 of batch 2300
loss: 1.36577 of batch 2400
loss: 1.33385 of batch 2500
loss: 1.36705 of batch 2600
loss: 1.28989 of batch 2700
loss: 1.35604 of batch 2800
loss: 1.30150 of batch 2900
loss: 1.30576 of batch 3000
loss: 1.30832 of batch 3100
loss: 1.28968 of batch 3200
loss: 1.28770 of batch 3300
loss: 1.27296 of batch 3400
loss: 1.31154 of batch 3500
loss: 1.30212 of batch 3600
loss: 1.29709 of batch 3700
loss: 1.28448 of batch 3800
loss: 1.28766 of batch 3900
loss: 1.25970 of batch 4000
loss: 1.27008 of batch 4100
loss: 1.29763 of batch 4200
loss: 1.25666 of batch 4300
loss: 1.29813 of batch 4400
loss: 1.26807 of batch 4500
loss: 1.23903 of batch 4600
loss: 1.21010 of batch 4700
loss: 1.27084 of batch 4800
loss: 1.27161 of batch 4900
loss: 1.25789 of batch 5000
loss: 1.22986 of batch 5100
loss: 1.24404 of batch 5200
loss: 1.27089 of batch 5300
loss: 1.23036 of batch 5400
loss: 1.25348 of batch 5500
loss: 1.23626 of batch 5600
loss: 1.21493 of batch 5700
loss: 1.20419 of batch 5800
loss: 1.23771 of batch 5900
loss: 1.20754 of batch 6000
loss: 1.23489 of batch 6100
loss: 1.20233 of batch 6200
loss: 1.20366 of batch 6300
loss: 1.23586 of batch 6400
loss: 1.21687 of batch 6500
loss: 1.19479 of batch 6600
loss: 1.21297 of batch 6700
loss: 1.23598 of batch 6800
loss: 1.19476 of batch 6900
loss: 1.21584 of batch 7000
loss: 1.22816 of batch 7100
loss: 1.19449 of batch 7200
loss: 1.19346 of batch 7300
loss: 1.23466 of batch 7400
loss: 1.18541 of batch 7500
loss: 1.19469 of batch 7600
loss: 1.21069 of batch 7700
loss: 1.19641 of batch 7800
loss: 1.15550 of batch 7900
loss: 1.19861 of batch 8000
loss: 1.22582 of batch 8100
loss: 1.19766 of batch 8200
loss: 1.19041 of batch 8300
loss: 1.15410 of batch 8400
loss: 1.13109 of batch 8500
loss: 1.16434 of batch 8600
loss: 1.19457 of batch 8700
loss: 1.18558 of batch 8800
loss: 1.18043 of batch 8900
loss: 1.17171 of batch 9000
loss: 1.17663 of batch 9100
loss: 1.14107 of batch 9200
loss: 1.20001 of batch 9300
loss: 1.16926 of batch 9400
loss: 1.14761 of batch 9500
loss: 1.15305 of batch 9600
loss: 1.20601 of batch 9700
loss: 1.16141 of batch 9800
loss: 1.15704 of batch 9900
一些输出:
prefix:she
sheated here do in the weary blood
unless he single thou mayst break your ancients dren
lads, then i in to thy fair age.
so i say the nursed with mine eyes from my sleep.
clarence:
how is it pale and mo
prefix:the
the tears of mine.
hark! they may providy to the extremement,
make an office of day, like to be so run out.
dost than the heaven and your pardon did forght
the crown: all shall be too close hather change
prefix:i
in heavy part:
in your comforts of our state exposed
for in the kin a phalteres times disprison.
isabella:
dispatch and great to do a plance:
and now they do repent us to be pitteth caser
which in my
prefix:i
it confined him to his necessities,
but to an e this young belonged,'
as it susminine of france of vices of folly,
when though 'i trust up me not, our reward,
sir, you shall be absent disposed.
warwic
prefix:me
me;
for will is angelo, take it,--and thou knows he did:
one catesby!
pronirerd:
uppecured for me hencefully, you make one flight,
making down thee thought in who thy by the world,
but name, though hat