当前位置: 首页 > 编程笔记 >

K均值聚类算法的Java版实现代码示例

穆鸿波
2023-03-14
本文向大家介绍K均值聚类算法的Java版实现代码示例,包括了K均值聚类算法的Java版实现代码示例的使用技巧和注意事项,需要的朋友参考一下

1.简介

K均值聚类算法是先随机选取K个对象作为初始的聚类中心。然后计算每个对象与各个种子聚类中心之间的距离,把每个对象分配给距离它最近的聚类中心。聚类中心以及分配给它们的对象就代表一个聚类。一旦全部对象都被分配了,每个聚类的聚类中心会根据聚类中现有的对象被重新计算。这个过程将不断重复直到满足某个终止条件。终止条件可以是没有(或最小数目)对象被重新分配给不同的聚类,没有(或最小数目)聚类中心再发生变化,误差平方和局部最小。

2.什么是聚类

聚类是一个将数据集中在某些方面相似的数据成员进行分类组织的过程,聚类就是一种发现这种内在结构的技术,聚类技术经常被称为无监督学习。

3.什么是k均值聚类

k均值聚类是最著名的划分聚类算法,由于简洁和效率使得他成为所有聚类算法中最广泛使用的。给定一个数据点集合和需要的聚类数目k,k由用户指定,k均值算法根据某个距离函数反复把数据分入k个聚类中。

4.实现

Java代码如下:

package org.algorithm;
import java.util.ArrayList;
import java.util.Random;
/** 
 * K均值聚类算法 
 */
public class Kmeans {
	private int k;
	// 分成多少簇 
	private int m;
	// 迭代次数 
	private int dataSetLength;
	// 数据集元素个数,即数据集的长度 
	private ArrayList<float[]> dataSet;
	// 数据集链表 
	private ArrayList<float[]> center;
	// 中心链表 
	private ArrayList<ArrayList<float[]>> cluster;
	// 簇 
	private ArrayList<float> jc;
	// 误差平方和,k越接近dataSetLength,误差越小 
	private Random random;
	/** 
   * 设置需分组的原始数据集 
   * 
   * @param dataSet 
   */
	public void setDataSet(ArrayList<float[]> dataSet) {
		this.dataSet = dataSet;
	}
	/** 
   * 获取结果分组 
   * 
   * @return 结果集 
   */
	public ArrayList<ArrayList<float[]>> getCluster() {
		return cluster;
	}
	/** 
   * 构造函数,传入需要分成的簇数量 
   * 
   * @param k 
   *      簇数量,若k<=0时,设置为1,若k大于数据源的长度时,置为数据源的长度 
   */
	public Kmeans(int k) {
		if (k <= 0) {
			k = 1;
		}
		this.k = k;
	}
	/** 
   * 初始化 
   */
	private void init() {
		m = 0;
		random = new Random();
		if (dataSet == null || dataSet.size() == 0) {
			initDataSet();
		}
		dataSetLength = dataSet.size();
		if (k > dataSetLength) {
			k = dataSetLength;
		}
		center = initCenters();
		cluster = initCluster();
		jc = new ArrayList<float>();
	}
	/** 
   * 如果调用者未初始化数据集,则采用内部测试数据集 
   */
	private void initDataSet() {
		dataSet = new ArrayList<float[]>();
		// 其中{6,3}是一样的,所以长度为15的数据集分成14簇和15簇的误差都为0 
		float[][] dataSetArray = new float[][] { { 8, 2 }, { 3, 4 }, { 2, 5 }, 
		        { 4, 2 }, { 7, 3 }, { 6, 2 }, { 4, 7 }, { 6, 3 }, { 5, 3 }, 
		        { 6, 3 }, { 6, 9 }, { 1, 6 }, { 3, 9 }, { 4, 1 }, { 8, 6 } };
		for (int i = 0; i < dataSetArray.length; i++) {
			dataSet.add(dataSetArray[i]);
		}
	}
	/** 
   * 初始化中心数据链表,分成多少簇就有多少个中心点 
   * 
   * @return 中心点集 
   */
	private ArrayList<float[]> initCenters() {
		ArrayList<float[]> center = new ArrayList<float[]>();
		int[] randoms = new int[k];
		Boolean flag;
		int temp = random.nextint(dataSetLength);
		randoms[0] = temp;
		for (int i = 1; i < k; i++) {
			flag = true;
			while (flag) {
				temp = random.nextint(dataSetLength);
				int j = 0;
				// 不清楚for循环导致j无法加1 
				// for(j=0;j<i;++j) 
				// { 
				// if(temp==randoms[j]); 
				// { 
				// break; 
				// } 
				// } 
				while (j < i) {
					if (temp == randoms[j]) {
						break;
					}
					j++;
				}
				if (j == i) {
					flag = false;
				}
			}
			randoms[i] = temp;
		}
		// 测试随机数生成情况 
		// for(int i=0;i<k;i++) 
		// { 
		// System.out.println("test1:randoms["+i+"]="+randoms[i]); 
		// } 
		// System.out.println(); 
		for (int i = 0; i < k; i++) {
			center.add(dataSet.get(randoms[i]));
			// 生成初始化中心链表
		}
		return center;
	}
	/** 
   * 初始化簇集合 
   * 
   * @return 一个分为k簇的空数据的簇集合 
   */
	private ArrayList<ArrayList<float[]>> initCluster() {
		ArrayList<ArrayList<float[]>> cluster = new ArrayList<ArrayList<float[]>>();
		for (int i = 0; i < k; i++) {
			cluster.add(new ArrayList<float[]>());
		}
		return cluster;
	}
	/** 
   * 计算两个点之间的距离 
   * 
   * @param element 
   *      点1 
   * @param center 
   *      点2 
   * @return 距离 
   */
	private float distance(float[] element, float[] center) {
		float distance = 0.0f;
		float x = element[0] - center[0];
		float y = element[1] - center[1];
		float z = x * x + y * y;
		distance = (float) Math.sqrt(z);
		return distance;
	}
	/** 
   * 获取距离集合中最小距离的位置 
   * 
   * @param distance 
   *      距离数组 
   * @return 最小距离在距离数组中的位置 
   */
	private int minDistance(float[] distance) {
		float minDistance = distance[0];
		int minLocation = 0;
		for (int i = 1; i < distance.length; i++) {
			if (distance[i] < minDistance) {
				minDistance = distance[i];
				minLocation = i;
			} else if (distance[i] == minDistance) // 如果相等,随机返回一个位置 
			{
				if (random.nextint(10) < 5) {
					minLocation = i;
				}
			}
		}
		return minLocation;
	}
	/** 
   * 核心,将当前元素放到最小距离中心相关的簇中 
   */
	private void clusterSet() {
		float[] distance = new float[k];
		for (int i = 0; i < dataSetLength; i++) {
			for (int j = 0; j < k; j++) {
				distance[j] = distance(dataSet.get(i), center.get(j));
				// System.out.println("test2:"+"dataSet["+i+"],center["+j+"],distance="+distance[j]);
			}
			int minLocation = minDistance(distance);
			// System.out.println("test3:"+"dataSet["+i+"],minLocation="+minLocation); 
			// System.out.println(); 
			cluster.get(minLocation).add(dataSet.get(i));
			// 核心,将当前元素放到最小距离中心相关的簇中
		}
	}
	/** 
   * 求两点误差平方的方法 
   * 
   * @param element 
   *      点1 
   * @param center 
   *      点2 
   * @return 误差平方 
   */
	private float errorSquare(float[] element, float[] center) {
		float x = element[0] - center[0];
		float y = element[1] - center[1];
		float errSquare = x * x + y * y;
		return errSquare;
	}
	/** 
   * 计算误差平方和准则函数方法 
   */
	private void countRule() {
		float jcF = 0;
		for (int i = 0; i < cluster.size(); i++) {
			for (int j = 0; j < cluster.get(i).size(); j++) {
				jcF += errorSquare(cluster.get(i).get(j), center.get(i));
			}
		}
		jc.add(jcF);
	}
	/** 
   * 设置新的簇中心方法 
   */
	private void setNewCenter() {
		for (int i = 0; i < k; i++) {
			int n = cluster.get(i).size();
			if (n != 0) {
				float[] newCenter = { 0, 0 };
				for (int j = 0; j < n; j++) {
					newCenter[0] += cluster.get(i).get(j)[0];
					newCenter[1] += cluster.get(i).get(j)[1];
				}
				// 设置一个平均值 
				newCenter[0] = newCenter[0] / n;
				newCenter[1] = newCenter[1] / n;
				center.set(i, newCenter);
			}
		}
	}
	/** 
   * 打印数据,测试用 
   * 
   * @param dataArray 
   *      数据集 
   * @param dataArrayName 
   *      数据集名称 
   */
	public void printDataArray(ArrayList<float[]> dataArray, 
	      String dataArrayName) {
		for (int i = 0; i < dataArray.size(); i++) {
			System.out.println("print:" + dataArrayName + "[" + i + "]={" 
			          + dataArray.get(i)[0] + "," + dataArray.get(i)[1] + "}");
		}
		System.out.println("===================================");
	}
	/** 
   * Kmeans算法核心过程方法 
   */
	private void kmeans() {
		init();
		// printDataArray(dataSet,"initDataSet"); 
		// printDataArray(center,"initCenter"); 
		// 循环分组,直到误差不变为止 
		while (true) {
			clusterSet();
			// for(int i=0;i<cluster.size();i++) 
			// { 
			// printDataArray(cluster.get(i),"cluster["+i+"]"); 
			// } 
			countRule();
			// System.out.println("count:"+"jc["+m+"]="+jc.get(m)); 
			// System.out.println(); 
			// 误差不变了,分组完成 
			if (m != 0) {
				if (jc.get(m) - jc.get(m - 1) == 0) {
					break;
				}
			}
			setNewCenter();
			// printDataArray(center,"newCenter"); 
			m++;
			cluster.clear();
			cluster = initCluster();
		}
		// System.out.println("note:the times of repeat:m="+m);//输出迭代次数
	}
	/** 
   * 执行算法 
   */
	public void execute() {
		long startTime = System.currentTimeMillis();
		System.out.println("kmeans begins");
		kmeans();
		long endTime = System.currentTimeMillis();
		System.out.println("kmeans running time=" + (endTime - startTime) 
		        + "ms");
		System.out.println("kmeans ends");
		System.out.println();
	}
}

5.说明:

具体代码是从网上找的,根据自己的理解加了注释和进行部分修改,若注释有误还望指正

6.测试

package org.test;
import java.util.ArrayList;
import org.algorithm.Kmeans;
public class KmeansTest {
	public static void main(String[] args) 
	  {
		//初始化一个Kmean对象,将k置为10 
		Kmeans k=new Kmeans(10);
		ArrayList<float[]> dataSet=new ArrayList<float[]>();
		dataSet.add(new float[]{1,2});
		dataSet.add(new float[]{3,3});
		dataSet.add(new float[]{3,4});
		dataSet.add(new float[]{5,6});
		dataSet.add(new float[]{8,9});
		dataSet.add(new float[]{4,5});
		dataSet.add(new float[]{6,4});
		dataSet.add(new float[]{3,9});
		dataSet.add(new float[]{5,9});
		dataSet.add(new float[]{4,2});
		dataSet.add(new float[]{1,9});
		dataSet.add(new float[]{7,8});
		//设置原始数据集 
		k.setDataSet(dataSet);
		//执行算法 
		k.execute();
		//得到聚类结果 
		ArrayList<ArrayList<float[]>> cluster=k.getCluster();
		//查看结果 
		for (int i=0;i<cluster.size();i++) 
		    {
			k.printDataArray(cluster.get(i), "cluster["+i+"]");
		}
	}
}

总结:测试代码已经通过。并对聚类的结果进行了查看,结果基本上符合要求。至于有没有更精确的算法有待发现。具体的实践还有待挖掘

总结

以上就是本文关于K均值聚类算法的Java版实现代码示例的全部内容,希望对大家有所帮助。感兴趣的朋友可以继续参阅本站其他相关专题。如有不足之处,欢迎留言指出。感谢朋友们对本站的支持!

 类似资料:
  • $k$均值聚类算法(k-means clustering algorithm) 在聚类的问题中,我们得到了一组训练样本集 ${x^{(1)},...,x^{(m)}}$,然后想要把这些样本划分成若干个相关的“类群(clusters)”。其中的 $x^{(i)}\in R^n$,而并未给出分类标签 $y^{(i)}$ 。所以这就是一个无监督学习的问题了。 $K$ 均值聚类算法如下所示: 随机初始化(

  • 本文向大家介绍k-means 聚类算法与Python实现代码,包括了k-means 聚类算法与Python实现代码的使用技巧和注意事项,需要的朋友参考一下 k-means 聚类算法思想先随机选择k个聚类中心,把集合里的元素与最近的聚类中心聚为一类,得到一次聚类,再把每一个类的均值作为新的聚类中心重新聚类,迭代n次得到最终结果分步解析 一、初始化聚类中心 首先随机选择集合里的一个元素作为第一个聚类中

  • 聚类 聚类,简单来说,就是将一个庞杂数据集中具有相似特征的数据自动归类到一起,称为一个簇,簇内的对象越相似,聚类的效果越好。它是一种无监督的学习(Unsupervised Learning)方法,不需要预先标注好的训练集。聚类与分类最大的区别就是分类的目标事先已知,例如猫狗识别,你在分类之前已经预先知道要将它分为猫、狗两个种类;而在你聚类之前,你对你的目标是未知的,同样以动物为例,对于一个动物集来

  • 本文向大家介绍Python聚类算法之基本K均值实例详解,包括了Python聚类算法之基本K均值实例详解的使用技巧和注意事项,需要的朋友参考一下 本文实例讲述了Python聚类算法之基本K均值运算技巧。分享给大家供大家参考,具体如下: 基本K均值 :选择 K 个初始质心,其中 K 是用户指定的参数,即所期望的簇的个数。每次循环中,每个点被指派到最近的质心,指派到同一个质心的点集构成一个。然后,根据指

  • 本文向大家介绍python实现k-means聚类算法,包括了python实现k-means聚类算法的使用技巧和注意事项,需要的朋友参考一下 k-means聚类算法 k-means是发现给定数据集的k个簇的算法,也就是将数据集聚合为k类的算法。 算法过程如下: 1)从N个文档随机选取K个文档作为质心 2)对剩余的每个文档测量其到每个质心的距离,并把它归到最近的质心的类,我们一般取欧几里得距离 3)重

  • 目标 在本章中,我们将了解K-Means聚类的概念,其工作原理等。 理论 我们将用一个常用的例子来处理这个问题。 T-shirt尺寸问题 考虑一家公司,该公司将向市场发布新型号的T恤。显然,他们将不得不制造不同尺寸的模型,以满足各种规模的人们的需求。因此,该公司会记录人们的身高和体重数据,并将其绘制到图形上,如下所示: 公司无法制作所有尺寸的T恤。取而代之的是,他们将人划分为小,中和大,并仅制造这