TensorFlow Lite 设备端训练

郁光熙
2023-12-01

TensorFlow Lite 是 Google 的机器学习框架,用于在多种设备和平台上部署机器学习模型,例如移动设备(iOS 和 Android)、桌面设备和其他边缘设备。

  • TensorFlow Lite

    https://www.tensorflow.google.cn/lite

最近,我们又添加了在浏览器中运行 TensorFlow Lite 模型的支持。要使用 TensorFlow Lite 构建应用,您可以利用 TensorFlow Hub 中的现成模型,或者使用转换器将现有的 TensorFlow 模型转换为 TensorFlow Lite 模型。

  • 构建应用

    https://www.tensorflow.google.cn/lite/guide#development_workflow

  • TensorFlow Hub

    https://hub.tensorflow.google.cn/s?deployment-format=lite

  • 转换器

    https://tensorflow.google.cn/lite/convert/index

模型部署到应用中后,您可以基于输入数据在该模型上运行推理。

  • 运行推理

    https://tensorflow.google.cn/lite/guide#2_run_inference

除运行推理外,TensorFlow Lite 现在还支持在设备端训练模型。设备端训练支持有趣的个性化用例,其中模型可以根据用户需求进行微调。例如,您可以部署一个图像分类模型,允许用户使用迁移学习对模型进行微调来识别鸟类,同时允许其他用户重新训练该模型来识别水果。这项新功能在 TensorFlow 2.7 及以上版本中提供,现在可用于 Android 应用,并会在未来增加对 iOS 的支持。

  • 迁移学习

    https://developers.google.com/machine-learning/glossary#transfer-learning

设备端训练也是根据分散式数据训练全局模型的联合学习用例的必要基础。本文文章不会涉及到联合学习,而是侧重帮助您在 Android 应用中集成设备端训练。

  • 联合

    https://ai.googleblog.com/2017/04/federated-learning-collaborative.html

参考 Colab 和 Android 示例应用,向您介绍设备端学习的端到端实现路径,引导您完成图像分类模型的微调。

  • Colab

    https://tensorflow.google.cn/lite/examples/on_device_training/overview

  • Android 示例应用

    https://github.com/tensorflow/examples/tree/master/lite/examples/model_personalization

对早期方法的改进

我们在 2019 年的文章中介绍了设备端训练的概念,并展示了一个在 TensorFlow Lite 中进行设备端训练的示例。但是,当时存在几个限制。比如,自定义模型结构和优化器并不容易。您还必须处理多个物理 TensorFlow Lite (.tflite) 模型,而不是单个 TensorFlow Lite 模型。同样,存储和更新训练权重也没有简单的方法。我们最新的 TensorFlow Lite 版本提供更便捷的设备端训练选项,简化了这个过程,接下来就给大家介绍一下。

  • 文章

    https://blog.tensorflow.google.cn/2019/12/example-on-device-model-personalization.html

它是怎样实现的呢?

要部署内置设备端训练的 TensorFlow Lite 模型,简要步骤如下:

  • 构建用于训练和推理的 TensorFlow 模型

  • 将 TensorFlow 模型转换为 TensorFlow Lite 格式

  • 将模型集成到您的 Android 应用中

  • 在应用中调用模型训练,与调用模型推理的方式类似

具体步骤如下。

构建用于训练和推理的 TensorFlow 模型

TensorFlow Lite 模型应当同时支持模型推理和模型训练,训练通常涉及将模型的权重保存到文件系统,并从文件系统中恢复权重。这样做是为了在每个训练周期结束后保存训练权重,以便下个训练周期可以使用前一个周期的权重,而不是从头开始训练。

  • 一个使用训练数据训练模型的 train 函数。如下的 train 函数进行预测,计算损失(或误差),使用 tf.GradientTape() 记录自动微分的操作并更新模型的参数。

  • train

    https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/examples/on_device_training/overview.ipynb#scrollTo=d8577c80&line=38&uniqifier=1

  • 自动微分

    https://tensorflow.google.cn/guide/autodiff#automatic_differentiation_and_gradients

# The `train` function takes a batch of input images and labels.
@tf.function(input_signature=[
     tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32),
     tf.TensorSpec([None, 10], tf.float32),
 ])
def train(self, x, y):
   with tf.GradientTape() as tape:
     prediction = self.model(x)
     loss = self._LOSS_FN(prediction, y)
   gradients = tape.gradient(loss, self.model.trainable_variables)
   self._OPTIM.apply_gradients(
       zip(gradients, self.model.trainable_variables))
   result = {"loss": loss}
   for grad in gradients:
     result[grad.name] = grad
   return result

  • 一个调用模型推理的 infer 函数或 predict 函数。这和您目前使用 TensorFlow Lite 进行推理的方法类似。

  • infer

    https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/examples/on_device_training/overview.ipynb#scrollTo=d8577c80&line=38&uniqifier=1

@tf.function(input_signature=[tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32)])
 def predict(self, x):
   return {
       "output": self.model(x)
   }

  • 一个 save/restore 函数,将训练权重(即模型使用的参数)以 Checkpoints 格式保存到文件系统。该 save 函数的代码如下所示。

  • save/restore

    https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/examples/on_device_training/overview.ipynb#scrollTo=d8577c80&line=38&uniqifier=1

@tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
 def save(self, checkpoint_path):
   tensor_names = [weight.name for weight in self.model.weights]
   tensors_to_save = [weight.read_value() for weight in self.model.weights]
   tf.raw_ops.Save(
       filename=checkpoint_path, tensor_names=tensor_names,
       data=tensors_to_save, name='save')
   return {
       "checkpoint_path": checkpoint_path
   }

转换为 TensorFlow Lite 格式

您可能已经熟悉将 TensorFlow 模型转换为 TensorFlow Lite 格式的工作流。设备端训练的一些低级功能(例如,存储模型参数的变量)仍处于实验阶段,而其他(例如,权重序列化)目前依赖于 TF Select 运算符,因此您需要在转换过程中设置这些标志。您可以在 Colab 中找到所有需要设置标志的示例。

  • 转换

    https://tensorflow.google.cn/lite/convert

  • TF Select

    https://tensorflow.google.cn/lite/guide/ops_select

  • Colab

    https://www.tensorflow.org/lite/examples/on_device_training/overview

# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_DIR)
converter.target_spec.supported_ops = [
   tf.lite.OpsSet.TFLITE_BUILTINS,  # enable TensorFlow Lite ops.
   tf.lite.OpsSet.SELECT_TF_OPS  # enable TensorFlow ops.
]
converter.experimental_enable_resource_variables = True
tflite_model = converter.convert()

将模型集成到您的 Android 应用中

将模型转换为 TensorFlow Lite 格式后,您就可以将模型集成到应用中了!更多详细信息,请参阅 Android 应用示例。

  • Android

    https://github.com/tensorflow/examples/tree/master/lite/examples/model_personalization

在应用中调用模型训练和推理

在 Android 中,可以使用 Java 或 C++ API 执行 TensorFlow Lite 设备端训练。您可以创建一个 TensorFlow Lite Interpreter 的实例来加载模型和驱动模型训练任务。我们先前已经定义了多个 tf.functions:可以使用 TensorFlow Lite 对签名的支持来调用这些函数,签名允许单个 TensorFlow Lite 模型支持多个“入口”点。例如,我们为设备端训练定义了一个 train 函数, 这是模型的其中一个签名。通过指定签名的名称 (“train”)使用 TensorFlow Lite 的 runSignature 方法,即可调用 train 函数:

  • Interpreter

    https://tensorflow.google.cn/lite/guide/inference#load_and_run_a_model_in_java

  • 签名

    https://tensorflow.google.cn/lite/guide/signatures

// Run training for a few steps.
float[] losses = new float[NUM_EPOCHS];
for (int epoch = 0; epoch < NUM_EPOCHS; ++epoch) {
    for (int batchIdx = 0; batchIdx < NUM_BATCHES; ++batchIdx) {
        Mapinputs = new HashMap<>>();
        inputs.put("x", trainImageBatches.get(batchIdx));
        inputs.put("y", trainLabelBatches.get(batchIdx));

        Mapoutputs = new HashMap<>();
        FloatBuffer loss = FloatBuffer.allocate(1);
        outputs.put("loss", loss);

        interpreter.runSignature(inputs, outputs, "train");

        // Record the last loss.
        if (batchIdx == NUM_BATCHES - 1) losses[epoch] = loss.get(0);
    }
}

同样,下面的示例展示了如何使用模型的“infer”签名调用推理函数:

try (Interpreter anotherInterpreter = new Interpreter(modelBuffer)) {
    // Restore the weights from the checkpoint file.

    int NUM_TESTS = 10;
    FloatBuffer testImages = FloatBuffer.allocateDirect(NUM_TESTS * 28 * 28).order(ByteOrder.nativeOrder());
    FloatBuffer output = FloatBuffer.allocateDirect(NUM_TESTS * 10).order(ByteOrder.nativeOrder());

    // Fill the test data.

    // Run the inference.
    Mapinputs = new HashMap<>>();
    inputs.put("x", testImages.rewind());
    Mapoutputs = new HashMap<>();
    outputs.put("output", output);
    anotherInterpreter.runSignature(inputs, outputs, "infer");
    output.rewind();

    // Process the result to get the final category values.
    int[] testLabels = new int[NUM_TESTS];
    for (int i = 0; i < NUM_TESTS; ++i) {
        int index = 0;
        for (int j = 1; j < 10; ++j) {
            if (output.get(i * 10 + index) < output.get(i * 10 + j))
                index = testLabels[j];
        }
        testLabels[i] = index;
    }
}

就这么简单!现在您拥有了一个可以使用设备端训练的 TensorFlow Lite 模型。

 类似资料: