import tensorflow as tf
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import glob
from pathlib import Path
%matplotlib inline
ls datasets/cifar-10-batches-bin/
file_path = "datasets/cifar-10-batches-bin/data_batch_*"
train_bins = glob.glob(file_path)
train_bins
<1 byte x label><3*32*32 bytes x pixel>
...
<1 byte x label><3*32*32 bytes x pixel>
Each image is represented in (C,H,W) format
ft = Path(train_bins[0]).open('rb')
lbl = np.frombuffer(ft.read(1),np.dtype('u1'))
print(lbl)
img = np.frombuffer(ft.read(3*32*32),np.dtype('u1')).reshape(3,32,32)
img = np.transpose(img,(1,2,0))
print(img.shape)
Image.fromarray(img)
lbl = np.frombuffer(ft.read(1),np.dtype('u1'))
img = np.frombuffer(ft.read(32*32*3),np.dtype('u1')).reshape(3,32,32)
img = np.transpose(img,(1,2,0))
print(lbl)
Image.fromarray(img)
def cifar_dataset(files_list:list) -> tf.data.Dataset:
data = tf.data.FixedLengthRecordDataset(files_list,1+3*32*32)
data = data.map(lambda x: tf.decode_raw(x,tf.uint8),num_parallel_calls=4)
data = data.map(lambda x: (x[1:],tf.expand_dims(x[0],0)),num_parallel_calls=4) #To match with MNIST
data = data.map(lambda x,y: (tf.reshape(x,(3,32,32)),y),num_parallel_calls=4)
data = data.map(lambda x,y: (tf.transpose(x,(1,2,0)),y),num_parallel_calls=4)
data = data.map(lambda x,y: (tf.image.rgb_to_grayscale(x),y),num_parallel_calls=4)
data = data.map(lambda x,y: (tf.image.convert_image_dtype(x,tf.float32),y),num_parallel_calls=4)
return data
with tf.device('/cpu:0'):
train_dataset = cifar_dataset(train_bins)
train_dataset = train_dataset.shuffle(20000)
train_dataset = train_dataset.repeat(20)
train_dataset = train_dataset.batch(10)
train_dataset = train_dataset.prefetch(2)
train_dataset.output_shapes, train_dataset.output_types
train_iterator = train_dataset.make_initializable_iterator()
Check (Restart kernel and dont call this check before training)
im,l = train_iterator.get_next()
with tf.Session() as sess:
sess.run(train_iterator.initializer)
imr,lr = sess.run([im,l])
lr
imr.shape
plt.imshow(imr[0].reshape(32,32),cmap='gray')
from tensorflow.python.tools import inspect_checkpoint as chkp
chkp.print_tensors_in_checkpoint_file('models/MNIST_CNN/mnist_model.ckpt',tensor_name='', all_tensors=False,all_tensor_names=True)
Calling import_meta_data adds all the variables to the default graph
tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
saver = tf.train.import_meta_graph('models/MNIST_CNN/mnist_model.ckpt.meta')
tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
Check
with tf.Session() as sess:
#saver.restore(sess,tf.train.latest_checkpoint('models/MNIST_CNN/'))
print(sess.graph.get_all_collection_keys())
print(sess.graph.get_collection('classifier'))
print(sess.graph.get_collection('data_handle'))
Check loading
with tf.Session() as sess:
saver.restore(sess,tf.train.latest_checkpoint('models/MNIST_CNN/'))
sess.run(train_iterator.initializer)
hdl = sess.run(train_iterator.string_handle())
out = sess.run(sess.graph.get_collection('classifier')[0],{sess.graph.get_collection('data_handle')[0]:hdl})
out
!rm models/CIFAR10_grey/*
import time
#Its better to use new saver, esp when you have created new tensors in this code
saver_train = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess,tf.train.latest_checkpoint('models/MNIST_CNN/'))
sess.run(train_iterator.initializer)
hdl = sess.run(train_iterator.string_handle())
loss = sess.graph.get_collection('loss')[0]
train = sess.graph.get_collection('train')[0]
handle = sess.graph.get_collection('data_handle')[0]
start = time.time()
try:
i = 1
tmp = []
while True:
i = i+1
l,_ = sess.run([loss,train],{handle:hdl})
tmp.append(l)
if i%5000 == 0:
avg_loss = np.array(tmp).mean()
print("Batch: ",i,avg_loss)
tmp = []
except tf.errors.OutOfRangeError:
pass
end = time.time()
elapsed = end-start
print("Elapsed time : ", elapsed, " s")
saver_train.save(sess,'models/CIFAR10_grey/cifar_model.ckpt')
with tf.device('/cpu:0'):
test_dataset = cifar_dataset(['datasets/cifar-10-batches-bin/test_batch.bin'])
test_dataset = test_dataset.batch(10)
test_dataset = test_dataset.prefetch(2)
test_iterator = test_dataset.make_initializable_iterator()
def get_accuracy(predict:'eg: [2,4,1,...]',true: 'eg: [2,4,1,...]') -> int:
correct_pred = tf.equal(predict,true)
#We have to cast [True,False,True,...] --> [1,0,1...]
acc = tf.reduce_mean(tf.cast(correct_pred,tf.float32))
return acc
with tf.Session() as sess:
saver_train.restore(sess,'models/CIFAR10_grey/cifar_model.ckpt')
sess.run(test_iterator.initializer)
hdl = sess.run(test_iterator.string_handle())
label = sess.graph.get_collection('target')[0]
classifier = sess.graph.get_collection('classifier')[0]
handle = sess.graph.get_collection('data_handle')[0]
#IMPORTANT:
#Dont place this code inside the loop! This will slow down everything
acc = get_accuracy(tf.argmax(classifier,axis=1),tf.transpose(tf.argmax(tf.one_hot(label,10),axis=2)))
try:
i = 0
acc_list = []
while True:
i = i+1
a = sess.run(acc,{handle:hdl})
acc_list.append(a)
if i%100 == 0:
print(i, "Mean Acc : ", np.array(acc_list).mean())
acc_list = []
except tf.errors.OutOfRangeError:
pass
Just add tf.global_variable_initializer to the previous code
!rm models/CIFAR10_grey/*
import time
#saver = tf.train.import_meta_graph('models/MNIST_CNN/mnist_model.ckpt.meta')
with tf.Session() as sess:
saver.restore(sess,tf.train.latest_checkpoint('models/MNIST_CNN/'))
sess.run(tf.global_variables_initializer())
sess.run(train_iterator.initializer)
hdl = sess.run(train_iterator.string_handle())
loss = sess.graph.get_collection('loss')[0]
train = sess.graph.get_collection('train')[0]
handle = sess.graph.get_collection('data_handle')[0]
start = time.time()
try:
i = 1
tmp = []
while True:
i = i+1
l,_ = sess.run([loss,train],{handle:hdl})
tmp.append(l)
if i%5000 == 0:
avg_loss = np.array(tmp).mean()
print("Batch: ",i,avg_loss)
tmp = []
except tf.errors.OutOfRangeError:
pass
end = time.time()
elapsed = end-start
print("Elapsed time : ", elapsed, " s")
saver.save(sess,'models/CIFAR10_grey/cifar_model.ckpt')
with tf.Session() as sess:
saver.restore(sess,'models/CIFAR10_grey/cifar_model.ckpt')
#sess.run(tf.global_variables_initializer()) #0.01 acc for non trained model
sess.run(test_iterator.initializer)
hdl = sess.run(test_iterator.string_handle())
label = sess.graph.get_collection('target')[0]
classifier = sess.graph.get_collection('classifier')[0]
#IMPORTANT:
#Dont place this code inside the loop! This will slow down everything
acc = get_accuracy(tf.argmax(classifier,axis=1),tf.transpose(tf.argmax(tf.one_hot(label,10),axis=2)))
try:
i = 0
acc_list = []
while True:
i = i+1
a = sess.run(acc,{handle:hdl})
acc_list.append(a)
if i%100 == 0:
print(i, "Mean Acc : ", np.array(acc_list).mean())
acc_list = []
except tf.errors.OutOfRangeError:
pass
IMPORTANT:
1) Do not call get_next of any child iterators, when using feedable iterators. Make sure they are not called before taining
2) Calling import_meta_data adds all the variables to the default graph