Previously, we had trained CIFAR for gray scale images. Now we want to reuse that model for RBG. That is, we want to reuse the weights from 2nd CNN layer
Here, we ll demonstrate using tf.layers.conv2d, so that we can leverage the regularizer convinience. Other than that, this can be done more easily done with tf.nn.conv2d
import tensorflow as tf
import numpy as np
The output shape and datatypes have to be fixed
tf.TensorShape((None,32,32,3)),tf.TensorShape((None,1))
Create a feedable iterator
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle,
(tf.float32,tf.uint8),
(tf.TensorShape((None,32,32,3)),tf.TensorShape((None,1))))
model_input,label = iterator.get_next()
!rm tensorboard_logs/*
Load the model in a separate graph so that it does not interfere with current graph
g = tf.Graph()
with g.as_default():
saver = tf.train.import_meta_graph('models/CIFAR10_grey/cifar_model.ckpt.meta')
with tf.Session() as sess:
saver.restore(sess,tf.train.latest_checkpoint('models/CIFAR10_grey/'))
writer = tf.summary.FileWriter('./tensorboard_logs', sess.graph)
writer.close()
!tensorboard --logdir=tensorboard_logs/
from tensorflow.python.tools import inspect_checkpoint as chkp
chkp.print_tensors_in_checkpoint_file('models/CIFAR10_grey/cifar_model.ckpt',tensor_name='', all_tensors=False,all_tensor_names=True)
This is because, tf.initializer.constant needs a numpy array
Get the weights and store it in a numpy array under separate graph. Thus, when we save the new model, the old graph wont interfere
#Get weights as numpy
g = tf.Graph()
with g.as_default():
saver = tf.train.import_meta_graph('models/CIFAR10_grey/cifar_model.ckpt.meta')
with tf.Session() as sess:
saver.restore(sess,tf.train.latest_checkpoint('models/CIFAR10_grey/'))
conv2_bias = sess.run(sess.graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope = 'Conv2/bias:0')[0])
conv2_kernel = sess.run(sess.graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope = 'Conv2/kernel:0')[0])
conv3_bias = sess.run(sess.graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope = 'Conv3/bias:0')[0])
conv3_kernel = sess.run(sess.graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope = 'Conv3/kernel:0')[0])
dense84_bias = sess.run(sess.graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope = 'Dense_84/bias:0')[0])
dense84_kernel = sess.run(sess.graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope = 'Dense_84/kernel:0')[0])
dense10_bias = sess.run(sess.graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope = 'Dense_10/bias:0')[0])
dense10_kernel = sess.run(sess.graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope = 'Dense_10/kernel:0')[0])
print(conv2_bias)
Check
conv2_bias
def cnn_features(model_input):
with tf.Session() as sess:
conv1 = tf.layers.conv2d(model_input,6,(5,5),(1,1),activation=tf.nn.tanh,name="Conv1")
avgpool1 = tf.layers.average_pooling2d(conv1,(2,2),(2,2),name="AvgPool1")
conv2 = tf.layers.conv2d(avgpool1,16,(5,5),(1,1),activation=tf.nn.tanh,name="Conv2",
bias_initializer=tf.initializers.constant(conv2_bias),
kernel_initializer=tf.initializers.constant(conv2_kernel))
avgpool2 = tf.layers.average_pooling2d(conv2,(2,2),(2,2),name="AvgPool2")
conv3 = tf.layers.conv2d(avgpool2,120,(5,5),(1,1),activation=tf.nn.tanh,name="Conv3",
bias_initializer=tf.initializers.constant(conv3_bias),
kernel_initializer=tf.initializers.constant(conv3_kernel))
flatten = tf.layers.Flatten()(conv3)
features = tf.layers.dense(flatten,84,activation=tf.nn.tanh,name="Dense_84",
bias_initializer=tf.initializers.constant(dense84_bias),
kernel_initializer=tf.initializers.constant(dense84_kernel))
return features
features = cnn_features(model_input)
Check the global variables created
tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
Check if the weights were properly initialized
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope="Conv2/bias")[0]))
Save model
!rm models/WT_TRANSFER/*
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
tf.add_to_collection('handle',handle)
tf.add_to_collection('target',label)
tf.add_to_collection('model_input',model_input)
tf.add_to_collection('features',features)
saver.save(sess,'models/WT_TRANSFER/wt_transfer_eg.chpt')
To do this without saving an intermediate model, use Low level conv API (tf.nn.conv2d)