tf.estimator.train_and_evaluate
-
tf.estimator.train_and_evaluate简介
train_and_evaluate
API用来train然后evaluate一个estimator,调用如下:tf.estimator.train_and_evaluate( estimator, train_spec, eval_spec )
除了train和evaluate,还提供了模型导出功能
函数参数
estimator
: 一个Estimator
实例train_spec
: 一个TrainSpec
实例,用来配置训练过程eval_spec
: 一个TestSpec
实例,用来配置评估过程,(可选)模型的导出
返回值
Estimator.evaluate
的结果
为什么要用
train_and_evaluate
呢?官方文档的说法是这个函数可以保持本地和分布式的配置一致性,使用Estimator
和train_and_evaluate
编写的程序同时支持本地、集群上的训练,而不需要修改任何代码。当然,方便的背后一般都有代价。为了保证代码在本地和集群上都可以正常终止,所以只能使用
Estimator
的max_steps
参数设定终止条件。所以,如果想使用别的方式终止训练,可能就需要一些“技巧”了。参数说明
train_spec
参数接收一个TrainSpec
实例# TrainSpec的参数 __new__( cls, # 这个参数不用指定,忽略即可。 input_fn, max_steps=None, hooks=None )
其中:
input_fn
:
参数用来指定数据输入。max_steps
:
参数用来指定训练的最大步数,这是训练的唯一终止条件。hooks
:
参数用来挂一些tf.train.SessionRunHook
,用来在session
运行的时候做一些额外的操作,比如记录一些TensorBoard
日志什么的。
eval_spec
参数接收一个EvalSpec
实例,相比TrainSpec
,EvalSpec
的参数多很多。因为EvalSpec
不仅可以指定评估过程,还可以指定导出模型的功能(可选)__new__( cls, # 这个参数不用指定,忽略即可。 input_fn, steps=100, # 评估的迭代步数,如果为None,则在整个数据集上评估。 name=None, hooks=None, exporters=None, start_delay_secs=120, throttle_secs=600 )
其中:
input_fn
:
含义同2.1。steps
:
用来指定评估的迭代步数,如果为None
,则在整个数据集上评估。name
:
如果要在多个数据集上进行评估,通过name
参数可以保证不同数据集上的评估日志保存在不同的文件夹中,从而区分不同数据集上的评估日志。
不同的评估日志保存在独立的文件夹中,在 TensorBoard 中从而独立的展现。hooks
:
含义同2.1exporters
:
一个tf.estimator.export
模块中的类的实例。start_delay_secs
:
调用train_and_evaluate
函数后,多少秒之后开始评估。第一次评估发生在start_delay_secs + throttle_secs
秒后。throttle_secs
:
多少秒后又开始评估,如果没有新的checkpoints
产生,则不评估,所以这个间隔是最小值。