如何解决tensorflow反复调用预测函数导致内存爆炸的问题?

各位大神,这是我的预测函数,输入的是batch_size50030的array,输出的是batch_size*2的包含预测分类和概率的array。
这个函数极其消耗内存,监测发现内存总是突然暴涨然后迅速下降,迅速下降应该是释放函数内的临时变量,但是整体来看的话还是迅速在占用内存(下降的没有暴涨的多),导致这个函数循环调用30多次,我16G的内存就爆掉了,请问该如何解决?
我猜测是由于with tf.Graph().as_default():导致什么东西没有被当做临时变量释放掉,然后不用这种写法又会导致ckpt无法重复调用,报Key <variable_name_#n> not found in checkpoint,请问这个又该如何解决?

def evaluate_all_oneday_data(data,batch_size):
    with tf.Graph().as_default():
        BATCH_SIZE = batch_size
        N_CLASSES = 7
        n_inputs=30
        n_steps=500
        n_hidden_units=100
        
        tf_data=tf.reshape(data,[BATCH_SIZE,n_steps,n_inputs])
        
        #inference是个RNN的网络,最后的全连接层没有加act_fun,所以这里加了softmax。
        logit = tf.nn.softmax(model_1206.inference((tf_data), BATCH_SIZE, N_CLASSES,n_inputs,n_steps,n_hidden_units))        
        x = tf.placeholder(tf.float32, shape=[BATCH_SIZE,500, 30])        
        logs_train_dir = 'G:\\first\\train501\\'
        
        saver = tf.train.Saver()
        
        with tf.Session() as sess:
            ckpt = tf.train.get_checkpoint_state(logs_train_dir)
            if ckpt and ckpt.model_checkpoint_path:
                #global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                saver.restore(sess, ckpt.model_checkpoint_path)
            else:
                pass
            prediction = sess.run(logit, feed_dict={x: data})
            pred = np.column_stack((np.argmax(prediction,axis=1),np.max(prediction,axis=1)))
    return pred
阅读 11.3k
1 个回答

你可以在外面加载一下模型,预测的时候只是调用模型好了,不要每次都加载,应该就可以避免这个问题。我一般是加载一次模型,然后预测所用数据,完全不用再加载。

撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进
推荐问题