A guide on how to train a model to solve Where’s Wally puzzles
Deep learning provides yet another way to solve the Where’s Wally puzzle problem. But unlike traditional image processing computer vision methods, it works using only a handful of labelled examples that include the location of Wally in an image.
Why look for Wally if a neural network can do it for you?
Final trained model with evaluation images and detection scripts is published on my Github repo.
This post describes the process of training a neural network using Tensorflow Object Detection API and using a Python script built around it to find Wally. It consists of the following steps:
- Preparing the Dataset by creating a set of labelled training images where labels represent x y locations of Wally in an image
- Fetching and configuring the model to use with Tensorflow Object Detection API
- Training the model on our dataset
- Testing the model on evaluation images using the exported graph
Before starting, make sure to install Tensorflow Object Detection API as per the instructions.
Preparing the Dataset
While dealing with neural networks is the most notable process in deep learning, it sadly turns out that the step data scientists spend the most time on is preparing and formatting training data.
Target values for the most simple machine learning problems are usually scalars (like for a digit detector) or a categorical string. Tensorflow Object Detection API training data uses a combination of both. It consists of a set of images accompanied with labels of desired objects and locations of where they appear in an image. Locations are defined with two points since (in 2d space) two points are enough to draw a bounding box around an object.
So in order to create the training set, we need to come up with a set of Where’s Wally puzzle images with locations of where Wally appears.
The final step of preparing our dataset involved packing our labels (saved as a text file) and images (.jpeg) into a single binary .tfrecord file. The process of doing that is explained here, but you can find both train and eval Where’s Wally .tfecord file on my Github repo.
Preparing the Model
Tensorflow Object Detection API provides a set of pre trained models with different performances (usually a speed-accuracy tradeoff) trained on several public datasets.
While the model could be trained from scratch starting with randomly initialised network weights, this process would probably take weeks. Instead, we used a method called transfer learning.
It involves taking a model usually trained to solve some general problem and retraining it to solve ours. The idea behind transfer learning is that instead of reinventing the wheel by training our model from scratch we can use the knowledge obtained in the pre trained model and transfer it to out new one. This saves us a lot of time so that the time spend for training can be invested into obtaining only the knowledge specific to our problem.
After downloading the configuration file make sure to REPLACE “PATH_TO_BE_CONFIGURED” fields with paths pointing to your checkpoint file, training and eval .tfrecord files and the labels map file.
The final file that needs to be configured is the labels.txt map file which includes labels of all our different objects. Since we’re only looking for one type of object, our labels file looks like this
Finally, we should end up with:
- A pretrained model with a .ckpt checkpoint file
- Training and evaluation .tfrecord dataset
- Label map file
- Pipeline configuration file pointing to the files above
Now, we’re ready to start training.
Tensorflow Object Detection API provides a simple-to-use Python script to retrain our model locally. It is located in models/research/object_detection and can be ran with:
python train.py –logtostderr –pipeline_config_path= PATH_TO_PIPELINE_CONFIG –train_dir=PATH_TO_TRAIN_DIR
Where PATH_TO_PIPELINE_CONFIG is the path to our pipeline config file and PATH_TO_TRAIN_DIR is a newly created directory where our new checkpoints and model will be stored.
The output of train.py should look something like this:
With the most important information to look for being loss. It’s a summation of the errors made for each example in training or validation sets. You, of course, want it to be as low as possible, meaning that if it’s slowly decreasing, that means that your model is learning (…or overfitting your training data).
You can also use Tensorboard to display training data in more detail.
The script will automatically store a checkpoint file after a certain number of steps, so that you can restore your saved checkpoints at any time in case your computer crashes while learning.
This means that when you want to finish training the model, you can just terminate the script.
But when to stop learning? The general rule as to when to stop training is when the loss on our evaluation set stops decreasing or is generally very low (below 0.01 in our example).
Now we can actually use our model in practice by testing it on a few example images.
First we need to export an inference graph from the stored checkpoint (which is located in our train directory) using a script in models/research/object_detection :
python export_inference_graph.py — pipeline_config_path PATH_TO_PIPELINE_CONFIG --trained_checkpoint_prefix PATH_TO_CHECPOINT --output_directory OUTPUT_PATH
The exported inference graph is now what our Python script can use to find Wally.
I wrote a few simple Python scripts (based on Tensorflow Object Detection API) you can use to perform object detection on your models and draw boxes around detected objects or expose them.
When using scripts on your own model or own evaluation images, make sure to modify the model_path and image_path variables.
The model published on my Github repo performed surprisingly really well.
It managed to find Wally in the evaluation images and did decently on some extra random examples from the web. It failed to find Wally where he was really large, which by intuition should be even easier to solve as opposed to finding him where he’s really small. This indicates that our model probably overfit our training data mostly as a result of using only a handful of training images.
Anyone looking to improve the model performance by hand labelling some extra images from the web, feel free the submit a PR on my Github repo and help improve the training set.
Gurupriyan is a Software Engineer and a technology enthusiast, he’s been working on the field for the last 6 years. Currently focusing on mobile app development and IoT.