Java使用OpenCV类库实现简单的KNN Machine Learning.

姚星河
2023-12-01

OpenCV目前在图像处理和识别领域里, 是最强大的开源类库. (maybe我是井底之蛙, 大家有好东西别藏着掖着), 传送门在这: http://opencv.org


OpenCV在图像处理和识别领域覆盖的面也非常广泛, 图像的各种操作(缩放, 旋转, 变灰, ColorSpace转换, 门槛值等等等等; 最大的亮点是它实现了非常多的图像抓取和识别算法, 如GrabCut, SIFT, SURF (额, 这几个算法我是看了一周啊,,, 没看懂原理. 现在看到高数一样的公式就头大. 当然有能力有信心的同学可以观摩一下SIFT的原著paper, 咳, 谷歌的到)  另外也实现了Machine Learning的算法, KNN, SVM. (KNN是看懂了, 我承认, SVM具体原理完全没懂, 啥事Logistics Regression?? 大学的东西完全没印象)  没看懂归没看懂, 对于我们这样的伸手党来说, 原理神马的不重要, 重要的是会用!!! 


好.. 我们先从最简单的KNN说起吧, KNN=k nearest neighbours, 原理比较简单, 这里就不阐述了, 想看原理的同学请直接进传送门: http://docs.opencv.org/3.1.0/d5/d26/tutorial_py_knn_understanding.html#gsc.tab=0  


OpenCV虽然很强大, 但...对于我这种早就把C++还给老师 也没有跟随潮流Python的同学来说, 要找个Java用OpenCV的例子, 就TMD难啊~~~ 所以我也花了挺久研究Java 类库和C++/Python类库的对应, 这里写了一个简单的KNN例子, 以供大家参考:


程序第一步是load native dll. (OpenCV是C++写的)

然后是准备训练数据和训练数据标签 (标签是告诉KNN算法, 某个数据是啥), 例如我这个例子是小于100的两个数是0, 大于100的两个数是1. 数都是随机生成滴. 各训练50个样本.

(50, 10) 是0,  (1, 99)是0, (33, 34)是0,   (155, 133)是1, (111, 199)是1, 等等等等.

KNN有了这100个样本后, 就像在一个2D平面上, 画了100个点, 50个标签是0, 50个标签是1.


接着再生成100个随机测试数据, 他们的值是 (n, n)  其中0<n<200, 那么我们认为很简单的判断当n<100, 分类结果是0, 反之是1.

使用KNN的find_nearest()方法对100个测试数据进行分类, 得到3个结果集, 分别是:

results: 100个0或1, 分表代表100个测试数据的分类结果.

neighborResponses: 100个5D数组, 代表5个最近的样本label.

dists: 100个5D数组, 代表5个最近样本的距离.



具体实现如下:

import java.util.Date;
import java.util.Random;

import org.opencv.core.Core;
import org.opencv.core.CvType;
import org.opencv.core.Mat;
import org.opencv.core.Scalar;
import org.opencv.ml.CvKNearest;

public class S4_KNN {
	public static final int K = 5;

	public static void main(String[] args) {
		//Must: to load native opencv library (you must add the DLL to the library path first)
		System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
		
		Random random = new Random(new Date().getTime());

		//prepare trainData and trainLabel
		Mat trainData = new Mat(100, 2, CvType.CV_32FC1, new Scalar(1));
		Mat trainLabel = new Mat(100, 1, CvType.CV_32FC1, new Scalar(1));
		for (int i = 0; i < 50; i++) {
			trainData.put(i, 0, random.nextInt(100));
			trainData.put(i, 1, random.nextInt(100));
			trainLabel.put(i, 0, 0);
		}
		for (int i = 50; i < 100; i++) {
			trainData.put(i, 0, random.nextInt(100) + 100);
			trainData.put(i, 1, random.nextInt(100) + 100);
			trainLabel.put(i, 0, 1);
		}

		// System.out.println("trainData:\n" + trainData.dump());
		// System.out.println("trainLabel:\n" + trainLabel.dump());

		//train data using KNN
		CvKNearest knn = new CvKNearest();
		boolean success = knn.train(trainData, trainLabel);
		System.out.println("training result: " + success);

		//prepare test data
		Mat testData = new Mat(100, 2, CvType.CV_32FC1, new Scalar(1));
		for (int i = 0; i < 100; i++) {
			int r = random.nextInt(200);
			testData.put(i, 0, r);
			testData.put(i, 1, r);
		}

		//find the nearest neighbours of test data
		Mat results = new Mat();
		Mat neighborResponses = new Mat();
		Mat dists = new Mat();
		knn.find_nearest(testData, K, results, neighborResponses, dists);

		// print out the results
		System.out.println("testData:\n" + testData.dump());
		System.out.println("results:\n" + results.dump());
		System.out.println("neighborResponses:\n" + neighborResponses.dump());
		System.out.println("dists:\n" + dists.dump());

	}
}

 类似资料: