订阅号怎么做免费的视频网站吗,网站建站如何做seo,东营市造价信息网,如何把wordpress头部去掉使用 PyTorch 数据读取#xff0c;JAX 框架来训练一个简单的神经网络
本文例程部分主要参考官方文档。
JAX简介
JAX 的前身是 Autograd #xff0c;也就是说 JAX 是 Autograd 升级版本#xff0c;JAX 可以对 Python 和 NumPy 程序进行自动微分。可以通过 Python的大量特征…使用 PyTorch 数据读取JAX 框架来训练一个简单的神经网络
本文例程部分主要参考官方文档。
JAX简介
JAX 的前身是 Autograd 也就是说 JAX 是 Autograd 升级版本JAX 可以对 Python 和 NumPy 程序进行自动微分。可以通过 Python的大量特征子集进行区分包括循环、分支、递归和闭包语句进行自动求导也可以求三阶导数(三阶导数是由原函数导数的导数的导数。 所谓三阶导数即原函数导数的导数的导数将原函数进行三次求导)。通过 grad JAX 支持反向模式和正向模式的求导而且这两种模式可以任意组合成任何顺序具有一定灵活性。
另一个特点是基于 XLA 的 JIT 即时编译大大提高速度。
需要注意的是JAX 仅提供计算时的优化相当于是一个支持自动微分和 JIT 编译的 NumPy。也就是说数据处理 Dataloader 等其他框架都会提供的 utils 功能这里是没有的。所幸 JAX 可以比较好的支持 PyTorch、 TensorFlow 等主流框架的数据读取。本文就将基于 PyTorch 的数据读取工具和 JAX 框架来训练一个简单的神经网络。
以下是国内优秀的机器学习框架 OneFlow 同名公司的创始人袁进辉老师在知乎上的一个评价 如果说tensorflow 是主打lazy, 偏functional 的思想但实现的臃肿面目可憎pytorch 则主打eager, 偏imperative 编程但内核简单可视为支持gpu的numpy, 加上一个autograd。JAX 像是这俩框架的混合体取了tensorflow的functional和PyTorch的精简即支持gpu的 numpy, 具有autograd功能非常追求函数式编程的思想强调无状态immutable加上JIT修饰符后就是lazy可以使用xla对计算流程进行静态分析和优化。当然JAX不带jit也可像pytorch那种命令式编程和eager执行。 JAX有可能和PyTorch竞争。 安装
安装可以通过源码编译也可以直接 pip。源码编译详见[官方文档: Building from source][2]对于官方没有提供预编译包的 cuda-cudnn 版本组合只能通过自己源码构建。pip的方式比较简单在 github 仓库的 README 文档中就有介绍。要注意不同于 PyTorch 等框架JAX 不会再 pip 安装中绑定 CUDA 或 cuDNN 进行安装若未安装需要自己先手动安装。仅使用 CPU 的版本也有支持。
笔者是 CUDA11.1CUDNN 8.2安装如下
pip install --upgrade pip
pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html前面已经提到过本文会借用 PyTorch 的数据处理工具因此 torch 和 torchvision 也是必不可少的已经安装的可跳过
pip install torch torchvision构建简单的神经网络训练
框架安装完毕我们正式开始。接下来我们使用 JAX 在 MNIST 上指定和训练一个简单的 MLP 进行计算用 PyTorch 的数据加载 API 来加载图像和标签。
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random超参数
# 本函数用来随机初始化网络权重
def random_layer_params(m, n, key, scale1e-2):w_key, b_key random.split(key)return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n, ))# 初始化各个全连接层
def init_network_params(sizes, key):keys random.split(key, len(sizes))return [random_layer_params(m, n, k) for m, n, k in zip(sizes[: -1], sizes[1: ], keys)]layer_sizes [784, 512, 512, 10]
step_size 0.01
num_epochs 8
batch_size 128
n_targets 10
params init_network_params(layer_sizes, random.PRNGKey(0))自动分批次预测
对于小批量我们稍后将使用 JAX 的 vmap 函数来自动处理而不会降低性能。我们现在先准备一个单张图像推理预测函数
from jax.scipy.special import logsumexpdef relu(x):return jnp.maximum(0, x)# 对单张图像进行推理的函数
def predict(params, image):activations imagefor w, b in params[: -1]:outputs jnp.dot(w, activations) bactivations relu(outputs)final_w, final_b params[-1]logits jnp.dot(final_w, activations) final_breturn logits - logsumexp(logits)这个函数应该只能用来处理单张图像推理预测而不能批量处理我们简单测试一下对于单张
random_flattened_images random.normal(random.PRNGKey(1), (28 * 28,))
preds predict(params, random_flattened_images)
print(preds.shape)输出
(10,)对于批次
random_flattened_images random.normal(random.PRNGKey(1), (10, 28 * 28))
try:preds predict(params, random_flattened_images)
except TypeError:print(Invalid shapes!)输出
Invalid shapes!现在我们使用 vmap 来使它能够处理批量数据
# 用 vmap 来实现一个批量版本
batched_predict vmap(predict, in_axes(None, 0))# batched_predict 的调用与 predict 相同
batched_preds batched_predict(params, random_flattened_images)
print(batched_preds.shape)输出
(10, 10)现在我们已经做好了准备工作接下来就是要定义一个神经网络并且进行训练了我们已经构建了的自动批处理版本的 predict 函数并且将在损失函数中也使用它。我们将使用 grad 来得到损失关于神经网络参数的导数。而且这一切都可以用 jit 进行加速。
实用工具函数和损失函数
def one_hot(x, k, dtypejnp.float32):构建一个 x 的 k 维 one-hot 编码.return jnp.array(x[:, None] jnp.arange(k), dtype)def accuracy(params, images, targets):target_class jnp.argmax(targets, axis1)predicted_class jnp.argmax(batched_predict(params, images), axis1)return jnp.mean(predicted_class target_class)def loss(params, images, targets):preds batched_predict(params, images)return -jnp.mean(preds * targets)jit
def update(params, x, y):grads grad(loss)(params, x, y)return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)]使用 PyTorch 进行数据读取
JAX 是一个专注于程序转换和支持加速的 NumPy对于数据的读取已经有很多优秀的工具了这里我们就直接用 PyTorch 的 API。我们会做一个小的 shim 来使得它能够支持 NumPy 数组。
import numpy as np
from torch.utils import data
from torchvision.datasets import MNISTdef numpy_collate(batch):if isinstance(batch[0], np.ndarray):return np.stack(batch)elif isinstance(batch[0], (tuple, list)):transposed zip(*batch)return [numpy_collate(samples) for samples in transposed]else:return np.array(batch)class NumpyLoader(data.DataLoader):def __init__(self, dataset, batch_size1,shuffleFalse, samplerNone,batch_samplerNone, num_workers0,pin_memoryFalse, drop_lastFalse,timeout0, worker_init_fnNone):super(self.__class__, self).__init__(dataset,batch_sizebatch_size,shuffleshuffle,samplersampler,batch_samplerbatch_sampler,collate_fnnumpy_collate,num_workersnum_workers,pin_memorypin_memory,drop_lastdrop_last,timeouttimeout,worker_init_fnworker_init_fn)class FlattenAndCast(object):def __call__(self, pic):return np.ravel(np.array(pic, dtypejnp.float32))接下来借助 PyTorch 的 datasets定义我们自己的 dataset
mnist_dataset MNIST(/tmp/mnist/, downloadTrue, transformFlattenAndCast())
training_generator NumpyLoader(mnist_dataset, batch_sizebatch_size, num_workers0)此处应该输出一堆下载 MNIST 数据集的信息就不贴了。
接下来分别拿到整个训练集和整个测试集下面会用于测准确率
train_images np.array(mnist_dataset.train_data).reshape(len(mnist_dataset.train_data), -1)
train_labels one_hot(np.array(mnist_dataset.train_labels), n_targets)mnist_dataset_test MNIST(/tmp/mnist/, downloadTrue, trainFalse)
test_images jnp.array(mnist_dataset_test.test_data.numpy().reshape(len(mnist_dataset_test.test_data), -1), dtypejnp.float32)
test_labels one_hot(np.array(mnist_dataset_test.test_labels), n_targets)开始训练
import time
for epoch in range(num_epochs):start_time time.time()for x, y in training_generator:y one_hot(y, n_targets)params update(params, x, y)epoch_time time.time() - start_timetrain_acc accuracy(params, train_images, train_labels)test_acc accuracy(params, test_images, test_labels)print(Epoch {} in {:0.2f} sec.format(epoch, epoch_time))print(Training set accuracy {}.format(train_acc))print(Test set accuracy {}.format(test_acc))输出
Epoch 0 in 3.29 sec
Training set accuracy 0.9156666994094849
Test set accuracy 0.9196999669075012
...
Epoch 7 in 1.78 sec
Training set accuracy 0.9736666679382324
Test set accuracy 0.9670999646186829在本文的过程中我们已经使用了整个 JAX APIgrad 用于自动微分、jit 用于加速、vmap 用于自动矢量化。我们使用 NumPy 来进行我们所有的计算并从 PyTorch 借用了出色的数据加载器并在 GPU 上运行了整个过程。
Ref
https://juejin.cn/post/6994695537316331556
https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html
https://jax.readthedocs.io/en/latest/developer.html#building-from-source