剑客
关注科技互联网

Keras.js可以让你使用浏览器在GPU上运行Keras模型

本项目可以让你使用 WebGL 在 GPU 驱动的、你的浏览器上运行训练好的 Keras 模型。模型直接根据 Keras JSON 格式配置文件和关联的 HDF5 权重而序列化(serialized)。

项目地址: https://github.com/transcranial/keras-js

互动演示

  • 用于 MNIST 的基本卷积网络

  • 在 MNIST 上训练的卷积变自编码器(Convolutional Variational Autoencoder)

  • 在 ImageNet 上训练的 50 层的残差网络(Residual Network)

  • 在 ImageNet 上训练的 Inception V3

  • 用于 IMDB 情绪分类的双向 LSTM

Keras.js可以让你使用浏览器在GPU上运行Keras模型

Keras.js可以让你使用浏览器在GPU上运行Keras模型

Keras.js可以让你使用浏览器在GPU上运行Keras模型

Keras.js可以让你使用浏览器在GPU上运行Keras模型

为什么要做这个项目?

  • 消除对后端基础设施或 API 调用的需求

  • 完全将计算卸载到客户端浏览器

  • 互动应用程序

使用方法

查看 demos/src/ 获取真实案例的源代码。

1.对 Model 和 Sequential 都适用

model = Sequential()
model.add(...)
...
...model = Model(input=..., output=...)

一旦训练完成,保存权重和导出模型架构配置:

model.save_weights('model.hdf5')
with open('model.json', 'w') as f:
  f.write(model.to_json())

参见演示的 jupyter notebooks 了解详情: demos/notebooks/

2.在 HDF5 权重文件上运行编码器脚本:

$ python encoder.py /path/to/model.hdf5

这将在同一个文件夹中产生两个用作 HDF5 权重的文件:model_weights.buf 和 model_metadata.json

3.Keras.js 所需的三个文件:

  • 模型文件: model.json

  • 权重文件: model_weights.buf

  • 权重元数据文件: model_metadata.json

4.GPU 支持由 weblas 驱动。将 Keras.js 和 Weblas 库包含进去:

<script src="lib/weblas.js"></script>
<script src="dist/keras.js"></script>

5.创建新模型

实例化时,数据通过 XHR(相同域或要求 CORS)加载,层被初始化为有向无环图。当这些步骤完成之后,类方法 ready() 返回一个解决问题的 Promise。然后,使用 perdict() 让数据通过模型,这也会返回一个 Promise。

const model = new KerasJS.Model({
  filepaths: {
    model: 'url/path/to/model.json',
    weights: 'url/path/to/model_weights.buf',
    metadata: 'url/path/to/model_metadata.json'
  }
  gpu: true})model.ready().then(() => {  
  // input data object keyed by names of the input layers
  // or `input` for Sequential models
  // values are the flattened Float32Array data
  // (input tensor shapes are specified in the model config)
  const inputData = {    'input_1': new Float32Array(data)
  }  
  // make predictions
  // outputData is an object keyed by names of the output layers
  // or `output` for Sequential models
  model.predict(inputData).then(outputData => {    
  // e.g.,
    // outputData['fc1000']
  })
})

可用的层

  • 高级激活: LeakyReLU, PReLU, ELU, ParametricSoftplus, ThresholdedReLU, SReLU

  • 卷积: Convolution1D, Convolution2D, AtrousConvolution2D, SeparableConvolution2D, Deconvolution2D, Convolution3D, UpSampling1D, UpSampling2D, UpSampling3D, ZeroPadding1D, ZeroPadding2D, ZeroPadding3D

  • 内核: Dense, Activation, Dropout, SpatialDropout2D, SpatialDropout3D, Flatten, Reshape, Permute, RepeatVector, Merge, Highway, MaxoutDense

  • 嵌入: Embedding

  • 归一化: BatchNormalization

  • 池化: MaxPooling1D, MaxPooling2D, MaxPooling3D, AveragePooling1D, AveragePooling2D, AveragePooling3D, GlobalMaxPooling1D, GlobalAveragePooling1D, GlobalMaxPooling2D, GlobalAveragePooling2D

  • 循环: SimpleRNN, LSTM, GRU

  • 包装器:Bidirectional, TimeDistributed

还没有实现的层

目前还不能直接实现 Lambda,但最终会创建一个通过 JavaScript 定义计算逻辑的机制。

  • 内核: Lambda

  • 卷积: Cropping1D, Cropping2D, Cropping3D

  • 本地连接: LocallyConnected1D, LocallyConnected2D

  • 噪声 :GaussianNoise, GaussianDropout

备注

WebWorker 及其限制

Keras.js 可以与主线程分开单独运行在 WebWorker 中。因为 Keras.js 会执行大量同步计算,这可以防止该 UI 受到影响。但是,WebWorker 的最大限制之一是缺乏 <canvas> 访问(所以要用 WebGL)。所以在单独的线程中运行 Keras.js 的好处被必须运行在 CPU 模式中的要求抵消了。换句话说,在 GPU 模式中运行的 Keras.js 只能运行在主线程上。

WebGL MAX_TEXTURE_SIZE

在 GPU 模式中,张量对象被编码成了计算之前的 WebGL textures。这些张量的大小由 gl.getParameter(gl.MAX_TEXTURE_SIZE) 限定,这会根据硬件或平台的状况而有所不同。参考 http://webglstats.com/ 了解典型的预期值。在 im2col 之后,卷积层中可能会有一个问题。比如在 Inception V3 网络演示中,第一层卷积层中 im2col 创造了一个 22201 x 27 的矩阵,并在第二层和第三层卷积层中创造 21609 x 288 的矩阵。第一个维度上的大小超过了MAX_TEXTURE_SIZE 的最大值 16384,所以必须被分割开。根据权重为每一个分割开的张量执行矩阵乘法,然后再组合起来。在这个案例中,当createWeblasTensor() 被调用时,Tensor 对象上会提供一个 weblasTensorsSplit 属性。了解其使用的例子可查看 src/layers/convolutional/Convolution2D.js

开发/测试

对于每一个实现的层都存在广泛的测试。查看 notebooks/ 获取为所有这些测试生成数据的 jupyter notebooks。

$ npm install

要运行所有测试,执行 npm run server 并访问 http://localhost:3000/test/。所有的测试都会自动运行。打开你的浏览器的开发工具获取额外的测试数据信息。

对于开发,请运行:

$ npm run watch

编辑 src/ 中的任意文件都会触发 webpack 来更新 dist/keras.js

要创建生产型的 UMD webpack 版本,输出到 dist/keras.js ,运行:

$ npm run build

证书

MIT

分享到:更多 ()

评论 抢沙发

  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址