Saving and restoring

In [3]:
import tensorflow as tf
import numpy as np

Remember

1) Tensors that should be retrieved must be added to collection

2) Varibales are automatically added
In [4]:
tf.reset_default_graph()
w = tf.Variable(tf.random_normal((5,5),seed=101),name="w")
b = tf.placeholder(tf.float32,(5,),name="b")
x = tf.matmul(w,tf.expand_dims(b,axis=1),name="x")
tf.add_to_collection('output',x)
In [5]:
saver = tf.train.Saver()
In [6]:
!rm models/*
In [7]:
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    W,o = sess.run([w,x],{b:np.ones((5,))})
    print(W,o)
    saver.save(sess,'models/my_first_model.ckpt')
[[-2.3159974  -0.07912122 -0.8998717   0.31664512 -1.1119006 ]
 [ 1.3949379   0.49928963 -1.4618131  -0.7174765  -0.4182692 ]
 [-2.1247718   1.751722   -2.4979882  -0.38874432 -0.49987432]
 [-0.29406643  0.18566869 -0.14782837  0.5535406  -0.51702005]
 [ 0.32753068  0.259901   -1.3543104   1.0400561   0.855376  ]] [[-4.0902452 ]
 [-0.70333123]
 [-3.7596567 ]
 [-0.21970558]
 [ 1.1285534 ]]

Restoring from same file

In [9]:
with tf.Session() as sess:
    saver.restore(sess,'models/my_first_model.ckpt')
    W,o = sess.run([w,x],{b:np.ones((5,))})
    print(W,o)
INFO:tensorflow:Restoring parameters from models/my_first_model.ckpt
[[-2.3159974  -0.07912122 -0.8998717   0.31664512 -1.1119006 ]
 [ 1.3949379   0.49928963 -1.4618131  -0.7174765  -0.4182692 ]
 [-2.1247718   1.751722   -2.4979882  -0.38874432 -0.49987432]
 [-0.29406643  0.18566869 -0.14782837  0.5535406  -0.51702005]
 [ 0.32753068  0.259901   -1.3543104   1.0400561   0.855376  ]] [[-4.0902452 ]
 [-0.70333123]
 [-3.7596567 ]
 [-0.21970558]
 [ 1.1285534 ]]

Restoring from different file

Restart kernel

In [1]:
import tensorflow as tf
import numpy as np

Load the graph

In [18]:
tf.reset_default_graph()
saver = tf.train.import_meta_graph('models/my_first_model.ckpt.meta')
In [19]:
sess = tf.Session()
saver.restore(sess,tf.train.latest_checkpoint('models/'))
INFO:tensorflow:Restoring parameters from models/my_first_model.ckpt

Check if the data is available

In [20]:
sess.graph.get_all_collection_keys()
Out[20]:
['variables', 'trainable_variables', 'output']
In [21]:
sess.graph.get_collection('output')
Out[21]:
[<tf.Tensor 'x:0' shape=(5, 1) dtype=float32>]

Run

In [25]:
W,o = sess.run(['w:0','x:0'],{'b:0':np.ones((5,))})
print(W,o)
[[-2.3159974  -0.07912122 -0.8998717   0.31664512 -1.1119006 ]
 [ 1.3949379   0.49928963 -1.4618131  -0.7174765  -0.4182692 ]
 [-2.1247718   1.751722   -2.4979882  -0.38874432 -0.49987432]
 [-0.29406643  0.18566869 -0.14782837  0.5535406  -0.51702005]
 [ 0.32753068  0.259901   -1.3543104   1.0400561   0.855376  ]] [[-4.0902452 ]
 [-0.70333123]
 [-3.7596567 ]
 [-0.21970558]
 [ 1.1285534 ]]

Inspecting check points

In [26]:
# import the inspect_checkpoint library
from tensorflow.python.tools import inspect_checkpoint as chkp
In [28]:
chkp.print_tensors_in_checkpoint_file("models/my_first_model.ckpt",tensor_name='', all_tensors=True,all_tensor_names=True)
tensor_name:  w
[[-2.3159974  -0.07912122 -0.8998717   0.31664512 -1.1119006 ]
 [ 1.3949379   0.49928963 -1.4618131  -0.7174765  -0.4182692 ]
 [-2.1247718   1.751722   -2.4979882  -0.38874432 -0.49987432]
 [-0.29406643  0.18566869 -0.14782837  0.5535406  -0.51702005]
 [ 0.32753068  0.259901   -1.3543104   1.0400561   0.855376  ]]