当前位置: 首页 > 工具软件 > ari4java > 使用案例 >

类簇 java_JAVA实现聚类指标的计算Purity、NMI、RI、Precision、Recall、F值。

慕容嘉荣
2023-12-01

第一个:计算NMI的:

package clusters;

import java.io.*;

import java.util.ArrayList;

import java.util.HashMap;

import java.util.List;

import java.util.Map;

/**

* DATE: 16-6-18 TIME: 上午10:00

*/

/**

* 参考文献:http://www-nlp.stanford.edu/IR-book/html/htmledition/evaluation-of-clustering-1.html

*/

public class NormalizedMutualInformation {

public static String path = "/home/fhqplzj/IdeaProjects/Vein/src/main/resources/nmi_data";

public static void loadData(List> lists) {

try {

BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(new FileInputStream(path)));

String line;

while ((line = bufferedReader.readLine()) != null) {

String[] data = line.split("\\s+");

ArrayList integers = new ArrayList<>();

for (String s : data) {

integers.add(Integer.parseInt(s));

}

lists.add(integers);

}

bufferedReader.close();

} catch (FileNotFoundException e) {

e.printStackTrace();

} catch (IOException e) {

e.printStackTrace();

}

}

public static void main(String[] args) {

List> lists = new ArrayList<>();

loadData(lists);

int K = lists.size();

int N = 0;

int[] clusters = new int[K];

for (int i = 0; i < K; i++) {

clusters[i] = lists.get(i).size();

N += clusters[i];

}

Map map = new HashMap<>();

for (List list : lists) {

for (Integer integer : list) {

map.put(integer, map.getOrDefault(integer, 0) + 1);

}

}

double clusterEntropy = 0;

for (int cluster : clusters) {

double tmp = 1.0 * cluster / N;

clusterEntropy -= (tmp * (Math.log(tmp) / Math.log(2)));

}

// System.out.println("clusterEntropy = " + clusterEntropy);

double classEntropy = 0;

for (Integer integer : map.values()) {

double tmp = 1.0 * integer / N;

classEntropy -= (tmp * (Math.log(tmp) / Math.log(2)));

}

// System.out.println("classEntropy = " + classEntropy);

double totalEntropy = 0;

Map tmpMap = new HashMap<>();

for (int i = 0; i < K; i++) {

int wk = clusters[i];

tmpMap.clear();

for (Integer integer : lists.get(i)) {

tmpMap.put(integer, tmpMap.getOrDefault(integer, 0) + 1);

}

for (Map.Entry entry : tmpMap.entrySet()) {

int cj = map.get(entry.getKey());

int value = entry.getValue();

totalEntropy += (1.0 * value / N * (Math.log(1.0 * N * value / (wk * cj)) / Math.log(2)));

}

}

// System.out.println("totalEntropy = " + totalEntropy);

double nmi = 2 * totalEntropy / (clusterEntropy + classEntropy);

System.out.println(String.format("nmi = %.2f", nmi));

}

}

//

第二个,一些工具类:

package clusters;

import java.util.Arrays;

import java.util.List;

import java.util.Map;

/**

* DATE: 16-6-18 TIME: 上午11:07

*/

public class ClusterUtils {

public static int combination(int n, int k) {

if (k > n) {

return 0;

}

int[] data = new int[n + 1];

data[0] = 1;

for (int i = 0; i < n; i++) {

for (int j = i + 1; j >= 1; j--) {

data[j] += data[j - 1];

}

}

return data[k];

}

public static int computeTPAndFP(int[] clusters) {

int result = 0;

for (int cluster : clusters) {

result += combination(cluster, 2);

}

return result;

}

public static int computeFP(List> mapList) {

int FP = 0;

for (Map map : mapList) {

for (Integer integer : map.values()) {

if (integer >= 2) {

FP += combination(integer, 2);

}

}

}

return FP;

}

public static int computeOneClass(List list) {

int n = list.size();

if (n == 0) {

return 0;

}

int result = 0;

for (int i = 0; i < n - 1; i++) {

for (int j = i + 1; j < n; j++) {

result += list.get(i) * list.get(j);

}

}

return result;

}

public static int computeFN(List> lists) {

int result = 0;

for (List list : lists) {

result += computeOneClass(list);

}

return result;

}

public static double computeFValue(double P, double R, double beta) {

return (beta * beta + 1) * P * R / (beta * beta * P + R);

}

public static void main(String[] args) {

List list = Arrays.asList(1, 4, 0);

System.out.println("computeOneClass(list) = " + computeOneClass(list));

}

}

第三个,计算RI、P、R、F以及Purity的,顺便调用了NMI,一起打印输出,beta取1和5,如stanford文章所述,计算F1和F5

package clusters;

import java.util.*;

/**

* DATE: 16-6-18 TIME: 上午11:05

*/

public class RandIndex {

public static void main(String[] args) {

List> lists = new ArrayList<>();

NormalizedMutualInformation.loadData(lists);

int K = lists.size();

int N = 0;

int[] clusters = new int[K];

for (int i = 0; i < K; i++) {

clusters[i] = lists.get(i).size();

N += clusters[i];

}

int TPAndFP = ClusterUtils.computeTPAndFP(clusters);

List> mapList = new ArrayList<>();

for (List list : lists) {

Map map = new HashMap<>();

for (Integer integer : list) {

map.put(integer, map.getOrDefault(integer, 0) + 1);

}

mapList.add(map);

}

Set set = new HashSet<>();

for (Map map : mapList) {

set.addAll(map.keySet());

}

int FP = ClusterUtils.computeFP(mapList);

int TP = TPAndFP - FP;

List> lists1 = new ArrayList<>();

for (Integer integer : set) {

List list = new ArrayList<>();

for (Map map : mapList) {

if (map.containsKey(integer)) {

list.add(map.get(integer));

}

}

lists1.add(list);

}

int FN = ClusterUtils.computeFN(lists1);

int TN = ClusterUtils.combination(N, 2) - TPAndFP - FN;

// System.out.println("TP = " + TP);

// System.out.println("FP = " + FP);

// System.out.println("FN = " + FN);

// System.out.println("TN = " + TN);

double RI = 1.0 * (TP + TN) / (TP + FP + FN + TN);

/**

* compute Purity

*/

int totalMax = 0;

for (Map map : mapList) {

totalMax += map.values().stream().reduce(Math::max).get();

}

double purity = 1.0 * totalMax / N;

System.out.println(String.format("purity = %.2f", purity));

/**

* println Normalized Mutual Information

*/

NormalizedMutualInformation.main(null);

System.out.println(String.format("RI = %.2f", RI));

/**

* compute F5

*/

double P = 1.0 * TP / (TP + FP);

double R = 1.0 * TP / (TP + FN);

double beta = 1;

System.out.println(String.format("P = %.2f", P));

System.out.printf("R = %.3f\n", R);

System.out.println(String.format("beta = 1, F = %.2f", ClusterUtils.computeFValue(P, R, beta)));

beta = 5;

System.out.println(String.format("beta = 5, F = %.3f", ClusterUtils.computeFValue(P, R, beta)));

}

}

输入数据就是stanford文中的3个类簇:

1 1 1 1 1 2

1 2 2 2 2 3

1 1 3 3 3

本文来自http://blog.csdn.net/asd991936157/article/details/51705958,只为学习

 类似资料: