There is always a pleasure in reading, typing line by line. Staying focused and trying to understand. Siamese with my own comments and observations added for my reference
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#https://keras.io/examples/vision/siamese_contrastive/ | |
#added comments for my understanding | |
import random | |
import numpy as np | |
import tensorflow as tf | |
from tensorflow import keras | |
from tensorflow.keras import layers | |
import matplotlib.pyplot as plt | |
#hyperparameters | |
epochs = 10 | |
batch_size = 16 | |
#margin for contrastive loss | |
margin = 1 | |
#load mnist dataset | |
(x_train_val, y_train_val),(x_test,y_test) = keras.datasets.mnist.load_data() | |
#change the data type to floating point format | |
x_train_val = x_train_val.astype("float32") | |
x_test = x_test.astype("float32") | |
#define training and validation tests | |
x_train, x_val = x_train_val[:30000],x_train_val[30000:] | |
y_train, y_val = y_train_val[:30000],y_train_val[30000:] | |
del x_train_val, y_train_val | |
#creating pairs of images | |
def make_pairs(x,y): | |
num_classes = max(y)+1 | |
digit_indices = [np.where(y==i)[0] for i in range(num_classes)] | |
pairs = [] | |
labels = [] | |
for idx1 in range(len(x)): | |
#add a matching pair | |
x1 = x[idx1] | |
label1 = y[idx1] | |
#for the value of y pick another sample for x | |
idx2 = random.choice(digit_indices[label1]) | |
x2 = x[idx2] | |
#x1, x2 are pairs of same y label | |
pairs += [[x1,x2]] | |
#assign 1 for same lables and similar pairs | |
labels +=[1] | |
#add a non-matching pair | |
label2 = random.randint(0,num_classes-1) | |
while label2==label1: | |
#if match then pick non match y | |
label2 = random.randint(0,num_classes-1) | |
idx2 = random.choice(digit_indices[label2]) | |
x2 = x[idx2] | |
#for the same x1, we added a non match y | |
pairs += [[x1,x2]] | |
#for non match value is set to zero | |
#interesting random pick logic | |
labels += [0] | |
return np.array(pairs),np.array(labels).astype("float32") | |
#make train pairs | |
pairs_train, labels_train = make_pairs(x_train,y_train) | |
#make validation pairs | |
pairs_val, labels_val = make_pairs(x_val,y_val) | |
#make test pairs | |
pairs_test, labels_test = make_pairs(x_test,y_test) | |
x_train_1 = pairs_train[:,0] | |
x_train_2 = pairs_train[:,1] | |
x_val_1 = pairs_val[:,0] | |
x_val_2 = pairs_val[:,1] | |
x_test_1 = pairs_test[:,0] | |
x_test_2 = pairs_test[:,1] | |
def visualize(pairs, labels, to_show=6,num_col=3,predictions=None,test=False): | |
num_row = to_show // num_col if to_show // num_col != 0 else 1 | |
to_show = num_row*num_col | |
fig, axes = plt.subplots(num_row,num_col,figsize=(5,5)) | |
for i in range(to_show): | |
if num_row == 1: | |
ax = axes[i%num_col] | |
else: | |
ax = axes[i//num_col,i%num_col] | |
ax.imshow(tf.concat([pairs[i][0],pairs[i][1]],axis=1),cmap='gray') | |
ax.set_axis_off() | |
if test: | |
ax.set_title('True: {} | pred: {:.5f}'.format(labels[i],predictions[i][0])) | |
else: | |
ax.set_title('label: {}'.format(labels[i])) | |
if test: | |
plt.tight_layout(rect=(0,0,1.9,1.9),w_pad=0.0) | |
else: | |
plt.tight_layout(rect=(0,0,1.5,1.5)) | |
plt.show() | |
visualize(pairs_train[:-1],labels_train[:-1],to_show=4,num_col=4) | |
visualize(pairs_val[:-1],labels_val[:-1],to_show=4,num_col=4) | |
visualize(pairs_test[:-1],labels_test[:-1],to_show=4,num_col=4) | |
def euclidean_distance(vects): | |
x,y = vects | |
sum_square = tf.math.reduce_sum(tf.math.square(x-y),axis=1,keepdims=True) | |
return tf.math.sqrt(tf.math.maximum(sum_square,tf.keras.backend.epsilon())) | |
input = layers.Input((28,28,1)) | |
x = tf.keras.layers.BatchNormalization()(input) | |
x = layers.Conv2D(4,(5,5),activation='tanh')(x) | |
x = layers.AveragePooling2D(pool_size=(2,2))(x) | |
x = layers.Conv2D(16,(5,5),activation='tanh')(x) | |
x = layers.AveragePooling2D(pool_size=(2,2))(x) | |
x = layers.Flatten()(x) | |
x = tf.keras.layers.BatchNormalization()(x) | |
x = layers.Dense(10,activation='tanh')(x) | |
embedding_network = keras.Model(input,x) | |
input_1 = layers.Input((28,28,1)) | |
input_2 = layers.Input((28,28,1)) | |
#Two network for two inputs | |
tower_1 = embedding_network(input_1) | |
tower_2 = embedding_network(input_2) | |
#distance computed for both the vectors | |
merge_layer = layers.Lambda(euclidean_distance)([tower_1,tower_2]) | |
normal_layer = tf.keras.layers.BatchNormalization()(merge_layer) | |
#sigmoid if distance > 5 then 1 else 0 - Match vs Non Match | |
output_layer = layers.Dense(1,activation='sigmoid')(normal_layer) | |
siamese = keras.Model(inputs=[input_1,input_2],outputs=output_layer) | |
#custom loss similar, dissimilar | |
def loss(margin=1): | |
def contrastive_loss(y_true,y_pred): | |
square_pred = tf.math.square(y_pred) | |
margin_square = tf.math.square(tf.math.maximum(margin-(y_pred),0)) | |
return tf.math.reduce_mean((1-y_true)*square_pred*(y_true)*margin_square) | |
return contrastive_loss | |
siamese.compile(loss=loss(margin=margin),optimizer='rmsprop',metrics=['accuracy']) | |
siamese.summary() | |
history = siamese.fit([x_train_1,x_train_2],labels_train,validation_data=([x_val_1,x_val_2],labels_val),batch_size=batch_size,epochs=epochs,) | |
#pass two images and check them similar or not | |
predictions = siamese.predict([x_test_1,x_test_2]) | |
visualize(pairs_test,labels_test,to_show=3,predictions=predictions,test=True) |
- Siamese Network Keras for Image and Text similarity.
- Multi-class Classification using Deep Neural Networks
- Siamese Networks
- Image Similarity with Siamese Networks
- machine-learning-experiments
- Generating Images of Clothes Using Deep Convolutional Generative Adversarial Network (DCGAN)
- Plant Disease Using Siamese Network - Keras
- Face Recognition Using Siamese Network
- Image similarity estimation using a Siamese Network with a contrastive loss
- Image similarity estimation using a Siamese Network with a triplet loss
- Near-duplicate image search
- Siamese Neural Networks for Few Shot Learning
Keep Exploring!!!
No comments:
Post a Comment