InfoQ Geekathon 大模型技术应用创新大赛 了解详情
写点什么

Keras 和 TensorFlow 之争何时休?

  • 2019-09-30
  • 本文字数:3571 字

    阅读完需:约 12 分钟

Keras和TensorFlow之争何时休?

本文的作者经常在电子邮箱中、社交媒体上,甚至在与深度学习研究人员、从业者和工程师面对面交谈时,会被问到这些问题:我应该在项目中使用 Keras 还是 TensorFlow?TensorFlow 和 Keras 哪个更好?我应该花时间研究 TensorFlow 还是 Keras?你是不是也有相同的疑问?如果有,相信这篇文章会给你答案。


实际上,到 2017 年中,Keras 已经被大规模采用,并与 TensorFlow 集成在一起。这种 TensorFlow + Keras 的组合让你可以:


  1. 使用 Keras 的接口定义模型;

  2. 如果你需要特定的 TensorFlow 功能或者需要实现 Keras 不支持但 TensorFlow 支持的自定义功能,可以回到 TensorFlow。


简单地说,你可以将 TensorFlow 代码直接插入到 Keras 的模型或训练管道中!


但请别误会,我并不是说你就不需要了解 TensorFlow 了。我的意思是,如果你:


  1. 刚开始接触深度学习……

  2. 在为下一个项目选型……

  3. 想知道 Keras 或 TensorFlow 哪个“更好”……


我的建议是先从 Keras 着手,然后深入 TensorFlow,这样可以获得你需要的某些特定功能。


在这篇文章中,我将向你展示如何使用 Keras 训练神经网络,以及如何使用直接构建在 TensorFlow 库中的 Keras + TensorFlow 组合来训练模型。

Keras 与 TF 我该学哪个?

在文章的其余部分,我将继续讨论有关 Keras 与 TensorFlow 的争论以及为什么说这个问题其实是个错误的问题。


我们将使用标准的 keras 模块以及 TensorFlow 的 tf.keras 模块实现一个卷积神经网络(CNN)。


我们将在一个样本数据集上训练 CNN,然后检查结果——你会发现,Keras 和 TensorFlow 可以很融洽地合作。


最重要的是,你将会了解为什么 Keras 与 TensorFlow 之间的争论其实是没有意义的。


尽管从 TensorFlow 宣布将 Keras 集成到官方 TensorFlow 版本中已经一年多时间了,但很多深度学习从业者仍然不知道他们可以通过 tf.keras 子模块访问 Keras,为此我感到很惊讶。


更重要的是,Keras + TensorFlow 的集成是无缝的,你可以直接将 TensorFlow 代码放到 Keras 模型中。


在 TensorFlow 中使用 Keras 将为你带来两全其美的好处:


  1. 你可以使用 Keras 提供的简单直观的 API 来创建模型;

  2. Keras API 与 scikit-learn(被认为是机器学习 API 的“黄金标准”)很像;

  3. Keras API 采用了模块化,易于使用;

  4. 当你需要自定义实现或者更复杂的损失函数时,可以直接进入 TensorFlow,并让代码自动与 Keras 模型集成。


在过去几年中,深度学习研究人员、从业人员和工程师通常需要做出以下选择:


  1. 我是选择易用但难以定制的 Keras 库?

  2. 还是选择难用的 TensorFlow API,并编写更多的代码?


所幸的是,我们不必再纠结了。


如果你发现自己还在问这样的问题,那么请退后一步——你问的是错误的问题——你可以同时拥有这两个框架。



如图所示,导入 TensorFlow(tf),然后调用 tf.keras,可见 Keras 实际上已经成为 TensorFlow 的一部分。


在 tf.keras 中包含 Keras 让你可以使用标准的 Keras 包实现简单的前馈神经网络:



然后使用 tf.keras 子模块实现相同的网络:



这是否意味着你必须使用 tf.keras?标准的 Keras 包是不是已经过时?当然不是。


作为一个库,Keras 仍然可以单独使用,因此未来两者可能会分道扬镳。不过,因为谷歌官方支持 Keras 和 TensorFlow,所以似乎不太可能出现这种情况。


关键是:


如果你习惯使用 Keras 编写代码,那么请继续这样做。


但如果你主要使用的是 TensorFlow,那么应该开始考虑一下 Keras API:


  1. 它内置于 TensorFlow 中;

  2. 它更容易使用;

  3. 当你需要使用 TensorFlow 来实现特定功能时,可以直接将其集成到 Keras 模型中。

我们的样本数据集


CIFAR-10 数据集包含了 10 个分类,我们将它用在我们的演示中。


为简单起见,我们将使用以下方法在 CIFAR-10 数据集上训练两个单独的卷积神经网络:


  1. TensorFlow + Keras;

  2. tf.keras 的 Keras 子模块。


我还将展示如何将自定义的 TensorFlow 代码包含在 Keras 模型中。

我们的项目结构

可以使用 tree 命令在终端中查看我们的项目结构:



pyimagesearch 模块不能通过 pip 安装,请点击文末提供的下载链接。现在让我们看一下该模块的两个重要 Python 文件:


  • minivggnetkeras.py:MiniVGGNet(一个机遇 VGGNet 的深度学习模型)的 Keras 实现。

  • minivggnettf.py:MiniVGGNet 的 TensorFlow + Keras(即 tf.keras)实现。


项目根目录包含两个 Python 文件:


  • train_network_keras.py:Keras 版本的训练脚本。

  • train_network_tf.py:TensorFlow + Keras 版本的训练脚本,几乎与前一个一模一样。


每个脚本都将生成相应的训练准确率和损失:


  • plot_keras.png

  • plot_tf.png

使用 Keras 训练网络


训练的第一步是使用 Keras 实现网络架构。


打开 minivggnetkeras.py 文件,并插入以下代码:



我们先导入构建模型需要的一系列 Keras 包。


然后定义我们的 MiniVGGNetKeras 类:



我们定义了 build 方法、inputShape 和 input。


然后定义卷积神经网络的主要部分:



你会发现我们在应用池化层之前堆叠了一系列卷积、ReLU 激活和批量规范化层,以便减少卷的空间维度。还使用了 Dropout 来减少过拟合。


现在将全连接层添加到网络中:



我们已经使用 Keras 实现了 CNN,现在创建将用于训练的驱动脚本。


打开 train_network_keras.py 并插入以下代码:



我们先导入需要的包。


  • matplotlib 设置为“Agg”,这样就可以将训练结果保存为图像文件。

  • 然后导入 MiniVGGNetKeras 类。

  • 我们使用 scikit-learn 的 LabelBinarizer 进行“独热”编码,并使用 classification_report 打印分类精度。

  • 然后导入数据集。


我们通过 --plot 传入命令行参数,也就是图像的保存路径。


现在让我们加载 CIFAR-10 数据集,并对标签进行编码:



我们先加载和提取训练和测试分割,并将它们转换为浮点数和进行数据缩放。


然后我们对标签进行编码,并初始化 labelNames。


接下来,让我们开始训练模型:



我们先设置训练参数和优化方法。


然后我们使用 MiniVGGNetKeras.build 方法初始化和编译模型。


随后,我们启动了训练程序。


现在让我们来评估网络并生成结果图:



我们基于数据的测试分割来评估网络,并生成 classification_report,最后再导出结果。


注意:通常我会序列化并导出模型,以便可以将其用在图像或视频的处理脚本中,但这里不打算这样做,因为这超出了本文的范围。


打开一个终端并执行以下命令:





我的 CPU 完成一个 epoch 需要 5 分多钟。



我们获得了 75%的准确率——当然不是最先进的,不过它比随机猜测(1/10)要好得多。


对于小型网络来说,我们的准确率算是非常好的了,而且没有发生过拟合。

使用 TensorFlow 和 tf.keras 训练网络

使用 tf.keras 构建的 MiniVGGNet CNN 与我们直接使用 Keras 构建的模型是一样的,除了为演示目的而修改的激活函数。


现在我们已经使用 Keras 库实现并训练了一个简单的 CNN,接下来我们要:


  1. 使用 TensorFlow 的 tf.keras 实现相同的网络;

  2. 在 Keras 模型中包含一个 TensorFlow 激活函数,这个函数不是使用 Keras 实现的。


首先,打开 minivggnettf.py 文件,我们将实现 TensorFlow 版本的 MiniVGGNet:





请注意,导入部分只有一行。tf.keras 子模块包含了我们可以直接调用的所有 Keras 函数。


我想强调一下 Lambda 层——它们用来插入自定义激活函数 CRELU(Concatenated ReLU)。


Keras 并没有实现 CRELU,但 TensorFlow 实现了——通过使用 TensorFlow 和 tf.keras,我们可以使用一行代码将 CRELU 添加到 Keras 模型中。


下一步是编写 TensorFlow + Keras 驱动脚本来训练 MiniVGGNetTF。


打开 train_network_tf.py 并插入以下代码:




然后是解析命令行参数。


接着像之前一样加载数据集。


其余的行都一样——提取训练 / 测试分割和编码标签。


现在让我们开始训练模型:




训练过程几乎是一样的。我们已经实现了完全相同的训练流程,只是这次使用的是 tf.keras。


打开一个终端并执行以下命令:




训练完成后,你将获得类似于下面这样的结果:



通过使用 CRELU 替换 RELU 激活函数,我们获得了 76%的准确率。不过,这 1%的提升可能是因为网络权重的随机初始化,需要通过进一步的交叉验证实验来证明这种准确率的提升确实是因为 CRELU。


不管怎样,原始准确率并不是本节的重点。我们需要关注的是如何在 Keras 模型内部使用 TensorFlow 激活函数替换标准的 Keras 激活函数!


你也可以使用自己的自定义激活函数、损失 / 成本函数或层。

总 结

在这篇文章中,我们讨论了 Keras 和 TensorFlow 相关的问题,包括:


  • 我应该在项目中使用 Keras 还是 TensorFlow?

  • TensorFlow 和 Keras 哪个更好?

  • 我应该花时间研究 TensorFlow 还是 Keras?


最后我们发现,在 Keras 和 TensorFlow 之间做出选择变得不那么重要。


因为 Keras 库已经通过 tf.keras 模块直接集成到 TensorFlow 中了。


相关代码下载:


https://app.monstercampaigns.com/c/hvovin011avqlrtdtz0j/


原文链接:


https://www.pyimagesearch.com/2018/10/08/keras-vs-tensorflow-which-one-is-better-and-which-one-should-i-learn/


活动推荐:

2023年9月3-5日,「QCon全球软件开发大会·北京站」 将在北京•富力万丽酒店举办。此次大会以「启航·AIGC软件工程变革」为主题,策划了大前端融合提效、大模型应用落地、面向 AI 的存储、AIGC 浪潮下的研发效能提升、LLMOps、异构算力、微服务架构治理、业务安全技术、构建未来软件的编程语言、FinOps 等近30个精彩专题。咨询购票可联系票务经理 18514549229(微信同手机号)。

2019-09-30 11:332785
用户头像

发布了 731 篇内容, 共 422.0 次阅读, 收获喜欢 1988 次。

关注

评论

发布
暂无评论
发现更多内容

Java技术:SpringBoot实现邮件发送功能

天使不哭

Java email #开源 8月月更

什么是Linux内核,怎么才能玩转它?

简说Linux内核

Linux内核 进程管理 嵌入式开发 设备驱动

《MySQL入门很轻松》第4章:数据表中存放的数据类型

乌龟哥哥

8月月更

如何正确理解线程机制中常见的I/O模型,各自主要用来解决什么问题?

PivotalCloud

Linux Linux Kenel

FileZilla搭建FTP服务器图解教程

天使不哭

#开源 8月月更

你有对象类,我有结构体,Go lang1.18入门精炼教程,由白丁入鸿儒,go lang结构体(struct)的使用EP06

刘悦的技术博客

Go golang 编程语言 Go web golang 面试

头脑风暴:单词拆分

HelloWorld杰少

算法 LeetCode 数据结构, 8月月更

数据治理(五):元数据管理

Lansonli

大数据 数据治理 8月月更

阿里云架构师金云龙:基于云XR平台的视觉计算应用部署

阿里云弹性计算

视觉计算 计算巢 云XR平台 GPU实例

上海一科技公司刷单被罚22万,揭露网络刷单灰色产业链

石头IT视角

SRv6性能测量

穿过生命散发芬芳

8月月更 SRv6

10min快速回顾C++语法(一)

timerring

c++ 算法 8月月更

Kubernetes YAML编写 讲解

CTO技术共享

开源 签约计划第三季 8月月更

数据库治理利器:动态读写分离

阿里巴巴云原生

数据库 阿里云 微服务 云原生

每日一R「02」所有权与 Move 语义

Samson

签约计划第三季 8月月更 ​Rust

带着昇腾去旅行:一日看尽金陵城里的AI胜景

脑极体

每天一个CSS小特效,文字闪烁——【钢铁侠:爱你三千遍】

前端小刘不怕牛牛

JavaScript html/css 8月月更

RocketMQ Binder集成消息订阅

急需上岸的小谢

8月月更

前端食堂技术周刊第 47 期:Docusaurus 2.0 、7 月登陆网络平台的新内容 、Nuxt.js 团队的轮子库

童欧巴

JavaScript 前端

操作系统:SSH协议知识介绍

天使不哭

Linux SSH #开源 8月月更

Kubernetes服务接入Istio

CTO技术共享

开源 签约计划第三季 8月月更

Kubernetes 60个为什么

CTO技术共享

开源 签约计划第三季 8月月更

小程序+自定义插件的关键性

Geek_99967b

小程序

开发者必备:一文快速熟记【数据库系统】和【软件开发模型】常用知识点

小阿杰

软件开发流程 软件开发原则 数据库系统 签约计划第三季

Kubernetes 开发环境比对

CTO技术共享

开源 签约计划第三季 8月月更

大型分布式存储方案MinIO介绍,看完你就懂了!

天使不哭

存储 MINO #开源 8月月更

781. 森林中的兔子

小卢要刷力扣题

力扣 8月月更

什么是服务治理

阿泽🧸

服务治理 8月月更

程序员从佩洛西窜访事件中可以学到什么?

慕枫技术笔记

思维 构架 8月月更

一文教会你快速上手 Vim

昆吾kw

vim Linux

全面解析FPGA基础知识

向阳逐梦

签约计划第三季

  • 扫码添加小助手
    领取最新资料包
Keras和TensorFlow之争何时休?_AI_Adrian Rosebrock_InfoQ精选文章