Is Loading In Eager Tensorflow Broken Right Now?
Solution 1:
As it turns out, there are three different ways TensorFlow does checkpointing, depending on what is being checkpointed.
The checkpointed object is just a variable. This is restored immediately upon calling
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_path))
.The checkpointed object is a model with input shape defined. This is also restored immediately.
The checkpointed object is a model without input shape defined. This is where the behaviour changes, as TensorFlow does a "delayed" restore, and will NOT restore the model weights until input is passed to the model.
Here is an example:
import os
import tensorflow as tf
import numpy as np
# Disable logging
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
tf.logging.set_verbosity(tf.logging.ERROR)
tf.enable_eager_execution()
# Create model
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(256, 3, padding="same"),
tf.keras.layers.Conv2D(3, 3, padding="same")
])
print("Are weights empty before training?", model.weights == [])
# Create optim, checkpoint
optimizer = tf.train.AdamOptimizer(0.001)
checkpoint = tf.train.Checkpoint(model=model)
# Make fake data
img = np.random.uniform(0, 255, (1, 32, 32, 3)).astype(np.float32)
truth = np.random.uniform(0, 255, (1, 32, 32, 3)).astype(np.float32)
# Trainwith tf.GradientTape() as tape:
logits = model(img)
loss = tf.losses.mean_squared_error(truth, logits)
# Compute/apply gradients
grads = tape.gradient(loss, model.trainable_weights)
grads_and_vars = zip(grads, model.trainable_weights)
optimizer.apply_gradients(grads_and_vars)
# Save model
checkpoint_path = './ckpt/'
checkpoint.save('./ckpt/')
# Check if weights updateprint("Are weights empty after training?", model.weights == [])
# Reset model
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(256, 3, padding="same"),
tf.keras.layers.Conv2D(3, 3, padding="same")
])
print("Are weights empty when resetting model?", model.weights == [])
# Update checkpoint pointer
checkpoint = tf.train.Checkpoint(model=model)
# Restore values from the checkpoint
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_path))
# This next line is REQUIRED to restore#model(img)print("Are weights empty after restoring from checkpoint?", model.weights == [])
print(status)
status.assert_existing_objects_matched()
status.assert_consumed()
With output:
Are weights empty before training? True
Are weights empty after training? False
Are weights empty when resetting model? True
Are weights empty after restoring from checkpoint? True
<tensorflow.python.training.checkpointable.util.CheckpointLoadStatus object at 0x7f6256b4ddd8>
Traceback (most recent call last):
File "test.py", line 58, in <module>
status.assert_consumed()
File "/home/jpatts/.local/lib/python3.6/site-packages/tensorflow/python/training/checkpointable/util.py", line 1013, in assert_consumed
raise AssertionError("Unresolved object in checkpoint: %s" % (node,))
AssertionError: Unresolved object in checkpoint: attributes {
name: "VARIABLE_VALUE"
full_name: "sequential/conv2d/kernel"
checkpoint_key: "model/layer-0/kernel/.ATTRIBUTES/VARIABLE_VALUE"
}
However, uncommenting the line model(img)
will produce the following output:
Are weights empty before training? TrueAre weights empty after training? FalseAre weights emptywhen resetting model? TrueAre weights empty after restoring from checkpoint? False<tensorflow.python.training.checkpointable.util.CheckpointLoadStatus object at0x7ff62320fe48>
So input data needs to be passed to properly restore a shape invariant model.
References:
https://www.tensorflow.org/alpha/guide/checkpoints#delayed_restorationshttps://github.com/tensorflow/tensorflow/issues/27937
Post a Comment for "Is Loading In Eager Tensorflow Broken Right Now?"