当前位置: 首页 > 文档资料 > 数据挖掘算法 >

Apriori 算法 关联规则挖掘

优质
小牛编辑
122浏览
2023-12-01

我的数据挖掘算法代码:https://github.com/linyiqun/DataMiningAlgorithm

介绍

Apriori算法是一个经典的数据挖掘算法,Apriori的单词的意思是"先验的",说明这个算法是具有先验性质的,就是说要通过上一次的结果推导出下一次的结果,这个如何体现将会在下面的分析中会慢慢的体现出来。Apriori算法的用处是挖掘频繁项集的,频繁项集粗俗的理解就是找出经常出现的组合,然后根据这些组合最终推出我们的关联规则。

Apriori算法原理

Apriori算法是一种逐层搜索的迭代式算法,其中k项集用于挖掘(k+1)项集,这是依靠他的先验性质的:

频繁项集的所有非空子集一定是也是频繁的。

通过这个性质可以对候选集进行剪枝。用k项集如何生成(k+1)项集呢,这个是算法里面最难也是最核心的部分。

通过2个步骤

1、连接步,将频繁项自己与自己进行连接运算。

2、剪枝步,去除候选集项中的不符合要求的候选项,不符合要求指的是这个候选项的子集并非都是频繁项,要遵守上文提到的先验性质。

3、通过1,2步骤还不够,在后面还要根据支持度计数筛选掉不满足最小支持度数的候选集。

算法实例

首先是测试数据:

交易ID

商品ID列表

T100

I1,I2,I5

T200

I2,I4

T300

I2,I3

T400

I1,I2,I4

T500

I1,I3

T600

I2,I3

T700

I1,I3

T800

I1,I2,I3,I5

T900

I1,I2,I3

算法的步骤图:

最后我们可以看到频繁3项集的结果为{1, 2, 3}和{1, 2, 5},然后我们去后者{1, 2, 5}作为频繁项集来生产他的关联规则,但是在这之前得先知道一些概念,怎么样才能够成为一条关联规则,关有频繁项集还是不够的。

关联规则

confidence(置信度)

confidence的中文意思为自信的,在这里其实表示的是一种条件概率,当在A条件下,B发生的概率就可以表示为confidence(A->B)=p(B|A),意为在A的情况下,推出B的概率。那么关联规则与有什么关系呢,请继续往下看。

最小置信度阈值

按照字面上的意思就是限制置信度值的一个限制条件嘛,这个很好理解。

强规则

强规则就是指的是置信度满足最小置信度(就是>=最小置信度)的推断就是一个强规则,也就是文中所说的关联规则了。这个在下面的程序中会有所体现。

算法的代码实现

我自己写的算法实现可能会让你有点晦涩难懂,不过重在理解算法的整个思路即可,尤其是连接步和剪枝步是最难点所在,可能还存在bug。

输入数据:

T1 1 2 5
T2 2 4
T3 2 3
T4 1 2 4
T5 1 3
T6 2 3
T7 1 3
T8 1 2 3 5
T9 1 2 3
频繁项类:
/**
 * 频繁项集
 * 
 * @author lyq
 * 
 */
public class FrequentItem implements Comparable<FrequentItem>{
  // 频繁项集的集合ID
  private String[] idArray;
  // 频繁项集的支持度计数
  private int count;
  //频繁项集的长度,1项集或是2项集,亦或是3项集
  private int length;

  public FrequentItem(String[] idArray, int count){
    this.idArray = idArray;
    this.count = count;
    length = idArray.length;
  }

  public String[] getIdArray() {
    return idArray;
  }

  public void setIdArray(String[] idArray) {
    this.idArray = idArray;
  }

  public int getCount() {
    return count;
  }

  public void setCount(int count) {
    this.count = count;
  }

  public int getLength() {
    return length;
  }

  public void setLength(int length) {
    this.length = length;
  }

  @Override
  public int compareTo(FrequentItem o) {
    // TODO Auto-generated method stub
    return this.getIdArray()[0].compareTo(o.getIdArray()[0]);
  }
  
}
主程序类:
package DataMining_Apriori;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

/**
 * apriori算法工具类
 * 
 * @author lyq
 * 
 */
public class AprioriTool {
  // 最小支持度计数
  private int minSupportCount;
  // 测试数据文件地址
  private String filePath;
  // 每个事务中的商品ID
  private ArrayList<String[]> totalGoodsIDs;
  // 过程中计算出来的所有频繁项集列表
  private ArrayList<FrequentItem> resultItem;
  // 过程中计算出来频繁项集的ID集合
  private ArrayList<String[]> resultItemID;

  public AprioriTool(String filePath, int minSupportCount) {
    this.filePath = filePath;
    this.minSupportCount = minSupportCount;
    readDataFile();
  }

  /**
   * 从文件中读取数据
   */
  private void readDataFile() {
    File file = new File(filePath);
    ArrayList<String[]> dataArray = new ArrayList<String[]>();

    try {
      BufferedReader in = new BufferedReader(new FileReader(file));
      String str;
      String[] tempArray;
      while ((str = in.readLine()) != null) {
        tempArray = str.split(" ");
        dataArray.add(tempArray);
      }
      in.close();
    } catch (IOException e) {
      e.getStackTrace();
    }

    String[] temp = null;
    totalGoodsIDs = new ArrayList<>();
    for (String[] array : dataArray) {
      temp = new String[array.length - 1];
      System.arraycopy(array, 1, temp, 0, array.length - 1);

      // 将事务ID加入列表吧中
      totalGoodsIDs.add(temp);
    }
  }

  /**
   * 判读字符数组array2是否包含于数组array1中
   * 
   * @param array1
   * @param array2
   * @return
   */
  public boolean iSStrContain(String[] array1, String[] array2) {
    if (array1 == null || array2 == null) {
      return false;
    }

    boolean iSContain = false;
    for (String s : array2) {
      // 新的字母比较时,重新初始化变量
      iSContain = false;
      // 判读array2中每个字符,只要包括在array1中 ,就算包含
      for (String s2 : array1) {
        if (s.equals(s2)) {
          iSContain = true;
          break;
        }
      }

      // 如果已经判断出不包含了,则直接中断循环
      if (!iSContain) {
        break;
      }
    }

    return iSContain;
  }

  /**
   * 项集进行连接运算
   */
  private void computeLink() {
    // 连接计算的终止数,k项集必须算到k-1子项集为止
    int endNum = 0;
    // 当前已经进行连接运算到几项集,开始时就是1项集
    int currentNum = 1;
    // 商品,1频繁项集映射图
    HashMap<String, FrequentItem> itemMap = new HashMap<>();
    FrequentItem tempItem;
    // 初始列表
    ArrayList<FrequentItem> list = new ArrayList<>();
    // 经过连接运算后产生的结果项集
    resultItem = new ArrayList<>();
    resultItemID = new ArrayList<>();
    // 商品ID的种类
    ArrayList<String> idType = new ArrayList<>();
    for (String[] a : totalGoodsIDs) {
      for (String s : a) {
        if (!idType.contains(s)) {
          tempItem = new FrequentItem(new String[] { s }, 1);
          idType.add(s);
          resultItemID.add(new String[] { s });
        } else {
          // 支持度计数加1
          tempItem = itemMap.get(s);
          tempItem.setCount(tempItem.getCount() + 1);
        }
        itemMap.put(s, tempItem);
      }
    }
    // 将初始频繁项集转入到列表中,以便继续做连接运算
    for (Map.Entry entry : itemMap.entrySet()) {
      list.add((FrequentItem) entry.getValue());
    }
    // 按照商品ID进行排序,否则连接计算结果将会不一致,将会减少
    Collections.sort(list);
    resultItem.addAll(list);

    String[] array1;
    String[] array2;
    String[] resultArray;
    ArrayList<String> tempIds;
    ArrayList<String[]> resultContainer;
    // 总共要算到endNum项集
    endNum = list.size() - 1;

    while (currentNum < endNum) {
      resultContainer = new ArrayList<>();
      for (int i = 0; i < list.size() - 1; i++) {
        tempItem = list.get(i);
        array1 = tempItem.getIdArray();
        for (int j = i + 1; j < list.size(); j++) {
          tempIds = new ArrayList<>();
          array2 = list.get(j).getIdArray();
          for (int k = 0; k < array1.length; k++) {
            // 如果对应位置上的值相等的时候,只取其中一个值,做了一个连接删除操作
            if (array1[k].equals(array2[k])) {
              tempIds.add(array1[k]);
            } else {
              tempIds.add(array1[k]);
              tempIds.add(array2[k]);
            }
          }
          resultArray = new String[tempIds.size()];
          tempIds.toArray(resultArray);

          boolean isContain = false;
          // 过滤不符合条件的的ID数组,包括重复的和长度不符合要求的
          if (resultArray.length == (array1.length + 1)) {
            isContain = isIDArrayContains(resultContainer,
                resultArray);
            if (!isContain) {
              resultContainer.add(resultArray);
            }
          }
        }
      }

      // 做频繁项集的剪枝处理,必须保证新的频繁项集的子项集也必须是频繁项集
      list = cutItem(resultContainer);
      currentNum++;
    }

    // 输出频繁项集
    for (int k = 1; k <= currentNum; k++) {
      System.out.println("频繁" + k + "项集:");
      for (FrequentItem i : resultItem) {
        if (i.getLength() == k) {
          System.out.print("{");
          for (String t : i.getIdArray()) {
            System.out.print(t + ",");
          }
          System.out.print("},");
        }
      }
      System.out.println();
    }
  }

  /**
   * 判断列表结果中是否已经包含此数组
   * 
   * @param container
   *            ID数组容器
   * @param array
   *            待比较数组
   * @return
   */
  private boolean isIDArrayContains(ArrayList<String[]> container,
      String[] array) {
    boolean isContain = true;
    if (container.size() == 0) {
      isContain = false;
      return isContain;
    }

    for (String[] s : container) {
      // 比较的视乎必须保证长度一样
      if (s.length != array.length) {
        continue;
      }

      isContain = true;
      for (int i = 0; i < s.length; i++) {
        // 只要有一个id不等,就算不相等
        if (s[i] != array[i]) {
          isContain = false;
          break;
        }
      }

      // 如果已经判断是包含在容器中时,直接退出
      if (isContain) {
        break;
      }
    }

    return isContain;
  }

  /**
   * 对频繁项集做剪枝步骤,必须保证新的频繁项集的子项集也必须是频繁项集
   */
  private ArrayList<FrequentItem> cutItem(ArrayList<String[]> resultIds) {
    String[] temp;
    // 忽略的索引位置,以此构建子集
    int igNoreIndex = 0;
    FrequentItem tempItem;
    // 剪枝生成新的频繁项集
    ArrayList<FrequentItem> newItem = new ArrayList<>();
    // 不符合要求的id
    ArrayList<String[]> deleteIdArray = new ArrayList<>();
    // 子项集是否也为频繁子项集
    boolean isContain = true;

    for (String[] array : resultIds) {
      // 列举出其中的一个个的子项集,判断存在于频繁项集列表中
      temp = new String[array.length - 1];
      for (igNoreIndex = 0; igNoreIndex < array.length; igNoreIndex++) {
        isContain = true;
        for (int j = 0, k = 0; j < array.length; j++) {
          if (j != igNoreIndex) {
            temp[k] = array[j];
            k++;
          }
        }

        if (!isIDArrayContains(resultItemID, temp)) {
          isContain = false;
          break;
        }
      }

      if (!isContain) {
        deleteIdArray.add(array);
      }
    }

    // 移除不符合条件的ID组合
    resultIds.removeAll(deleteIdArray);

    // 移除支持度计数不够的id集合
    int tempCount = 0;
    for (String[] array : resultIds) {
      tempCount = 0;
      for (String[] array2 : totalGoodsIDs) {
        if (isStrArrayContain(array2, array)) {
          tempCount++;
        }
      }

      // 如果支持度计数大于等于最小最小支持度计数则生成新的频繁项集,并加入结果集中
      if (tempCount >= minSupportCount) {
        tempItem = new FrequentItem(array, tempCount);
        newItem.add(tempItem);
        resultItemID.add(array);
        resultItem.add(tempItem);
      }
    }

    return newItem;
  }

  /**
   * 数组array2是否包含于array1中,不需要完全一样
   * 
   * @param array1
   * @param array2
   * @return
   */
  private boolean isStrArrayContain(String[] array1, String[] array2) {
    boolean isContain = true;
    for (String s2 : array2) {
      isContain = false;
      for (String s1 : array1) {
        // 只要s2字符存在于array1中,这个字符就算包含在array1中
        if (s2.equals(s1)) {
          isContain = true;
          break;
        }
      }

      // 一旦发现不包含的字符,则array2数组不包含于array1中
      if (!isContain) {
        break;
      }
    }

    return isContain;
  }

  /**
   * 根据产生的频繁项集输出关联规则
   * 
   * @param minConf
   *            最小置信度阈值
   */
  public void printAttachRule(double minConf) {
    // 进行连接和剪枝操作
    computeLink();

    int count1 = 0;
    int count2 = 0;
    ArrayList<String> childGroup1;
    ArrayList<String> childGroup2;
    String[] group1;
    String[] group2;
    // 以最后一个频繁项集做关联规则的输出
    String[] array = resultItem.get(resultItem.size() - 1).getIdArray();
    // 子集总数,计算的时候除去自身和空集
    int totalNum = (int) Math.pow(2, array.length);
    String[] temp;
    // 二进制数组,用来代表各个子集
    int[] binaryArray;
    // 除去头和尾部
    for (int i = 1; i < totalNum - 1; i++) {
      binaryArray = new int[array.length];
      numToBinaryArray(binaryArray, i);

      childGroup1 = new ArrayList<>();
      childGroup2 = new ArrayList<>();
      count1 = 0;
      count2 = 0;
      // 按照二进制位关系取出子集
      for (int j = 0; j < binaryArray.length; j++) {
        if (binaryArray[j] == 1) {
          childGroup1.add(array[j]);
        } else {
          childGroup2.add(array[j]);
        }
      }

      group1 = new String[childGroup1.size()];
      group2 = new String[childGroup2.size()];

      childGroup1.toArray(group1);
      childGroup2.toArray(group2);

      for (String[] a : totalGoodsIDs) {
        if (isStrArrayContain(a, group1)) {
          count1++;

          // 在group1的条件下,统计group2的事件发生次数
          if (isStrArrayContain(a, group2)) {
            count2++;
          }
        }
      }

      // {A}-->{B}的意思为在A的情况下发生B的概率
      System.out.print("{");
      for (String s : group1) {
        System.out.print(s + ", ");
      }
      System.out.print("}-->");
      System.out.print("{");
      for (String s : group2) {
        System.out.print(s + ", ");
      }
      System.out.print(MessageFormat.format(
          "},confidence(置信度):{0}/{1}={2}", count2, count1, count2
              * 1.0 / count1));
      if (count2 * 1.0 / count1 < minConf) {
        // 不符合要求,不是强规则
        System.out.println("由于此规则置信度未达到最小置信度的要求,不是强规则");
      } else {
        System.out.println("为强规则");
      }
    }

  }

  /**
   * 数字转为二进制形式
   * 
   * @param binaryArray
   *            转化后的二进制数组形式
   * @param num
   *            待转化数字
   */
  private void numToBinaryArray(int[] binaryArray, int num) {
    int index = 0;
    while (num != 0) {
      binaryArray[index] = num % 2;
      index++;
      num /= 2;
    }
  }

}
调用类:
/**
 * apriori关联规则挖掘算法调用类
 * @author lyq
 *
 */
public class Client {
  public static void main(String[] args){
    String filePath = "C:\\Users\\lyq\\Desktop\\icon\\testInput.txt";
    
    AprioriTool tool = new AprioriTool(filePath, 2);
    tool.printAttachRule(0.7);
  }
}
输出的结果:
频繁1项集:
{1,},{2,},{3,},{4,},{5,},
频繁2项集:
{1,2,},{1,3,},{1,5,},{2,3,},{2,4,},{2,5,},
频繁3项集:
{1,2,3,},{1,2,5,},
频繁4项集:

{1, }-->{2, 5, },confidence(置信度):2/6=0.333由于此规则置信度未达到最小置信度的要求,不是强规则
{2, }-->{1, 5, },confidence(置信度):2/7=0.286由于此规则置信度未达到最小置信度的要求,不是强规则
{1, 2, }-->{5, },confidence(置信度):2/4=0.5由于此规则置信度未达到最小置信度的要求,不是强规则
{5, }-->{1, 2, },confidence(置信度):2/2=1为强规则
{1, 5, }-->{2, },confidence(置信度):2/2=1为强规则
{2, 5, }-->{1, },confidence(置信度):2/2=1为强规则

程序算法的问题和技巧

在实现Apiori算法的时候,碰到的一些问题和待优化的点特别要提一下:

1、首先程序的运行效率不高,里面有大量的for嵌套循环叠加上循环,当然这有本身算法的原因(连接运算所致)还有我的各个的方法选择,很多一部分用来比较字符串数组。

2、这个是我觉得会是程序的一个漏洞,当生成的候选项集加入resultItemId时,会出现{1, 2, 3}和{3, 2, 1}会被当成不同的侯选集,未做顺序的判断。

3、程序的调试过程中由于未按照从小到大的排序,导致,生成的候选集与真实值不一致的情况,所以这里必须在频繁1项集的时候就应该是有序的。

4、在输出关联规则的时候,用到了数字转二进制数组的形式,输出他的各个非空子集,然后最出关联规则的判断。

Apriori算法的缺点

此算法的的应用非常广泛,但是他在运算的过程中会产生大量的侯选集,而且在匹配的时候要进行整个数据库的扫描,因为要做支持度计数的统计操作,在小规模的数据上操作还不会有大问题,如果是大型的数据库上呢,他的效率还是有待提高的。