现在的位置: 首页 > 综合 > 正文

数据挖掘–频繁集测试–Apriori算法–java实现

2018年05月12日 ⁄ 综合 ⁄ 共 6059字 ⁄ 字号 评论关闭

2013年11月19日注:以下算法中,combine算法实现不正确,应该是从已有的频繁中来产生。需要进一步修改

=================================================================================

Apriori算法原理:

如果某个项集是频繁的,那么它所有的子集也是频繁的。如果一个项集是非频繁的,那么它所有的超集也是非频繁的。

示意图

图一:

图二:

package cn.ffr.frequent.apriori;

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.net.URL;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
 * Apriori的核心代码实现
 * @author neu_fufengrui@163.com
 */
public class Apriori {
	
	public static final String STRING_SPLIT = ",";
	
	/**
	 * 主要的计算方法
	 * @param data 数据集
	 * @param minSupport 最小支持度
	 * @param maxLoop 最大执行次数,设NULL为获取最终结果
	 * @param containSet 结果中必须包含的子集
	 * @return
	 */
	public Map<String, Double> compute(List<String[]> data, Double minSupport, Integer maxLoop, String[] containSet){
	
		//校验
		if(data == null || data.size() <= 0){
			return null;
		}
		
		//初始化
		Map<String, Double> result = new HashMap<String, Double>();
		Object[] itemSet = getDataUnitSet(data);
		int loop = 0;
		//核心循环处理过程
		while(true){
			//重要步骤一:合并,产生新的频繁集
			Set<String> keys = combine(result.keySet(), itemSet);
			result.clear();//移除之前的结果
			for(String key : keys){				
				result.put(key, computeSupport(data, key.split(STRING_SPLIT)));
			}
			//重要步骤二:修剪,去除支持度小于条件的。
			cut(result, minSupport, containSet);
			loop++;
			//输出计算过程
			System.out.println("loop ["+loop+"], result : "+result);
			//循环结束条件
			if(result.size() <= 0){
				break;
			}
			if(maxLoop != null && maxLoop > 0 && loop >= maxLoop){//可控制循环执行次数
				break;
			}
		}
		return result;
	}

	/**
	 * 计算子集的支持度
	 * 
	 * 支持度 = 子集在数据集中的数据项 / 总的数据集的数据项
	 * 
	 * 数据项的意思是一条数据。
	 * @param data 数据集
	 * @param subSet 子集 
	 * @return
	 */
	public Double computeSupport(List<String[]> data, String[] subSet){
		Integer value = 0;
		for(int i = 0; i < data.size(); i++){
			if(contain(data.get(i), subSet)){
				value ++;
			}
		}
		return value*1.0/data.size();
	}
	/**
	 * 获得初始化唯一的数据集,用于初始化
	 * @param data
	 * @return
	 */
	public Object[] getDataUnitSet(List<String[]> data){
		List<String> uniqueKeys = new ArrayList<String>();
		for(String[] dat : data){
			for(String da : dat){
				if(!uniqueKeys.contains(da)){
					uniqueKeys.add(da);
				}
			}
		}
		return uniqueKeys.toArray();
	}
	/**
	 * 合并src和target来获取频繁集
	 * 增加频繁集的计算维度
	 * @param src
	 * @param target
	 * @return
	 */
	public Set<String> combine(Set<String> src, Object[] target){
		Set<String> dest = new HashSet<String>();
		if(src == null || src.size() <= 0){
			for(Object t : target){
				dest.add(t.toString());
			}
			return dest;
		}
		for(String s : src){
			for(Object t : target){
				if(s.indexOf(t.toString())<0){
					String key = s+STRING_SPLIT+t;
					if(!contain(dest, key)){
						dest.add(key);
					}
				}
			}
		}
		return dest;
	}
	/**
	 * dest集中是否包含了key
	 * @param dest
	 * @param key
	 * @return
	 */
	public boolean contain(Set<String> dest, String key){
		for(String d : dest){
			if(equal(d.split(STRING_SPLIT), key.split(STRING_SPLIT))){
				return true;
			}
		}
		return false;
	}
	/**
	 * 移除结果中,支持度小于所需要的支持度的结果。
	 * @param result
	 * @param minSupport
	 * @return
	 */
	public Map<String, Double> cut(Map<String, Double> result, Double minSupport, String[] containSet){
		for(Object key : result.keySet().toArray()){//防止 java.util.ConcurrentModificationException,使用keySet().toArray()
			if(minSupport != null && minSupport > 0 && minSupport < 1 && result.get(key) < minSupport){//比较支持度
				result.remove(key);
			}
			if(containSet != null && containSet.length > 0 && !contain(key.toString().split(STRING_SPLIT), containSet)){
				result.remove(key);
			}
		}
		return result;
	}
	/**
	 * src中是否包含dest,需要循环遍历查询
	 * @param src
	 * @param dest
	 * @return
	 */
	public static boolean contain(String[] src, String[] dest){
		for(int i = 0; i < dest.length; i++){
			int j = 0;
			for(; j < src.length; j++){
				if(src[j].equals(dest[i])){
					break;
				}
			}
			if(j == src.length){
				return false;//can not find
			}
		}
		return true;
	}
	/**
	 * src是否与dest相等
	 * @param src
	 * @param dest
	 * @return
	 */
	public boolean equal(String[] src, String[] dest){
		if(src.length == dest.length && contain(src, dest)){
			return true;
		}
		return false;
	}
	/**
	 * 主测试方法
	 * 测试方法,挨个去掉注释,进行测试。
	 */
	public static void main(String[] args) throws Exception{
		//test 1
//		List<String[]> data = loadSmallData();
//		Long start = System.currentTimeMillis();
//		Map<String, Double> result = new Apriori().compute(data, 0.5, 3, null);//求支持度大于指定值
//		Long end = System.currentTimeMillis();
//		System.out.println("Apriori Result [costs:"+(end-start)+"ms]: ");
//		for(String key : result.keySet()){
//			System.out.println("\tFrequent Set=["+key+"] & Support=["+result.get(key)+"];");
//		}
		
		//test 2
//		List<String[]> data = loadMushRoomData();
//		Long start = System.currentTimeMillis();
//		Map<String, Double> result = new Apriori().compute(data, 0.3, 4, new String[]{"2"});//求支持度大于指定值
//		Long end = System.currentTimeMillis();
//		System.out.println("Apriori Result [costs:"+(end-start)+"ms]: ");
//		for(String key : result.keySet()){
//			System.out.println("\tFrequent Set=["+key+"] & Support=["+result.get(key)+"];");
//		}
		
		//test 3
		List<String[]> data = loadChessData();
		Long start = System.currentTimeMillis();
		Map<String, Double> result = new Apriori().compute(data, 0.95, 3, null);//求支持度大于指定值
		Long end = System.currentTimeMillis();
		System.out.println("Apriori Result [costs:"+(end-start)+"ms]: ");
		for(String key : result.keySet()){
			System.out.println("\tFrequent Set=["+key+"] & Support=["+result.get(key)+"];");
		}
	}
	/*
	 *	SmallData: minSupport 0.5, maxLoop 3, containSet null, [costs: 16ms]
	 *	MushRoomData: minSupport 0.3, maxLoop 4, containSet {"2"}, [costs: 103250ms]	
	 *	ChessData: minSupport 0.95, maxLoop 34, containSet {null, [costs: 9718ms]
	 */
	
	//测试数据集-1
	public static List<String[]> loadSmallData() throws Exception{
		List<String[]> data = new ArrayList<String[]>();
		data.add(new String[]{"d1","d3","d4"});
		data.add(new String[]{"d2","d3","d5"});
		data.add(new String[]{"d1","d2","d3","d5"});
		data.add(new String[]{"d2","d5"});
		return data;
	}
	
	//测试数据集-2
	public static List<String[]> loadMushRoomData() throws Exception{
		String link = "http://fimi.ua.ac.be/data/mushroom.dat";
		URL url = new URL(link);
		BufferedReader reader = new BufferedReader(new InputStreamReader(url.openStream()));
		String temp = reader.readLine();
		List<String[]> result = new ArrayList<String[]>();
		int lineNumber = 0;
		while(temp != null){
			System.out.println("reading data... [No."+(++lineNumber)+"]");
			String[] item = temp.split(" ");
			result.add(item);
			temp = reader.readLine();
		}
		reader.close();
		return result;
	}
	
	//测试数据集-3
	public static List<String[]> loadChessData() throws Exception{
		String link = "http://fimi.ua.ac.be/data/chess.dat";
		URL url = new URL(link);
		BufferedReader reader = new BufferedReader(new InputStreamReader(url.openStream()));
		String temp = reader.readLine();
		List<String[]> result = new ArrayList<String[]>();
		int lineNumber = 0;
		while(temp != null){
			System.out.println("reading data... [No."+(++lineNumber)+"]");
			String[] item = temp.split(" ");
			result.add(item);
			temp = reader.readLine();
		}
		reader.close();
		return result;
	}
}

算法原理:

抱歉!评论已关闭.