现在的位置: 首页 > 综合 > 正文

隐马模型Python代码

2017年12月12日 ⁄ 综合 ⁄ 共 13266字 ⁄ 字号 评论关闭
logprobs.py
import numpy as np
from math import exp, log

def sumLogProb2(log_prob1, log_prob2):
    if(np.isinf(log_prob1) and np.isinf(log_prob2)):
	return log_prob1
    elif(log_prob1 > log_prob2):
	return log_prob1 + log(1 + exp(log_prob2 - log_prob1))
    else:
	return log_prob2 + log(1 + exp(log_prob1 - log_prob2))
    
def sumLogProb1(log_probs):
    _max = 0
    for i in range(0, len(log_probs)):
	if(i == 0 or log_probs[i] > _max):
	    _max = log_probs[i]
	    
    if(np.isinf(_max)):
	return _max
    p = 0.0
    for i in range(0, len(log_probs)):
	p += exp(log_probs[i] - _max)
    if (p == 0):
	return -1e20
    else:
	return _max + log(p)


str2idmap.py
class Str2IdMap():
    def __init__(self):
        self._to_id = {}
        self._to_str = []
    def getStr(self, id):
        return self._to_str[id]
    def getId(self, str):
        if(self._to_id.has_key(str) == False):
            id = len(self._to_id.items())
            self._to_id[str] = id
            self._to_str.append(str)
            return id
        else:
            return self._to_id[str]
import str2idmap
import math
from logprobs import *

SMOOTHEDZEROCOUNT=-40
class OneDTable(dict):
    def __init__(self):
	self._smoothed_zero_count = SMOOTHEDZEROCOUNT
    def smoothedZeroCount(self):
	return self._smoothed_zero_count
    
    def getValue(self, event):
	if(self.has_key(event) == False):
	    return self.smoothedZeroCount()
	else:
	    return self.get(event)
    def add(self, event, count):
	if(self.has_key(event) == False):
	    self[event] = count #
	else:
	    self[event] = sumLogProb2(self.get(event), count)
    def rand(self, next):
	p = random.uniform(0, 0x7fff) / 0x7fff
	total = 0
	for key in self.keys:
	    total = total + exp(self.get(key))
	    if(total >= p):
		next = key
		return [next, True]
	return [next, False]
	
tables.py
class TwoDTable(dict):
    def __init__(self):
	self._possibleContexts = {}
	self._backoff = OneDTable()
    def get(self, event, context):
	if(self.has_key(context) == False):
	    return self._backoff.getValue(event)
	else:
	    return self[context].getValue(event)
    def add(self, context, event, count):
	entry = OneDTable()
	if(self.has_key(context) == False):
	    self[context] = entry
	else:
	    entry = self[context]
	entry.add(event, count)
	#
	possCntx = set()
	if(self._possibleContexts.has_key(event) == False):    
	    self._possibleContexts[event] = possCntx
	else:
	    possCntx = self._possibleContexts[event]
	possCntx.add(context)
    def getCntx(self, event):
	if(self._possibleContexts.has_key(event) == False):
	    return 0
	else:
	    return self._possibleContexts[event]
    def load(self, input_file, str2id):
	while(True):
	    line = input_file.readline()
	    if(line == '\n' or line == ''):
		break	    
	    eles = line.split(' ')
	    if(len(eles) < 3):
		eles = line.split('\t')
	    p = []
	    for ele in eles:
		if(ele != '\n'):
		    p.append(ele)
	    if(float(p[2]) > 0.0):
		self.add(str2id.getId(p[0]), str2id.getId(p[1]), math.log(float(p[2])))			
    def save(self, output_file, str2id):
	twodtable_keys = self.keys()
	for t_key in twodtable_keys:
	    vals = self[t_key]
	    onedtable_keys = vals.keys()
	    for o_key in onedtable_keys:
		output_file.write(str2id.getStr(t_key))
		output_file.write(' ')
		output_file.write(str2id.getStr(o_key))
		output_file.write(' ')
		output_file.write(str(exp(vals[o_key])))
		output_file.write('\n')	
    def rand(self, curr, next):
	if(self.has_key(curr) == False):
	    return False
	val = self.get(curr)
	return val.rand(next)

"""
Copyright ChaoZhong superclocks@163.com
"""
import math
import random
import numpy as np
from logprobs import  * #from my code
from tables import OneDTable,TwoDTable#from my code
from str2idmap import Str2IdMap
from sys import stderr

#A transition between two Hmm nodes.
class Transition():
	def __init__(self, from_node, to_node, obs ):
		self._from = from_node
		self._to = to_node
		self._obs = obs
		if(self._from and self._to):
			self._from.outs().append(self)
			self._to.ins().append(self)
#A node in an Hmm object.
class HmmNode():
	def __init__(self, time, state, hmm):
		self._time = time #The time slot for this node.
		self._state = state #The hmm that this node belongs to
		self._logAlpha = 0 #alpha_t(s) = P(e_1:t, x_t=s);
		self._logBeta = 0 #beta_t(s)  = P(e_t+1:T |x_t=s);	 	
		self._psi = 0 #the last transition of the most probable path that reaches this node
		self._hmm = hmm
		
		self._ins = list() #incoming transitions
		self._outs = list() #out going transitions
		
	def time(self):
		return self._time
	def state(self):
		return self._state
	def setLogAlpha(self, logAlpha):
		self._logAlpha = logAlpha
	def getLogAlpha(self):
		return self._logAlpha
	def setLogBeta(self, logBeta):
		self._logBeta = logBeta
	def getLogBeta(self):
		return self._logBeta
	def setPsi(self, psi):
		self._psi = psi
	def getPsi(self):
		return self._psi
	def ins(self):
		return self._ins
	def outs(self):
		return self._outs
	def _print(self):
		print('HmmNode')
# Pseudo Counts 
class PseudoCounts():
	def __init__(self):
		self._stateCount = OneDTable()
		self._transCount = TwoDTable()
		self._emitCount = TwoDTable()
	def getStateCount(self):
		return self._stateCount
	def getTransCount(self):
		return self._transCount
	def getEmitCount(self):
		return self._emitCount
	def _print(self, str2id):
		print('TRANSTION \n')
		#self._transCount.save()
#The possible states at a particular time slot.
class TimeSlot(list): #save HmmNode
	def __init__(self):
		pass
	
hmm.py
class Hmm:
	def __init__(self,init_state = 0, min_log_prob = 0.0000001 ):
		self._init_state = init_state #the initial state
		self._transition =  TwoDTable()#transition probablity
		self._emission = TwoDTable() #emission probablity
		self._str2id = Str2IdMap()  #mapping between strings and integers
		self._time_slots = list() # the time slots
		self._min_log_prob = min_log_prob 
	def loadProbs(self, path):
		''' 
		Read the transition and emission probability tables from the
		files NAME.trans and NAME.emit, where NAME is the value of the
		variable name.
		'''  
		trans_file_path = path + '.trans'
		trans_prob_reader = file(trans_file_path, 'r')
		init_state = trans_prob_reader.readline().split('\n')[0]
		self._init_state = self._str2id.getId(init_state)
		self._transition.load(trans_prob_reader, self._str2id)
		emit_file_path = path + '.emit'
		emit_prob_reader = file(emit_file_path, 'r')
		self._emission.load(emit_prob_reader, self._str2id)
	
	def readSeqs(self, input_file, sequences):
		''' Read the training data from the input stream. Each line in the
		      input stream is an observation sequence. 
		'''	
		while(True):
			ele_set = list()
			line = input_file.readline()
			if(line =='' or line == '\n'):
				break
			ele_in_line = line.split('\n')[0].split(' ')
			for ele in ele_in_line:
				if(ele != ' '):
					ele_set.append(self.getId(ele))
			sequences.append(ele_set)
	'''
	Conversion between the integer id and string form of states and
	observations.
	'''
	def getId(self, str):
		return self._str2id.getId(str)
	def getStr(self, id):
		return self._str2id.getStr(id)
	def addObservation(self, o):
		stateIds = list()
		cntx = self._emission.getCntx(o)
		if cntx == 0:
			keys = self._emission.keys()
			for key in keys:
				stateIds.append(key)
		else:
			for ele in cntx:
				stateIds.append(ele)
			
		if (len(self._time_slots) == 0):
			t0 = TimeSlot()
			t0.append(HmmNode(0, self._init_state,self))
			self._time_slots.append(t0)
		ts = TimeSlot()
		time = len(self._time_slots)
		for i in range(0, len(stateIds)):
			node = HmmNode(time, stateIds[i] , self)
			ts.append(node)
			prev = self._time_slots[time - 1]
			for it in prev:
				possibleSrc = self._transition.getCntx(node.state())
				if(len(possibleSrc) > 0 and possibleSrc.__contains__(it.state())):
					Transition(it , node, o)
		self._time_slots.append(ts)	
	def getTransProb(self, trans):
		return self._transition.get(trans._to.state(), trans._from.state())
	def getEmitProb(self, trans):
		return self._emission.get(trans._obs, trans._to.state())
	#compute the forward probabilities P(e_1:t, X_t=s)
	def forward(self):
		#computer forward probabilities at time 0
		t0 = self._time_slots[0] #TimeSlot
		init = t0[0] #HmmNode
		init.setLogAlpha(0.0)
		#computer forward probabilities at time t using the alpha values for time t-1
		for t in range(1, len(self._time_slots)):
			ts = self._time_slots[t] #get TimeSlot object at time t
			for it in ts: #it is list() type saved the HmmNode
				ins = it.ins() #get Transition list
				log_probs = list()
				for trans in ins:
					log_prob = trans._from.getLogAlpha() + self.getTransProb(trans) + self.getEmitProb(trans)
					log_probs.append(log_prob)
				it.setLogAlpha(sumLogProb1(log_probs))
	#compute the backward probabilities P(e_t+1:T | X_t=s)
	def backward(self):
		T = len(self._time_slots) - 1
		if(T < 1):  #no observation
			return
		time = range(0, T + 1)
		time.reverse()
		for t in time:
			ts = self._time_slots[t]
			for it in ts:
				node = it
				if t ==T:
					node.setLogBeta(0.0)
				else:
					outs = node.outs()
					log_probs = list()
					for i in range(0, len(outs)):
						trans = outs[i]
						log_prob = trans._to.getLogBeta() + \
						self.getTransProb(trans) + self.getEmitProb(trans)
						log_probs.append(log_prob)
					node.setLogBeta(sumLogProb1(log_probs))
	'''
	Accumulate pseudo counts using the BaumWelch algorithm.  The
	return value is the probability of the observations according to
	the current model. 
	'''
	def getPseudoCounts(self, counts):
		p_of_obs = self.obsProb()
		self.backward()
		# Compute the pseudo counts of transitions, emissions, and initializations
		for t in range(0, len(self._time_slots)):
			ts = self._time_slots[t]  #get TimeSlot object at time t
			'''
			P(X_t=s|e_1:T) = alpha_s(t)*beta_s(t)/P(e_t+1:T|e_1:t)
			The value sum below is log P(e_t+1:T|e_1:t)		
			'''			
			log_probs = list()
			for it in ts: #it equ TimeSlot
				log_probs.append(it.getLogAlpha() + it.getLogBeta())
			_sum = sumLogProb1(log_probs)
			#add the pseudo counts into counts
			for it in ts:
				node = it #HmmNode
				#stateCount=P(X_t=s|e_1:T)
				state_count = node.getLogAlpha() + node.getLogBeta() - _sum
				counts.getStateCount().add(node.state(), state_count)
				ins = node.ins() # vector<Transition*>
				for k in range(0, len(ins)):
					trans = ins[k] #Transition*
					_from = trans._from #HmmNode
					trans_count = _from.getLogAlpha() + self.getTransProb(trans) + self.getEmitProb(trans) + node.getLogBeta() - p_of_obs
					counts.getEmitCount().add(node.state(),trans._obs,trans_count)
				outs = node.outs() #vector<Transition*>
				for k in range(0, len(outs)):
					trans = outs[k] #Transition
					to = trans._to #HmmNode
					trans_count = node.getLogAlpha() + self.getTransProb(trans) + self.getEmitProb(trans)+to.getLogBeta() - p_of_obs; 
					counts.getTransCount().add(node.state(), to.state(), trans_count)
		return p_of_obs
	'''
	Find the state sequence (a path) that has the maximum
	probability given the sequence of observations:
	max_{x_1:T} P(x_1:T | e_1:T);
	The return value is the logarithm of the joint probability 
	of the state sequence and the observation sequence:
	log P(x_1:T, e_1:T)
	'''
	def viterbi(self, path):
		#set nodes at time 0 according to initial probabilities.
		ts = self._time_slots[0] #TimeSlot
		init = ts[0] #HmmNode
		init.setLogAlpha(0.0)
		#find the best path up to path t.
		for t in range(1, len(self._time_slots)):
			ts = self._time_slots[t] #ts is vector saved HmmNode
			for it in ts: #it is a HmmNode
				node = it
				ins = node.ins() #ins is vector saved Transition
				max_prob = -1e20
				best_trans = 0 #best_trans is Transition object
				for i in range(0, len(ins)):
					trans = ins[i]
					log_prob = trans._from.getLogAlpha() + \
					         self.getTransProb(trans) + self.getEmitProb(trans)
					if(best_trans == 0 or max_prob < log_prob):
						best_trans = trans
						max_prob = log_prob
				node.setLogAlpha(max_prob) #store the highest probability in logAlpha
				node.setPsi(best_trans) #store the best transition in psi
		#Find the best node at time T. It will be the last node in the best path
		ts = self._time_slots[len(self._time_slots) - 1]
		best = 0 #HmmNode*
		for it in ts:
			node = it
			if (best == 0 or best.getLogAlpha() < node.getLogAlpha()):
				best = node
		#retrieve the nodes in the best path	
		nd = best
		while(nd):
			if(nd.getPsi()):
				path.append(nd.getPsi())
				nd = nd.getPsi()._from
			else:
				nd = 0
		#reverse the path
		i = 0
		j = len(path) - 1
		while(True):
			tmp = path[i]
			path[i] = path[j]
			path[j] = tmp
			if(i >= j):
				break
			i = i + 1
			j = j - 1
		return best.getLogAlpha()
		
	def obsProb(self):
		#return the logarithm of the observation sequence: log P(e_1:T) 
		if(len(self._time_slots) < 1):
			return 1
		self.forward()
		last = self._time_slots[len(self._time_slots) - 1]
		alphaT = list()
		for it in last:
			alphaT.append(it.getLogAlpha())
		return sumLogProb1(alphaT)
	#Clear all time slots to get ready to deal with another sequence. 
	def reset(self):
		for t in range(0, len(self._time_slots)):
			self._time_slots.pop()
	def updateProbs(self, counts):
		self._transition.clear()
		self._emission.clear()
		keys = counts.getTransCount().keys()
		for i in keys:
			_from = i
			from_count = counts.getStateCount().getValue(_from)
			cnts = counts.getTransCount()[i]
			cnts_keys = cnts.keys()
			for j in cnts_keys:
				self._transition.add(_from, j, cnts[j] - from_count)	
		keys = counts.getEmitCount().keys()
		for s in keys:
			state = s
			state_count = counts.getStateCount().get(state)
			cnts = counts.getEmitCount()[s]
			cnts_keys = cnts.keys()
			for o in cnts_keys:
				self._emission.add(state, o, cnts[o] - state_count)
	'''
	Train the model with the given observation sequences using the
	Baum-Welch algorithm. 
	'''
	def baumWelch(self, sequence, max_iterations):
		'''Train the model with the given observation sequences using the
		    Baum-Welch algorithm.
		   '''
		print  'Training with Baum-Welch for up to %d  iterations, using %d sequences.' %(max_iterations, len(sequence))
		prev_total_log_prob = 0
		for k in range(0, max_iterations):
			counts = PseudoCounts()
			total_log_prob = 0
			for i in range(0, len(sequence)):
				seq = sequence[i]
				for j in range(0, len(seq)):
					self.addObservation(seq[j])
				#accumulate the pseudo counts
				total_log_prob += self.getPseudoCounts(counts)
				self.reset()
				if((i+1) % 1000 == 0):
					print('Processed %d sequences' %(i+1) )
			print('Iteration %d total_log_prob = %f' %(k, total_log_prob ))
			if((prev_total_log_prob != 0) and (total_log_prob - prev_total_log_prob < 1)):
				break
			else:
				prev_total_log_prob = total_log_prob
			self.updateProbs(counts)
	def saveProbs(self, name):
		if(name == ''):
			stderr.write('transition probalities: \n')
			self._transition.save(stderr, self._str2id)
			stderr.write('-----------------------\n')
			stderr.write('emission probabilities: \n')
			self._emission.save(stderr, self._str2id)
		else:
			s = name + '.trans'
			trans_prob_writer = file(s, 'w')
			trans_prob_writer.write(self._str2id.getStr(self._init_state))
			trans_prob_writer.write('\n')
			self._transition.save(trans_prob_writer,self._str2id)
			
			s = name + '.emit'
			emit_prob_writer = file(s, 'w')
			self._emission.save(emit_prob_writer, self._str2id)
def testTraining():
	hmm = Hmm()
	hmm.loadProbs('./phone/phone-init1')
	input_reader = file('./phone/phone.train', 'r')
	sequences = list()
	hmm.readSeqs(input_reader,sequences)
	hmm.baumWelch(sequences, 10)
	hmm.saveProbs('./phone/rphone-init1')
def testPrediction():
	hmm = Hmm()
	hmm.loadProbs('./phone/pos/pos')
	input_reader = file('./phone/pos/phone.train', 'r')
	#hmm.loadProbs('./phone/phone-init1')
	#input_reader = file('./phone/phone.train', 'r')
	sequences = list()
	hmm.readSeqs(input_reader,sequences)
	for i in range(0, len(sequences)):
		seq = sequences[i]
		for j in range(0, len(seq)):
			hmm.addObservation(seq[j])
		path = list()
		joint_prob = hmm.viterbi(path)
		print('P(path)=%f ' %exp(joint_prob - hmm.obsProb() ))
		print('path:  ')
		for j in range(0, len(path)):
			trans = path[j]
			if(trans == 0):
				continue
			print('%s \t %s' %(hmm.getStr(trans._obs) , hmm.getStr(trans._to.state())))
		hmm.reset()
if __name__ == "__main__":
	testPrediction()
	testTraining()
	

抱歉!评论已关闭.