现在的位置: 首页 > 综合 > 正文

贝叶斯的JAVA分类器实现

2012年11月28日 ⁄ 综合 ⁄ 共 10049字 ⁄ 字号 评论关闭

注:本算法的实现仅仅适用于小规模数据集的实验与测试,不适合用于工程应用

算法假定训练数据各属性列的值均是离散类型的。若是非离散类型的数据,需要首先进行数据的预处理,将非离散型的数据离散化。

算法中使用到了DecimalCaculate类,该类是java中BigDecimal类的扩展,用于高精度浮点数的运算。该类的实现同本人转载的一篇博文:对BigDecimal常用方法的归类中的Arith类相同。

算法实现的代码如下

  1. package Bayes; 
  2. import java.util.ArrayList; 
  3. import java.util.HashMap; 
  4. import java.util.Map; 
  5. import util.DecimalCalculate; 
  6. /**
  7. * 贝叶斯主体类
  8. * @author Rowen
  9. * @qq 443773264
  10. * @mail luowen3405@163.com
  11. * @blog blog.csdn.net/luowen3405
  12. * @data 2011.03.15
  13. */ 
  14. public class Bayes { 
  15.     /**
  16.      * 将原训练元组按类别划分
  17.      * @param datas 训练元组
  18.      * @return Map<类别,属于该类别的训练元组>
  19.      */ 
  20.     Map<String, ArrayList<ArrayList<String>>> datasOfClass(ArrayList<ArrayList<String>> datas){ 
  21.         Map<String, ArrayList<ArrayList<String>>> map =
    new
    HashMap<String, ArrayList<ArrayList<String>>>(); 
  22.         ArrayList<String> t = null
  23.         String c = ""
  24.         for (int i =
    0; i < datas.size(); i++) { 
  25.             t = datas.get(i); 
  26.             c = t.get(t.size() - 1); 
  27.             if (map.containsKey(c)) { 
  28.                 map.get(c).add(t); 
  29.             } else
  30.                 ArrayList<ArrayList<String>> nt =
    new
    ArrayList<ArrayList<String>>(); 
  31.                 nt.add(t); 
  32.                 map.put(c, nt); 
  33.             } 
  34.         } 
  35.         return map; 
  36.     } 
  37.      
  38.     /**
  39.      * 在训练数据的基础上预测测试元组的类别
  40.      * @param datas 训练元组
  41.      * @param testT 测试元组
  42.      * @return 测试元组的类别
  43.      */ 
  44.     public String predictClass(ArrayList<ArrayList<String>> datas, ArrayList<String> testT) { 
  45.         Map<String, ArrayList<ArrayList<String>>> doc =
    this
    .datasOfClass(datas); 
  46.         Object classes[] = doc.keySet().toArray(); 
  47.         double maxP =
    0.00
  48.         int maxPIndex = -1
  49.         for (int i =
    0; i < doc.size(); i++) { 
  50.             String c = classes[i].toString();  
  51.             ArrayList<ArrayList<String>> d = doc.get(c); 
  52.             double pOfC = DecimalCalculate.div(d.size(), datas.size(),
    3); 
  53.             for (int j =
    0; j < testT.size(); j++) { 
  54.                 double pv = this.pOfV(d, testT.get(j), j); 
  55.                 pOfC = DecimalCalculate.mul(pOfC, pv); 
  56.             } 
  57.             if(pOfC > maxP){ 
  58.                 maxP = pOfC; 
  59.                 maxPIndex = i; 
  60.             } 
  61.         } 
  62.         return classes[maxPIndex].toString(); 
  63.     } 
  64.     /**
  65.      * 计算指定属性列上指定值出现的概率
  66.      * @param d 属于某一类的训练元组
  67.      * @param value 列值
  68.      * @param index 属性列索引
  69.      * @return 概率
  70.      */ 
  71.     private double pOfV(ArrayList<ArrayList<String>> d, String value,
    int index) { 
  72.         double p = 0.00
  73.         int count =
    0
  74.         int total = d.size(); 
  75.         ArrayList<String> t = null
  76.         for (int i =
    0; i < total; i++) { 
  77.             if(d.get(i).get(index).equals(value)){ 
  78.                 count++; 
  79.             } 
  80.         } 
  81.         p = DecimalCalculate.div(count, total,
    3
    ); 
  82.         return p; 
  83.     } 

算法测试类:

  1. package Bayes; 
  2. import java.io.BufferedReader; 
  3. import java.io.IOException; 
  4. import java.io.InputStreamReader; 
  5. import java.util.ArrayList; 
  6. import java.util.StringTokenizer; 
  7. /**
  8. * 贝叶斯算法测试类
  9. * @author Rowen
  10. * @qq 443773264
  11. * @mail luowen3405@163.com
  12. * @blog blog.csdn.net/luowen3405
  13. * @data 2011.03.15
  14. */ 
  15. public class TestBayes { 
  16.     /**
  17.      * 读取测试元组
  18.      * @return 一条测试元组
  19.      * @throws IOException
  20.      */ 
  21.     public ArrayList<String> readTestData()
    throws IOException{ 
  22.         ArrayList<String> candAttr = new ArrayList<String>(); 
  23.         BufferedReader reader = new BufferedReader(new InputStreamReader(System.in)); 
  24.         String str = ""
  25.         while (!(str = reader.readLine()).equals("")) { 
  26.             StringTokenizer tokenizer = new StringTokenizer(str); 
  27.             while (tokenizer.hasMoreTokens()) { 
  28.                 candAttr.add(tokenizer.nextToken()); 
  29.             } 
  30.         } 
  31.         return candAttr; 
  32.     } 
  33.      
  34.     /**
  35.      * 读取训练元组
  36.      * @return 训练元组集合
  37.      * @throws IOException
  38.      */ 
  39.     public ArrayList<ArrayList<String>> readData()
    throws IOException { 
  40.         ArrayList<ArrayList<String>> datas = new ArrayList<ArrayList<String>>(); 
  41.         BufferedReader reader = new BufferedReader(new InputStreamReader(System.in)); 
  42.         String str = ""
  43.         while (!(str = reader.readLine()).equals("")) { 
  44.             StringTokenizer tokenizer = new StringTokenizer(str); 
  45.             ArrayList<String> s = new ArrayList<String>(); 
  46.             while (tokenizer.hasMoreTokens()) { 
  47.                 s.add(tokenizer.nextToken()); 
  48.             } 
  49.             datas.add(s); 
  50.         } 
  51.         return datas; 
  52.     } 
  53.      
  54.     public static
    void main(String[] args) { 
  55.         TestBayes tb = new TestBayes(); 
  56.         ArrayList<ArrayList<String>> datas = null
  57.         ArrayList<String> testT = null
  58.         Bayes bayes = new Bayes(); 
  59.         try
  60.             System.out.println("请输入训练数据"); 
  61.             datas = tb.readData(); 
  62.             while (true) { 
  63.                 System.out.println("请输入测试元组"); 
  64.                 testT = tb.readTestData(); 
  65.                 String c = bayes.predictClass(datas, testT); 
  66.                 System.out.println("The class is: " + c); 
  67.             } 
  68.         } catch (IOException e) { 
  69.             e.printStackTrace(); 
  70.         } 
  71.     } 

训练数据:

  1. youth high no fair no 
  2. youth high no excellent no 
  3. middle_aged high no fair yes 
  4. senior medium no fair yes 
  5. senior low yes fair yes 
  6. senior low yes excellent no 
  7. middle_aged low yes excellent yes 
  8. youth medium no fair no 
  9. youth low yes fair yes 
  10. senior medium yes fair yes 
  11. youth medium yes excellent yes 
  12. middle_aged medium no excellent yes 
  13. middle_aged high yes fair yes 
  14. senior medium no excellent no 

对原训练数据进行测试,测试如果如下:

  1. 请输入测试元组 
  2. youth high no fair 
  3. The class is: no 
  4. 请输入测试元组 
  5. youth high no excellent 
  6. The class is: no 
  7. 请输入测试元组 
  8. middle_aged high no fair 
  9. The class is: yes 
  10. 请输入测试元组 
  11. senior medium no fair 
  12. The class is: yes 
  13. 请输入测试元组 
  14. senior low yes fair 
  15. The class is: yes 
  16. 请输入测试元组 
  17. senior low yes excellent 
  18. The class is: yes 
  19. 请输入测试元组 
  20. middle_aged low yes excellent 
  21. The class is: yes 
  22. 请输入测试元组 
  23. youth medium no fair 
  24. The class is: no 
  25. 请输入测试元组 
  26. youth low yes fair 
  27. The class is: yes 
  28. 请输入测试元组 
  29. senior medium yes fair 
  30. The class is: yes 
  31. 请输入测试元组 
  32. youth medium yes excellent 
  33. The class is: yes 
  34. 请输入测试元组 
  35. middle_aged medium no excellent 
  36. The class is: yes 
  37. 请输入测试元组 
  38. middle_aged high yes fair 
  39. The class is: yes 
  40. 请输入测试元组 
  41. senior medium no excellent 
  42. The class is: no 

测试结果显示14个测试实例中有13个分类是正确的,正确率为93%,说明算法能够给出一个准确的预测与分类,但是算法还需改进以提高正确率。

改进的可选方法之一:

为避免单个属性值对分类结果的权重过大,例如当某属性值在某一类中出现0次时,该属性值就决定了测试实例已经不可能属于该类了,这就可能会造成误差,因此在计算概率时可能进行如下改进:

将原先的P(Xk|Ci)=|Xk| / |Ci| 改为 P(Xk|Ci)=(|Xk|+mp)
/ (|Ci|+m),其中m可设定为训练元组的个数,p为等可能假设的先验概率。

抱歉!评论已关闭.