PaddleServing用java编写yolov3客户端代码

曾景龙
2023-12-01

大家好!学习人工智能断断续续的3个月,一直使用paddle深度学习框架,从数据集、训练、导出模型、部署serving整体流程都学了一遍,今天就是为大家补充PaddleServing中java example中缺少的yolov3的代码,已经测试过了没问。

public class Yolov3Client
{
    public String yolov3(String model_config_path,String filename)throws Exception
    {
        System.out.println("===================yolov3===================");
        int height = 608;
        int width = 608;
        int channels = 3;
        NativeImageLoader loader = new NativeImageLoader(height, width, channels);
        INDArray BGRimage = null;
        try {
            BGRimage = loader.asMatrix(new File(filename));
        } catch (java.io.IOException e) {
            System.out.println("load image fail.");
            throw new Exception("load image fail.");
        }
        //先做规划处理(img/255-mean)/std
        //        mean = [0.485, 0.456, 0.406]
//        std = [0.229, 0.224, 0.225]
        float[] mean = new float[]{0.485F, 0.456F, 0.406F};
        float[] std  = new float[]{0.229F, 0.224F, 0.225F};
        BGRimage = BGRimage.reshape(height,width,channels);
        INDArray mean_array = Nd4j.create(mean).reshape(1,1,-1);
        INDArray std_array = Nd4j.create(std).reshape(1,1,-1);
//        img  = (img/255.0-mean)/std
        INDArray img1 = BGRimage.div(255.0);
        INDArray img2 = img1.sub(mean_array);
        INDArray img = img2.div(std_array);
        INDArray imga = img.reshape(channels,height,width);
        HashMap<String, Object> feed_data = new HashMap<>();
        feed_data.put("image",imga);
        feed_data.put("im_shape",Nd4j.create(new float[]{608,608}));
        feed_data.put("scale_factor",Nd4j.create(new float[]{608.0F/474.0F,608.0F/800.0F}));
        List<String> fetch = Arrays.asList("multiclass_nms3_0.tmp_0");
        Client client = new Client();
        client.setIP("192.168.10.166");
        client.setPort("9494");
        client.set_use_grpc_client(false);
        client.set_http_proto(false);
        client.loadClientConfig(model_config_path);
        String result = client.predict(feed_data, fetch, false, 0);
        System.out.println(result);
        return result;
    }
}
 类似资料: