在本篇文章中,作者将在 Rust 上移植一个简单的神经网络实现。他的目标是探索 Rust 中的数据科学工作流在性能和工程学上的表现。
Python 实现
第一章描述了一个非常简单的单层神经网络。这个神经网络可以使用基于随机梯度下降的机器学习算法,对来自于MNIST数据集的手写数字进行分类。这听起来挺复杂,这些东西也确实在上世纪 80 年代中期是最先进的,但是实际上,这全部是由一段150行的Python代码做出来的,而且这些代码有很多人评论过。
如果你已经知道了这一节的内容(神经网络基础知识),那么我建议你可以跳过去,当然如果想再复习一下神经网络的基础知识也是可以看这一节的。或者不要只关注代码,没有必要特别细致的理解代码为什么以现有的方式运行,而应该关注 Python 方法和 Rust 方法的不同。
代码中的基础数据容器是一个 Network 类,它表示一个对层数和每层网络数量可控制的神经网络。在 Network 类的内部,用 2D NumPy 数组组成的列表表示类的数据。网络各层用一个二维的权重数组和一个一维的偏差数组来表示,这些数组被包含在叫做 biases 和 weights 的 Network 类的属性中。这些都是二维数组列表。biases 属性是列向量,但是会利用虚拟维度被存储成二维的数组。Network 的构造函数如下:
在这个简单的实现中,属性 biases 和 weights 通过描述标准正态分布来初始化。正态分布的均值为 0,标准差为 1。我们还可以清楚地看到,biases 是如何被初始化为列向量的。
Network 类公开了两个可以被用户直接调用的方法。第一个方法是 evaluate 方法,这个方法可以通过网络尝试识别一系列测试图片中的数字,然后基于先验已知的正确结果,对识别的结果进行打分。第二个方法是 SGD 方法,这个方法可以通过遍历一组图片来执行随机梯度下降的过程。这个过程包括:将整组图像分解成小的类别,基于各个小类别图像来更新网络状态,更新用户指定的学习率,eta,以及在用户随机指定数量的一系列小类别图像上重新运行以上的训练步骤。每组小分类图像和网络的更新的核心算法如下代码所示:
对于小类别集合中的每个经受训练的图像,我们通过反向传播算法(在 backprop 函数中实现)积累成本函数的梯度估计。当程序跑完了小分类图像集合,会根据估计梯度调整权重和偏差。因为我们希望得到小分类集合中所有估计的平均梯度,所以要更新的数据包括分母中的 len(mini_batch)值。我们还可以通过调整学习速率和 eta 值来控制权重和偏差的更新速度,从而可以全局调整每个小分类集合更新的大小。backprop 函数开始于给定输入图像的网络预期,然后通过网络向后面运行,以在网络中通过层来传播这些错误,最后计算出神经网络的成本梯度。这需要大量的数据调整,这也是我在移植到 Rust 时,花费了大量时间的地方。不过我认为我深入这块花费了过于长的时间。如果你想更深入的了解细节,请看这本书的第2章。
Rust 实现
第一步要搞清楚如何加载数据。这一步太繁琐了,我写了一篇专门的文章来介绍。按照顺序,我必须弄清楚如何用 Rust 实现 Python 的 Network 类。最终我决定使用结构体(struct):
和 Python 实现大致一样,依据每层神经网络的数量初始化结构体。
有一个和 Python 实现的不同点。我们在 Python 中,使用 numpy.random.randn来初始化权重和偏差,而在 Rust 中,我们使用 ndarray::Array::random 函数接受一个 rand::distribution::Distributionas 类型的参数和一个其他参数,并允许选择任意分布,来完成初始化。在这种情况下,我们使用了 rand::distributions::StandardNormal 做分布。值得注意的是,这使用了在三个不同包中定义的接口:其中两个接口,ndarray 自身和 ndarray-rand 由 ndarray 的作者维护,剩下一个由其它的开发者维护。
这些统一类包的优势
原则上,一个好处是,随机数生成不会在 ndarray 代码库中被单独隔离,如果新的随机数分布或者功能被加到了 rand 中,ndarray 和在 Rust 系统中的类都可以同等的使用。另一方面,需要为了各种装箱操作,在不同文件之间进行引用,而不是可以在一个集中的地方进行查看,这样增加了一些学习成本。我也有个特殊情况,也算是运气不好,rand 发布了改变公共 api 的新版本时,我正在开发这个工程。这导致了,依赖于 0.6 版本的 rand 的 ndarray-rand 和依赖于 0.7 版本的 rand 的我的工程,之间的不一致。
我了解到 cargo 和 rust 的构建系统可以很好地处理这类问题,但至少在这种情况下,我遇到了一个令人困惑的错误信息。这则错误信息是关于我的随机分布如何不满足 Distribution(分布)特征的。然而,我的分布式正确的,它的随机分布特征满足了 0.7 版本的 rand,而不是 ndarray-rand 依赖的 rand0.6 版本。但是因为装箱的版本信息不出现在错误信息当中,所以会让人非常的困惑。我最后提交了一个issue。我发现对于 Rust 语言来说,来自包装箱的各种不一致接口的让人困惑的错误信息,是一个长期存在的问题。希望在未来,Rust 可以产生更多有用的错误信息。
最后,作为一个新用户,这种关注点的分离给我(的理解上)带来了许多阻力。在 Python 中,我可以简单的做 import numpy 操作便可以完成导入。我认为 NumPy 在完全单片化的方向上走的太远了。它最初被编写的时候,用 C 语言扩展的 Python 代码,在打包和发布上,要比现在困难的多。我认为在一个极端方向上走的太远,会让一个语言或者工具的生态系统变得很难学习。
类型和所有权
下一步我将详细介绍 update_mini_batch 的 Rust 版本:
该函数使用了我定义的两个简短的辅助函数,使得代码更简洁了一些:
和 Python 的实现版本相比较,调用 update_mini_batch 的方式有些不同。没有直接传递对象列表,反而传递了对完整集合中全套训练数据和一份索引的引用。这样,更容易理解没有触发器的借用检查器。
创建 nabla_b 和 nabla_win zero_vec_like 和我们在 Python 中使用的列表解析非常类似。有一个挫折让我有些沮丧,因为如果我尝试使用 Array2::zeros 创建一个用 0 填充的数组,并将它传递给一个特定形状的 slice 或 Vec,我就会得到一个 ArrayD 的实例。为了获得 Array2 对象(显然这是一个二维数组,而不是一个通用的 D 纬数组),我需要向 Array::zeros 传递一个元素。然而,由于 ndarray::shape 返回一个切片(slice),我需要使用 to_tuple 函数,将这个切片转换为一个元组。这些事情在 Python 中可以被隐藏,但是在 Rust 中,切片(slice)和元组(tuple)之间的不同(造成的影响)变得非常大,就和在这个 API 中的情况一样。
通过反向传播对估计的权重和偏差进行更新的代码具有和 Python 实现的版本非常相似的结构。我们在小分类中训练每个示例图像,并获得二次成本梯度的估计值作为偏差和权重相关的一个函数:
然后累积这些估计值:
等到我们完成了小分类的处理,我们就会根据学习率更新权重和偏差。
这个例子说明了,Rust 对数组数据在工程上的处理和 Python 相比有区别的。首先,我们不用浮点数 eta/nbatch 乘以数组,而是使用 Array::mapv,并定义一个内连闭包,以便在整个数组上以矢量化的方式进行映射。在 Python 中,由于方法调用比较慢,所以这些事情不会处理的太快。而在 Rust 中,则不会出现这种情况。当我们减去时,还需要借用带 &符号 mapv 的返回值,以免我们在迭代它时消耗数组数据。在 Rust 中,需要仔细考虑函数是否会消耗数据或者引入引用,这导致在概念上,用 Rust 编写这种代码,比在 Python 中要求更多。另一方面,我对我的代码的正确性并且能够编译通过,有了更高的信心。我不确定的是,我写这段代码很费力的原因,是因为 Rust 真的更难写,还是因为我在 Python 和 Rust 上经验的不同。
用 Rust 重写这些代码,然后一切都会好起来
在这里,我留下了一些东西,比我开始使用的未经优化的 Python 版本代码更快。然而,相比 10 倍或是更快的速度,人们可能更期望从像 Python 这样的动态解释性语言转变为像 Rust 这样的编译性能导向语言,并且我也只观察到了 2 倍的加速。可以去理解一下我为什么要测量 Rust 语言的性能表现。幸运的是,这里有一个非常方便的项目,可以为 Rust 工程生成火焰图:flamegraph。这里添加了一个 flamegraph 的子命令 cargo,因此只需要在包中执行 cargo flamegraph 即可运行代码,就会编写出一个可以在浏览器中执行的火焰图 svg 文件(原图为可交互的 svg 脚本,如果希望尝试,可以查看原网页)。
如果你之前还没有看过一个火焰图,(我解释下),在例程中发生的程序运行时间与与该例程的条形宽度成正比。主函数位于图形的底部,主函数调用的函数在图形的顶部。这样你就可以简单查看哪些函数占用了程序中最多的时间。图中非常宽的东西代表了花费最多时间的地方。在调用栈中非常高和宽的函数,在代码上花费了大量的时间。看一下上面的火焰图,我们可以发现一般的时间都花费在了像名字叫 dgemm_kernel_HASWELL 的这类函数身上,这类函数是 OpenBLAS 的线性代数类库,剩下的时间,花费在 update_mini_batch 的数组和分配数组之间的添加上。我程序的其它所有部分,对运行时间的贡献可以忽略不计。
如果我们为 Python 代码做一个类似的火焰图,我们会发现一个相似的情况:大部分时间花费在了做线性代数函数上去(在 np.dot 反向传播例程中调用的地方)。因此,由于不管是 Rust 还是 Python 花费的时间大部分都在数值性的线性代数库上,我们就不能够希望得到一个 10 倍加速的结果。
实际情况比这更糟糕。这本书中的一个练习是重写了使用向量化矩阵乘法的 Python 代码。在这个方法中,每个小分类中的所有样例的反向传播发生在单组矢量化矩阵乘法运算中。这需要能够在 3 维和 2 维数组之间进行矩阵乘法。由于每个矩阵乘法运算使用的数据量大于非向量化的情况,OpenBLAS 可以更有效地使用 CPU 缓存和寄存器,基本上可以更好地利用我笔记本电脑上的可用 CPU 资源。重写的 Python 版本要比 Rust 版本更快,又快了大约两倍左右。
理论上,可以对 Rust 代码进行相同的优化。但是对于高于 2 维(的矩阵)的情况,ndarraycrate还不支持矩阵乘法。也可以使用像rayon这样的库在小批量更新上使用线程并行化。我在我的笔记本上尝试这个(并行化)没有看到任何的加速,但是可能在具有更多 CPU 线程的更强大的机器上会有作用。我也可以尝试使用一个不同的线性代数函数实现,例如,有TensorFlow和Torch的 Rust 构建,但是在这种情况下,我觉得我也可以使用那些库的 Python 构建。
Rust 是否适合数据科学的工作流?
现在我不得不说,答案是”未知“。在未来,当我需要编写具有小依赖性的低级别优化代码时,我肯定会使用 Rust。但是,如果把 Rust 作为 Python 和 C++的完全替代品,还需要一个更稳定和完善的类库生态系统。
评论