## Tensorflow使用pb文件保存(恢复)模型计算图和参数实例详解

2020年02月18日 编程语言 ⁄ 共 3937字 ⁄ 字号 评论关闭

graph_util.convert_variables_to_constants 可以把当前session的计算图串行化成一个字节流（二进制），这个函数包含三个参数：参数1：当前活动的session，它含有各变量

constant_graph = graph_util.convert_variables_to_constants( sess, sess.graph_def , ['sum_operation'] )with open( pbName, mode='wb') as f:f.write(constant_graph.SerializeToString())

graph0 = tf.GraphDef() with open( pbName, mode='rb') as f: graph0.ParseFromString( f.read() ) tf.import_graph_def( graph0 , name = '' )

import tensorflow as tffrom tensorflow.python.framework import graph_utilpbName = 'graphA.pb'def graphCreate() : with tf.Session() as sess : var1 = tf.placeholder ( tf.int32 , name='var1' ) var2 = tf.Variable( 20 , name='var2' )#实参name='var2'指定了操作名，该操作返回的张量名是在 #'var2'后面:0 ,即var2:0 是返回的张量名，也就是说变量 # var2的名称是'var2:0' var3 = tf.Variable( 30 , name='var3' ) var4 = tf.Variable( 40 , name='var4' ) var4op = tf.assign( var4 , 1000 , name = 'var4op1' ) sum = tf.Variable( 4, name='sum' ) sum = tf.add ( var1 , var2, name = 'var1_var2' ) sum = tf.add( sum , var3 , name='sum_var3' ) sumOps = tf.add( sum , var4 , name='sum_operation' ) oper = tf.get_default_graph().get_operations() with open( 'operation.csv','wt' ) as f: s = 'name,type,output\n' f.write( s ) for o in oper: s = o.name s += ','+ o.type inp = o.inputs oup = o.outputs for iip in inp : s #s += ','+ str(iip) for iop in oup : s += ',' + str(iop) s += '\n' f.write( s ) for var in tf.global_variables(): print('variable=> ' , var.name) #张量是tf.Variable/tf.Add之类操作的结果， #张量的名字使用操作名加:0来表示 init = tf.global_variables_initializer() sess.run( init ) sess.run( var4op ) print('sum_operation result is Tensor ' , sess.run( sumOps , feed_dict={var1:1}) ) constant_graph = graph_util.convert_variables_to_constants( sess, sess.graph_def , ['sum_operation'] ) with open( pbName, mode='wb') as f: f.write(constant_graph.SerializeToString())def graphGet() : print("start get:" ) with tf.Graph().as_default(): graph0 = tf.GraphDef() with open( pbName, mode='rb') as f: graph0.ParseFromString( f.read() ) tf.import_graph_def( graph0 , name = '' ) with tf.Session() as sess : init = tf.global_variables_initializer() sess.run(init) v1 = sess.graph.get_tensor_by_name('var1:0' ) v2 = sess.graph.get_tensor_by_name('var2:0' ) v3 = sess.graph.get_tensor_by_name('var3:0' ) v4 = sess.graph.get_tensor_by_name('var4:0' ) sumTensor = sess.graph.get_tensor_by_name("sum_operation:0") print('sumTensor is : ' , sumTensor ) print( sess.run( sumTensor , feed_dict={v1:1} ) ) graphCreate()graphGet()