Skip to content Skip to sidebar Skip to footer

Is Loading In Eager Tensorflow Broken Right Now?

Weights in classes inheriting from tf.keras.Model seem unable to load at the moment. I am unable to load the weights from Example() outside of the class using checkpointing, so I t

Solution 1:

As it turns out, there are three different ways TensorFlow does checkpointing, depending on what is being checkpointed.

  1. The checkpointed object is just a variable. This is restored immediately upon calling checkpoint.restore(tf.train.latest_checkpoint(checkpoint_path)).

  2. The checkpointed object is a model with input shape defined. This is also restored immediately.

  3. The checkpointed object is a model without input shape defined. This is where the behaviour changes, as TensorFlow does a "delayed" restore, and will NOT restore the model weights until input is passed to the model.

Here is an example:

import os
import tensorflow as tf
import numpy as np

# Disable logging
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# Create model
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(256, 3, padding="same"),
    tf.keras.layers.Conv2D(3, 3, padding="same")
print("Are weights empty before training?", model.weights == [])

# Create optim, checkpoint
optimizer = tf.train.AdamOptimizer(0.001)
checkpoint = tf.train.Checkpoint(model=model)

# Make fake data
img = np.random.uniform(0, 255, (1, 32, 32, 3)).astype(np.float32)
truth = np.random.uniform(0, 255, (1, 32, 32, 3)).astype(np.float32)
# Trainwith tf.GradientTape() as tape:
    logits = model(img)
    loss = tf.losses.mean_squared_error(truth, logits)

# Compute/apply gradients
grads = tape.gradient(loss, model.trainable_weights)
grads_and_vars = zip(grads, model.trainable_weights)

# Save model
checkpoint_path = './ckpt/''./ckpt/')

# Check if weights updateprint("Are weights empty after training?", model.weights == [])

# Reset model
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(256, 3, padding="same"),
    tf.keras.layers.Conv2D(3, 3, padding="same")
print("Are weights empty when resetting model?", model.weights == [])

# Update checkpoint pointer
checkpoint = tf.train.Checkpoint(model=model)
# Restore values from the checkpoint
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_path))

# This next line is REQUIRED to restore#model(img)print("Are weights empty after restoring from checkpoint?", model.weights == [])

With output:

Are weights empty before training? True
Are weights empty after training? False
Are weights empty when resetting model? True
Are weights empty after restoring from checkpoint? True
< object at 0x7f6256b4ddd8>
Traceback (most recent call last):
  File "", line 58, in <module>
  File "/home/jpatts/.local/lib/python3.6/site-packages/tensorflow/python/training/checkpointable/", line 1013, in assert_consumed
    raise AssertionError("Unresolved object in checkpoint: %s" % (node,))
AssertionError: Unresolved object in checkpoint: attributes {
  full_name: "sequential/conv2d/kernel"
  checkpoint_key: "model/layer-0/kernel/.ATTRIBUTES/VARIABLE_VALUE"

However, uncommenting the line model(img) will produce the following output:

Are weights empty before training? TrueAre weights empty after training? FalseAre weights emptywhen resetting model? TrueAre weights empty after restoring from checkpoint? False< object at0x7ff62320fe48>

So input data needs to be passed to properly restore a shape invariant model.


Post a Comment for "Is Loading In Eager Tensorflow Broken Right Now?"