现在的位置: 首页 > 综合 > 正文

plsa代码实现

2013年10月18日 ⁄ 综合 ⁄ 共 5775字 ⁄ 字号 评论关闭

plsa的代码实现,plsa的原理可参考这个:http://luxinxin.is-programmer.com/user_files/luxinxin/File/plsanote.pdf

plsa这里使用em算法来估计其中的参数,已知变量是:文档、单词,未知变量时:主题,假设是:p(w|z),p(z|d)。

用EM算法来估计参数主要有两步,在plsa中,E部是根据假设求后验概率P(z|w,d),M部是通过最大化似然函数来求p(w|z),p(z|d),重新估计假设,然后在用假设求后验概率......以此循环。具体的公式可参考上面的文档。

具体的代码为:

import java.io.File;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import com.bj58.data.dataming.machinelearning.util.FileUtil;
import com.bj58.data.dataming.machinelearning.util.IKParticiple;
import com.bj58.data.dataming.machinelearning.util.Participle;
import com.bj58.data.dataming.machinelearning.util.Probability;

public class Plsa2 {
	
	private static Participle participle = new IKParticiple();
	private static Set<String> wordSet = new HashSet<String>();
	private static List<String> topicList = new ArrayList<String>();
	private static List<String> docList = new ArrayList<String>();
	private static Map<String,Double> dtmap = new HashMap<String,Double>(); 
	private static Map<String,Double> twmap = new HashMap<String,Double>(); 
	private static Map<String,Set<String>> dwSetMap = new HashMap<String,Set<String>>();
	private static Map<String,Double> dtwmap = new HashMap<String,Double>();
	private static Map<String,Integer> dwcount = new HashMap<String,Integer>();
	
	static{
		topicList.add("1");
		topicList.add("2");
		topicList.add("3");
	}
	
	public static void main(String[] args){
		
		String trainPath = "E:\\学习\\文本处理\\data\\SogouC.reduced.20061127\\SogouC.reduced\\plsa";
		readFile(trainPath);
		
		//初始化假设p(z|d)
		for(String doc : docList){
			double all = 0d;
			for(String topic : topicList){
				double random = Math.random();
				all += random;
				Probability p = new Probability(doc,topic);
				dtmap.put(p.toString(), random);
			}
			for(String topic : topicList){
				Probability p = new Probability(doc,topic);
				dtmap.put(p.toString(), dtmap.get(p.toString())/all);
			}
		}
		
		//初始化假设p(w|z)
		for(String topic : topicList){
			double all = 0d;
			for(String word : wordSet){
				double random = Math.random();
				Probability p = new Probability(topic,word);
				twmap.put(p.toString(), random);
				all += random;
			}
			
			for(String word : wordSet){
				Probability p = new Probability(topic,word);
				twmap.put(p.toString(), twmap.get(p.toString())/all);
			}
		}
		
		for(int i=0;i<100;i++){
			E();
			M();
		}
		
		//排序
		List<Map.Entry<String, Double>> orderList=new ArrayList<Map.Entry<String, Double>>(twmap.entrySet());  
		Collections.sort(orderList, new Comparator<Map.Entry<String, Double>>() {   
		                public int compare(Map.Entry<String, Double> o1, Map.Entry<String, Double> o2) {   
		                    return o2.getValue() > o1.getValue()? 1:-1;   
		                }   
		            }); 

		Map<String,List<String>> ml = new HashMap<String,List<String>>();
		Set<String> set  = new HashSet<String>();
		for(Map.Entry<String, Double> me : orderList){
			String [] words =me.getKey().split(",");
			if(set.contains(words[1]))
				continue;
			set.add(words[1]);
			if(ml.get(words[0]) == null){
				List<String> list = new ArrayList<String>();
				ml.put(words[0], list);
			}
			ml.get(words[0]).add(words[1]);
		}
		for(Map.Entry<String, List<String>> me : ml.entrySet()){
			System.out.println(me.getKey()+"="+me.getValue().toString());
		}
	}
	
	
	private static void E(){
		
		for(Map.Entry<String, Set<String>> me : dwSetMap.entrySet()){
			String doc = me.getKey();
			for( String word : me.getValue()){
				double fenmu = 0d;
				for(String topic : topicList){
					Probability pdt = new Probability(doc,topic);
					Probability ptw = new Probability(topic,word);
					double fenzi = dtmap.get(pdt.toString())*twmap.get(ptw.toString());
					fenmu += fenzi;
					Probability pdtw = new Probability(doc,topic,word);
					dtwmap.put(pdtw.toString(), fenzi);
				}
				for(String topic : topicList){
					Probability pdtw = new Probability(doc,topic,word);
					dtwmap.put(pdtw.toString(), dtwmap.get(pdtw.toString())/fenmu);
				}
			}
		}
	}
	
	private static void M(){
		
		//重新估计p(z|d)
		for( String doc : docList){
			
			int fenmu = 0;
			for(String word : dwSetMap.get(doc)){
				Probability pdw = new Probability(doc,word);
				fenmu += dwcount.get(pdw.toString());
			}
			
			for(String topic : topicList){
				double fenzi = 0d;
				for(String word : dwSetMap.get(doc)){
					Probability pdw = new Probability(doc,word);
					Probability pdtw = new Probability(doc,topic,word);
					fenzi += dwcount.get(pdw.toString())*dtwmap.get(pdtw.toString());
				}
				Probability pdt = new Probability(doc,topic);
				dtmap.put(pdt.toString(), fenzi/fenmu);
			}
		}
		//重新估计p(w|z)
		for(String topic : topicList){
			
			double fenmu = 0d;
			for(String doc : docList){
				for(String word : dwSetMap.get(doc)){
					Probability pdw = new Probability(doc,word);
					Probability pdtw = new Probability(doc,topic,word);
					fenmu += dwcount.get(pdw.toString())*dtwmap.get(pdtw.toString());
				}
			}
			
			for(String word : wordSet){
				
				double fenzi = 0d;
				for(String doc : docList){
					Probability pdw = new Probability(doc,word);
					Probability pdtw = new Probability(doc,topic,word);
					if(dwcount.get(pdw.toString())==null)
						continue;
					if( dtwmap.get(pdtw.toString())==null)
						continue;
					fenzi += dwcount.get(pdw.toString())*dtwmap.get(pdtw.toString());
				}
				Probability ptw = new Probability(topic,word);
				twmap.put(ptw.toString(), fenzi/fenmu);
			}
		}
	}
	
	
	private static void readFile(String path){
		
		File ppfile  = new File(path);
		if( ppfile.isDirectory()){
			File [] pfiles = ppfile.listFiles();
			for( int i=0;i<pfiles.length;i++){
				String categoryName = pfiles[i].getName();
				File[] files = pfiles[i].listFiles();
				for( File file : files){
					String docName = categoryName+","+file.getName();
					docList.add(docName);
					List<String> fileContent = FileUtil.readFile(file.getPath(),"GB2312");
					List<String> participleResult = participle.getParticipleResultList(fileContent.toString());
					Set<String> set = new HashSet<String>();
					set.addAll(participleResult);
					dwSetMap.put(docName, set);
					wordSet.addAll(participleResult);
					for(String word : participleResult){
						Probability p = new Probability(docName,word);
						if( dwcount.get(p.toString()) == null){
							dwcount.put(p.toString(), 1);
						}else{
							dwcount.put(p.toString(), dwcount.get(p.toString())+1);
						}
					}
				}
			}
		}
	}
}

其中一个实体类为:

public class Probability {

	private String str1;
	private String str2 ;
	private String str3;
	
	public Probability(String doc, String topic,String word){
		this.str1 = doc;
		this.str2 = topic;
		this.str3 = word;
	}
	
	public Probability(String str1, String str2){
		this.str1 = str1;
		this.str2 = str2;
	}

	@Override
	public String toString() {
		if( str1!= null && str2 != null && str3 != null)
			return str1+","+str2+","+str3;
		else
			return str1+","+str2;
	}
}

抱歉!评论已关闭.