Transfering weights from selected layers

Previously, we had trained CIFAR for gray scale images. Now we want to reuse that model for RBG. That is, we want to reuse the weights from 2nd CNN layer

Here, we ll demonstrate using tf.layers.conv2d, so that we can leverage the regularizer convinience. Other than that, this can be done more easily done with tf.nn.conv2d

In [1]:
import tensorflow as tf
import numpy as np

Create a data pipeline

The output shape and datatypes have to be fixed

In [2]:
tf.TensorShape((None,32,32,3)),tf.TensorShape((None,1))
Out[2]:
(TensorShape([Dimension(None), Dimension(32), Dimension(32), Dimension(3)]),
 TensorShape([Dimension(None), Dimension(1)]))

Create a feedable iterator

In [3]:
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle,
                                               (tf.float32,tf.uint8),
                                               (tf.TensorShape((None,32,32,3)),tf.TensorShape((None,1))))

model_input,label = iterator.get_next()

Loading model and viewing it on tensorboard

In [4]:
!rm tensorboard_logs/*

Load the model in a separate graph so that it does not interfere with current graph

In [5]:
g = tf.Graph()
with g.as_default():
    saver = tf.train.import_meta_graph('models/CIFAR10_grey/cifar_model.ckpt.meta')
    with tf.Session() as sess:
        saver.restore(sess,tf.train.latest_checkpoint('models/CIFAR10_grey/'))
        writer = tf.summary.FileWriter('./tensorboard_logs', sess.graph)
        writer.close()
INFO:tensorflow:Restoring parameters from models/CIFAR10_grey/cifar_model.ckpt
In [6]:
!tensorboard --logdir=tensorboard_logs/
TensorBoard 1.12.0 at http://a7d515f6a632:6006 (Press CTRL+C to quit)
^C

Get variable names in each layer

In [4]:
from tensorflow.python.tools import inspect_checkpoint as chkp
chkp.print_tensors_in_checkpoint_file('models/CIFAR10_grey/cifar_model.ckpt',tensor_name='', all_tensors=False,all_tensor_names=True)
tensor_name:  Conv1/bias
tensor_name:  Conv1/kernel
tensor_name:  Conv2/bias
tensor_name:  Conv2/kernel
tensor_name:  Conv3/bias
tensor_name:  Conv3/kernel
tensor_name:  Dense_10/bias
tensor_name:  Dense_10/kernel
tensor_name:  Dense_84/bias
tensor_name:  Dense_84/kernel

Get weights as numpy array

This is because, tf.initializer.constant needs a numpy array

Get the weights and store it in a numpy array under separate graph. Thus, when we save the new model, the old graph wont interfere

In [5]:
#Get weights as numpy
g = tf.Graph()
with g.as_default():
    saver = tf.train.import_meta_graph('models/CIFAR10_grey/cifar_model.ckpt.meta')
    with tf.Session() as sess:
        saver.restore(sess,tf.train.latest_checkpoint('models/CIFAR10_grey/'))

        conv2_bias = sess.run(sess.graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope = 'Conv2/bias:0')[0])
        conv2_kernel = sess.run(sess.graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope = 'Conv2/kernel:0')[0])

        conv3_bias = sess.run(sess.graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope = 'Conv3/bias:0')[0])
        conv3_kernel = sess.run(sess.graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope = 'Conv3/kernel:0')[0])

        dense84_bias = sess.run(sess.graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope = 'Dense_84/bias:0')[0])
        dense84_kernel = sess.run(sess.graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope = 'Dense_84/kernel:0')[0])

        dense10_bias = sess.run(sess.graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope = 'Dense_10/bias:0')[0])
        dense10_kernel = sess.run(sess.graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope = 'Dense_10/kernel:0')[0])    

        print(conv2_bias)
INFO:tensorflow:Restoring parameters from models/CIFAR10_grey/cifar_model.ckpt
[-0.14805889  0.6004731   0.14357984 -0.3337978  -0.48210445 -0.07447348
 -0.3848596  -0.07660984  0.13621555  0.2962938  -0.15078765 -0.14430566
  0.3739807  -0.3766448   0.60352474  0.4924152 ]

Check

In [6]:
conv2_bias
Out[6]:
array([-0.14805889,  0.6004731 ,  0.14357984, -0.3337978 , -0.48210445,
       -0.07447348, -0.3848596 , -0.07660984,  0.13621555,  0.2962938 ,
       -0.15078765, -0.14430566,  0.3739807 , -0.3766448 ,  0.60352474,
        0.4924152 ], dtype=float32)

Build model and set the initial weights

In [7]:
def cnn_features(model_input):

    with tf.Session() as sess:
            
        conv1 = tf.layers.conv2d(model_input,6,(5,5),(1,1),activation=tf.nn.tanh,name="Conv1")
        avgpool1 = tf.layers.average_pooling2d(conv1,(2,2),(2,2),name="AvgPool1")

        conv2 = tf.layers.conv2d(avgpool1,16,(5,5),(1,1),activation=tf.nn.tanh,name="Conv2",
                                 bias_initializer=tf.initializers.constant(conv2_bias),
                                 kernel_initializer=tf.initializers.constant(conv2_kernel))
        avgpool2 = tf.layers.average_pooling2d(conv2,(2,2),(2,2),name="AvgPool2")

        conv3 = tf.layers.conv2d(avgpool2,120,(5,5),(1,1),activation=tf.nn.tanh,name="Conv3",
                                 bias_initializer=tf.initializers.constant(conv3_bias),
                                 kernel_initializer=tf.initializers.constant(conv3_kernel))

        flatten = tf.layers.Flatten()(conv3)

        features = tf.layers.dense(flatten,84,activation=tf.nn.tanh,name="Dense_84",
                                 bias_initializer=tf.initializers.constant(dense84_bias),
                                 kernel_initializer=tf.initializers.constant(dense84_kernel))
    
    return features
In [8]:
features = cnn_features(model_input)

Check the global variables created

In [9]:
tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
Out[9]:
[<tf.Variable 'Conv1/kernel:0' shape=(5, 5, 3, 6) dtype=float32_ref>,
 <tf.Variable 'Conv1/bias:0' shape=(6,) dtype=float32_ref>,
 <tf.Variable 'Conv2/kernel:0' shape=(5, 5, 6, 16) dtype=float32_ref>,
 <tf.Variable 'Conv2/bias:0' shape=(16,) dtype=float32_ref>,
 <tf.Variable 'Conv3/kernel:0' shape=(5, 5, 16, 120) dtype=float32_ref>,
 <tf.Variable 'Conv3/bias:0' shape=(120,) dtype=float32_ref>,
 <tf.Variable 'Dense_84/kernel:0' shape=(120, 84) dtype=float32_ref>,
 <tf.Variable 'Dense_84/bias:0' shape=(84,) dtype=float32_ref>]

Check if the weights were properly initialized

In [10]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope="Conv2/bias")[0]))
[-0.14805889  0.6004731   0.14357984 -0.3337978  -0.48210445 -0.07447348
 -0.3848596  -0.07660984  0.13621555  0.2962938  -0.15078765 -0.14430566
  0.3739807  -0.3766448   0.60352474  0.4924152 ]

Save model

In [11]:
!rm models/WT_TRANSFER/*
In [12]:
saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    tf.add_to_collection('handle',handle)
    tf.add_to_collection('target',label)
    tf.add_to_collection('model_input',model_input)
    tf.add_to_collection('features',features)
    saver.save(sess,'models/WT_TRANSFER/wt_transfer_eg.chpt')

Summary

To do this without saving an intermediate model, use Low level conv API (tf.nn.conv2d)