当前位置: 首页 > 工具软件 > TensorFlow.js > 使用案例 >

tensorflowjs保存并加载tf.Model

西门伟
2023-12-01

将tf.Model:保存到Web浏览器的本地存储。本地存储是标准的客户端数据存储。保存在那里的数据可以在同一页面的多个负载中持续存在。
假设你有一个tf.Model名为的对象model。无论是从头开始使用Layers API还是从预训练的Keras模型加载/微调,您都可以使用一行代码将其保存到本地存储:

const saveResult = await model.save('localstorage://my-model-1');
  • 该save方法采用类似于URL的字符串参数,该参数以方案开头。在这种情况下,我们使用该localstorage://方案指定将模型保存到本地存储
  • 该计划之后是一条路径。在保存到本地存储的情况下,路径只是一个任意字符串,用于唯一标识要保存的模型。例如,当您从本地存储加载模型时,将使用它。
  • 该save方法是异步的,因此您需要使用then或者await如果其完成形成其他操作的前提条件。
  • 返回值model.save是一个JSON对象,它携带一些可能有用的信息,例如模型拓扑和权重的字节大小。
  • 任何tf.Model,无论它是否由 tf.sequential 构成,它包含哪些类型的层,都可以这种方式保存。
    下表列出了所有当前支持的保存模型目的地及其相应的方案和示例。
保存目的地   方案字符串   代码示例
本地存储(浏览器)   localstorage://  await model.save('localstorage://my-model-1');
IndexedDB(浏览器)  indexeddb://    await model.save('indexeddb://my-model-1');
触发文件下载(浏览器) downloads://    await model.save('downloads://my-model-1');
HTTP请求(浏览器) http:// 要么 https://  await model.save('http://model-server.domain/upload');
文件系统(Node.js)   file://           await model.save('file:///tmp/my-model-1');

IndexedDB

IndexedDB 是大多数主流Web浏览器支持的另一个客户端数据存储。与本地存储不同,它更好地支持存储大型二进制数据(BLOB)和更大的配额。因此,tf.Model与本地存储相比,保存到IndexedDB通常可以提供更好的存储效率和更大的大小限制。

文件下载

该downloads://方案后面的字符串是将要下载的文件名称的前缀。例如,该行将 model.save(‘downloads://my-model-1’)导致浏览器下载两个共享相同文件名前缀的文件:

  • 一个名为的文本JSON文件my-model-1.json,它在其modelTopology字段中包含模型的拓扑,并在其字段中显示权重清单 weightsManifest。
  • 一个二进制文件,带有名为的权重值my-model-1.weights.bin。
    这些文件的格式与tensorflowjs转换器从Keras HDF5文件转换的工件格式相同。

注意:某些浏览器要求用户在同时下载多个文件之前授予权限。

HTTP请求

如果tf.Model.save使用HTTP / HTTPS URL调用,则模型的拓扑和权重将通过POST请求发送到指定的HTTP服务器 。POST请求的主体具有一个名为的格式 multipart/form-data。它是用于将文件上载到服务器的标准MIME格式。正文由两个文件组成,文件名model.json和文件名 model.weights.bin。文件格式与downloads://方案触发的下载文件格式相同(参见上文)。此 文档字符串 包含一个Python代码片段,演示了如何使用烧瓶 Web框架以及Keras和TensorFlow来处理源自save请求的有效负载并将其重新构建为服务器内存中的Keras Model对象。

通常,您的HTTP服务器对请求有特殊约束和要求,例如HTTP方法,标头和身份验证凭据。您可以save通过将URL字符串参数替换为调用来获得对请求的这些方面的细粒度控制tf.io.browserHTTPRequest。它是一个更详细的API,但它在控制由此产生的HTTP请求时提供了更大的灵活性save。例如:

await model.save(tf.io.browserHTTPRequest(
    'http://model-server.domain/upload',
    {method: 'PUT', headers: {'header_key_1': 'header_value_1'}}));

TensorFlow.js可以在Node.js中使用。有关更多详细信息,请参阅 tfjs-node项目。与Web浏览器不同,Node.js可以直接访问本地文件系统。因此,您可以将tf.Models 保存到文件系统,就像在Keras中将模型保存到磁盘一样。要执行此操作,请使用file://URL方案,然后使用 要保存模型工件的目录路径,例如:

await model.save('file:///tmp/my-model-1');

上面的命令将在目录中生成一个model.json文件和一个weights.bin文件/tmp/my-model-1。这两个文件的格式与上面的“文件下载”和“HTTP请求”部分中描述的文件格式相同。保存模型后,可以将其加载回运行TensorFlow.js的Node.js程序,或者为TensorFlow.js的浏览器版本提供服务。要实现前者,请tf.loadModel()使用model.json文件路径调用:

const model = await tf.loadModel('file:///tmp/my-model-1/model.json');

将保存的tf.Model转换为Keras格式

  • 通过文件从Web浏览器下载,使用该downloads://方案
  • 使用该file://方案将模型直接写入Node.js中的本机文件系统 。使用tensorflowjs转换器,您可以将这些文件转换为HDF5格式,然后可以将其加载到Python中的Keras中。例如:
pip install tensorflowjs

tensorflowjs_converter \
    --input_format tensorflowjs --output_format keras \
    ./my-model-1.json /tmp/my-model-1.h5
 类似资料: