当前位置: 首页 > news >正文

订阅号怎么做免费的视频网站吗网站建站如何做seo

订阅号怎么做免费的视频网站吗,网站建站如何做seo,东营市造价信息网,如何把wordpress头部去掉使用 PyTorch 数据读取#xff0c;JAX 框架来训练一个简单的神经网络 本文例程部分主要参考官方文档。 JAX简介 JAX 的前身是 Autograd #xff0c;也就是说 JAX 是 Autograd 升级版本#xff0c;JAX 可以对 Python 和 NumPy 程序进行自动微分。可以通过 Python的大量特征…使用 PyTorch 数据读取JAX 框架来训练一个简单的神经网络 本文例程部分主要参考官方文档。 JAX简介 JAX 的前身是 Autograd 也就是说 JAX 是 Autograd 升级版本JAX 可以对 Python 和 NumPy 程序进行自动微分。可以通过 Python的大量特征子集进行区分包括循环、分支、递归和闭包语句进行自动求导也可以求三阶导数(三阶导数是由原函数导数的导数的导数。 所谓三阶导数即原函数导数的导数的导数将原函数进行三次求导)。通过 grad JAX 支持反向模式和正向模式的求导而且这两种模式可以任意组合成任何顺序具有一定灵活性。 另一个特点是基于 XLA 的 JIT 即时编译大大提高速度。 需要注意的是JAX 仅提供计算时的优化相当于是一个支持自动微分和 JIT 编译的 NumPy。也就是说数据处理 Dataloader 等其他框架都会提供的 utils 功能这里是没有的。所幸 JAX 可以比较好的支持 PyTorch、 TensorFlow 等主流框架的数据读取。本文就将基于 PyTorch 的数据读取工具和 JAX 框架来训练一个简单的神经网络。 以下是国内优秀的机器学习框架 OneFlow 同名公司的创始人袁进辉老师在知乎上的一个评价 如果说tensorflow 是主打lazy, 偏functional 的思想但实现的臃肿面目可憎pytorch 则主打eager, 偏imperative 编程但内核简单可视为支持gpu的numpy, 加上一个autograd。JAX 像是这俩框架的混合体取了tensorflow的functional和PyTorch的精简即支持gpu的 numpy, 具有autograd功能非常追求函数式编程的思想强调无状态immutable加上JIT修饰符后就是lazy可以使用xla对计算流程进行静态分析和优化。当然JAX不带jit也可像pytorch那种命令式编程和eager执行。 JAX有可能和PyTorch竞争。 安装 安装可以通过源码编译也可以直接 pip。源码编译详见[官方文档: Building from source][2]对于官方没有提供预编译包的 cuda-cudnn 版本组合只能通过自己源码构建。pip的方式比较简单在 github 仓库的 README 文档中就有介绍。要注意不同于 PyTorch 等框架JAX 不会再 pip 安装中绑定 CUDA 或 cuDNN 进行安装若未安装需要自己先手动安装。仅使用 CPU 的版本也有支持。 笔者是 CUDA11.1CUDNN 8.2安装如下 pip install --upgrade pip pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html前面已经提到过本文会借用 PyTorch 的数据处理工具因此 torch 和 torchvision 也是必不可少的已经安装的可跳过 pip install torch torchvision构建简单的神经网络训练 框架安装完毕我们正式开始。接下来我们使用 JAX 在 MNIST 上指定和训练一个简单的 MLP 进行计算用 PyTorch 的数据加载 API 来加载图像和标签。 import jax.numpy as jnp from jax import grad, jit, vmap from jax import random超参数 # 本函数用来随机初始化网络权重 def random_layer_params(m, n, key, scale1e-2):w_key, b_key random.split(key)return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n, ))# 初始化各个全连接层 def init_network_params(sizes, key):keys random.split(key, len(sizes))return [random_layer_params(m, n, k) for m, n, k in zip(sizes[: -1], sizes[1: ], keys)]layer_sizes [784, 512, 512, 10] step_size 0.01 num_epochs 8 batch_size 128 n_targets 10 params init_network_params(layer_sizes, random.PRNGKey(0))自动分批次预测 对于小批量我们稍后将使用 JAX 的 vmap 函数来自动处理而不会降低性能。我们现在先准备一个单张图像推理预测函数 from jax.scipy.special import logsumexpdef relu(x):return jnp.maximum(0, x)# 对单张图像进行推理的函数 def predict(params, image):activations imagefor w, b in params[: -1]:outputs jnp.dot(w, activations) bactivations relu(outputs)final_w, final_b params[-1]logits jnp.dot(final_w, activations) final_breturn logits - logsumexp(logits)这个函数应该只能用来处理单张图像推理预测而不能批量处理我们简单测试一下对于单张 random_flattened_images random.normal(random.PRNGKey(1), (28 * 28,)) preds predict(params, random_flattened_images) print(preds.shape)输出 (10,)对于批次 random_flattened_images random.normal(random.PRNGKey(1), (10, 28 * 28)) try:preds predict(params, random_flattened_images) except TypeError:print(Invalid shapes!)输出 Invalid shapes!现在我们使用 vmap 来使它能够处理批量数据 # 用 vmap 来实现一个批量版本 batched_predict vmap(predict, in_axes(None, 0))# batched_predict 的调用与 predict 相同 batched_preds batched_predict(params, random_flattened_images) print(batched_preds.shape)输出 (10, 10)现在我们已经做好了准备工作接下来就是要定义一个神经网络并且进行训练了我们已经构建了的自动批处理版本的 predict 函数并且将在损失函数中也使用它。我们将使用 grad 来得到损失关于神经网络参数的导数。而且这一切都可以用 jit 进行加速。 实用工具函数和损失函数 def one_hot(x, k, dtypejnp.float32):构建一个 x 的 k 维 one-hot 编码.return jnp.array(x[:, None] jnp.arange(k), dtype)def accuracy(params, images, targets):target_class jnp.argmax(targets, axis1)predicted_class jnp.argmax(batched_predict(params, images), axis1)return jnp.mean(predicted_class target_class)def loss(params, images, targets):preds batched_predict(params, images)return -jnp.mean(preds * targets)jit def update(params, x, y):grads grad(loss)(params, x, y)return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)]使用 PyTorch 进行数据读取 JAX 是一个专注于程序转换和支持加速的 NumPy对于数据的读取已经有很多优秀的工具了这里我们就直接用 PyTorch 的 API。我们会做一个小的 shim 来使得它能够支持 NumPy 数组。 import numpy as np from torch.utils import data from torchvision.datasets import MNISTdef numpy_collate(batch):if isinstance(batch[0], np.ndarray):return np.stack(batch)elif isinstance(batch[0], (tuple, list)):transposed zip(*batch)return [numpy_collate(samples) for samples in transposed]else:return np.array(batch)class NumpyLoader(data.DataLoader):def __init__(self, dataset, batch_size1,shuffleFalse, samplerNone,batch_samplerNone, num_workers0,pin_memoryFalse, drop_lastFalse,timeout0, worker_init_fnNone):super(self.__class__, self).__init__(dataset,batch_sizebatch_size,shuffleshuffle,samplersampler,batch_samplerbatch_sampler,collate_fnnumpy_collate,num_workersnum_workers,pin_memorypin_memory,drop_lastdrop_last,timeouttimeout,worker_init_fnworker_init_fn)class FlattenAndCast(object):def __call__(self, pic):return np.ravel(np.array(pic, dtypejnp.float32))接下来借助 PyTorch 的 datasets定义我们自己的 dataset mnist_dataset MNIST(/tmp/mnist/, downloadTrue, transformFlattenAndCast()) training_generator NumpyLoader(mnist_dataset, batch_sizebatch_size, num_workers0)此处应该输出一堆下载 MNIST 数据集的信息就不贴了。 接下来分别拿到整个训练集和整个测试集下面会用于测准确率 train_images np.array(mnist_dataset.train_data).reshape(len(mnist_dataset.train_data), -1) train_labels one_hot(np.array(mnist_dataset.train_labels), n_targets)mnist_dataset_test MNIST(/tmp/mnist/, downloadTrue, trainFalse) test_images jnp.array(mnist_dataset_test.test_data.numpy().reshape(len(mnist_dataset_test.test_data), -1), dtypejnp.float32) test_labels one_hot(np.array(mnist_dataset_test.test_labels), n_targets)开始训练 import time for epoch in range(num_epochs):start_time time.time()for x, y in training_generator:y one_hot(y, n_targets)params update(params, x, y)epoch_time time.time() - start_timetrain_acc accuracy(params, train_images, train_labels)test_acc accuracy(params, test_images, test_labels)print(Epoch {} in {:0.2f} sec.format(epoch, epoch_time))print(Training set accuracy {}.format(train_acc))print(Test set accuracy {}.format(test_acc))输出 Epoch 0 in 3.29 sec Training set accuracy 0.9156666994094849 Test set accuracy 0.9196999669075012 ... Epoch 7 in 1.78 sec Training set accuracy 0.9736666679382324 Test set accuracy 0.9670999646186829在本文的过程中我们已经使用了整个 JAX APIgrad 用于自动微分、jit 用于加速、vmap 用于自动矢量化。我们使用 NumPy 来进行我们所有的计算并从 PyTorch 借用了出色的数据加载器并在 GPU 上运行了整个过程。 Ref https://juejin.cn/post/6994695537316331556 https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html https://jax.readthedocs.io/en/latest/developer.html#building-from-source
http://www.yutouwan.com/news/480092/

相关文章:

  • 如何用php做网站做网站王仁杰
  • asp企业网站模板下载苏州网站设计公司山东济南兴田德润什么活动
  • 在网站的标题上怎么做图标电源网站模版
  • 苏州做网站公司排名网站建设合同属于技术服务么
  • 福州企业制作网站mysql 网站空间
  • 自助建站管理平台蜂聘原360建筑网
  • 中国建设银行网站怎么改支付密码是什么网站开发平台
  • 网站建设的目标和需求分析科技有限公司 翻译
  • 深圳最好的营销网站建设公司哪家好网站权重一直做不上去
  • 学校网站开发协议怎么样做推广网站
  • 可做装饰推广的网站wordpress 图片点击放大
  • 网站开发可行性分析报告台州网站建设系统
  • 公司网站续费一年多少钱如何进行网络销售
  • 如何建一个商业网站叮当app制作平台下载
  • 保定网站建设seo优化营销天空彩票网站怎么做
  • 青岛企业网站开发ag1042入口快速入口
  • 可以做试卷的网站英语怎么说程序员外包公司是什么意思
  • 网站建设技术问题永久使用免费虚拟主机
  • 网站建设学什么书在邯郸开互联网公司
  • 台州网站公司那里好运城市网站建设
  • 网站群管理平台免费传奇网站免费传奇
  • 哪些行业做网站最重要网时代教育培训机构怎么样
  • 专业做家具的网站有哪些木兰姐网站建设
  • 怎样在百度建立自己的网站网页设计秀丽南宁
  • 企业网站的常见服务做一个网站的策划方案
  • 网站开发用什么笔记本医疗机械网站怎么做
  • 做网站的视频教学焦作网站开发
  • 泰安网站建设xtempire广元市剑阁县建设局网站
  • 盐城网站开发厂商顶尖网站建设公司
  • 响应式网站用什么做建立网站的主要步骤