码农的后花园

V1

2022/08/20阅读:16主题:橙心

AI创作诗词

很久以来,我们都想让机器自己创作诗歌,当无数作家、编辑还没有抬起笔时,AI已经完成了数千篇文章。现在,这里是第一步....

这诗做的很有感觉啊,这都是勤奋的结果啊,基本上学习了全唐诗的所有精华才有了这么牛逼的能力,这一般人能做到?

甚至还可以模仿周杰伦创作歌词 !! 怎么说,目前由于缺乏训练文本,导致我们的AI做的歌词有点....额,还好啦,有那么一点忧郁之风。

1.下载代码和数据集

Github地址: https://github.com/jinfagang/tensorflow_poems

数据集: 存放于项目的data文件夹内

2.环境导入

import os
import tensorflow as tf
from poems.model import rnn_model
from poems.poems import process_poems, generate_batch
import argparse
from pathlib import Path

3.参数设置

parser = argparse.ArgumentParser()
#type是要传入的参数的数据类型  help是该参数的提示信息
parser.add_argument('--batch_size', type=int, help='batch_size',default=64)
parser.add_argument('--learning_rate', type=float, help='learning_rate',default=0.0001)
parser.add_argument('--model_dir', type=Path, help='model save path.',default='./model')
parser.add_argument('--file_path', type=Path, help='file name of poems.',default='./data/poems.txt')
parser.add_argument('--model_prefix', type=str, help='model save prefix.',default='poems')
parser.add_argument('--epochs', type=int, help='train how many epochs.',default=126)

args = parser.parse_args(args=[])

4.训练

下载的代码中的./model/中包含最新的训练模型,再次训练会接着训练。如果训练路径报错,需要删除./model的模型,重新开始训练。

def run_training():
    if not os.path.exists(args.model_dir):
        os.makedirs(args.model_dir)

    poems_vector, word_to_int, vocabularies = process_poems(args.file_path)
    batches_inputs, batches_outputs = generate_batch(args.batch_size, poems_vector, word_to_int)

    input_data = tf.placeholder(tf.int32, [args.batch_size, None])
    output_targets = tf.placeholder(tf.int32, [args.batch_size, None])

    end_points = rnn_model(model='lstm', input_data=input_data, output_data=output_targets, vocab_size=len(
        vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=args.learning_rate)

    saver = tf.train.Saver(tf.global_variables())
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    with tf.Session() as sess:
        # sess = tf_debug.LocalCLIDebugWrapperSession(sess=sess)
        # sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)
        sess.run(init_op)

        start_epoch = 0
        checkpoint = tf.train.latest_checkpoint(args.model_dir)
        if checkpoint:
            saver.restore(sess, checkpoint)
            print("## restore from the checkpoint {0}".format(checkpoint))
            start_epoch += int(checkpoint.split('-')[-1])
        print('## start training...')
        try:
            n_chunk = len(poems_vector) // args.batch_size
            for epoch in range(start_epoch, args.epochs):
                n = 0
                for batch in range(n_chunk):
                    loss, _, _ = sess.run([
                        end_points['total_loss'],
                        end_points['last_state'],
                        end_points['train_op']
                    ], feed_dict={input_data: batches_inputs[n], output_targets: batches_outputs[n]})
                    n += 1
                print('Epoch: %d, batch: %d, training loss: %.6f' % (epoch, batch, loss))
                #if epoch % 5 == 0:
                saver.save(sess, os.path.join(args.model_dir, args.model_prefix), global_step=epoch)
        except KeyboardInterrupt:
            print('## Interrupt manually, try saving checkpoint for now...')
            saver.save(sess, os.path.join(args.model_dir, args.model_prefix), global_step=epoch)
            print('## Last epoch were saved, next time will start from epoch {}.'.format(epoch))
            
run_training()

5.诗词生成

import tensorflow as tf
from poems.model import rnn_model
from poems.poems import process_poems
import numpy as np

start_token = 'B'
end_token = 'E'
model_dir = './model/'
corpus_file = './data/poems.txt'

lr = 0.0002

def to_word(predict, vocabs):
    predict = predict[0]       
    predict /= np.sum(predict)
    sample = np.random.choice(np.arange(len(predict)), p=predict)
    if sample > len(vocabs):
        return vocabs[-1]
    else:
        return vocabs[sample]


def gen_poem(begin_word):
    batch_size = 1
    print('## loading corpus from %s' % model_dir)
    tf.reset_default_graph()
    
    poems_vector, word_int_map, vocabularies = process_poems(corpus_file)

    input_data = tf.placeholder(tf.int32, [batch_size, None])

    end_points = rnn_model(model='lstm', input_data=input_data, output_data=None, vocab_size=len(
        vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=lr)

    saver = tf.train.Saver(tf.global_variables())
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    with tf.Session() as sess:
        sess.run(init_op)

        checkpoint = tf.train.latest_checkpoint(model_dir)
        saver.restore(sess, checkpoint)

        x = np.array([list(map(word_int_map.get, start_token))])

        [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']],
                                         feed_dict={input_data: x})
        word = begin_word or to_word(predict, vocabularies)
        poem_ = ''

        i = 0
        while word != end_token:
            poem_ += word
            i += 1
            if i > 24:
                break
            x = np.array([[word_int_map[word]]])
            [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']],
                                             feed_dict={input_data: x, end_points['initial_state']: last_state})
            word = to_word(predict, vocabularies)

        return poem_


def pretty_print_poem(poem_):
    poem_sentences = poem_.split('。')
    for s in poem_sentences:
        if s != '' and len(s) > 10:
            print(s + '。')

6.测试运行

7. 运行环境

本次使用框架TensorFlow1.13.1,本项目可以在华为提供的JupyterLab环境中运行。 参考华为的实践案例:《AI作诗》:http://su.modelarts.club/dqTT https://developer.huaweicloud.com/signup/e4240e984d1c4d20bfcc83e7f7648b6c?

后台回复关键字:项目实战,可下载完整代码。

分类:

人工智能

标签:

自然语言处理

作者介绍

码农的后花园
V1

公众号:码农的后花园