1.6.9.3 使用Java客户端

优质
小牛编辑
127浏览
2023-12-01

简介

Java应用可以直接访问TensorFlow serving加载模型提供的服务,我们需要编写Java的gRPC客户端代码。

完整例子

这里有一个导出模型使用Java来访问模型的例子 https://github.com/tobegit3hub/deep_recommend_system/tree/master/java_predict_client

使用时通过Maven编译即可,不同模型只需要修改一个Java文件,其他外部依赖已经管理好,建议在此项目中修改使用。

Java客户端实现原理

Java无论是服务端还是客户端都是在独立于grpc的项目中实现,代码在 https://github.com/grpc/grpc-java 。使用时需要引入grpc实现的类,建议使用maven管理依赖,在pom.xml中加入下面的依赖。

<dependency>
  <groupId>io.grpc</groupId>
  <artifactId>grpc-netty</artifactId>
  <version>1.0.0</version>
</dependency>
<dependency>
  <groupId>io.grpc</groupId>
  <artifactId>grpc-protobuf</artifactId>
  <version>1.0.0</version>
</dependency>
<dependency>
  <groupId>io.grpc</groupId>
  <artifactId>grpc-stub</artifactId>
  <version>1.0.0</version>
</dependency>

由于使用grpc还需要用到protobuf生成的Java代码,如果通过命令生成再拷贝jar文件不好管理,可以使用maven插件,把proto文件拷贝到指定目录,在编译时就会自动生成java文件放到target目录。

<build>
  <extensions>
    <extension>
      <groupId>kr.motd.maven</groupId>
      <artifactId>os-maven-plugin</artifactId>
      <version>1.4.1.Final</version>
    </extension>
  </extensions>
  <plugins>
    <plugin>
      <groupId>org.xolstice.maven.plugins</groupId>
      <artifactId>protobuf-maven-plugin</artifactId>
      <version>0.5.0</version>
      <configuration>
        <!--
          The version of protoc must match protobuf-java. If you don't depend on
          protobuf-java directly, you will be transitively depending on the
          protobuf-java version that grpc depends on.
        -->
        <protocArtifact>com.google.protobuf:protoc:3.0.0:exe:${os.detected.classifier}</protocArtifact>
        <pluginId>grpc-java</pluginId>
        <pluginArtifact>io.grpc:protoc-gen-grpc-java:1.0.0:exe:${os.detected.classifier}</pluginArtifact>
      </configuration>
      <executions>
        <execution>
          <goals>
            <goal>compile</goal>
            <goal>compile-custom</goal>
          </goals>
        </execution>
      </executions>
    </plugin>
  </plugins>
</build>

注意我们需要加入TensorFlow serving和TensorFlow项目的proto文件,由于我们不使用bazel编译,因此proto文件的依赖路径需要修改,建议参考上面的完整项目。

构造TensorProto对象

使用protobuf定义了请求的接口,但我们还需要构建protobuf生成代码中的TensorProto对象,本质上是一个多维数据,在C++和Python中都有函数可以直接生成。

Java可以定义多维数据,然后参考这个Stackoverflow答案来构建 http://stackoverflow.com/questions/39443019/how-can-i-create-tensorproto-for-tensorflow-in-java ,下面是一个构建二位TensorProto的代码。

// Generate features TensorProto
float[][] featuresTensorData = new float[][]{
    {10f, 10f, 10f, 8f, 6f, 1f, 8f, 9f, 1f},
    {10f, 10f, 10f, 8f, 6f, 1f, 8f, 9f, 1f},
};
TensorProto.Builder featuresTensorBuilder = TensorProto.newBuilder();
        featuresTensorBuilder.setDtype(org.tensorflow.framework.DataType.DT_FLOAT);

for (int i = 0; i < featuresTensorData.length; ++i) {
    for (int j = 0; j < featuresTensorData[i].length; ++j) {
        featuresTensorBuilder.addFloatVal(featuresTensorData[i][j]);
    }
}

TensorShapeProto.Dim dim1 = TensorShapeProto.Dim.newBuilder().setSize(2).build();
TensorShapeProto.Dim dim2 = TensorShapeProto.Dim.newBuilder().setSize(9).build();
TensorShapeProto shape = TensorShapeProto.newBuilder().addDim(dim1).addDim(dim2).build();
featuresTensorBuilder.setTensorShape(shape);
TensorProto featuresTensorProto = featuresTensorBuilder.build();

注意除了设置data,shape和dtype都需要我们手动设置,否则服务端无法解析TensorProto成tensor对象。

读取图片文件生成TensorProto

在图像分类等场景中,我们需要读取图片文件生成TensorProto对象,才可以通过gRPC请求TensorFlow serving服务,这里提供一个Java例子,测试支持jpg和png图片格式。

这里有完整的使用CNN训练模型和inference的例子,Java客户端可以直接读取本地文件来请求服务进行预测和分类 https://github.com/tobegit3hub/deep_cnn/tree/master/java_predict_client

// Generate image file to array
int[][][][] featuresTensorData = new int[2][32][32][3];

String[] imageFilenames = new String[]{"../data/inference/Mew.png", "../data/inference/Pikachu.png"};

for (int i = 0; i < imageFilenames.length; i++) {

    // Convert image file to multi-dimension array
    File imageFile = new File(imageFilenames[i]);
    try {
        BufferedImage image = ImageIO.read(imageFile);
        logger.info("Start to convert the image: " + imageFile.getPath());

        int imageWidth = 32;
        int imageHeight = 32;
        int[][] imageArray = new int[imageHeight][imageWidth];

        for (int row = 0; row < imageHeight; row++) {
            for (int column = 0; column < imageWidth; column++) {
                imageArray[row][column] = image.getRGB(column, row);

                int pixel = image.getRGB(column, row);
                int red = (pixel >> 16) & 0xff;
                int green = (pixel >> 8) & 0xff;
                int blue = pixel & 0xff;

                featuresTensorData[i][row][column][0] = red;
                featuresTensorData[i][row][column][1] = green;
                featuresTensorData[i][row][column][2] = blue;
            }
        }
    } catch (IOException e) {
        logger.log(Level.WARNING, e.getMessage());
        System.exit(1);
    }
}

// Generate features TensorProto
TensorProto.Builder featuresTensorBuilder = TensorProto.newBuilder();

for (int i = 0; i < featuresTensorData.length; ++i) {
    for (int j = 0; j < featuresTensorData[i].length; ++j) {
        for (int k = 0; k < featuresTensorData[i][j].length; ++k) {
            for (int l = 0; l < featuresTensorData[i][j][k].length; ++l) {
                featuresTensorBuilder.addFloatVal(featuresTensorData[i][j][k][l]);
            }
        }
    }
}

TensorShapeProto.Dim featuresDim1 = TensorShapeProto.Dim.newBuilder().setSize(2).build();
TensorShapeProto.Dim featuresDim2 = TensorShapeProto.Dim.newBuilder().setSize(32).build();
TensorShapeProto.Dim featuresDim3 = TensorShapeProto.Dim.newBuilder().setSize(32).build();
TensorShapeProto.Dim featuresDim4 = TensorShapeProto.Dim.newBuilder().setSize(3).build();

TensorShapeProto featuresShape = TensorShapeProto.newBuilder().addDim(featuresDim1).addDim(featuresDim2).addDim(featuresDim3).addDim(featuresDim4).build();
featuresTensorBuilder.setDtype(org.tensorflow.framework.DataType.DT_FLOAT).setTensorShape(featuresShape);
TensorProto featuresTensorProto = featuresTensorBuilder.build();