This notebook demonstrates how to use TensorFlow Similarity to train a SimilarityModel() on a fraction of the MNIST classes, and yet the model is able to index and retrieve similar looking images for all MNIST classes.
1.train() a similarity model on a sub-set of the 10 MNIST classes that will learn how to project digits within a cosine space
2.index() a few examples of each of the 10 classes present in the train dataset (e.g 10 images per classes) to make them searchable
3.lookup() a few test images to check that the trained model, despite having only a few examples of seen and unseen classes in it’s index, is able to efficiently retrieve similar looking examples for all classes.
4.calibrate() the model to estimate what is the best distance theshold to separate matching elements from elements belonging to other classes.
5.match() the test dataset to evaluate how well the calibrated model works for classification purpose.
代码链接: https://github.com/MaoXianXin/Tensorflow_tutorial/blob/main/tensorflow_practice/tf_similarity_hello_world.py
Along the way you can try the following things to improve the model performance:
We are going to load the MNIST dataset and restrict our training data to only N of the 10 classes (6 by default) to showcase how the model is able to find similar examples from classes unseen during training. The model’s ability to generalize the matching to unseen classes, without retraining, is one of the main reason you would want to use metric learning.
Tensorflow similarity expects y_train to be an IntTensor containing the class ids for each example instead of the standard categorical encoding traditionally used for multi-class classification
For a similarity model to learn efficiently, each batch must contains at least 2 examples of each class.
To make this easy, tf_similarity offers Samplers() that enable you to set both the number of classes and the minimum number of examples of each class per batch. Here we are creating a MultiShotMemorySampler() which allows you to sample an in-memory dataset and provides multiple examples per class.
TensorFlow Similarity provides various samplers to accomodate different requirements, including a SingleShotMemorySampler() for single-shot learning, a TFDatasetMultiShotMemorySampler() that integrate directly with the TensorFlow datasets catalogue, and a TFRecordDatasetSampler() that allows you to sample from very large datasets stored on disk as TFRecords shards.
SimilarityModel() models extend tensorflow.keras.model.Model with additional features and functionality that allow you to index and search for similar looking examples.
As visible in the model definition below, similarity models output a 64 dimensional float embedding using the MetricEmbedding() layers. This layer is a Dense layer with L2 normalization. Thanks to the loss, the model learns to minimize the distance between similar examples and maximize the distance between dissimilar examples. As a result, the distance between examples in the embedding space is meaningful; the smaller the distance the more similar the examples are.
Being able to use a distance as a meaningful proxy for how similar two examples are, is what enables the fast ANN (aproximate nearest neighbor) search. Using a sub-linear ANN search instead of a standard quadratic NN search is what allows deep similarity search to scale to millions of items. The built in memory index used in this notebook scales to a million indexed examples very easily… if you have enough RAM
Overall what makes Metric losses different from tradional losses is that:
In this example we are using the MultiSimilarityLoss(). This loss takes a weighted combination of all valid positive and negative pairs, making it one of the best loss that you can use for similarity training
Tensorflow similarity use an extended compile() method that allows you to optionally specify distance_metrics (metrics that are computed over the distance between the embeddings), and the distance to use for the indexer.
By default the compile() method tries to infer what type of distance you are using by looking at the fist loss specified. If you use multiple losses, and the distance loss is not the first one, then you need to specify the distance function used as distance= parameter in the compile function.
Similarity models are trained like normal models.
don’t expect the validation loss to decrease too much here because we only use a subset of the classes within the train data but include all classes in the validation data.
Indexing is where things get different from traditional classification models. Because the model learned to output an embedding that represent the example position within the learned metric space, we need a way to find which known example(s) are the closest to determine the class of the query example (aka nearest neighboors classication).
To do so, we are creating an index of known examples from all the classes present in the dataset. We do this by taking a total of 200 examples from the train dataset which amount to 20 examples per class and we use the index() method of the model to build the index.
we store the images (x_index) as data in the index (data=x_index) so that we can display them later. Here the images are small so its not an issue but in general, be careful while storing a lot of data in the index to avoid blewing up your memory. You might consider using a different Store() backend if you have to store and serve very large indexes.
Indexing more examples per class will help increase the accuracy/generalization, as having more variations improves the classifier “knowledge” of what variations to expect.
Reseting the index is not needed for the first run; however we always calling it to ensure we start the evaluation with a clean index in case of a partial re-run.
To “classify” examples, we need to lookup their k nearest neighbors in the index.
Here we going to query a single random example for each class from the test dataset using select_examples() and then find their nearest neighboors using the lookup() function.
By default the classes 8, 5, 0, and 4 were not seen during training, but we still get reasonable matches as visible in the image below.
To be able to tell if an example matches a given class, we first need to calibrate() the model to find the optimal cut point. This cut point is the maximum distance below which returned neighbors are of the same class. Increasing the threshold improves the recall at the expense of the precision.
By default, the calibration uses the F1Score classification metric to optimally balance out the precsion and recalll; however, you can speficy your own target and change the calibration metric to better suite your usecase.
Let’s plot the performance metrics to see how they evolve as the distance threshold increases.
We clearly see an inflection point where the precision and recall intersect, however, this is not the optimal_cutpoint because the recall continues to increase faster than the precision decreases. Different usecases will have different performance profiles, which why each model needs to be calibrated
We can see in the precision/recall curve below, that the curve is not smooth. This is because the recall can improve independently of the precision causing a seesaw pattern.
Additionally, the model does extremly well on known classes and less well on the unseen ones, which contributes to the flat curve at the begining followed by a sharp decline as the distance threshold increases and examples are further away from the indexed examples.
The purpose of match() is to allow you to use your similarity models to make classification predictions. It accomplishes this by finding the nearest neigbors for a set of query examples and returning an infered label based on neighbors labels and the matching strategy used (MatchNearest by default).
unlike traditional models, the match() method potentially returns -1 when there are no indexed examples below the cutpoint threshold. The -1 class should be treated as “unknown”.
Let’s now match a 10 examples to see how you can use the model match() method in practice.
Now that we have a better sense of what the match() method does, let’s scale up to a few thousand samples per class and evaluate how good our model is at predicting the correct classes.
As expected, while the model prediction performance is very good, its not competitive with a classification model. However this lower accuracy comes with the unique advantage that the model is able to classify classes that were not seen during training
tf.math.confusion_matrix doesn’t support negative classes, so we are going to use class 10 as our unknown class. As mentioned earlier, unknown examples are any testing example for which the closest neighbor distance is greater than the cutpoint threshold.
Saving and reloading the model works as you would expected: