【Tensorflow】数据及模型的保存和恢复

如果你是一个深度学习的初学者,那么我相信你应该会跟着教材或者视频敲上那么一遍代码,搭建最简单的神经网络去完成针对 MNIST 数据库的数字识别任务。通常,随意构建 3 层神经网络就可以很快地完成任务,得到比较高的准确率。这时候,你信心大增,准备挑战更难的任务。

你准备进行针对彩色图片做类型识别,那么选 CIFAR-10 就好了。于是,你也基于自己的理解,搭建了一个较为复杂的神经网络,于是,问题可能来了。你自行搭建的神经网络的准确率实在是太低了,有可能 30% 都达不到,没有办法,你只能做各种调试,加深网络,增大卷积核的数量,降低学习率等等,你会发现识别效果会得到改善,但是,训练时间却被拉长了,如果你自己学习的电脑没有 GPU 或者是 GPU 性能不好,那么训练的时间会让你绝望,因此,你渴望神经网络训练的过程可以保存和重载,就像下载软件断点续传一般,这样你就可以在晚上睡觉的时候,让机器训练,早上的时候保存结果,然后下次训练时又在上一次基础上进行。

Tensorflow 是当前最流行的机器学习框架,它自然支持这种需求。

Tensorflow 通过 tf.train.Saver 这个模块进行数据的保存和恢复。它有 2 个核心方法。

save()

restore()

顾名思义,save() 就是用来保存变量,restore() 就是用来恢复的。

它们的用法非常简单。下面,我们用示例来说明。

假设我们程序的计算图是 a * b + c
在这里插入图片描述

a、b、d、e 都是变量,现在要保存它们的值,怎么用 Tensorflow 的代码实现呢?

数据的保存

import tensorflow as tf

a = tf.get_variable("a",[1])
b = tf.get_variable("b",[1])
c = tf.get_variable("c",[1])


d = tf.multiply(a,b,name="d")

e = tf.add(d,c,name="e")

saver = tf.train.Saver()

创建标量,然后创建 Saver() 对象就好了。

接下来怎么保存这些变量呢?

def test_save(saver):

    with tf.Session() as sess:

        sess.run(tf.global_variables_initializer())

        saver.save(sess,"model/weights")
		print("a %f" % a.eval())
        print("b %f" % b.eval())
        print("c %f" % c.eval())
        print("e %f" % e.eval())
        
test_save(saver)

先初始化变量,然后调用 Saver.save() 方法就好了,第一个参数是 session 对象,第二个参数是变量存放的路径。

运行程序后,当前目录下会生成存储文件。
在这里插入图片描述

并且,程序代码有打印变量存储时本身的值。

a -1.723781
b 0.387082
c -1.321383
e -1.988627

现在编写程序代码让它恢复这些值。

数据的恢复

同样很简单。

def test_restore(saver):

    with tf.Session() as sess:
        saver.restore(sess, "model/weights")

        print("a %f" % a.eval())
        print("b %f" % b.eval())
        print("c %f" % c.eval())
        print("e %f" % e.eval())
        
test_restore(saver)

调用 Saver.restore() 方法就可以了,同样需要传递一个 session 对象,第二个参数是被保存的模型数据的路径。

当调用 Saver.restore() 时,不需要初始化所需要的变量。

大家可以仔细比较保存时的代码,和恢复时的代码。

运行程序后,会在控制台打印恢复过来的变量。

a -1.723781
b 0.387082
c -1.321383
e -1.988627

这和之前的值,一模一样,这说明程序代码有正确保存和恢复变量。

上面是最简单的变量保存例子,在实际工作当中,模型当中的变量会更多,但基本上的流程不会脱离这个最简化的流程。

frank909 CSDN认证博客专家 CV(computer vision)
爱阅读的程序员,专注于技术思考和分享。关注架构设计、Android 开发、AI、数学、自动驾驶领域,个人公号:Frankcall
已标记关键词 清除标记
©️2020 CSDN 皮肤主题: 编程工作室 设计师:CSDN官方博客 返回首页
实付 19.90元
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

余额充值