当前位置:首页 » 《随便一记》 » 正文

AI对联生成案例(二)

3 人参与  2024年02月11日 16:21  分类 : 《随便一记》  评论

点击全文阅读


模型训练

有了处理好的数据,我们就可以进行训练了。你可以选择本地训练或在OpenPAI上训练

OpenPAI上训练

OpenPAI 作为开源平台,提供了完整的 AI 模型训练和资源管理能力,能轻松扩展,并支持各种规模的私有部署、云和混合环境。因此,我们推荐在OpenPAI上训练。

完整训练过程请查阅: 在OpenPAI上训练

本地训练

如果你的本地机器性能较好,也可以在本地训练。

模型训练的代码请参考 train.sh。

训练过程依然调用t2t模型训练命令:。具体命令如下:t2t_trainer

<span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><span style="color:#1f2328"><span style="color:var(--fgColor-default, var(--color-fg-default))"><span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><code>TRAIN_DIR=./outputLOG_DIR=${TRAIN_DIR}DATA_DIR=./data_dirUSR_DIR=./usr_dirPROBLEM=translate_up2downMODEL=transformerHPARAMS_SET=transformer_smallt2t-trainer \--t2t_usr_dir=${USR_DIR} \--data_dir=${DATA_DIR} \--problem=${PROBLEM} \--model=${MODEL} \--hparams_set=${HPARAMS_SET} \--output_dir=${TRAIN_DIR} \--keep_checkpoint_max=1000 \--worker_gpu=1 \--train_steps=200000 \--save_checkpoints_secs=1800 \--schedule=train \--worker_gpu_memory_fraction=0.95 \--hparams="batch_size=1024" 2>&1 | tee -a ${LOG_DIR}/train_default.log</code></span></span></span></span>

各项参数的作用和取值分别如下:

t2t_usr_dir:如前一小节所述,指定了处理对联问题的模块所在的目录。

data_dir:训练数据目录

problem:问题名称,即translate_up2down

model:训练所使用的 NLP 算法模型,本案例中使用 transformer 模型

hparams_set:transformer 模型下,具体使用的模型。transformer 的各种模型定义在 tensor2tensor/models/transformer.py 文件夹内。本案例使用 transformer_small 模型。

output_dir:保存训练结果

keep_checkpoint_max:保存 checkpoint 文件的最大数目

worker_gpu:是否使用 GPU,以及使用多少 GPU 资源

train_steps:总训练次数

save_checkpoints_secs:保存 checkpoint 的时间间隔

schedule:将要执行的 方法,比如:train, train_and_evaluate, continuous_train_and_eval,train_eval_and_decode, run_std_servertf.contrib.learn.Expeiment

worker_gpu_memory_fraction:分配的 GPU 显存空间

hparams:定义 batch_size 参数。

好啦,我们输入完命令,点击回车,训练终于跑起来啦!如果你在拥有一块 K80 显卡的机器上运行,只需5个小时就可以完成训练。如果你只有 CPU ,那么你只能多等几天啦。 我们将训练过程运行在 Microsoft OpenPAI 分布式资源调度平台上,使用一块 K80 进行训练。

如果你想利用OpenPAI平台训练,可以查看在OpenPAI上训练。

4小时24分钟后,训练完成,得到如下模型文件:

检查站型号.ckpt-200000.data-00000-of-00003型号.ckpt-200000.data-00001-of-00003型号.ckpt-200000.data-00002-of-00003型号.ckpt-200000.index型号.ckpt-200000.meta

我们将使用该模型文件进行模型推理。

模型推理

在这一阶段,我们将使用上述训练得到的模型文件进行模型推理,利用上联生成下联。

新建推理脚本文件inference.sh

点击查看 inference.sh 的代码。

在推理之前,需要注意如下几个目录:

TRAIN_DIR:上述的训练模型文件存放的目录。DATA_DIR:训练字典文件存放目录,即之前提到的。merge.txt.vocab.cleanUSR_DIR:自定义问题的存放目录,即之前提到的文件。merge_vocab.py
<span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><span style="color:#1f2328"><span style="color:var(--fgColor-default, var(--color-fg-default))"><span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><code>TRAIN_DIR=./outputDATA_DIR=./data_dirUSR_DIR=./usr_dirDECODE_FILE=./decode_this.txtPROBLEM=translate_up2downMODEL=transformerHPARAMS=transformer_smallBEAM_SIZE=4ALPHA=0.6poet=$1new_chars=""for ((i=0;i < ${#poet} ;++i))donew_chars="$new_chars ${poet:i:1}"doneecho $new_chars > decode_this.txtecho "生成中..."t2t-decoder \--t2t_usr_dir=$USR_DIR \  --data_dir=$DATA_DIR \  --problem=$PROBLEM \  --model=$MODEL \  --hparams_set=$HPARAMS \  --output_dir=$TRAIN_DIR \  --decode_from_file=$DECODE_FILE \  --decode_to_file=result.txt >> /dev/null 2>&1echo $new_charscat result.txt</code></span></span></span></span>

开始推理

给增加可执行权限inference.sh

<span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><span style="color:#1f2328"><span style="color:var(--fgColor-default, var(--color-fg-default))"><span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><code>chmod +x ./inference.sh</code></span></span></span></span>

使用如下命令推理

<span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><span style="color:#1f2328"><span style="color:var(--fgColor-default, var(--color-fg-default))"><span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><code>./inference.sh [上联]</code></span></span></span></span>

例如,

<span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><span style="color:#1f2328"><span style="color:var(--fgColor-default, var(--color-fg-default))"><span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><code>./inference.sh 西子湖边逢暮雨</code></span></span></span></span>

等待推理完成后,你可能会得到下面的输出。当然,下联的生成和你的训练集、迭代次数等都有关系,因此大概率不会有一样的结果。

<span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><span style="color:#1f2328"><span style="color:var(--fgColor-default, var(--color-fg-default))"><span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><code>生成中...西 子 湖 边 逢 暮 雨故 里 乾 坤 日 盖 章</code></span></span></span></span>

推理结果也保存到了文件中。result.txt

搭建后端服务

训练好了模型,我们显然不能每次都通过命令行来调用,既不用户友好,又需要多次加载模型。因此,我们可以通过搭建一个后端服务,将模型封装成一个api,以便构建应用。

我们后端服务架构如下:

首先,利用为我们的模型开启服务,再通过Flask构建一个Web应用接收和响应http请求,并与我们的模型服务通信获取推理结果。tensorflow-serving-api

开启模型服务

开启模型服务有以下几个步骤:

安装tensorflow-serving-api

<span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><span style="color:var(--fgColor-default, var(--color-fg-default))"><span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><code>pip3 install tensorflow-serving-api==1.14.0echo "deb [arch=amd64] http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | sudo tee /etc/apt/sources.list.d/tensorflow-serving.listcurl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | sudo apt-key add -sudo apt-get update && sudo apt-get install tensorflow-model-server</code></span></span></span>

注意:

安装会自动安装的cpu版本,会覆盖版本。tensorflow-serving-apitensorflowtensorflow-gpu如果有依赖缺失,请查阅:https://medium.com/@noone7791/how-to-install-tensorflow-serving-load-a-saved-tf-model-and-connect-it-to-a-rest-api-in-ubuntu-48e2a27b8c2a。

导出我们训练好的模型

<span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><span style="color:var(--fgColor-default, var(--color-fg-default))"><span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><code>t2t-exporter --model=transformer  \        --hparams_set=transformer_small  \        --problem=translate_up2down  \        --t2t_usr_dir=./usr_dir \        --data_dir=./data_dir \        --output_dir=./output</code></span></span></span>

启动服务

<span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><span style="color:var(--fgColor-default, var(--color-fg-default))"><span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><code>tensorflow_model_server --port=9000 --model_name=up2down --model_base_path=$HOME/output/export</code></span></span></span>

此处需要注意,

--port:服务开启的端口--model_name:模型名称,可自定义,会在后续使用到--model_base_path:导出的模型的目录

至此,模型服务已成功启动。

在Python中调用

启动模型服务后,完成以下步骤即可在Python中调用模型完成推理。

首先,新建目录,并将文件按如下目录结构放置。service

<span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><span style="color:#1f2328"><span style="color:var(--fgColor-default, var(--color-fg-default))"><span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><code>service \  config.json  up2down_model \    up2down_model.py    data \      __init__.py      merge.txt.vocab.clean      merge_vocab.py</code></span></span></span></span>

其中,字典文件和需拷贝到目录。merge.txt.vocab.cleanmerge_vocab.pyservice\up2down_model\data

此外,我们将与模型服务通信获取下联的函数封装在了up2down_model.py中,下载该文件后拷贝到目录。service\up2down_model

另外,我们需要修改config.json文件为对应的内容:

<span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><span style="color:#1f2328"><span style="color:var(--fgColor-default, var(--color-fg-default))"><span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><code>{    "t2t_usr_dir":"./up2down_model/data",    "problem":"translate_up2down",    "model_name":"up2down",    "server_address":"127.0.0.1:9000"}</code></span></span></span></span>
t2t_usr_dir:对联问题模块的定义文件及字典的存放目录model_name:开启时定义的模型名称tensorflow-serving-apiproblem:定义的问题名称server_address: 服务开启的地址及端口

最后,在目录下新建Python文件,通过以下两行代码即可完成模型的推理并生成下联。service

<span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><span style="color:#1f2328"><span style="color:var(--fgColor-default, var(--color-fg-default))"><span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><code>from up2down_model.up2down_model import up2downup2down.get_down_couplet([upper_couplet])</code></span></span></span></span>

由于服务开启后无需再次加载模型和其余相关文件,因此模型推理速度非常快,适合作为应用的接口调用。

搭建Flask Web应用

利用Flask,我们可以快速地用Python搭建一个Web应用,实现对联生成。

主要分为以下几个步骤:

安装flask

<span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><span style="color:var(--fgColor-default, var(--color-fg-default))"><span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><code>pip3 install flask</code></span></span></span>

搭建服务

我们在目录下新建一个文件,内容如下:serviceapp.py

<span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><span style="color:var(--fgColor-default, var(--color-fg-default))"><span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><code>from flask import Flaskfrom flask import requestfrom up2down_model.up2down_model import up2downapp = Flask(__name__)@app.route('/',methods=['GET'])def get_couplet_down():    couplet_up = request.args.get('upper','')    couplet_down = up2down.get_down_couplet([couplet_up])    return couplet_up + "," + couplet_down[0]</code></span></span></span>

由于我们把推理下联的功能封装在中,因此通过几行代码我们就实现了一个web服务。up2down_model.py

启动服务

在测试环境中,我们使用flask自带的web服务即可(注:生产环境应使用uwsgi+nginx部署,有兴趣的同学可以自行查阅资料)。

使用以下两条命令:

In Ubuntu,

<span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><span style="color:var(--fgColor-default, var(--color-fg-default))"><span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><code>export FLASK_APP=app.pypython -m flask run</code></span></span></span>

In Windows,

<span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><span style="color:var(--fgColor-default, var(--color-fg-default))"><span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><code>set FLASK_APP=app.pypython -m flask run</code></span></span></span>

此时,服务就启动啦。

我们仅需向后端 http://127.0.0.1:5000/ 发起get请求,并带上上联参数,即可返回生成的对联到前端。upper

示例,

<span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><span style="color:var(--fgColor-default, var(--color-fg-default))"><span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><code>http://127.0.0.1:5000/?upper=海内存知己</code></span></span></span>

返回结果如图:

后端服务的完整代码请参考:.\src\service

案例拓展

至此,我们已经学会了小程序的核心部分:训练模型、推理模型及搭建后端服务使用模型。由于小程序的其余实现部分涉及比较多的开发知识,超出了NLP的范畴,因此我们不再详细介绍,而是在该部分简单讲解其实现思路,对上层应用开发感兴趣的同学可以参考并实现。

实体提取

当用户通过小程序上传图片时,程序需要从图片中提取出能够描述图片的信息。 本案例利用了微软的Cognitive Service完成从上传的图片中提取实体的工作。上传图片后,程序会调用微软的Cognitive Service并将结果返回。

下面是返回结果的示例:

<span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><span style="color:#1f2328"><span style="color:var(--fgColor-default, var(--color-fg-default))"><span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><code>{    'tags':        [            {'name': 'person', 'confidence': 0.99773770570755},            {'name': 'birthday cake', 'confidence': 0.992998480796814},            {'name': 'food', 'confidence': 0.9029457569122314},            ...        ],    'description':        [            'person', 'woman', 'holding', 'smiling', ...        ]}</code></span></span></span></span>

返回结果中包含了和字段,里面分别包含了该图片的意象。tagsdescription

筛选及翻译Tag

可以看到,调用cognitive service以后,会返回大量的tags,而我们需要从中挑选出符合要求的tag。在这个阶段,我们有两个目标:

找到能准确描述图片内容的tag找到概括性强的tag,

首先,我们为了找出能准确描述图片内容的tag,我们取了返回结果中和中都存在的tag作为对图片描述的tag。这样就初步筛选出了更贴近图片内容的tag。tagsdescription

从直观上理解,概括能力越好的tag自然是出现频率越高的。因此,我们构建了一个高频词典,收集了出现频率前500的tag,并给出了对应的中文翻译。我们仅保留并翻译在词典内的tag,而不在词典内的tag会在这个阶段被进一步地过滤掉。

在高频词典的构建中,我们对中文翻译做了改进,使其与古文意象更接近,便于搜索出对应的上联。因此,高频词典不再是纯粹的中英互译的词典,而是英文tag到相关意象的映射。例如,我们将'building'映射为'楼','skiing'映射为'雪','day'映射为'昼'等。

利用这样的高频词典,就完成了翻译及过滤tag的过程。

思考:会不会出现过滤后的tag太少的情况?

为此,我们做实验统计了两个指标,若仅保留前500个高频tag,tag覆盖约为100%,tag平均覆盖数约为10个/张。

( 注:tag覆盖率 = 至少有一个tag在高频词典内的图片数 / 总图片数 * 100% , tag平均覆盖数 = 每张图片中在高频词典内的tag数之和 / 总图片数 * 100% )

因此可以确保极大多数的图片是不会全部tag都被过滤掉的,并且剩余的tag数量适中。

上联匹配

提取完实体信息,我们的目标是找出与实体匹配程度较高的上联数据。于是,我们希望尽量找出包含两个tag的上联数据,这样能够保证匹配程度较高。

匹配分为如下几个步骤:

分别找出包含每个tag的上联的索引

例如,假设通过上一步的翻译及过滤最终得到了:'天', '草','沙滩'这几个tag,我们需要分别找出包含这几个tag的上联的索引,如:

'天':{ 3, 74, 237, 345, 457, 847 }'草':{ 23, 74, 455, 674, 54, 87, 198 }'沙滩':{ 86, 87, 354, 457 }

找出包含两个tag的对每组索引分别取交集

例如,

'天' + '草':{ 74 }'天' + '沙滩': { 457 }'草' + '沙滩':{ 87 }

合并取交集的结果

例如,得到结果{ 74, 457, 87 }。

若交集为空,则随机从各自tag中选取部分索引。

从上面的结果中随机选出上联数据。

通过以上几个步骤,我们可以在确保至少包含一个tag的同时,尽可能找出包含两个tag的上联。

下联生成

得到了上联以后,我们可以利用上面开启模型服务中提到的方法生成下联。

搭建后端

后端部分的实现也可以参考上述的搭建Flask Web应用或Flask中文文档。

在部署至生产环境时,可以使用uwsgi+nginx的方式。

总结

本案例利用深度学习方法构建了一个上联预测下联的对联生成模型。首先通过词嵌入对数据集编码,再利用已编码的数据训练一个Encoder-Decoder模型,从而实现对联生成的功能。另外,该案例还结合微软Cognitive Service中的目标检测,对用户上传图片进行分析,利用分析结果匹配上联,再通过训练好的模型生成下联。最后,搭建后端服务实现完整的应用功能。该案例很好地演示了从模型选择、训练、推理到搭建后端服务等完整的应用开发流程,将理论与实践结合。

   在线教程

麻省理工学院人工智能视频教程 – 麻省理工人工智能课程人工智能入门 – 人工智能基础学习。Peter Norvig举办的课程EdX 人工智能 – 此课程讲授人工智能计算机系统设计的基本概念和技术。人工智能中的计划 – 计划是人工智能系统的基础部分之一。在这个课程中,你将会学习到让机器人执行一系列动作所需要的基本算法。机器人人工智能 – 这个课程将会教授你实现人工智能的基本方法,包括:概率推算,计划和搜索,本地化,跟踪和控制,全部都是围绕有关机器人设计。机器学习 – 有指导和无指导情况下的基本机器学习算法机器学习中的神经网络 – 智能神经网络上的算法和实践经验斯坦福统计学习

有需要的小伙伴,可以点击下方链接免费领取或者V扫描下方二维码免费领取?

请添加图片描述

人工智能书籍

OpenCV(中文版).(布拉德斯基等)OpenCV+3计算机视觉++Python语言实现+第二版OpenCV3编程入门 毛星云编著数字图像处理_第三版人工智能:一种现代的方法深度学习面试宝典深度学习之PyTorch物体检测实战吴恩达DeepLearning.ai中文版笔记计算机视觉中的多视图几何PyTorch-官方推荐教程-英文版《神经网络与深度学习》(邱锡鹏-20191121)…
在这里插入图片描述

第一阶段:零基础入门(3-6个月)

新手应首先通过少而精的学习,看到全景图,建立大局观。 通过完成小实验,建立信心,才能避免“从入门到放弃”的尴尬。因此,第一阶段只推荐4本最必要的书(而且这些书到了第二、三阶段也能继续用),入门以后,在后续学习中再“哪里不会补哪里”即可。

第二阶段:基础进阶(3-6个月)

熟读《机器学习算法的数学解析与Python实现》并动手实践后,你已经对机器学习有了基本的了解,不再是小白了。这时可以开始触类旁通,学习热门技术,加强实践水平。在深入学习的同时,也可以探索自己感兴趣的方向,为求职面试打好基础。

第三阶段:工作应用

这一阶段你已经不再需要引导,只需要一些推荐书目。如果你从入门时就确认了未来的工作方向,可以在第二阶段就提前阅读相关入门书籍(对应“商业落地五大方向”中的前两本),然后再“哪里不会补哪里”。

 有需要的小伙伴,可以点击下方链接免费领取或者V扫描下方二维码免费领取?

在这里插入图片描述


点击全文阅读


本文链接:http://zhangshiyu.com/post/68171.html

<< 上一篇 下一篇 >>

  • 评论(0)
  • 赞助本站

◎欢迎参与讨论,请在这里发表您的看法、交流您的观点。

关于我们 | 我要投稿 | 免责申明

Copyright © 2020-2022 ZhangShiYu.com Rights Reserved.豫ICP备2022013469号-1