Tic商业评论

关注微信公众号【站长自定义模块】,定时推送前沿、专业、深度的商业资讯。

 找回密码
 立即注册

QQ登录

只需一步,快速开始

微信登录

微信扫码,快速开始

  • QQ空间
  • 回复
  • 收藏

TensorFlow2 保存和恢复模型

lijingle 深度学习框架 2022-1-27 21:10 1837人围观

Keras API 提供内置类,用于在模型拟合期间定期保存模型。 为了保存模型并在以后恢复它,我们可以创建一个回调 ModelCheckPoint 传递给 model.fit,模型会定期保存。


在上面的例子中,模型被保存了 epoch。 在下面的配置中,save_best_only 为 True。 因此,仅当验证损失目前为止最低时才保存模型。


在上面的配置中,新的checkpoint会覆盖旧的checkpoint,因为它使用相同的checkpoint名称。 这是另一个示例,其中checkpoint文件包含时代编号,因此不会被覆盖。


这是training_2目录下保存的内容。 上面的配置每 5 个 epoch 保存一次模型。


保存整个模型与仅仅保存权重

保存模型有两种选择——仅权重或包括训练状态以及模型架构。 如果在创建 ModelCheckpoint 时 save_weights_only 标志为 True,则模型将保存为 model.save_weights(filepath)。 这仅保存模型权重。 如果为 False,则以 SavedModel 格式保存完整模型。

默认情况下,每个 epoch 都会保存一个模型。 但它可以被 ModelCheckpoint 中的 save_freq 覆盖。

save_freq= int(NUM_OF_EPOCHS * STEPS_PER_EPOCH)


model.save_weights

如果模型只保存权重,我们需要先实例化一个新模型,然后再恢复权重。 很可能,我们调用原始 Python 代码(在我们的示例中为 create_model)来创建模型实例。 然后我们用model.load_weights加载模型的权重。 最新的检查点可以通过 tf.train.latest_checkpoint 定位。


如果没有 ModelCheckpoint 回调,我们可以调用 model.save_weights 手动保存模型权重。


model.save

为了保存完整的模型,我们使用 model.save(filepath) 将其保存为 SavedModel。 正如稍后解释的,它包含优化器和数据集迭代器的状态,以便整个训练可以从最后保存的点恢复。 由于还保存了模型架构和配置,因此可以直接恢复模型,而无需创建模型实例。


保存模型时,模型的所有 tf.Variable 都会保存,所有 @tf.function 注释的方法也保存为图形。 下面是保存为 dnn_model 的模型。


我们不再需要原始的 Python 代码。 TF 直接执行图。 事实上,这减少了生产部署期间可能出现的错误。 下面是 my_model 现在包含的目录:


但这需要 @tf.function 注释覆盖任何自定义层所需的所有方法。


CheckpointManager

如果我们想使用较低级别的 Keras API,我们也可以使用 CheckpointManager 来保存模型。 下面的代码是用于创建玩具数据集和模型的样板代码。 它还包含训练步骤的代码。


为了保存Checkpoint,我们创建一个带有Checkpoint的 CheckpointManager。 该Checkpoint包含模型、优化器、训练状态(步骤)和数据集迭代器。 所以在训练开始之前,我们可以用最新存储的Checkpoint来恢复Checkpoint。 这会加载模型权重并恢复优化器、数据集迭代器和训练步骤的状态。 简而言之,我们在上次保存模型时恢复训练状态——而不仅仅是模型权重。


Restore a training session

最后,我们将更深入地了解 SavedModel 中保存的内容以及如何恢复训练session。 上一节中的检查点不仅仅保存模型参数。 它还包含优化器的状态(学习率、衰减)和与可训练参数相关的任何参数,例如动量 (m)。 它还包含训练的状态,包括训练步骤和附加到检查点文件名称的保存计数器。 因此,当检查点恢复时,它也会恢复优化器的状态和检查点的状态。 它还检查数据集迭代器的进度。 因此,迭代器可以从它停止的地方恢复,而不是从头开始。


checkpoint.restore 从Checkpoint对象恢复任何匹配路径的变量值,即我们可以只加载Checkpoint的一个子部分。 例如,我们可以只重新创建模型的一部分,在下面的示例中,我们只是从 self.l1  dense层检查点加载偏差权重。


复制权重

下面的代码将权重从一层复制到另一层。


在下面的代码中,即使functional_model_with_dropout 与functional_model 相比包含额外的dropout 层,dropout 层也不包含任何权重。 所以我们仍然可以使用 model.set_weights 从functional_model 复制权重到functional_model_with_dropout。



路过

雷人

握手

鲜花

鸡蛋
我有话说......
电话咨询: 135xxxxxxx
关注微信